Acyclic partition patch (#86511)

Fixes #86159 and #86108

Refactored graph partition to check for cyclic dependency on each partition merge, instead of relying on a pre-baked dependency map.

The previous implementation suffers from not updating dependency on existing partition. When a fusion happens, the updated dependency map needs to be propagated to all nodes in the graph, so each node in a partition shares an identical dependency set. Previous implementation suffers from the not identifying cyclic dependency in issue #86159.

Updated implementation does a cyclic check on partitioned graph before attempting a merge of two partitions.

- [x] python repro added with cyclic dependency after partition `TestFXGraphPasses.forward12`
- [x] fix dependency map with updated implementation using cyclic check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86511
Approved by: https://github.com/SherlockNoMad
diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py
index 652285f..5e3e939 100644
--- a/test/test_fx_passes.py
+++ b/test/test_fx_passes.py
@@ -160,6 +160,21 @@
         out = torch.stack([add_1, add_2, add_3])
         return out
 
+    @staticmethod
+    def forward12(a, b, c):
+        b0 = a + 1.0
+        c0 = a + 1.5
+        x0 = b0.relu()
+        x1 = c0.relu()
+        b1 = b0 + x1
+        c1 = c0 + 1.2
+        # c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
+        # this dependency should be updated to the fusion group and reflected
+        # on the decision to not fuse b0 & b1, which forms a cyclic dependency in
+        # the new graph
+        c2 = x0 + c0
+        return b1, c2
+
 # A mock OperatorSupport class, where only operator.add is supported
 class MockOperatorSupport(OperatorSupport):
     def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
@@ -173,6 +188,10 @@
         (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]]),
         (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]]),
 
+        # 1 horizontal fusion with common producer
+        (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]]),
+        (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]]),
+
         # 2 branches cases
         (TestPartitionFunctions.forward5, [["add_1", "add"]]),
         (TestPartitionFunctions.forward6, [["add"]]),
@@ -183,6 +202,9 @@
         (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']]),
         (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']]),
         (TestPartitionFunctions.forward11, [['add_1'], ['add']]),
+
+        # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
+        (TestPartitionFunctions.forward12, [["add_2"], ["add_3", "add_4", "add_1"], ["add"]]),
     ])
     def test_partitioner(self, fn, expected_partition):
         traced = symbolic_trace(fn)
@@ -204,24 +226,6 @@
         result = fused_graph(a, b, c)
         torch.testing.assert_close(expected, result)
 
