Partition modules (#98628)

Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.

`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]`

Args:
* graph: The graph we want to partition
* wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear)

Returns:
* Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type.

```
@dataclass
class SourcePartition():
    # Nodes in a particular partition
    nodes: List[Node]
    # Module type
    module_type: Type
    # Nodes in the graph that are needed as inputs to the partition
    input_nodes: List[Node] = field(default_factory=list)
    # Nodes in the partition that are being used by nodes outside of the partition
    output_nodes: List[Node] = field(default_factory=list)
    # Parameters that are being used
    params: List[str] = field(default_factory=list)
```

Example:

Original:
```
x -> linear -> linear -> relu -> linear
```
Traced graph:
```
.graph():
    %arg0 : [#users=1] = placeholder[target=arg0]
    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
    %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {})
    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
    %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
    %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {})
    %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
    %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
    %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
    %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
    %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
    %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
    return [addmm_default_2]
```
Result of `get_module_partitions`:
```
{<class 'torch.nn.modules.linear.Linear'>: [
    ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
    ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
    ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],

 <class 'torch.nn.modules.activation.ReLU'>: [
    ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
```

Also added helper function to check if two module partitions are connected:
`check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98628
Approved by: https://github.com/cccclai
diff --git a/test/fx/test_source_matcher_utils.py b/test/fx/test_source_matcher_utils.py
new file mode 100644
index 0000000..dd6ccb7
--- /dev/null
+++ b/test/fx/test_source_matcher_utils.py
@@ -0,0 +1,145 @@
+# Owner(s): ["module: fx"]
+
+import os
+import sys
+import unittest
+
+import torch
+
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from torch._dynamo.eval_frame import is_dynamo_supported
+from torch.fx.passes.utils.source_matcher_utils import get_source_partitions, check_subgraphs_connected
+from torch.testing._internal.jit_utils import JitTestCase
+
+class TestSourceMatcher(JitTestCase):
+    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+    def test_module_partitioner_linear_relu_linear(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear1 = torch.nn.Linear(3, 3)
+                self.relu = torch.nn.ReLU()
+                self.linear2 = torch.nn.Linear(3, 5)
+
+            def forward(self, x):
+                x = self.linear1(x)
+                x = self.linear1(x)
+                x = self.relu(x)
+                x = self.linear2(x)
+                return x
+
+        inputs = (torch.randn(3, 3),)
+        gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
+        gm.graph.eliminate_dead_code()
+
+        module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU])
+
+        self.assertEqual(len(module_partitions), 2)
+        self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
+        self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
+
+        self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][0], module_partitions[torch.nn.ReLU][0]))
+        self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Linear][1], module_partitions[torch.nn.ReLU][0]))
+        self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][2], module_partitions[torch.nn.ReLU][0]))
+
+    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+    def test_module_partitioner_conv_relu_maxpool(self):
+        class M(torch.nn.Module):
+            def __init__(self, constant_tensor: torch.Tensor) -> None:
+                super().__init__()
+                self.constant_tensor = constant_tensor
+                self.conv1 = torch.nn.Conv2d(
+                    in_channels=3, out_channels=16, kernel_size=3, padding=1
+                )
+                self.conv2 = torch.nn.Conv2d(
+                    in_channels=16, out_channels=16, kernel_size=3, padding=1
+                )
+                self.conv3 = torch.nn.Conv2d(
+                    in_channels=16, out_channels=16, kernel_size=3, padding=1
+                )
+                self.relu = torch.nn.ReLU()
+                self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
+
+            def forward(self, x: torch.Tensor) -> torch.Tensor:
+                a = self.conv1(x)
+                b = self.conv2(a)
+                c = a + self.constant_tensor
+                z = self.conv3(b + c)
+                return self.maxpool(self.relu(z))
+
+        inputs = (torch.randn(1, 3, 256, 256),)
+        gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), *inputs, aten_graph=True)
+        gm.graph.eliminate_dead_code()
+
+        module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d])
+
+        self.assertEqual(len(module_partitions), 3)
+        self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
+        self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
+        self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)
+
+        self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][0], module_partitions[torch.nn.ReLU][0]))
+        self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][1], module_partitions[torch.nn.ReLU][0]))
+        self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][2], module_partitions[torch.nn.ReLU][0]))
+        self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.MaxPool2d][0], module_partitions[torch.nn.ReLU][0]))
+        self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.ReLU][0], module_partitions[torch.nn.MaxPool2d][0]))
+
+    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+    def test_module_partitioner_functional_conv_relu_conv(self):
+        class FunctionalConv2d(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.stride = (1, 1)
+                self.padding = (0, 0)
+                self.dilation = (1, 1)
+                self.groups = 1
+
+            def forward(self, x, weight, bias):
+                return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.conv1 = FunctionalConv2d()
+                self.conv2 = FunctionalConv2d()
+
+            def forward(self, x, weight, bias):
+                x = self.conv1(x, weight, bias)
+                x = torch.nn.functional.relu(x)
+                x = self.conv2(x, weight, bias)
+                return x
+
+        inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
+        gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
+        gm.graph.eliminate_dead_code()
+
+        module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d])
+
+        self.assertEqual(len(module_partitions), 1)
+        self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)
+
+    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+    def test_module_partitioner_functional_linear_relu_linear(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x, weight, bias):
+                x = torch.nn.functional.linear(x, weight, bias)
+                x = torch.nn.functional.linear(x, weight, bias)
+                x = torch.nn.functional.relu(x)
+                x = torch.nn.functional.linear(x, weight, bias)
+                x = torch.nn.functional.linear(x, weight, bias)
+                x = torch.nn.functional.relu(x)
+                return x
+
+        inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
+        gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
+        gm.graph.eliminate_dead_code()
+
+        module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu])
+
+        self.assertEqual(len(module_partitions), 2)
+        self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
+        self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2)
diff --git a/test/test_fx.py b/test/test_fx.py
index 6f8396a..c801756 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -46,6 +46,7 @@
 from fx.test_common_passes import TestCommonPass  # noqa: F401
 from fx.test_cse_pass import TestCSEPass  # noqa: F401
 from fx.test_matcher_utils import TestMatcher  # noqa: F401
+from fx.test_source_matcher_utils import TestSourceMatcher  # noqa: F401
 
 from fx.test_gradual_type import AnnotationsTest  # noqa: F401
 from fx.test_gradual_type import TypeCheckerTest  # noqa: F401
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 72ab96f..13bf8dc 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -1049,12 +1049,15 @@
             rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
 
         if kind in {"call_function", "call_method"}:
-            rv.node.meta["source_fn"] = target
+            rv.node.meta["source_fn"] = (rv.node.name, target)
         elif kind == "call_module":
             if self.parent is not None:
                 unimplemented("Invoking an nn.Module inside HigherOrderOperator")
             # For modules we store the class
-            rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1]
+            rv.node.meta["source_fn"] = (
+                rv.node.name,
+                rv.node.meta["nn_module_stack"][target][1],
+            )
 
         frame_summaries: List[traceback.FrameSummary] = []
         while tx:
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index f56c483..aae816f 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -295,10 +295,10 @@
         sources = []
         for origin in all_origins:
             if origin.op == "call_function" and "source_fn" in origin.meta:
-                if isinstance(origin.meta["source_fn"], str):
-                    sources.append(origin.meta["source_fn"])
+                if isinstance(origin.meta["source_fn"][1], str):
+                    sources.append(origin.meta["source_fn"][1])
                 else:
-                    sources.append(origin.meta["source_fn"].__name__)
+                    sources.append(origin.meta["source_fn"][1].__name__)
         sources = sorted(set(sources))
     elif config.triton.descriptive_names == "inductor_node":
         sources = [
diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py
new file mode 100644
index 0000000..ba3fd66
--- /dev/null
+++ b/torch/fx/passes/utils/source_matcher_utils.py
@@ -0,0 +1,128 @@
+from dataclasses import dataclass, field
+from torch.fx.graph import Graph
+from torch.fx.node import Node
+from torch.fx._compatibility import compatibility
+from typing import Dict, List, Any, Type
+import logging
+import os
+
+
+__all__ = ['get_source_partitions', 'check_subgraphs_connected']
+
+# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
+def _init_logger():
+    logger = logging.getLogger(__name__)
+
+    level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter("%(filename)s > %(message)s")
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+logger = _init_logger()
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class SourcePartition():
+    # Nodes in a particular partition
+    nodes: List[Node]
+
+    # The source these nodes decomposed from
+    source: Any
+
+    # Nodes in the graph that are needed as inputs to the partition
+    input_nodes: List[Node] = field(default_factory=list)
+
+    # Nodes in the partition that are being used by nodes outside of the
+    # partition
+    output_nodes: List[Node] = field(default_factory=list)
+
+    # Parameters that are being used
+    params: List[str] = field(default_factory=list)
+
+
+@compatibility(is_backward_compatible=False)
+def get_source_partitions(
+    graph: Graph,
+    wanted_sources: List[Any]
+) -> Dict[Any, List[SourcePartition]]:
+    """
+    Args:
+        graph: The graph we want to partition
+        wanted_sources: List of sources of nodes that were decomposed from this
+            source. This can be a function (ex. torch.nn.functional.linear) or a
+            leaf module type (ex. torch.nn.Linear).
+
+    Returns:
+        Dictionary mapping sources that were given to a list of SourcePartitions
+        that correspond to the list of nodes that were decomposed from the given
+        source.
+    """
+    modules: Dict[Type, Dict[str, List[Node]]] = {}
+
+    for node in graph.nodes:
+        # The metadata source_fn should contain a tuple of a unique name for the
+        # source, and the source function if the node is decomposed from a
+        # function, or the type of module if the node is decomposed from a leaf
+        # module
+
+        if (source_fn := node.meta.get("source_fn", None)) is None:
+            continue
+
+        if source_fn[1] not in wanted_sources:
+            continue
+
+        diff_modules = modules.setdefault(source_fn[1], {})
+        partition = diff_modules.setdefault(source_fn[0], [])
+        partition.append(node)
+
+    def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
+        input_nodes = set()
+        output_nodes = set()
+        params = set()
+        for node in nodes:
+            for arg in node.args:
+                if isinstance(arg, Node) and arg not in nodes:
+                    input_nodes.add(arg)
+
+            if node.op == "get_attr":
+                params.add(node.target)
+
+            for user in node.users.keys():
+                if user not in nodes:
+                    output_nodes.add(node)
+
+        return SourcePartition(
+            nodes,
+            module_type,
+            list(input_nodes),
+            list(output_nodes),
+            list(params),  # type: ignore[arg-type]
+        )
+
+    ret: Dict[Type[Any], List[SourcePartition]] = {}
+    for k, v in modules.items():
+        ret[k] = [make_partition(partition, k) for partition in v.values()]
+
+    return ret
+
+
+@compatibility(is_backward_compatible=False)
+def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
+    """
+    Given two subgraphs A and B (in the form of a list of nodes), checks if
+    A has nodes connecting to at least one node in B -- aka there exists a node
+    in B that uses a node in A (not the other way around).
+    """
+
+    for node in reversed(subgraph1.nodes):
+        for user in node.users.keys():
+            if user in subgraph2.nodes:
+                return True
+    return False