Gradcheck forward AD respects requires grad but run with requires_grad=False (#72309)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72309
Fixes: https://github.com/pytorch/pytorch/issues/72113
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D33991570
Pulled By: soulitzer
fbshipit-source-id: 610de162e9848d2d3b12e0fb039860fd9dee844f
(cherry picked from commit a7ecb13610a4e01d91a2ecff107ac1cf6cd94cba)
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index ddf5f44..225985d 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -792,6 +792,11 @@
}
}
+bool _requires_fw_or_bw_grad(const Tensor& input) {
+ return ((at::GradMode::is_enabled() && input.requires_grad())
+ || input._fw_grad(/*level */ 0).defined());
+}
+
// Below of the definitions of the functions operating on a batch that are going to be dispatched
// in the main helper functions for the linear algebra operations
@@ -2382,7 +2387,7 @@
Tensor linalg_eigvalsh(const Tensor& input, c10::string_view uplo) {
// if input requires grad we must compute the eigenvectors to make this function differentiable
// the eigenvectors are not exposed to the user
- if (at::GradMode::is_enabled() && input.requires_grad()) {
+ if (_requires_fw_or_bw_grad(input)) {
Tensor values;
std::tie(values, std::ignore) = at::linalg_eigh(input, uplo);
return values;
@@ -2878,7 +2883,7 @@
Tensor linalg_eigvals(const Tensor& input) {
// if input requires grad we must compute the eigenvectors to make this function differentiable
// the eigenvectors are not exposed to the user
- if (at::GradMode::is_enabled() && input.requires_grad()) {
+ if (_requires_fw_or_bw_grad(input)) {
return std::get<0>(at::linalg_eig(input));
}
@@ -3063,10 +3068,7 @@
}
Tensor linalg_svdvals(const Tensor& A) {
- const bool A_requires_grad = (at::GradMode::is_enabled() && A.requires_grad())
- || A._fw_grad(/*level */ 0).defined()
- || isTensorSubclassLike(A);
- return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false, /*comptue_uv=*/A_requires_grad));
+ return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false, /*comptue_uv=*/_requires_fw_or_bw_grad(A)));
}
std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, bool compute_uv, Tensor& U, Tensor& S, Tensor& V) {
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 42fab2e..e6f88f5 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -798,7 +798,7 @@
padding_idx = maybe_wrap_dim(padding_idx, weight.size(0));
}
std::tuple<Tensor, Tensor, Tensor, Tensor> out;
- if (!weight.requires_grad()) {
+ if (!weight.requires_grad() && !weight._fw_grad(/*level=*/0).defined()) {
out = at::_embedding_bag_forward_only(
weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq,
mode, sparse, per_sample_weights, include_last_offset, padding_idx);
diff --git a/aten/src/ATen/native/MaxPooling.cpp b/aten/src/ATen/native/MaxPooling.cpp
index 53f2a10..3e615d7 100644
--- a/aten/src/ATen/native/MaxPooling.cpp
+++ b/aten/src/ATen/native/MaxPooling.cpp
@@ -102,6 +102,7 @@
self, kernel_size, stride, padding, dilation, ceil_mode);
}
if ((self.requires_grad() && at::GradMode::is_enabled()) ||
+ self._fw_grad(/*level */ 0).defined() ||
!self.device().is_cpu()) {
// Needs indices for grad and with_indices defines CUDA dispatch
return std::get<0>(at::max_pool1d_with_indices(
diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp
index fac656f..5800bd2 100644
--- a/aten/src/ATen/native/mkldnn/Pooling.cpp
+++ b/aten/src/ATen/native/mkldnn/Pooling.cpp
@@ -252,7 +252,7 @@
// for inference, don't need the indices, set aprop_kind to prop_kind::forward_inference
// can reduce the memory use.
if (ideep::algorithm::pooling_max == algo
- && !(input.requires_grad() && at::GradMode::is_enabled())) {
+ && !((input.requires_grad() && at::GradMode::is_enabled()) || input._fw_grad(/*level */ 0).defined())) {
aprop_kind = ideep::prop_kind::forward_inference;
}
diff --git a/test/test_autograd.py b/test/test_autograd.py
index ad77b0c..2fe584a 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4064,6 +4064,82 @@
with self.assertRaisesRegex(RuntimeError, err_msg):
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
+ def test_gradcheck_forward_ad_runs_with_no_requires_grad(self):
+ # Currently requires_grad is used as a easy way for gradcheck to know
+ # which inputs of the function are meant to be differentiable
+ # This test checks that when the inputs are passed to the function they should not have
+ # requires_grad=True even though they may have requires_grad=True when passed
+ # to gradcheck
+ class UserFn(Function):
+ @staticmethod
+ def forward(ctx, x, y):
+ if fwAD._current_level >= 0:
+ self.assertFalse(x.requires_grad)
+ self.assertFalse(y.requires_grad)
+ return x.clone(), y.clone()
+
+ @staticmethod
+ def jvp(ctx, x_t, y_t):
+ return x_t, y_t
+
+ x = torch.rand(2, dtype=torch.double, requires_grad=True)
+ y = torch.rand(2, dtype=torch.double, requires_grad=True)
+
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=False)
+
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=False)
+
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=True)
+
+ x = torch.rand(2, dtype=torch.double, requires_grad=True)
+ y = torch.rand(2, dtype=torch.double, requires_grad=False)
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=True)
+
+ def test_gradcheck_forward_ad_respects_requires_grad(self):
+ # Currently requires_grad is used as a easy way for gradcheck to know
+ # which inputs of the function are meant to be differentiable
+ jvp_count = [0]
+
+ class UserFn(Function):
+ @staticmethod
+ def forward(ctx, x, y):
+ return x.clone(), y.clone()
+
+ @staticmethod
+ def jvp(ctx, x_t, y_t):
+ jvp_count[0] += 1
+ return x_t, y_t
+
+ x = torch.rand(2, dtype=torch.double, requires_grad=True)
+ y = torch.rand(2, dtype=torch.double, requires_grad=True)
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=False)
+ self.assertEqual(jvp_count[0], 2) # (2) once per input
+ jvp_count = [0]
+
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=False)
+ self.assertEqual(jvp_count[0], 6) # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2)
+ jvp_count = [0]
+
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=True)
+ self.assertEqual(jvp_count[0], 12) # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2)
+ jvp_count = [0]
+
+ # Repeat the previous test except we mark one input with requires_grad=False
+ # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)!
+ # Otherwise, other counts are halved.
+ x = torch.rand(2, dtype=torch.double, requires_grad=True)
+ y = torch.rand(2, dtype=torch.double, requires_grad=False)
+ gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
+ check_batched_grad=False, check_batched_forward_grad=True)
+ self.assertEqual(jvp_count[0], 5) # 1 + 1 + 3
+
def test_gradcheck_check_forward_or_backward_only(self):
"""Depending on settings for check_forward_ad and check_backward_ad, the
correct codepaths should be reached (or not reached)
diff --git a/test/test_overrides.py b/test/test_overrides.py
index eb46b99..da013d3 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -851,6 +851,7 @@
'new_zeros',
'numel',
'requires_grad',
+ 'requires_grad_',
'retain_grad',
'size',
'stride',
@@ -867,6 +868,7 @@
torch.Tensor.numel,
torch.Tensor.retain_grad,
torch.Tensor.stride,
+ torch.Tensor.requires_grad_,
torch.autograd.grad,
torch.add,
}
diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py
index 1b61d01..3aac0fc 100644
--- a/torch/autograd/gradcheck.py
+++ b/torch/autograd/gradcheck.py
@@ -329,7 +329,7 @@
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
- inp = fwAD.make_dual(inp, torch.zeros_like(inp))
+ inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
# If inp is a differentiable view, the dual might not be the tangent given to
# make_dual, so read it explicitly from the dual tensor
fw_grads.append(fwAD.unpack_dual(inp)[1])
@@ -760,10 +760,14 @@
assert isinstance(inputs, tuple)
for input_idx, current_input in enumerate(inputs):
+ if not (is_tensor_like(current_input) and current_input.requires_grad):
+ continue
+
def jvp(tangent: torch.Tensor):
with fwAD.dual_level():
- dual = fwAD.make_dual(current_input, tangent)
- inputs_with_dual = tuple(dual if idx == input_idx else inp for idx, inp in enumerate(inputs))
+ dual = fwAD.make_dual(current_input.detach(), tangent)
+ inputs_with_dual = tuple(dual if idx == input_idx else (inp.detach() if is_tensor_like(inp) else inp)
+ for idx, inp in enumerate(inputs))
dual_outputs = _as_tuple(func(*inputs_with_dual))
ret = []
for dual_output in dual_outputs:
@@ -888,7 +892,7 @@
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
- inp = fwAD.make_dual(inp, torch.zeros_like(inp))
+ inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
# If inp is a differentiable view, the dual might not be the tangent given to
# make_dual, so read it explicitly from the dual tensor
fw_grads.append(fwAD.unpack_dual(inp)[1])
@@ -904,12 +908,12 @@
dual_inp_obj = dual_inputs[idx]
# case 1 (Materialized Zero Tensor Tangent)
- dual_inputs[idx] = fwAD.make_dual(inp, torch.zeros_like(inp))
+ dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
raw_outputs = _as_tuple(func(*dual_inputs))
dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
# case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
- dual_inputs[idx] = inp
+ dual_inputs[idx] = inp.detach()
raw_outputs = _as_tuple(func(*dual_inputs))
dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
@@ -1532,13 +1536,18 @@
num_outputs = len(tupled_grad_outputs)
+ # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
+ # before running forward mode AD
+ diff_input_args_indices = set(i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad)
+ diff_grad_output_indices = set(i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad)
+
def new_func(*args):
- input_args = args[:-num_outputs]
- grad_outputs = args[-num_outputs:]
+ # Restore the requires_grad information
+ input_args = tuple(x.requires_grad_() if i in diff_input_args_indices else x for i, x in enumerate(args[:-num_outputs]))
outputs = _differentiable_outputs(func(*input_args))
- input_args = tuple(x for x in input_args
- if is_tensor_like(x) and x.requires_grad)
- grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True,
+ grad_outputs = tuple(x.requires_grad_() if i in diff_grad_output_indices else x for i, x in enumerate(args[-num_outputs:]))
+ diff_input_args = tuple(x for i, x in enumerate(input_args) if i in diff_input_args_indices)
+ grad_inputs = torch.autograd.grad(outputs, diff_input_args, grad_outputs, create_graph=True,
allow_unused=True)
grad_inputs = tuple(g for g in grad_inputs if g is not None)
return grad_inputs