Allow to not use all function outputs in autograd
diff --git a/setup.py b/setup.py
index bab5881..5a537f0 100644
--- a/setup.py
+++ b/setup.py
@@ -194,6 +194,7 @@
         "torch/csrc/cuda/Storage.cpp",
         "torch/csrc/cuda/Stream.cpp",
         "torch/csrc/cuda/Tensor.cpp",
+        "torch/csrc/cuda/AutoGPU.cpp",
         "torch/csrc/cuda/utils.cpp",
         "torch/csrc/cuda/serialization.cpp",
     ]
diff --git a/test/test_autograd.py b/test/test_autograd.py
index d2b04c9..b837cac 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -253,6 +253,34 @@
         self._test_setitem_tensor((5, 5), Variable(mask))
         self._test_setitem_tensor((5,), Variable(mask[0]))
 
+    def test_unused_output(self):
+        x = Variable(torch.randn(10, 10), requires_grad=True)
+        outputs = x.chunk(5)
+        o = outputs[2]
+        o = o * 4 + 2
+        o.sum().backward()
+        expected_grad = torch.zeros(10, 10)
+        expected_grad[4:6] = 4
+        self.assertEqual(x.grad, expected_grad)
+
+        x.grad.zero_()
+        grad_output = torch.randn(2, 10)
+        outputs = x.chunk(5)
+        outputs[0].backward(grad_output)
+        expected_grad = torch.zeros(10, 10)
+        expected_grad[:2] = grad_output
+        self.assertEqual(x.grad, expected_grad)
+
+    @unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2,
+            "CUDA not available or <2 GPUs detected")
+    def test_unused_output_gpu(self):
+        from torch.nn.parallel.functions import Broadcast
+        x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True)
+        outputs = Broadcast(list(range(torch.cuda.device_count())))(x)
+        y = outputs[-1] * 2
+        y.sum().backward()
+        self.assertEqual(x.grad, torch.ones(5, 5) * 2)
+
     def test_type_conversions(self):
         import torch.cuda
         x = Variable(torch.randn(5, 5))
diff --git a/torch/autograd/engine.py b/torch/autograd/engine.py
index cb9fdae..5b74818 100644
--- a/torch/autograd/engine.py
+++ b/torch/autograd/engine.py
@@ -44,7 +44,9 @@
             variable._do_backward((grad,), retain_variables)
             return
 
-        ready = deque([(variable.creator, (grad,))])
+        initial_grad = [None for _ in range(variable.creator.num_outputs)]
+        initial_grad[variable.output_nr] = grad
+        ready = deque([(variable.creator, initial_grad)])
         not_ready = {}
         need_copy = set()
 
diff --git a/torch/csrc/THP.h b/torch/csrc/THP.h
index 40a5e0b..302bd1b 100644
--- a/torch/csrc/THP.h
+++ b/torch/csrc/THP.h
@@ -25,6 +25,7 @@
 #include "Tensor.h"
 #include "Size.h"
 #include "Module.h"
+#include "Types.h"
 #include "utils.h" // This requires defined Storage and Tensor types
 #include "byte_order.h"
 
diff --git a/torch/csrc/Types.h b/torch/csrc/Types.h
new file mode 100644
index 0000000..03cd78f
--- /dev/null
+++ b/torch/csrc/Types.h
@@ -0,0 +1,40 @@
+#ifndef THP_TYPES_INC
+#define THP_TYPES_INC
+
+#include <Python.h>
+#include <cstddef>
+
+namespace torch {
+
+typedef struct THVoidStorage
+{
+  void *data;
+  ptrdiff_t size;
+  int refcount;
+  char flag;
+  void *allocator;
+  void *allocatorContext;
+  THVoidStorage *view;
+} THVoidStorage;
+
+typedef struct THVoidTensor
+{
+   long *size;
+   long *stride;
+   int nDimension;
+   THVoidStorage *storage;
+   ptrdiff_t storageOffset;
+   int refcount;
+   char flag;
+} THVoidTensor;
+
+struct THPVoidTensor {
+  PyObject_HEAD
+  THVoidTensor *cdata;
+  char device_type;
+  char data_type;
+};
+
+}  // namespace torch
+
+#endif
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 88ad966..08a0c7e 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -122,9 +122,9 @@
   buffer_set_type need_copy;
 
   // Initialize the queue
-  grad_buffer_type buf(next_buf_id++, 1);
+  grad_buffer_type buf(next_buf_id++, ((THPFunction*)variable->creator)->num_outputs);
   Py_INCREF(grad_variable);
-  buf[0] = grad_variable;
+  buf[variable->output_nr] = grad_variable;
   ready.emplace_front((THPFunction*)variable->creator, std::move(buf));
 
   dependencies_type dependencies = THPEngine_compute_dependencies((THPFunction*)variable->creator);
@@ -139,10 +139,14 @@
     THPObjectPtr grad_tuple = PyTuple_New(fn_grad_buffer.size());
     if (!grad_tuple) return NULL;
     for (unsigned int i = 0; i < fn_grad_buffer.size(); i++) {
-      // TODO: allocate correctly sized zero buffer
-      THPUtils_assert(fn_grad_buffer[i], "error no grad buffer - this will be "
-              "fixed in upcoming releases!");
-      PyTuple_SET_ITEM(grad_tuple.get(), i, fn_grad_buffer[i].release());
+      PyObject *_grad;
+      if (fn_grad_buffer[i]) {
+        _grad = fn_grad_buffer[i].release();
+      } else {
+        _grad = Py_None;
+        Py_INCREF(_grad);
+      }
+      PyTuple_SET_ITEM(grad_tuple.get(), i, _grad);
     }
 
     // Call _do_backward and make sure grad_input is sound
@@ -187,7 +191,7 @@
         grad_buffer_type prev_buffer(-1);
         if (not_ready_it == not_ready.end()) {
           // The function is ready and no buffers have been allocated for it.
-          prev_buffer = grad_buffer_type(next_buf_id++, 1);
+          prev_buffer = grad_buffer_type(next_buf_id++, prev_fn->num_outputs);
           Py_INCREF(grad_prev);
           prev_buffer[output_idx] = grad_prev;
         } else {
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp
index 2c12ca3..ee0cada 100644
--- a/torch/csrc/autograd/function.cpp
+++ b/torch/csrc/autograd/function.cpp
@@ -5,6 +5,10 @@
 
 #include "THP.h"
 
+#ifdef WITH_CUDA
+#include "cuda/AutoGPU.h"
+#endif
+
 PyObject *THPFunctionClass = NULL;
 
 static void THPFunction_dealloc(THPFunction* self)
@@ -24,6 +28,7 @@
   THPFunctionPtr *previous_functions = self->previous_functions;
   self->previous_functions = NULL;
   delete[] previous_functions;
+  delete self->output_info;
 
   Py_TYPE(self)->tp_free((PyObject*)self);
 }
@@ -71,17 +76,8 @@
   THPFunction *self = (THPFunction*)type->tp_alloc(type, 0);
   if (!self)
     return NULL;
-  self->previous_functions = NULL;
-  self->needs_input_grad = NULL;
-  self->saved_variables = NULL;
-  self->backward_hooks = NULL;
-  self->to_save = NULL;
-  self->shared_pairs = NULL;
-  self->non_differentiable = NULL;
-  self->dirty_tensors = NULL;
-  self->needs_input_grad = 0;
-  self->has_freed_buffers = 0;
-  self->num_inputs = 0;
+  // Python zero-initializes the object memory, so there's no need to initialize
+  // most fields
   self->num_outputs = -1;
   return (PyObject*)self;
 }
