Skip to content

Instantly share code, notes, and snippets.

@groverpr
Created March 12, 2020 08:01
Show Gist options
  • Save groverpr/e580b8085f70833a420f802d2adf26a4 to your computer and use it in GitHub Desktop.
Save groverpr/e580b8085f70833a420f802d2adf26a4 to your computer and use it in GitHub Desktop.
def get_dataloader(dataset,
dataset_type="train", # valid/test
batch_size=256,
bucket_num=5,
shuffle=True, # true for training
num_workers=1):
# Batchify function appends the length of each sequence to feed as addtional input
combined_batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0, ret_length=True),
nlp.data.batchify.Stack(dtype='float32')) # stack input samples
if dataset_type == "train":
data_lengths = dataset.transform(
lambda review, label: float(len(review)), lazy=False)
# We need to shuffle for training data.
# It's more efficient to shuffle training data such that sequences with similar length come together
# This is achieved using FixedBucketSampler
batch_sampler = nlp.data.sampler.FixedBucketSampler(
data_lengths,
batch_size=batch_size,
num_buckets=bucket_num,
shuffle=shuffle)
dataloader = gluon.data.DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
batchify_fn=combined_batchify_fn,
num_workers=num_workers,
)
# We don't need to shuffle for valid and test datasets
elif dataset_type in ["valid", "test"]:
batch_sampler = None
dataloader = gluon.data.DataLoader(
dataset=dataset,
shuffle=shuffle,
batch_size=batch_size,
batchify_fn=combined_batchify_fn,
num_workers=num_workers)
else:
raise Exception("Pass dataset type from train, dev, valid or test")
return dataloader
# Example of creating dataloader
train_dataloader = get_dataloader(train_dataset,
batch_size=64,
bucket_num=5, # number of different buckets based on length
shuffle=True,
num_workers=0,
dataset_type="train"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment