Fix custom function forward AD internal assert (#71531)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71531
Based on the comment above the original internal assert, this is the desired check.
1. Don't error, and automatically make jvp return a view for that tensor output (this is easier than I originally thought: https://github.com/pytorch/pytorch/pull/71531#discussion_r789211877)
2. Error (currently doing)
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D33695399
Pulled By: soulitzer
fbshipit-source-id: dba49890a55ad1dd59ed5c41faa96bf7cfc9e562
(cherry picked from commit fdb0f266f51e939e122676ab378f4cacba4295aa)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index e7c38e0..ecd9c87 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -5576,6 +5576,37 @@
gradcheck(MyFn.apply, (1, x.requires_grad_(True), 1, y.requires_grad_(True)), check_forward_ad=True,
check_backward_ad=False, check_batched_grad=False)
+ def test_custom_function_forward_mode_forward_is_no_op(self):
+ for jvp_mul_by_2 in (True, False):
+ class MyFn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, y):
+ return x + y, x
+
+ @staticmethod
+ def vjp(ctx, gO1, gO2):
+ return gO1 + gO2, gO1
+
+ @staticmethod
+ def jvp(ctx, x_t, y_t):
+ if jvp_mul_by_2:
+ # If the user returns input as-is, this result
+ # isn't used!
+ return x_t + y_t, x_t * 2
+ else:
+ return x_t + y_t, x_t
+
+ x = torch.tensor(1., dtype=torch.double, requires_grad=True)
+ t = torch.tensor(1., dtype=torch.double)
+ y = torch.tensor(1., dtype=torch.double, requires_grad=True)
+
+ with fwAD.dual_level():
+ x_dual = fwAD.make_dual(x, t)
+ _, out2 = MyFn.apply(x_dual, y)
+ self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t)
+
+ gradcheck(MyFn.apply, (x, y), check_forward_ad=True)
+
def test_custom_function_local_inplace(self):
class MyFn(torch.autograd.Function):
@staticmethod
diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp
index be02659..af39227 100644
--- a/torch/csrc/autograd/custom_function.cpp
+++ b/torch/csrc/autograd/custom_function.cpp
@@ -147,7 +147,12 @@
}
} else {
// At this point, outputs[i] cannot be one of the input (raw_outputs[i] might be but was changed by the backward code)
- TORCH_INTERNAL_ASSERT(!is_input);
+ TORCH_INTERNAL_ASSERT(inputs_mapping.count(out.unsafeGetTensorImpl()) == 0);
+ if (is_input && !is_modified) {
+ // If the forward return an input as-is, since backward code performed a view without the
+ // forward no-grad guard, we are done.
+ continue;
+ }
if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) {
// If the output is a view