| # mypy: allow-untyped-defs |
| |
| import hashlib |
| import torch |
| import torch.fx |
| from typing import Any, Dict, Optional, TYPE_CHECKING |
| from torch.fx.node import _get_qualified_name, _format_arg |
| from torch.fx.graph import _parse_stack_trace |
| from torch.fx.passes.shape_prop import TensorMetadata |
| from torch.fx._compatibility import compatibility |
| from itertools import chain |
| |
| __all__ = ['FxGraphDrawer'] |
| try: |
| import pydot |
| HAS_PYDOT = True |
| except ImportError: |
| HAS_PYDOT = False |
| |
| _COLOR_MAP = { |
| "placeholder": '"AliceBlue"', |
| "call_module": "LemonChiffon1", |
| "get_param": "Yellow2", |
| "get_attr": "LightGrey", |
| "output": "PowderBlue", |
| } |
| |
| _HASH_COLOR_MAP = [ |
| "CadetBlue1", |
| "Coral", |
| "DarkOliveGreen1", |
| "DarkSeaGreen1", |
| "GhostWhite", |
| "Khaki1", |
| "LavenderBlush1", |
| "LightSkyBlue", |
| "MistyRose1", |
| "MistyRose2", |
| "PaleTurquoise2", |
| "PeachPuff1", |
| "Salmon", |
| "Thistle1", |
| "Thistle3", |
| "Wheat1", |
| ] |
| |
| _WEIGHT_TEMPLATE = { |
| "fillcolor": "Salmon", |
| "style": '"filled,rounded"', |
| "fontcolor": "#000000", |
| } |
| |
| if HAS_PYDOT: |
| @compatibility(is_backward_compatible=False) |
| class FxGraphDrawer: |
| """ |
| Visualize a torch.fx.Graph with graphviz |
| Basic usage: |
| g = FxGraphDrawer(symbolic_traced, "resnet18") |
| g.get_dot_graph().write_svg("a.svg") |
| """ |
| |
| def __init__( |
| self, |
| graph_module: torch.fx.GraphModule, |
| name: str, |
| ignore_getattr: bool = False, |
| ignore_parameters_and_buffers: bool = False, |
| skip_node_names_in_args: bool = True, |
| parse_stack_trace: bool = False, |
| dot_graph_shape: Optional[str] = None, |
| ): |
| self._name = name |
| self.dot_graph_shape = ( |
| dot_graph_shape if dot_graph_shape is not None else "record" |
| ) |
| _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape |
| |
| self._dot_graphs = { |
| name: self._to_dot( |
| graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace |
| ) |
| } |
| |
| for node in graph_module.graph.nodes: |
| if node.op != "call_module": |
| continue |
| |
| leaf_node = self._get_leaf_node(graph_module, node) |
| |
| if not isinstance(leaf_node, torch.fx.GraphModule): |
| continue |
| |
| |
| self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( |
| leaf_node, |
| f"{name}_{node.target}", |
| ignore_getattr, |
| ignore_parameters_and_buffers, |
| skip_node_names_in_args, |
| parse_stack_trace, |
| ) |
| |
| def get_dot_graph(self, submod_name=None) -> pydot.Dot: |
| """ |
| Visualize a torch.fx.Graph with graphviz |
| Example: |
| >>> # xdoctest: +REQUIRES(module:pydot) |
| >>> # xdoctest: +REQUIRES(module:ubelt) |
| >>> # define module |
| >>> class MyModule(torch.nn.Module): |
| >>> def __init__(self): |
| >>> super().__init__() |
| >>> self.linear = torch.nn.Linear(4, 5) |
| >>> def forward(self, x): |
| >>> return self.linear(x).clamp(min=0.0, max=1.0) |
| >>> module = MyModule() |
| >>> # trace the module |
| >>> symbolic_traced = torch.fx.symbolic_trace(module) |
| >>> # setup output file |
| >>> import ubelt as ub |
| >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() |
| >>> fpath = dpath / 'linear.svg' |
| >>> # draw the graph |
| >>> g = FxGraphDrawer(symbolic_traced, "linear") |
| >>> g.get_dot_graph().write_svg(fpath) |
| """ |
| if submod_name is None: |
| return self.get_main_dot_graph() |
| else: |
| return self.get_submod_dot_graph(submod_name) |
| |
| def get_main_dot_graph(self) -> pydot.Dot: |
| return self._dot_graphs[self._name] |
| |
| def get_submod_dot_graph(self, submod_name) -> pydot.Dot: |
| return self._dot_graphs[f"{self._name}_{submod_name}"] |
| |
| def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: |
| return self._dot_graphs |
| |
| def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: |
| |
| template = { |
| "shape": self.dot_graph_shape, |
| "fillcolor": "#CAFFE3", |
| "style": '"filled,rounded"', |
| "fontcolor": "#000000", |
| } |
| if node.op in _COLOR_MAP: |
| template["fillcolor"] = _COLOR_MAP[node.op] |
| else: |
| # Use a random color for each node; based on its name so it's stable. |
| target_name = node._pretty_print_target(node.target) |
| target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) |
| template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] |
| return template |
| |
| def _get_leaf_node( |
| self, module: torch.nn.Module, node: torch.fx.Node |
| ) -> torch.nn.Module: |
| py_obj = module |
| assert isinstance(node.target, str) |
| atoms = node.target.split(".") |
| for atom in atoms: |
| if not hasattr(py_obj, atom): |
| raise RuntimeError( |
| str(py_obj) + " does not have attribute " + atom + "!" |
| ) |
| py_obj = getattr(py_obj, atom) |
| return py_obj |
| |
| def _typename(self, target: Any) -> str: |
| if isinstance(target, torch.nn.Module): |
| ret = torch.typename(target) |
| elif isinstance(target, str): |
| ret = target |
| else: |
| ret = _get_qualified_name(target) |
| |
| # Escape "{" and "}" to prevent dot files like: |
| # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc |
| # which triggers `Error: bad label format (...)` from dot |
| return ret.replace("{", r"\{").replace("}", r"\}") |
| |
| # shorten path to avoid drawing long boxes |
| # for full path = '/home/weif/pytorch/test.py' |
| # return short path = 'pytorch/test.py' |
| def _shorten_file_name( |
| self, |
| full_file_name: str, |
| truncate_to_last_n: int = 2, |
| ): |
| splits = full_file_name.split('/') |
| if len(splits) >= truncate_to_last_n: |
| return '/'.join(splits[-truncate_to_last_n:]) |
| return full_file_name |
| |
| |
| def _get_node_label( |
| self, |
| module: torch.fx.GraphModule, |
| node: torch.fx.Node, |
| skip_node_names_in_args: bool, |
| parse_stack_trace: bool, |
| ) -> str: |
| def _get_str_for_args_kwargs(arg): |
| if isinstance(arg, tuple): |
| prefix, suffix = r"|args=(\l", r",\n)\l" |
| arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] |
| elif isinstance(arg, dict): |
| prefix, suffix = r"|kwargs={\l", r",\n}\l" |
| arg_strs_list = [ |
| f"{k}: {_format_arg(v, max_list_len=8)}" |
| for k, v in arg.items() |
| ] |
| else: # Fall back to nothing in unexpected case. |
| return "" |
| |
| # Strip out node names if requested. |
| if skip_node_names_in_args: |
| arg_strs_list = [a for a in arg_strs_list if "%" not in a] |
| if len(arg_strs_list) == 0: |
| return "" |
| arg_strs = prefix + r",\n".join(arg_strs_list) + suffix |
| if len(arg_strs_list) == 1: |
| arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") |
| return arg_strs.replace("{", r"\{").replace("}", r"\}") |
| |
| |
| label = "{" + f"name=%{node.name}|op_code={node.op}\n" |
| |
| if node.op == "call_module": |
| leaf_module = self._get_leaf_node(module, node) |
| label += r"\n" + self._typename(leaf_module) + r"\n|" |
| extra = "" |
| if hasattr(leaf_module, "__constants__"): |
| extra = r"\n".join( |
| [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] |
| ) |
| label += extra + r"\n" |
| else: |
| label += f"|target={self._typename(node.target)}" + r"\n" |
| if len(node.args) > 0: |
| label += _get_str_for_args_kwargs(node.args) |
| if len(node.kwargs) > 0: |
| label += _get_str_for_args_kwargs(node.kwargs) |
| label += f"|num_users={len(node.users)}" + r"\n" |
| |
| tensor_meta = node.meta.get('tensor_meta') |
| label += self._tensor_meta_to_label(tensor_meta) |
| |
| # for original fx graph |
| # print buf=buf0, n_origin=6 |
| buf_meta = node.meta.get('buf_meta', None) |
| if buf_meta is not None: |
| label += f"|buf={buf_meta.name}" + r"\n" |
| label += f"|n_origin={buf_meta.n_origin}" + r"\n" |
| |
| # for original fx graph |
| # print file:lineno code |
| if parse_stack_trace and node.stack_trace is not None: |
| parsed_stack_trace = _parse_stack_trace(node.stack_trace) |
| fname = self._shorten_file_name(parsed_stack_trace.file) |
| label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" |
| |
| |
| return label + "}" |
| |
| def _tensor_meta_to_label(self, tm) -> str: |
| if tm is None: |
| return "" |
| elif isinstance(tm, TensorMetadata): |
| return self._stringify_tensor_meta(tm) |
| elif isinstance(tm, list): |
| result = "" |
| for item in tm: |
| result += self._tensor_meta_to_label(item) |
| return result |
| elif isinstance(tm, dict): |
| result = "" |
| for v in tm.values(): |
| result += self._tensor_meta_to_label(v) |
| return result |
| elif isinstance(tm, tuple): |
| result = "" |
| for item in tm: |
| result += self._tensor_meta_to_label(item) |
| return result |
| else: |
| raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") |
| |
| def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: |
| result = "" |
| if not hasattr(tm, "dtype"): |
| print("tm", tm) |
| result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" |
| result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" |
| result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" |
| result += "|" + "stride" + "=" + str(tm.stride) + r"\n" |
| if tm.is_quantized: |
| assert tm.qparams is not None |
| assert "qscheme" in tm.qparams |
| qscheme = tm.qparams["qscheme"] |
| if qscheme in { |
| torch.per_tensor_affine, |
| torch.per_tensor_symmetric, |
| }: |
| result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" |
| result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" |
| elif qscheme in { |
| torch.per_channel_affine, |
| torch.per_channel_symmetric, |
| torch.per_channel_affine_float_qparams, |
| }: |
| result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" |
| result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" |
| result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" |
| else: |
| raise RuntimeError(f"Unsupported qscheme: {qscheme}") |
| result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" |
| return result |
| |
| def _get_tensor_label(self, t: torch.Tensor) -> str: |
| return str(t.dtype) + str(list(t.shape)) + r"\n" |
| |
| # when parse_stack_trace=True |
| # print file:lineno code |
| def _to_dot( |
| self, |
| graph_module: torch.fx.GraphModule, |
| name: str, |
| ignore_getattr: bool, |
| ignore_parameters_and_buffers: bool, |
| skip_node_names_in_args: bool, |
| parse_stack_trace: bool, |
| ) -> pydot.Dot: |
| """ |
| Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. |
| If ignore_parameters_and_buffers is True, the parameters and buffers |
| created with the module will not be added as nodes and edges. |
| """ |
| |
| # "TB" means top-to-bottom rank direction in layout |
| dot_graph = pydot.Dot(name, rankdir="TB") |
| |
| |
| buf_name_to_subgraph = {} |
| |
| for node in graph_module.graph.nodes: |
| if ignore_getattr and node.op == "get_attr": |
| continue |
| |
| style = self._get_node_style(node) |
| dot_node = pydot.Node( |
| node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style |
| ) |
| |
| current_graph = dot_graph |
| |
| buf_meta = node.meta.get('buf_meta', None) |
| if buf_meta is not None and buf_meta.n_origin > 1: |
| buf_name = buf_meta.name |
| if buf_name not in buf_name_to_subgraph: |
| buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) |
| current_graph = buf_name_to_subgraph.get(buf_name) |
| |
| current_graph.add_node(dot_node) |
| |
| def get_module_params_or_buffers(): |
| for pname, ptensor in chain( |
| leaf_module.named_parameters(), leaf_module.named_buffers() |
| ): |
| pname1 = node.name + "." + pname |
| label1 = ( |
| pname1 + "|op_code=get_" + "parameter" |
| if isinstance(ptensor, torch.nn.Parameter) |
| else "buffer" + r"\l" |
| ) |
| dot_w_node = pydot.Node( |
| pname1, |
| label="{" + label1 + self._get_tensor_label(ptensor) + "}", |
| **_WEIGHT_TEMPLATE, |
| ) |
| dot_graph.add_node(dot_w_node) |
| dot_graph.add_edge(pydot.Edge(pname1, node.name)) |
| |
| if node.op == "call_module": |
| leaf_module = self._get_leaf_node(graph_module, node) |
| |
| if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): |
| get_module_params_or_buffers() |
| |
| for subgraph in buf_name_to_subgraph.values(): |
| subgraph.set('color', 'royalblue') |
| subgraph.set('penwidth', '2') |
| dot_graph.add_subgraph(subgraph) |
| |
| for node in graph_module.graph.nodes: |
| if ignore_getattr and node.op == "get_attr": |
| continue |
| |
| for user in node.users: |
| dot_graph.add_edge(pydot.Edge(node.name, user.name)) |
| |
| return dot_graph |
| |
| else: |
| if not TYPE_CHECKING: |
| @compatibility(is_backward_compatible=False) |
| class FxGraphDrawer: |
| def __init__( |
| self, |
| graph_module: torch.fx.GraphModule, |
| name: str, |
| ignore_getattr: bool = False, |
| ignore_parameters_and_buffers: bool = False, |
| skip_node_names_in_args: bool = True, |
| parse_stack_trace: bool = False, |
| dot_graph_shape: Optional[str] = None, |
| ): |
| raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' |
| 'pydot through your favorite Python package manager.') |