Created
December 12, 2023 10:11
-
-
Save Algomancer/bdbb50c993fdfe36fbce29c288a2782c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchdiffeq import odeint | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import tqdm | |
import imageio | |
import os | |
# Load target image and preprocess it | |
target_image = Image.open('x.png')#.convert('L') # Convert to grayscale | |
target_image = target_image.resize((32, 32)) # Resize the image | |
target_image = torch.from_numpy(np.array(target_image)).float() / 255.0 # Normalize to [0, 1] | |
target_image = target_image.cuda() | |
num_oscillators = target_image.numel() # Number of oscillators is same as number of pixels | |
class KuramotoLayer(torch.nn.Module): | |
def __init__(self, num_oscillators, coupling_strength): | |
super(KuramotoLayer, self).__init__() | |
self.num_oscillators = num_oscillators | |
self.coupling_strength = coupling_strength | |
self.natural_frequencies = torch.nn.Parameter(torch.randn(num_oscillators)) # Learnable parameters, replace this attention weights | |
def forward(self, t, phase): | |
phase_diffs = phase[None, :] - phase[:, None] # phase differences between all pairs of oscillators | |
interaction_terms = torch.sin(phase_diffs).sum(dim=1) # interaction term for each oscillator | |
dphase_dt = self.natural_frequencies + self.coupling_strength / self.num_oscillators * interaction_terms | |
return dphase_dt | |
# Create the model | |
model = KuramotoModel(num_oscillators, coupling_strength=1).cuda() | |
# Initial conditions and time span | |
initial_phase = torch.rand(num_oscillators).cuda() # Initial phase | |
t = torch.linspace(0, 10, 100) # Time span | |
# Define an optimizer | |
optimizer = torch.optim.Adam(model.parameters(), lr=1.0) | |
# Initialize list to store frames | |
frames = [] | |
for epoch in tqdm.tqdm(range(1000)): | |
optimizer.zero_grad() | |
# Solve the differential equations | |
phase = odeint(model, initial_phase, t) # Solve the differential equations | |
output_image = phase[-1].view(target_image.shape) # Reshape phase to match target image | |
# Compute the loss | |
loss = torch.nn.functional.mse_loss(output_image, target_image) | |
# Backpropagation | |
loss.backward() | |
# Update the parameters | |
optimizer.step() | |
# Generate and store frame | |
fig, ax = plt.subplots() | |
ax.imshow(output_image.cpu().detach().numpy(), cmap='gray') | |
ax.set_title(f'Epoch {epoch}, Loss {loss.item()}, num_occ={num_oscillators}') | |
fig.canvas.draw() # draw the canvas, cache the renderer | |
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') | |
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
frames.append(image) | |
plt.close(fig) | |
# Create gif from frames | |
imageio.mimsave('training_process.gif', frames, duration=0.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment