| # mypy: allow-untyped-defs |
| import sys |
| from typing import Dict, Optional |
| |
| import torch |
| |
| from torch._logging import LazyString |
| |
| |
| def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): |
| """ |
| Returns a LazyString that formats the graph code. |
| """ |
| |
| def format_name(): |
| if maybe_id is not None: |
| return f"{name} {maybe_id}" |
| else: |
| return name |
| |
| if "print_output" not in kwargs: |
| kwargs["print_output"] = False |
| |
| if "colored" in kwargs and not sys.stdout.isatty(): |
| kwargs["colored"] = False |
| |
| return LazyString( |
| lambda: _format_graph_code( |
| f"===== {format_name()} =====\n", |
| gm.forward.__code__.co_filename, |
| gm.print_readable(**kwargs), |
| ) |
| ) |
| |
| |
| def _format_graph_code(name, filename, graph_str): |
| """ |
| Returns a string that formats the graph code. |
| """ |
| return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" |
| |
| |
| def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]: |
| """ |
| Returns the nn_module_stack of the first call_function node. |
| """ |
| for node in graph.nodes: |
| if node.op == "call_function" and "nn_module_stack" in node.meta: |
| return node.meta["nn_module_stack"] |
| return None |
| |
| |
| def get_node_context(node, num_nodes=2) -> str: |
| """ |
| Returns a string of the last num_nodes nodes in the graph. |
| """ |
| node_contexts = [] |
| cur = node |
| for i in range(num_nodes): |
| node_contexts.append(cur.format_node()) |
| if cur.op == "root": |
| break |
| cur = cur.prev |
| return "\n".join(node_contexts[::-1]) |