AOTAutograd: support mutations on buffers that happen during the bw (#114953)
Re-land of https://github.com/pytorch/pytorch/pull/112906
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114953
Approved by: https://github.com/zou3519, https://github.com/drisspg
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 2d57f0e..b27a0f5 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -1463,6 +1463,91 @@
with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
+ # Mutations in the backward are allowed as long as the mutated object does not require grad
+ def test_backward_mutation_data(self):
+ class BwMutation(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return x.clone()
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x, = ctx.saved_tensors
+ # bw mutation
+ x.mul_(2)
+ return grad_output.clone()
+
+ def f(a, b):
+ out = BwMutation.apply(b)
+ return a * out
+
+ inp_no_grad = [
+ torch.ones(3, 3, requires_grad=True),
+ torch.ones(3, 3, requires_grad=False),
+ ]
+
+ # Mutation on buffer that does not require grad during the backward is allowed
+ self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
+
+ inp_grad = [
+ torch.ones(3, 3, requires_grad=True),
+ torch.ones(3, 3, requires_grad=True),
+ ]
+ with self.assertRaisesRegex(AssertionError, "input that requires_grad and was mutated in the backward"):
+ self.verify_aot_autograd(f, inp_grad, test_mutation=True)
+
+ def test_backward_mutation_metadata(self):
+ class BwMutation(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, a, b):
+ ctx.save_for_backward(b)
+ return a.clone(), b.clone()
+
+ @staticmethod
+ def backward(ctx, grad_a, grad_b):
+ b, = ctx.saved_tensors
+ # bw metadata mutation
+ b.transpose_(1, 0)
+ return grad_a.clone(), grad_b.clone()
+
+ def f(a, b):
+ a_, b_ = BwMutation.apply(a, b)
+ out = a_ * b_
+ return out
+
+ inp_no_grad = [
+ torch.ones(3, 3, requires_grad=True),
+ torch.ones(3, 3, requires_grad=False),
+ ]
+
+ with self.assertRaisesRegex(AssertionError, "input that had its metadata mutated in the backward"):
+ self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
+
+ def test_backward_mutation_on_grad_out(self):
+ class BwMutation(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ return x.clone()
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_output.mul_(2)
+ return grad_output.clone()
+
+ def f(a, b):
+ tmp = a * b
+ out = BwMutation.apply(tmp)
+ return out
+
+ inp_grad = [
+ torch.ones(3, 3, requires_grad=True),
+ torch.ones(3, 3, requires_grad=True),
+ ]
+ f_compiled = aot_function(f, nop)
+ with self.assertRaisesRegex(AssertionError, "input to the backward that was mutated during the backward"):
+ out = f_compiled(*inp_grad)
+
# Partially addresses https://github.com/pytorch/pytorch/issues/106457
def test_input_mutation_false_aliasing(self):
def f(a, b):
diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py
index 9d465f6..2dc7ff1 100644
--- a/torch/_functorch/_aot_autograd/traced_function_transforms.py
+++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py
@@ -29,7 +29,14 @@
from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
-from .functional_utils import from_fun, is_fun, sync_functional_tensor, to_fun
+from .functional_utils import (
+ from_fun,
+ has_data_mutation,
+ has_metadata_mutation,
+ is_fun,
+ sync_functional_tensor,
+ to_fun,
+)
from .logging_utils import setup_stacktrace_preservation_hooks
from .schemas import (
AOTConfig,
@@ -347,6 +354,56 @@
# Run the joint
f_outs = fn(*f_args)
+ if trace_joint:
+ # We support a limited amount of mutation of graph inputs during the backward pass.
+ # (This is used e.g. by Float8, which needs to update buffers during the backward pass)
+ # Here, we perform extra checks for primals that were mutated in the **backward**
+ # We're doing the checks here instead of doing them with the rest of the input mutation handling because:
+ # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
+ # during the forward, because the handling is different: some input mutations from the the forward
+ # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
+ # types of mutations in the backward we would need a bw-only runtime epilogue.
+ # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
+ # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
+ # require an extra round of tracing though, so it's more efficient to do in-line here.
+ assert (
+ isinstance(args, tuple)
+ and len(args) == 2
+ and isinstance(args[0], (list, tuple))
+ )
+ # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
+ primals_before = args[0]
+ primals_after = pytree.tree_map(from_fun, f_args[0])
+ for f_inpt, before, after, inpt_info in zip(
+ f_args[0], primals_before, primals_after, meta.input_info
+ ):
+ # Ban metadata mutations on fw inputs during the bw
+ if not inpt_info.mutates_metadata:
+ assert not has_metadata_mutation(
+ f_inpt, before, check_only_storage_mutation=False
+ ), "Found a graph input that had its metadata mutated in the backward. This is not supported"
+ # Allow data mutations on fw inputs during the bw, but only if they do not require grad
+ # So we can guarantee that we can keep the mutations in the graph
+ if has_data_mutation(f_inpt) and not inpt_info.mutates_data:
+ assert (
+ not inpt_info.requires_grad
+ ), "Found a graph input that requires_grad and was mutated in the backward. This is not supported"
+ # Otherwise, put the mutation in the graph
+ before.copy_(after)
+ # Now that we covered mutations to *forward* inputs during the backward,
+ # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
+ # Today, we will just error in all cases of this happening unless someone needs us to support it.
+ tangents_before = args[1]
+ tangents_after = pytree.tree_map(from_fun, f_args[1])
+ for f_inpt, before, after in zip(
+ f_args[1], tangents_before, tangents_after
+ ):
+ assert not has_metadata_mutation(
+ f_inpt, before, check_only_storage_mutation=False
+ ) and not has_data_mutation(
+ f_inpt
+ ), "Found an input to the backward that was mutated during the backward pass. This is not supported"
+
if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where:
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.