[aotinductor] Add example_value metadata to nodes (#112415)
split_cat fx passes expect the `example_value` metadata on every node. However, the graph module from _export_torch_ir does not contain this metadata, causing the split_cat fx passes to not run. So, I added a pass to add this metadata to every node in the graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112415
Approved by: https://github.com/frank-wei
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 91d8089..2f0b639 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -11,6 +11,7 @@
import torch._inductor
import torch.fx._pytree as fx_pytree
from torch._dynamo.testing import same
+from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.exc import CppWrapperCodeGenError
from torch._inductor.utils import aot_inductor_launcher
@@ -202,6 +203,21 @@
)
self.check_model(Model(), example_inputs)
+ def test_simple_split(self):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
+
+ example_inputs = (torch.randn(2, 8, device=self.device),)
+ counters.clear()
+ self.check_model(Model(), example_inputs)
+ self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
+ self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
+ self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
+
def test_small_constant(self):
class Model(torch.nn.Module):
def __init__(self):
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index ad955cf..891290f 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -947,6 +947,10 @@
# We want to export to Torch IR here to utilize the pre_grad passes in
# inductor, which run on Torch IR.
gm = _export_to_torch_ir(f, args, kwargs, constraints)
+
+ from torch._export.passes.fake_tensor_prop import FakeTensorProp
+ FakeTensorProp(gm).run()
+
flat_example_inputs = pytree.arg_tree_leaves(*args, **kwargs or {})
with torch.no_grad():
diff --git a/torch/_export/passes/fake_tensor_prop.py b/torch/_export/passes/fake_tensor_prop.py
new file mode 100644
index 0000000..b49ca15
--- /dev/null
+++ b/torch/_export/passes/fake_tensor_prop.py
@@ -0,0 +1,18 @@
+from torch.fx.interpreter import Interpreter
+
+
+class FakeTensorProp(Interpreter):
+ def run(self):
+ inp = tuple(
+ node.meta["val"]
+ for node in self.module.graph.nodes
+ if node.op == "placeholder"
+ )
+ super().run(*inp)
+
+ def run_node(self, node):
+ res = super().run_node(node)
+ # split_cat fx passes expect "example_value" metadata on the nodes
+ node.meta["example_value"] = res
+ node.meta["val"] = res
+ return res