Skip to content

Instantly share code, notes, and snippets.

@botcs
Created February 19, 2024 17:10
Show Gist options
  • Save botcs/9d61645842c28e4c59902f92c273f7be to your computer and use it in GitHub Desktop.
Save botcs/9d61645842c28e4c59902f92c273f7be to your computer and use it in GitHub Desktop.
import torch
class MicroCluster:
def __init__(self, x, label, max_radius=0.1):
self.N = 1
self.LS = x.clone()
self.SS = x.pow(2)
self.label = label
self.max_radius = max_radius
def add_point(self, x):
self.N += 1
self.LS += x
self.SS += x.pow(2)
@property
def centroid(self):
return self.LS / self.N
@property
def radius(self):
return torch.sqrt(self.SS / self.N - self.centroid.pow(2)).mean()
def can_accept(self, x):
return torch.norm(x - self.centroid) <= self.max_radius
def merge(self, mc):
self.N += mc.N
self.LS += mc.LS
self.SS += mc.SS
class MClassification:
def __init__(self, max_radius=0.1, max_clusters=100):
self.max_radius = max_radius
self.max_clusters = max_clusters
self.micro_clusters = []
def fit_initial(self, X, labels):
for x, label in zip(X, labels):
self.micro_clusters.append(MicroCluster(x, label, self.max_radius))
def find_closest_clusters(self):
min_distance = float('inf')
pair = (None, None)
for i in range(len(self.micro_clusters)):
for j in range(i + 1, len(self.micro_clusters)):
distance = torch.norm(self.micro_clusters[i].centroid - self.micro_clusters[j].centroid)
if distance < min_distance:
min_distance = distance
pair = (i, j)
return pair
def find_farthest_clusters_from_point(self, x, label):
distances = [(i, torch.norm(x - mc.centroid)) for i, mc in enumerate(self.micro_clusters) if mc.label == label]
if not distances:
return None, None
sorted_distances = sorted(distances, key=lambda d: d[1], reverse=True)
return sorted_distances[0][0], sorted_distances[1][0] if len(sorted_distances) > 1 else None
def merge_clusters(self, index1, index2):
self.micro_clusters[index1].merge(self.micro_clusters[index2])
del self.micro_clusters[index2]
def predict_and_update(self, x):
if not self.micro_clusters:
return -1 # Placeholder for no clusters case
distances = [torch.norm(x - mc.centroid) for mc in self.micro_clusters]
nearest_mc_index = distances.index(min(distances))
nearest_mc = self.micro_clusters[nearest_mc_index]
if nearest_mc.can_accept(x):
nearest_mc.add_point(x)
else:
if len(self.micro_clusters) > self.max_clusters:
farthest_i, farthest_j = self.find_farthest_clusters_from_point(x, nearest_mc.label)
if farthest_i is not None and farthest_j is not None:
self.merge_clusters(farthest_i, farthest_j)
self.micro_clusters.append(MicroCluster(x, nearest_mc.label, self.max_radius))
return nearest_mc.label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment