| from __future__ import absolute_import, division, print_function, unicode_literals |
| |
| import torch |
| import torch.fx |
| import pydot |
| from typing import Dict, Any |
| from torch.fx.node import _get_qualified_name |
| |
| _COLOR_MAP = { |
| "placeholder": '"AliceBlue"', |
| "call_module": "LemonChiffon1", |
| "call_function": "PeachPuff1", |
| "get_param": "Yellow2", |
| "get_attr": "LightGrey", |
| "call_method": "LavenderBlush1", |
| "output": "PowderBlue", |
| } |
| |
| _WEIGHT_TEMPLATE = { |
| "shape": "record", |
| "fillcolor": "Salmon", |
| "style": '"filled,rounded"', |
| "fontcolor": "#000000", |
| } |
| |
| |
| class FxGraphDrawer: |
| """ |
| Visualize a torch.fx.Graph with graphviz |
| Basic usage: |
| g = FxGraphDrawer(symbolic_traced, "resnet18") |
| with open("a.svg", "w") as f: |
| f.write(g.get_dot_graph().create_svg()) |
| """ |
| |
| def __init__(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool = False): |
| self._name = name |
| self._dot_graphs = {name: self._to_dot(graph_module, name, ignore_getattr)} |
| |
| for node in graph_module.graph.nodes: |
| if node.op != "call_module": |
| continue |
| |
| leaf_node = self._get_leaf_node(graph_module, node) |
| |
| if not isinstance(leaf_node, torch.fx.GraphModule): |
| continue |
| |
| self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(leaf_node, f"{name}_{node.target}", ignore_getattr) |
| |
| def get_main_dot_graph(self) -> pydot.Dot: |
| return self._dot_graphs[self._name] |
| |
| def get_submod_dot_graph(self, submod_name) -> pydot.Dot: |
| return self._dot_graphs[f"{self._name}_{submod_name}"] |
| |
| def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: |
| return self._dot_graphs |
| |
| def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: |
| template = { |
| "shape": "record", |
| "fillcolor": "#CAFFE3", |
| "style": '"filled,rounded"', |
| "fontcolor": "#000000", |
| } |
| template["fillcolor"] = _COLOR_MAP[node.op] |
| return template |
| |
| def _get_leaf_node( |
| self, module: torch.nn.Module, node: torch.fx.Node |
| ) -> torch.nn.Module: |
| py_obj = module |
| assert isinstance(node.target, str) |
| atoms = node.target.split(".") |
| for atom in atoms: |
| if not hasattr(py_obj, atom): |
| raise RuntimeError( |
| str(py_obj) + " does not have attribute " + atom + "!" |
| ) |
| py_obj = getattr(py_obj, atom) |
| return py_obj |
| |
| def _typename(self, target: Any) -> str: |
| if isinstance(target, torch.nn.Module): |
| return torch.typename(target) |
| |
| if isinstance(target, str): |
| return target |
| |
| return _get_qualified_name(target) |
| |
| def _get_node_label(self, module: torch.fx.GraphModule, node: torch.fx.Node) -> str: |
| label = "{" + f"{node.name}|op_code={node.op}" |
| |
| if node.op == "call_module": |
| leaf_module = self._get_leaf_node(module, node) |
| label += r"\l" + self._typename(leaf_module) + r"\l|" |
| extra = "" |
| if hasattr(leaf_module, "__constants__"): |
| extra = r"\l".join( |
| [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] |
| ) |
| label += extra + r"\l" |
| else: |
| label += "|" + self._typename(node.target) + r"\l" |
| |
| tensor_meta = node.meta.get('tensor_meta') |
| if tensor_meta: |
| dtype_ = tensor_meta.dtype |
| shape_ = tensor_meta.shape |
| stride_ = tensor_meta.stride |
| if dtype_: |
| label += "|" + "dtype" + "=" + str(dtype_) + r"\l" |
| if shape_: |
| label += "|" + "shape" + "=" + str(shape_) + r"\l" |
| if stride_: |
| label += "|" + "stride" + "=" + str(stride_) + r"\l" |
| |
| return label + "}" |
| |
| def _get_tensor_label(self, t: torch.Tensor) -> str: |
| return str(t.dtype) + str(list(t.shape)) + r"\l" |
| |
| def _to_dot(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool) -> pydot.Dot: |
| """ |
| Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph |
| """ |
| dot_graph = pydot.Dot(name, rankdir="TB") |
| |
| for node in graph_module.graph.nodes: |
| if ignore_getattr and node.op == "get_attr": |
| continue |
| |
| style = self._get_node_style(node) |
| dot_node = pydot.Node( |
| node.name, label=self._get_node_label(graph_module, node), **style |
| ) |
| dot_graph.add_node(dot_node) |
| |
| def get_module_params_or_buffers(is_param: bool): |
| for pname, ptensor in ( |
| leaf_module.named_parameters() |
| if is_param |
| else leaf_module.named_buffers() |
| ): |
| pname1 = node.name + "." + pname |
| label1 = ( |
| pname1 + "|op_code=get_" + "parameter" |
| if is_param |
| else "buffer" + r"\l" |
| ) |
| dot_w_node = pydot.Node( |
| pname1, |
| label="{" + label1 + self._get_tensor_label(ptensor) + "}", |
| **_WEIGHT_TEMPLATE, |
| ) |
| dot_graph.add_node(dot_w_node) |
| dot_graph.add_edge(pydot.Edge(pname1, node.name)) |
| |
| if node.op == "call_module": |
| leaf_module = self._get_leaf_node(graph_module, node) |
| |
| if not isinstance(leaf_module, torch.fx.GraphModule): |
| get_module_params_or_buffers(True) |
| get_module_params_or_buffers(False) |
| |
| for node in graph_module.graph.nodes: |
| if ignore_getattr and node.op == "get_attr": |
| continue |
| |
| for user in node.users: |
| dot_graph.add_edge(pydot.Edge(node.name, user.name)) |
| |
| return dot_graph |