| from typing import List, Tuple, Union, Dict, Any, Set, Mapping |
| from dataclasses import dataclass |
| |
| import torch |
| import torch.fx |
| from torch.fx.node import _get_qualified_name |
| from torch.fx._compatibility import compatibility |
| |
| |
| Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] |
| TensorOrTensors = Union[torch.Tensor, Tensors] |
| NodeList = List[torch.fx.Node] |
| NodeSet = Set[torch.fx.Node] |
| Names = List[str] |
| CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def get_acc_ops_name(k): |
| if isinstance(k, str): |
| return k |
| elif k.__module__ and "acc_ops" in k.__module__: |
| return f"acc_ops.{k.__name__}" |
| else: |
| module = k.__module__ |
| return f"{module if module else ''}.{k.__name__}" |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str: |
| """ |
| Given a `node` returns its target typename. |
| |
| For "call_method" node, return node.target which is the name of that method being called. |
| This could potential lead to conflict but should be okay because normally it's on a tensor. |
| |
| For "call_function" node, return typename of node.target. |
| |
| For "call_module" node, return typename of the module that node.target point to. |
| |
| If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by |
| "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. |
| """ |
| |
| assert node.op in CALLABLE_NODE_OPS, ( |
| "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" |
| ) |
| |
| if node.op == "call_module": |
| assert isinstance(node.target, str) |
| submod = submodules[node.target] |
| submod_type = getattr(submod, "_base_class_origin", type(submod)) |
| return get_acc_ops_name(submod_type) |
| elif node.op == "call_function": |
| target: Any = node.target |
| return ( |
| f"acc_ops.{target.__name__}" |
| if target.__module__ is not None and "acc_ops" in target.__module__ |
| else _get_qualified_name(target) |
| ) |
| else: |
| assert isinstance(node.target, str) |
| return node.target |
| |
| @compatibility(is_backward_compatible=False) |
| def is_node_output_tensor(node: torch.fx.Node) -> bool: |
| """Checks if the node output produces a Tensor or not. |
| |
| NOTE: This requires to run `ShapeProp` on the containing fx graph before |
| calling this function. This is because it works by checking the `type` |
| metadata on the node. This metadata is produced by the `ShapeProp`. |
| """ |
| type_ = node.meta.get("type", None) |
| return type_ is not None and issubclass(type_, torch.Tensor) |
| |
| @compatibility(is_backward_compatible=False) |
| class FxNetAccFusionsFinder: |
| """ |
| Finds groups of connected ACC nodes that pass non-tensor data between each other. |
| Such groups are called fusion groups. |
| """ |
| |
| def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): |
| self.module = module |
| self.nodes = list(module.graph.nodes) |
| self.acc_nodes = acc_nodes |
| |
| @dataclass |
| class FusionGroup: |
| # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. |
| top_node_idx: int |
| |
| # Nodes in this fusion group. |
| nodes: NodeSet |
| |
| # Inputs to this fusion group. |
| inputs: NodeSet |
| |
| # Nodes that in the fusion group that haven't been processed yet. |
| nodes_need_process: NodeSet |
| |
| def add_node(self, node): |
| """ |
| Add a node to fusion group. |
| """ |
| if node in self.nodes: |
| return |
| |
| self.nodes_need_process.add(node) |
| self.nodes.add(node) |
| self.inputs.discard(node) |
| self.inputs.update( |
| { |
| n |
| for n in node.all_input_nodes |
| if n.op in CALLABLE_NODE_OPS and n not in self.nodes |
| } |
| ) |
| |
| def recursive_add_node( |
| self, |
| fusion_group: "FxNetAccFusionsFinder.FusionGroup", |
| inputs: Union[NodeSet, NodeList], |
| ): |
| """ |
| Start from inputs and going reverse topological order. If any upstream node |
| is in the fusion group, add all the nodes in this path to fusion group. |
| """ |
| for arg in inputs: |
| # Skip placeholder and get_attr because they won't be in the fusion group. |
| if arg.op not in CALLABLE_NODE_OPS: |
| continue |
| |
| # If the node has smaller idx, it's already an upstream node of the fusion |
| # group. We don't need to check it anymore. |
| if self.nodes.index(arg) < fusion_group.top_node_idx: |
| continue |
| |
| # If the node is in the fusion group, return True. |
| if arg in fusion_group.nodes: |
| return True |
| |
| # Check the upstream nodes of the node, if any of them is in the fusion group |
| # we'll add this node to fusion group and return True. |
| if self.recursive_add_node(fusion_group, arg.all_input_nodes): |
| fusion_group.add_node(arg) |
| return True |
| |
| return False |
| |
| def __call__(self) -> Dict[torch.fx.Node, NodeSet]: |
| result: Dict[torch.fx.Node, NodeSet] = {} |
| acc_nodes = list(self.acc_nodes) |
| |
| for node in acc_nodes: |
| if node in result: |
| continue |
| if node.op not in CALLABLE_NODE_OPS: |
| continue |
| if "tensor_meta" in node.meta: |
| continue |
| if node not in self.acc_nodes: |
| continue |
| |
| fusion_group: "FxNetAccFusionsFinder.FusionGroup" = self.FusionGroup( |
| top_node_idx=self.nodes.index(node), |
| nodes={node}, |
| inputs=set(node.all_input_nodes), |
| nodes_need_process={node}, |
| ) |
| while fusion_group.nodes_need_process: |
| node = fusion_group.nodes_need_process.pop() |
| self.recursive_add_node(fusion_group, fusion_group.inputs) |
| |
| # Optionally add downstream nodes |
| if "tensor_meta" not in node.meta: |
| for user in node.users: |
| if user.op not in CALLABLE_NODE_OPS: |
| continue |
| if user in fusion_group.nodes: |
| continue |
| |
| fusion_group.add_node(user) |
| self.recursive_add_node(fusion_group, fusion_group.inputs) |
| |
| # Add some upstream nodes |
| for arg in node.all_input_nodes: |
| if arg.op not in CALLABLE_NODE_OPS: |
| continue |
| if "tensor_meta" in arg.meta: |
| continue |
| if arg in fusion_group.nodes: |
| continue |
| |
| fusion_group.add_node(arg) |
| fusion_group.top_node_idx = min( |
| fusion_group.top_node_idx, self.nodes.index(arg) |
| ) |
| self.recursive_add_node(fusion_group, fusion_group.inputs) |
| |
| if not (set(fusion_group.nodes) <= self.acc_nodes): |
| self.acc_nodes -= fusion_group.nodes |
| else: |
| for n in fusion_group.nodes: |
| result[n] = fusion_group.nodes |
| |
| return result |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def legalize_graph(gm: torch.fx.GraphModule): |
| """ |
| Replace the graph of the given GraphModule with one that contains the same nodes as the |
| original, but in topologically sorted order. |
| |
| This is used by the merge_matmul transformation below, which disturbs the topologically sorted |
| order of its input GraphModule, so that this order is restored before further transformation. |
| |
| Arguments: |
| gm: The graph module to topologically sort. It is modified in-place. |
| |
| """ |
| # Build an adjacency list representation of node dependencies in the graph. This also |
| # serves as a list of nodes that still need to be inserted into the new, topologically |
| # sorted graph. |
| dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes} |
| |
| # Construct a new graph that will contain all nodes in topologically sorted order. |
| new_graph = torch.fx.Graph() |
| value_remap: Dict[torch.fx.Node, torch.fx.Node] = {} |
| |
| # Copy over all nodes with no dependencies. |
| for node, deps in dependencies.items(): |
| if not deps: |
| value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) |
| |
| # Remove the copied over nodes from the adjacency list. |
| for copied_node in value_remap.keys(): |
| del dependencies[copied_node] |
| |
| # While there are still nodes to insert into the new graph: |
| while dependencies: |
| copied_this_round = [] |
| |
| # Copy over all nodes whose dependencies already exist in the new graph. |
| for node, deps in dependencies.items(): |
| all_deps_copied = True |
| for dep in deps: |
| if dep not in value_remap: |
| all_deps_copied = False |
| |
| if all_deps_copied: |
| value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) |
| copied_this_round.append(node) |
| |
| # Delete all nodes copied over in this iteration from dependencies. |
| for copied_node in copied_this_round: |
| del dependencies[copied_node] |
| |
| # Replace the old graph with the new, topologically sorted one. |
| gm.graph = new_graph |