Created
March 29, 2020 05:42
-
-
Save groverpr/b7396c492de13113dab8e00c40c3e4b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(network, | |
train_data, | |
holdout_data, | |
loss, | |
epochs, | |
ctx, | |
lr=1e-2, | |
wd=1e-5, | |
optimizer='adam'): | |
# 2. Define optimizer | |
trainer = gluon.Trainer(network.collect_params(), optimizer, | |
{'learning_rate': lr, 'wd': wd}) | |
# Hybridize network for faster computations. (Symbolic) | |
network.hybridize() | |
# Print loss values before training starts | |
valid_loss, valid_auc, _, _ = evaluate_network(network, loss, holdout_data, ctx) | |
train_loss, train_auc, _, _ = evaluate_network(network, loss, train_data, ctx) | |
print("Start \n Training BCE {:.4f}, Train AUC {:.4f}, Valid AUC {:.4f}".format(train_loss, | |
train_auc, | |
valid_auc)) | |
# 4. Train the network | |
for e in range(epochs): | |
for idx, ((data, length), label) in enumerate(train_data): # For each mini batch | |
X_ = gluon.utils.split_and_load(data, ctx, even_split=False) # splits data to go to each gpu | |
X_l_ = gluon.utils.split_and_load(length, ctx, even_split=False) | |
y_ = gluon.utils.split_and_load(label, ctx, even_split=False) | |
# Forward pass to be done in .record() mode. | |
# By default, record mode takes it to training mode and helps with layers like dropout which | |
# require different treatment for predict model | |
with autograd.record(): | |
preds = [network(x_, x_l_) for x_, x_l_ in zip(X_, X_l_)] # forward pass | |
losses = [loss(p, y) for p, y in zip(preds, y_)] # loss calculation | |
[k.backward() for k in losses] # gradient calculation using chain rule | |
trainer.step(data.shape[0]) # performs one step of gradient descent. input parameter is # rows in mini batch | |
valid_loss, valid_auc, _, _ = evaluate_network(network, loss, holdout_data, ctx) | |
train_loss, train_auc, _, _ = evaluate_network(network, loss, train_data, ctx) | |
print("Epoch [{}], Training BCE {:.4f}, Train AUC {:.4f}, Valid AUC {:.4f}".format(e+1, | |
train_loss, | |
train_auc, | |
valid_auc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment