Rewrite `unsafe_remove_auto_functionalized_pass` using `decompose_auto_functionalized` (#134831)
`unsafe_remove_auto_functionalized_pass` can be written as using `decompose_auto_functionalized`, this way we do not have to update it each time we do a change to `auto_functionalize` (Ex https://github.com/pytorch/pytorch/pull/134409) , and we avoid duplicate logics implemented in two different ways.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134831
Approved by: https://github.com/zou3519
diff --git a/test/export/test_passes.py b/test/export/test_passes.py
index 7bcd50a..624b528 100644
--- a/test/export/test_passes.py
+++ b/test/export/test_passes.py
@@ -1166,20 +1166,19 @@
x = torch.randn([3, 3])
ep = export(mod, (x,))
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
-
- nodes = inplace_ep.graph.nodes
- getitems = 0
- for node in nodes:
- if node.op == "call_function":
- self.assertFalse(node.target is auto_functionalized)
- if node.target is operator.getitem:
- getitems += 1
- self.assertEqual(getitems, 2) # tuple return of len 2
-
- out_specs = inplace_ep.graph_signature.output_specs
- self.assertEqual(out_specs[0].arg.name, "b_state") # state
- self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1
- self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2
+ graph_text = str(inplace_ep.graph)
+ self.assertExpectedInline(
+ graph_text,
+ """\
+graph():
+ %b_state : [num_users=2] = placeholder[target=b_state]
+ %x : [num_users=1] = placeholder[target=x]
+ %custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\
+default](args = (%x, %b_state), kwargs = {})
+ %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {})
+ %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {})
+ return (b_state, getitem_3, getitem_4)""",
+ )
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_move_to_device_pass(self):
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index 9c78d09..c31d7aa 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -236,9 +236,13 @@
replacement graph.
"""
- from torch._inductor.virtualized import V
+ from torch._inductor.virtualized import NullHandler, V
- context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
+ context = (
+ V.fake_mode
+ if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
+ else contextlib.nullcontext()
+ )
with context:
if trace_fn is None:
diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py
index 930915f..76e2541 100644
--- a/torch/export/_remove_auto_functionalized_pass.py
+++ b/torch/export/_remove_auto_functionalized_pass.py
@@ -5,64 +5,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-import operator
-from typing import List
import torch
from torch._higher_order_ops.auto_functionalize import (
auto_functionalized,
get_mutable_arg_names,
)
+from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
from torch.export import ExportedProgram
-def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes):
- # Update every use of the HOP
- for node in reversed(auto_functionalize_nodes):
- func = node.args[0]
- original_kwargs = node.kwargs
- assert isinstance(func, torch._ops.OpOverload)
-
- with ep.graph.inserting_before(node):
- # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
- new_node = ep.graph.call_function(func, kwargs=node.kwargs)
- for k, v in node.meta.items():
- new_node.meta[k] = v
-
- # Replace auto_functionalize(func, args) with just func(args)
- node.replace_all_uses_with(new_node)
-
- mutable_args_names = get_mutable_arg_names(new_node.target)
-
- # update the users of the auto_func node (the getitem nodes)
- for user in list(new_node.users.keys()):
- assert user.target == operator.getitem
- # getitem corresponding to a mutated input, just replace all uses with the original input
- if user.args[1] >= len(func._schema.returns):
- assert user.args[1] <= len(func._schema.returns) + len(
- mutable_args_names
- )
-
- # If the result of getitem was used in an output node, update the output spec with the correct name
- adjusted_index = user.args[1] - len(func._schema.returns)
- original_arg = original_kwargs[mutable_args_names[adjusted_index]]
-
- # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
- # of the getitem calls following the HOP.
- user.replace_all_uses_with(original_arg)
-
- if len(func._schema.returns) == 1:
- # If the function has 1 return then it will just directly return the
- # result -- we don't need a getitem. So we can replace all the
- # getitem(auto_functionalized, 0) with just the note itself.
- for user in list(new_node.users.keys()):
- if user.args[1] == 0:
- user.replace_all_uses_with(new_node)
-
- new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
- ep.graph.erase_node(node)
-
- ep.graph.eliminate_dead_code()
+def remove_self_clone(graph: torch.fx.Graph):
+ for node in graph.nodes:
+ if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
+ node.replace_all_uses_with(node.args[0])
+ graph.erase_node(node)
def unsafe_remove_auto_functionalized_pass(
@@ -73,15 +30,20 @@
and modifies the calling EP inplace to have the original mutator op.
This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
"""
- auto_functionalize_nodes: List[torch.fx.Node] = []
- for module in ep.graph_module.modules():
- if not isinstance(module, torch.fx.GraphModule):
- continue
- for node in ep.graph.nodes:
- if node.op == "call_function" and node.target is auto_functionalized:
- auto_functionalize_nodes.append(node)
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
- _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes)
+ for module in ep.graph_module.modules():
+ if not isinstance(module, torch.fx.GraphModule):
+ continue
+ for node in ep.graph.nodes:
+ if node.op == "call_function" and node.target is auto_functionalized:
+ func = node.args[0]
+ assert isinstance(func, torch._ops.OpOverload)
+ mutable_args_names = get_mutable_arg_names(func)
+ # re-inplace everything
+ node.meta["only_clone_these_tensors"] = []
+ decompose_auto_functionalized(ep.graph)
+ remove_self_clone(ep.graph)
+ ep.graph.eliminate_dead_code()
return ep