Created
June 20, 2017 20:46
-
-
Save arunmallya/5e569e4c23ad0567a64764ad70a393b1 to your computer and use it in GitHub Desktop.
Exposes bug with DataParallel when using dicts as input
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
class SimpleModel(nn.Module): | |
def __init__(self): | |
super(SimpleModel, self).__init__() | |
self.net = nn.Linear(10, 2) | |
def forward(self, inputs): | |
return self.net(inputs['data']) | |
net = nn.DataParallel(SimpleModel()).cuda() | |
inputs = {'data': Variable(torch.rand(10, 10).cuda())} | |
outputs = net(inputs) | |
print(outputs) | |
""" | |
# Works fine on single device as DataParallel defaults to simple execution if one device only. | |
$ CUDA_VISIBLE_DEVICES=0 python bug.py | |
Variable containing: | |
-0.1836 0.2654 | |
-0.3584 0.0049 | |
-0.2587 -0.0808 | |
-0.2482 -0.2587 | |
-0.4238 -0.2014 | |
-0.1964 -0.1709 | |
-0.6334 -0.0843 | |
-0.4466 0.1243 | |
-0.5991 -0.2169 | |
-0.3005 -0.0565 | |
[torch.cuda.FloatTensor of size 10x2 (GPU 0)] | |
# Fails on multiple devices. | |
$ CUDA_VISIBLE_DEVICES=0,1 python bug.py | |
Exception in thread Thread-2: | |
Traceback (most recent call last): | |
File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner | |
self.run() | |
File "/usr/lib/python3.5/threading.py", line 862, in run | |
self._target(*self._args, **self._kwargs) | |
File "venv/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 22, in _worker | |
var_input = var_input[0] | |
KeyError: 0 | |
Exception in thread Thread-1: | |
Traceback (most recent call last): | |
File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner | |
self.run() | |
File "/usr/lib/python3.5/threading.py", line 862, in run | |
self._target(*self._args, **self._kwargs) | |
File "venv/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 22, in _worker | |
var_input = var_input[0] | |
KeyError: 0 | |
Traceback (most recent call last): | |
File "bug.py", line 16, in <module> | |
outputs = net(inputs) | |
File "venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in __call__ | |
result = self.forward(*input, **kwargs) | |
File "venv/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 61, in forward | |
outputs = self.parallel_apply(replicas, inputs, kwargs) | |
File "venv/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 71, in parallel_apply | |
return parallel_apply(replicas, inputs, kwargs) | |
File "venv/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 44, in parallel_apply | |
output = results[i] | |
KeyError: 0 | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment