Revert "[test] AOTAutograd: support mutations on buffers that happen during th bw (#112906)"
This reverts commit c8974d649d684a33a5c02a0b112a6e0743201d97.
Reverted https://github.com/pytorch/pytorch/pull/112906 on behalf of https://github.com/huydhn due to There are lots of failure after this change https://hud.pytorch.org/pytorch/pytorch/commit/c8974d649d684a33a5c02a0b112a6e0743201d97, this is probably a landrace ([comment](https://github.com/pytorch/pytorch/pull/112906#issuecomment-1831016362))
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index e2edfbc..ab310c2 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -1438,91 +1438,6 @@
with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
- # Mutations in the backward are allowed as long as the mutated object does not require grad
- def test_backward_mutation_data(self):
- class BwMutation(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x):
- ctx.save_for_backward(x)
- return x.clone()
-
- @staticmethod
- def backward(ctx, grad_output):
- x, = ctx.saved_tensors
- # bw mutation
- x.mul_(2)
- return grad_output.clone()
-
- def f(a, b):
- out = BwMutation.apply(b)
- return a * out
-
- inp_no_grad = [
- torch.ones(3, 3, requires_grad=True),
- torch.ones(3, 3, requires_grad=False),
- ]
-
- # Mutation on buffer that does not require grad during the backward is allowed
- self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
-
- inp_grad = [
- torch.ones(3, 3, requires_grad=True),
- torch.ones(3, 3, requires_grad=True),
- ]
- with self.assertRaisesRegex(AssertionError, "input that requires_grad and was mutated in the backward"):
- self.verify_aot_autograd(f, inp_grad, test_mutation=True)
-
- def test_backward_mutation_metadata(self):
- class BwMutation(torch.autograd.Function):
- @staticmethod
- def forward(ctx, a, b):
- ctx.save_for_backward(b)
- return a.clone(), b.clone()
-
- @staticmethod
- def backward(ctx, grad_a, grad_b):
- b, = ctx.saved_tensors
- # bw metadata mutation
- b.transpose_(1, 0)
- return grad_a.clone(), grad_b.clone()
-
- def f(a, b):
- a_, b_ = BwMutation.apply(a, b)
- out = a_ * b_
- return out
-
- inp_no_grad = [
- torch.ones(3, 3, requires_grad=True),
- torch.ones(3, 3, requires_grad=False),
- ]
-
- with self.assertRaisesRegex(AssertionError, "input that had its metadata mutated in the backward"):
- self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
-
- def test_backward_mutation_on_grad_out(self):
- class BwMutation(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x):
- return x.clone()
-
- @staticmethod
- def backward(ctx, grad_output):
- grad_output.mul_(2)
- return grad_output.clone()
-
- def f(a, b):
- tmp = a * b
- out = BwMutation.apply(tmp)
- return out
-
- inp_grad = [
- torch.ones(3, 3, requires_grad=True),
- torch.ones(3, 3, requires_grad=True),
- ]
- f_compiled = aot_function(f, nop)
- with self.assertRaisesRegex(AssertionError, "input to the backward that was mutated during the backward"):
- out = f_compiled(*inp_grad)
-
# Partially addresses https://github.com/pytorch/pytorch/issues/106457
def test_input_mutation_false_aliasing(self):
def f(a, b):
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index f1c1d02..1bfd0c4 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -1910,43 +1910,6 @@
# Run the joint
f_outs = fn(*f_args)
- if trace_joint:
- # We support a limited amount of mutation of graph inputs during the backward pass.
- # (This is used e.g. by Float8, which needs to update buffers during the backward pass)
- # Here, we perform extra checks for primals that were mutated in the **backward**
- # We're doing the checks here instead of doing them with the rest of the input mutation handling because:
- # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
- # during the forward, because the handling is different: some input mutations from the the forward
- # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
- # types of mutations in the backward we would need a bw-only runtime epilogue.
- # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
- # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
- # require an extra round of tracing though, so it's more efficient to do in-line here.
- assert isinstance(args, tuple) and len(args) == 2 and isinstance(args[0], (list, tuple))
- # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
- primals_before = args[0]
- primals_after = pytree.tree_map(from_fun, f_args[0])
- for before, after, inpt_info in zip(primals_before, primals_after, meta.input_info):
- # Ban metadata mutations on fw inputs during the bw
- if not inpt_info.mutates_metadata:
- assert not was_metadata_updated(before, after), \
- "Found a graph input that had its metadata mutated in the backward. This is not supported"
- # Allow data mutations on fw inputs during the bw, but only if they do not require grad
- # So we can guarantee that we can keep the mutations in the graph
- if was_updated(before, after) and not inpt_info.mutates_data:
- assert not inpt_info.requires_grad, \
- "Found a graph input that requires_grad and was mutated in the backward. This is not supported"
- # Otherwise, put the mutation in the graph
- before.copy_(after)
- # Now that we covered mutations to *forward* inputs during the backward,
- # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
- # Today, we will just error in all cases of this happening unless someone needs us to support it.
- tangents_before = args[1]
- tangents_after = pytree.tree_map(from_fun, f_args[1])
- for before, after in zip(tangents_before, tangents_after):
- assert not was_metadata_updated(before, after) and not was_updated(before, after), \
- "Found an input to the backward that was mutated during the backward pass. This is not supported"
-
if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where:
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
@@ -2230,7 +2193,7 @@
if n.op == "placeholder":
placeholders.add(n)
if isinstance(n.target, torch._ops.OpOverload):
- if n.target is aten.copy_.default:
+ if n.target is aten.copy_.default and allow_input_mutations:
suffix = True
# Can only copy_ into an input, and can only do so once
assert n.args[0] in placeholders