Support multiple inputs in data parallel
diff --git a/test/test_nn.py b/test/test_nn.py
index 3a71bf5..89479ec 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -722,7 +722,7 @@
i2 = Variable(torch.randn(2, 10).float().cuda(1))
expected1 = l1(i1).data
expected2 = l2(i2).data
- inputs = (i1, i2)
+ inputs = ((i1,), (i2,))
modules = (l1, l2)
expected_outputs = (expected1, expected2)
outputs = dp.parallel_apply(modules, inputs)
@@ -740,6 +740,31 @@
self.assertFalse(out.is_cuda)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+ def test_data_parallel_multiple_input(self):
+ class TestModule(nn.Module):
+ def forward(self, x, y):
+ return x + y
+
+ m = TestModule()
+ x = Variable(torch.randn(5, 5).float())
+ y = Variable(torch.randn(5, 5).float())
+ expected = m(x, y)
+
+ out = dp.data_parallel(m, (x, y), (0, 1))
+ self.assertEqual(out, expected)
+
+ out = dp.data_parallel(m, (x, y), (0,))
+ self.assertEqual(out, expected)
+
+ dpm = nn.DataParallel(TestModule())
+ out = dpm(x, y)
+ self.assertEqual(out, expected)
+
+ dpm = nn.DataParallel(TestModule(), device_ids=[0])
+ out = dpm(x, y)
+ self.assertEqual(out, expected)
+
+ @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_small_back(self):
l = nn.Linear(10, 5).float().cuda()
i = Variable(torch.randn(20, 10).float().cuda())
@@ -785,7 +810,7 @@
class Net(nn.Module):
- def forward(self, input):
+ def forward(self, *input):
return fn(input)
i = Variable(torch.randn(20, 3).float().cuda(1))
input = (i.cos(), (i.sin(), i), i.sin())
diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py
index 94d9d04..97fd8f4 100644
--- a/torch/nn/parallel/data_parallel.py
+++ b/torch/nn/parallel/data_parallel.py
@@ -36,7 +36,7 @@
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
- def forward(self, input):
+ def forward(self, *inputs):
def _to_cuda(obj):
if isinstance(obj, Variable):
return obj.cuda()
@@ -44,12 +44,12 @@
if len(self.device_ids) == 1:
with torch.cuda.device(self.device_ids[0]):
- inpcuda = _to_cuda(input)
- return self.module(inpcuda)
+ inputs_cuda = _to_cuda(inputs)
+ return self.module(*inputs_cuda)
replicas = self.replicate(self.module, self.device_ids)
- inputs = self.scatter(input, self.device_ids)
- replicas = replicas[:len(inputs)]
- outputs = self.parallel_apply(replicas, inputs)
+ scattered = self.scatter(inputs, self.device_ids)
+ replicas = replicas[:len(scattered)]
+ outputs = self.parallel_apply(replicas, scattered)
return self.gather(outputs, self.output_device)
def replicate(self, module, device_ids):
@@ -65,14 +65,14 @@
return gather(outputs, output_device)
-def data_parallel(module, input, device_ids, output_device=None):
+def data_parallel(module, inputs, device_ids, output_device=None):
"""Evaluates module(input) in parallel across the GPUs given in device_ids.
This is the functional version of the DataParallel module.
Args:
module: the module to evaluate in parallel
- input: input to the module
+ inputs: inputs to the module
device_ids: GPU ids on which to replicate module
output_device: GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0])
@@ -80,14 +80,17 @@
a Variable containing the result of module(input) located on
output_device
"""
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+
if not device_ids:
- return module(input)
+ return module(*inputs)
if output_device is None:
output_device = device_ids[0]
replicas = replicate(module, device_ids)
- inputs = scatter(input, device_ids)
- replicas = replicas[:len(inputs)]
- outputs = parallel_apply(replicas, inputs)
+ scattered = scatter(inputs, device_ids)
+ replicas = replicas[:len(scattered)]
+ outputs = parallel_apply(replicas, scattered)
return gather(outputs, output_device)
diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py
index 32df2b1..ce51be0 100644
--- a/torch/nn/parallel/parallel_apply.py
+++ b/torch/nn/parallel/parallel_apply.py
@@ -12,7 +12,7 @@
assert len(modules) == len(inputs)
# Fast track
if len(modules) == 1:
- return (modules[0](inputs[0]),)
+ return (modules[0](*inputs[0]),)
lock = threading.Lock()
results = {}
@@ -23,7 +23,7 @@
var_input = var_input[0]
try:
with torch.cuda.device_of(var_input):
- output = module(input)
+ output = module(*input)
with lock:
results[input] = output
except Exception as e: