Partition modules (#98628)
Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.
`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]`
Args:
* graph: The graph we want to partition
* wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear)
Returns:
* Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type.
```
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
# Module type
module_type: Type
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
```
Example:
Original:
```
x -> linear -> linear -> relu -> linear
```
Traced graph:
```
.graph():
%arg0 : [#users=1] = placeholder[target=arg0]
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
%_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {})
%_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
%relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
return [addmm_default_2]
```
Result of `get_module_partitions`:
```
{<class 'torch.nn.modules.linear.Linear'>: [
ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
<class 'torch.nn.modules.activation.ReLU'>: [
ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
```
Also added helper function to check if two module partitions are connected:
`check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98628
Approved by: https://github.com/cccclai
diff --git a/test/fx/test_source_matcher_utils.py b/test/fx/test_source_matcher_utils.py
new file mode 100644
index 0000000..dd6ccb7
--- /dev/null
+++ b/test/fx/test_source_matcher_utils.py
@@ -0,0 +1,145 @@
+# Owner(s): ["module: fx"]
+
+import os
+import sys
+import unittest
+
+import torch
+
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from torch._dynamo.eval_frame import is_dynamo_supported
+from torch.fx.passes.utils.source_matcher_utils import get_source_partitions, check_subgraphs_connected
+from torch.testing._internal.jit_utils import JitTestCase
+
+class TestSourceMatcher(JitTestCase):
+ @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+ def test_module_partitioner_linear_relu_linear(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(3, 3)
+ self.relu = torch.nn.ReLU()
+ self.linear2 = torch.nn.Linear(3, 5)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.linear1(x)
+ x = self.relu(x)
+ x = self.linear2(x)
+ return x
+
+ inputs = (torch.randn(3, 3),)
+ gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
+ gm.graph.eliminate_dead_code()
+
+ module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU])
+
+ self.assertEqual(len(module_partitions), 2)
+ self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
+ self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
+
+ self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][0], module_partitions[torch.nn.ReLU][0]))
+ self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Linear][1], module_partitions[torch.nn.ReLU][0]))
+ self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][2], module_partitions[torch.nn.ReLU][0]))
+
+ @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+ def test_module_partitioner_conv_relu_maxpool(self):
+ class M(torch.nn.Module):
+ def __init__(self, constant_tensor: torch.Tensor) -> None:
+ super().__init__()
+ self.constant_tensor = constant_tensor
+ self.conv1 = torch.nn.Conv2d(
+ in_channels=3, out_channels=16, kernel_size=3, padding=1
+ )
+ self.conv2 = torch.nn.Conv2d(
+ in_channels=16, out_channels=16, kernel_size=3, padding=1
+ )
+ self.conv3 = torch.nn.Conv2d(
+ in_channels=16, out_channels=16, kernel_size=3, padding=1
+ )
+ self.relu = torch.nn.ReLU()
+ self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ a = self.conv1(x)
+ b = self.conv2(a)
+ c = a + self.constant_tensor
+ z = self.conv3(b + c)
+ return self.maxpool(self.relu(z))
+
+ inputs = (torch.randn(1, 3, 256, 256),)
+ gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), *inputs, aten_graph=True)
+ gm.graph.eliminate_dead_code()
+
+ module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d])
+
+ self.assertEqual(len(module_partitions), 3)
+ self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
+ self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
+ self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)
+
+ self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][0], module_partitions[torch.nn.ReLU][0]))
+ self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][1], module_partitions[torch.nn.ReLU][0]))
+ self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][2], module_partitions[torch.nn.ReLU][0]))
+ self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.MaxPool2d][0], module_partitions[torch.nn.ReLU][0]))
+ self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.ReLU][0], module_partitions[torch.nn.MaxPool2d][0]))
+
+ @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+ def test_module_partitioner_functional_conv_relu_conv(self):
+ class FunctionalConv2d(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.stride = (1, 1)
+ self.padding = (0, 0)
+ self.dilation = (1, 1)
+ self.groups = 1
+
+ def forward(self, x, weight, bias):
+ return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = FunctionalConv2d()
+ self.conv2 = FunctionalConv2d()
+
+ def forward(self, x, weight, bias):
+ x = self.conv1(x, weight, bias)
+ x = torch.nn.functional.relu(x)
+ x = self.conv2(x, weight, bias)
+ return x
+
+ inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
+ gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
+ gm.graph.eliminate_dead_code()
+
+ module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d])
+
+ self.assertEqual(len(module_partitions), 1)
+ self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)
+
+ @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
+ def test_module_partitioner_functional_linear_relu_linear(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, weight, bias):
+ x = torch.nn.functional.linear(x, weight, bias)
+ x = torch.nn.functional.linear(x, weight, bias)
+ x = torch.nn.functional.relu(x)
+ x = torch.nn.functional.linear(x, weight, bias)
+ x = torch.nn.functional.linear(x, weight, bias)
+ x = torch.nn.functional.relu(x)
+ return x
+
+ inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
+ gm, _ = torch._dynamo.export(M(), *inputs, aten_graph=True)
+ gm.graph.eliminate_dead_code()
+
+ module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu])
+
+ self.assertEqual(len(module_partitions), 2)
+ self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
+ self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2)
diff --git a/test/test_fx.py b/test/test_fx.py
index 6f8396a..c801756 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -46,6 +46,7 @@
from fx.test_common_passes import TestCommonPass # noqa: F401
from fx.test_cse_pass import TestCSEPass # noqa: F401
from fx.test_matcher_utils import TestMatcher # noqa: F401
+from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401
from fx.test_gradual_type import AnnotationsTest # noqa: F401
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 72ab96f..13bf8dc 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -1049,12 +1049,15 @@
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
if kind in {"call_function", "call_method"}:
- rv.node.meta["source_fn"] = target
+ rv.node.meta["source_fn"] = (rv.node.name, target)
elif kind == "call_module":
if self.parent is not None:
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
# For modules we store the class
- rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1]
+ rv.node.meta["source_fn"] = (
+ rv.node.name,
+ rv.node.meta["nn_module_stack"][target][1],
+ )
frame_summaries: List[traceback.FrameSummary] = []
while tx:
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index f56c483..aae816f 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -295,10 +295,10 @@
sources = []
for origin in all_origins:
if origin.op == "call_function" and "source_fn" in origin.meta:
- if isinstance(origin.meta["source_fn"], str):
- sources.append(origin.meta["source_fn"])
+ if isinstance(origin.meta["source_fn"][1], str):
+ sources.append(origin.meta["source_fn"][1])
else:
- sources.append(origin.meta["source_fn"].__name__)
+ sources.append(origin.meta["source_fn"][1].__name__)
sources = sorted(set(sources))
elif config.triton.descriptive_names == "inductor_node":
sources = [
diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py
new file mode 100644
index 0000000..ba3fd66
--- /dev/null
+++ b/torch/fx/passes/utils/source_matcher_utils.py
@@ -0,0 +1,128 @@
+from dataclasses import dataclass, field
+from torch.fx.graph import Graph
+from torch.fx.node import Node
+from torch.fx._compatibility import compatibility
+from typing import Dict, List, Any, Type
+import logging
+import os
+
+
+__all__ = ['get_source_partitions', 'check_subgraphs_connected']
+
+# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
+def _init_logger():
+ logger = logging.getLogger(__name__)
+
+ level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
+ logger.setLevel(level)
+ console = logging.StreamHandler()
+ formatter = logging.Formatter("%(filename)s > %(message)s")
+ console.setFormatter(formatter)
+ console.setLevel(level)
+ # add the handlers to the logger
+ logger.addHandler(console)
+ logger.propagate = False
+ return logger
+
+logger = _init_logger()
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class SourcePartition():
+ # Nodes in a particular partition
+ nodes: List[Node]
+
+ # The source these nodes decomposed from
+ source: Any
+
+ # Nodes in the graph that are needed as inputs to the partition
+ input_nodes: List[Node] = field(default_factory=list)
+
+ # Nodes in the partition that are being used by nodes outside of the
+ # partition
+ output_nodes: List[Node] = field(default_factory=list)
+
+ # Parameters that are being used
+ params: List[str] = field(default_factory=list)
+
+
+@compatibility(is_backward_compatible=False)
+def get_source_partitions(
+ graph: Graph,
+ wanted_sources: List[Any]
+) -> Dict[Any, List[SourcePartition]]:
+ """
+ Args:
+ graph: The graph we want to partition
+ wanted_sources: List of sources of nodes that were decomposed from this
+ source. This can be a function (ex. torch.nn.functional.linear) or a
+ leaf module type (ex. torch.nn.Linear).
+
+ Returns:
+ Dictionary mapping sources that were given to a list of SourcePartitions
+ that correspond to the list of nodes that were decomposed from the given
+ source.
+ """
+ modules: Dict[Type, Dict[str, List[Node]]] = {}
+
+ for node in graph.nodes:
+ # The metadata source_fn should contain a tuple of a unique name for the
+ # source, and the source function if the node is decomposed from a
+ # function, or the type of module if the node is decomposed from a leaf
+ # module
+
+ if (source_fn := node.meta.get("source_fn", None)) is None:
+ continue
+
+ if source_fn[1] not in wanted_sources:
+ continue
+
+ diff_modules = modules.setdefault(source_fn[1], {})
+ partition = diff_modules.setdefault(source_fn[0], [])
+ partition.append(node)
+
+ def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
+ input_nodes = set()
+ output_nodes = set()
+ params = set()
+ for node in nodes:
+ for arg in node.args:
+ if isinstance(arg, Node) and arg not in nodes:
+ input_nodes.add(arg)
+
+ if node.op == "get_attr":
+ params.add(node.target)
+
+ for user in node.users.keys():
+ if user not in nodes:
+ output_nodes.add(node)
+
+ return SourcePartition(
+ nodes,
+ module_type,
+ list(input_nodes),
+ list(output_nodes),
+ list(params), # type: ignore[arg-type]
+ )
+
+ ret: Dict[Type[Any], List[SourcePartition]] = {}
+ for k, v in modules.items():
+ ret[k] = [make_partition(partition, k) for partition in v.values()]
+
+ return ret
+
+
+@compatibility(is_backward_compatible=False)
+def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
+ """
+ Given two subgraphs A and B (in the form of a list of nodes), checks if
+ A has nodes connecting to at least one node in B -- aka there exists a node
+ in B that uses a node in A (not the other way around).
+ """
+
+ for node in reversed(subgraph1.nodes):
+ for user in node.users.keys():
+ if user in subgraph2.nodes:
+ return True
+ return False