[AotAutograd] Move mutations hidden from autograd in graph (#113454)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113454
Approved by: https://github.com/bdhirsh
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 7cedc6f..30d9f7f 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1918,6 +1918,12 @@
         # to be in the metadata, so there might be false negatives
         self.assertTrue("aten.copy" not in codes[0])
         self.assertTrue("aten.clone" not in codes[0])
+        # The following checks that there are only the tensor output is in
+        # the compiled graph
+        if dynamic and grad:
+            self.assertTrue("return (buf0, s0, )" in codes[0])
+        else:
+            self.assertTrue("return (buf0, )" in codes[0])
 
     @requires_cuda()
     @skipIfRocm
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index f022c1f..c09ab6b 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -1156,7 +1156,13 @@
                 mutates_metadata=mutates_metadata,
                 mutations_hidden_from_autograd=mutations_hidden_from_autograd,
                 requires_grad=requires_grad,
-                mutation_type=_get_mutation_type(keep_input_mutations, mutates_data, mutates_metadata, requires_grad)
+                mutation_type=_get_mutation_type(
+                    keep_input_mutations,
+                    mutates_data,
+                    mutates_metadata,
+                    mutations_hidden_from_autograd,
+                    requires_grad
+                )
             ))
 
         # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate,
@@ -1426,6 +1432,7 @@
                 keep_input_mutations,
                 mutates_data=info.mutates_data,
                 mutates_metadata=info.mutates_metadata,
+                mutations_hidden_from_autograd=info.mutations_hidden_from_autograd,
                 requires_grad=info.requires_grad
                 # MUTATED_OUT_GRAPH corresponds to any input mutations that happen outside the graph.
                 # this can also include metadata mutations, and inputs that do not require grad,
@@ -2295,17 +2302,35 @@
     return actual_aliased_indices
 
 
-def _check_if_mutation_can_be_in_graph(keep_input_mutations: bool, mutates_data, mutates_metadata, requires_grad):
+def _check_if_mutation_can_be_in_graph(
+    keep_input_mutations: bool,
+    mutates_data,
+    mutates_metadata,
+    mutations_hidden_from_autograd,
+    requires_grad
+):
     if keep_input_mutations:
-        return mutates_data and not mutates_metadata and not requires_grad
+        return mutates_data and ((not mutates_metadata and not requires_grad) or mutations_hidden_from_autograd)
     return False
 
 
-def _get_mutation_type(keep_input_mutations: bool, mutates_data, mutates_metadata, requires_grad):
+def _get_mutation_type(
+    keep_input_mutations: bool,
+    mutates_data,
+    mutates_metadata,
+    mutations_hidden_from_autograd,
+    requires_grad
+):
     if (not mutates_data) and (not mutates_metadata):
         return MutationType.NOT_MUTATED
 
-    if _check_if_mutation_can_be_in_graph(keep_input_mutations, mutates_data, mutates_metadata, requires_grad):
+    if _check_if_mutation_can_be_in_graph(
+        keep_input_mutations,
+        mutates_data,
+        mutates_metadata,
+        mutations_hidden_from_autograd,
+        requires_grad
+    ):
         return MutationType.MUTATED_IN_GRAPH
 
     return MutationType.MUTATED_OUT_GRAPH
@@ -2619,6 +2644,14 @@
         mutates_data = True if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_data
         mutates_metadata = False if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_metadata
         requires_grad = any(m.input_info[x].requires_grad for x in outer_indices)
+        mutations_hidden_from_autograd = all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices)
+        mutation_type = _get_mutation_type(
+            m.keep_input_mutations,
+            mutates_data,
+            mutates_metadata,
+            mutations_hidden_from_autograd,
+            requires_grad
+        )
 
         inpt_info = InputAliasInfo(
             # If len(outer_indices) > 1, then this input is a synthetic base.
@@ -2627,10 +2660,10 @@
             # mutations, they will be hidden from the rest of aot autograd.
             mutates_data=mutates_data,
             mutates_metadata=mutates_metadata,
-            mutations_hidden_from_autograd=all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices),
+            mutations_hidden_from_autograd=mutations_hidden_from_autograd,
             is_leaf=any_leaf,
             requires_grad=requires_grad,
-            mutation_type=_get_mutation_type(m.keep_input_mutations, mutates_data, mutates_metadata, requires_grad)
+            mutation_type=mutation_type,
         )
         input_infos.append(inpt_info)