Skip to content

Instantly share code, notes, and snippets.

@jgoodie
Created May 25, 2024 23:48
Show Gist options
  • Save jgoodie/a0c8f3eb2ee9149d4852e8a58d5ee255 to your computer and use it in GitHub Desktop.
Save jgoodie/a0c8f3eb2ee9149d4852e8a58d5ee255 to your computer and use it in GitHub Desktop.
def training_loop(model, X_train, X_val, y_train, y_val, epochs = 1000, weight_decay = 0.0, lr=0.001, device='cuda'):
# Put all data on target device
X_train, y_train = X_train.to(device), y_train.to(device) #y_train.unsqueeze(dim=1).to(device)
X_val, y_val = X_val.to(device), y_val.to(device) #y_test.unsqueeze(dim=1).to(device)
# Define the accuracy function and initialize train/validation accuracy and loss lists
accuracy = Accuracy(task="multiclass", num_classes=model.output_features).to(device)
train_losses, train_accs, val_losses, val_accs = [], [], [], []
loss_fn = nn.CrossEntropyLoss(weight=label_weights.to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
for epoch in range(epochs):
### Set the model to training mode
model.train()
# 1. Forward pass
train_logits = model(X_train)
train_pred = torch.softmax(train_logits, dim=1).argmax(dim=1)
# 2. Calculate loss and accuracy
loss = loss_fn(train_logits, y_train)
train_losses.append(loss.item())
acc = accuracy(y_train, train_pred)
train_accs.append(acc.item())
# 3. Run optimizer zero grad
optimizer.zero_grad()
# 4. Kick-off Backpropagation
loss.backward()
# 5. Optimizer step
optimizer.step()
### Set the model to eval mode for validation
model.eval()
with torch.inference_mode():
# 1. Forward pass
val_logits = model(X_val)
val_pred = torch.softmax(val_logits, dim=1).argmax(dim=1)
# 2. Calculate test loss and accuracy
val_loss = loss_fn(val_logits, y_val)
val_losses.append(val_loss.item())
# test_acc = accuracy_fn(y_true=y_blob_test, y_pred=test_pred)
val_acc = accuracy(y_val, val_pred)
val_accs.append(val_acc.item())
# Print out epochs loss and accuracy
if epoch % 100 == 0:
print(f"Epoch: {epoch} | Train Loss: {loss:.5f}, Train Acc: {acc:.2f}% | Validation Loss: {val_loss:.5f}, Validation Acc: {val_acc:.2f}%")
return train_losses, train_accs, val_losses, val_accs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment