Add support for requires_grad in JIT's AD (#4898)

diff --git a/test/expect/TestJit.test_cpp.expect b/test/expect/TestJit.test_cpp.expect
index 8685e47..b089006 100644
--- a/test/expect/TestJit.test_cpp.expect
+++ b/test/expect/TestJit.test_cpp.expect
@@ -44,8 +44,30 @@
   %6 : Float(2, 3, 4) = add[alpha={1}](%4, %5)
   %7 : Float(2, 3, 4) = mul(%3, %2)
   %8 : Float(2, 3, 4) = mul(%6, %1)
-  %9 : Float(2, 3, 4) = add[alpha={1}](%7, %8)
-  %10 : Float(2, 3, 4) = mul(%6, %0)
-  %11 : Float(2, 3, 4) = add[alpha={1}](%3, %10)
-  return (%9, %11);
+  %9 : Float(2, 3, 4) = mul(%6, %0)
+  %10 : Float(2, 3, 4) = add[alpha={1}](%7, %8)
+  %11 : Float(2, 3, 4) = add[alpha={1}](%3, %9)
+  return (%10, %11);
 }
+
+testDifferentiateWithRequiresGrad
+graph(%0 : Float(2, 3, 4)
+      %1 : Float(2, 3, 4)) {
+  %2 : Float(2, 3, 4) = mul(%1, %1)
+  %3 : Float(2, 3, 4) = add[alpha={1}](%2, %1)
+  %4 : Float(2, 3, 4) = add[alpha={1}](%3, %0)
+  %5 : Float(2, 3, 4) = mul(%4, %0)
+  %6 : Float(2, 3, 4) = add[alpha={1}](%5, %1)
+  return (%3, %6, %4);
+}
+graph(%0 : Float(2, 3, 4)
+      %1 : Float(2, 3, 4)
+      %2 : Float(2, 3, 4)
+      %3 : Float(2, 3, 4)) {
+  %4 : Float(2, 3, 4) = mul(%2, %0)
+  %5 : Float(2, 3, 4) = add[alpha={1}](%3, %4)
+  %6 : Float(2, 3, 4) = mul(%2, %1)
+  %7 : Float(2, 3, 4) = add[alpha={1}](%6, %5)
+  return (%7);
+}
+
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 6b48611..53a42e8 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -1,18 +1,26 @@
 #include "torch/csrc/jit/autodiff.h"
 
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
 #include "torch/csrc/jit/symbolic_variable.h"
 #include "torch/csrc/utils/functional.h"
