Rewrite `unsafe_remove_auto_functionalized_pass` using `decompose_auto_functionalized` (#134831)

`unsafe_remove_auto_functionalized_pass` can be written as using `decompose_auto_functionalized`, this way we do not have to update it each time we do a change to `auto_functionalize` (Ex https://github.com/pytorch/pytorch/pull/134409) , and we avoid duplicate logics implemented in two different ways.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134831
Approved by: https://github.com/zou3519
diff --git a/test/export/test_passes.py b/test/export/test_passes.py
index 7bcd50a..624b528 100644
--- a/test/export/test_passes.py
+++ b/test/export/test_passes.py
@@ -1166,20 +1166,19 @@
             x = torch.randn([3, 3])
             ep = export(mod, (x,))
             inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
-
-            nodes = inplace_ep.graph.nodes
-            getitems = 0
-            for node in nodes:
-                if node.op == "call_function":
-                    self.assertFalse(node.target is auto_functionalized)
-                    if node.target is operator.getitem:
-                        getitems += 1
-            self.assertEqual(getitems, 2)  # tuple return of len 2
-
-            out_specs = inplace_ep.graph_signature.output_specs
-            self.assertEqual(out_specs[0].arg.name, "b_state")  # state
-            self.assertEqual(out_specs[1].arg.name, "getitem")  # tuple return 1
-            self.assertEqual(out_specs[2].arg.name, "getitem_1")  # tuple return 2
+            graph_text = str(inplace_ep.graph)
+            self.assertExpectedInline(
+                graph_text,
+                """\
+graph():
+    %b_state : [num_users=2] = placeholder[target=b_state]
+    %x : [num_users=1] = placeholder[target=x]
+    %custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\
+default](args = (%x, %b_state), kwargs = {})
+    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {})
+    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {})
+    return (b_state, getitem_3, getitem_4)""",
+            )
 
     @unittest.skipIf(not TEST_CUDA, "requires cuda")
     def test_move_to_device_pass(self):
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index 9c78d09..c31d7aa 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -236,9 +236,13 @@
                 replacement graph.
 
         """
-        from torch._inductor.virtualized import V
+        from torch._inductor.virtualized import NullHandler, V
 
-        context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
+        context = (
+            V.fake_mode
+            if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
+            else contextlib.nullcontext()
+        )
 
         with context:
             if trace_fn is None:
diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py
index 930915f..76e2541 100644
--- a/torch/export/_remove_auto_functionalized_pass.py
+++ b/torch/export/_remove_auto_functionalized_pass.py
@@ -5,64 +5,21 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-import operator
-from typing import List
 
 import torch
 from torch._higher_order_ops.auto_functionalize import (
     auto_functionalized,
     get_mutable_arg_names,
 )
+from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
 from torch.export import ExportedProgram
 
 
-def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes):
-    # Update every use of the HOP
-    for node in reversed(auto_functionalize_nodes):
-        func = node.args[0]
-        original_kwargs = node.kwargs
-        assert isinstance(func, torch._ops.OpOverload)
-
-        with ep.graph.inserting_before(node):
-            # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
-            new_node = ep.graph.call_function(func, kwargs=node.kwargs)
-        for k, v in node.meta.items():
-            new_node.meta[k] = v
-
-        # Replace auto_functionalize(func, args) with just func(args)
-        node.replace_all_uses_with(new_node)
-
-        mutable_args_names = get_mutable_arg_names(new_node.target)
-
-        # update the users of the auto_func node (the getitem nodes)
-        for user in list(new_node.users.keys()):
-            assert user.target == operator.getitem
-            # getitem corresponding to a mutated input, just replace all uses with the original input
-            if user.args[1] >= len(func._schema.returns):
-                assert user.args[1] <= len(func._schema.returns) + len(
-                    mutable_args_names
-                )
-
-                # If the result of getitem was used in an output node, update the output spec with the correct name
-                adjusted_index = user.args[1] - len(func._schema.returns)
-                original_arg = original_kwargs[mutable_args_names[adjusted_index]]
-
-                # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
-                # of the getitem calls following the HOP.
-                user.replace_all_uses_with(original_arg)
-
-        if len(func._schema.returns) == 1:
-            # If the function has 1 return then it will just directly return the
-            # result -- we don't need a getitem. So we can replace all the
-            # getitem(auto_functionalized, 0) with just the note itself.
-            for user in list(new_node.users.keys()):
-                if user.args[1] == 0:
-                    user.replace_all_uses_with(new_node)
-
-        new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
-        ep.graph.erase_node(node)
-
-    ep.graph.eliminate_dead_code()
+def remove_self_clone(graph: torch.fx.Graph):
+    for node in graph.nodes:
+        if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
+            node.replace_all_uses_with(node.args[0])
+            graph.erase_node(node)
 
 
 def unsafe_remove_auto_functionalized_pass(
@@ -73,15 +30,20 @@
     and modifies the calling EP inplace to have the original mutator op.
     This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
     """
-    auto_functionalize_nodes: List[torch.fx.Node] = []
-    for module in ep.graph_module.modules():
-        if not isinstance(module, torch.fx.GraphModule):
-            continue
-        for node in ep.graph.nodes:
-            if node.op == "call_function" and node.target is auto_functionalized:
-                auto_functionalize_nodes.append(node)
 
     with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
-        _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes)
+        for module in ep.graph_module.modules():
+            if not isinstance(module, torch.fx.GraphModule):
+                continue
+            for node in ep.graph.nodes:
+                if node.op == "call_function" and node.target is auto_functionalized:
+                    func = node.args[0]
+                    assert isinstance(func, torch._ops.OpOverload)
+                    mutable_args_names = get_mutable_arg_names(func)
+                    # re-inplace everything
+                    node.meta["only_clone_these_tensors"] = []
+            decompose_auto_functionalized(ep.graph)
+            remove_self_clone(ep.graph)
+            ep.graph.eliminate_dead_code()
 
     return ep