[PT2][Inductor][3/n] Customize pre grad and post grad patterns (#121915)
Summary: Currently, we only enabled the group batch fusion customization, we also enable the split cat customization.
Test Plan:
```
buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split --model_type "cmf" --flow_id 524546542
```
P1196013839
Differential Revision: D54861682
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121915
Approved by: https://github.com/jackiexu1992
diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py
index de0e38a..6457547 100644
--- a/test/inductor/test_split_cat_fx_passes.py
+++ b/test/inductor/test_split_cat_fx_passes.py
@@ -739,7 +739,7 @@
torch.testing.assert_close(actual, expected)
self.assertEqual(
- counters["inductor"]["split_squeeze_replaced"],
+ counters["inductor"]["split_cat_pass"],
split_squeeze_replaced,
)
counters.clear()
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index c1da004..bf0b948 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -138,13 +138,6 @@
"batch_tanh": {},
"batch_relu": {},
"batch_sigmoid": {},
- "normalization_pass": {},
- "remove_split_with_size_one_pass": {},
- "merge_getitem_cat_pass": {},
- "merge_stack_tahn_unbind_pass": {},
- "merge_splits_pass": {},
- "mutate_cat_pass": {},
- "split_cat_pass": {},
}
# Post grad fusion and options, set to empty dict to disable fusion.
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 0d793ec..222c916 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -12,7 +12,7 @@
import torch.utils._pytree as pytree
from torch import fx
from torch._decomp import register_decomposition
-from torch._dynamo.utils import optimus_scuba_log
+from torch._dynamo.utils import counters, optimus_scuba_log
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
@@ -45,13 +45,16 @@
from ..virtualized import V
from .ddp_fusion import fuse_ddp_communication
from .group_batch_fusion import group_batch_fusion_passes
+from .pre_grad import is_same_dict, save_inductor_dict
from .reinplace import reinplace_inplaceable_ops
+from .split_cat import POST_GRAD_PATTERNS
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims
+pattern_matcher_passes = POST_GRAD_PATTERNS.values()
# First pass_patterns[0] are applied, then [1], then [2]
pass_patterns = [
PatternMatcherPass(),
@@ -89,6 +92,15 @@
remove_noop_ops(gm.graph)
for patterns in pass_patterns:
patterns.apply(gm.graph) # type: ignore[arg-type]
+ for pattern_matcher_pass in pattern_matcher_passes:
+ inductor_before_change = save_inductor_dict(
+ [pattern_matcher_pass.pass_name]
+ )
+ pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
+ if not is_same_dict(counters["inductor"], inductor_before_change):
+ optimus_scuba_log[
+ f"{pattern_matcher_pass.pass_name}_post_grad"
+ ] = upload_graph(gm.graph)
if is_inference:
inference_patterns.apply(gm.graph) # type: ignore[arg-type]
decompose_mm_pass.apply(gm.graph) # type: ignore[arg-type]
diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py
index dbec853..df9dd0b 100644
--- a/torch/_inductor/fx_passes/pre_grad.py
+++ b/torch/_inductor/fx_passes/pre_grad.py
@@ -1,6 +1,6 @@
import copy
import logging
-from typing import Dict, List, Optional
+from typing import List, Optional
import torch
import torch.nn as nn
@@ -25,6 +25,7 @@
from ..utils import is_cpu_device, pass_execution_and_save
from .group_batch_fusion import group_batch_fusion_passes
from .misc_patterns import numpy_compat_normalization
+from .split_cat import PRE_GRAD_PATTERNS
log = logging.getLogger(__name__)
@@ -75,10 +76,6 @@
return True
-def construct_pattern_matcher_pass(pass_name):
- return PatternMatcherPass(prevent_match_across_mutations=True, pass_name=pass_name)
-
-
def fuse_parallel_linear_pass(graph):
return None
@@ -87,20 +84,6 @@
return None
-PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = dict()
-pass_names = [
- "normalization_pass",
- "remove_split_with_size_one_pass",
- "merge_getitem_cat_pass",
- "merge_stack_tahn_unbind_pass",
- "merge_splits_pass",
- "mutate_cat_pass",
- "split_cat_pass",
- "unbind_stack_pass",
-]
-
-for pass_name in pass_names:
- PRE_GRAD_PATTERNS[pass_name] = construct_pattern_matcher_pass(pass_name)
# split_cat related fusions
pattern_matcher_passes = list(PRE_GRAD_PATTERNS.values())
# non-split_cat related fusions
diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py
index b2c31d7..4ebefd8 100644
--- a/torch/_inductor/fx_passes/split_cat.py
+++ b/torch/_inductor/fx_passes/split_cat.py
@@ -1,19 +1,19 @@
import itertools
import logging
import operator
-from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing_extensions import TypeAlias
import torch
from torch._dynamo.utils import counters
+from .. import config
from ..pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethodVarArgs,
- config_flag,
FailedMatch,
get_arg_value,
Ignored,
@@ -23,11 +23,11 @@
MatchContext,
MULTIPLE,
PatternExpr,
+ PatternMatcherPass,
register_graph_pattern,
RepeatedExpr,
)
-from .group_batch_fusion import is_node_meta_valid
-from .pre_grad import PRE_GRAD_PATTERNS
+from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS
log = logging.getLogger(__name__)
@@ -41,6 +41,64 @@
_Range: TypeAlias = Tuple[int, int]
+PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = dict()
+POST_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = dict()
+
+# TODO: read the pass_names from the config after the frontend change
+pass_names = [
+ "normalization_pass",
+ "remove_split_with_size_one_pass",
+ "merge_getitem_cat_pass",
+ "merge_stack_tahn_unbind_pass",
+ "merge_splits_pass",
+ "mutate_cat_pass",
+ "split_cat_pass",
+ "unbind_stack_pass",
+]
+
+for pass_name in pass_names:
+ # exclude all passes from the group batch fusion
+ # they do not use pattern matcher
+ if pass_name in PRE_GRAD_FUSIONS or pass_name in POST_GRAD_FUSIONS:
+ continue
+ PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass(
+ prevent_match_across_mutations=True,
+ pass_name=pass_name,
+ )
+
+
+def construct_pattern_matcher_pass(pass_name: str) -> PatternMatcherPass:
+ """
+ Return the specific pattern_matcher_pass given the pass name.
+ """
+ if pass_name in PRE_GRAD_PATTERNS:
+ return PRE_GRAD_PATTERNS[pass_name]
+ elif pass_name in POST_GRAD_PATTERNS:
+ return POST_GRAD_PATTERNS[pass_name]
+ else:
+ # pattern that does not in the config, will
+ # not be conduted in the optimization
+ return PatternMatcherPass(
+ prevent_match_across_mutations=True,
+ pass_name=pass_name,
+ )
+
+
+def get_config_flag(pass_name: str, flag="split_cat_fx_passes", pre_grad=True):
+ def flag_check(match):
+ # TODO: remove the flag config check after we have the front end change
+ # currently, pre_grad_fusion_options and post_grad_fusion_options are only have batch fusion
+ # options controlled by the batch_fusion flag, after we extend it to indluce other fusions,
+ # we can only check if the pass_name is in the config
+ if pre_grad:
+ # not to disturb models without the config flag is turned off
+ return getattr(config, flag)
+ else:
+ return getattr(config, flag)
+
+ return flag_check
+
+
def _get_split_args_default(split_node):
input_kwarg = "tensor"
split_size_kwarg = "split_size_or_sections"
@@ -148,13 +206,13 @@
@register_graph_pattern(
CallFunctionVarArgs(torch.split, users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+ extra_check=get_config_flag("normalization_pass"),
)
@register_graph_pattern(
CallMethodVarArgs("split", users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+ extra_check=get_config_flag("normalization_pass"),
)
def normalize_split_default(match: Match, *args, **kwargs):
return normalize_split_base(match, _get_split_args_default)
@@ -162,13 +220,13 @@
@register_graph_pattern(
CallFunctionVarArgs(torch.split, users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["remove_split_with_size_one_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
+ extra_check=get_config_flag("remove_split_with_size_one_pass"),
)
@register_graph_pattern(
CallMethodVarArgs("split", users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["remove_split_with_size_one_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"),
+ extra_check=get_config_flag("remove_split_with_size_one_pass"),
)
def remove_split_with_size_one(match: Match, *args, **kwargs):
graph = match.graph
@@ -202,13 +260,13 @@
@register_graph_pattern(
CallFunctionVarArgs(torch.unbind, users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+ extra_check=get_config_flag("normalization_pass"),
)
@register_graph_pattern(
CallMethodVarArgs("unbind", users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+ extra_check=get_config_flag("normalization_pass"),
)
def normalize_unbind_default(match: Match, *args, **kwargs):
node = match.nodes[0]
@@ -244,8 +302,8 @@
@register_graph_pattern(
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+ extra_check=get_config_flag("normalization_pass"),
)
def normalize_cat_default(match: Match, *args, **kwargs):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@@ -306,8 +364,8 @@
@register_graph_pattern(
CallFunctionVarArgs(torch.stack, users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+ extra_check=get_config_flag("normalization_pass"),
)
def normalize_stack_default(match: Match, *args, **kwargs):
node = match.nodes[0]
@@ -352,8 +410,8 @@
@register_graph_pattern(
CallMethodVarArgs("squeeze", users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["normalization_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("normalization_pass"),
+ extra_check=get_config_flag("normalization_pass"),
)
def normalize_squeeze_default(match: Match, *args, **kwargs):
squeeze_node = match.nodes[0]
@@ -435,8 +493,8 @@
),
KeywordArg("next_split_sections"),
),
- pass_dict=PRE_GRAD_PATTERNS["merge_splits_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("merge_splits_pass"),
+ extra_check=get_config_flag("merge_splits_pass"),
)
def merge_splits(
match: Match,
@@ -1041,8 +1099,8 @@
_users=MULTIPLE,
),
),
- pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+ extra_check=get_config_flag("split_cat_pass"),
)
@register_graph_pattern(
RepeatedExpr(
@@ -1059,8 +1117,8 @@
_users=MULTIPLE,
)
),
- pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+ extra_check=get_config_flag("split_cat_pass"),
)
def merge_split_squeeze(
match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int
@@ -1093,7 +1151,7 @@
graph.erase_node(squeeze)
graph.erase_node(getitem_node)
graph.erase_node(split)
- counters["inductor"]["split_squeeze_replaced"] += 1
+ counters["inductor"]["split_cat_pass"] += 1
getitem_unbind = ListOf(
@@ -1113,22 +1171,22 @@
@register_graph_pattern(
CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
- pass_dict=PRE_GRAD_PATTERNS["unbind_stack_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+ extra_check=get_config_flag("unbind_stack_pass"),
)
@register_graph_pattern(
CallFunction(
[torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
),
- pass_dict=PRE_GRAD_PATTERNS["unbind_stack_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+ extra_check=get_config_flag("unbind_stack_pass"),
)
@register_graph_pattern(
CallFunction(
[torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
),
- pass_dict=PRE_GRAD_PATTERNS["unbind_stack_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
+ extra_check=get_config_flag("unbind_stack_pass"),
)
def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
@@ -1156,8 +1214,8 @@
dim=Ignored(),
_users=MULTIPLE,
),
- pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+ extra_check=get_config_flag("split_cat_pass"),
)
@register_graph_pattern(
CallFunction(
@@ -1166,8 +1224,8 @@
dim=Ignored(),
_users=MULTIPLE,
),
- pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+ extra_check=get_config_flag("split_cat_pass"),
)
@register_graph_pattern(
CallFunction(
@@ -1176,8 +1234,8 @@
Ignored(),
_users=MULTIPLE,
),
- pass_dict=PRE_GRAD_PATTERNS["split_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("split_cat_pass"),
+ extra_check=get_config_flag("split_cat_pass"),
)
def simplify_split_cat(match: Match, split_sections: List[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
@@ -1261,8 +1319,8 @@
dim=Ignored(),
_users=MULTIPLE,
),
- pass_dict=PRE_GRAD_PATTERNS["merge_getitem_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"),
+ extra_check=get_config_flag("merge_getitem_cat_pass"),
)
def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
@@ -1369,8 +1427,8 @@
dim=Ignored(),
_users=MULTIPLE,
),
- pass_dict=PRE_GRAD_PATTERNS["mutate_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"),
+ extra_check=get_config_flag("mutate_cat_pass"),
)
def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
@@ -1465,8 +1523,8 @@
dim=Ignored(),
),
),
- pass_dict=PRE_GRAD_PATTERNS["merge_getitem_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("merge_stack_tahn_unbind_pass"),
+ extra_check=get_config_flag("merge_stack_tahn_unbind_pass"),
)
@register_graph_pattern(
CallFunction(
@@ -1477,8 +1535,8 @@
dim=Ignored(),
),
),
- pass_dict=PRE_GRAD_PATTERNS["merge_getitem_cat_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("merge_stack_tahn_unbind_pass"),
+ extra_check=get_config_flag("merge_stack_tahn_unbind_pass"),
)
@register_graph_pattern(
CallFunction(
@@ -1489,8 +1547,8 @@
Ignored(),
),
),
- pass_dict=PRE_GRAD_PATTERNS["merge_stack_tahn_unbind_pass"],
- extra_check=config_flag("split_cat_fx_passes"),
+ pass_dict=construct_pattern_matcher_pass("merge_stack_tahn_unbind_pass"),
+ extra_check=get_config_flag("merge_stack_tahn_unbind_pass"),
)
def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split