@@ -123,6 +119,8 @@
 {
   // Wrap outputs in Variables
   Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
+  self->output_info = new std::vector<output_info_type>(num_outputs);
+  auto &output_info = *self->output_info;
   for (int i = 0; i < num_outputs; i++) {
     PyObject *output = PyTuple_GET_ITEM(raw_output, i);
     THPVariable *output_var;
@@ -158,6 +156,23 @@
     if (!output_var)
       return false;
 
+    torch::THPVoidTensor *output_obj = (torch::THPVoidTensor*)output_var->data;
+    torch::THVoidTensor *output_tensor = output_obj->cdata;
+    long ndim = output_tensor->nDimension;
+    int device_id = -1;
+    THPObjectPtr is_cuda = PyObject_GetAttrString(output_var->data, "is_cuda");
+    if (is_cuda.get() == Py_True) {
+      THPObjectPtr device_id_obj = PyObject_CallMethod(output_var->data,
+          "get_device", "");
+      THPUtils_assertRet(false, THPUtils_checkLong(device_id_obj), "get_device "
+          "should return an int, but got %s", THPUtils_typename(device_id_obj));
+      device_id = THPUtils_unpackLong(device_id_obj);
+    }
+    output_info[i] = std::make_tuple(
+      (PyObject*)Py_TYPE(output_var->data),
+      device_id,
+      std::vector<long>(output_tensor->size, output_tensor->size + ndim)
+    );
     t2var[output] = output_var;
     output_var->output_nr = i;
     PyTuple_SET_ITEM(outputs, i, (PyObject*)output_var);
@@ -371,7 +386,6 @@
     // Mark non-differentiable outputs as not requiring gradients
   }
 
-
   if (num_outputs == 1) {
     PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0);
     Py_INCREF(output);
@@ -385,17 +399,43 @@
 {
   Py_ssize_t num_args = args ? PyTuple_GET_SIZE(args) : 0;
   THPUtils_assert(num_args == 2, "_do_backward expects exactly two arguments");
-  PyObject *grad_output = PyTuple_GET_ITEM(args, 0);
+  PyObject *raw_grad_output = PyTuple_GET_ITEM(args, 0);
   PyObject *retain_variables = PyTuple_GET_ITEM(args, 1);
-  if (!PyTuple_Check(grad_output) || !PyBool_Check(retain_variables)) {
+  if (!PyTuple_Check(raw_grad_output) || !PyBool_Check(retain_variables)) {
     THPUtils_invalidArguments(args, "_do_backward", 1, "(tuple, bool)");
     return NULL;
   }
 
+  int num_grad_output = PyTuple_GET_SIZE(raw_grad_output);
+  THPObjectPtr grad_output = PyTuple_New(num_grad_output);
+  if (!grad_output) return NULL;
+#ifdef WITH_CUDA
+  THCPAutoGPU gpu_guard(-1);
+#endif
+  for (int i = 0; i < num_grad_output; i++) {
+    PyObject *grad = PyTuple_GET_ITEM(raw_grad_output, i);
+    // If there's no gradient we have to allocate a buffer ourselves
+    if (grad == Py_None) {
+      auto &info = (*self->output_info)[i];
+      PyObject *tensor_cls = std::get<0>(info);
+      gpu_guard.setDevice(std::get<1>(info));
+      std::vector<long> &sizes = std::get<2>(info);
+      THPObjectPtr grad_size = THPSize_New(sizes.size(), sizes.data());
+      THPObjectPtr new_grad = PyObject_CallFunctionObjArgs(tensor_cls, grad_size.get(), NULL);
+      if (!new_grad) return NULL;
+      THPObjectPtr result = PyObject_CallMethod(new_grad.get(), "zero_", "");
+      if (!result) return NULL;
+      grad = new_grad.release();
+    } else {
+      Py_INCREF(grad);
+    }
+    PyTuple_SET_ITEM(grad_output.get(), i, grad);
+  }
+
   THPObjectPtr backward_fn = PyObject_GetAttrString((PyObject*)self, "backward");
   THPUtils_assert(backward_fn.get(), "function %s doesn't implement a required "
       "'backward' method", THPUtils_typename((PyObject*)self));
-  THPObjectPtr grad_input = PyObject_CallObject(backward_fn, grad_output);
+  THPObjectPtr grad_input = PyObject_CallObject(backward_fn, grad_output.get());
   if (!grad_input)
     return NULL;
 
@@ -421,7 +461,7 @@
         "attribute has to be a dictionary");
     while (PyDict_Next(self->backward_hooks, &pos, &key, &value)) {
       THPObjectPtr result = PyObject_CallFunctionObjArgs(value,
-          grad_input.get(), grad_output, NULL);
+          grad_input.get(), grad_output.get(), NULL);
       if (!result)
         return NULL;
     }
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index f6488f1..6f8fb4f 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -24,6 +24,9 @@
     int output_nr;
 };
 
+// (class, gpu id, sizes)
+using output_info_type = std::tuple<PyObject *, int, std::vector<long>>;
+
 struct THPFunction {
     PyObject_HEAD
 
@@ -37,6 +40,7 @@
     PyObject *dirty_tensors;
 
     THPFunctionPtr *previous_functions;
+    std::vector<output_info_type> *output_info;
     int num_inputs;
     int num_outputs;
     char requires_grad;
diff --git a/torch/csrc/cuda/AutoGPU.cpp b/torch/csrc/cuda/AutoGPU.cpp
new file mode 100644
index 0000000..6dc93b2
--- /dev/null
+++ b/torch/csrc/cuda/AutoGPU.cpp
@@ -0,0 +1,61 @@
+#include "AutoGPU.h"
+
+#include "THCP.h"
+#include <THC/THC.h>
+
+THCPAutoGPU::THCPAutoGPU(int device_id) {
+  setDevice(device_id);
+}
+
+THCPAutoGPU::THCPAutoGPU(PyObject *args, PyObject *self) {
+  if (self && setObjDevice(self))
+    return;
+
+  if (!args)
+    return;
+  for (int i = 0; i < PyTuple_Size(args); i++) {
+    PyObject *arg = PyTuple_GET_ITEM(args, i);
+    if (setObjDevice(arg)) return;
+  }
+}
+
+bool THCPAutoGPU::setObjDevice(PyObject *obj) {
+  int new_device = -1;
+  PyObject *obj_type = (PyObject*)Py_TYPE(obj);
+  if (obj_type == THCPDoubleTensorClass) {
+    new_device = THCudaDoubleTensor_getDevice(LIBRARY_STATE ((THCPDoubleTensor*)obj)->cdata);
+  } else if (obj_type == THCPFloatTensorClass) {
+    new_device = THCudaTensor_getDevice(LIBRARY_STATE ((THCPFloatTensor*)obj)->cdata);
+  } else if (obj_type == THCPHalfTensorClass) {
+    new_device = THCudaHalfTensor_getDevice(LIBRARY_STATE ((THCPHalfTensor*)obj)->cdata);
+  } else if (obj_type == THCPLongTensorClass) {
+    new_device = THCudaLongTensor_getDevice(LIBRARY_STATE ((THCPLongTensor*)obj)->cdata);
+  } else if (obj_type == THCPIntTensorClass) {
+    new_device = THCudaIntTensor_getDevice(LIBRARY_STATE ((THCPIntTensor*)obj)->cdata);
+  } else if (obj_type == THCPShortTensorClass) {
+    new_device = THCudaShortTensor_getDevice(LIBRARY_STATE ((THCPShortTensor*)obj)->cdata);
+  } else if (obj_type == THCPCharTensorClass) {
+    new_device = THCudaCharTensor_getDevice(LIBRARY_STATE ((THCPCharTensor*)obj)->cdata);
+  } else if (obj_type == THCPByteTensorClass) {
+    new_device = THCudaByteTensor_getDevice(LIBRARY_STATE ((THCPByteTensor*)obj)->cdata);
+  }
+  return setDevice(new_device);
+}
+
+bool THCPAutoGPU::setDevice(int new_device) {
+  if (new_device == -1)
+    return false;
+
+  if (device == -1)
+    THCudaCheck(cudaGetDevice(&device));
+  if (new_device != device)
+    THCPModule_setDevice(new_device);
+  return true;
+}
+
+// This can throw... But if it does I have no idea how to recover.
+THCPAutoGPU::~THCPAutoGPU() {
+  if (device != -1)
+    THCPModule_setDevice(device);
+}
+
diff --git a/torch/csrc/cuda/AutoGPU.h b/torch/csrc/cuda/AutoGPU.h
new file mode 100644
index 0000000..fa2a04f
--- /dev/null
+++ b/torch/csrc/cuda/AutoGPU.h
@@ -0,0 +1,16 @@
+#ifndef THCP_AUTOGPU_INC
+#define THCP_AUTOGPU_INC
+
+#include <Python.h>
+
+class THCPAutoGPU {
+public:
+  THCPAutoGPU(int device_id=-1);
+  THCPAutoGPU(PyObject *args, PyObject *self=NULL);
+  ~THCPAutoGPU();
+  bool setObjDevice(PyObject *obj);
+  bool setDevice(int new_device);
+  int device = -1;
+};
+
+#endif
diff --git a/torch/csrc/cuda/THCP.h b/torch/csrc/cuda/THCP.h
index 712dafa..e7f8689 100644
--- a/torch/csrc/cuda/THCP.h
+++ b/torch/csrc/cuda/THCP.h
@@ -6,6 +6,7 @@
 
 #include "torch/csrc/THP.h"
 #include "serialization.h"
+#include "AutoGPU.h"
 #include "Module.h"
 #include "Storage.h"
 #include "Tensor.h"
diff --git a/torch/csrc/cuda/Tensor.cpp b/torch/csrc/cuda/Tensor.cpp
index 69210d3..88a0e09 100644
--- a/torch/csrc/cuda/Tensor.cpp
+++ b/torch/csrc/cuda/Tensor.cpp
@@ -10,61 +10,6 @@
 
 #include "override_macros.h"
 
-THCPAutoGPU::THCPAutoGPU(int device_id) {
-  setDevice(device_id);
-}
-
-THCPAutoGPU::THCPAutoGPU(PyObject *args, PyObject *self) {
-  if (self && setObjDevice(self))
-    return;
-
-  if (!args)
-    return;
-  for (int i = 0; i < PyTuple_Size(args); i++) {
-    PyObject *arg = PyTuple_GET_ITEM(args, i);
-    if (setObjDevice(arg)) return;
-  }
-}
-
-bool THCPAutoGPU::setObjDevice(PyObject *obj) {
-  int new_device = -1;
-  PyObject *obj_type = (PyObject*)Py_TYPE(obj);
-  if (obj_type == THCPDoubleTensorClass) {
-    new_device = THCudaDoubleTensor_getDevice(LIBRARY_STATE ((THCPDoubleTensor*)obj)->cdata);
-  } else if (obj_type == THCPFloatTensorClass) {
-    new_device = THCudaTensor_getDevice(LIBRARY_STATE ((THCPFloatTensor*)obj)->cdata);
-  } else if (obj_type == THCPHalfTensorClass) {
-    new_device = THCudaHalfTensor_getDevice(LIBRARY_STATE ((THCPHalfTensor*)obj)->cdata);
-  } else if (obj_type == THCPLongTensorClass) {
-    new_device = THCudaLongTensor_getDevice(LIBRARY_STATE ((THCPLongTensor*)obj)->cdata);
-  } else if (obj_type == THCPIntTensorClass) {
-    new_device = THCudaIntTensor_getDevice(LIBRARY_STATE ((THCPIntTensor*)obj)->cdata);
-  } else if (obj_type == THCPShortTensorClass) {
-    new_device = THCudaShortTensor_getDevice(LIBRARY_STATE ((THCPShortTensor*)obj)->cdata);
-  } else if (obj_type == THCPCharTensorClass) {
-    new_device = THCudaCharTensor_getDevice(LIBRARY_STATE ((THCPCharTensor*)obj)->cdata);
-  } else if (obj_type == THCPByteTensorClass) {
-    new_device = THCudaByteTensor_getDevice(LIBRARY_STATE ((THCPByteTensor*)obj)->cdata);
-  }
-  return setDevice(new_device);
-}
-
-bool THCPAutoGPU::setDevice(int new_device) {
-  if (new_device == -1)
-    return false;
-
-  if (device == -1)
-    THCudaCheck(cudaGetDevice(&device));
-  THCPModule_setDevice(new_device);
-  return true;
-}
-
-// This can throw... But if it does I have no idea how to recover.
-THCPAutoGPU::~THCPAutoGPU() {
-  if (device != -1)
-    THCPModule_setDevice(device);
-}
-
 #define THC_GENERIC_FILE "torch/csrc/generic/Tensor.cpp"
 #include <THC/THCGenerateAllTypes.h>
 
diff --git a/torch/csrc/cuda/Tensor.h b/torch/csrc/cuda/Tensor.h
index 8600100..7320c09 100644
--- a/torch/csrc/cuda/Tensor.h
+++ b/torch/csrc/cuda/Tensor.h
@@ -1,16 +1,6 @@
 #ifndef THCP_TENSOR_INC
 #define THCP_TENSOR_INC
 
-class THCPAutoGPU {
-public:
-  THCPAutoGPU(int device_id=-1);
-  THCPAutoGPU(PyObject *args, PyObject *self=NULL);
-  ~THCPAutoGPU();
-  bool setObjDevice(PyObject *obj);
-  bool setDevice(int new_device);
-  int device = -1;
-};
-
 #define THCPTensor TH_CONCAT_3(THCP,Real,Tensor)
 #define THCPTensorStr TH_CONCAT_STRING_3(torch.cuda.,Real,Tensor)
 #define THCPTensorClass TH_CONCAT_3(THCP,Real,TensorClass)
diff --git a/torch/csrc/cudnn/Conv.h b/torch/csrc/cudnn/Conv.h
index 3d19064..3339536 100644
--- a/torch/csrc/cudnn/Conv.h
+++ b/torch/csrc/cudnn/Conv.h
@@ -4,7 +4,7 @@
 #include <cudnn.h>
 #include "THC/THC.h"
 
-#include "Types.h"
+#include "../Types.h"
 #include "Descriptors.h"
 
 namespace torch { namespace cudnn {
diff --git a/torch/csrc/cudnn/Types.h b/torch/csrc/cudnn/Types.h
index 1f19e02..3abb4bf 100644
--- a/torch/csrc/cudnn/Types.h
+++ b/torch/csrc/cudnn/Types.h
@@ -7,33 +7,6 @@
 
 namespace torch { namespace cudnn {
 
-typedef struct THVoidStorage
-{
-  void *data;
-  ptrdiff_t size;
-  int refcount;
-  char flag;
-  void *allocator;
-  void *allocatorContext;
-  THVoidStorage *view;
-} THVoidStorage;
-
-typedef struct THVoidTensor
-{
-   long *size;
-   long *stride;
-   int nDimension;
-   THVoidStorage *storage;
-   ptrdiff_t storageOffset;
-   int refcount;
-   char flag;
-} THVoidTensor;
-
-struct THPVoidTensor {
-  PyObject_HEAD
-  THVoidTensor *cdata;
-};
-
 PyObject * getTensorClass(PyObject *args);
 cudnnDataType_t getCudnnDataType(PyObject *tensorClass);
 
diff --git a/torch/csrc/cudnn/cuDNN.cwrap b/torch/csrc/cudnn/cuDNN.cwrap
index fbdccbf..50966f9 100644
--- a/torch/csrc/cudnn/cuDNN.cwrap
+++ b/torch/csrc/cudnn/cuDNN.cwrap
@@ -7,6 +7,7 @@
 
 
 using namespace torch::cudnn;
+using namespace torch;
 
 extern THCState* state;