Created
February 1, 2013 04:28
-
-
Save TheDataLeek/4689217 to your computer and use it in GitHub Desktop.
Binary Search Tree implementation and test harness
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env | |
# | |
# binary_search_tree.py | |
# | |
import sys | |
class bt_node: | |
data = 0 | |
left = None | |
right = None | |
class BinarySearchTree: | |
def __init__(self): | |
self.tree = None | |
self.tree_list = [] | |
self.tree_nodes = [] | |
def init_node(self, data): | |
""" | |
Create and return a bt_node object that has been initialized | |
with the given data and two None children. | |
""" | |
new_node = bt_node() | |
new_node.data = data | |
return new_node | |
def insert(self, new_node): | |
""" | |
Insert the new_node into the tree at the correct location. | |
""" | |
if self.tree is None: | |
self.tree = new_node | |
else: | |
cursor = self.tree | |
while True: | |
if new_node.data < cursor.data: | |
if cursor.left is None: | |
cursor.left = new_node | |
break | |
else: | |
cursor = cursor.left | |
if new_node.data >= cursor.data: | |
if cursor.right is None: | |
cursor.right = new_node | |
break | |
else: | |
cursor = cursor.right | |
def insert_data(self, data): | |
""" | |
Insert a new node that contains the given data into the tree | |
at the correct location. | |
""" | |
new_node = self.init_node(data) | |
self.insert(new_node) | |
def remove(self, data): | |
""" | |
Removes a node from the tree whose data value is the same as | |
the argument. | |
""" | |
if self.contains(data): | |
cursor = self.tree | |
parent = None | |
while True: | |
if cursor.data == data: | |
break | |
elif cursor.data > data: | |
parent = cursor | |
cursor = cursor.left | |
elif cursor.data < data: | |
parent = cursor | |
cursor = cursor.right | |
if cursor.right and cursor.left: | |
successor, successor_parent = self.find_successor(cursor) | |
predecessor, predecessor_parent = self.find_predecessor(cursor) | |
self.to_array() | |
if len(self.tree_list) % 2 == 0: | |
if successor_parent: | |
cursor.data = successor.data | |
self.delete_node(successor, successor_parent) | |
else: | |
cursor.data = successor.data | |
self.delete_node(successor, cursor) | |
else: | |
if predecessor_parent: | |
cursor.data = predecessor.data | |
self.delete_node(predecessor, predecessor_parent) | |
else: | |
cursor.data = predecessor.data | |
self.delete_node(predecessor, cursor) | |
else: | |
if parent: | |
self.delete_node(cursor, parent) | |
else: | |
try: | |
successor, successor_parent = self.find_successor(cursor) | |
except AttributeError: | |
predecessor, predecessor_parent = self.find_predecessor(cursor) | |
if successor: | |
cursor.data = successor.data | |
self.delete_node(successor, cursor) | |
else: | |
cursor.data = predecessor.data | |
self.delete_node(predecessor, cursor) | |
def delete_node(self, node, parent): | |
''' | |
Deletes specified node with specified parent | |
Should not be called | |
''' | |
try: | |
if parent.left == node: | |
if node.left: | |
parent.left = node.left | |
elif node.right: | |
parent.left = node.right | |
else: | |
parent.left = None | |
elif parent.right == node: | |
if node.left: | |
parent.right = node.left | |
elif node.right: | |
parent.right = node.right | |
else: | |
parent.right = None | |
else: | |
if parent.left == node: | |
parent.left = None | |
if parent.right == node: | |
parent.right = None | |
except AttributeError: | |
print sys.exc_info() | |
print ' ', self.to_array() | |
print ' ', node.data | |
sys.exit(1) | |
def find_successor(self, node): | |
''' | |
Finds the in-order successor of the given node | |
''' | |
successor_node = node.right | |
successor_top = None | |
while True: | |
if successor_node.left == None: | |
break | |
else: | |
successor_top = successor_node | |
successor_node = successor_node.left | |
return successor_node, successor_top | |
def find_predecessor(self, node): | |
''' | |
Finds the in-order predecessor of the given node | |
''' | |
predecessor_node = node.left | |
predecessor_top = None | |
while True: | |
if predecessor_node.right == None: | |
break | |
else: | |
predecessor_top = predecessor_node | |
predecessor_node = predecessor_node.right | |
return predecessor_node, predecessor_top | |
def contains(self, data): | |
""" | |
Return True or False depending on if this tree contains a node | |
with the supplied data. | |
""" | |
if self.tree is None: | |
return False | |
else: | |
cursor = self.tree | |
while True: | |
if data < cursor.data: | |
if cursor.data == data: | |
return True | |
if cursor.left is None: | |
return False | |
else: | |
cursor = cursor.left | |
if data >= cursor.data: | |
if cursor.data == data: | |
return True | |
if cursor.right is None: | |
return False | |
else: | |
cursor = cursor.right | |
def get_node(self, data): | |
""" | |
If the tree contains a node with the supplied data, return | |
it. Otherwise return None. | |
""" | |
if self.contains(data): | |
cursor = self.tree | |
while True: | |
if data < cursor.data: | |
cursor = cursor.left | |
if data >= cursor.data: | |
if cursor.data == data: | |
return cursor | |
else: | |
cursor = cursor.right | |
else: | |
return None | |
def size(self): | |
""" | |
Return the size of this tree. If it is empty this returns 0. | |
""" | |
if self.tree is None: | |
return 0 | |
else: | |
return len(self.to_array()) | |
def to_array(self): | |
""" | |
Create and fill a list with the data contained in this | |
tree. The elements of the returned list must be in the same | |
order as they are found during an inorder traversal, which | |
means the numbers should be in non-decreasing order. | |
""" | |
if self.tree is None: | |
return [] | |
else: | |
self.tree_list = [] | |
self.tree_nodes = [] | |
self.search(self.tree) | |
return self.tree_list | |
def search(self, node): | |
if node != None: | |
self.tree_list.append(node.data) | |
self.tree_nodes.append(node) | |
self.search(node.left) | |
self.search(node.right) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import unittest | |
from binary_search_tree import BinarySearchTree | |
class TestBinarySearchTree(unittest.TestCase): | |
def __init__(self, *args, **kwargs): | |
unittest.TestCase.__init__(self, *args, **kwargs) | |
def setUp(self): | |
self.bst = BinarySearchTree() | |
self.node_5 = self.bst.init_node(-5) | |
self.node_4 = self.bst.init_node(-4) | |
self.node_3 = self.bst.init_node(-3) | |
self.node_2 = self.bst.init_node(-2) | |
self.node_1 = self.bst.init_node(-1) | |
self.node0 = self.bst.init_node(0) | |
self.node1 = self.bst.init_node(1) | |
self.node2 = self.bst.init_node(2) | |
self.node3 = self.bst.init_node(3) | |
self.node4 = self.bst.init_node(4) | |
self.node5 = self.bst.init_node(5) | |
self.mytree = BinarySearchTree() | |
self.mytree.insert_data(5) | |
self.mytree.insert_data(2) | |
self.mytree.insert_data(3) | |
self.mytree.insert_data(1) | |
self.mytree.insert_data(4) | |
self.mytree.insert_data(0) | |
self.mytree.insert_data(2.5) | |
self.mytree.insert_data(8) | |
self.mytree.insert_data(9) | |
self.mytree.insert_data(6) | |
self.mytree.insert_data(4.5) | |
def test_init_node(self): | |
node1 = self.bst.init_node(5) | |
node2 = self.bst.init_node(4) | |
assert(node1.data == 5) | |
assert(node1.left == None) | |
assert(node1.right == None) | |
assert(node2.data == 4) | |
assert(node2.left == None) | |
assert(node2.right == None) | |
def test_insert(self): | |
assert(self.bst.tree == None) | |
self.bst.insert(self.node0) | |
assert(self.bst.tree.data == self.node0.data) | |
self.bst.insert(self.node3) | |
assert(self.bst.tree.right.data == self.node3.data) | |
self.bst.insert(self.node5) | |
assert(self.bst.tree.right.right.data == self.node5.data) | |
self.bst.insert(self.node4) | |
assert(self.bst.tree.right.right.left.data == self.node4.data) | |
self.bst.insert(self.node2) | |
assert(self.bst.tree.right.left.data == self.node2.data) | |
self.bst.insert(self.node1) | |
assert(self.bst.tree.right.left.left.data == self.node1.data) | |
self.bst.insert(self.node_4) | |
assert(self.bst.tree.left.data == self.node_4.data) | |
self.bst.insert(self.node_5) | |
assert(self.bst.tree.left.left.data == self.node_5.data) | |
self.bst.insert(self.node_2) | |
assert(self.bst.tree.left.right.data == self.node_2.data) | |
self.bst.insert(self.node_3) | |
assert(self.bst.tree.left.right.left.data == self.node_3.data) | |
self.bst.insert(self.node_1) | |
assert(self.bst.tree.left.right.right.data == self.node_1.data) | |
def test_remove(self): | |
assert(self.mytree.contains(0) == True) | |
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 6, 9]) | |
self.mytree.remove(0) | |
assert(self.mytree.contains(0) == False) | |
assert(self.mytree.to_array() == [5, 2, 1, 3, 2.5, 4, 4.5, 8, 6, 9]) | |
assert(self.mytree.contains(4.5) == True) | |
self.mytree.remove(4.5) | |
assert(self.mytree.contains(4.5) == False) | |
assert(self.mytree.to_array() == [5, 2, 1, 3, 2.5, 4, 8, 6, 9]) | |
self.mytree.insert_data(0) | |
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 8, 6, 9]) | |
assert(self.mytree.contains(1) == True) | |
self.mytree.remove(1) | |
assert(self.mytree.contains(1) == False) | |
assert(self.mytree.contains(0) == True) | |
assert(self.mytree.to_array() == [5, 2, 0, 3, 2.5, 4, 8, 6, 9]) | |
# RESET | |
self.setUp() | |
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 6, 9]) | |
self.mytree.remove(5) | |
assert(self.mytree.to_array() == [6, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 9] or | |
self.mytree.to_array() == [4.5, 2, 1, 0, 3, 2.5, 4, 8, 6, 9]) | |
self.mytree.remove(6) | |
assert(self.mytree.to_array() == [4.5, 2, 1, 0, 3, 2.5, 4, 8, 9]) | |
self.mytree.remove(4.5) | |
assert(self.mytree.to_array() == [4, 2, 1, 0, 3, 2.5, 8, 9]) | |
self.mytree.remove(4) | |
assert(self.mytree.to_array() == [3, 2, 1, 0, 2.5, 8, 9] or | |
self.mytree.to_array() == [8, 2, 1, 0, 3, 2.5, 9]) | |
self.mytree.remove(3) | |
assert(self.mytree.to_array() == [2.5, 2, 1, 0, 8, 9] or | |
self.mytree.to_array() == [8, 2, 1, 0, 2.5, 9]) | |
self.mytree.remove(2.5) | |
assert(self.mytree.to_array() == [8, 2, 1, 0, 9]) | |
self.mytree.remove(8) | |
assert(self.mytree.to_array() == [9, 2, 1, 0] or | |
self.mytree.to_array() == [2, 1, 0, 9]) | |
self.mytree.remove(2) | |
assert(self.mytree.to_array() == [9, 1, 0] or | |
self.mytree.to_array() == [1, 0, 9]) | |
self.mytree.remove(1) | |
assert(self.mytree.to_array() == [9, 0]) | |
self.mytree.remove(0) | |
assert(self.mytree.to_array() == [9]) | |
# RESET | |
self.setUp() | |
self.mytree.remove(4.5) | |
self.mytree.remove(4) | |
self.mytree.remove(2.5) | |
self.mytree.remove(3) | |
self.mytree.remove(0) | |
self.mytree.remove(1) | |
self.mytree.remove(2) | |
self.mytree.remove(6) | |
assert(self.mytree.to_array() == [5, 8, 9]) | |
self.mytree.remove(5) | |
assert(self.mytree.to_array() == [8, 9]) | |
self.mytree.remove(8) | |
assert(self.mytree.to_array() == [9]) | |
def test_contains(self): | |
assert(self.mytree.contains(5) == True) | |
assert(self.mytree.contains(4) == True) | |
assert(self.mytree.contains(3) == True) | |
assert(self.mytree.contains(2) == True) | |
assert(self.mytree.contains(1) == True) | |
assert(self.mytree.contains(0) == True) | |
assert(self.mytree.contains(9) == True) | |
assert(self.mytree.contains(80) == False) | |
assert(self.mytree.contains(-5) == False) | |
assert(self.mytree.contains(-10) == False) | |
assert(self.mytree.contains(230) == False) | |
assert(self.mytree.contains(340) == False) | |
def test_get_node(self): | |
acquired_node0 = self.mytree.get_node(2) | |
acquired_node1 = self.mytree.get_node(8) | |
acquired_node2 = self.mytree.get_node(4.5) | |
acquired_node3 = self.mytree.get_node(45) | |
assert(acquired_node0.data == 2) | |
assert(acquired_node0.left.data == 1) | |
assert(acquired_node0.right.data == 3) | |
assert(acquired_node1.data == 8) | |
assert(acquired_node1.left.data == 6) | |
assert(acquired_node1.right.data == 9) | |
assert(acquired_node2.data == 4.5) | |
assert(acquired_node2.left == None) | |
assert(acquired_node2.right == None) | |
assert(acquired_node3 == None) | |
def test_to_array(self): | |
assert(self.bst.to_array() == []) | |
self.bst.insert_data(5) | |
self.bst.insert_data(2) | |
self.bst.insert_data(8) | |
self.bst.insert_data(6) | |
self.bst.insert_data(1) | |
self.bst.insert_data(3) | |
self.bst.insert_data(0) | |
self.bst.insert_data(4) | |
self.bst.insert_data(9) | |
assert(self.bst.to_array() == [5, 2, 1, 0, 3, 4, 8, 6, 9]) | |
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 6, 9]) | |
def test_size(self): | |
assert(self.mytree.size() == 11) | |
self.mytree.remove(4.5) | |
assert(self.mytree.size() == 10) | |
self.mytree.remove(4) | |
assert(self.mytree.size() == 9) | |
self.mytree.insert_data(30) | |
assert(self.mytree.size() == 10) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment