Gradcheck forward AD respects requires grad but run with requires_grad=False (#72309)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72309

Fixes: https://github.com/pytorch/pytorch/issues/72113

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33991570

Pulled By: soulitzer

fbshipit-source-id: 610de162e9848d2d3b12e0fb039860fd9dee844f
(cherry picked from commit a7ecb13610a4e01d91a2ecff107ac1cf6cd94cba)
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index ddf5f44..225985d 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -792,6 +792,11 @@
   }
 }
 
+bool _requires_fw_or_bw_grad(const Tensor& input) {
+  return ((at::GradMode::is_enabled() && input.requires_grad())
+          || input._fw_grad(/*level */ 0).defined());
+}
+
 // Below of the definitions of the functions operating on a batch that are going to be dispatched
 // in the main helper functions for the linear algebra operations
 
@@ -2382,7 +2387,7 @@
 Tensor linalg_eigvalsh(const Tensor& input, c10::string_view uplo) {
   // if input requires grad we must compute the eigenvectors to make this function differentiable
   // the eigenvectors are not exposed to the user
-  if (at::GradMode::is_enabled() && input.requires_grad()) {
+  if (_requires_fw_or_bw_grad(input)) {
     Tensor values;
     std::tie(values, std::ignore) = at::linalg_eigh(input, uplo);
     return values;
@@ -2878,7 +2883,7 @@
 Tensor linalg_eigvals(const Tensor& input) {
   // if input requires grad we must compute the eigenvectors to make this function differentiable
   // the eigenvectors are not exposed to the user
-  if (at::GradMode::is_enabled() && input.requires_grad()) {
+  if (_requires_fw_or_bw_grad(input)) {
     return std::get<0>(at::linalg_eig(input));
   }
 
@@ -3063,10 +3068,7 @@
 }
 
 Tensor linalg_svdvals(const Tensor& A) {
-  const bool A_requires_grad = (at::GradMode::is_enabled() && A.requires_grad())
-                               || A._fw_grad(/*level */ 0).defined()
-                               || isTensorSubclassLike(A);
-  return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false, /*comptue_uv=*/A_requires_grad));
+  return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false, /*comptue_uv=*/_requires_fw_or_bw_grad(A)));
 }
 
 std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, bool compute_uv, Tensor& U, Tensor& S, Tensor& V) {
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 42fab2e..e6f88f5 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -798,7 +798,7 @@
     padding_idx = maybe_wrap_dim(padding_idx, weight.size(0));
   }
   std::tuple<Tensor, Tensor, Tensor, Tensor> out;
