Last active
May 14, 2019 15:36
-
-
Save takiyu/cb9f72495db4b37001d902f7e87a55d2 to your computer and use it in GitHub Desktop.
Pytorch Persistent Dataloader for Windows (pytorch 1.1.0)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data.dataloader import DataLoader, _DataLoaderIter | |
class PersistentDataLoader(DataLoader): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# --------------------------- Different here -------------------------- | |
self._dataloader_iter = _PersistentDataLoaderIter(self) | |
# --------------------------------------------------------------------- | |
def __iter__(self): | |
# --------------------------- Different here -------------------------- | |
if self._dataloader_iter.batches_outstanding == 0: | |
# Reset loading status | |
self._dataloader_iter.sample_iter = iter(self.batch_sampler) | |
for _ in range(2 * self.num_workers): | |
self._dataloader_iter._put_indices() | |
# --------------------------------------------------------------------- | |
return self._dataloader_iter | |
class _PersistentDataLoaderIter(_DataLoaderIter): | |
def __next__(self): | |
if self.num_workers == 0: # same-process loading | |
indices = next(self.sample_iter) # may raise StopIteration | |
batch = self.collate_fn([self.dataset[i] for i in indices]) | |
if self.pin_memory: | |
batch = _utils.pin_memory.pin_memory_batch(batch) | |
return batch | |
# check if the next sample has already been generated | |
if self.rcvd_idx in self.reorder_dict: | |
batch = self.reorder_dict.pop(self.rcvd_idx) | |
return self._process_next_batch(batch) | |
if self.batches_outstanding == 0: | |
# ------------------------- Different here ------------------------ | |
# self._shutdown_workers() | |
# ----------------------------------------------------------------- | |
raise StopIteration | |
while True: | |
assert (not self.shutdown and self.batches_outstanding > 0) | |
idx, batch = self._get_batch() | |
self.batches_outstanding -= 1 | |
if idx != self.rcvd_idx: | |
# store out-of-order samples | |
self.reorder_dict[idx] = batch | |
continue | |
return self._process_next_batch(batch) | |
if __name__ == '__main__': | |
dataset = [i for i in range(1001)] | |
# data_loader = DataLoader(dataset, batch_size=10, num_workers=10) | |
data_loader = PersistentDataLoader(dataset, batch_size=10, num_workers=10) | |
for epoch in range(10): | |
for i in data_loader: | |
print(f'epoch {epoch}: {str(i)}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment