Skip to content

Instantly share code, notes, and snippets.

@jmsdnns
Last active August 27, 2024 14:37
Show Gist options
  • Save jmsdnns/895d8cb0892cd13e57c9b589296d770a to your computer and use it in GitHub Desktop.
Save jmsdnns/895d8cb0892cd13e57c9b589296d770a to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# Output looks like this:
#
# $ python ./kmeans.py
# Centroids: [[14, 17], [17, 7], [3, 15]]
# Done: [[17.5, 16.5], [16.5, 4.166666666666667], [1.8, 1.6]]
# 0: [[15, 14], [18, 14], [20, 20], [17, 18]]
# 1: [[12, 10], [20, 0], [19, 2], [18, 1], [11, 9], [19, 3]]
# 2: [[0, 1], [1, 1], [2, 2], [3, 3], [3, 1]]
import math
from random import randint
K = 3
data = [[0,1], [1,1], [2,2], [3,3], [15,14], [12,10], [18,14], [20,0], [19,2],
[18,1], [11,9], [3,1], [19,3], [20,20], [17,18]]
def distance(p1, p2):
"""euclidean distance"""
dims = len(p1)
d = math.sqrt(sum([(p1[d]-p2[d])**2 for d in range(dims)]))
return abs(d)
# initialize centroids
min_x = min([x for x,y in data])
min_y = min([y for x,y in data])
max_x = max([x for x,y in data])
max_y = max([y for x,y in data])
centroids = [[randint(min_x, max_x), randint(min_y, max_y)] for i in range(K)]
print(f"Centroids: {centroids}")
def cluster_loop(centroids):
# distance from each point to each centroid
distances = [[] for i in range(K)]
for idx, c in enumerate(centroids):
for point in data:
d = distance(point, c)
distances[idx].append(d)
# closest centroid to each point
nearest = [ds.index(min(ds)) for ds in zip(*distances)]
clusters = [[] for i in range(K)]
for n,p in zip(nearest,data):
clusters[n].append(p)
# create new centroids
new_centroids = []
for idx,c in enumerate(clusters):
if not c:
new_centroids.append(centroids[idx])
else:
avg_x = sum([x for x,y in c]) / len(c)
avg_y = sum([y for x,y in c]) / len(c)
new_centroids.append([avg_x, avg_y])
return new_centroids, clusters
done = False
while not done:
new_centroids, clusters = cluster_loop(centroids)
done = centroids == new_centroids
centroids = new_centroids
print(f"Done: {centroids}")
for idx,cluster in enumerate(clusters):
print(f"{idx}: {cluster}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment