Support multiple inputs in data parallel
diff --git a/test/test_nn.py b/test/test_nn.py
index 3a71bf5..89479ec 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -722,7 +722,7 @@
         i2 = Variable(torch.randn(2, 10).float().cuda(1))
         expected1 = l1(i1).data
         expected2 = l2(i2).data
-        inputs = (i1, i2)
+        inputs = ((i1,), (i2,))
         modules = (l1, l2)
         expected_outputs = (expected1, expected2)
         outputs = dp.parallel_apply(modules, inputs)
@@ -740,6 +740,31 @@
         self.assertFalse(out.is_cuda)
 
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+    def test_data_parallel_multiple_input(self):
+        class TestModule(nn.Module):
+            def forward(self, x, y):
+                return x + y
+
+        m = TestModule()
+        x = Variable(torch.randn(5, 5).float())
+        y = Variable(torch.randn(5, 5).float())
+        expected = m(x, y)
+
+        out = dp.data_parallel(m, (x, y), (0, 1))
+        self.assertEqual(out, expected)
+
+        out = dp.data_parallel(m, (x, y), (0,))
+        self.assertEqual(out, expected)
+
+        dpm = nn.DataParallel(TestModule())
+        out = dpm(x, y)
+        self.assertEqual(out, expected)
+
+        dpm = nn.DataParallel(TestModule(), device_ids=[0])
+        out = dpm(x, y)
+        self.assertEqual(out, expected)
+
+    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     def test_data_parallel_small_back(self):
         l = nn.Linear(10, 5).float().cuda()
         i = Variable(torch.randn(20, 10).float().cuda())
@@ -785,7 +810,7 @@
 
         class Net(nn.Module):
 
-            def forward(self, input):
+            def forward(self, *input):
                 return fn(input)
         i = Variable(torch.randn(20, 3).float().cuda(1))
         input = (i.cos(), (i.sin(), i), i.sin())
diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py
index 94d9d04..97fd8f4 100644
--- a/torch/nn/parallel/data_parallel.py
+++ b/torch/nn/parallel/data_parallel.py
@@ -36,7 +36,7 @@
         if len(self.device_ids) == 1:
             self.module.cuda(device_ids[0])
 
-    def forward(self, input):
+    def forward(self, *inputs):
         def _to_cuda(obj):
             if isinstance(obj, Variable):
                 return obj.cuda()
@@ -44,12 +44,12 @@
 
         if len(self.device_ids) == 1:
             with torch.cuda.device(self.device_ids[0]):
-                inpcuda = _to_cuda(input)
-            return self.module(inpcuda)
+                inputs_cuda = _to_cuda(inputs)
+            return self.module(*inputs_cuda)
         replicas = self.replicate(self.module, self.device_ids)
-        inputs = self.scatter(input, self.device_ids)
-        replicas = replicas[:len(inputs)]
-        outputs = self.parallel_apply(replicas, inputs)
+        scattered = self.scatter(inputs, self.device_ids)
+        replicas = replicas[:len(scattered)]
+        outputs = self.parallel_apply(replicas, scattered)
         return self.gather(outputs, self.output_device)
 
     def replicate(self, module, device_ids):
@@ -65,14 +65,14 @@
         return gather(outputs, output_device)
 
 
-def data_parallel(module, input, device_ids, output_device=None):
+def data_parallel(module, inputs, device_ids, output_device=None):
     """Evaluates module(input) in parallel across the GPUs given in device_ids.
 
     This is the functional version of the DataParallel module.
 
     Args:
         module: the module to evaluate in parallel
-        input: input to the module
+        inputs: inputs to the module
         device_ids: GPU ids on which to replicate module
         output_device: GPU location of the output  Use -1 to indicate the CPU.
             (default: device_ids[0])
@@ -80,14 +80,17 @@
         a Variable containing the result of module(input) located on
         output_device
     """
+    if not isinstance(inputs, tuple):
+        inputs = (inputs,)
+
     if not device_ids:
-        return module(input)
+        return module(*inputs)
 
     if output_device is None:
         output_device = device_ids[0]
 
     replicas = replicate(module, device_ids)
-    inputs = scatter(input, device_ids)
-    replicas = replicas[:len(inputs)]
-    outputs = parallel_apply(replicas, inputs)
+    scattered = scatter(inputs, device_ids)
+    replicas = replicas[:len(scattered)]
+    outputs = parallel_apply(replicas, scattered)
     return gather(outputs, output_device)
diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py
index 32df2b1..ce51be0 100644
--- a/torch/nn/parallel/parallel_apply.py
+++ b/torch/nn/parallel/parallel_apply.py
@@ -12,7 +12,7 @@
     assert len(modules) == len(inputs)
     # Fast track
     if len(modules) == 1:
-        return (modules[0](inputs[0]),)
+        return (modules[0](*inputs[0]),)
 
     lock = threading.Lock()
     results = {}
@@ -23,7 +23,7 @@
             var_input = var_input[0]
         try:
             with torch.cuda.device_of(var_input):
-                output = module(input)
+                output = module(*input)
             with lock:
                 results[input] = output
         except Exception as e: