Skip to content

Instantly share code, notes, and snippets.

@chaonan99
Last active March 30, 2020 17:22
Show Gist options
  • Save chaonan99/c3dd34fe93402371eeab74a44678d558 to your computer and use it in GitHub Desktop.
Save chaonan99/c3dd34fe93402371eeab74a44678d558 to your computer and use it in GitHub Desktop.
A pytorch multiprocessing problem
import os
import torch
class Dataset(torch.utils.data.Dataset):
arg = {'batch_size': 1}
def __init__(self, arg):
print('__init__')
self.arg.update(arg)
# self.arg = self.arg
print(self.arg)
def _worker_init_fn(self, *args):
print('worker init')
print(self.arg)
def get_dataloader(self):
return torch.utils.data.DataLoader(self, batch_size=None,
num_workers=3,
worker_init_fn=self._worker_init_fn,
pin_memory=True,
multiprocessing_context='spawn')
def __getitem__(self, idx):
return 0
def __len__(self):
return 5
def main():
dataloader = Dataset({'batch_size': 2}).get_dataloader()
for _ in dataloader:
pass
if __name__ == '__main__':
main()
@chaonan99
Copy link
Author

I want workers to have {'batch_size': 2} but it turns out to print {'batch_size': 1}

@chaonan99
Copy link
Author

chaonan99 commented Mar 30, 2020

Here is the output

__init__
{'batch_size': 2}
worker init
{'batch_size': 1}
worker init
{'batch_size': 1}
worker init
{'batch_size': 1}

@chaonan99
Copy link
Author

Strange enough! If I add line 11 it works...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment