Properly move retains_grad hook on in-place over view for base (#117552)
Fixes https://github.com/pytorch/pytorch/issues/117366
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117552
Approved by: https://github.com/albanD
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 445a35a..043ac4e 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -1718,6 +1718,24 @@
a.sum().backward()
self.assertEqual(a.grad, torch.tensor([1.]))
+ # When in-place over view is done, the retains_grad hooks should be
+ # moved from base's original grad_fn to the copyslices node.
+ x = torch.tensor([1.], requires_grad=True).clone()
+ x.retain_grad()
+ x_view = x[:]
+ x_view *= 2
+ x *= 2
+ x.sum().backward()
+ # The grad is 1, not 4, because we are computing grad wrt the latest
+ # version of x.
+ self.assertEqual(a.grad, torch.tensor([1.]))
+
+ # If the base did not originally require grad, there should be no hook
+ # to move. Make sure this case runs without error.
+ x = torch.zeros(4)
+ y = x.view(2, 2)
+ y.add_(torch.randn(2, 2, requires_grad=True))
+
def test_retains_grad_inplace_multiple_outputs(self):
class DoubleMul(Function):
@staticmethod
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index 81c0f19..821eea0 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -239,6 +239,11 @@
at::TensorGeometry(self),
view_info.view_fn_,
std::move(gradient_edge.function));
+ if (self.requires_grad()) {
+ // If self did not previously require grad, there are no hooks to move
+ torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
+ view_info.base_, view_info.base_.grad_fn(), copy_slices);
+ }
set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
self.grad_fn(); // trigger an update to the view's grad_fn
return;