Fix custom function forward AD internal assert (#71531)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71531

Based on the comment above the original internal assert, this is the desired check.
1. Don't error, and automatically make jvp return a view for that tensor output (this is easier than I originally thought: https://github.com/pytorch/pytorch/pull/71531#discussion_r789211877)
2. Error (currently doing)

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33695399

Pulled By: soulitzer

fbshipit-source-id: dba49890a55ad1dd59ed5c41faa96bf7cfc9e562
(cherry picked from commit fdb0f266f51e939e122676ab378f4cacba4295aa)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index e7c38e0..ecd9c87 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -5576,6 +5576,37 @@
         gradcheck(MyFn.apply, (1, x.requires_grad_(True), 1, y.requires_grad_(True)), check_forward_ad=True,
                   check_backward_ad=False, check_batched_grad=False)
 
+    def test_custom_function_forward_mode_forward_is_no_op(self):
+        for jvp_mul_by_2 in (True, False):
+            class MyFn(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, x, y):
+                    return x + y, x
+
+                @staticmethod
+                def vjp(ctx, gO1, gO2):
+                    return gO1 + gO2, gO1
+
+                @staticmethod
+                def jvp(ctx, x_t, y_t):
+                    if jvp_mul_by_2:
+                        # If the user returns input as-is, this result
+                        # isn't used!
+                        return x_t + y_t, x_t * 2
+                    else:
+                        return x_t + y_t, x_t
+
+            x = torch.tensor(1., dtype=torch.double, requires_grad=True)
+            t = torch.tensor(1., dtype=torch.double)
+            y = torch.tensor(1., dtype=torch.double, requires_grad=True)
+
+            with fwAD.dual_level():
+                x_dual = fwAD.make_dual(x, t)
+                _, out2 = MyFn.apply(x_dual, y)
+                self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t)
+
+            gradcheck(MyFn.apply, (x, y), check_forward_ad=True)
+
     def test_custom_function_local_inplace(self):
         class MyFn(torch.autograd.Function):
             @staticmethod
diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp
index be02659..af39227 100644
--- a/torch/csrc/autograd/custom_function.cpp
+++ b/torch/csrc/autograd/custom_function.cpp
@@ -147,7 +147,12 @@
       }
     } else {
       // At this point, outputs[i] cannot be one of the input (raw_outputs[i] might be but was changed by the backward code)
-      TORCH_INTERNAL_ASSERT(!is_input);
+      TORCH_INTERNAL_ASSERT(inputs_mapping.count(out.unsafeGetTensorImpl()) == 0);
+      if (is_input && !is_modified) {
+        // If the forward return an input as-is, since backward code performed a view without the
+        // forward no-grad guard, we are done.
+        continue;
+      }
 
       if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) {
         // If the output is a view