AOTAutograd: support mutations on buffers that happen during the bw (#114953)

Re-land of https://github.com/pytorch/pytorch/pull/112906

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114953
Approved by: https://github.com/zou3519, https://github.com/drisspg
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 2d57f0e..b27a0f5 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -1463,6 +1463,91 @@
         with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"):
             self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
 
+    # Mutations in the backward are allowed as long as the mutated object does not require grad
+    def test_backward_mutation_data(self):
+        class BwMutation(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, x):
+                ctx.save_for_backward(x)
+                return x.clone()
+
+            @staticmethod
+            def backward(ctx, grad_output):
+                x, = ctx.saved_tensors
+                # bw mutation
+                x.mul_(2)
+                return grad_output.clone()
+
+        def f(a, b):
+            out = BwMutation.apply(b)
+            return a * out
+
+        inp_no_grad = [
+            torch.ones(3, 3, requires_grad=True),
+            torch.ones(3, 3, requires_grad=False),
+        ]
+
+        # Mutation on buffer that does not require grad during the backward is allowed
+        self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
+
+        inp_grad = [
+            torch.ones(3, 3, requires_grad=True),
+            torch.ones(3, 3, requires_grad=True),
+        ]
+        with self.assertRaisesRegex(AssertionError, "input that requires_grad and was mutated in the backward"):
+            self.verify_aot_autograd(f, inp_grad, test_mutation=True)
+
+    def test_backward_mutation_metadata(self):
+        class BwMutation(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, a, b):
+                ctx.save_for_backward(b)
+                return a.clone(), b.clone()
+
+            @staticmethod
+            def backward(ctx, grad_a, grad_b):
+                b, = ctx.saved_tensors
+                # bw metadata mutation
+                b.transpose_(1, 0)
+                return grad_a.clone(), grad_b.clone()
+
+        def f(a, b):
+            a_, b_ = BwMutation.apply(a, b)
+            out = a_ * b_
+            return out
+
+        inp_no_grad = [
+            torch.ones(3, 3, requires_grad=True),
+            torch.ones(3, 3, requires_grad=False),
+        ]
+
+        with self.assertRaisesRegex(AssertionError, "input that had its metadata mutated in the backward"):
+            self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
+
+    def test_backward_mutation_on_grad_out(self):
+        class BwMutation(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, x):
+                return x.clone()
+
+            @staticmethod
+            def backward(ctx, grad_output):
+                grad_output.mul_(2)
+                return grad_output.clone()
+
+        def f(a, b):
+            tmp = a * b
+            out = BwMutation.apply(tmp)
+            return out
+
+        inp_grad = [
+            torch.ones(3, 3, requires_grad=True),
+            torch.ones(3, 3, requires_grad=True),
+        ]
+        f_compiled = aot_function(f, nop)
+        with self.assertRaisesRegex(AssertionError, "input to the backward that was mutated during the backward"):
+            out = f_compiled(*inp_grad)
+
     # Partially addresses https://github.com/pytorch/pytorch/issues/106457
     def test_input_mutation_false_aliasing(self):
         def f(a, b):
diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py
index 9d465f6..2dc7ff1 100644
--- a/torch/_functorch/_aot_autograd/traced_function_transforms.py
+++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py
@@ -29,7 +29,14 @@
 
 from .. import config
 from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
-from .functional_utils import from_fun, is_fun, sync_functional_tensor, to_fun
+from .functional_utils import (
+    from_fun,
+    has_data_mutation,
+    has_metadata_mutation,
+    is_fun,
+    sync_functional_tensor,
+    to_fun,
+)
 from .logging_utils import setup_stacktrace_preservation_hooks
 from .schemas import (
     AOTConfig,
@@ -347,6 +354,56 @@
             # Run the joint
             f_outs = fn(*f_args)
 
+        if trace_joint:
+            # We support a limited amount of mutation of graph inputs during the backward pass.
+            # (This is used e.g. by Float8, which needs to update buffers during the backward pass)
+            # Here, we perform extra checks for primals that were mutated in the **backward**
+            # We're doing the checks here instead of doing them with the rest of the input mutation handling because:
+            # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
+            #   during the forward, because the handling is different: some input mutations from the the forward
+            #   can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
+            #   types of mutations in the backward we would need a bw-only runtime epilogue.
+            # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
+            #   the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
+            #   require an extra round of tracing though, so it's more efficient to do in-line here.
+            assert (
+                isinstance(args, tuple)
+                and len(args) == 2
+                and isinstance(args[0], (list, tuple))
+            )
+            # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
+            primals_before = args[0]
+            primals_after = pytree.tree_map(from_fun, f_args[0])
+            for f_inpt, before, after, inpt_info in zip(
+                f_args[0], primals_before, primals_after, meta.input_info
+            ):
+                # Ban metadata mutations on fw inputs during the bw
+                if not inpt_info.mutates_metadata:
+                    assert not has_metadata_mutation(
+                        f_inpt, before, check_only_storage_mutation=False
+                    ), "Found a graph input that had its metadata mutated in the backward. This is not supported"
+                # Allow data mutations on fw inputs during the bw, but only if they do not require grad
+                # So we can guarantee that we can keep the mutations in the graph
+                if has_data_mutation(f_inpt) and not inpt_info.mutates_data:
+                    assert (
+                        not inpt_info.requires_grad
+                    ), "Found a graph input that requires_grad and was mutated in the backward. This is not supported"
+                    # Otherwise, put the mutation in the graph
+                    before.copy_(after)
+            # Now that we covered mutations to *forward* inputs during the backward,
+            # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
+            # Today, we will just error in all cases of this happening unless someone needs us to support it.
+            tangents_before = args[1]
+            tangents_after = pytree.tree_map(from_fun, f_args[1])
+            for f_inpt, before, after in zip(
+                f_args[1], tangents_before, tangents_after
+            ):
+                assert not has_metadata_mutation(
+                    f_inpt, before, check_only_storage_mutation=False
+                ) and not has_data_mutation(
+                    f_inpt
+                ), "Found an input to the backward that was mutated during the backward pass. This is not supported"
+
         if aot_config.keep_inference_input_mutations:
             # Note: This is a bit annoying. There's a layering issue here, where:
             # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.