Last active
August 4, 2022 21:47
-
-
Save ThiagoLira/174be8f347ffb3023a3ca57a827fd7dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
pbar = tqdm(range(50)) | |
for epoch in pbar: # loop over the dataset multiple times | |
running_loss = 0.0 | |
# sample a bunch of timesteps | |
Ts = np.random.randint(1,num_diffusion_timesteps, size=training_steps_per_epoch) | |
for _, t in enumerate(Ts): | |
# produce corrupted sample | |
q_t = q_sample(x_init, t, list_bar_alphas, device) | |
# calculate the mean and variance of the posterior forward distribution q(x_t-1 | x_t,x_0) | |
mu_t, cov_t = posterior_q(x_init, q_t, t, alphas, list_bar_alphas, device) | |
# get just first element from diagonal of covariance since they are all equal | |
sigma_t = cov_t[0][0] | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# the model is just some function that will try to predict the mean of the forward posterior distribution | |
mu_theta = denoising_model(q_t , t) | |
# loss is just how wrong the model is in predicting the mean | |
loss = criterion(mu_t, mu_theta) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.detach() | |
pbar.set_description('Epoch: {} Loss: {}'.format(epoch, running_loss/training_steps_per_epoch)) | |
print('Finished Training') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment