Record shape and type in autograd to validate gradients (#8168)

The check that the gradient is defined is currently disabled because
TestJit.test_ge_optimized will trigger the error.
diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h
index e41c2d9..8456888 100644
--- a/aten/src/ATen/ATen.h
+++ b/aten/src/ATen/ATen.h
@@ -15,3 +15,4 @@
 #include "ATen/TensorOperators.h"
 #include "ATen/TensorMethods.h"
 #include "ATen/Dispatch.h"
+#include "ATen/DimVector.h"
diff --git a/aten/src/ATen/DimVector.h b/aten/src/ATen/DimVector.h
new file mode 100644
index 0000000..aaa4dc9
--- /dev/null
+++ b/aten/src/ATen/DimVector.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "SmallVector.h"
+#include <stdint.h>
+
+namespace at {
+
+/// A container for sizes or strides
+using DimVector = SmallVector<int64_t, 5>;
+
+}
diff --git a/aten/src/ATen/SmallVector.h b/aten/src/ATen/SmallVector.h
index 521e460..3a5926a 100644
--- a/aten/src/ATen/SmallVector.h
+++ b/aten/src/ATen/SmallVector.h
@@ -921,6 +921,12 @@
       SmallVectorImpl<T>::operator=(::std::move(RHS));
   }
 
+  template<typename Container>
+  const SmallVector &operator=(const Container &RHS) {
+    this->assign(RHS.begin(), RHS.end());
+    return *this;
+  }
+
   SmallVector(SmallVectorImpl<T> &&RHS) : SmallVectorImpl<T>(N) {
     if (!RHS.empty())
       SmallVectorImpl<T>::operator=(::std::move(RHS));
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp
index a494e33..0f4fa33 100644
--- a/test/cpp/api/misc.cpp
+++ b/test/cpp/api/misc.cpp
@@ -71,7 +71,7 @@
   }
   SECTION("custom gradient inputs") {
     z.sum().backward(
-        autograd::make_variable(at::ones(at::CPU(at::kFloat), {1}) * 2));
+        autograd::make_variable(at::ones(at::CPU(at::kFloat), {}) * 2));
     REQUIRE(x.grad().allclose(y * 2));
   }
   // Assume everything else is safe from PyTorch tests.
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 0c19e24..4c26e58 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -75,7 +75,7 @@
         x = torch.randn(5, 5, requires_grad=True)
         y = torch.randn(5, 5, requires_grad=True)
         result = cls.apply(x, 2, y)
-        go = torch.ones(1, requires_grad=True)
+        go = torch.ones((), requires_grad=True)
         result.sum().backward(go, create_graph=True)
 
         self.assertEqual(x.grad.data, y.data + torch.ones(5, 5))
@@ -173,6 +173,23 @@
         MyFunction()(y).sum().backward()
         self.assertEqual(v.grad.data, torch.zeros(shape))
 
+    def test_invalid_gradients(self):
+        class MyFunction(Function):
+            @staticmethod
+            def forward(ctx, x):
+                return x * 2
+
+            @staticmethod
+            def backward(ctx, grad_output):
+                return torch.randn(10, dtype=torch.float)
+
+        with self.assertRaisesRegex(RuntimeError, 'expected shape'):
+            input = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
+            MyFunction.apply(input).sum().backward()
+        with self.assertRaisesRegex(RuntimeError, 'expected type'):
+            input = torch.randn(10, dtype=torch.double, requires_grad=True)
+            MyFunction.apply(input).sum().backward()
+
     def test_accumulate_grad(self):
         grad_output = torch.ones(5, 5)
 
@@ -495,7 +512,6 @@
 
     def test_sparse_backward(self):
         class FixedGradientFunction(Function):
-
             def __init__(self, grad):
                 self.grad = grad
 
@@ -524,15 +540,15 @@
         dense_fn = FixedGradientFunction(dense_grad)
 
         # sparse first
-        x = torch.randn(5, 5, requires_grad=True)
+        x = torch.randn(size, requires_grad=True)
         (sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward()
         self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
         # dense first
-        x = torch.randn(5, 5, requires_grad=True)
+        x = torch.randn(size, requires_grad=True)
         (dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward()
         self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
         # sparse only
-        x = torch.randn(5, 5, requires_grad=True)
+        x = torch.randn(size, requires_grad=True)
         (sparse_fn1(x) + sparse_fn2(x)).sum().backward()
         self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
 
diff --git a/test/test_jit.py b/test/test_jit.py
index d56f5ad..4b6ee94 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1152,7 +1152,7 @@
         out = func(x, y)
         self.assertEqual(func(x, y), x + y)
 
-        grad = torch.randn(2, 3)
+        grad = torch.randn(2, 3, dtype=torch.float)
         out.backward(grad)
         self.assertEqual(x.grad, grad)
         self.assertEqual(y.grad, grad.sum(dim=0))
diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py
index a343050..94ed1f8 100644
--- a/tools/autograd/gen_autograd_functions.py
+++ b/tools/autograd/gen_autograd_functions.py
@@ -18,7 +18,7 @@
 struct ${op} : public ${superclass} {
   using ${superclass}::${superclass};
   variable_list apply(const variable_list& grads) override;
-  std::string name() override { return "${op}"; }
+  std::string name() const override { return "${op}"; }
   void release_variables() override {
     ${release_variables}
   }
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index 38fa9c2..3dfb633 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -359,35 +359,35 @@
 
 static void rebase_history(Variable& var, std::shared_ptr<Function> grad_fn) {
   if (grad_fn && var.defined()) {
-    grad_fn->set_num_inputs(1);
+    grad_fn->add_input_metadata(var.type(), var.sizes());
     var.rebase_history({std::move(grad_fn), 0});
   }
 }
 
 static void rebase_history(ArrayRef<Variable> vars, std::shared_ptr<Function> grad_fn) {
   if (grad_fn) {
-    grad_fn->set_num_inputs(vars.size());
-    uint32_t output_nr = 0;
     for (auto& var : vars) {
       if (var.defined()) {
         // TODO: eliminate const_cast
+        auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes());
         const_cast<Variable&>(var).rebase_history({grad_fn, output_nr});
+      } else {
+        grad_fn->add_input_metadata(Function::undefined_input());
       }
-      output_nr++;
     }
   }
 }
 
 static void set_history(ArrayRef<Variable> vars, std::shared_ptr<Function> grad_fn) {
   if (grad_fn) {
-    grad_fn->set_num_inputs(vars.size());
-    uint32_t output_nr = 0;
     for (auto& var : vars) {
       if (var.defined()) {
         // TODO: eliminate const_cast
+        auto output_nr = grad_fn->add_input_metadata(var.type(), var.sizes());
         const_cast<Variable&>(var).set_gradient_edge({grad_fn, output_nr});
+      } else {
+        grad_fn->add_input_metadata(Function::undefined_input());
       }
-      output_nr++;
     }
   }
 }
@@ -428,7 +428,6 @@
   if (requires_grad) {
     grad_fn = std::make_shared<CopyBackwards>();
     grad_fn->set_next_edges(collect_next_edges(self, src));
-    grad_fn->set_num_inputs(1);
     grad_fn->src_type = &src.type();
     grad_fn->src_device = src.is_cuda() ? src.get_device() : -1;
   }
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 8f4d723..88aec15 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -288,6 +288,49 @@
   return outputs;
 }
 
+static bool is_compatible_type(const at::Type& expected, const at::Type& actual) {
+  // Types are compatible if they exactly match or if the gradient is a sparse
+  // version of the expected type.
+  return expected == actual || (actual.is_sparse() &&
+      expected == actual.toBackend(toDense(actual.backend())));
+}
+
+template<typename F>
+static void validate_outputs(const edge_list& edges, const variable_list& grads, const F& format_error) {
+  if (grads.size() != edges.size()) {
+    std::stringstream ss;
+    ss << "invalid number of gradients - expected ";
+    ss << edges.size() << ", but got " << grads.size();
+    throw std::runtime_error(format_error(ss.str()));
+  }
+  for (size_t i = 0; i < grads.size(); i++) {
+    const auto& edge = edges[i];
+    if (!edge.is_valid()) continue;
+
+    const auto& metadata = edge.function->input_metadata(edge.input_nr);
+    const auto& output = grads[i];
+    if (!output.defined()) {
+      // FIXME: TestJit.test_ge_optimized fails this assertion.
+      // std::stringstream ss;
+      // ss << "undefined gradient at index " << i;
+      // throw std::runtime_error(format_error(ss.str()));
+      continue;
+    }
+    if (!grads[i].sizes().equals(metadata.shape())) {
+      std::stringstream ss;
+      ss << "invalid gradient at index " << i << " - expected shape ";
+      ss << metadata.shape() << " but got " << grads[i].sizes();
+      throw std::runtime_error(format_error(ss.str()));
+    }
+    if (!is_compatible_type(metadata.type(), grads[i].type())) {
+      std::stringstream ss;
+      ss << "invalid gradient at index " << i << " - expected type ";
+      ss << metadata.type() << " but got " << grads[i].type();
+      throw std::runtime_error(format_error(ss.str()));
+    }
+  }
+}
+
 static variable_list call_function(FunctionTask& task) {
   bool prev_checkpoint_valid_state = checkpoint_valid;
   checkpoint_valid = task.base->can_checkpoint() && prev_checkpoint_valid_state;
@@ -298,6 +341,11 @@
     fn.will_release_variables();
   }
   auto outputs = fn(inputs);
+  validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
+    std::ostringstream ss;
+    ss << "Function "  << fn.name() << " returned an " << msg;
+    return ss.str();
+  });
   checkpoint_valid = prev_checkpoint_valid_state;
   return call_post_hooks(fn, std::move(outputs), std::move(inputs));
 }
@@ -323,13 +371,6 @@
     fn.release_variables();
   }
 
-  if (outputs.size() != fn.num_outputs()) {
-    std::stringstream ss;
-    ss << "Function '" << fn.name() << "' returned an invalid number of outputs - expected ";
-    ss << fn.num_outputs() << ", but got " << outputs.size();
-    throw std::runtime_error(ss.str());
-  }
-
   int num_outputs = outputs.size();
   if (num_outputs == 0) return; // Don't even acquire the mutex
   std::lock_guard<std::mutex> lock(task.base->mutex);
@@ -426,6 +467,11 @@
                      bool create_graph,
                      const edge_list& outputs) -> variable_list {
   std::call_once(start_threads_flag, &Engine::start_threads, this);
+
+  validate_outputs(input_roots, inputs, [](const std::string& msg) {
+    return msg;
+  });
+
   // Callbacks are only valid for the duration of this run and should always be cleared
   ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock);
 
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp
index 47b4e06..d116069 100644
--- a/torch/csrc/autograd/function.cpp
+++ b/torch/csrc/autograd/function.cpp
@@ -19,8 +19,8 @@
 
 thread_local uint64_t Function::next_sequence_nr_ = 0;
 
-auto Function::name() -> std::string {
-  return std::string(typeid(*this).name());
+auto Function::name() const -> std::string {
+  return at::demangle(typeid(*this).name());
 }
 
 // This function is analogous to make_trace which operates on PythonOp, but this
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index 77e9b21..d76296a 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -5,6 +5,7 @@
 #include "torch/csrc/autograd/grad_mode.h"
 #include "torch/csrc/autograd/profiler.h"
 #include "torch/csrc/autograd/saved_variable.h"
+#include "torch/csrc/autograd/type_and_shape.h"
 #include "torch/csrc/autograd/variable.h"
 #include "torch/csrc/jit/tracer.h"
 #include "torch/csrc/utils/auto_unique_ptr.h"
@@ -91,18 +92,15 @@
   /// in the backward() pass, with higher sequence numbers prioritized
   /// before lower sequence numbers.
   explicit Function(
-      uint32_t num_inputs,
       uint64_t sequence_nr,
       edge_list&& next_edges = edge_list())
       : sequence_nr_(sequence_nr),
-      num_inputs_(num_inputs),
       next_edges_(std::move(next_edges)) {}
 
   explicit Function(
-      uint32_t num_inputs = 0,
       edge_list&& next_edges = edge_list())
-      : Function(num_inputs, next_sequence_nr_++, std::move(next_edges)) {}
-  
+      : Function(next_sequence_nr_++, std::move(next_edges)) {}
+
   /// Functions are neither copyable nor moveable.
   Function(const Function& other) = delete;
   Function(Function&& other) = delete;
@@ -123,20 +121,37 @@
   // Graph Connectivity API
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-  // Inputs
+  // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the
+  // forward function.
 
-  /// Increments the number of inputs of the function and returns the previous
-  /// value.
-  uint32_t bump_inputs() noexcept {
-    return num_inputs_++;
+  // Marker for expected undefined input
+  struct undefined_input {};
+
+  /// Adds the type and shape metadata for a new input. Returns the index of
+  /// of the new input.
+  uint32_t add_input_metadata(const at::Type& type, at::IntList shape) noexcept {
+    uint32_t input_nr = input_metadata_.size();
+    input_metadata_.emplace_back(type, shape);
+    return input_nr;
   }
 
-  void set_num_inputs(uint32_t num_inputs) noexcept {
-    num_inputs_ = num_inputs;
+  /// Adds a placeholder for an input that will not be used.
+  uint32_t add_input_metadata(undefined_input u) noexcept {
+    uint32_t input_nr = input_metadata_.size();
+    input_metadata_.emplace_back();
+    return input_nr;
   }
 
   uint32_t num_inputs() const noexcept {
-    return num_inputs_;
+    return input_metadata_.size();
+  }
+
+  const TypeAndShape& input_metadata(size_t index) const {
+    return input_metadata_[index];
+  }
+
+  void clear_input_metadata() {
+    input_metadata_.clear();
   }
 
   // Outputs ("Next Edges")
@@ -185,7 +200,7 @@
   }
 
   /// Returns the name of the dynamic type of the function, for debugging.
-  virtual std::string name();
+  virtual std::string name() const;
 
   /// Returns true if the particular output edge is active, and that particular
   /// output of this function should be computed.
@@ -312,12 +327,12 @@
   // fields.
   const uint64_t sequence_nr_;
 
-  uint32_t num_inputs_;
   edge_list next_edges_;
   PyObject* pyobj_ = nullptr; // weak reference
   std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
   std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
   auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state_;
+  at::SmallVector<TypeAndShape, 2> input_metadata_;
 };
 
 /// See Function::is_traceable() for definition.
@@ -355,13 +370,14 @@
 /// `input_nr` thus equal to `function->num_inputs()`. Additionally, it
 /// increments the `Function`'s number of inputs by one. Approximately
 /// equivalent to `variable.set_gradient_edge(function,
-/// function->bump_inputs())`. If you don't want the `Function`'s `num_inputs`
-/// to be incremented, use `set_gradient_edge` directly.
+/// function->add_input_metadata(variable.type(), variable.sizes()))`.
+/// If you don't want the `Function`'s `num_inputs` to be incremented, use
+/// `set_gradient_edge` directly.
 inline void create_gradient_edge(
     Variable& variable,
     std::shared_ptr<Function> function) {
   // Copy before move.
-  const auto input_nr = function->bump_inputs();
+  const auto input_nr = function->add_input_metadata(variable.type(), variable.sizes());
   variable.set_gradient_edge({std::move(function), input_nr});
 }
 
diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp
index 31fc243..fdbe54d 100644
--- a/torch/csrc/autograd/functions/accumulate_grad.cpp
+++ b/torch/csrc/autograd/functions/accumulate_grad.cpp
@@ -13,12 +13,13 @@
 
 namespace torch { namespace autograd {
 
-// AccumulateGrad sets sequence_nr to the max value so it's always called 
+// AccumulateGrad sets sequence_nr to the max value so it's always called
 // ASAP during backwards.
 AccumulateGrad::AccumulateGrad(Variable variable_)
-    : Function(/*num_inputs=*/1
-              , /*sequence_nr=*/UINT64_MAX)
-              , variable(std::move(variable_)) {}
+    : Function(/*sequence_nr=*/UINT64_MAX)
+    , variable(std::move(variable_)) {
+  add_input_metadata(variable.type(), variable.sizes());
+}
 
 auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
   // XXX: this method is not thread-safe!
diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h
index a630c3c..7ac4d2a 100644
--- a/torch/csrc/autograd/functions/basic_ops.h
+++ b/torch/csrc/autograd/functions/basic_ops.h
@@ -12,7 +12,7 @@
 
 struct Error : public Function {
   Error(std::string msg, edge_list&& next_edges)
-    : Function(/*num_inputs=*/0, std::move(next_edges))
+    : Function(std::move(next_edges))
     , msg(std::move(msg)) {}
 
   Error(std::string msg)
@@ -35,7 +35,7 @@
 
 struct GraphRoot : public Function {
   GraphRoot(edge_list functions, variable_list inputs)
-      : Function(/*num_inputs=*/0, std::move(functions)),
+      : Function(std::move(functions)),
         outputs(std::move(inputs)) {}
 
   virtual variable_list apply(const variable_list& inputs) {
diff --git a/torch/csrc/autograd/functions/special.cpp b/torch/csrc/autograd/functions/special.cpp
index a5c5a62..88ac969 100644
--- a/torch/csrc/autograd/functions/special.cpp
+++ b/torch/csrc/autograd/functions/special.cpp
@@ -17,7 +17,9 @@
 // Used when an output has multiple uses (there's only one entry
 // in next_edges per output).
 struct Replicate : public Function {
-  Replicate() : Function(/*num_inputs=*/1) {}
+  Replicate(const at::Type& type, at::IntList shape) : Function() {
+    add_input_metadata(type, shape);
+  }
 
   virtual variable_list apply(const variable_list& inputs) {
 		TORCH_ASSERT(inputs.size() == 1);
@@ -236,6 +238,7 @@
     // This detaches the subgraph from the full backward graph.
     for (auto& begin : subgraph.boundary.begins) {
       const auto& edge = begin.function->next_edge(begin.input_nr);
+
       begin.function->set_next_edge(
           begin.input_nr, Edge(ends_to_outputs.at(edge), 0));
     }
@@ -265,7 +268,7 @@
     // the same Variable has been returned multiple times, and
     // is repeated in this list.
     if (output.grad_fn_unsafe() == this) {
-      auto replicate = std::make_shared<Replicate>();
+      auto replicate = std::make_shared<Replicate>(output.type(), output.sizes());
       replicate->add_next_edge({this_shared, output.output_nr()});
       output.set_gradient_edge({std::move(replicate), 0});
       repeated_outputs.emplace(&output);
@@ -274,7 +277,8 @@
     // perform any allocations until we actually see repeated outputs.
     if (repeated_outputs.count(&output) > 0) {
       auto & replicate = output.grad_fn();
-      replicate->add_next_edge({this_shared, num_inputs_++});
+      auto input_nr = add_input_metadata(output.type(), output.sizes());
+      replicate->add_next_edge({this_shared, input_nr});
     } else {
       autograd::create_gradient_edge(output, this_shared);
     }
diff --git a/torch/csrc/autograd/functions/special.h b/torch/csrc/autograd/functions/special.h
index 076a4fa..273b139 100644
--- a/torch/csrc/autograd/functions/special.h
+++ b/torch/csrc/autograd/functions/special.h
@@ -15,7 +15,9 @@
 
 struct EvalOutput : Function {
   explicit EvalOutput(const Edge& next_edge_)
-      : Function(/*num_inputs=*/1), next_edge(next_edge_) {}
+      : Function(), next_edge(next_edge_) {
+    add_input_metadata(undefined_input());
+  }
 
   virtual variable_list apply(const variable_list& inputs) override {
     throw std::logic_error("EvalOutput::apply() called");
diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp
index 75ec818..5df72d5 100644
--- a/torch/csrc/autograd/functions/tensor.cpp
+++ b/torch/csrc/autograd/functions/tensor.cpp
@@ -36,12 +36,13 @@
     const Variable& base_var,
     at::TensorGeometry view_,
     std::shared_ptr<Function> fn_)
-    : Function(/*num_inputs=*/1),
+    : Function(),
       base(base_var),
       view(std::move(view_)),
       fn(std::move(fn_)) {
   // Take the next_edges of fn as our own, except for index 0 which goes
   // to base instead of the view.
+  add_input_metadata(base_var.type(), base_var.sizes());
   const auto num_outputs = fn->num_outputs();
   next_edges_.reserve(num_outputs);
   add_next_edge(base_var.gradient_edge());
diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp
index 09939a1..485572d 100644
--- a/torch/csrc/autograd/functions/utils.cpp
+++ b/torch/csrc/autograd/functions/utils.cpp
@@ -29,7 +29,7 @@
         autograd::create_gradient_edge(variable, grad_fn);
         result.push_back(std::move(variable));
       } else {
-        grad_fn->bump_inputs();
+        grad_fn->add_input_metadata(Function::undefined_input());
         result.emplace_back();
       }
     }
diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp
index e07d88f..e52e06a 100644
--- a/torch/csrc/autograd/python_engine.cpp
+++ b/torch/csrc/autograd/python_engine.cpp
@@ -134,7 +134,7 @@
     }
   }
 
-  edge_list output_edges;
+  std::vector<Edge> output_edges;
   if (inputs != nullptr) {
     int num_inputs = PyTuple_GET_SIZE(inputs);
     output_edges.reserve(num_inputs);
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 3fb53c9..4672d53 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -207,7 +207,7 @@
   f->has_freed_buffers = 1;
 }
 
-auto PyFunction::name() -> std::string {
+auto PyFunction::name() const -> std::string {
   AutoGIL gil;
   auto f = (THPFunction*) obj;
   auto name = std::string(Py_TYPE(f)->tp_name);
@@ -245,7 +245,7 @@
 
 static int THPFunction_clear(THPFunction *self)
 {
-  self->cdata.set_num_inputs(0);
+  self->cdata.clear_input_metadata();
 
   Py_CLEAR(self->needs_input_grad);
 
@@ -293,7 +293,6 @@
   new (&self->input_info) std::vector<VariableInfo>();
   new (&self->saved_variables) std::vector<SavedVariable>();
   new (&self->is_variable_input) std::vector<bool>();
-  self->cdata.set_num_inputs(0);
   return obj;
 }
 
@@ -425,6 +424,10 @@
     // Note that output Variables may be repeated. In that case, the last call
     // to set_history wins.
     auto var = as_variable(obj, i);
+    if (cdata) {
+      auto output_nr = cdata->add_input_metadata(var.type(), var.sizes());
+      TORCH_ASSERT(i == (int)output_nr);
+    }
     set_history(var, i, is_input, is_modified, is_differentiable);
 
     if (is_executable) {
@@ -616,7 +619,7 @@
   THPObjectPtr outputs(PyTuple_New(num_outputs));
   if (!outputs) throw python_error();
 
-  grad_fn->cdata.set_num_inputs(num_outputs);
+  grad_fn->cdata.clear_input_metadata();
 
   // Record type, device, and size information about inputs
   if (is_executable) {
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h
index 562ab79..529bcaf 100644
--- a/torch/csrc/autograd/python_function.h
+++ b/torch/csrc/autograd/python_function.h
@@ -35,7 +35,7 @@
   variable_list legacy_apply(const variable_list& inputs);
 
   virtual void release_variables() override;
-  virtual std::string name() override;
+  virtual std::string name() const override;
   virtual std::shared_ptr<Function> get_shared_ptr() override;
   virtual bool is_traceable() override;
 
diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp
index 179c54a..33dd281 100644
--- a/torch/csrc/autograd/python_legacy_variable.cpp
+++ b/torch/csrc/autograd/python_legacy_variable.cpp
@@ -57,7 +57,7 @@
   Variable var;
   if (grad_fn) {
     auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
-    Edge edge(grad_fn_, grad_fn_->bump_inputs());
+    Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor.type(), tensor.sizes()));
     var = make_variable(std::move(tensor), std::move(edge));
   } else {
     var = make_variable(std::move(tensor), requires_grad);
diff --git a/torch/csrc/autograd/type_and_shape.h b/torch/csrc/autograd/type_and_shape.h
new file mode 100644
index 0000000..01a62fa
--- /dev/null
+++ b/torch/csrc/autograd/type_and_shape.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include "torch/csrc/assertions.h"
+
+namespace torch { namespace autograd {
+
+/// A tensor's type and shape. Each Function records the required type and
+/// shape of its inputs. If is_valid() is false, then the corresponding input
+/// is not used and may be an undefined tensor.
+struct TypeAndShape {
+  TypeAndShape() : type_(nullptr) {}
+
+  TypeAndShape(const at::Type& type, at::IntList shape)
+    : type_(&type) , shape_(shape) {}
+
+  bool is_valid() const {
+    return type_ != nullptr;
+  }
+
+  const at::Type& type() const {
+    TORCH_ASSERT(type_);
+    return *type_;
+  }
+
+  at::IntList shape() const {
+    return shape_;
+  }
+
+  const at::Type* type_;
+  at::DimVector shape_;
+};
+
+}}
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index f0b4c1d..48a35f4 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -126,10 +126,12 @@
 }
 
 void Variable::Impl::set_data(Tensor new_data) {
-  data_ = std::move(new_data);
-  if (data_.type() != *type_) {
-    type_ = VariableType::getType(data_);
+  if (new_data.type() != data_.type()) {
+    type_ = VariableType::getType(new_data.type());
+    // Clear grad_accumulator if it exists, since it stores the old type info.
+    grad_accumulator_.reset();
   }
+  data_ = std::move(new_data);
 }
 
 Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
@@ -158,7 +160,7 @@
     fn->stride = strides();
     fn->storage_offset = data_.storage_offset();
     fn->set_next_edges(collect_next_edges(base_));
-    fn->set_num_inputs(1);
+    fn->add_input_metadata(base_.type(), sizes());
     grad_fn_ = std::move(fn);
     attr_version = current_version;
   }