Release saved variable from DifferentiableGraphBackward (#42994)

Summary:
When the backward ops execute via the autograd engine evaluate_function(), the fn.release_variables() is called to release the SavedVariables. For the eager mode ops, this releases the saved inputs that was required for backward grad function. However, with TorchScript, we get a DifferentableGraph and the DifferentiableGraphBackward() doesn't implement a release_variables(). This leads to the SavedVariables to be alive longer. Implement release_variables() for DifferentiableGraphBackward to release these SavedVariables  early.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42994

Reviewed By: izdeby

Differential Revision: D23503172

Pulled By: albanD

fbshipit-source-id: d87127498cfa72883ae6bb31d0e6c7056c4c36d4
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index f9a38cc..ac21379 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -93,7 +93,7 @@
         y = torch.ones([1], requires_grad=True)
         broadcast_f(x, y)
         b = broadcast_f(x, y)
-        b.backward(torch.ones([2, 2], dtype=torch.float))
+        b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True)
         b.backward(torch.ones([2, 2], dtype=torch.float))
         # warmup_backward(b, torch.ones([2, 2], dtype=torch.float))
         g = torch.jit.last_executed_optimized_graph()
diff --git a/test/test_jit.py b/test/test_jit.py
index f070e7b..27b7dc2 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -950,6 +950,38 @@
         checkBackwardScript(test_torch_autograd_backward, (inp,))
         checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,))
 
+    def test_script_backward_twice(self):
+        def checkBackwardTwiceScript(fn, inputs, retain_graph_=False):
+            torch._C._jit_set_profiling_executor(False)
+
+            with torch.jit.optimized_execution(True):
+                scripted_fn = torch.jit.script(fn, inputs)
+                FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs))
+
+                result = scripted_fn(*inputs)
+                result.sum().backward(retain_graph=retain_graph_)
+                if not retain_graph_:
+                    self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
+                                           lambda: result.sum().backward())
+                else:
+                    result.sum().backward()
+
+        def test_script_backward_twice_with_saved_values(input1, input2):
+            # type: (Tensor, Tensor) -> Tensor
+            tmp1 = torch.mul(input1, input2)
+            tmp2 = torch.abs(tmp1)
+            if torch.equal(input1, input2):
+                tmp2 = torch.acos(tmp2)
+            else:
+                tmp2 = torch.atan(tmp2)
+            result = torch.add(tmp2, input2)
+            return result
+
+        inp1 = torch.randn(2, 2, requires_grad=True)
+        inp2 = torch.randn(2, 2, requires_grad=True)
+        checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False)
+        checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True)
+
     def test_diff_subgraph_clones_constants(self):
         @torch.jit.script
         def f(x, y):
diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp
index b7237ee..36dd8b4 100644
--- a/torch/csrc/jit/runtime/graph_executor.cpp
+++ b/torch/csrc/jit/runtime/graph_executor.cpp
@@ -159,6 +159,12 @@
     }
   }
 
+  void release_variables() {
+    for (auto& var_capture_ : var_captures_) {
+      var_capture_.reset_data();
+    }
+  }
+
  private:
   enum Capture : uint8_t {
     CAPTURE_TENSOR,
@@ -311,6 +317,10 @@
     }
   }
 
+  void release_variables() override {
+    captures_.release_variables();
+  }
+
  private:
   void produceOutput(size_t i, at::Tensor output, variable_list& outputs) {
     if (should_compute_output(i)) {