Skip to content

Instantly share code, notes, and snippets.

@taoyds
Created March 20, 2020 12:33
Show Gist options
  • Save taoyds/1198b9d8c165669d2bf8d1ec88ba2539 to your computer and use it in GitHub Desktop.
Save taoyds/1198b9d8c165669d2bf8d1ec88ba2539 to your computer and use it in GitHub Desktop.
# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)
# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
# load datasets
task.load_dataset('train')
task.load_dataset('valid')
# iterate over mini-batches of data
batch_itr = task.get_batch_iterator(
task.dataset('train'), max_tokens=4096,
)
for batch in batch_itr:
# compute the loss
loss, sample_size, logging_output = task.get_loss(
model, criterion, batch,
)
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment