[AotAutograd] Move mutations hidden from autograd in graph (#113454)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113454
Approved by: https://github.com/bdhirsh
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 7cedc6f..30d9f7f 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1918,6 +1918,12 @@
# to be in the metadata, so there might be false negatives
self.assertTrue("aten.copy" not in codes[0])
self.assertTrue("aten.clone" not in codes[0])
+ # The following checks that there are only the tensor output is in
+ # the compiled graph
+ if dynamic and grad:
+ self.assertTrue("return (buf0, s0, )" in codes[0])
+ else:
+ self.assertTrue("return (buf0, )" in codes[0])
@requires_cuda()
@skipIfRocm
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index f022c1f..c09ab6b 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -1156,7 +1156,13 @@
mutates_metadata=mutates_metadata,
mutations_hidden_from_autograd=mutations_hidden_from_autograd,
requires_grad=requires_grad,
- mutation_type=_get_mutation_type(keep_input_mutations, mutates_data, mutates_metadata, requires_grad)
+ mutation_type=_get_mutation_type(
+ keep_input_mutations,
+ mutates_data,
+ mutates_metadata,
+ mutations_hidden_from_autograd,
+ requires_grad
+ )
))
# If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate,
@@ -1426,6 +1432,7 @@
keep_input_mutations,
mutates_data=info.mutates_data,
mutates_metadata=info.mutates_metadata,
+ mutations_hidden_from_autograd=info.mutations_hidden_from_autograd,
requires_grad=info.requires_grad
# MUTATED_OUT_GRAPH corresponds to any input mutations that happen outside the graph.
# this can also include metadata mutations, and inputs that do not require grad,
@@ -2295,17 +2302,35 @@
return actual_aliased_indices
-def _check_if_mutation_can_be_in_graph(keep_input_mutations: bool, mutates_data, mutates_metadata, requires_grad):
+def _check_if_mutation_can_be_in_graph(
+ keep_input_mutations: bool,
+ mutates_data,
+ mutates_metadata,
+ mutations_hidden_from_autograd,
+ requires_grad
+):
if keep_input_mutations:
- return mutates_data and not mutates_metadata and not requires_grad
+ return mutates_data and ((not mutates_metadata and not requires_grad) or mutations_hidden_from_autograd)
return False
-def _get_mutation_type(keep_input_mutations: bool, mutates_data, mutates_metadata, requires_grad):
+def _get_mutation_type(
+ keep_input_mutations: bool,
+ mutates_data,
+ mutates_metadata,
+ mutations_hidden_from_autograd,
+ requires_grad
+):
if (not mutates_data) and (not mutates_metadata):
return MutationType.NOT_MUTATED
- if _check_if_mutation_can_be_in_graph(keep_input_mutations, mutates_data, mutates_metadata, requires_grad):
+ if _check_if_mutation_can_be_in_graph(
+ keep_input_mutations,
+ mutates_data,
+ mutates_metadata,
+ mutations_hidden_from_autograd,
+ requires_grad
+ ):
return MutationType.MUTATED_IN_GRAPH
return MutationType.MUTATED_OUT_GRAPH
@@ -2619,6 +2644,14 @@
mutates_data = True if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_data
mutates_metadata = False if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_metadata
requires_grad = any(m.input_info[x].requires_grad for x in outer_indices)
+ mutations_hidden_from_autograd = all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices)
+ mutation_type = _get_mutation_type(
+ m.keep_input_mutations,
+ mutates_data,
+ mutates_metadata,
+ mutations_hidden_from_autograd,
+ requires_grad
+ )
inpt_info = InputAliasInfo(
# If len(outer_indices) > 1, then this input is a synthetic base.
@@ -2627,10 +2660,10 @@
# mutations, they will be hidden from the rest of aot autograd.
mutates_data=mutates_data,
mutates_metadata=mutates_metadata,
- mutations_hidden_from_autograd=all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices),
+ mutations_hidden_from_autograd=mutations_hidden_from_autograd,
is_leaf=any_leaf,
requires_grad=requires_grad,
- mutation_type=_get_mutation_type(m.keep_input_mutations, mutates_data, mutates_metadata, requires_grad)
+ mutation_type=mutation_type,
)
input_infos.append(inpt_info)