Some cleanups in pattern matcher (#112101)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112101
Approved by: https://github.com/eellison
ghstack dependencies: #112093
diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py
index a27be82..d1f7abf 100644
--- a/test/inductor/test_pattern_matcher.py
+++ b/test/inductor/test_pattern_matcher.py
@@ -26,6 +26,24 @@
class TestPaternMatcher(TestCase):
+ def common(self, fn, args, expected_matches, expected_nodes):
+ counters.clear()
+ torch.manual_seed(42)
+ expected = fn(*args)
+ torch.manual_seed(42)
+ actual = torch.compile(fn)(*args)
+ torch.testing.assert_close(actual, expected)
+ if inductor_config.cpp_wrapper:
+ # CPP wrapper runs everything twice, so we'll match the pattern twice
+ expected_matches *= 2
+ expected_nodes *= 2
+
+ self.assertEqual(
+ counters["inductor"]["pattern_matcher_count"], expected_matches
+ )
+ self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], expected_nodes)
+ counters.clear()
+
def test_mm_plus_mm(self):
def fn(a, b, c, d):
return torch.add(torch.mm(a, b), torch.mm(c, d))
@@ -58,12 +76,7 @@
),
]
for args in args_list:
- counters.clear()
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
+ self.common(fn, args, 1, 3)
def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True):
torch._dynamo.reset()
@@ -467,11 +480,7 @@
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5)
+ self.common(fn, args, 2, 5)
def test_cat_addmm(self):
def fn(a, b, c):
@@ -489,11 +498,7 @@
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5)
+ self.common(fn, args, 2, 5)
def test_cat_slice_cat(self):
def check_counter(counter, expected):
@@ -513,17 +518,13 @@
torch.randn(2, 32, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- check_counter(counters["inductor"]["pattern_matcher_count"], 1)
- check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
+ self.common(fn, args, 1, 3)
- counters.clear()
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
+ counters.clear()
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
@@ -539,16 +540,11 @@
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, -1)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)
- counters.clear()
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- check_counter(counters["inductor"]["pattern_matcher_count"], 1)
- check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
+ self.common(fn, args, 1, 3)
def test_pointless_convert(self):
def fn1(x):
@@ -624,12 +620,7 @@
args = [
torch.randn(2, 32, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
- counters.clear()
+ self.common(fn, args, 1, 4)
# Not all getitems are passed to cat
def fn(a):
@@ -643,12 +634,7 @@
args = [
torch.randn(2, 32, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
- counters.clear()
+ self.common(fn, args, 0, 0)
# Different dimensions (TODO this case should be handled by replacing with a reshape)
def fn(a):
@@ -661,11 +647,7 @@
args = [
torch.randn(2, 32, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
+ self.common(fn, args, 0, 0)
# https://github.com/pytorch/pytorch/issues/99686.
def fn(a):
@@ -676,11 +658,7 @@
args = [
torch.randn(1, 8, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
+ self.common(fn, args, 0, 0)
def test_cat_splitwithsizes(self):
# good case
@@ -696,12 +674,7 @@
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 2)
- counters.clear()
+ self.common(fn, args, 1, 2)
# cat node has other users
def fn(a, b, c):
@@ -716,12 +689,7 @@
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
- counters.clear()
+ self.common(fn, args, 0, 0)
# cat and split dims are different
def fn(a, b, c):
@@ -736,12 +704,7 @@
torch.randn(10, 3, device="cuda"),
torch.randn(10, 5, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
- counters.clear()
+ self.common(fn, args, 0, 0)
# cat and split lenghts are different
def fn(a, b, c):
@@ -754,12 +717,7 @@
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
- counters.clear()
+ self.common(fn, args, 0, 0)
# cat input sizes and split sizes are different
def fn(a, b, c):
@@ -774,12 +732,7 @@
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
- expected = fn(*args)
- actual = torch.compile(fn)(*args)
- torch.testing.assert_close(actual, expected)
- self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
- self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
- counters.clear()
+ self.common(fn, args, 0, 0)
def test_match_with_mutation(self):
from torch._inductor.pattern_matcher import (
diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py
index dc0eb65..cafd7aa 100644
--- a/torch/_inductor/fx_passes/freezing_patterns.py
+++ b/torch/_inductor/fx_passes/freezing_patterns.py
@@ -8,8 +8,8 @@
from ..pattern_matcher import (
_return_true,
CallFunction,
+ fwd_only,
Ignored,
- inference_graph,
init_once_fakemode,
KeywordArg,
Match,
@@ -144,7 +144,7 @@
matmul_fuse_pattern,
matmul_replacement,
[val(), val(), val(), val()],
- inference_graph,
+ fwd_only,
pass_patterns[0],
extra_check=check_concat_weights,
exclusive_arg_names=("w1", "w2", "w3"),
@@ -162,7 +162,7 @@
matmul_fuse_pattern_two,
matmul_replacement_two,
[val(), val(), val()],
- inference_graph,
+ fwd_only,
pass_patterns[0],
extra_check=check_concat_weights,
exclusive_arg_names=("w1", "w2"),
@@ -184,7 +184,7 @@
addmm_fuse_pattern_second,
addmm_fuse_replacement_second,
[val() for _ in range(7)],
- inference_graph,
+ fwd_only,
pass_patterns[0],
extra_check=check_concat_weights,
exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py
index 8001d58..207f645 100644
--- a/torch/_inductor/fx_passes/fuse_attention.py
+++ b/torch/_inductor/fx_passes/fuse_attention.py
@@ -7,9 +7,9 @@
from ..._dynamo.utils import counters
from ..pattern_matcher import (
filter_nodes,
- inference_graph,
+ fwd_only,
+ joint_fwd_bwd,
register_replacement,
- training_graph,
)
log = logging.getLogger(__name__)
@@ -513,7 +513,6 @@
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
# gets serialized to a python file and does not require tracing at runtime.
assert isinstance(workaround, dict)
- training_args = [*args, *workaround.values()]
name = pattern.__name__
training_name = (
@@ -522,9 +521,9 @@
yield training_name, {
"search_fn": pattern,
"replace_fn": replacement,
- "example_inputs": training_args,
- "trace_fn": training_graph,
- "pass_dict": patterns,
+ "example_inputs": args,
+ "trace_fn": joint_fwd_bwd,
+ "pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
}
@@ -547,8 +546,8 @@
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
- "trace_fn": inference_graph,
- "pass_dict": patterns,
+ "trace_fn": fwd_only,
+ "pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
}
diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py
index 7f01bdc..3b6d6d1 100644
--- a/torch/_inductor/fx_passes/pad_mm.py
+++ b/torch/_inductor/fx_passes/pad_mm.py
@@ -8,12 +8,7 @@
from torch.utils._mode_utils import no_dispatch
from torch.utils._triton import has_triton
-from ..pattern_matcher import (
- inference_graph,
- Match,
- register_replacement,
- training_graph,
-)
+from ..pattern_matcher import fwd_only, joint_fwd_bwd, Match, register_replacement
aten = torch.ops.aten
@@ -453,12 +448,11 @@
),
]:
assert isinstance(workaround, dict) # mypy is unable to infer the type properly
- args = [*args, *workaround.values()]
register_replacement(
pattern,
replacement,
args,
- training_graph,
+ joint_fwd_bwd,
patterns,
extra_check=extra_check,
scalar_workaround=workaround,
@@ -467,7 +461,7 @@
pattern,
replacement,
args,
- inference_graph,
+ fwd_only,
patterns,
extra_check=extra_check,
scalar_workaround=workaround,
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 1defdab..7b2797f 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -566,6 +566,11 @@
return len(inputs) == 1
+@register_noop_decomp(aten.view)
+def view_noop(arg, size):
+ return arg.shape == size
+
+
# Note, we also always have a check for identical metadata, which is why these
# are safe
@register_noop_decomp([aten.copy], nop_arg=1)
@@ -576,9 +581,7 @@
def remove_noop_ops(graph: torch.fx.Graph):
"""
- Removes aten.clone and aten.alias ops from the graph when it's safe.
-
- Other no-ops should be done as decompositions that selectively turn into aten.clone or aten.alias
+ Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
"""
input_storages = set()
output_storages = set()
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index c90d178..67fb384 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -124,7 +124,7 @@
def replace_by_example(self, replacement_fn, args, trace_fn=None):
assert self.ctx
if trace_fn is None:
- trace_fn = inference_graph
+ trace_fn = fwd_only
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
)
@@ -842,7 +842,7 @@
replace_fn,
example_inputs: Iterable[Any],
trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
- pass_dict,
+ pass_dicts,
extra_check=_return_true,
scalar_workaround=(),
exclusive_arg_names=(),
@@ -857,10 +857,11 @@
search_fn: traced to give original pattern
replace_fn: traced to give replacement graph
example_inputs: example inputs for initial trace
- trace_fn: inference_graph or training_graph
+ trace_fn: fwd_only or joint_fwd_bwd
pass_dict: dict of passes to register to
extra_check: additional check to run on match(using real shapes)
"""
+ argnames = [*inspect.signature(search_fn).parameters.keys()]
def check_fn(match: Match):
"""
@@ -869,17 +870,24 @@
Recheck the match with the correct shapes.
"""
+ for name in argnames:
+ if name not in match.kwargs:
+ raise RuntimeError(
+ f"Not all inputs to pattern found in match.kwargs. Perhaps one "
+ f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
+ )
+
args = list(
torch.fx.map_arg(
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"] # type: ignore[has-type]
)
)
- for i, grad in enumerate(requires_grad):
- if isinstance(args[i], torch.Tensor):
- if grad and is_integer_dtype(args[i].dtype):
- return False
+ with torch._dynamo.utils.detect_fake_mode(args):
+ for i, grad in enumerate(requires_grad):
+ if isinstance(args[i], torch.Tensor):
+ if grad and is_integer_dtype(args[i].dtype):
+ return False
- with torch._dynamo.utils.detect_fake_mode(args):
args[i] = torch.empty_strided(
args[i].size(),
args[i].stride(),
@@ -887,27 +895,32 @@
device=args[i].device,
requires_grad=grad,
)
- specific_graph = trace_fn(search_fn, args)
- specific_pattern = fx_to_pattern(
- specific_graph, argnames=argnames, exclusive_arg_names=exclusive_arg_names # type: ignore[has-type]
- )
- specific_pattern_match = specific_pattern.match(match.output_nodes()[0])
- if specific_pattern_match and extra_check(specific_pattern_match):
- # trace the pattern using the shapes form the user program
- match.replacement_graph = trace_fn(replace_fn, args)
- return True
- return False
+ specific_graph = trace_fn(search_fn, args)
+ specific_pattern = fx_to_pattern(
+ specific_graph,
+ argnames=argnames,
+ exclusive_arg_names=exclusive_arg_names, # type: ignore[has-type]
+ scalar_workaround=scalar_workaround,
+ )
+ specific_pattern_match = specific_pattern.match(match.output_nodes()[0])
+ if specific_pattern_match and extra_check(specific_pattern_match):
+ # trace the pattern using the shapes from the user program
+ match.replacement_graph = trace_fn(replace_fn, args)
+ return True
+ return False
def normalize_args(**kwargs):
args = []
for name in argnames: # type: ignore[has-type]
args.append(kwargs.pop(name))
for i in range(1, len(kwargs) + 1):
+ if f"tangents_{i}" not in kwargs:
+ break
args.append(kwargs.pop(f"tangents_{i}"))
assert not kwargs, f"leftover kwargs: {kwargs!r}"
return args
- if trace_fn is training_graph:
+ if trace_fn is joint_fwd_bwd:
# If inference mode is enabled during compilation, assume that we don't
# want to match on any training graph patterns
if torch.is_inference_mode_enabled():
@@ -915,7 +928,6 @@
# TODO: Revisit the functionalize_rng_ops for lowmem dropout
with functorch_config.patch(functionalize_rng_ops=False):
- argnames = [*inspect.signature(search_fn).parameters.keys()]
requires_grad: List[bool] = [
isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
]
@@ -938,7 +950,7 @@
extra_check=check_fn,
normalize_args=normalize_args,
)
- pattern.register(pass_dict)
+ pattern.register(pass_dicts)
return pattern.pattern
@@ -947,7 +959,20 @@
search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=()
) -> PatternExpr:
argnames = [*inspect.signature(search_fn).parameters.keys()]
- search_gm = trace_fn(search_fn, example_inputs)
+
+ if scalar_workaround == ():
+ scalar_workaround = {}
+ flat_inputs = []
+ input_idx = 0 # Positional arguments index
+
+ for argname in argnames:
+ if argname in scalar_workaround:
+ flat_inputs.append(scalar_workaround[argname])
+ else:
+ flat_inputs.append(example_inputs[input_idx])
+ input_idx += 1
+
+ search_gm = trace_fn(search_fn, flat_inputs)
return fx_to_pattern(
search_gm,
ignore_types=(int, float, list, torch.device, torch.dtype),
@@ -1175,7 +1200,7 @@
@torch.no_grad()
-def inference_graph(fn, args) -> torch.fx.GraphModule:
+def fwd_only(fn, args) -> torch.fx.GraphModule:
"""Build a normalized inference graph, for use with fx_to_pattern"""
# TODO - look into using aot autograd, asserting no mutating ops here
with enable_python_dispatcher():
@@ -1186,7 +1211,7 @@
@torch.enable_grad()
-def training_graph(fn, args) -> torch.fx.GraphModule:
+def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule:
"""Build a normalized training graph, for use with fx_to_pattern"""
gm: Optional[torch.fx.GraphModule] = None