Release saved variable from DifferentiableGraphBackward (#42994)
Summary:
When the backward ops execute via the autograd engine evaluate_function(), the fn.release_variables() is called to release the SavedVariables. For the eager mode ops, this releases the saved inputs that was required for backward grad function. However, with TorchScript, we get a DifferentableGraph and the DifferentiableGraphBackward() doesn't implement a release_variables(). This leads to the SavedVariables to be alive longer. Implement release_variables() for DifferentiableGraphBackward to release these SavedVariables early.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42994
Reviewed By: izdeby
Differential Revision: D23503172
Pulled By: albanD
fbshipit-source-id: d87127498cfa72883ae6bb31d0e6c7056c4c36d4
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index f9a38cc..ac21379 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -93,7 +93,7 @@
y = torch.ones([1], requires_grad=True)
broadcast_f(x, y)
b = broadcast_f(x, y)
- b.backward(torch.ones([2, 2], dtype=torch.float))
+ b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True)
b.backward(torch.ones([2, 2], dtype=torch.float))
# warmup_backward(b, torch.ones([2, 2], dtype=torch.float))
g = torch.jit.last_executed_optimized_graph()
diff --git a/test/test_jit.py b/test/test_jit.py
index f070e7b..27b7dc2 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -950,6 +950,38 @@
checkBackwardScript(test_torch_autograd_backward, (inp,))
checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,))
+ def test_script_backward_twice(self):
+ def checkBackwardTwiceScript(fn, inputs, retain_graph_=False):
+ torch._C._jit_set_profiling_executor(False)
+
+ with torch.jit.optimized_execution(True):
+ scripted_fn = torch.jit.script(fn, inputs)
+ FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs))
+
+ result = scripted_fn(*inputs)
+ result.sum().backward(retain_graph=retain_graph_)
+ if not retain_graph_:
+ self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
+ lambda: result.sum().backward())
+ else:
+ result.sum().backward()
+
+ def test_script_backward_twice_with_saved_values(input1, input2):
+ # type: (Tensor, Tensor) -> Tensor
+ tmp1 = torch.mul(input1, input2)
+ tmp2 = torch.abs(tmp1)
+ if torch.equal(input1, input2):
+ tmp2 = torch.acos(tmp2)
+ else:
+ tmp2 = torch.atan(tmp2)
+ result = torch.add(tmp2, input2)
+ return result
+
+ inp1 = torch.randn(2, 2, requires_grad=True)
+ inp2 = torch.randn(2, 2, requires_grad=True)
+ checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False)
+ checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True)
+
def test_diff_subgraph_clones_constants(self):
@torch.jit.script
def f(x, y):
diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp
index b7237ee..36dd8b4 100644
--- a/torch/csrc/jit/runtime/graph_executor.cpp
+++ b/torch/csrc/jit/runtime/graph_executor.cpp
@@ -159,6 +159,12 @@
}
}
+ void release_variables() {
+ for (auto& var_capture_ : var_captures_) {
+ var_capture_.reset_data();
+ }
+ }
+
private:
enum Capture : uint8_t {
CAPTURE_TENSOR,
@@ -311,6 +317,10 @@
}
}
+ void release_variables() override {
+ captures_.release_variables();
+ }
+
private:
void produceOutput(size_t i, at::Tensor output, variable_list& outputs) {
if (should_compute_output(i)) {