Skip to content

Instantly share code, notes, and snippets.

@takiyu
Last active May 14, 2019 15:36
Show Gist options
  • Save takiyu/cb9f72495db4b37001d902f7e87a55d2 to your computer and use it in GitHub Desktop.
Save takiyu/cb9f72495db4b37001d902f7e87a55d2 to your computer and use it in GitHub Desktop.
Pytorch Persistent Dataloader for Windows (pytorch 1.1.0)
from torch.utils.data.dataloader import DataLoader, _DataLoaderIter
class PersistentDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# --------------------------- Different here --------------------------
self._dataloader_iter = _PersistentDataLoaderIter(self)
# ---------------------------------------------------------------------
def __iter__(self):
# --------------------------- Different here --------------------------
if self._dataloader_iter.batches_outstanding == 0:
# Reset loading status
self._dataloader_iter.sample_iter = iter(self.batch_sampler)
for _ in range(2 * self.num_workers):
self._dataloader_iter._put_indices()
# ---------------------------------------------------------------------
return self._dataloader_iter
class _PersistentDataLoaderIter(_DataLoaderIter):
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
# ------------------------- Different here ------------------------
# self._shutdown_workers()
# -----------------------------------------------------------------
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
if __name__ == '__main__':
dataset = [i for i in range(1001)]
# data_loader = DataLoader(dataset, batch_size=10, num_workers=10)
data_loader = PersistentDataLoader(dataset, batch_size=10, num_workers=10)
for epoch in range(10):
for i in data_loader:
print(f'epoch {epoch}: {str(i)}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment