Created
February 19, 2024 17:10
-
-
Save botcs/9d61645842c28e4c59902f92c273f7be to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class MicroCluster: | |
def __init__(self, x, label, max_radius=0.1): | |
self.N = 1 | |
self.LS = x.clone() | |
self.SS = x.pow(2) | |
self.label = label | |
self.max_radius = max_radius | |
def add_point(self, x): | |
self.N += 1 | |
self.LS += x | |
self.SS += x.pow(2) | |
@property | |
def centroid(self): | |
return self.LS / self.N | |
@property | |
def radius(self): | |
return torch.sqrt(self.SS / self.N - self.centroid.pow(2)).mean() | |
def can_accept(self, x): | |
return torch.norm(x - self.centroid) <= self.max_radius | |
def merge(self, mc): | |
self.N += mc.N | |
self.LS += mc.LS | |
self.SS += mc.SS | |
class MClassification: | |
def __init__(self, max_radius=0.1, max_clusters=100): | |
self.max_radius = max_radius | |
self.max_clusters = max_clusters | |
self.micro_clusters = [] | |
def fit_initial(self, X, labels): | |
for x, label in zip(X, labels): | |
self.micro_clusters.append(MicroCluster(x, label, self.max_radius)) | |
def find_closest_clusters(self): | |
min_distance = float('inf') | |
pair = (None, None) | |
for i in range(len(self.micro_clusters)): | |
for j in range(i + 1, len(self.micro_clusters)): | |
distance = torch.norm(self.micro_clusters[i].centroid - self.micro_clusters[j].centroid) | |
if distance < min_distance: | |
min_distance = distance | |
pair = (i, j) | |
return pair | |
def find_farthest_clusters_from_point(self, x, label): | |
distances = [(i, torch.norm(x - mc.centroid)) for i, mc in enumerate(self.micro_clusters) if mc.label == label] | |
if not distances: | |
return None, None | |
sorted_distances = sorted(distances, key=lambda d: d[1], reverse=True) | |
return sorted_distances[0][0], sorted_distances[1][0] if len(sorted_distances) > 1 else None | |
def merge_clusters(self, index1, index2): | |
self.micro_clusters[index1].merge(self.micro_clusters[index2]) | |
del self.micro_clusters[index2] | |
def predict_and_update(self, x): | |
if not self.micro_clusters: | |
return -1 # Placeholder for no clusters case | |
distances = [torch.norm(x - mc.centroid) for mc in self.micro_clusters] | |
nearest_mc_index = distances.index(min(distances)) | |
nearest_mc = self.micro_clusters[nearest_mc_index] | |
if nearest_mc.can_accept(x): | |
nearest_mc.add_point(x) | |
else: | |
if len(self.micro_clusters) > self.max_clusters: | |
farthest_i, farthest_j = self.find_farthest_clusters_from_point(x, nearest_mc.label) | |
if farthest_i is not None and farthest_j is not None: | |
self.merge_clusters(farthest_i, farthest_j) | |
self.micro_clusters.append(MicroCluster(x, nearest_mc.label, self.max_radius)) | |
return nearest_mc.label |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment