Skip to content

Instantly share code, notes, and snippets.

@groverpr
Created March 29, 2020 05:42
Show Gist options
  • Save groverpr/b7396c492de13113dab8e00c40c3e4b4 to your computer and use it in GitHub Desktop.
Save groverpr/b7396c492de13113dab8e00c40c3e4b4 to your computer and use it in GitHub Desktop.
def train(network,
train_data,
holdout_data,
loss,
epochs,
ctx,
lr=1e-2,
wd=1e-5,
optimizer='adam'):
# 2. Define optimizer
trainer = gluon.Trainer(network.collect_params(), optimizer,
{'learning_rate': lr, 'wd': wd})
# Hybridize network for faster computations. (Symbolic)
network.hybridize()
# Print loss values before training starts
valid_loss, valid_auc, _, _ = evaluate_network(network, loss, holdout_data, ctx)
train_loss, train_auc, _, _ = evaluate_network(network, loss, train_data, ctx)
print("Start \n Training BCE {:.4f}, Train AUC {:.4f}, Valid AUC {:.4f}".format(train_loss,
train_auc,
valid_auc))
# 4. Train the network
for e in range(epochs):
for idx, ((data, length), label) in enumerate(train_data): # For each mini batch
X_ = gluon.utils.split_and_load(data, ctx, even_split=False) # splits data to go to each gpu
X_l_ = gluon.utils.split_and_load(length, ctx, even_split=False)
y_ = gluon.utils.split_and_load(label, ctx, even_split=False)
# Forward pass to be done in .record() mode.
# By default, record mode takes it to training mode and helps with layers like dropout which
# require different treatment for predict model
with autograd.record():
preds = [network(x_, x_l_) for x_, x_l_ in zip(X_, X_l_)] # forward pass
losses = [loss(p, y) for p, y in zip(preds, y_)] # loss calculation
[k.backward() for k in losses] # gradient calculation using chain rule
trainer.step(data.shape[0]) # performs one step of gradient descent. input parameter is # rows in mini batch
valid_loss, valid_auc, _, _ = evaluate_network(network, loss, holdout_data, ctx)
train_loss, train_auc, _, _ = evaluate_network(network, loss, train_data, ctx)
print("Epoch [{}], Training BCE {:.4f}, Train AUC {:.4f}, Valid AUC {:.4f}".format(e+1,
train_loss,
train_auc,
valid_auc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment