Created
September 28, 2022 22:21
-
-
Save dfaker/664a9cd8d2cc391c42c2a612955a3fbf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
def lerp(theta0, theta1, alpha): | |
return (1 - alpha) * theta0 + alpha * theta1 | |
def slerp(theta0, theta1, alpha): | |
theta0 = theta0 | |
theta1 = theta1 | |
# Copy the vectors to reuse them later | |
theta0_copy = torch.clone(theta0) | |
theta1_copy = torch.clone(theta1) | |
# Normalize the vectors to get the directions and angles | |
theta0 = theta0 / np.linalg.norm(theta0) | |
theta1 = theta1 / np.linalg.norm(theta1) | |
# Dot product with the normalized vectors (can't use np.dot in W) | |
dot = torch.sum(theta0 * theta1) | |
if np.abs(dot) >= 1.0: | |
return lerp(alpha, theta0_copy, theta1_copy) | |
# Calculate initial angle between v0 and v1 | |
theta_0 = np.arccos(dot) | |
sin_theta_0 = np.sin(theta_0) | |
# Angle at timestep t | |
theta_t = theta_0 * alpha | |
sin_theta_t = np.sin(theta_t) | |
# Finish the slerp algorithm | |
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
s1 = sin_theta_t / sin_theta_0 | |
result = s0 * theta0_copy + s1 * theta1_copy | |
del theta0_copy,theta1_copy | |
del theta0,theta1 | |
return result | |
primary_model = torch.load('models\\model-aa-base.ckpt', map_location='cpu') | |
secondary_model = torch.load('models\\model-aa-waifu.ckpt', map_location='cpu') | |
theta_0 = primary_model['state_dict'] | |
theta_1 = secondary_model['state_dict'] | |
for key in set(theta_0.keys()).union(set(theta_1.keys())): | |
if 'model' in key and key in theta_0 and key in theta_1: | |
print(key) | |
theta_0[key] = slerp(theta_0[key], theta_1[key], (float(1.0) - 0.25)) | |
if 'model' in key and key in theta_1 and key not in theta_0: | |
theta_0[key] = theta_1[key] | |
del theta_1[key] | |
del secondary_model | |
torch.save(primary_model, 'models\\model-aa-base-plus-waifu.ckpt') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment