Created
April 15, 2024 05:01
-
-
Save airspeedswift/6e0c037ad5dd1b763d353f358791eef7 to your computer and use it in GitHub Desktop.
Red black tree in Swift 5.10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
indirect enum Tree<Element: Comparable> { | |
enum Color { case R, B } | |
case empty | |
case node(Color, Tree<Element>, Element, Tree<Element>) | |
init() { self = .empty } | |
init( | |
_ x: Element, | |
color: Color = .B, | |
left: Tree<Element> = .empty, | |
right: Tree<Element> = .empty | |
) { | |
self = .node(color, left, x, right) | |
} | |
} | |
extension Tree { | |
func contains(_ x: Element) -> Bool { | |
guard case let .node(_,left,y,right) = self | |
else { return false } | |
return if x < y { left.contains(x) } | |
else if y < x { right.contains(x) } | |
else { true } | |
} | |
} | |
extension Tree { | |
private func balance() -> Self { | |
// z: black --> y: red | |
// / \ / \ | |
// x: red d x: black z: black | |
// / \ / \ / \ | |
// a y: red a b c d | |
// / \ | |
// b c | |
switch self { | |
case let .node(.B, .node(.R, .node(.R, a, x, b), y, c), z, d): | |
.node(.R, .node(.B,a,x,b),y,.node(.B,c,z,d)) | |
case let .node(.B, .node(.R, a, x, .node(.R, b, y, c)), z, d): | |
.node(.R, .node(.B,a,x,b),y,.node(.B,c,z,d)) | |
case let .node(.B, a, x, .node(.R, .node(.R, b, y, c), z, d)): | |
.node(.R, .node(.B,a,x,b),y,.node(.B,c,z,d)) | |
case let .node(.B, a, x, .node(.R, b, y, .node(.R, c, z, d))): | |
.node(.R, .node(.B,a,x,b),y,.node(.B,c,z,d)) | |
default: | |
self | |
} | |
} | |
private func ins(_ x: Element) -> Self { | |
guard case let .node(c, l, y, r) = self | |
else { return Tree(x, color: .R) } | |
return if x < y { | |
Tree(y, color: c, left: l.ins(x), right: r).balance() | |
} else if y < x { | |
Tree(y, color: c, left: l, right: r.ins(x)).balance() | |
} else { | |
self | |
} | |
} | |
public func insert(_ x: Element) -> Self { | |
guard case let .node(_,l,y,r) = self.ins(x) | |
else { fatalError("ins should never return an empty tree") } | |
return .node(.B,l,y,r) | |
} | |
} | |
extension Tree { | |
struct Iterator { | |
var stack: [Tree] = [] | |
var current: Tree | |
} | |
} | |
extension Tree.Iterator: IteratorProtocol { | |
mutating func next() -> Element? { | |
while true { | |
// if there's a left-hand node, head down it | |
if case let .node(_,l,_,_) = current { | |
stack.append(current) | |
current = l | |
} | |
// if there isn’t, head back up, going right as | |
// soon as you can: | |
else if !stack.isEmpty, case let .node(_,_,x,r) = stack.removeLast() { | |
current = r | |
return x | |
} | |
else { | |
// otherwise, we’re done | |
return nil | |
} | |
} | |
} | |
} | |
extension Tree: Sequence { | |
func makeIterator() -> Iterator { | |
Iterator(current: self) | |
} | |
} | |
extension Tree: Equatable { | |
static func == (lhs: Self, rhs: Self) -> Bool { | |
lhs.elementsEqual(rhs) | |
} | |
} | |
extension Tree: ExpressibleByArrayLiteral { | |
init(_ source: some Sequence<Element>) { | |
self = source.reduce(.init()) { $0.insert($1) } | |
} | |
init(arrayLiteral elements: Element...) { | |
self = Tree(elements) | |
} | |
} | |
extension Tree: CustomDebugStringConvertible { | |
var debugDescription: String { | |
"[🌲:\(t.joined(separator: ","))]" | |
} | |
} | |
let engines = [ | |
"Daisy", "Salty", "Harold", "Cranky", | |
"Thomas", "Henry", "James", "Toby", | |
"Belle", "Diesel", "Stepney", "Gordon", | |
"Captain", "Percy", "Arry", "Bert", | |
"Spencer", | |
] | |
let t = Tree(engines) | |
// test various inserting engines in various different permutations | |
for permutation in [engines, engines.sorted(), engines.sorted(by: >),engines.shuffled(),engines.shuffled()] { | |
let t1 = Tree(permutation) | |
assert(t1.contains("James")) | |
assert(!t1.contains("Fred")) | |
assert(t1.elementsEqual(t.insert("Thomas"))) | |
assert(!engines.contains { !t1.contains($0) }) | |
assert(t1 == t) | |
print(t1) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment