Created
December 4, 2014 23:02
-
-
Save blogle/30e0e88ceb963f6557f7 to your computer and use it in GitHub Desktop.
bvka_nn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import networkx as nx | |
import heapq | |
from collections import defaultdict | |
def distance(u, v): | |
return np.sum((u - v)**2) | |
class KDTree(object): | |
def __init__(self, data, index=None, depth=0): | |
""" | |
Creates a recursive space partitioning DataStructure where each node | |
splits the dimension at the median of that axis. Similar to a BST, | |
provides O(n log n) creation and O(log n) queries. | |
Args: | |
data (np.array): dataset with shape n, k (n obs, k dim). | |
Optional | |
index (np.array): Index corresponding to each node in data | |
if left empty, the data is zero indexed. | |
depth (int) : This determines the axis in which to first | |
partition on e.g 0 -> x, 1 -> y, 2 -> z | |
Notes: | |
http://en.wikipedia.org/wiki/K-d_tree | |
http://en.wikipedia.org/wiki/K-d_tree#mediaviewer/File:KDTree-animation.gif | |
""" | |
# Build index at top level | |
if type(index) == type(None): | |
index = np.arange(data.shape[0]) | |
self.n = None | |
self.k = None | |
self.idx = None | |
self.node = None | |
self.axis = None | |
self.left = None | |
self.right = None | |
self.children = None | |
self._build(data, index, depth) | |
def _build(self, data, index, depth): | |
"""Recursively builds the child nodes of the KDTree""" | |
# If there is data to partition create nodes | |
if data[index].size: | |
# Store the dimensions of the data and the axis to partition on | |
self.n, self.k = data[index].shape | |
self.axis = (self.k + depth) % self.k | |
# list of nodes beneath this node | |
self.children = index | |
# Find the index of the data sorted on the current axis | |
# and the midpoint in which to partition | |
idx_data = np.column_stack((data[index], index)) | |
sort_ax = idx_data[np.argsort(idx_data[:, self.axis]), -1].astype(int) | |
partition = sort_ax.size / 2 | |
# Node index and data | |
self.idx = sort_ax[partition] | |
self.node = data[self.idx] | |
# Build the branches, partitioning on the next axis | |
self.left = KDTree(data, sort_ax[ : partition], depth+1) | |
self.right = KDTree(data, sort_ax[partition+1:], depth+1) | |
def near_branch(self, point): | |
"""Returns the branch nearest the input point""" | |
if point[self.axis] < self.node[self.axis]: | |
return self.left | |
return self.right | |
def far_branch(self, point): | |
"""Returns the branch furthest the input point""" | |
if self.near_branch(point) == self.left: | |
return self.right | |
return self.left | |
def orthogonal_dist(self, point): | |
"""computes the distance from a point to the partition""" | |
orth_point = np.copy(point) | |
orth_point[self.axis] = self.node[self.axis] | |
return distance(point, self.node) | |
def query(self, point, best=None): | |
"""Find the nearest neighbor of point in KDTree""" | |
# Dead end backtrack up the tree | |
if self.node is None: | |
return best | |
# Initialize best | |
if best is None: | |
best = (self.idx, self.node) | |
# check if current node is closer than best | |
if distance(self.node, point) < distance(best[1], point): | |
best = (self.idx, self.node) | |
# continue traversing the tree | |
best = self.near_branch(point).query(point, best) | |
# traverse the away branch if the orthogonal distance is less than best | |
if self.orthogonal_dist(point) < distance(best[1], point): | |
best = self.far_branch(point).query(point, best) | |
return best | |
def query_subset(self, point, subset): | |
"""Find the nearest neighbor of point in subset""" | |
subset_vec = np.zeros(self.n) | |
subset_vec[subset] = 1 | |
return self._query_subset(point, subset_vec, None) | |
def _query_subset(self, point, subset, best=None): | |
"""Recursively implements constrained nearest neighbor search""" | |
# Dead end backtrack up the tree | |
if np.all(self.node == None): | |
return best | |
# Initialize node vectors | |
idx_vec = np.empty_like(subset) | |
child_vec = np.empty_like(subset) | |
idx_vec[:] = child_vec[:] = 0 | |
idx_vec[self.idx] = child_vec[self.children] = 1 | |
# if point in subset, try to update best | |
if np.dot(idx_vec, subset) != 0: | |
# if closer than current best, or best is none update | |
# is_closer is a thunk to prevent '__getitem__' error | |
is_closer = lambda: distance(self.node, point) < distance(best[1], point) | |
if np.all(best == None) or is_closer(): | |
best = (self.idx, self.node) | |
near = self.near_branch(point) | |
far = self.far_branch(point) | |
# check the near branch, if its nodes intersect with the queried subset | |
# otherwise move to the away branch | |
if np.dot(child_vec, subset) > 0: | |
best = near._query_subset(point, subset, best) | |
else: | |
best = far._query_subset(point, subset, best) | |
# validate best, by ensuring closer point doesn't exist just beyond | |
# partition if best still has yet to be found also look | |
# into this further branch | |
if (np.all(best != None) and self.orthogonal_dist(point) < | |
distance(best[1], point)) or np.all(best == None): | |
best = far._query_subset(point, subset, best) | |
return best | |
class PriorityQueue(object): | |
def __init__(self): | |
""" | |
Queue implementing highest-priority-in first-out. | |
Note: | |
Priority is cost based, therefore smaller values are prioritized | |
over larger values. | |
""" | |
self._queue = [] | |
self._index = 0 | |
def push(self, item, priority): | |
""" | |
Push an item into the queue. | |
Args: | |
item (obj): Item to be stored in the queue | |
priority (Num): Priority in which item will be retrieved from the queue | |
""" | |
heapq.heappush(self._queue, (priority, self._index, item)) | |
self._index += 1 | |
def pop(self): | |
""" | |
Removes the highest priority item from the queue | |
Returns: | |
obj: item with highest priority | |
""" | |
return heapq.heappop(self._queue)[-1] | |
def merge(self, other): | |
""" | |
Given another queue, consumes each item in it | |
and pushes the item and its priority into its own queue | |
Args: | |
other (PriorityQueue): Queue to be merged | |
""" | |
while other._queue: | |
priority,i,item = heapq.heappop(other._queue) | |
self.push(item, priority) | |
def top(self): | |
""" | |
Allows peek at top item in the queue without removing it | |
Returns: | |
obj: if the queue is not empty otherwise None | |
""" | |
try: | |
return self._queue[0][-1] | |
except: | |
return None | |
def bvka_mst_edges(G, assume_connected=False, pos='coords'): | |
V = set(G.nodes(data=False)) | |
pos = np.row_stack(nx.get_node_attributes(G, pos).values()) | |
kdtree = KDTree(pos) | |
subgraphs = nx.utils.UnionFind() | |
# This could be swapped for a defaultdict if preferred | |
queues = defaultdict(PriorityQueue) | |
for v in V: | |
# Todo restrict this further to connected edges | |
vm, _ = kdtree.query_subset(pos[v], list(V - {v})) | |
dm = distance(pos[v], pos[vm]) | |
root = subgraphs[v] | |
queues[root].push((v, vm), dm) | |
Et = [] | |
while len(Et) != len(V) - 1: | |
Ep = PriorityQueue() | |
for C in set(map(subgraphs.__getitem__, subgraphs.parents.values())): | |
(v, vm) = queues[C].top() | |
component_set = [child for child, parent | |
in subgraphs.parents.iteritems() | |
if parent == C] | |
disjoint_nodes = list(V - set(component_set)) | |
while vm in component_set: | |
queues[C].pop() | |
um, _ = kdtree.query_subset(pos[v], disjoint_nodes) | |
dm = distance(pos[v], pos[vm]) | |
queues[C].push((v, um), dm) | |
(v, vm) = queues[C].top() | |
dm = distance(pos[v], pos[vm]) | |
Ep.push((v, vm, dm), dm) | |
while Ep._queue: | |
(um, vm, dm) = Ep.pop() | |
component_i, component_j = subgraphs[um], subgraphs[vm] | |
if component_i != component_j: | |
# add the edge and merge the queues | |
Et += [(um, vm)] | |
subgraphs.union(um, vm) | |
if component_i == subgraphs[um]: | |
major, minor = component_i, component_j | |
else: | |
minor, major = component_i, component_j | |
queues[major].merge(queues[minor]) | |
del(queues[minor]) | |
return Et |
@patrafter1999
5 years later I am stumbling across this message :(
To you and anyone else coming across this gist, please feel free to use the above code as you wish
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Blogle,
It's an awesome work. Much cleaner than scipy.spatial.KDTree. I didn't test this code yet. Hopefully it's better than scipy version as you shown in the comparison graph. I'm trying to modify your code a little so that I can remove some of the nodes as I wish. Do you have any license on this code?
Cheers,
Sean