-  if (!weight.requires_grad()) {
+  if (!weight.requires_grad() && !weight._fw_grad(/*level=*/0).defined()) {
     out = at::_embedding_bag_forward_only(
       weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq,
       mode, sparse, per_sample_weights, include_last_offset, padding_idx);
diff --git a/aten/src/ATen/native/MaxPooling.cpp b/aten/src/ATen/native/MaxPooling.cpp
index 53f2a10..3e615d7 100644
--- a/aten/src/ATen/native/MaxPooling.cpp
+++ b/aten/src/ATen/native/MaxPooling.cpp
@@ -102,6 +102,7 @@
         self, kernel_size, stride, padding, dilation, ceil_mode);
   }
   if ((self.requires_grad() && at::GradMode::is_enabled()) ||
+      self._fw_grad(/*level */ 0).defined() ||
       !self.device().is_cpu()) {
     // Needs indices for grad and with_indices defines CUDA dispatch
     return std::get<0>(at::max_pool1d_with_indices(
diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp
index fac656f..5800bd2 100644
--- a/aten/src/ATen/native/mkldnn/Pooling.cpp
+++ b/aten/src/ATen/native/mkldnn/Pooling.cpp
@@ -252,7 +252,7 @@
   // for inference, don't need the indices, set aprop_kind to prop_kind::forward_inference
   // can reduce the memory use.
   if (ideep::algorithm::pooling_max == algo
-      && !(input.requires_grad() && at::GradMode::is_enabled())) {
+      && !((input.requires_grad() && at::GradMode::is_enabled()) || input._fw_grad(/*level */ 0).defined())) {
     aprop_kind = ideep::prop_kind::forward_inference;
   }
 
diff --git a/test/test_autograd.py b/test/test_autograd.py
index ad77b0c..2fe584a 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4064,6 +4064,82 @@
             with self.assertRaisesRegex(RuntimeError, err_msg):
                 gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
 
+    def test_gradcheck_forward_ad_runs_with_no_requires_grad(self):
+        # Currently requires_grad is used as a easy way for gradcheck to know
+        # which inputs of the function are meant to be differentiable
+        # This test checks that when the inputs are passed to the function they should not have
+        # requires_grad=True even though they may have requires_grad=True when passed
+        # to gradcheck
+        class UserFn(Function):
+            @staticmethod
+            def forward(ctx, x, y):
+                if fwAD._current_level >= 0:
+                    self.assertFalse(x.requires_grad)
+                    self.assertFalse(y.requires_grad)
+                return x.clone(), y.clone()
+
+            @staticmethod
+            def jvp(ctx, x_t, y_t):
+                return x_t, y_t
+
+        x = torch.rand(2, dtype=torch.double, requires_grad=True)
+        y = torch.rand(2, dtype=torch.double, requires_grad=True)
+
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=False)
+
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=False)
+
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=True)
+
+        x = torch.rand(2, dtype=torch.double, requires_grad=True)
+        y = torch.rand(2, dtype=torch.double, requires_grad=False)
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=True)
+
+    def test_gradcheck_forward_ad_respects_requires_grad(self):
+        # Currently requires_grad is used as a easy way for gradcheck to know
+        # which inputs of the function are meant to be differentiable
+        jvp_count = [0]
+
+        class UserFn(Function):
+            @staticmethod
+            def forward(ctx, x, y):
+                return x.clone(), y.clone()
+
+            @staticmethod
+            def jvp(ctx, x_t, y_t):
+                jvp_count[0] += 1
+                return x_t, y_t
+
+        x = torch.rand(2, dtype=torch.double, requires_grad=True)
+        y = torch.rand(2, dtype=torch.double, requires_grad=True)
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=False)
+        self.assertEqual(jvp_count[0], 2)  # (2) once per input
+        jvp_count = [0]
+
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=False)
+        self.assertEqual(jvp_count[0], 6)  # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2)
+        jvp_count = [0]
+
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=True)
+        self.assertEqual(jvp_count[0], 12)  # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2)
+        jvp_count = [0]
+
+        # Repeat the previous test except we mark one input with requires_grad=False
+        # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)!
+        #     Otherwise, other counts are halved.
+        x = torch.rand(2, dtype=torch.double, requires_grad=True)
+        y = torch.rand(2, dtype=torch.double, requires_grad=False)
+        gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+                  check_batched_grad=False, check_batched_forward_grad=True)
+        self.assertEqual(jvp_count[0], 5)  # 1 + 1 + 3
+
     def test_gradcheck_check_forward_or_backward_only(self):
         """Depending on settings for check_forward_ad and check_backward_ad, the
         correct codepaths should be reached (or not reached)
diff --git a/test/test_overrides.py b/test/test_overrides.py
index eb46b99..da013d3 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -851,6 +851,7 @@
                 'new_zeros',
                 'numel',
                 'requires_grad',
+                'requires_grad_',
                 'retain_grad',
                 'size',
                 'stride',
@@ -867,6 +868,7 @@
                 torch.Tensor.numel,
                 torch.Tensor.retain_grad,
                 torch.Tensor.stride,
+                torch.Tensor.requires_grad_,
                 torch.autograd.grad,
                 torch.add,
             }
diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py
index 1b61d01..3aac0fc 100644
--- a/torch/autograd/gradcheck.py
+++ b/torch/autograd/gradcheck.py
@@ -329,7 +329,7 @@
                 if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
                     raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
 
-                inp = fwAD.make_dual(inp, torch.zeros_like(inp))
+                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
                 # If inp is a differentiable view, the dual might not be the tangent given to
                 # make_dual, so read it explicitly from the dual tensor
                 fw_grads.append(fwAD.unpack_dual(inp)[1])
@@ -760,10 +760,14 @@
     assert isinstance(inputs, tuple)
 
     for input_idx, current_input in enumerate(inputs):
+        if not (is_tensor_like(current_input) and current_input.requires_grad):
+            continue
+
         def jvp(tangent: torch.Tensor):
             with fwAD.dual_level():
-                dual = fwAD.make_dual(current_input, tangent)
-                inputs_with_dual = tuple(dual if idx == input_idx else inp for idx, inp in enumerate(inputs))
+                dual = fwAD.make_dual(current_input.detach(), tangent)
+                inputs_with_dual = tuple(dual if idx == input_idx else (inp.detach() if is_tensor_like(inp) else inp)
+                                         for idx, inp in enumerate(inputs))
                 dual_outputs = _as_tuple(func(*inputs_with_dual))
                 ret = []
                 for dual_output in dual_outputs:
@@ -888,7 +892,7 @@
                 if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
                     raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
 
-                inp = fwAD.make_dual(inp, torch.zeros_like(inp))
+                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
                 # If inp is a differentiable view, the dual might not be the tangent given to
                 # make_dual, so read it explicitly from the dual tensor
                 fw_grads.append(fwAD.unpack_dual(inp)[1])
@@ -904,12 +908,12 @@
             dual_inp_obj = dual_inputs[idx]
 
             # case 1 (Materialized Zero Tensor Tangent)
-            dual_inputs[idx] = fwAD.make_dual(inp, torch.zeros_like(inp))
+            dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
             raw_outputs = _as_tuple(func(*dual_inputs))
             dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
 
             # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
-            dual_inputs[idx] = inp
+            dual_inputs[idx] = inp.detach()
             raw_outputs = _as_tuple(func(*dual_inputs))
             dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
 
@@ -1532,13 +1536,18 @@
 
     num_outputs = len(tupled_grad_outputs)
 
+    # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
+    #     before running forward mode AD
+    diff_input_args_indices = set(i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad)
+    diff_grad_output_indices = set(i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad)
+
     def new_func(*args):
-        input_args = args[:-num_outputs]
-        grad_outputs = args[-num_outputs:]
+        # Restore the requires_grad information
+        input_args = tuple(x.requires_grad_() if i in diff_input_args_indices else x for i, x in enumerate(args[:-num_outputs]))
         outputs = _differentiable_outputs(func(*input_args))
-        input_args = tuple(x for x in input_args
-                           if is_tensor_like(x) and x.requires_grad)
-        grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True,
+        grad_outputs = tuple(x.requires_grad_() if i in diff_grad_output_indices else x for i, x in enumerate(args[-num_outputs:]))
+        diff_input_args = tuple(x for i, x in enumerate(input_args) if i in diff_input_args_indices)
+        grad_inputs = torch.autograd.grad(outputs, diff_input_args, grad_outputs, create_graph=True,
                                           allow_unused=True)
         grad_inputs = tuple(g for g in grad_inputs if g is not None)
         return grad_inputs