[aotinductor] Add example_value metadata to nodes (#112415)

split_cat fx passes expect the `example_value` metadata on every node. However, the graph module from _export_torch_ir does not contain this metadata, causing the split_cat fx passes to not run. So, I added a pass to add this metadata to every node in the graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112415
Approved by: https://github.com/frank-wei
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 91d8089..2f0b639 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -11,6 +11,7 @@
 import torch._inductor
 import torch.fx._pytree as fx_pytree
 from torch._dynamo.testing import same
+from torch._dynamo.utils import counters
 from torch._inductor import config
 from torch._inductor.exc import CppWrapperCodeGenError
 from torch._inductor.utils import aot_inductor_launcher
@@ -202,6 +203,21 @@
         )
         self.check_model(Model(), example_inputs)
 
+    def test_simple_split(self):
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
+
+        example_inputs = (torch.randn(2, 8, device=self.device),)
+        counters.clear()
+        self.check_model(Model(), example_inputs)
+        self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
+        self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
+        self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
+
     def test_small_constant(self):
         class Model(torch.nn.Module):
             def __init__(self):
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index ad955cf..891290f 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -947,6 +947,10 @@
     # We want to export to Torch IR here to utilize the pre_grad passes in
     # inductor, which run on Torch IR.
     gm = _export_to_torch_ir(f, args, kwargs, constraints)
+
+    from torch._export.passes.fake_tensor_prop import FakeTensorProp
+    FakeTensorProp(gm).run()
+
     flat_example_inputs = pytree.arg_tree_leaves(*args, **kwargs or {})
 
     with torch.no_grad():
diff --git a/torch/_export/passes/fake_tensor_prop.py b/torch/_export/passes/fake_tensor_prop.py
new file mode 100644
index 0000000..b49ca15
--- /dev/null
+++ b/torch/_export/passes/fake_tensor_prop.py
@@ -0,0 +1,18 @@
+from torch.fx.interpreter import Interpreter
+
+
+class FakeTensorProp(Interpreter):
+    def run(self):
+        inp = tuple(
+            node.meta["val"]
+            for node in self.module.graph.nodes
+            if node.op == "placeholder"
+        )
+        super().run(*inp)
+
+    def run_node(self, node):
+        res = super().run_node(node)
+        # split_cat fx passes expect "example_value" metadata on the nodes
+        node.meta["example_value"] = res
+        node.meta["val"] = res
+        return res