+#include "torch/csrc/utils/auto_gpu.h"
 
 namespace torch { namespace jit {
 
 using value_map = std::unordered_map<Value*, Value*>;
 using value_set = std::unordered_set<Value*>;
 
-Value* addAndPutAfter(Value *a, Value *b, Node *node) {
-  Graph *graph = node->owningGraph();
+// Creates a node for a + b and puts it after the given node.
+// If node is a null pointer, appends that node at the end of the node list.
+Value* addValues(Value *a, Value *b, Node *node = nullptr) {
+  Graph *graph = a->node()->owningGraph();
   Node *add_node = graph->create(kadd, {a, b})
                         ->t_(kalpha, at::Scalar(1).toTensor());
-  add_node->insertAfter(node);
+  if (node) {
+    add_node->insertAfter(node);
+  } else {
+    graph->appendNode(add_node);
+  }
   Value *add_output = add_node->output();
   add_output->setType(a->typeOption());
   return add_output;
@@ -45,6 +53,55 @@
   return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
 }
 
+static value_set findAllRequiresGradNodes(
+        Graph& graph, const std::vector<bool>& input_requires_grad) {
+  JIT_ASSERT(graph.inputs().size() == input_requires_grad.size());
+
+  std::unordered_set<Value*> requires_grad_set;
+  const auto requires_grad = [&](Value *v) { return requires_grad_set.count(v) > 0; };
+
+  auto inputs = graph.inputs();
+  for (std::size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
+    if (!input_requires_grad[i]) continue;
+    requires_grad_set.emplace(inputs[i]);
+  }
+
+  for (Node * node : graph.nodes()) {
+    if (std::none_of(node->inputs().begin(), node->inputs().end(), requires_grad)) continue;
+    for (Value * output : node->outputs())
+      requires_grad_set.emplace(output);
+  }
+
+  return requires_grad_set;
+}
+
+static Value* allocZerosLike(Value *v) {
+  static const Symbol constant_sym = "constant"_sym;
+  static const Symbol is_zero_sym  = "is_zero"_sym;
+  static const Symbol value_sym    = "value"_sym;
+  JIT_EXPECTM(v->hasType(), "can't allocate zero gradient for a value without a type");
+  Graph *graph = v->owningGraph();
+  auto type = v->type()->expect<TensorType>();
+  AutoGPU gpu_guard(type->device());
+
+  auto & at_type = type->device() == -1 ? at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
+  auto zeros = at_type.zeros({1}).expand(type->sizes());
+  Node *constant = graph->create(constant_sym)
+                        ->t_(value_sym, zeros)
+                        ->i_(is_zero_sym, 1);
+  graph->appendNode(constant);
+  return constant->output();
+}
+
+struct ReverseDetails {
+  ReverseDetails(value_map&& grad_map, value_set&& requires_grad_set)
+    : grad_map(std::move(grad_map))
+    , requires_grad_set(std::move(requires_grad_set)) {}
+
+  value_map grad_map;
+  value_set requires_grad_set;
+};
+
 // Before:
 //   - graph has only stage 0
 //   - grad_desc doesn't have any fields initialized
@@ -52,15 +109,25 @@
 //   - graph has stage 0 and stage 1 that computes its vjp
 //   - grad_desc has df_input_vjps and df_output_vjps set
 //     (but df_input_vjps will be modified later as well)
-static value_map addReverseInline(Graph& graph, Gradient& grad_desc) {
+static ReverseDetails addReverseInline(Graph& graph, Gradient& grad_desc,
+                                  const std::vector<bool>& input_requires_grad) {
   JIT_ASSERT(graph.stage() == 0);
   graph.advanceStage();
 
+  auto requires_grad_set = findAllRequiresGradNodes(graph, input_requires_grad);
+  const auto requires_grad = [&](Value *v) { return requires_grad_set.count(v) > 0; };
+
   value_map grad_map; // x -> dx mapping
-  const auto get_grad = [&](Value* v) { return grad_map.at(v); };
+  const auto get_grad = [&](Value* v) -> Value* {
+    auto it = grad_map.find(v);
+    if (it == grad_map.end()) {
+      std::tie(it, std::ignore) = grad_map.emplace(v, allocZerosLike(v));
+    }
+    return it->second;
+  };
   const auto set_grad = [&](Value *x, Value *dx) {
     if (Value * prev_grad = grad_map[x]) {
-      Value * new_grad = addAndPutAfter(prev_grad, dx, dx->node());
+      Value * new_grad = addValues(prev_grad, dx);
       grad_map[x] = new_grad;
     } else {
       grad_map[x] = dx;
@@ -70,6 +137,7 @@
   auto outputs = graph.outputs();
   for (std::size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
     Value * output = outputs[i];
+    if (!requires_grad(output)) continue;
     Value * output_grad = graph.addInput()->setType(output->typeOption());
     set_grad(output, output_grad);
     grad_desc.df_input_vjps.push_back(i);
@@ -78,6 +146,7 @@
   for (auto it = graph.rbegin(), end = graph.rend(); it != end; ++it) {
     Node *node = *it;
     auto inputs = node->inputs();
+    if (std::none_of(inputs.begin(), inputs.end(), requires_grad)) continue;
     value_list grad_inputs = gradientForNode(node, fmap(node->outputs(), get_grad));
     JIT_ASSERT(grad_inputs.size() == node->inputs().size());
     for (std::size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
@@ -89,11 +158,12 @@
   for (std::size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
     Value * input = inputs[i];
     if (input->stage() > 0) break;
+    if (!requires_grad(input)) continue;
     graph.registerOutput(get_grad(input));
     grad_desc.df_output_vjps.push_back(i);
   }
 
-  return grad_map;
+  return ReverseDetails(std::move(grad_map), std::move(requires_grad_set));
 }
 
 // This function will take the graph and return a new one that:
@@ -138,7 +208,9 @@
 // detailed description see Note [Gradient graphs] in autodiff.h.
 // This function also initializes the fields in grad_desc that were undefined after
 // `addReverseInline` (and modifies `df_input_vjps`).
-static void lambdaLiftReverse(Graph& graph, value_map& grad_map, Gradient& grad_desc) {
+static void lambdaLiftReverse(Graph& graph,
+                              ReverseDetails& rev_info,
+                              Gradient& grad_desc) {
   static const auto is_stage_0 = [](Value *v) { return v->stage() == 0; };
   static const auto is_stage_1 = [](Value *v) { return v->stage() == 1; };
   // NOTE: in the comments inside this function first stage is stage 0
@@ -227,13 +299,14 @@
   JIT_ASSERT(graph.stage() == 1); // We will be adding inputs to stage 1
   for (std::size_t i = grad_desc.f_real_outputs; i < primal_outputs.size(); ++i) {
     Value * tmp = primal_outputs.at(i);
+    // Add VJP inputs only for intermediates that actually required grad.
+    if (rev_info.requires_grad_set.count(tmp) == 0) continue;
     Value * tmp_vjp_in = graph.addInput()->setType(tmp->typeOption());
-    if (grad_map.count(tmp) == 0) continue; // This gradient wasn't even used.
-    Value * tmp_vjp_prev = grad_map.at(tmp);
+    Value * tmp_vjp_prev = rev_info.grad_map.at(tmp);
     // This is quite weird because we can't first make a sum and then replace all uses
     // of tmp_vjp_prev (that would replace its use in the sum too!), so we create an
     // incorrect sum that doesn't use prev vjp, replace uses, and fix the sum.
-    Value * new_vjp = addAndPutAfter(tmp_vjp_in, tmp_vjp_in, tmp_vjp_prev->node());
+    Value * new_vjp = addValues(tmp_vjp_in, tmp_vjp_in, tmp_vjp_prev->node());
     tmp_vjp_prev->replaceAllUsesWith(new_vjp);
     new_vjp->node()->replaceInput(1, tmp_vjp_prev);
     grad_desc.df_input_vjps.emplace_back(i);
@@ -257,7 +330,7 @@
   grad_desc.df = splitOffStage(graph, 1, reverse_inputs, reverse_outputs);
 }
 
-Gradient differentiate(std::shared_ptr<Graph>& _graph) {
+Gradient differentiate(std::shared_ptr<Graph>& _graph, const std::vector<bool>& requires_grad) {
   // Take ownership of the graph
   std::shared_ptr<Graph> graph;
   JIT_ASSERTM(_graph.use_count() == 1,
@@ -267,10 +340,14 @@
   // XXX: Take care when handling outputs - they can be duplicated!
   Gradient grad_desc;
   // Fills in df_input_vjps and df_output_vjps
-  auto grad_map = addReverseInline(*graph, grad_desc);
+  auto rev_info = addReverseInline(*graph, grad_desc, requires_grad);
+  // addReverseInline has to call gradientForNode if *any* of the outputs
+  // require grad, but it will emit vjps for *all* outputs. Use DCE to remove
+  // unnecessary nodes.
+  EliminateDeadCode(graph);
   // Fills in f, df, f_real_outputs, df_input_captures,
   // modifies df_input_vjps (new vjps are added for temporaries)
-  lambdaLiftReverse(*graph, grad_map, grad_desc);
+  lambdaLiftReverse(*graph, rev_info, grad_desc);
   return grad_desc;
 }
 
diff --git a/torch/csrc/jit/autodiff.h b/torch/csrc/jit/autodiff.h
index 9ee9c99..90014fa 100644
--- a/torch/csrc/jit/autodiff.h
+++ b/torch/csrc/jit/autodiff.h
@@ -80,7 +80,7 @@
   //   - Interpret df
   //   - Wrap outputs of df into Variables (that don't require grad)
 };
-Gradient differentiate(std::shared_ptr<Graph>& graph);
+Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& requires_grad);
 
 // can we take a derivative of this node symbolically?
 bool isDifferentiable(Node * n);
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index 1227870..5af14c4 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -545,7 +545,7 @@
 
     // Trace and differentiate the op
     auto graph = trace(test, vars_in);
-    auto grad_spec = differentiate(graph);
+    auto grad_spec = differentiate(graph, std::vector<bool>(vars_in.size(), true));
 
     // Get outputs from the interpreter
     auto tensors_in                = fmap(vars_in, unwrap);
@@ -582,7 +582,7 @@
   auto c = a * b * a + b;
   graph->registerOutput(c.value());
 
-  auto grad_spec = differentiate(graph);
+  auto grad_spec = differentiate(graph, {true, true});
   std::vector<Capture> expected_captures = {
     {Capture::Kind::Input, 0},
     {Capture::Kind::Input, 1},
@@ -597,6 +597,37 @@
   out << "testDifferentiate\n";
   out << *grad_spec.f;
   out << *grad_spec.df;
+  out << "\n";
+}
+
+void testDifferentiateWithRequiresGrad(std::ostream & out) {
+  auto graph = std::make_shared<Graph>();
+  at::ScalarType s = at::ScalarType::Float;
+  auto type = std::shared_ptr<TensorType>(new TensorType(s, -1, {2, 3, 4}, {12, 4, 1}));
+
+  // Build up a fake graph
+  auto a = SymbolicVariable::asNewInput(*graph, type);
+  auto b = SymbolicVariable::asNewInput(*graph, type);
+  auto d = b * b + b;
+  auto e = (d + a) * a + b;
+  graph->registerOutput(d.value());
+  graph->registerOutput(e.value());
+
+  auto grad_spec = differentiate(graph, {true, false});
+  std::vector<Capture> expected_captures = {
+    {Capture::Kind::Input, 0},
+    {Capture::Kind::Output, 2},
+  };
+  std::vector<std::size_t> expected_input_vjps = {1, 2};  // for e and %4 = (d + a)
+  std::vector<std::size_t> expected_output_vjps = {0};    // only a requires grad
+  JIT_ASSERT(grad_spec.f_real_outputs == 2);              // we need one temporary %4 = (d + a)
+  JIT_ASSERT(grad_spec.df_input_captures == expected_captures);
+  JIT_ASSERT(grad_spec.df_input_vjps == expected_input_vjps);
+  JIT_ASSERT(grad_spec.df_output_vjps == expected_output_vjps);
+  out << "testDifferentiateWithRequiresGrad\n";
+  out << *grad_spec.f;
+  out << *grad_spec.df;
+  out << "\n";
 }
 
 void testCreateAutodiffSubgraphs(std::ostream & out) {
@@ -706,6 +737,7 @@
   std::stringstream out;
   testCreateAutodiffSubgraphs(out);
   testDifferentiate(out);
+  testDifferentiateWithRequiresGrad(out);
   testADFormulas();
   interpTest();
   interpStageTest();