Skip to content

Instantly share code, notes, and snippets.

@botcs
Created February 19, 2024 17:18
Show Gist options
  • Save botcs/9d290eea4738aea30b2eb98cb32b9191 to your computer and use it in GitHub Desktop.
Save botcs/9d290eea4738aea30b2eb98cb32b9191 to your computer and use it in GitHub Desktop.
import torch
class MicroCluster:
def __init__(self, x, label, max_radius=0.1):
self.N = 1
self.LS = x.clone()
self.SS = x.pow(2)
self.label = label
self.max_radius = max_radius
def add_point(self, x):
self.N += 1
self.LS += x
self.SS += x.pow(2)
@property
def centroid(self):
return self.LS / self.N
@property
def radius(self):
return torch.sqrt(self.SS / self.N - self.centroid.pow(2)).mean()
def can_accept(self, x):
return torch.norm(x - self.centroid) <= self.max_radius
def merge(self, mc):
self.N += mc.N
self.LS += mc.LS
self.SS += mc.SS
class MClassification:
def __init__(self, max_radius=0.1, max_clusters=100):
self.max_radius = max_radius
self.max_clusters = max_clusters
self.micro_clusters = []
def fit_initial(self, X, labels):
for x, label in zip(X, labels):
self.micro_clusters.append(MicroCluster(x, label, self.max_radius))
def find_farthest_mc_from_point_with_label(self, x, label):
farthest_distance = -1
farthest_index = None
for i, mc in enumerate(self.micro_clusters):
if mc.label == label:
distance = torch.norm(x - mc.centroid)
if distance > farthest_distance:
farthest_distance = distance
farthest_index = i
return farthest_index
def find_closest_mc_to_mc(self, mc_index):
target_mc = self.micro_clusters[mc_index]
closest_distance = float('inf')
closest_index = None
for i, mc in enumerate(self.micro_clusters):
if i != mc_index:
distance = torch.norm(target_mc.centroid - mc.centroid)
if distance < closest_distance:
closest_distance = distance
closest_index = i
return closest_index
def merge_clusters(self, index1, index2):
self.micro_clusters[index1].merge(self.micro_clusters[index2])
del self.micro_clusters[index2]
def predict_and_update(self, x):
if not self.micro_clusters:
return -1 # Placeholder for no clusters case
distances = [torch.norm(x - mc.centroid) for mc in self.micro_clusters]
nearest_mc_index = distances.index(min(distances))
nearest_mc = self.micro_clusters[nearest_mc_index]
if nearest_mc.can_accept(x):
nearest_mc.add_point(x)
else:
# When a new point doesn't fit in existing clusters
# Step 1: Find the farthest MC with the same label
farthest_mc_index = self.find_farthest_mc_from_point_with_label(x, nearest_mc.label)
# Step 2: Find the closest MC to the farthest MC
if farthest_mc_index is not None:
closest_to_farthest_mc_index = self.find_closest_mc_to_mc(farthest_mc_index)
# Step 3: Merge these two MCs
if closest_to_farthest_mc_index is not None:
self.merge_clusters(farthest_mc_index, closest_to_farthest_mc_index)
# Add the new point as a new MC
self.micro_clusters.append(MicroCluster(x, nearest_mc.label, self.max_radius))
return nearest_mc.label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment