Skip to content

Instantly share code, notes, and snippets.

@ianforme
Created November 29, 2020 13:52
Show Gist options
  • Save ianforme/f9a48427b596d549b8c4219c9472f1ac to your computer and use it in GitHub Desktop.
Save ianforme/f9a48427b596d549b8c4219c9472f1ac to your computer and use it in GitHub Desktop.
Use clustering algorithms to generate photo filter
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN, Birch
from sklearn.mixture import GaussianMixture
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# load picture
img = Image.open('./thor4.jpg')
arr = np.array(img)
# function to transform the image
def create_image(arr, n_center=3, method='kmeans'):
arr_flat = arr.reshape([arr.shape[0] * arr.shape[1], 3])
if method == 'kmeans':
cluster_obj = KMeans(n_clusters=n_center, random_state=42)
elif method == 'agglo':
cluster_obj = AgglomerativeClustering(n_clusters=n_center)
elif method == 'em':
cluster_obj = GaussianMixture(n_components=n_center, random_state=42)
elif method == 'dbscan':
cluster_obj = DBSCAN()
elif method == 'birch':
cluster_obj = Birch(n_clusters=n_center)
else:
raise Exception('unknown clustering method')
cluster = cluster_obj.fit_predict(arr_flat)
arr_df = pd.DataFrame(arr_flat, columns=['r', 'g', 'b'])
arr_df['cluster'] = cluster
cluster_dict = dict()
for c in arr_df['cluster'].unique():
c_ave = arr_df.loc[arr_df['cluster'] == c, ['r', 'g', 'b']].mean(axis=0)
cluster_dict[c] = c_ave.tolist()
return cluster_dict
new_arr = [cluster_dict[c] for c in cluster]
new_arr = np.array(new_arr, dtype=np.uint8)
new_arr = new_arr.reshape(arr.shape[0], arr.shape[1], 3)
new_img = Image.fromarray(new_arr, mode='RGB')
return new_img
# different experimentations
kmeans_3 = create_image(arr, 3, method='kmeans')
kmeans_5 = create_image(arr, 5, method='kmeans')
kmeans_10 = create_image(arr, 10, method='kmeans')
em_3 = create_image(arr, 3, method='em')
em_5 = create_image(arr, 5, method='em')
em_10 = create_image(arr, 10, method='em')
dbscan = create_image(arr, method='dbscan')
# plot the final photos
f, axarr = plt.subplots(3,3, figsize=(15,12))
axarr[0,0].imshow(arr)
axarr[0,0].axis('off')
axarr[0,0].set_title('Original')
axarr[0,1].imshow(dbscan)
axarr[0,1].axis('off')
axarr[0,1].set_title('DBSCAN')
axarr[0,2].axis('off')
axarr[1,0].imshow(kmeans_3)
axarr[1,0].axis('off')
axarr[1,0].set_title('KMeans - 3 centres')
axarr[1,1].imshow(kmeans_5)
axarr[1,1].axis('off')
axarr[1,1].set_title('KMeans - 5 centres')
axarr[1,2].imshow(kmeans_10)
axarr[1,2].axis('off')
axarr[1,2].set_title('KMeans - 10 centres')
axarr[2,0].imshow(em_3)
axarr[2,0].axis('off')
axarr[2,0].set_title('EM - 3 centres')
axarr[2,1].imshow(em_5)
axarr[2,1].axis('off')
axarr[2,1].set_title('EM - 5 centres')
axarr[2,2].imshow(em_10)
axarr[2,2].axis('off')
axarr[2,2].set_title('EM - 10 centres')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment