Created
April 13, 2019 20:29
-
-
Save bvarghese1/bf9dabb64533b5040e1dc9cc47b58535 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import numpy as np | |
from abc import abstractmethod | |
class ClusterTemplate(object): | |
def __init__(self, path): | |
self.path = path | |
if not os.path.exists(self.path): | |
os.makedirs(self.path, exist_ok=True) | |
self.centroids = None | |
self.clustering_algo = None | |
@abstractmethod | |
def init_cluster_algo(self, num_clusters): | |
raise NotImplementedError("Abstract method 'init_cluster_algo' not implemented") | |
def train(self, embeddings): | |
self.clustering_algo.fit(embeddings) | |
def should_normalize(self): | |
return True | |
def normalize(self, embeddings, num_clusters): | |
# Extract the assigned cluster labels | |
labels = self.clustering_algo.labels_ | |
# Generate centroids using the features and assigned cluster labels | |
data = np.empty((0, features.shape[1]), 'float32') | |
for i in range(num_clusters): | |
row = np.dot(labels == i, embeddings) / np.sum(labels == i) | |
data = np.vstack((data, row)) | |
# Normalize | |
tdata = data.transpose() | |
self.centroids = (tdata / np.sqrt(np.sum(tdata * tdata, axis=0))).transpose() | |
def save(self, cluster_name): | |
np.save(os.path.join(self.path, cluster_name), self.centroids) | |
# Final method that no sub class must override. Should be invoked directly from the client | |
def cluster(self, features, cluster_name, niter=20, num_clusters=100): | |
self.init_cluster_algo(num_clusters) | |
self.train(embeddings) | |
if self.should_normalize(): | |
self.normalize(embeddings, num_clusters) | |
self.save(cluster_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment