Skip to content

Instantly share code, notes, and snippets.

@davidrpugh
Created March 16, 2016 08:22
Show Gist options
  • Save davidrpugh/c28abfc1dd86076128dc to your computer and use it in GitHub Desktop.
Save davidrpugh/c28abfc1dd86076128dc to your computer and use it in GitHub Desktop.
Simple implementation of a binary search tree...
class BinarySearchTree(object):
_length = 0
_root = None
class TreeNode(object):
def __init__(self, value, left_child, parent, right_child):
"""Create an instance of a TreeNode."""
self.value = value
self.left_child = left_child
self.parent = parent
self.right_child = right_child
def __eq__(self, other):
return self.value == other.value
def __lt__(self, other):
return self.value < other.value
@property
def has_left_child(self):
"""Return true if node has a left child; false otherwise."""
return False if self.left_child is None else True
@property
def has_right_child(self):
"""Return true if node has a right child; false otherwise."""
return False if self.right_child is None else True
@property
def is_leaf(self):
"""Return true if node has no children; false otherwise."""
return True if (self.left_child is None) and (self.right_child is None) else False
@property
def is_left_child(self):
"""Return true if node is the left child of some other node."""
return True if (self.parent is not None) and (self.parent > self) else False
@property
def is_right_child(self):
"""Return true if node is the right child of some other node."""
return True if (self.parent is not None) and (self.parent < self) else False
@property
def left_child(self):
"""Return the left child node."""
return self._left_child
@left_child.setter
def left_child(self, node):
"""Set a new left child node."""
self._left_child = self._validate_left_child(node)
@property
def right_child(self):
"""Return the right child node."""
return self._right_child
@right_child.setter
def right_child(self, node):
"""Set a new right child node."""
self._right_child = self._validate_right_child(node)
def replace_with(self, node):
if self.is_left_child:
self.parent.left_child = node
else:
self.parent.right_child = node
def _validate_left_child(self, node):
if node is None or node < self:
return node
else:
msg = "Left child value must be strictly less than parent value."
raise AttributeError(msg)
def _validate_right_child(self, node):
if node is None or node > self:
return node
else:
msg = "Right child value must be strictly greater than parent value."
raise AttributeError(msg)
def __init__(self, value=None):
"""Create an instance of a BinaryTree."""
if value is not None:
self.insert(value)
else:
pass
def __len__(self):
return self._length
@property
def minimum(self):
"""
Return the minimum value in the tree in O(log n) time.
Notes
-----
Could be made O(1) by updating the minimum value on all
insert and remove calls?
"""
minimum_node = self._find_minimum(self._root)
return minimum_node.value
@property
def maximum(self):
"""
Return the maximum value in the tree in O(log n) time.
Notes
-----
Could be made O(1) by updating the maxium value on all
insert and remove calls?
"""
maximum_node = self._find_maximum(self._root)
return maximum_node.value
@staticmethod
def _find_minimum(node):
"""O(log n) search for the minimum value in the tree"""
current_node = node
while current_node.left_child is not None:
current_node = current_node.left_child
return current_node
@staticmethod
def _find_maximum(node):
"""O(log n) search for the maximum value in the tree"""
current_node = node
while current_node.right_child is not None:
current_node = current_node.right_child
return current_node
@classmethod
def _find(cls, node, value):
"""O(log n) search for node containing some data."""
if node is None:
return None
elif node.value == value:
return node
elif node.value < value:
return cls._find(node.right_child, value)
else:
return cls._find(node.left_child, value)
def _insert(self, node, value, parent):
"""O(log n) insertion of a new data node into the tree."""
if node is None:
new_node = self.TreeNode(value, None, parent, None)
if new_node < parent:
parent.left_child = new_node
else:
parent.right_child = new_node
elif node.value < value:
self._insert(node.right_child, value, node)
else:
self._insert(node.left_child, value, node)
def _remove(self, node):
"""O(log n) removal of a node from the tree."""
self._length -= 1
if node.is_leaf:
node.replace_with(None)
else: # node has at least one child!
if not node.has_right_child:
node.left_child.parent = node.parent
if node.parent is None:
self._root = node.left_child
else:
node.replace_with(node.left_child)
elif not node.right_child.has_left_child:
node.right_child.parent = node.parent
node.right_child.left_child = node.left_child
if node.parent is None:
self._root = node.right_child
else:
node.replace_with(node.right_child)
else:
# node's right child has a left child!
leftmost_child = self._find_minimum(node.right_child)
leftmost_child.parent = node.parent
leftmost_child.left_child = node.left_child
leftmost_child.right_child = node.right_child
if node.parent is None:
self._root = leftmost_child
else:
node.replace_with(leftmost_child)
@classmethod
def _traverse_nodes_in_order(cls, node, f):
"""O(n) in-order traversal of the tree."""
if node is None:
pass
else:
cls._traverse_nodes_in_order(node.left_child, f)
f(node.value)
cls._traverse_nodes_in_order(node.right_child, f)
@classmethod
def _traverse_nodes_pre_order(cls, node, f):
"""O(n) pre-order traversal of the tree."""
if node is None:
pass
else:
f(node.value)
cls._traverse_nodes_pre_order(node.left_child, f)
cls._traverse_nodes_pre_order(node.right_child, f)
@classmethod
def _traverse_nodes_post_order(cls, node, f):
"""O(n) post-order traversal of the tree."""
if node is None:
pass
else:
cls._traverse_nodes_post_order(node.left_child, f)
cls._traverse_nodes_post_order(node.right_child, f)
f(node.value)
def insert(self, value):
"""O(log n) insertion of a new value into the tree."""
self._length += 1
if self._root is None:
self._root = self.TreeNode(value, None, None, None)
else:
self._insert(self._root, value, None)
def remove(self, value):
"""O(log n) removal of some value from the tree."""
node = self._find(self._root, value)
if node is not None:
self._remove(node)
else:
raise ValueError("{} is not in the tree!".format(value))
def traverse(self, f, order="in_order"):
"""O(n) traversal of the tree."""
if order == "in_order":
self._traverse_nodes_in_order(self._root, f)
elif order == "pre_order":
self._traverse_nodes_pre_order(self._root, f)
elif order == "post_order":
self._traverse_nodes_post_order(self._root, f)
else:
raise ValueError("'order' must be one of 'in_order', 'pre_order', or 'post_order'.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment