Skip to content

Instantly share code, notes, and snippets.

@khlmnn
Last active April 8, 2022 11:09
Show Gist options
  • Save khlmnn/cbd09ec9b8e692155413e5411662ca83 to your computer and use it in GitHub Desktop.
Save khlmnn/cbd09ec9b8e692155413e5411662ca83 to your computer and use it in GitHub Desktop.
Training loop for fixed-window model
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
def train_fixed_window(n, n_epochs=1, batch_size=3200, lr=1e-2):
# Vectorize the data
train_x, train_y = vectorize_fixed_window(wikitext.train, n)
valid_x, valid_y = vectorize_fixed_window(wikitext.valid, n)
# Initialize the model
model = FixedWindowModel(n, len(wikitext.vocab), embedding_dim=50, hidden_dim=50).to(device)
# nn.init.normal_(model.embedding.weight, mean=0, std=1e-1)
# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# Minimal validation perplexity seen so far
min_ppl = float('inf')
for t in range(n_epochs):
# Training
model.train()
with tqdm(total=len(train_x)) as pbar:
pbar.set_description(f'Epoch {t+1}')
for bx, by in batchify(train_x, train_y, batch_size):
optimizer.zero_grad()
output = model.forward(bx)
loss = F.cross_entropy(output, by)
loss.backward()
optimizer.step()
pbar.set_postfix(loss=loss.item(), ppl=np.exp(loss.item()))
pbar.update(len(bx))
# Evaluation
model.eval()
with torch.no_grad():
losses = []
for bx, by in batchify(valid_x, valid_y, batch_size):
output = model.forward(bx)
losses.append(F.cross_entropy(output, by).item())
ppl = np.exp(sum(losses) / len(losses))
print(f'Perplexity after epoch {t+1}: {ppl}', flush=True)
# Terminate the training if the validation perplexity has not improved
if ppl <= min_ppl - 2:
min_ppl = ppl
else:
break
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment