Skip to content

Instantly share code, notes, and snippets.

@ispapadakis
Forked from cocodrips/union_find.py
Last active February 16, 2018 23:21
Show Gist options
  • Save ispapadakis/004dfcc0f9081b7f0e341929033008ed to your computer and use it in GitHub Desktop.
Save ispapadakis/004dfcc0f9081b7f0e341929033008ed to your computer and use it in GitHub Desktop.
Union Find in Python with Path Compression and Balancing
'''
Implement a Disjoint Set Data Structure Using Path Compression and Balancing
'''
class disjointSet:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, element):
if self.parent[element] == element:
return element
self.parent[element] = self.find(self.parent[element])
return self.parent[element]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.rank[x] < self.rank[y]:
self.parent[x] = y
else:
self.parent[y] = x
if self.rank[x] == self.rank[y]:
self.rank[x] += 1
if __name__ == "__main__":
from random import randint, seed
seed(0)
elementCount = 70
unitedCount = 45
dset = disjointSet(elementCount)
united = []
for _ in range(unitedCount):
i = j = 0
while i == j:
i = randint(0,elementCount-1)
j = randint(0,elementCount-1)
united.append([i,j])
dset.union(i,j)
print('Union Operations')
for u in united:
print('U({:3d},{:3d})'.format(*u),end='\t')
print()
print('Resulting Element Parents')
for i,p in enumerate(dset.parent):
print('P({:3d}) = {:3d}'.format(i,p),end='\t')
print()
print('Resulting Roots')
height = dict()
for i,p in enumerate(dset.parent):
if i == p:
print(i,end=' ')
continue
rank = 0
elem = i
while dset.parent[elem] != elem:
rank += 1
elem = dset.parent[elem]
height[i] = rank
print()
print('Disjoint Set Tree Heights')
countHeights = dict()
for i in height:
countHeights.setdefault(height[i],[]).append(i)
for h in countHeights:
print('Height {:2d}: {}'.format(h,countHeights[h]))
print('Not Bad for Union Method Heuristic !')
print('Expecting')
expected = '''
Disjoint Set Tree Heights
Height 1: [0, 1, 4, 5, 8, 9, 11, 12, 13, 18, 19, 26, 27, 28, 36, 38, 41, 49, 50, 53, 57, 60, 61, 62, 63, 66, 67, 68, 69]
Height 2: [7, 10, 23, 24, 30, 33, 35, 37, 39, 45, 55, 64]
Height 3: [15, 40]
'''
print(expected)
@ispapadakis
Copy link
Author

Warning

This recursive definition above leads to long processing times in Python.
Use alternatively this:

class disjointSet:
    __slots__ = ['parent','rank']
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, element):
        p = self.parent
        start = element
        while element != p[element]:
            element = p[element]
        p[start] = element
        return element

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x == y:
            return
        
        if self.rank[x] < self.rank[y]:
            self.parent[x] = y
        else:
            self.parent[y] = x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1

Or another programming language!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment