Skip to content

Instantly share code, notes, and snippets.

@isaacmg
Last active February 11, 2020 18:03
Show Gist options
  • Save isaacmg/bf22aead11edb4affa1b4005fef8ee60 to your computer and use it in GitHub Desktop.
Save isaacmg/bf22aead11edb4affa1b4005fef8ee60 to your computer and use it in GitHub Desktop.
def train_epoch_loop(data_loader:DataLoader, opt:torch.optim, model:PyTorchForecast, takes_target:bool, forward_params={})
i = 0
running_loss = 0.0
for src, trg in data_loader:
opt.zero_grad()
# Convert to CPU/GPU/TPU
src = src.to(model.device)
trg = trg.to(model.device)
# TODO figure how to avoid
if takes_target:
forward_params["t"] = trg
output = model.model(src, **forward_params)
labels = trg[:, :, 0]
loss = criterion(output, labels.float())
if loss > 100:
print("Warning: high loss detected")
loss.backward()
#torch.nn.utils.clip_grad_norm_(s.parameters(), 0.5)
opt.step()
running_loss += loss.item()
i+=1
if torch.isnan(loss) or loss==float('inf'):
raise "Error infinite or NaN loss detected. Try normalizing data or performing interpolation"
print("The loss for epoch " + str(epoch))
total_loss = running_loss/float(i)
print(total_loss)
return total_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment