Skip to content

Instantly share code, notes, and snippets.

@nrupatunga
Created April 12, 2022 10:34
Show Gist options
  • Save nrupatunga/cb6d8546c2903849d1f1b1dd6445ac25 to your computer and use it in GitHub Desktop.
Save nrupatunga/cb6d8546c2903849d1f1b1dd6445ac25 to your computer and use it in GitHub Desktop.
# Required packages: torch
# Desired packages: numpy, matplotlib
# Add missing imports:
def main():
training_images = {x for x in range(1, 100001)}
priority_images = {x * 100 for x in range(1, 11)}
# Task:
# 1. Dataloader needs to return a batch of size 10. Each element is a unique element from `training_images`.
# 2. Elements from `priority_images` must be present in EVERY batch with the ratio 1:1, meaning that batch will have
# 50% of images from `priority_images` and 50% from `training_images` containers.
# NOTE: elements of `priority_images` are part of the `training_images`.
# 3. The code needs to be multiprocessing safe.
# 4. Follow comments below to add extra functionality.
batch_size = 10
num_epochs = 10
for epoch in range(num_epochs):
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment