Add support for call_method patterns (#99782)
Summary: This add support for CallMethod patterns in pattern_matcher. Also extends split_cat transforms to normalize tensor.split() type nodes
Test Plan: Unit tests (fb + OSS)
Differential Revision: D45195548
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99782
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py
index 0168976..b287875 100644
--- a/test/inductor/test_split_cat_fx_passes.py
+++ b/test/inductor/test_split_cat_fx_passes.py
@@ -36,6 +36,21 @@
def unequal_split(x):
return [torch.relu(s) for s in torch.split(x, 3, 1)]
+ def arg_only_cm(x):
+ return [torch.relu(s) for s in x.split(2, 1)]
+
+ def kwarg1_cm(x):
+ return [torch.relu(s) for s in x.split(2, dim=1)]
+
+ def kwarg2_cm(x):
+ return [torch.relu(s) for s in x.split(split_size=2, dim=1)]
+
+ def multi_split_cm(x):
+ return [s.split(2, 1) for s in x.split(2, 1)]
+
+ def unequal_split_cm(x):
+ return [torch.relu(s) for s in x.split(3, 1)]
+
args = [
torch.randn(2, 32),
]
@@ -47,6 +62,11 @@
(no_replace, 0),
(multi_split, 17),
(unequal_split, 1),
+ (arg_only_cm, 1),
+ (kwarg1_cm, 1),
+ (kwarg2_cm, 1),
+ (multi_split_cm, 17),
+ (unequal_split_cm, 1),
]:
expected = fn(*args)
actual = torch.compile(fn, dynamic=True)(*args)
diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py
index 75a4caa..9adf658 100644
--- a/torch/_inductor/fx_passes/split_cat.py
+++ b/torch/_inductor/fx_passes/split_cat.py
@@ -3,7 +3,14 @@
import torch
from torch._dynamo.utils import counters
-from ..pattern_matcher import Arg, CallFunction, get_arg_value, MULTIPLE, PatternEntry
+from ..pattern_matcher import (
+ Arg,
+ CallFunction,
+ CallMethod,
+ get_arg_value,
+ MULTIPLE,
+ PatternEntry,
+)
log = logging.getLogger(__name__)
@@ -54,6 +61,9 @@
dim=Arg(),
_users=MULTIPLE,
),
+ CallMethod("split", Arg(), Arg(), Arg(), _users=MULTIPLE),
+ CallMethod("split", Arg(), Arg(), dim=Arg(), _users=MULTIPLE),
+ CallMethod("split", Arg(), split_size=Arg(), dim=Arg(), _users=MULTIPLE),
]:
pattern = NormalizeSplit(pattern=pattern, extra_check=lambda arg: True)
pattern.register(patterns)
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index c1b212c..b7924e0 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -184,14 +184,18 @@
return Match(self, kwargs={self.name: node}) # matches anything
-class CallFunction(PatternExpr):
+class _BaseNodeMatch(PatternExpr):
"""
- Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
+ Base class for matching a node in a graph
"""
+ op = None
+
def __init__(self, fns, *args, _users=1, **kwargs):
+ if not self.op:
+ raise NotImplementedError("Shouldn't directly use _BaseNodeMatch")
super().__init__()
- fns = [fns] if callable(fns) else list(fns)
+ fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
for fn in list(fns):
if isinstance(fn, torch._ops.OpOverloadPacket):
fns.extend([getattr(fn, overload) for overload in fn.overloads()])
@@ -243,7 +247,7 @@
def _match(self, node: torch.fx.Node, ctx: MatchContext):
if (
not isinstance(node, torch.fx.Node)
- or node.op != "call_function"
+ or node.op != self.op
or node.target not in self.fns_set
or len(node.args) != len(self.args)
or len(node.kwargs) != len(self.kwargs)
@@ -297,13 +301,29 @@
for node in other_node.users:
if (
node not in searched
- and node.op == "call_function"
+ and node.op == self.op
and node.target in self.fns_set
):
yield node
searched.add(node)
+class CallFunction(_BaseNodeMatch):
+ """
+ Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
+ """
+
+ op = "call_function"
+
+
+class CallMethod(_BaseNodeMatch):
+ """
+ Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
+ """
+
+ op = "call_method"
+
+
class ListOf(PatternExpr):
"""
Matches a repeated pattern
@@ -605,7 +625,10 @@
return 0
count = 0
for node in reversed(graph.nodes):
- if node.op == "call_function" and node.target in self.patterns:
+ if (
+ node.op in ["call_function", "call_method"]
+ and node.target in self.patterns
+ ):
# conservatively not applying pattern for cpu input,
# since some of the patterns induce codegen and split nodes.
# Note: we will only skip cpu compute if disable_cpp_codegen=True