This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fairseq.models.transformer import TransformerModel | |
zh2en = TransformerModel.from_pretrained( | |
'/path/to/checkpoints', | |
checkpoint_file='checkpoint_best.pt', | |
data_name_or_path='data-bin/wmt17_zh_en_full', | |
bpe='subword_nmt', | |
bpe_codes='data-bin/wmt17_zh_en_full/zh.code' | |
) | |
zh2en.translate('你好 世界') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# List available models | |
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ] | |
# Load a transformer trained on WMT'16 En-De | |
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt') | |
en2de.eval() # disable dropout | |
# The underlying model is available under the *models* attribute |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@register_task('classification') | |
class ClassificationTask(FairseqTask): (...) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model.train() | |
model.set_num_updates(update_num) | |
loss, sample_size, logging_output = criterion(model, sample) | |
if ignore_grad: | |
loss *= 0 | |
optimizer.backward(loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# setup the task (e.g., load dictionaries) | |
task = fairseq.tasks.setup_task(args) | |
# build model and criterion | |
model = task.build_model(args) | |
criterion = task.build_criterion(args) | |
# load datasets | |
task.load_dataset('train') | |
task.load_dataset('valid') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(num_epochs): | |
itr = task.get_batch_iterator(task.dataset('train')) | |
for num_updates, batch in enumerate(itr): | |
task.train_step(batch, model, criterion, optimizer) | |
average_and_clip_gradients() | |
optimizer.step() | |
lr_scheduler.step_update(num_updates) | |
lr_scheduler.step(epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if args.joined_dictionary: | |
assert not args.srcdict or not args.tgtdict, \ | |
"cannot use both --srcdict and --tgtdict with --joined-dictionary" | |
if args.srcdict: | |
src_dict = task.load_dictionary(args.srcdict) | |
elif args.tgtdict: | |
src_dict = task.load_dictionary(args.tgtdict) | |
else: | |
assert args.trainpref, "--trainpref must be set if --srcdict is not specified" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cli_main(): | |
parser = options.get_preprocessing_parser() | |
args = parser.parse_args() | |
main(args) | |
if __name__ == "__main__": | |
cli_main() |