Some cleanups in pattern matcher (#112101)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112101
Approved by: https://github.com/eellison
ghstack dependencies: #112093
diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py
index a27be82..d1f7abf 100644
--- a/test/inductor/test_pattern_matcher.py
+++ b/test/inductor/test_pattern_matcher.py
@@ -26,6 +26,24 @@
 
 
 class TestPaternMatcher(TestCase):
+    def common(self, fn, args, expected_matches, expected_nodes):
+        counters.clear()
+        torch.manual_seed(42)
+        expected = fn(*args)
+        torch.manual_seed(42)
+        actual = torch.compile(fn)(*args)
+        torch.testing.assert_close(actual, expected)
+        if inductor_config.cpp_wrapper:
+            # CPP wrapper runs everything twice, so we'll match the pattern twice
+            expected_matches *= 2
+            expected_nodes *= 2
+
+        self.assertEqual(
+            counters["inductor"]["pattern_matcher_count"], expected_matches
+        )
+        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], expected_nodes)
+        counters.clear()
+
     def test_mm_plus_mm(self):
         def fn(a, b, c, d):
             return torch.add(torch.mm(a, b), torch.mm(c, d))
@@ -58,12 +76,7 @@
             ),
         ]
         for args in args_list:
-            counters.clear()
-            expected = fn(*args)
-            actual = torch.compile(fn)(*args)
-            torch.testing.assert_close(actual, expected)
-            self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
-            self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
+            self.common(fn, args, 1, 3)
 
     def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True):
         torch._dynamo.reset()
@@ -467,11 +480,7 @@
             torch.randn(16, 16, device="cuda"),
             torch.randn(16, 16, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5)
+        self.common(fn, args, 2, 5)
 
     def test_cat_addmm(self):
         def fn(a, b, c):
@@ -489,11 +498,7 @@
             torch.randn(16, 16, device="cuda"),
             torch.randn(16, 16, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5)
+        self.common(fn, args, 2, 5)
 
     def test_cat_slice_cat(self):
         def check_counter(counter, expected):
@@ -513,17 +518,13 @@
             torch.randn(2, 32, device="cuda"),
             torch.randn(2, 16, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        check_counter(counters["inductor"]["pattern_matcher_count"], 1)
-        check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
+        self.common(fn, args, 1, 3)
 
-        counters.clear()
         args = [
             torch.randn(2, 8, device="cuda"),
             torch.randn(2, 16, device="cuda"),
         ]
+        counters.clear()
         expected = fn(*args)
         actual = torch.compile(fn)(*args)
         torch.testing.assert_close(actual, expected)
@@ -539,16 +540,11 @@
             slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, -1)
             return torch.ops.aten.cat.default([cat_1, slice_2], 1)
 
-        counters.clear()
         args = [
             torch.randn(2, 8, device="cuda"),
             torch.randn(2, 16, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        check_counter(counters["inductor"]["pattern_matcher_count"], 1)
-        check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
+        self.common(fn, args, 1, 3)
 
     def test_pointless_convert(self):
         def fn1(x):
@@ -624,12 +620,7 @@
         args = [
             torch.randn(2, 32, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
-        counters.clear()
+        self.common(fn, args, 1, 4)
 
         # Not all getitems are passed to cat
         def fn(a):
@@ -643,12 +634,7 @@
         args = [
             torch.randn(2, 32, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
-        counters.clear()
+        self.common(fn, args, 0, 0)
 
         # Different dimensions  (TODO this case should be handled by replacing with a reshape)
         def fn(a):
@@ -661,11 +647,7 @@
         args = [
             torch.randn(2, 32, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
+        self.common(fn, args, 0, 0)
 
         # https://github.com/pytorch/pytorch/issues/99686.
         def fn(a):
@@ -676,11 +658,7 @@
         args = [
             torch.randn(1, 8, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
+        self.common(fn, args, 0, 0)
 
     def test_cat_splitwithsizes(self):
         # good case
@@ -696,12 +674,7 @@
             torch.randn(2, 3, device="cuda"),
             torch.randn(2, 5, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 2)
-        counters.clear()
+        self.common(fn, args, 1, 2)
 
         # cat node has other users
         def fn(a, b, c):
@@ -716,12 +689,7 @@
             torch.randn(2, 3, device="cuda"),
             torch.randn(2, 5, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
-        counters.clear()
+        self.common(fn, args, 0, 0)
 
         # cat and split dims are different
         def fn(a, b, c):
@@ -736,12 +704,7 @@
             torch.randn(10, 3, device="cuda"),
             torch.randn(10, 5, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
-        counters.clear()
+        self.common(fn, args, 0, 0)
 
         # cat and split lenghts are different
         def fn(a, b, c):
@@ -754,12 +717,7 @@
             torch.randn(2, 3, device="cuda"),
             torch.randn(2, 5, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
-        counters.clear()
+        self.common(fn, args, 0, 0)
 
         # cat input sizes and split sizes are different
         def fn(a, b, c):
@@ -774,12 +732,7 @@
             torch.randn(2, 3, device="cuda"),
             torch.randn(2, 5, device="cuda"),
         ]
-        expected = fn(*args)
-        actual = torch.compile(fn)(*args)
-        torch.testing.assert_close(actual, expected)
-        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
-        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
-        counters.clear()
+        self.common(fn, args, 0, 0)
 
     def test_match_with_mutation(self):
         from torch._inductor.pattern_matcher import (
diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py
index dc0eb65..cafd7aa 100644
--- a/torch/_inductor/fx_passes/freezing_patterns.py
+++ b/torch/_inductor/fx_passes/freezing_patterns.py
@@ -8,8 +8,8 @@
 from ..pattern_matcher import (
     _return_true,
     CallFunction,
+    fwd_only,
     Ignored,
-    inference_graph,
     init_once_fakemode,
     KeywordArg,
     Match,
@@ -144,7 +144,7 @@
         matmul_fuse_pattern,
         matmul_replacement,
         [val(), val(), val(), val()],
-        inference_graph,
+        fwd_only,
         pass_patterns[0],
         extra_check=check_concat_weights,
         exclusive_arg_names=("w1", "w2", "w3"),
@@ -162,7 +162,7 @@
         matmul_fuse_pattern_two,
         matmul_replacement_two,
         [val(), val(), val()],
-        inference_graph,
+        fwd_only,
         pass_patterns[0],
         extra_check=check_concat_weights,
         exclusive_arg_names=("w1", "w2"),
@@ -184,7 +184,7 @@
         addmm_fuse_pattern_second,
         addmm_fuse_replacement_second,
         [val() for _ in range(7)],
-        inference_graph,
+        fwd_only,
         pass_patterns[0],
         extra_check=check_concat_weights,
         exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py
index 8001d58..207f645 100644
--- a/torch/_inductor/fx_passes/fuse_attention.py
+++ b/torch/_inductor/fx_passes/fuse_attention.py
@@ -7,9 +7,9 @@
 from ..._dynamo.utils import counters
 from ..pattern_matcher import (
     filter_nodes,
-    inference_graph,
+    fwd_only,
+    joint_fwd_bwd,
     register_replacement,
-    training_graph,
 )
 
 log = logging.getLogger(__name__)
@@ -513,7 +513,6 @@
             # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
             # gets serialized to a python file and does not require tracing at runtime.
             assert isinstance(workaround, dict)
-            training_args = [*args, *workaround.values()]
             name = pattern.__name__
 
             training_name = (
@@ -522,9 +521,9 @@
             yield training_name, {
                 "search_fn": pattern,
                 "replace_fn": replacement,
-                "example_inputs": training_args,
-                "trace_fn": training_graph,
-                "pass_dict": patterns,
+                "example_inputs": args,
+                "trace_fn": joint_fwd_bwd,
+                "pass_dicts": patterns,
                 "extra_check": extra_check,
                 "scalar_workaround": workaround,
             }
@@ -547,8 +546,8 @@
                 "search_fn": pattern,
                 "replace_fn": replacement,
                 "example_inputs": args,
-                "trace_fn": inference_graph,
-                "pass_dict": patterns,
+                "trace_fn": fwd_only,
+                "pass_dicts": patterns,
                 "extra_check": extra_check,
                 "scalar_workaround": workaround,
             }
diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py
index 7f01bdc..3b6d6d1 100644
--- a/torch/_inductor/fx_passes/pad_mm.py
+++ b/torch/_inductor/fx_passes/pad_mm.py
@@ -8,12 +8,7 @@
 from torch.utils._mode_utils import no_dispatch
 from torch.utils._triton import has_triton
 
-from ..pattern_matcher import (
-    inference_graph,
-    Match,
-    register_replacement,
-    training_graph,
-)
+from ..pattern_matcher import fwd_only, joint_fwd_bwd, Match, register_replacement
 
 aten = torch.ops.aten
 
@@ -453,12 +448,11 @@
         ),
     ]:
         assert isinstance(workaround, dict)  # mypy is unable to infer the type properly
-        args = [*args, *workaround.values()]
         register_replacement(
             pattern,
             replacement,
             args,
-            training_graph,
+            joint_fwd_bwd,
             patterns,
             extra_check=extra_check,
             scalar_workaround=workaround,
@@ -467,7 +461,7 @@
             pattern,
             replacement,
             args,
-            inference_graph,
+            fwd_only,
             patterns,
             extra_check=extra_check,
             scalar_workaround=workaround,
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 1defdab..7b2797f 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -566,6 +566,11 @@
     return len(inputs) == 1
 
 
+@register_noop_decomp(aten.view)
+def view_noop(arg, size):
+    return arg.shape == size
+
+
 # Note, we also always have a check for identical metadata, which is why these
 # are safe
 @register_noop_decomp([aten.copy], nop_arg=1)
@@ -576,9 +581,7 @@
 
 def remove_noop_ops(graph: torch.fx.Graph):
     """
-    Removes aten.clone and aten.alias ops from the graph when it's safe.
-
-    Other no-ops should be done as decompositions that selectively turn into aten.clone or aten.alias
+    Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
     """
     input_storages = set()
     output_storages = set()
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index c90d178..67fb384 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -124,7 +124,7 @@
     def replace_by_example(self, replacement_fn, args, trace_fn=None):
         assert self.ctx
         if trace_fn is None:
-            trace_fn = inference_graph
+            trace_fn = fwd_only
         replacement = trace_fn(
             replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
         )
@@ -842,7 +842,7 @@
     replace_fn,
     example_inputs: Iterable[Any],
     trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
-    pass_dict,
+    pass_dicts,
     extra_check=_return_true,
     scalar_workaround=(),
     exclusive_arg_names=(),
@@ -857,10 +857,11 @@
         search_fn: traced to give original pattern
         replace_fn: traced to give replacement graph
         example_inputs: example inputs for initial trace
-        trace_fn: inference_graph or training_graph
+        trace_fn: fwd_only or joint_fwd_bwd
         pass_dict: dict of passes to register to
         extra_check: additional check to run on match(using real shapes)
     """
+    argnames = [*inspect.signature(search_fn).parameters.keys()]
 
     def check_fn(match: Match):
         """
@@ -869,17 +870,24 @@
 
         Recheck the match with the correct shapes.
         """
+        for name in argnames:
+            if name not in match.kwargs:
+                raise RuntimeError(
+                    f"Not all inputs to pattern found in match.kwargs. Perhaps one "
+                    f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
+                )
+
         args = list(
             torch.fx.map_arg(
                 [match.kwargs[name] for name in argnames], lambda n: n.meta["val"]  # type: ignore[has-type]
             )
         )
-        for i, grad in enumerate(requires_grad):
-            if isinstance(args[i], torch.Tensor):
-                if grad and is_integer_dtype(args[i].dtype):
-                    return False
+        with torch._dynamo.utils.detect_fake_mode(args):
+            for i, grad in enumerate(requires_grad):
+                if isinstance(args[i], torch.Tensor):
+                    if grad and is_integer_dtype(args[i].dtype):
+                        return False
 
-                with torch._dynamo.utils.detect_fake_mode(args):
                     args[i] = torch.empty_strided(
                         args[i].size(),
                         args[i].stride(),
@@ -887,27 +895,32 @@
                         device=args[i].device,
                         requires_grad=grad,
                     )
-        specific_graph = trace_fn(search_fn, args)
-        specific_pattern = fx_to_pattern(
-            specific_graph, argnames=argnames, exclusive_arg_names=exclusive_arg_names  # type: ignore[has-type]
-        )
-        specific_pattern_match = specific_pattern.match(match.output_nodes()[0])
-        if specific_pattern_match and extra_check(specific_pattern_match):
-            # trace the pattern using the shapes form the user program
-            match.replacement_graph = trace_fn(replace_fn, args)
-            return True
-        return False
+            specific_graph = trace_fn(search_fn, args)
+            specific_pattern = fx_to_pattern(
+                specific_graph,
+                argnames=argnames,
+                exclusive_arg_names=exclusive_arg_names,  # type: ignore[has-type]
+                scalar_workaround=scalar_workaround,
+            )
+            specific_pattern_match = specific_pattern.match(match.output_nodes()[0])
+            if specific_pattern_match and extra_check(specific_pattern_match):
+                # trace the pattern using the shapes from the user program
+                match.replacement_graph = trace_fn(replace_fn, args)
+                return True
+            return False
 
     def normalize_args(**kwargs):
         args = []
         for name in argnames:  # type: ignore[has-type]
             args.append(kwargs.pop(name))
         for i in range(1, len(kwargs) + 1):
+            if f"tangents_{i}" not in kwargs:
+                break
             args.append(kwargs.pop(f"tangents_{i}"))
         assert not kwargs, f"leftover kwargs: {kwargs!r}"
         return args
 
-    if trace_fn is training_graph:
+    if trace_fn is joint_fwd_bwd:
         # If inference mode is enabled during compilation, assume that we don't
         # want to match on any training graph patterns
         if torch.is_inference_mode_enabled():
@@ -915,7 +928,6 @@
 
     # TODO: Revisit the functionalize_rng_ops for lowmem dropout
     with functorch_config.patch(functionalize_rng_ops=False):
-        argnames = [*inspect.signature(search_fn).parameters.keys()]
         requires_grad: List[bool] = [
             isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
         ]
@@ -938,7 +950,7 @@
             extra_check=check_fn,
             normalize_args=normalize_args,
         )
-        pattern.register(pass_dict)
+        pattern.register(pass_dicts)
         return pattern.pattern
 
 
@@ -947,7 +959,20 @@
     search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=()
 ) -> PatternExpr:
     argnames = [*inspect.signature(search_fn).parameters.keys()]
-    search_gm = trace_fn(search_fn, example_inputs)
+
+    if scalar_workaround == ():
+        scalar_workaround = {}
+    flat_inputs = []
+    input_idx = 0  # Positional arguments index
+
+    for argname in argnames:
+        if argname in scalar_workaround:
+            flat_inputs.append(scalar_workaround[argname])
+        else:
+            flat_inputs.append(example_inputs[input_idx])
+            input_idx += 1
+
+    search_gm = trace_fn(search_fn, flat_inputs)
     return fx_to_pattern(
         search_gm,
         ignore_types=(int, float, list, torch.device, torch.dtype),
@@ -1175,7 +1200,7 @@
 
 
 @torch.no_grad()
-def inference_graph(fn, args) -> torch.fx.GraphModule:
+def fwd_only(fn, args) -> torch.fx.GraphModule:
     """Build a normalized inference graph, for use with fx_to_pattern"""
     # TODO - look into using aot autograd, asserting no mutating ops here
     with enable_python_dispatcher():
@@ -1186,7 +1211,7 @@
 
 
 @torch.enable_grad()
-def training_graph(fn, args) -> torch.fx.GraphModule:
+def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule:
     """Build a normalized training graph, for use with fx_to_pattern"""
     gm: Optional[torch.fx.GraphModule] = None