Skip to content

Instantly share code, notes, and snippets.

@abhishek-ghose
Created July 22, 2024 19:23
Show Gist options
  • Save abhishek-ghose/7cf2cc8d9f6fb7f9b19855883f86c69c to your computer and use it in GitHub Desktop.
Save abhishek-ghose/7cf2cc8d9f6fb7f9b19855883f86c69c to your computer and use it in GitHub Desktop.
Sample codes for Optimal Transport
import numpy as np, itertools
import pandas as pd
import ot
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import seaborn as sns; sns.set()
from scipy.stats import multivariate_normal as mvn, norm
from sklearn.cluster import KMeans
import os
from datetime import datetime
OP_DIR = r'./generated/optimal_transport'
if not os.path.exists(OP_DIR) or not os.path.isdir(OP_DIR):
print(f"Output dir. {OP_DIR} doesn't exist! Will be created...")
os.makedirs(OP_DIR)
else:
print(f"The common output directory is {OP_DIR}.")
def get_plan_lines(opt_plan, X_source, X_target):
"""
Plotting helper function that takes in an optimal transport plan, the source and target point coordinates, and
produces a matplotlib segment that represent the source->target assignments as per the plan. A valid assignment is
a non-zero entry.
:param opt_plan:
:param X_source:
:param X_target:
:return:
"""
nz_idxs = list(zip(*np.where(opt_plan > 0)))
segments = [(X_source[i, :], X_target[j, :]) for i, j in nz_idxs]
return segments
def basic_2d():
N = 10
min_y, max_y = 0, 5
x1, x2 = 2, 5
fig = plt.figure(figsize=(24, 8))
ax1, ax2, ax3 = [fig.add_subplot(130 + i) for i in range(1, 4)]
temp = sorted(min_y + np.random.random(N) * (max_y - min_y))
X_source = np.concatenate((x1 * np.ones(N), temp)).reshape(2, -1).T
perturb = norm.rvs(0, 0.2, N) # add a small perturbation for the second set of points
X_target = np.concatenate((x2 * np.ones(N), temp + perturb)).reshape(2, -1).T
w_source, w_target = np.ones((N,)) / N, np.ones((N,)) / N # uniform distribution on samples
dist_mat = ot.dist(X_source, X_target, 'euclidean')
# plot the points
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], ax=ax1).set(xlabel='x1', ylabel='x2')
sns.scatterplot(x=X_target[:, 0], y=X_target[:, 1], ax=ax1)
ax1.set_title("Points with optimal mapping (EMD).")
# plot the distance matrix
sns.heatmap(dist_mat, annot=False, ax=ax2).set(xlabel='X_source', ylabel='X_target')
ax2.invert_yaxis()
ax2.set_title("Distance Matrix")
# get the optimal plan and plot that matrix
opt_plan = ot.emd(w_source, w_target , dist_mat)
sns.heatmap(opt_plan, annot=False, ax=ax3).set(xlabel='X_source', ylabel='X_target')
ax3.invert_yaxis()
ax3.set_title("Optimal Transport Plan (EMD)")
# add the transport plan to the scatter plot
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, colors='r', linestyles='--')
ax1.add_collection(L)
plt.savefig(f"{OP_DIR}/almost_bipartite.png", bbox_inches='tight')
def representative_points_2d():
num_components = 5 # gaussian components
N_target = 50
x_lim, y_lim = (0, 10), (0, 10) # limits of component means
points_per_comp = [int(1. * N_target/num_components)] * num_components
# assign leftovers to the last component if N_target can't be evenly divided
points_per_comp[-1] += N_target - sum(points_per_comp)
# components are Gaussian - get their mean and cov
mean = np.random.random(size=num_components * 2).reshape(num_components, 2)
# scale the means to be in the right bounding box
# data_x_lim, data_y_lim = (x_lim[0] + offset, x_lim[1] - offset), (y_lim[0] + offset, y_lim[1] - offset)
mean[:, 0] = (x_lim[1] - x_lim[0]) * mean[:, 0] + x_lim[0]
mean[:, 1] = (y_lim[1] - y_lim[0]) * mean[:, 1] + y_lim[0]
# we'll create isotropic Gaussians by setting the cross-terms in the cov. to 0
covs = [np.array([[np.random.rand(), 0], [0, np.random.rand()]]) for _ in mean]
# create points
X_target = np.empty((0, 2))
for i, n_comp, m_comp, cov_comp in zip(range(num_components), points_per_comp, mean, covs):
temp = mvn.rvs(mean=m_comp, cov=cov_comp, size=n_comp)
X_target = np.concatenate((X_target, temp))
print(f"Created target points, of shape={np.shape(X_target)}.")
fig = plt.figure(figsize=(14, 8))
point_size = 100
axes = [fig.add_subplot(120 + i) for i in range(1, 3)]
for ax in axes:
sns.scatterplot(x=X_target[:, 0], y=X_target[:, 1], s=point_size, ax=ax)
# let's see what OT wrt cluster centers look like
X_source = KMeans(n_clusters=num_components).fit(X_target).cluster_centers_
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], c='r', s=point_size*1.5, ax=axes[0])
dist_mat = ot.dist(X_source, X_target, 'euclidean')
# uniform distribution on samples
w_source, w_target = np.ones((len(X_source),)) / len(X_source), np.ones((len(X_target),)) / len(X_target)
opt_plan = ot.emd(w_source, w_target, dist_mat)
opt_cost = np.sum(opt_plan * dist_mat)
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, color='red')
axes[0].add_collection(L)
axes[0].set_title(f"{num_components} Gauss. components, with {num_components} k-means centers.\nOpt cost={opt_cost:.2f}.")
# random samples
X_source = np.random.random((num_components, 2))
x_min, x_max = min(X_target[:, 0]), max(X_target[:, 0])
y_min, y_max = min(X_target[:, 1]), max(X_target[:, 1])
X_source[:, 0] = (x_max - x_min) * X_source[:, 0] + x_min
X_source[:, 1] = (y_max - y_min) * X_source[:, 1] + y_min
dist_mat = ot.dist(X_source, X_target, 'euclidean')
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], c='r', s=point_size * 1.5, ax=axes[1])
w_source, w_target = np.ones((len(X_source),)) / len(X_source), np.ones((len(X_target),)) / len(X_target)
opt_plan = ot.emd(w_source, w_target, dist_mat)
opt_cost = np.sum(opt_plan * dist_mat)
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, color='red')
axes[1].add_collection(L)
axes[1].set_title(f"{num_components} Gauss. components, with {num_components} random samples.\nOpt cost={opt_cost:.2f}.")
fig.suptitle("Optimal Transport demo")
plt.savefig(f"{OP_DIR}/demo_representative_points.png", bbox_inches='tight')
def compute_runtimes():
dims = np.array([10, 50, 100, 200, 300, 400, 500], dtype=int)
num_points = np.array([100, 500, 1000, 1500, 2000], dtype=int) # this is per dist.
lambda_param = np.array([0, 0.01, 0.1, 1, 10])
num_trials = 3
res_file = f"{OP_DIR}/runtimes.csv"
res_df = pd.DataFrame(columns=['trial_idx', 'dims', 'num_points', 'lambda', 'ot_score', 'duration_sec'])
for trial_idx, curr_dim, curr_num_points, curr_lambda_param in itertools.product(range(num_trials), dims,
num_points, lambda_param):
print(f"\n{trial_idx, curr_dim, curr_num_points, curr_lambda_param}")
X_source = np.random.random((curr_num_points, curr_dim))
X_target = np.random.random((curr_num_points, curr_dim))
print(f"Created source points, shape={np.shape(X_source)}.")
print(f"Created target points, shape={np.shape(X_target)}.")
print(f"Calculating distances.")
dist_mat = ot.dist(X_source, X_target, 'euclidean')
w_source, w_target = (np.ones((curr_num_points,)) / curr_num_points,
np.ones((curr_num_points,)) / curr_num_points)
print(f"Calculating opt plan.")
time_start = datetime.now()
if curr_lambda_param == 0:
opt_plan = ot.emd(w_source, w_target, dist_mat)
else:
opt_plan = ot.sinkhorn(w_source, w_target, dist_mat, curr_lambda_param)
time_end = datetime.now()
duration_sec = (time_end - time_start).total_seconds()
print(f"Opt. transport finding took {duration_sec} sec.")
ot_score = np.sum(dist_mat * opt_plan) # TODO: can we directly return this score instead of computing?
temp_df = pd.DataFrame([[trial_idx, curr_dim, curr_num_points, curr_lambda_param, ot_score, duration_sec]],
columns=['trial_idx', 'dims', 'num_points', 'lambda', 'ot_score', 'duration_sec'])
res_df = pd.concat((res_df, temp_df), ignore_index=True)
res_df.to_csv(res_file, index=False)
def process_runtime_results(df):
# plot runtimes
fig = plt.figure(figsize=(20, 8))
ax_data_size, ax_dims = fig.add_subplot(121), fig.add_subplot(122)
temp_df = df.groupby(by=['num_points', 'lambda'], as_index=False).agg({'duration_sec': 'mean'})
sns.lineplot(data=temp_df, x='num_points', y='duration_sec', hue='lambda', palette=sns.color_palette("tab10"),
marker='o', ax=ax_data_size)
ax_data_size.set_title(f"Runtime wrt #points, averaged over #dims."
f"\nEMD is used for $\lambda=0$, the rest use Sinkhorn.")
temp_df = df.groupby(by=['dims', 'lambda'], as_index=False).agg({'duration_sec': 'mean'})
sns.lineplot(data=temp_df, x='dims', y='duration_sec', hue='lambda', palette=sns.color_palette("tab10"), marker='o',
ax=ax_dims)
ax_dims.set_title(f"Runtime wrt #dims, averaged over #points."
f"\nEMD is used for $\lambda=0$, the rest use Sinkhorn.")
plt.savefig(f"{OP_DIR}/runtimes.png", bbox_inches='tight')
# plot approx. accuracy
fig = plt.figure()
ax = fig.add_subplot(111)
aggr_df = df.groupby(by=['dims', 'num_points', 'lambda'], as_index=False).agg({'ot_score': 'mean'})
baseline_df = aggr_df[aggr_df['lambda'] == 0]
t = sorted(baseline_df['ot_score'].to_numpy().flatten())
ax.plot(t, t, c='y', linestyle='--', label='no approx.')
other_lambdas = set(df['lambda']) - {0}
for lambda_param in other_lambdas:
print(f"Analyzing data for lambda={lambda_param}.")
temp_df = aggr_df[aggr_df['lambda'] == lambda_param]
# join on identical settings
joined_df = pd.merge(baseline_df, temp_df, on=['dims', 'num_points'], how='inner')
plot_data = joined_df[['ot_score_x', 'ot_score_y']].sort_values(by='ot_score_x').to_numpy()
ax.plot(plot_data[:, 0], plot_data[:, 1], label=f"{lambda_param}")
ax.set_xlabel('OT score baseline (EMD)')
ax.set_ylabel('OT score approx. (Sinkhorn)')
plt.legend()
plt.savefig(f"{OP_DIR}/sinkhorn_approx_quality.png", bbox_inches='tight')
if __name__ == "__main__":
pass
# basic_2d()
# representative_points_2d()
# compute_runtimes()
resfile = f"{OP_DIR}/runtimes.csv"
df = pd.read_csv(resfile)
process_runtime_results(df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment