Last active
August 31, 2024 17:59
-
-
Save jmquintana79/d475a193f82693eadff2de3b1d228a4a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import silhouette_samples | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
## quantification of clustering quality via silhouette plot | |
def plot_quantification_clustering_quality(X:np.array, y_km:np.array): | |
""" | |
Quantification of clustering quality via silhouette analysis plot. | |
X -- Array of features used to estimate the clustering. | |
y_km -- Labels returned by the clustering method to be evaluated. | |
""" | |
# labels of clusters | |
cluster_labels = np.unique(y_km) | |
# number of clusters | |
n_clusters = cluster_labels.shape[0] | |
# estimate silhouette values | |
silhouette_vals = silhouette_samples(X, y_km, metric='euclidean') | |
# initialize | |
y_ax_lower, y_ax_upper = 0, 0 | |
yticks = [] | |
# loop of clusters | |
for i, c in enumerate(cluster_labels): | |
# silhouette values per cluster | |
c_silhouette_vals = silhouette_vals[y_km == c] | |
c_silhouette_vals.sort() | |
# plot silhouette by cluster | |
y_ax_upper += len(c_silhouette_vals) | |
color = cm.jet(float(i) / n_clusters) | |
plt.barh(range(y_ax_lower, y_ax_upper), c_silhouette_vals, height=1.0, | |
edgecolor='none', color=color) | |
yticks.append((y_ax_lower + y_ax_upper) / 2.) | |
y_ax_lower += len(c_silhouette_vals) | |
# estimate final avg | |
silhouette_avg = np.mean(silhouette_vals) | |
# plot final avg | |
plt.axvline(silhouette_avg, color="red", linestyle="--") | |
# customize plot | |
plt.yticks(yticks, cluster_labels + 1) | |
plt.ylabel('Cluster') | |
plt.xlabel('Silhouette coefficient') | |
# show | |
plt.tight_layout() | |
#plt.savefig('images/11_04.png', dpi=300) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import silhouette_samples | |
import numpy as np | |
## quantification of clustering quality via silhouette metric | |
def quantification_clustering_quality(X:np.array, y_km:np.array, verbose:bool = False)->np.array: | |
""" | |
Quantification of clustering quality via silhouette metric. | |
X -- Array of features used to estimate the clustering. | |
y_km -- Labels returned by the clustering method to be evaluated. | |
verbose -- Display or not extra information (default, False). | |
return -- Array of statistics (mean, std) for each cluster and total. | |
""" | |
# clusters labels | |
cluster_labels = np.unique(y_km) | |
# number of clusters | |
n_clusters = cluster_labels.shape[0] | |
# estimate silhouette values (for all records) | |
silhouette_vals = silhouette_samples(X, y_km, metric='euclidean') | |
# initialize | |
statistics = list() | |
# loop of cluster labels | |
for i, c in enumerate(cluster_labels): | |
# collect silhouette values per cluster | |
c_silhouette_vals = silhouette_vals[y_km == c] | |
# estimate statistics for each cluster | |
statistics.append([f"C{c}", np.mean(c_silhouette_vals), np.std(c_silhouette_vals)]) | |
# display | |
if verbose: | |
print(f"C{c}", np.mean(c_silhouette_vals), np.std(c_silhouette_vals)) | |
# final statistics | |
statistics.append([f"ALL", np.mean(silhouette_vals), np.std(silhouette_vals)]) | |
# display | |
if verbose: | |
print(f"ALL", np.mean(silhouette_vals), np.std(silhouette_vals)) | |
# to df and return | |
return statistics |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment