Skip to content

Instantly share code, notes, and snippets.

@azkalot1
Created February 12, 2019 22:53
Show Gist options
  • Save azkalot1/d03222d8dd7dbbc37aa8fc40826aca6e to your computer and use it in GitHub Desktop.
Save azkalot1/d03222d8dd7dbbc37aa8fc40826aca6e to your computer and use it in GitHub Desktop.
#basically like at https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html, but our data is reall
#prepre paramets
params = {'quantile': .3,
'eps': .3,
'damping': .9,
'preference': -200,
'n_neighbors': 10,
'n_clusters': 5}
bandwidth = estimate_bandwidth(embedding, quantile=params['quantile'])
connectivity = kneighbors_graph(
embedding, n_neighbors=params['n_neighbors'], include_self=False)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ward = AgglomerativeClustering(
n_clusters=params['n_clusters'], linkage='ward',
connectivity=connectivity)
spectral = SpectralClustering(
n_clusters=params['n_clusters'], eigen_solver='arpack',
affinity="nearest_neighbors")
dbscan = DBSCAN(eps=params['eps'])
affinity_propagation = AffinityPropagation(
damping=params['damping'], preference=params['preference'])
average_linkage = AgglomerativeClustering(
linkage="average", affinity="cityblock",
n_clusters=params['n_clusters'], connectivity=connectivity)
birch = Birch(n_clusters=params['n_clusters'])
gmm = GaussianMixture(n_components=params['n_clusters'], covariance_type='full')
clustering_algorithms = (
('AffinityPropagation', affinity_propagation),
('MeanShift', ms),
('SpectralClustering', spectral),
('Ward', ward),
('AgglomerativeClustering', average_linkage),
('DBSCAN', dbscan),
('Birch', birch),
('GaussianMixture', gmm))
#now plot everything
f, ax = plt.subplots(2, 4, figsize=(20,15))
for idx, (name, algorithm) in enumerate(clustering_algorithms):
algorithm.fit(embedding)
if hasattr(algorithm, 'labels_'):
y_pred = algorithm.labels_.astype(np.int)
else:
y_pred = algorithm.predict(embedding)
colors = np.array(list(islice(cycle(['#377eb8', '#ff7f00', '#4daf4a',
'#f781bf', '#a65628', '#984ea3',
'#999999', '#e41a1c', '#dede00']),
int(max(y_pred) + 1))))
# add black color for outliers (if any)
colors = np.append(colors, ["#000000"])
ax[idx//4, idx%4].scatter(embedding[:, 0], embedding[:, 1], s=2, color=colors[y_pred])
#ax[idx//4, idx%4].xlim(-2.5, 2.5)
#ax[idx//4, idx%4].ylim(-2.5, 2.5)
ax[idx//4, idx%4].set_xticks(())
ax[idx//4, idx%4].set_yticks(())
ax[idx//4, idx%4].set_title(name)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment