[PT2][Inductor][3/n] Customize pre grad and post grad patterns (#121915)

Summary: Currently, we only enabled the group batch fusion customization, we also enable the split cat customization.

Test Plan:
```
buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split --model_type "cmf" --flow_id 524546542
```
P1196013839

Differential Revision: D54861682

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121915
Approved by: https://github.com/jackiexu1992
diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py
index de0e38a..6457547 100644
--- a/test/inductor/test_split_cat_fx_passes.py
+++ b/test/inductor/test_split_cat_fx_passes.py
@@ -739,7 +739,7 @@
 
             torch.testing.assert_close(actual, expected)
             self.assertEqual(
-                counters["inductor"]["split_squeeze_replaced"],
+                counters["inductor"]["split_cat_pass"],
                 split_squeeze_replaced,
             )
             counters.clear()
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index c1da004..bf0b948 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -138,13 +138,6 @@
     "batch_tanh": {},
     "batch_relu": {},
     "batch_sigmoid": {},
-    "normalization_pass": {},
-    "remove_split_with_size_one_pass": {},
-    "merge_getitem_cat_pass": {},
-    "merge_stack_tahn_unbind_pass": {},
-    "merge_splits_pass": {},
-    "mutate_cat_pass": {},
-    "split_cat_pass": {},
 }
 
 # Post grad fusion and options, set to empty dict to disable fusion.
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 0d793ec..222c916 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -12,7 +12,7 @@
 import torch.utils._pytree as pytree
 from torch import fx
 from torch._decomp import register_decomposition
-from torch._dynamo.utils import optimus_scuba_log
+from torch._dynamo.utils import counters, optimus_scuba_log
 
 from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
 
@@ -45,13 +45,16 @@
 from ..virtualized import V
 from .ddp_fusion import fuse_ddp_communication
 from .group_batch_fusion import group_batch_fusion_passes
+from .pre_grad import is_same_dict, save_inductor_dict
 from .reinplace import reinplace_inplaceable_ops
+from .split_cat import POST_GRAD_PATTERNS
 
 
 log = logging.getLogger(__name__)
 aten = torch.ops.aten
 prims = torch.ops.prims
 
+pattern_matcher_passes = POST_GRAD_PATTERNS.values()
 # First pass_patterns[0] are applied, then [1], then [2]
 pass_patterns = [
     PatternMatcherPass(),
@@ -89,6 +92,15 @@
         remove_noop_ops(gm.graph)
         for patterns in pass_patterns:
             patterns.apply(gm.graph)  # type: ignore[arg-type]
+        for pattern_matcher_pass in pattern_matcher_passes:
+            inductor_before_change = save_inductor_dict(
+                [pattern_matcher_pass.pass_name]
+            )
+            pattern_matcher_pass.apply(gm.graph)  # type: ignore[arg-type]
+            if not is_same_dict(counters["inductor"], inductor_before_change):
+                optimus_scuba_log[
+                    f"{pattern_matcher_pass.pass_name}_post_grad"
+                ] = upload_graph(gm.graph)
         if is_inference:
             inference_patterns.apply(gm.graph)  # type: ignore[arg-type]
         decompose_mm_pass.apply(gm.graph)  # type: ignore[arg-type]
diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py
index dbec853..df9dd0b 100644
--- a/torch/_inductor/fx_passes/pre_grad.py
+++ b/torch/_inductor/fx_passes/pre_grad.py
@@ -1,6 +1,6 @@
 import copy
 import logging
-from typing import Dict, List, Optional
+from typing import List, Optional
 
 import torch
 import torch.nn as nn
@@ -25,6 +25,7 @@
 from ..utils import is_cpu_device, pass_execution_and_save
 from .group_batch_fusion import group_batch_fusion_passes
 from .misc_patterns import numpy_compat_normalization
+from .split_cat import PRE_GRAD_PATTERNS
 
 log = logging.getLogger(__name__)
 
@@ -75,10 +76,6 @@
     return True
 
 
-def construct_pattern_matcher_pass(pass_name):
-    return PatternMatcherPass(prevent_match_across_mutations=True, pass_name=pass_name)
-
-
 def fuse_parallel_linear_pass(graph):
     return None
 
@@ -87,20 +84,6 @@
     return None
 
 
-PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = dict()
-pass_names = [
-    "normalization_pass",
-    "remove_split_with_size_one_pass",
-    "merge_getitem_cat_pass",
-    "merge_stack_tahn_unbind_pass",
-    "merge_splits_pass",
-    "mutate_cat_pass",
-    "split_cat_pass",
-    "unbind_stack_pass",
-]
-
-for pass_name in pass_names:
-    PRE_GRAD_PATTERNS[pass_name] = construct_pattern_matcher_pass(pass_name)
 # split_cat related fusions
 pattern_matcher_passes = list(PRE_GRAD_PATTERNS.values())
 # non-split_cat related fusions
diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py
index b2c31d7..4ebefd8 100644
--- a/torch/_inductor/fx_passes/split_cat.py
+++ b/torch/_inductor/fx_passes/split_cat.py
@@ -1,19 +1,19 @@
 import itertools
 import logging
 import operator
-from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from typing_extensions import TypeAlias
 
 import torch
 from torch._dynamo.utils import counters
+from .. import config
 
 from ..pattern_matcher import (
     Arg,
     CallFunction,
     CallFunctionVarArgs,
     CallMethodVarArgs,
-    config_flag,
     FailedMatch,
     get_arg_value,
     Ignored,
@@ -23,11 +23,11 @@
     MatchContext,
     MULTIPLE,
     PatternExpr,
+    PatternMatcherPass,
     register_graph_pattern,
     RepeatedExpr,
 )
-from .group_batch_fusion import is_node_meta_valid
-from .pre_grad import PRE_GRAD_PATTERNS
+from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS
 
 log = logging.getLogger(__name__)
 
@@ -41,6 +41,64 @@
 _Range: TypeAlias = Tuple[int, int]
 
 
+PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = dict()
+POST_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = dict()
+
+# TODO: read the pass_names from the config after the frontend change
+pass_names = [
+    "normalization_pass",
+    "remove_split_with_size_one_pass",
+    "merge_getitem_cat_pass",
+    "merge_stack_tahn_unbind_pass",
+    "merge_splits_pass",
+    "mutate_cat_pass",
+    "split_cat_pass",
+    "unbind_stack_pass",
+]
+
+for pass_name in pass_names:
+    # exclude all passes from the group batch fusion
+    # they do not use pattern matcher
+    if pass_name in PRE_GRAD_FUSIONS or pass_name in POST_GRAD_FUSIONS:
+        continue
+    PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
+        prevent_match_across_mutations=True,
+        pass_name=pass_name,
+    )
+
+
+def construct_pattern_matcher_pass(pass_name: str) -> PatternMatcherPass:
+    """
+    Return the specific pattern_matcher_pass given the pass name.
+    """
+    if pass_name in PRE_GRAD_PATTERNS:
+        return PRE_GRAD_PATTERNS[pass_name]
+    elif pass_name in POST_GRAD_PATTERNS:
+        return POST_GRAD_PATTERNS[pass_name]
+    else:
+        # pattern that does not in the config, will
+        # not be conduted in the optimization
+        return PatternMatcherPass(
+            prevent_match_across_mutations=True,
+            pass_name=pass_name,
+        )
+
+
+def get_config_flag(pass_name: str, flag="split_cat_fx_passes", pre_grad=True):
+    def flag_check(match):
+        # TODO: remove the flag config check after we have the front end change
+        # currently, pre_grad_fusion_options and post_grad_fusion_options are only have batch fusion
+        # options controlled by the batch_fusion flag, after we extend it to indluce other fusions,
+        # we can only check if the pass_name is in the config
+        if pre_grad:
+            # not to disturb models without the config flag is turned off
+            return getattr(config, flag)
+        else:
+            return getattr(config, flag)
+
+    return flag_check
+
+
 def _get_split_args_default(split_node):
     input_kwarg = "tensor"
     split_size_kwarg = "split_size_or_sections"
@@ -148,13 +206,13 @@
 
 @register_graph_pattern(
     CallFunctionVarArgs(torch.split, users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+    extra_check=get_config_flag("normalization_pass"),
 )
 @register_graph_pattern(
     CallMethodVarArgs("split", users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+    extra_check=get_config_flag("normalization_pass"),
 )
 def normalize_split_default(match: Match, *args, **kwargs):
     return normalize_split_base(match, _get_split_args_default)
@@ -162,13 +220,13 @@
 
 @register_graph_pattern(
     CallFunctionVarArgs(torch.split, users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["remove_split_with_size_one_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
+    extra_check=get_config_flag("remove_split_with_size_one_pass"),
 )
 @register_graph_pattern(
     CallMethodVarArgs("split", users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["remove_split_with_size_one_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
+    extra_check=get_config_flag("remove_split_with_size_one_pass"),
 )
 def remove_split_with_size_one(match: Match, *args, **kwargs):
     graph = match.graph
@@ -202,13 +260,13 @@
 
 @register_graph_pattern(
     CallFunctionVarArgs(torch.unbind, users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+    extra_check=get_config_flag("normalization_pass"),
 )
 @register_graph_pattern(
     CallMethodVarArgs("unbind", users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+    extra_check=get_config_flag("normalization_pass"),
 )
 def normalize_unbind_default(match: Match, *args, **kwargs):
     node = match.nodes[0]
@@ -244,8 +302,8 @@
 
 @register_graph_pattern(
     CallFunctionVarArgs(torch.cat, users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+    extra_check=get_config_flag("normalization_pass"),
 )
 def normalize_cat_default(match: Match, *args, **kwargs):
     from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@@ -306,8 +364,8 @@
 
 @register_graph_pattern(
     CallFunctionVarArgs(torch.stack, users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+    extra_check=get_config_flag("normalization_pass"),
 )
 def normalize_stack_default(match: Match, *args, **kwargs):
     node = match.nodes[0]
@@ -352,8 +410,8 @@
 
 @register_graph_pattern(
     CallMethodVarArgs("squeeze", users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+    extra_check=get_config_flag("normalization_pass"),
 )
 def normalize_squeeze_default(match: Match, *args, **kwargs):
     squeeze_node = match.nodes[0]
@@ -435,8 +493,8 @@
         ),
         KeywordArg("next_split_sections"),
     ),
-    pass_dict=PRE_GRAD_PATTERNS["merge_splits_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("merge_splits_pass"),
+    extra_check=get_config_flag("merge_splits_pass"),
 )
 def merge_splits(
     match: Match,
@@ -1041,8 +1099,8 @@
             _users=MULTIPLE,
         ),
     ),
-    pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+    extra_check=get_config_flag("split_cat_pass"),
 )
 @register_graph_pattern(
     RepeatedExpr(
@@ -1059,8 +1117,8 @@
             _users=MULTIPLE,
         )
     ),
-    pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+    extra_check=get_config_flag("split_cat_pass"),
 )
 def merge_split_squeeze(
     match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int
@@ -1093,7 +1151,7 @@
             graph.erase_node(squeeze)
             graph.erase_node(getitem_node)
     graph.erase_node(split)
-    counters["inductor"]["split_squeeze_replaced"] += 1
+    counters["inductor"]["split_cat_pass"] += 1
 
 
 getitem_unbind = ListOf(
@@ -1113,22 +1171,22 @@
 
 @register_graph_pattern(
     CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
-    pass_dict=PRE_GRAD_PATTERNS["unbind_stack_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+    extra_check=get_config_flag("unbind_stack_pass"),
 )
 @register_graph_pattern(
     CallFunction(
         [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
     ),
-    pass_dict=PRE_GRAD_PATTERNS["unbind_stack_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+    extra_check=get_config_flag("unbind_stack_pass"),
 )
 @register_graph_pattern(
     CallFunction(
         [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
     ),
-    pass_dict=PRE_GRAD_PATTERNS["unbind_stack_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+    extra_check=get_config_flag("unbind_stack_pass"),
 )
 def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
     unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
@@ -1156,8 +1214,8 @@
         dim=Ignored(),
         _users=MULTIPLE,
     ),
-    pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+    extra_check=get_config_flag("split_cat_pass"),
 )
 @register_graph_pattern(
     CallFunction(
@@ -1166,8 +1224,8 @@
         dim=Ignored(),
         _users=MULTIPLE,
     ),
-    pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+    extra_check=get_config_flag("split_cat_pass"),
 )
 @register_graph_pattern(
     CallFunction(
@@ -1176,8 +1234,8 @@
         Ignored(),
         _users=MULTIPLE,
     ),
-    pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+    extra_check=get_config_flag("split_cat_pass"),
 )
 def simplify_split_cat(match: Match, split_sections: List[int], dim: int):
     if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
@@ -1261,8 +1319,8 @@
         dim=Ignored(),
         _users=MULTIPLE,
     ),
-    pass_dict=PRE_GRAD_PATTERNS["merge_getitem_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"),
+    extra_check=get_config_flag("merge_getitem_cat_pass"),
 )
 def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
     if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
@@ -1369,8 +1427,8 @@
         dim=Ignored(),
         _users=MULTIPLE,
     ),
-    pass_dict=PRE_GRAD_PATTERNS["mutate_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"),
+    extra_check=get_config_flag("mutate_cat_pass"),
 )
 def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
     if not isinstance(split_sections, (list, tuple)):  # Unnormalized split
@@ -1465,8 +1523,8 @@
             dim=Ignored(),
         ),
     ),
-    pass_dict=PRE_GRAD_PATTERNS["merge_getitem_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("merge_stack_tahn_unbind_pass"),
+    extra_check=get_config_flag("merge_stack_tahn_unbind_pass"),
 )
 @register_graph_pattern(
     CallFunction(
@@ -1477,8 +1535,8 @@
             dim=Ignored(),
         ),
     ),
-    pass_dict=PRE_GRAD_PATTERNS["merge_getitem_cat_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("merge_stack_tahn_unbind_pass"),
+    extra_check=get_config_flag("merge_stack_tahn_unbind_pass"),
 )
 @register_graph_pattern(
     CallFunction(
@@ -1489,8 +1547,8 @@
             Ignored(),
         ),
     ),
-    pass_dict=PRE_GRAD_PATTERNS["merge_stack_tahn_unbind_pass"],
-    extra_check=config_flag("split_cat_fx_passes"),
+    pass_dict=construct_pattern_matcher_pass("merge_stack_tahn_unbind_pass"),
+    extra_check=get_config_flag("merge_stack_tahn_unbind_pass"),
 )
 def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
     if not isinstance(split_sections, (list, tuple)):  # Unnormalized split