Skip to content

Instantly share code, notes, and snippets.

@jmquintana79
Last active August 31, 2024 17:59
Show Gist options
  • Save jmquintana79/d475a193f82693eadff2de3b1d228a4a to your computer and use it in GitHub Desktop.
Save jmquintana79/d475a193f82693eadff2de3b1d228a4a to your computer and use it in GitHub Desktop.
from sklearn.metrics import silhouette_samples
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
## quantification of clustering quality via silhouette plot
def plot_quantification_clustering_quality(X:np.array, y_km:np.array):
"""
Quantification of clustering quality via silhouette analysis plot.
X -- Array of features used to estimate the clustering.
y_km -- Labels returned by the clustering method to be evaluated.
"""
# labels of clusters
cluster_labels = np.unique(y_km)
# number of clusters
n_clusters = cluster_labels.shape[0]
# estimate silhouette values
silhouette_vals = silhouette_samples(X, y_km, metric='euclidean')
# initialize
y_ax_lower, y_ax_upper = 0, 0
yticks = []
# loop of clusters
for i, c in enumerate(cluster_labels):
# silhouette values per cluster
c_silhouette_vals = silhouette_vals[y_km == c]
c_silhouette_vals.sort()
# plot silhouette by cluster
y_ax_upper += len(c_silhouette_vals)
color = cm.jet(float(i) / n_clusters)
plt.barh(range(y_ax_lower, y_ax_upper), c_silhouette_vals, height=1.0,
edgecolor='none', color=color)
yticks.append((y_ax_lower + y_ax_upper) / 2.)
y_ax_lower += len(c_silhouette_vals)
# estimate final avg
silhouette_avg = np.mean(silhouette_vals)
# plot final avg
plt.axvline(silhouette_avg, color="red", linestyle="--")
# customize plot
plt.yticks(yticks, cluster_labels + 1)
plt.ylabel('Cluster')
plt.xlabel('Silhouette coefficient')
# show
plt.tight_layout()
#plt.savefig('images/11_04.png', dpi=300)
plt.show()
from sklearn.metrics import silhouette_samples
import numpy as np
## quantification of clustering quality via silhouette metric
def quantification_clustering_quality(X:np.array, y_km:np.array, verbose:bool = False)->np.array:
"""
Quantification of clustering quality via silhouette metric.
X -- Array of features used to estimate the clustering.
y_km -- Labels returned by the clustering method to be evaluated.
verbose -- Display or not extra information (default, False).
return -- Array of statistics (mean, std) for each cluster and total.
"""
# clusters labels
cluster_labels = np.unique(y_km)
# number of clusters
n_clusters = cluster_labels.shape[0]
# estimate silhouette values (for all records)
silhouette_vals = silhouette_samples(X, y_km, metric='euclidean')
# initialize
statistics = list()
# loop of cluster labels
for i, c in enumerate(cluster_labels):
# collect silhouette values per cluster
c_silhouette_vals = silhouette_vals[y_km == c]
# estimate statistics for each cluster
statistics.append([f"C{c}", np.mean(c_silhouette_vals), np.std(c_silhouette_vals)])
# display
if verbose:
print(f"C{c}", np.mean(c_silhouette_vals), np.std(c_silhouette_vals))
# final statistics
statistics.append([f"ALL", np.mean(silhouette_vals), np.std(silhouette_vals)])
# display
if verbose:
print(f"ALL", np.mean(silhouette_vals), np.std(silhouette_vals))
# to df and return
return statistics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment