| from typing import Dict, List, Set, Iterable, Optional |
| |
| from torch.fx.passes.utils.fuser_utils import fuse_by_partitions |
| from torch.fx.passes.tools_common import NodeList |
| |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.node import Node, _get_qualified_name |
| from torch.fx.passes.operator_support import OperatorSupportBase |
| |
| from collections import defaultdict |
| import logging |
| import itertools |
| |
| logging.basicConfig(level=logging.WARNING) |
| logger = logging.getLogger(__name__) |
| |
| class Partition: |
| def __init__(self, id: int = None, nodes: Iterable[Node] = None): |
| self.id = id |
| self.nodes: Set[Node] = set(nodes) if nodes is not None else set() |
| |
| def __repr__(self) -> str: |
| return str(self.nodes) |
| |
| def add_node(self, node: Node): |
| self.nodes.add(node) |
| |
| def remove_node(self, node: Node): |
| self.nodes.remove(node) |
| |
| def size(self): |
| return len(self.nodes) |
| |
| class CapabilityBasedPartitioner: |
| |
| def __init__(self, |
| graph_module: GraphModule, |
| operator_support: OperatorSupportBase, |
| allows_single_node_partition: bool = False |
| ) -> None: |
| self.graph_module = graph_module |
| self.operator_support = operator_support |
| self.allows_single_node_partition = allows_single_node_partition |
| |
| # map of node to it's upstream dependency nodes |
| # if A is found in dependency_map[B], then B depends on A (or a is an upstream depedency of b) |
| self.dependency_map = self.__build_dependency_map() |
| |
| def __build_dependency_map(self) -> Dict[Node, Set[Node]]: |
| dependency_map = defaultdict(set) |
| |
| # assumptions: nodes in graph are sorted in topological order |
| for node in self.graph_module.graph.nodes: |
| for input_node in node.all_input_nodes: |
| # add input_node and input_node's upstream dependency |
| dependency_map[node].add(input_node) |
| dependency_map[node].update(dependency_map[input_node]) |
| |
| return dependency_map |
| |
| def __node_depends_on(self, a: Node, b: Node) -> int: |
| # Returns |
| # 1 if b depends on a (,or equivalently a is an upstream depedency of b) |
| # -1 if a depends on b (,or equivalently b is an upstream depedency of a) |
| # 0 if a and b doesn't have dependency between each other |
| |
| if a in self.dependency_map[b]: |
| return 1 |
| elif b in self.dependency_map[a]: |
| return -1 |
| else: |
| return 0 |
| |
| def __partition_depends_on(self, partition_a: Partition, partition_b: Partition) -> int: |
| # Returns |
| # 1 if b depends on a (,or equivalently a is an upstream depedency of b) |
| # -1 if a depends on b (,or equivalently b is an upstream depedency of a) |
| # 0 if a and b doesn't have dependency between each other |
| |
| # TODO: build a cache here to speedup the query |
| |
| for node_a in partition_a.nodes: |
| for node_b in partition_b.nodes: |
| dependency = self.__node_depends_on(node_a, node_b) |
| if dependency != 0: |
| return dependency |
| return 0 |
| |
| def __get_supported_nodes(self) -> NodeList: |
| logging.debug("Collecting supported nodes...") |
| supported_nodes = [] |
| for node in self.graph_module.graph.nodes: |
| if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node): |
| supported_nodes.append(node) |
| return supported_nodes |
| |
| def propose_partitions(self) -> List[Partition]: |
| candidates: NodeList = self.__get_supported_nodes() |
| |
| # assumptions: nodes in candidate list is sorted in topological order |
| assignment: Dict[Node, int] = {} # maping from node to partition_id |
| partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition |
| new_partition_id = itertools.count() |
| |
| def assign(node: Node, id: Optional[int] = None): |
| # If id is None, remove the node from original assigment |
| |
| # node has been assigned before, clean up and re-assign |
| if node in assignment: |
| original_id = assignment[node] |
| del assignment[node] |
| partitions_by_id[original_id].remove_node(node) |
| if partitions_by_id[original_id].size() == 0: |
| del partitions_by_id[original_id] |
| |
| if id is not None: |
| assignment[node] = id |
| if id not in partitions_by_id: |
| partitions_by_id[id] = Partition(id=id, nodes=[node]) |
| else: |
| partitions_by_id[id].add_node(node) |
| |
| logging.debug("Proposing partitions...") |
| |
| # visit candidates in reversed topological order |
| for node in reversed(candidates): |
| # use Dict as an ordered set to ensure deterministic partitioning result, don't care value |
| user_partitions: Dict[Partition, None] = {} |
| for user_node in node.users: |
| if user_node in assignment: |
| id = assignment[user_node] |
| user_partitions[partitions_by_id[id]] = None |
| else: |
| user_partitions[Partition(nodes=[user_node])] = None |
| |
| # Filter out all the partitions that has dependency on other users |
| # TODO: find a better way to do this, rather than pair-wise comparision |
| user_partitions_list = list(user_partitions.keys()) |
| for i in range(len(user_partitions_list)): |
| for j in range(i + 1, len(user_partitions_list)): |
| pi = user_partitions_list[i] |
| pj = user_partitions_list[j] |
| dependency = self.__partition_depends_on(pi, pj) |
| if dependency == 1 and pj in user_partitions: |
| del user_partitions[pj] |
| elif dependency == -1 and pi in user_partitions: |
| del user_partitions[pi] |
| |
| # We use the following rules for partition assignment: |
| # 1. If none of the candidates has been assigned to a partition, create a new partition |
| # 2. If there is one partition candidate, assign to the partition |
| # 3. If there are more than one partition candidates, assign current node to the first partition and |
| # merge the other partitions with first partition, since user_partitions doesn't have depedency between |
| # each other. |
| |
| assigned_candidate_partition_ids = [partition.id for partition in user_partitions if partition.id is not None] |
| |
| if len(assigned_candidate_partition_ids) == 0: |
| # create a new partition |
| assign(node, next(new_partition_id)) |
| elif len(assigned_candidate_partition_ids) == 1: |
| id = assigned_candidate_partition_ids[0] |
| assign(node, id) |
| else: |
| # users are assigned to more than one partition, since user_partitions doesn't have |
| # dependency on each other, they can be fused into a single partition |
| id = assigned_candidate_partition_ids[0] |
| assign(node, id) |
| |
| reassignment: Dict[Node, int] = {} |
| for other_id in assigned_candidate_partition_ids[1:]: |
| for other_node in partitions_by_id[other_id].nodes: |
| reassignment[other_node] = id |
| for other_node in reassignment: |
| assign(other_node, id) |
| |
| # post processing to re-assign "getitem" nodes into upstream partition |
| logger.debug("Reassigning getitem nodes to its producer node's partition...") |
| nodes_reassignment: Dict[Node, int] = {} |
| for node in self.graph_module.graph.nodes: |
| is_tuple_output = True |
| for user in node.users: |
| if user.op != "call_function" or \ |
| _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] |
| is_tuple_output = False |
| break |
| |
| # node has tuple outputs, re-assign all following getitem node into node's partition |
| if is_tuple_output: |
| id = assignment.get(node, None) # type: ignore[arg-type] |
| for user in node.users: |
| if assignment.get(user, None) != id: # type: ignore[arg-type] |
| nodes_reassignment[user] = id |
| for node, id in nodes_reassignment.items(): |
| assign(node, id) |
| |
| # filter out single node partitions |
| if not self.allows_single_node_partition: |
| logger.debug("Filtering out single node partitions...") |
| non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} |
| partitions_to_remove: List[int] = [] |
| for id, partition in partitions_by_id.items(): |
| compute_node_count = 0 |
| for node in partition.nodes: |
| if node.op == "call_function" and \ |
| _get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type] |
| compute_node_count += 1 |
| if compute_node_count <= 1: |
| partitions_to_remove.append(id) |
| for id in partitions_to_remove: |
| del partitions_by_id[id] |
| |
| logging.debug("Partitions proposed:") |
| for id, partition in partitions_by_id.items(): |
| logging.debug(f"partition #{id}", [node.name for node in partition.nodes]) |
| |
| return list(partitions_by_id.values()) |
| |
| def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: |
| logging.debug("Fusing partitions...") |
| # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] |
| return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions]) |
| |
| def partition_and_fuse(self) -> GraphModule: |
| partitions = self.propose_partitions() |
| fused_gm = self.fuse_partitions(partitions) |
| return fused_gm |