-
-    @parametrize("fn, expected_partition", [
-        # horizontal fusion without a common downstream node, not supported yet
-        (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]]),
-        # horizontal fusion with a common downstream node, not supported yet
-        (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]]),
-    ])
-    def test_partitioner_xfail(self, fn, expected_partition):
-        traced = symbolic_trace(fn)
-
-        supported_ops = MockOperatorSupport()
-        partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True)
-        partitions = partitioner.propose_partitions()
-
-        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
-        with self.assertRaises(Exception):
-            assert len(partitions_name) == len(expected_partition)
-
     @parametrize("partition", [
         [['add', 'add_1'], ['add_5', 'add_6']],
         [['add', 'add_1', 'add_2']],  # vertical fusion
diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py
index 18a665b..dc50ea4 100644
--- a/torch/fx/passes/infra/partitioner.py
+++ b/torch/fx/passes/infra/partitioner.py
@@ -1,15 +1,14 @@
-from typing import Dict, List, Set, Iterable, Optional
+from typing import Dict, List, Set, Iterable
 
 from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
-from torch.fx.passes.tools_common import NodeList
 
 from torch.fx.graph_module import GraphModule
 from torch.fx.node import Node, _get_qualified_name
 from torch.fx.passes.operator_support import OperatorSupportBase
 
-from collections import defaultdict
 import logging
 import itertools
+from copy import copy
 
 logging.basicConfig(level=logging.WARNING)
 logger = logging.getLogger(__name__)
@@ -42,137 +41,107 @@
         self.operator_support = operator_support
         self.allows_single_node_partition = allows_single_node_partition
 
-        # map of node to it's upstream dependency nodes
-        # if A is found in dependency_map[B], then B depends on A (or a is an upstream depedency of b)
-        self.dependency_map = self.__build_dependency_map()
-
-    def __build_dependency_map(self) -> Dict[Node, Set[Node]]:
-        dependency_map = defaultdict(set)
-
-        # assumptions: nodes in graph are sorted in topological order
-        for node in self.graph_module.graph.nodes:
-            for input_node in node.all_input_nodes:
-                # add input_node and input_node's upstream dependency
-                dependency_map[node].add(input_node)
-                dependency_map[node].update(dependency_map[input_node])
-
-        return dependency_map
-
-    def __node_depends_on(self, a: Node, b: Node) -> int:
-        # Returns
-        # 1 if b depends on a (,or equivalently a is an upstream depedency of b)
-        # -1 if a depends on b (,or equivalently b is an upstream depedency of a)
-        # 0 if a and b doesn't have dependency between each other
-
-        if a in self.dependency_map[b]:
-            return 1
-        elif b in self.dependency_map[a]:
-            return -1
-        else:
-            return 0
-
-    def __partition_depends_on(self, partition_a: Partition, partition_b: Partition) -> int:
-        # Returns
-        # 1 if b depends on a (,or equivalently a is an upstream depedency of b)
-        # -1 if a depends on b (,or equivalently b is an upstream depedency of a)
-        # 0 if a and b doesn't have dependency between each other
-
-        # TODO: build a cache here to speedup the query
-
-        for node_a in partition_a.nodes:
-            for node_b in partition_b.nodes:
-                dependency = self.__node_depends_on(node_a, node_b)
-                if dependency != 0:
-                    return dependency
-        return 0
-
-    def __get_supported_nodes(self) -> NodeList:
-        logging.debug("Collecting supported nodes...")
-        supported_nodes = []
-        for node in self.graph_module.graph.nodes:
-            if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node):
-                supported_nodes.append(node)
-        return supported_nodes
+    def __is_node_supported(self, node: Node) -> bool:
+        # TODO: reject 'getitem' node since they are special cased in partitioning.
+        return self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
 
     def propose_partitions(self) -> List[Partition]:
-        candidates: NodeList = self.__get_supported_nodes()
-
         # assumptions: nodes in candidate list is sorted in topological order
         assignment: Dict[Node, int] = {}   # maping from node to partition_id
         partitions_by_id: Dict[int, Partition] = {}  # mapping from partition_id to partition
         new_partition_id = itertools.count()
 
-        def assign(node: Node, id: Optional[int] = None):
-            # If id is None, remove the node from original assigment
+        # try to merge partition other_id into partition self_id
+        # merge only happens if the end graph doesn't contain cyclic dependency
+        # returns `True` when merge happens, `False` otherwise.
+        def maybe_merge_partition(self_id: int, other_id: int):
+            # merged_nodes is the union of nodes in two partition to-be-merged
+            merged_nodes = copy(partitions_by_id[self_id].nodes)
+            merged_nodes.update(partitions_by_id[other_id].nodes)
 
-            # node has been assigned before, clean up and re-assign
-            if node in assignment:
-                original_id = assignment[node]
-                del assignment[node]
-                partitions_by_id[original_id].remove_node(node)
-                if partitions_by_id[original_id].size() == 0:
-                    del partitions_by_id[original_id]
+            visited: Set[Node] = set()
 
-            if id is not None:
-                assignment[node] = id
-                if id not in partitions_by_id:
-                    partitions_by_id[id] = Partition(id=id, nodes=[node])
+            def dfs_find_cycle(node):
+                if node in visited:
+                    return False
+                if node in merged_nodes:
+                    return True  # found cycle, return
+
+                visited.add(node)
+                # branching on hitting partition or not
+                if node in assignment:
+                    # Since partition is not merged in the graph yet, when we
+                    # hit a node in a partition through DFS, we need to
+                    # traverse all nodes in the partition to properly reflect
+                    # dependencies after the fusion
+                    for p_node in partitions_by_id[assignment[node]].nodes:
+                        for user_node in p_node.users:
+                            if dfs_find_cycle(user_node):
+                                return True
                 else:
-                    partitions_by_id[id].add_node(node)
+                    for user_node in node.users:
+                        if dfs_find_cycle(user_node):
+                            return True
+                return False
+
+            # check if merge would create cyclic dependency.
+            for node in merged_nodes:
+                for user_node in node.users:
+                    if user_node not in merged_nodes and dfs_find_cycle(user_node):
+                        # return false indicating cyclic dependency found and
+                        # merge is aborted
+                        return False
+
+            # no cyclic dependency found, move forward with the merge
+            # updating partition nodes
+            partitions_by_id[self_id].nodes = merged_nodes
+            # updating assignment map
+            for node in partitions_by_id[other_id].nodes:
+                assignment[node] = self_id
+            # delete other partition
+            del partitions_by_id[other_id]
+
+            return True
+
+        def merge_single_node(node: Node, id: int):
+            assert node not in assignment
+
+            assignment[node] = id
+            if id not in partitions_by_id:
+                partitions_by_id[id] = Partition(id=id, nodes=[node])
+            else:
+                partitions_by_id[id].add_node(node)
 
         logging.debug("Proposing partitions...")
 
-        # visit candidates in reversed topological order
-        for node in reversed(candidates):
+        for node in reversed(self.graph_module.graph.nodes):
             # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
-            user_partitions: Dict[Partition, None] = {}
+            merge_candidates: Dict[int, None] = {}
+
+            # Note a limited horizontal fusion is enabled:
+            #   when `node` is not supported, the code below attempts to fuse consumer of `node`.
+            #
+            # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
+            # the fusion by adding an `else` block here to skip horizontal fusion.
+            if self.__is_node_supported(node) and node not in assignment:
+                partition_id = next(new_partition_id)
+                merge_single_node(node, partition_id)
+                merge_candidates[partition_id] = None
+
             for user_node in node.users:
                 if user_node in assignment:
-                    id = assignment[user_node]
-                    user_partitions[partitions_by_id[id]] = None
-                else:
-                    user_partitions[Partition(nodes=[user_node])] = None
+                    merge_candidates[assignment[user_node]] = None
 
             # Filter out all the partitions that has dependency on other users
             # TODO: find a better way to do this, rather than pair-wise comparision
-            user_partitions_list = list(user_partitions.keys())
-            for i in range(len(user_partitions_list)):
-                for j in range(i + 1, len(user_partitions_list)):
-                    pi = user_partitions_list[i]
-                    pj = user_partitions_list[j]
-                    dependency = self.__partition_depends_on(pi, pj)
-                    if dependency == 1 and pj in user_partitions:
-                        del user_partitions[pj]
-                    elif dependency == -1 and pi in user_partitions:
-                        del user_partitions[pi]
-
-            # We use the following rules for partition assignment:
-            # 1. If none of the candidates has been assigned to a partition, create a new partition
-            # 2. If there is one partition candidate, assign to the partition
-            # 3. If there are more than one partition candidates, assign current node to the first partition and
-            #    merge the other partitions with first partition, since user_partitions doesn't have depedency between
-            #    each other.
-
-            assigned_candidate_partition_ids = [partition.id for partition in user_partitions if partition.id is not None]
-
-            if len(assigned_candidate_partition_ids) == 0:
-                # create a new partition
-                assign(node, next(new_partition_id))
-            elif len(assigned_candidate_partition_ids) == 1:
-                id = assigned_candidate_partition_ids[0]
-                assign(node, id)
-            else:
-                # users are assigned to more than one partition, since user_partitions doesn't have
-                # dependency on each other, they can be fused into a single partition
-                id = assigned_candidate_partition_ids[0]
-                assign(node, id)
-
-                reassignment: Dict[Node, int] = {}
-                for other_id in assigned_candidate_partition_ids[1:]:
-                    for other_node in partitions_by_id[other_id].nodes:
-                        reassignment[other_node] = id
-                for other_node in reassignment:
-                    assign(other_node, id)
+            merge_candidates_list = list(merge_candidates.keys())
+            if len(merge_candidates_list) > 1:
+                self_id = merge_candidates_list[0]
+                for other_id in merge_candidates_list[1:]:
+                    # note: merge partition `other_id` into partition `self_id` if
+                    # it doesn't create cyclic depenency in the graph, otherwise,
+                    # this is a no-op
+                    maybe_merge_partition(self_id, other_id)
 
         # post processing to re-assign "getitem" nodes into upstream partition
         logger.debug("Reassigning getitem nodes to its producer node's partition...")
@@ -190,9 +159,9 @@
                 id = assignment.get(node, None)     # type: ignore[arg-type]
                 for user in node.users:
                     if assignment.get(user, None) != id:    # type: ignore[arg-type]
-                        nodes_reassignment[user] = id
+                        nodes_reassignment[user] = id  # type: ignore[assignment]
         for node, id in nodes_reassignment.items():
-            assign(node, id)
+            merge_single_node(node, id)
 
         # filter out single node partitions
         if not self.allows_single_node_partition: