Acyclic partition patch (#86511)
Fixes #86159 and #86108
Refactored graph partition to check for cyclic dependency on each partition merge, instead of relying on a pre-baked dependency map.
The previous implementation suffers from not updating dependency on existing partition. When a fusion happens, the updated dependency map needs to be propagated to all nodes in the graph, so each node in a partition shares an identical dependency set. Previous implementation suffers from the not identifying cyclic dependency in issue #86159.
Updated implementation does a cyclic check on partitioned graph before attempting a merge of two partitions.
- [x] python repro added with cyclic dependency after partition `TestFXGraphPasses.forward12`
- [x] fix dependency map with updated implementation using cyclic check
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86511
Approved by: https://github.com/SherlockNoMad
diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py
index 652285f..5e3e939 100644
--- a/test/test_fx_passes.py
+++ b/test/test_fx_passes.py
@@ -160,6 +160,21 @@
out = torch.stack([add_1, add_2, add_3])
return out
+ @staticmethod
+ def forward12(a, b, c):
+ b0 = a + 1.0
+ c0 = a + 1.5
+ x0 = b0.relu()
+ x1 = c0.relu()
+ b1 = b0 + x1
+ c1 = c0 + 1.2
+ # c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
+ # this dependency should be updated to the fusion group and reflected
+ # on the decision to not fuse b0 & b1, which forms a cyclic dependency in
+ # the new graph
+ c2 = x0 + c0
+ return b1, c2
+
# A mock OperatorSupport class, where only operator.add is supported
class MockOperatorSupport(OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
@@ -173,6 +188,10 @@
(TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]]),
(TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]]),
+ # 1 horizontal fusion with common producer
+ (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]]),
+ (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]]),
+
# 2 branches cases
(TestPartitionFunctions.forward5, [["add_1", "add"]]),
(TestPartitionFunctions.forward6, [["add"]]),
@@ -183,6 +202,9 @@
(TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']]),
(TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']]),
(TestPartitionFunctions.forward11, [['add_1'], ['add']]),
+
+ # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
+ (TestPartitionFunctions.forward12, [["add_2"], ["add_3", "add_4", "add_1"], ["add"]]),
])
def test_partitioner(self, fn, expected_partition):
traced = symbolic_trace(fn)
@@ -204,24 +226,6 @@
result = fused_graph(a, b, c)
torch.testing.assert_close(expected, result)
-
- @parametrize("fn, expected_partition", [
- # horizontal fusion without a common downstream node, not supported yet
- (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]]),
- # horizontal fusion with a common downstream node, not supported yet
- (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]]),
- ])
- def test_partitioner_xfail(self, fn, expected_partition):
- traced = symbolic_trace(fn)
-
- supported_ops = MockOperatorSupport()
- partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True)
- partitions = partitioner.propose_partitions()
-
- partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
- with self.assertRaises(Exception):
- assert len(partitions_name) == len(expected_partition)
-
@parametrize("partition", [
[['add', 'add_1'], ['add_5', 'add_6']],
[['add', 'add_1', 'add_2']], # vertical fusion
diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py
index 18a665b..dc50ea4 100644
--- a/torch/fx/passes/infra/partitioner.py
+++ b/torch/fx/passes/infra/partitioner.py
@@ -1,15 +1,14 @@
-from typing import Dict, List, Set, Iterable, Optional
+from typing import Dict, List, Set, Iterable
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
-from torch.fx.passes.tools_common import NodeList
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupportBase
-from collections import defaultdict
import logging
import itertools
+from copy import copy
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
@@ -42,137 +41,107 @@
self.operator_support = operator_support
self.allows_single_node_partition = allows_single_node_partition
- # map of node to it's upstream dependency nodes
- # if A is found in dependency_map[B], then B depends on A (or a is an upstream depedency of b)
- self.dependency_map = self.__build_dependency_map()
-
- def __build_dependency_map(self) -> Dict[Node, Set[Node]]:
- dependency_map = defaultdict(set)
-
- # assumptions: nodes in graph are sorted in topological order
- for node in self.graph_module.graph.nodes:
- for input_node in node.all_input_nodes:
- # add input_node and input_node's upstream dependency
- dependency_map[node].add(input_node)
- dependency_map[node].update(dependency_map[input_node])
-
- return dependency_map
-
- def __node_depends_on(self, a: Node, b: Node) -> int:
- # Returns
- # 1 if b depends on a (,or equivalently a is an upstream depedency of b)
- # -1 if a depends on b (,or equivalently b is an upstream depedency of a)
- # 0 if a and b doesn't have dependency between each other
-
- if a in self.dependency_map[b]:
- return 1
- elif b in self.dependency_map[a]:
- return -1
- else:
- return 0
-
- def __partition_depends_on(self, partition_a: Partition, partition_b: Partition) -> int:
- # Returns
- # 1 if b depends on a (,or equivalently a is an upstream depedency of b)
- # -1 if a depends on b (,or equivalently b is an upstream depedency of a)
- # 0 if a and b doesn't have dependency between each other
-
- # TODO: build a cache here to speedup the query
-
- for node_a in partition_a.nodes:
- for node_b in partition_b.nodes:
- dependency = self.__node_depends_on(node_a, node_b)
- if dependency != 0:
- return dependency
- return 0
-
- def __get_supported_nodes(self) -> NodeList:
- logging.debug("Collecting supported nodes...")
- supported_nodes = []
- for node in self.graph_module.graph.nodes:
- if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node):
- supported_nodes.append(node)
- return supported_nodes
+ def __is_node_supported(self, node: Node) -> bool:
+ # TODO: reject 'getitem' node since they are special cased in partitioning.
+ return self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
def propose_partitions(self) -> List[Partition]:
- candidates: NodeList = self.__get_supported_nodes()
-
# assumptions: nodes in candidate list is sorted in topological order
assignment: Dict[Node, int] = {} # maping from node to partition_id
partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
new_partition_id = itertools.count()
- def assign(node: Node, id: Optional[int] = None):
- # If id is None, remove the node from original assigment
+ # try to merge partition other_id into partition self_id
+ # merge only happens if the end graph doesn't contain cyclic dependency
+ # returns `True` when merge happens, `False` otherwise.
+ def maybe_merge_partition(self_id: int, other_id: int):
+ # merged_nodes is the union of nodes in two partition to-be-merged
+ merged_nodes = copy(partitions_by_id[self_id].nodes)
+ merged_nodes.update(partitions_by_id[other_id].nodes)
- # node has been assigned before, clean up and re-assign
- if node in assignment:
- original_id = assignment[node]
- del assignment[node]
- partitions_by_id[original_id].remove_node(node)
- if partitions_by_id[original_id].size() == 0:
- del partitions_by_id[original_id]
+ visited: Set[Node] = set()
- if id is not None:
- assignment[node] = id
- if id not in partitions_by_id:
- partitions_by_id[id] = Partition(id=id, nodes=[node])
+ def dfs_find_cycle(node):
+ if node in visited:
+ return False
+ if node in merged_nodes:
+ return True # found cycle, return
+
+ visited.add(node)
+ # branching on hitting partition or not
+ if node in assignment:
+ # Since partition is not merged in the graph yet, when we
+ # hit a node in a partition through DFS, we need to
+ # traverse all nodes in the partition to properly reflect
+ # dependencies after the fusion
+ for p_node in partitions_by_id[assignment[node]].nodes:
+ for user_node in p_node.users:
+ if dfs_find_cycle(user_node):
+ return True
else:
- partitions_by_id[id].add_node(node)
+ for user_node in node.users:
+ if dfs_find_cycle(user_node):
+ return True
+ return False
+
+ # check if merge would create cyclic dependency.
+ for node in merged_nodes:
+ for user_node in node.users:
+ if user_node not in merged_nodes and dfs_find_cycle(user_node):
+ # return false indicating cyclic dependency found and
+ # merge is aborted
+ return False
+
+ # no cyclic dependency found, move forward with the merge
+ # updating partition nodes
+ partitions_by_id[self_id].nodes = merged_nodes
+ # updating assignment map
+ for node in partitions_by_id[other_id].nodes:
+ assignment[node] = self_id
+ # delete other partition
+ del partitions_by_id[other_id]
+
+ return True
+
+ def merge_single_node(node: Node, id: int):
+ assert node not in assignment
+
+ assignment[node] = id
+ if id not in partitions_by_id:
+ partitions_by_id[id] = Partition(id=id, nodes=[node])
+ else:
+ partitions_by_id[id].add_node(node)
logging.debug("Proposing partitions...")
- # visit candidates in reversed topological order
- for node in reversed(candidates):
+ for node in reversed(self.graph_module.graph.nodes):
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
- user_partitions: Dict[Partition, None] = {}
+ merge_candidates: Dict[int, None] = {}
+
+ # Note a limited horizontal fusion is enabled:
+ # when `node` is not supported, the code below attempts to fuse consumer of `node`.
+ #
+ # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
+ # the fusion by adding an `else` block here to skip horizontal fusion.
+ if self.__is_node_supported(node) and node not in assignment:
+ partition_id = next(new_partition_id)
+ merge_single_node(node, partition_id)
+ merge_candidates[partition_id] = None
+
for user_node in node.users:
if user_node in assignment:
- id = assignment[user_node]
- user_partitions[partitions_by_id[id]] = None
- else:
- user_partitions[Partition(nodes=[user_node])] = None
+ merge_candidates[assignment[user_node]] = None
# Filter out all the partitions that has dependency on other users
# TODO: find a better way to do this, rather than pair-wise comparision
- user_partitions_list = list(user_partitions.keys())
- for i in range(len(user_partitions_list)):
- for j in range(i + 1, len(user_partitions_list)):
- pi = user_partitions_list[i]
- pj = user_partitions_list[j]
- dependency = self.__partition_depends_on(pi, pj)
- if dependency == 1 and pj in user_partitions:
- del user_partitions[pj]
- elif dependency == -1 and pi in user_partitions:
- del user_partitions[pi]
-
- # We use the following rules for partition assignment:
- # 1. If none of the candidates has been assigned to a partition, create a new partition
- # 2. If there is one partition candidate, assign to the partition
- # 3. If there are more than one partition candidates, assign current node to the first partition and
- # merge the other partitions with first partition, since user_partitions doesn't have depedency between
- # each other.
-
- assigned_candidate_partition_ids = [partition.id for partition in user_partitions if partition.id is not None]
-
- if len(assigned_candidate_partition_ids) == 0:
- # create a new partition
- assign(node, next(new_partition_id))
- elif len(assigned_candidate_partition_ids) == 1:
- id = assigned_candidate_partition_ids[0]
- assign(node, id)
- else:
- # users are assigned to more than one partition, since user_partitions doesn't have
- # dependency on each other, they can be fused into a single partition
- id = assigned_candidate_partition_ids[0]
- assign(node, id)
-
- reassignment: Dict[Node, int] = {}
- for other_id in assigned_candidate_partition_ids[1:]:
- for other_node in partitions_by_id[other_id].nodes:
- reassignment[other_node] = id
- for other_node in reassignment:
- assign(other_node, id)
+ merge_candidates_list = list(merge_candidates.keys())
+ if len(merge_candidates_list) > 1:
+ self_id = merge_candidates_list[0]
+ for other_id in merge_candidates_list[1:]:
+ # note: merge partition `other_id` into partition `self_id` if
+ # it doesn't create cyclic depenency in the graph, otherwise,
+ # this is a no-op
+ maybe_merge_partition(self_id, other_id)
# post processing to re-assign "getitem" nodes into upstream partition
logger.debug("Reassigning getitem nodes to its producer node's partition...")
@@ -190,9 +159,9 @@
id = assignment.get(node, None) # type: ignore[arg-type]
for user in node.users:
if assignment.get(user, None) != id: # type: ignore[arg-type]
- nodes_reassignment[user] = id
+ nodes_reassignment[user] = id # type: ignore[assignment]
for node, id in nodes_reassignment.items():
- assign(node, id)
+ merge_single_node(node, id)
# filter out single node partitions
if not self.allows_single_node_partition: