[export] add pass to remove auto functionalized hop (#122246)
Summary: Adds a pass that blindly removes the functionalize hop without consideration on if its safe. Useful for ExecuTorch today and other usecases that have additional logic that can reason about when this pass is safe to use
Test Plan: added unit test
Differential Revision: D55103867
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122246
Approved by: https://github.com/angelayi
diff --git a/test/export/test_passes.py b/test/export/test_passes.py
index 32ea8af..ab78127 100644
--- a/test/export/test_passes.py
+++ b/test/export/test_passes.py
@@ -7,8 +7,8 @@
import math
import operator
import unittest
-from typing import List, Set
from re import escape
+from typing import List, Set
import torch
from functorch.experimental.control_flow import cond
@@ -17,22 +17,38 @@
from torch._export.passes.functionalize_side_effectful_ops_pass import (
_FunctionalizeSideEffectfulOpsPass,
)
+from torch._export.passes.replace_set_grad_with_hop_pass import (
+ _is_set_grad_enabled_node,
+ _is_set_grad_enabled_sub_mod,
+)
from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import (
get_view_copy_of_view_op,
is_view_op,
ReplaceViewOpsWithViewCopyOpsPass,
)
+from torch._export.utils import (
+ node_inline_,
+ nodes_count,
+ nodes_filter,
+ nodes_map,
+ sequential_split,
+)
+from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch.export import export
+from torch.export._remove_auto_functionalized_pass import (
+ unsafe_remove_auto_functionalized_pass,
+)
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupport
+from torch.library import impl, _scoped_library
from torch.testing import FileCheck
-from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, IS_WINDOWS
-from torch.utils import _pytree as pytree
-from torch._export.utils import sequential_split, nodes_filter, nodes_map, node_inline_, nodes_count
-from torch._export.passes.replace_set_grad_with_hop_pass import (
- _is_set_grad_enabled_node, _is_set_grad_enabled_sub_mod
+from torch.testing._internal.common_utils import (
+ IS_WINDOWS,
+ run_tests,
+ skipIfTorchDynamo,
+ TestCase,
)
-
+from torch.utils import _pytree as pytree
def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int:
count = 0
@@ -620,5 +636,96 @@
self.assertEqual(before_str, after_inline_str)
self.assertEqual(gm(*args), new_gm(*args))
+ def test_remove_auto_functionalized_pass(self) -> None:
+ with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
+
+ lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor")
+
+ @impl(lib, "custom_mutator", "Meta")
+ def custom_mutator_meta(
+ x: torch.Tensor,
+ y: torch.Tensor,
+ ) -> torch.Tensor:
+ return torch.empty_like(x)
+
+
+ @impl(lib, "custom_mutator", "CompositeExplicitAutograd")
+ def custom_mutator(
+ x: torch.Tensor,
+ y: torch.Tensor,
+ ) -> torch.Tensor:
+ return x + y.add_(1)
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("state", torch.zeros(1))
+
+ def forward(self, x):
+ return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator(x, self.state)
+
+ mod = M()
+ x = torch.randn([3, 3])
+ ep = export(mod, (x,))
+ inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
+ nodes = inplace_ep.graph.nodes
+ for node in nodes:
+ if node.op == "call_function":
+ self.assertFalse(node.target is auto_functionalized)
+ self.assertFalse(node.target is operator.getitem)
+
+ for spec in inplace_ep.graph_signature.output_specs:
+ self.assertFalse("getitem" in spec.arg.name)
+
+ def test_remove_auto_functionalized_pass_tuple(self) -> None:
+ with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
+
+ lib.define("custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)")
+
+ @impl(lib, "custom_mutator_tuple", "Meta")
+ def custom_mutator_tuple_meta(
+ x: torch.Tensor,
+ y: torch.Tensor,
+ ):
+ return (torch.empty_like(x), torch.empty_like(x))
+
+
+ @impl(lib, "custom_mutator_tuple", "CompositeExplicitAutograd")
+ def custom_mutator_tuple(
+ x: torch.Tensor,
+ y: torch.Tensor,
+ ):
+ return (x, x + y.add_(1))
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("state", torch.zeros(1))
+
+ def forward(self, x):
+ return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple(
+ x, self.state
+ )
+
+ mod = M()
+ x = torch.randn([3, 3])
+ ep = export(mod, (x,))
+ inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
+
+ nodes = inplace_ep.graph.nodes
+ getitems = 0
+ for node in nodes:
+ if node.op == "call_function":
+ self.assertFalse(node.target is auto_functionalized)
+ if node.target is operator.getitem:
+ getitems += 1
+ self.assertEqual(getitems, 2) # tuple return of len 2
+
+ out_specs = inplace_ep.graph_signature.output_specs
+ self.assertEqual(out_specs[0].arg.name, "arg0_1") # state
+ self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1
+ self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py
new file mode 100644
index 0000000..7628041
--- /dev/null
+++ b/torch/export/_remove_auto_functionalized_pass.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import operator
+from typing import List
+
+import torch
+from torch._higher_order_ops.auto_functionalize import (
+ auto_functionalized,
+ get_mutable_arg_names,
+)
+from torch.export import ExportedProgram
+
+
+def unsafe_remove_auto_functionalized_pass(
+ ep: ExportedProgram,
+) -> ExportedProgram:
+ """
+ This pass removes an instances of the higher order op 'auto_functionalized',
+ and modifies the calling EP inplace to have the original mutator op.
+ This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
+ """
+ auto_functionalize_nodes: List[torch.fx.Node] = []
+ for module in ep.graph_module.modules():
+ if not isinstance(module, torch.fx.GraphModule):
+ continue
+ for node in ep.graph.nodes:
+ if node.op == "call_function" and node.target is auto_functionalized:
+ auto_functionalize_nodes.append(node)
+
+ # Update every use of the HOP
+ for node in reversed(auto_functionalize_nodes):
+ func = node.args[0]
+ original_kwargs = node.kwargs
+ assert isinstance(func, torch._ops.OpOverload)
+
+ with ep.graph.inserting_before(node):
+ # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
+ new_node = ep.graph.call_function(func, kwargs=node.kwargs)
+ for k, v in node.meta.items():
+ new_node.meta[k] = v
+
+ # Replace auto_functionalize(func, args) with just func(args)
+ node.replace_all_uses_with(new_node)
+
+ mutable_args_names = get_mutable_arg_names(new_node.target)
+ output_specs = ep.graph_signature.output_specs
+
+ # update the users of the auto_func node (the getitem nodes)
+ for user in list(new_node.users.keys()):
+ assert user.target == operator.getitem
+ # getitem corresponding to a mutated input, just replace all uses with the original input
+ if user.args[1] >= len(func._schema.returns):
+ assert user.args[1] <= len(func._schema.returns) + len(
+ mutable_args_names
+ )
+
+ # If the result of getitem was used in an output node, update the output spec with the correct name
+ adusted_index = user.args[1] - len(func._schema.returns)
+ original_arg = original_kwargs[mutable_args_names[adusted_index]]
+ for spec in output_specs:
+ if spec.arg.name == user.name:
+ spec.arg.name = original_arg.name # pyre-ignore
+ break
+
+ # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
+ # of the getitem calls following the HOP.
+ user.replace_all_uses_with(
+ original_kwargs[mutable_args_names[adusted_index]]
+ )
+
+ if len(func._schema.returns) == 1:
+ # If the function has 1 return then it will just directly return the
+ # result -- we don't need a getitem. So we can replace all the
+ # getitem(auto_functionalized, 0) with just the note itself.
+ for user in list(new_node.users.keys()):
+ if user.args[1] == 0:
+ user.replace_all_uses_with(new_node)
+
+ # Same case as above, update the output spec if getitem result used in an output node
+ for spec in output_specs:
+ if spec.arg.name == user.name:
+ spec.arg.name = new_node.name
+ break
+
+ new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
+ ep.graph.erase_node(node)
+
+ ep.graph.eliminate_dead_code()
+ return ep