Skip to content

Instantly share code, notes, and snippets.

@yassersouri
Created February 19, 2018 22:35
Show Gist options
  • Save yassersouri/964ac9cbe3b128af59c4c85b7b8a79db to your computer and use it in GitHub Desktop.
Save yassersouri/964ac9cbe3b128af59c4c85b7b8a79db to your computer and use it in GitHub Desktop.
PyTorch Sampler for Intensional Overfitting!
from torch.utils.data.sampler import Sampler
class MySampler(Sampler):
def __init__(self, main_source, indices):
self.main_source = main_source
self.indices = indices
main_source_len = len(self.main_source)
how_many = int(round(main_source_len / len(self.indices)))
self.to_iter_from = []
for _ in range(how_many):
self.to_iter_from.extend(self.indices)
def __iter__(self):
return iter(self.to_iter_from)
def __len__(self):
return len(self.main_source)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment