Created
November 17, 2017 12:15
-
-
Save spyhi/2ce0ab2d008a4a1b46e41b9e2dcf04e3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
from sklearn.cluster import KMeans | |
from sklearn.utils import check_random_state | |
from sklearn.datasets.samples_generator import make_blobs | |
RUNS = 1 | |
_CLUSTERS = 4 | |
T_INIT_RANGE = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, | |
11, 12, 14, 15, 16, 17, 18, 19, 20]) | |
def kmeansinittest(): | |
"""This script runs the test to determine | |
the effect on increasing number of random inits""" | |
initfig = plt.figure() | |
plots = [] | |
inertia = np.empty((len(T_INIT_RANGE), RUNS)) | |
for run in range(RUNS): | |
data = make_blobs(n_samples=200, n_features=2, centers=_CLUSTERS, | |
cluster_std=1) | |
for i, t_init in enumerate(T_INIT_RANGE): | |
km = KMeans(n_clusters=_CLUSTERS, init='random', n_init=t_init, | |
max_iter=300, algorithm='full') | |
km.fit(data[0]) | |
inertia[i, run] = km.inertia_ | |
errorp = plt.errorbar(T_INIT_RANGE, inertia.mean(axis=1), inertia.std(axis=1)) | |
print(inertia.mean(axis=1), inertia.std(axis=1)) | |
plots.append(errorp) | |
plt.xlabel('n_init') | |
plt.ylabel('inertia') | |
plt.title("Mean inertia for various k-means init across %d runs" % RUNS) | |
h = .02 | |
x_min, x_max = data[0][:,0].min()-1, data[0][:,0].max()+1 | |
y_min, y_max = data[0][:,1].min()-1, data[0][:,1].max()+1 | |
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) | |
Z = km.predict(np.c_[xx.ravel(), yy.ravel()]) | |
Z = Z.reshape(xx.shape) | |
centroids = km.cluster_centers_ | |
fig2 = plt.figure() | |
plt.subplot(111) | |
plt.imshow(Z, interpolation='nearest', | |
extent=(xx.min(), xx.max(), yy.min(), yy.max()), | |
aspect='auto', origin='lower', cmap='Pastel2') | |
plt.scatter(data[0][:,0], data[0][:,1], marker='o', c=data[1][:], s=25, cmap='Set1') | |
plt.scatter(centroids[:,0], centroids[:,1], marker="x", c="w", s=150, linewidths=3) | |
plt.show() | |
kmeansinittest() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Outputs means and standard deviations in console for inspection, and plt outputs should look something like this (error vs n_init, and final plot/fit)