Revert "[test] AOTAutograd: support mutations on buffers that happen during th bw (#112906)"

This reverts commit c8974d649d684a33a5c02a0b112a6e0743201d97.

Reverted https://github.com/pytorch/pytorch/pull/112906 on behalf of https://github.com/huydhn due to There are lots of failure after this change https://hud.pytorch.org/pytorch/pytorch/commit/c8974d649d684a33a5c02a0b112a6e0743201d97, this is probably a landrace ([comment](https://github.com/pytorch/pytorch/pull/112906#issuecomment-1831016362))
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index e2edfbc..ab310c2 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -1438,91 +1438,6 @@
         with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"):
             self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
 
-    # Mutations in the backward are allowed as long as the mutated object does not require grad
-    def test_backward_mutation_data(self):
-        class BwMutation(torch.autograd.Function):
-            @staticmethod
-            def forward(ctx, x):
-                ctx.save_for_backward(x)
-                return x.clone()
-
-            @staticmethod
-            def backward(ctx, grad_output):
-                x, = ctx.saved_tensors
-                # bw mutation
-                x.mul_(2)
-                return grad_output.clone()
-
-        def f(a, b):
-            out = BwMutation.apply(b)
-            return a * out
-
-        inp_no_grad = [
-            torch.ones(3, 3, requires_grad=True),
-            torch.ones(3, 3, requires_grad=False),
-        ]
-
-        # Mutation on buffer that does not require grad during the backward is allowed
-        self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
-
-        inp_grad = [
-            torch.ones(3, 3, requires_grad=True),
-            torch.ones(3, 3, requires_grad=True),
-        ]
-        with self.assertRaisesRegex(AssertionError, "input that requires_grad and was mutated in the backward"):
-            self.verify_aot_autograd(f, inp_grad, test_mutation=True)
-
-    def test_backward_mutation_metadata(self):
-        class BwMutation(torch.autograd.Function):
-            @staticmethod
-            def forward(ctx, a, b):
-                ctx.save_for_backward(b)
-                return a.clone(), b.clone()
-
-            @staticmethod
-            def backward(ctx, grad_a, grad_b):
-                b, = ctx.saved_tensors
-                # bw metadata mutation
-                b.transpose_(1, 0)
-                return grad_a.clone(), grad_b.clone()
-
-        def f(a, b):
-            a_, b_ = BwMutation.apply(a, b)
-            out = a_ * b_
-            return out
-
-        inp_no_grad = [
-            torch.ones(3, 3, requires_grad=True),
-            torch.ones(3, 3, requires_grad=False),
-        ]
-
-        with self.assertRaisesRegex(AssertionError, "input that had its metadata mutated in the backward"):
-            self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
-
-    def test_backward_mutation_on_grad_out(self):
-        class BwMutation(torch.autograd.Function):
-            @staticmethod
-            def forward(ctx, x):
-                return x.clone()
-
-            @staticmethod
-            def backward(ctx, grad_output):
-                grad_output.mul_(2)
-                return grad_output.clone()
-
-        def f(a, b):
-            tmp = a * b
-            out = BwMutation.apply(tmp)
-            return out
-
-        inp_grad = [
-            torch.ones(3, 3, requires_grad=True),
-            torch.ones(3, 3, requires_grad=True),
-        ]
-        f_compiled = aot_function(f, nop)
-        with self.assertRaisesRegex(AssertionError, "input to the backward that was mutated during the backward"):
-            out = f_compiled(*inp_grad)
-
     # Partially addresses https://github.com/pytorch/pytorch/issues/106457
     def test_input_mutation_false_aliasing(self):
         def f(a, b):
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index f1c1d02..1bfd0c4 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -1910,43 +1910,6 @@
             # Run the joint
             f_outs = fn(*f_args)
 
-        if trace_joint:
-            # We support a limited amount of mutation of graph inputs during the backward pass.
-            # (This is used e.g. by Float8, which needs to update buffers during the backward pass)
-            # Here, we perform extra checks for primals that were mutated in the **backward**
-            # We're doing the checks here instead of doing them with the rest of the input mutation handling because:
-            # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
-            #   during the forward, because the handling is different: some input mutations from the the forward
-            #   can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
-            #   types of mutations in the backward we would need a bw-only runtime epilogue.
-            # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
-            #   the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
-            #   require an extra round of tracing though, so it's more efficient to do in-line here.
-            assert isinstance(args, tuple) and len(args) == 2 and isinstance(args[0], (list, tuple))
-            # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
-            primals_before = args[0]
-            primals_after = pytree.tree_map(from_fun, f_args[0])
-            for before, after, inpt_info in zip(primals_before, primals_after, meta.input_info):
-                # Ban metadata mutations on fw inputs during the bw
-                if not inpt_info.mutates_metadata:
-                    assert not was_metadata_updated(before, after), \
-                        "Found a graph input that had its metadata mutated in the backward. This is not supported"
-                # Allow data mutations on fw inputs during the bw, but only if they do not require grad
-                # So we can guarantee that we can keep the mutations in the graph
-                if was_updated(before, after) and not inpt_info.mutates_data:
-                    assert not inpt_info.requires_grad, \
-                        "Found a graph input that requires_grad and was mutated in the backward. This is not supported"
-                    # Otherwise, put the mutation in the graph
-                    before.copy_(after)
-            # Now that we covered mutations to *forward* inputs during the backward,
-            # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
-            # Today, we will just error in all cases of this happening unless someone needs us to support it.
-            tangents_before = args[1]
-            tangents_after = pytree.tree_map(from_fun, f_args[1])
-            for before, after in zip(tangents_before, tangents_after):
-                assert not was_metadata_updated(before, after) and not was_updated(before, after), \
-                    "Found an input to the backward that was mutated during the backward pass. This is not supported"
-
         if aot_config.keep_inference_input_mutations:
             # Note: This is a bit annoying. There's a layering issue here, where:
             # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
@@ -2230,7 +2193,7 @@
         if n.op == "placeholder":
             placeholders.add(n)
         if isinstance(n.target, torch._ops.OpOverload):
-            if n.target is aten.copy_.default:
+            if n.target is aten.copy_.default and allow_input_mutations:
                 suffix = True
                 # Can only copy_ into an input, and can only do so once
                 assert n.args[0] in placeholders