| # mypy: allow-untyped-defs |
| from typing import Dict, Tuple |
| |
| from torch.fx._compatibility import compatibility |
| from torch.fx.graph import Graph |
| |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.passes.utils.matcher_utils import SubgraphMatcher |
| from torch.nn import Module |
| |
| |
| __all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] |
| |
| |
| @compatibility(is_backward_compatible=False) |
| class HolderModule(Module): |
| """ |
| HolderModule is used to copy all the attributes from original module to submodules |
| that uses the attributes |
| """ |
| |
| def __init__(self, d): |
| super().__init__() |
| for k, v in d.items(): |
| self.add_module(k, v) |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def lift_subgraph_as_module( |
| gm: GraphModule, |
| subgraph: Graph, |
| comp_name: str = "", |
| class_name: str = "GraphModule", |
| ) -> Tuple[GraphModule, Dict[str, str]]: |
| """ |
| Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. |
| |
| Args: |
| gm (GraphModule): parent graph module |
| |
| subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph |
| |
| comp_name (str): name for the new component |
| |
| class_name (str): name for the submodule |
| |
| """ |
| |
| # Loop through all module calls (call_module) and param fetches (get_attr) |
| # in this component, creating HolderModules as necessary to match the path. |
| # e.g. if in the original module there's a get_attr node fetches "conv.weight". |
| # We create a HolderModule as root -> add a HolderModule named "conv" -> |
| # make "weight" a attribute of "conv" HolderModule and point to conv.weight in |
| # the original module. |
| submodule = HolderModule({}) |
| orig_to_split_fqn_mapping: Dict[str, str] = {} |
| for n in subgraph.nodes: |
| if n.op not in ("call_module", "get_attr"): |
| continue |
| |
| target = n.target |
| assert isinstance(target, str) |
| target_name_parts = target.split(".") |
| curr = submodule |
| orig_gm = gm |
| |
| for name in target_name_parts[:-1]: |
| if not hasattr(curr, name): |
| curr.add_module(name, HolderModule({})) |
| |
| curr = getattr(curr, name) |
| orig_gm = getattr(orig_gm, name) |
| |
| leaf_node_name = target_name_parts[-1] |
| leaf_node = getattr(orig_gm, leaf_node_name) |
| |
| orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" |
| # Relies on custom __setattr__ magic. |
| setattr(curr, leaf_node_name, leaf_node) |
| |
| return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def compare_graphs(left: Graph, right: Graph) -> bool: |
| """ |
| Return True if two graphs are identical, i.e they |
| - have the same number of outputs in the same order |
| - have the same number of inputs in the same order |
| - have the same set of nodes, and identical connectivity |
| """ |
| |
| matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) |
| matches = matcher.match(right) |
| |
| return len(matches) > 0 |