Add support for call_method patterns (#99782)

Summary: This add support for CallMethod patterns in pattern_matcher. Also extends split_cat transforms to normalize tensor.split() type nodes

Test Plan: Unit tests (fb + OSS)

Differential Revision: D45195548

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99782
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py
index 0168976..b287875 100644
--- a/test/inductor/test_split_cat_fx_passes.py
+++ b/test/inductor/test_split_cat_fx_passes.py
@@ -36,6 +36,21 @@
         def unequal_split(x):
             return [torch.relu(s) for s in torch.split(x, 3, 1)]
 
+        def arg_only_cm(x):
+            return [torch.relu(s) for s in x.split(2, 1)]
+
+        def kwarg1_cm(x):
+            return [torch.relu(s) for s in x.split(2, dim=1)]
+
+        def kwarg2_cm(x):
+            return [torch.relu(s) for s in x.split(split_size=2, dim=1)]
+
+        def multi_split_cm(x):
+            return [s.split(2, 1) for s in x.split(2, 1)]
+
+        def unequal_split_cm(x):
+            return [torch.relu(s) for s in x.split(3, 1)]
+
         args = [
             torch.randn(2, 32),
         ]
@@ -47,6 +62,11 @@
             (no_replace, 0),
             (multi_split, 17),
             (unequal_split, 1),
+            (arg_only_cm, 1),
+            (kwarg1_cm, 1),
+            (kwarg2_cm, 1),
+            (multi_split_cm, 17),
+            (unequal_split_cm, 1),
         ]:
             expected = fn(*args)
             actual = torch.compile(fn, dynamic=True)(*args)
diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py
index 75a4caa..9adf658 100644
--- a/torch/_inductor/fx_passes/split_cat.py
+++ b/torch/_inductor/fx_passes/split_cat.py
@@ -3,7 +3,14 @@
 
 import torch
 from torch._dynamo.utils import counters
-from ..pattern_matcher import Arg, CallFunction, get_arg_value, MULTIPLE, PatternEntry
+from ..pattern_matcher import (
+    Arg,
+    CallFunction,
+    CallMethod,
+    get_arg_value,
+    MULTIPLE,
+    PatternEntry,
+)
 
 log = logging.getLogger(__name__)
 
@@ -54,6 +61,9 @@
             dim=Arg(),
             _users=MULTIPLE,
         ),
+        CallMethod("split", Arg(), Arg(), Arg(), _users=MULTIPLE),
+        CallMethod("split", Arg(), Arg(), dim=Arg(), _users=MULTIPLE),
+        CallMethod("split", Arg(), split_size=Arg(), dim=Arg(), _users=MULTIPLE),
     ]:
         pattern = NormalizeSplit(pattern=pattern, extra_check=lambda arg: True)
         pattern.register(patterns)
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index c1b212c..b7924e0 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -184,14 +184,18 @@
         return Match(self, kwargs={self.name: node})  # matches anything
 
 
-class CallFunction(PatternExpr):
+class _BaseNodeMatch(PatternExpr):
     """
-    Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
+    Base class for matching a node in a graph
     """
 
+    op = None
+
     def __init__(self, fns, *args, _users=1, **kwargs):
+        if not self.op:
+            raise NotImplementedError("Shouldn't directly use _BaseNodeMatch")
         super().__init__()
-        fns = [fns] if callable(fns) else list(fns)
+        fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
         for fn in list(fns):
             if isinstance(fn, torch._ops.OpOverloadPacket):
                 fns.extend([getattr(fn, overload) for overload in fn.overloads()])
@@ -243,7 +247,7 @@
     def _match(self, node: torch.fx.Node, ctx: MatchContext):
         if (
             not isinstance(node, torch.fx.Node)
-            or node.op != "call_function"
+            or node.op != self.op
             or node.target not in self.fns_set
             or len(node.args) != len(self.args)
             or len(node.kwargs) != len(self.kwargs)
@@ -297,13 +301,29 @@
                     for node in other_node.users:
                         if (
                             node not in searched
-                            and node.op == "call_function"
+                            and node.op == self.op
                             and node.target in self.fns_set
                         ):
                             yield node
                             searched.add(node)
 
 
+class CallFunction(_BaseNodeMatch):
+    """
+    Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
+    """
+
+    op = "call_function"
+
+
+class CallMethod(_BaseNodeMatch):
+    """
+    Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
+    """
+
+    op = "call_method"
+
+
 class ListOf(PatternExpr):
     """
     Matches a repeated pattern
@@ -605,7 +625,10 @@
             return 0
         count = 0
         for node in reversed(graph.nodes):
-            if node.op == "call_function" and node.target in self.patterns:
+            if (
+                node.op in ["call_function", "call_method"]
+                and node.target in self.patterns
+            ):
                 # conservatively not applying pattern for cpu input,
                 # since some of the patterns induce codegen and split nodes.
                 # Note: we will only skip cpu compute if disable_cpp_codegen=True