[export] add pass to remove auto functionalized hop (#122246)

Summary: Adds a pass that blindly removes the functionalize hop without consideration on if its safe. Useful for ExecuTorch today and other usecases that have additional logic that can reason about when this pass is safe to use

Test Plan: added unit test

Differential Revision: D55103867

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122246
Approved by: https://github.com/angelayi
diff --git a/test/export/test_passes.py b/test/export/test_passes.py
index 32ea8af..ab78127 100644
--- a/test/export/test_passes.py
+++ b/test/export/test_passes.py
@@ -7,8 +7,8 @@
 import math
 import operator
 import unittest
-from typing import List, Set
 from re import escape
+from typing import List, Set
 
 import torch
 from functorch.experimental.control_flow import cond
@@ -17,22 +17,38 @@
 from torch._export.passes.functionalize_side_effectful_ops_pass import (
     _FunctionalizeSideEffectfulOpsPass,
 )
+from torch._export.passes.replace_set_grad_with_hop_pass import (
+    _is_set_grad_enabled_node,
+    _is_set_grad_enabled_sub_mod,
+)
 from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import (
     get_view_copy_of_view_op,
     is_view_op,
     ReplaceViewOpsWithViewCopyOpsPass,
 )
+from torch._export.utils import (
+    node_inline_,
+    nodes_count,
+    nodes_filter,
+    nodes_map,
+    sequential_split,
+)
+from torch._higher_order_ops.auto_functionalize import auto_functionalized
 from torch.export import export
+from torch.export._remove_auto_functionalized_pass import (
+    unsafe_remove_auto_functionalized_pass,
+)
 from torch.fx.passes.infra.partitioner import Partition
 from torch.fx.passes.operator_support import OperatorSupport
+from torch.library import impl, _scoped_library
 from torch.testing import FileCheck
-from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, IS_WINDOWS
-from torch.utils import _pytree as pytree
-from torch._export.utils import sequential_split, nodes_filter, nodes_map, node_inline_, nodes_count
-from torch._export.passes.replace_set_grad_with_hop_pass import (
-    _is_set_grad_enabled_node, _is_set_grad_enabled_sub_mod
+from torch.testing._internal.common_utils import (
+    IS_WINDOWS,
+    run_tests,
+    skipIfTorchDynamo,
+    TestCase,
 )
-
+from torch.utils import _pytree as pytree
 
 def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int:
     count = 0
@@ -620,5 +636,96 @@
             self.assertEqual(before_str, after_inline_str)
             self.assertEqual(gm(*args), new_gm(*args))
 
+    def test_remove_auto_functionalized_pass(self) -> None:
+        with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
+
+            lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor")
+
+            @impl(lib, "custom_mutator", "Meta")
+            def custom_mutator_meta(
+                x: torch.Tensor,
+                y: torch.Tensor,
+            ) -> torch.Tensor:
+                return torch.empty_like(x)
+
+
+            @impl(lib, "custom_mutator", "CompositeExplicitAutograd")
+            def custom_mutator(
+                x: torch.Tensor,
+                y: torch.Tensor,
+            ) -> torch.Tensor:
+                return x + y.add_(1)
+
+            class M(torch.nn.Module):
+                def __init__(self):
+                    super().__init__()
+                    self.register_buffer("state", torch.zeros(1))
+
+                def forward(self, x):
+                    return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator(x, self.state)
+
+            mod = M()
+            x = torch.randn([3, 3])
+            ep = export(mod, (x,))
+            inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
+            nodes = inplace_ep.graph.nodes
+            for node in nodes:
+                if node.op == "call_function":
+                    self.assertFalse(node.target is auto_functionalized)
+                    self.assertFalse(node.target is operator.getitem)
+
+            for spec in inplace_ep.graph_signature.output_specs:
+                self.assertFalse("getitem" in spec.arg.name)
+
+    def test_remove_auto_functionalized_pass_tuple(self) -> None:
+        with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
+
+            lib.define("custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)")
+
+            @impl(lib, "custom_mutator_tuple", "Meta")
+            def custom_mutator_tuple_meta(
+                x: torch.Tensor,
+                y: torch.Tensor,
+            ):
+                return (torch.empty_like(x), torch.empty_like(x))
+
+
+            @impl(lib, "custom_mutator_tuple", "CompositeExplicitAutograd")
+            def custom_mutator_tuple(
+                x: torch.Tensor,
+                y: torch.Tensor,
+            ):
+                return (x, x + y.add_(1))
+
+            class M(torch.nn.Module):
+                def __init__(self):
+                    super().__init__()
+                    self.register_buffer("state", torch.zeros(1))
+
+                def forward(self, x):
+                    return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple(
+                        x, self.state
+                    )
+
+            mod = M()
+            x = torch.randn([3, 3])
+            ep = export(mod, (x,))
+            inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
+
+            nodes = inplace_ep.graph.nodes
+            getitems = 0
+            for node in nodes:
+                if node.op == "call_function":
+                    self.assertFalse(node.target is auto_functionalized)
+                    if node.target is operator.getitem:
+                        getitems += 1
+            self.assertEqual(getitems, 2)  # tuple return of len 2
+
+            out_specs = inplace_ep.graph_signature.output_specs
+            self.assertEqual(out_specs[0].arg.name, "arg0_1")  # state
+            self.assertEqual(out_specs[1].arg.name, "getitem")  # tuple return 1
+            self.assertEqual(out_specs[2].arg.name, "getitem_1")  # tuple return 2
+
+
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py
new file mode 100644
index 0000000..7628041
--- /dev/null
+++ b/torch/export/_remove_auto_functionalized_pass.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import operator
+from typing import List
+
+import torch
+from torch._higher_order_ops.auto_functionalize import (
+    auto_functionalized,
+    get_mutable_arg_names,
+)
+from torch.export import ExportedProgram
+
+
+def unsafe_remove_auto_functionalized_pass(
+    ep: ExportedProgram,
+) -> ExportedProgram:
+    """
+    This pass removes an instances of the higher order op 'auto_functionalized',
+    and modifies the calling EP inplace to have the original mutator op.
+    This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
+    """
+    auto_functionalize_nodes: List[torch.fx.Node] = []
+    for module in ep.graph_module.modules():
+        if not isinstance(module, torch.fx.GraphModule):
+            continue
+        for node in ep.graph.nodes:
+            if node.op == "call_function" and node.target is auto_functionalized:
+                auto_functionalize_nodes.append(node)
+
+    # Update every use of the HOP
+    for node in reversed(auto_functionalize_nodes):
+        func = node.args[0]
+        original_kwargs = node.kwargs
+        assert isinstance(func, torch._ops.OpOverload)
+
+        with ep.graph.inserting_before(node):
+            # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
+            new_node = ep.graph.call_function(func, kwargs=node.kwargs)
+        for k, v in node.meta.items():
+            new_node.meta[k] = v
+
+        # Replace auto_functionalize(func, args) with just func(args)
+        node.replace_all_uses_with(new_node)
+
+        mutable_args_names = get_mutable_arg_names(new_node.target)
+        output_specs = ep.graph_signature.output_specs
+
+        # update the users of the auto_func node (the getitem nodes)
+        for user in list(new_node.users.keys()):
+            assert user.target == operator.getitem
+            # getitem corresponding to a mutated input, just replace all uses with the original input
+            if user.args[1] >= len(func._schema.returns):
+                assert user.args[1] <= len(func._schema.returns) + len(
+                    mutable_args_names
+                )
+
+                # If the result of getitem was used in an output node, update the output spec with the correct name
+                adusted_index = user.args[1] - len(func._schema.returns)
+                original_arg = original_kwargs[mutable_args_names[adusted_index]]
+                for spec in output_specs:
+                    if spec.arg.name == user.name:
+                        spec.arg.name = original_arg.name  # pyre-ignore
+                        break
+
+                # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
+                # of the getitem calls following the HOP.
+                user.replace_all_uses_with(
+                    original_kwargs[mutable_args_names[adusted_index]]
+                )
+
+        if len(func._schema.returns) == 1:
+            # If the function has 1 return then it will just directly return the
+            # result -- we don't need a getitem. So we can replace all the
+            # getitem(auto_functionalized, 0) with just the note itself.
+            for user in list(new_node.users.keys()):
+                if user.args[1] == 0:
+                    user.replace_all_uses_with(new_node)
+
+                    # Same case as above, update the output spec if getitem result used in an output node
+                    for spec in output_specs:
+                        if spec.arg.name == user.name:
+                            spec.arg.name = new_node.name
+                            break
+
+        new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
+        ep.graph.erase_node(node)
+
+    ep.graph.eliminate_dead_code()
+    return ep