Created
March 12, 2020 08:01
-
-
Save groverpr/e580b8085f70833a420f802d2adf26a4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_dataloader(dataset, | |
dataset_type="train", # valid/test | |
batch_size=256, | |
bucket_num=5, | |
shuffle=True, # true for training | |
num_workers=1): | |
# Batchify function appends the length of each sequence to feed as addtional input | |
combined_batchify_fn = nlp.data.batchify.Tuple( | |
nlp.data.batchify.Pad(axis=0, ret_length=True), | |
nlp.data.batchify.Stack(dtype='float32')) # stack input samples | |
if dataset_type == "train": | |
data_lengths = dataset.transform( | |
lambda review, label: float(len(review)), lazy=False) | |
# We need to shuffle for training data. | |
# It's more efficient to shuffle training data such that sequences with similar length come together | |
# This is achieved using FixedBucketSampler | |
batch_sampler = nlp.data.sampler.FixedBucketSampler( | |
data_lengths, | |
batch_size=batch_size, | |
num_buckets=bucket_num, | |
shuffle=shuffle) | |
dataloader = gluon.data.DataLoader( | |
dataset=dataset, | |
batch_sampler=batch_sampler, | |
batchify_fn=combined_batchify_fn, | |
num_workers=num_workers, | |
) | |
# We don't need to shuffle for valid and test datasets | |
elif dataset_type in ["valid", "test"]: | |
batch_sampler = None | |
dataloader = gluon.data.DataLoader( | |
dataset=dataset, | |
shuffle=shuffle, | |
batch_size=batch_size, | |
batchify_fn=combined_batchify_fn, | |
num_workers=num_workers) | |
else: | |
raise Exception("Pass dataset type from train, dev, valid or test") | |
return dataloader | |
# Example of creating dataloader | |
train_dataloader = get_dataloader(train_dataset, | |
batch_size=64, | |
bucket_num=5, # number of different buckets based on length | |
shuffle=True, | |
num_workers=0, | |
dataset_type="train" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment