| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import copy |
| import re |
| import reprlib |
| from dataclasses import fields |
| from enum import IntEnum |
| from typing import Any, List, Optional, TextIO |
| |
| import torch |
| from executorch.exir.error import ExportError, ExportErrorType, InternalError |
| |
| from executorch.exir.schema import ( |
| Bool, |
| BoolList, |
| DelegateCall, |
| Double, |
| DoubleList, |
| EValue, |
| Frame, |
| FrameList, |
| FreeCall, |
| Int, |
| IntList, |
| JumpFalseCall, |
| KernelCall, |
| MoveCall, |
| Null, |
| OptionalTensorList, |
| Program, |
| ScalarType, |
| String, |
| Tensor, |
| TensorList, |
| TensorShapeDynamism, |
| ) |
| |
| |
| def _scalar_type_str(scalar_type: ScalarType) -> str: |
| type2str = { |
| ScalarType.BYTE: "bt", |
| ScalarType.CHAR: "c", |
| ScalarType.SHORT: "s", |
| ScalarType.INT: "i", |
| ScalarType.LONG: "l", |
| ScalarType.HALF: "h", |
| ScalarType.FLOAT: "f", |
| ScalarType.DOUBLE: "d", |
| ScalarType.COMPLEX32: "c32", |
| ScalarType.COMPLEX64: "c64", |
| ScalarType.COMPLEX128: "c128", |
| ScalarType.BOOL: "b", |
| ScalarType.QINT8: "qi8", |
| ScalarType.QUINT8: "qui8", |
| ScalarType.QINT32: "qi32", |
| ScalarType.BFLOAT16: "bf16", |
| ScalarType.QUINT4x2: "qui4x2", |
| ScalarType.QUINT2x4: "qui2x4", |
| } |
| if not (ret := type2str.get(scalar_type, None)): |
| raise RuntimeError(f"Unrecognized scalar_type: {scalar_type}") |
| else: |
| return ret |
| |
| |
| def _is_dynamic_shape_tensor(tensor: Tensor) -> bool: |
| return tensor.shape_dynamism != TensorShapeDynamism.STATIC |
| |
| |
| def _format_evalue( # noqa: C901 |
| evalue: EValue, show_meminfo: bool, mark_dynamic_shape_tensor: bool |
| ) -> str: |
| evstr = "\033[34m" |
| if isinstance(evalue.val, Tensor): |
| tensor = evalue.val |
| if tensor.data_buffer_idx > 0: |
| assert not _is_dynamic_shape_tensor( |
| tensor |
| ), "A constant tensor can not be dynamic shape" |
| evstr += "CT" # constant tensor |
| assert tensor.allocation_info is None |
| else: |
| if mark_dynamic_shape_tensor: |
| if tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND: |
| evstr += "UB" # upper bound tensor will be shown as 'UBT' |
| elif tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND: |
| evstr += "DU" # dynamic unbound tensor will be shown as 'DUT' |
| evstr += "T" |
| if show_meminfo: |
| if tensor.allocation_info: |
| evstr += f"m{tensor.allocation_info.memory_id}.{tensor.allocation_info.memory_offset}" |
| else: |
| evstr += "m." |
| evstr += f"{tensor.sizes}{_scalar_type_str(tensor.scalar_type)}" |
| elif isinstance(evalue.val, TensorList): |
| evstr += "TL" |
| tensorlist = evalue.val |
| # pyre-ignore |
| evstr += str(tensorlist.items) |
| elif isinstance(evalue.val, OptionalTensorList): |
| evstr += "OTL" |
| optionaltensorlist = evalue.val |
| # pyre-ignore |
| evstr += str(optionaltensorlist.items) |
| elif isinstance(evalue.val, IntList): |
| evstr += "IL" |
| intlist = evalue.val |
| # pyre-ignore |
| evstr += str(intlist.items) |
| elif isinstance(evalue.val, DoubleList): |
| evstr += "DL" |
| doublelist = evalue.val |
| # pyre-ignore |
| evstr += str(doublelist.items) |
| elif isinstance(evalue.val, BoolList): |
| evstr += "BL" |
| boollist = evalue.val |
| # pyre-ignore |
| evstr += str(boollist.items) |
| elif isinstance(evalue.val, Int): |
| intval = evalue.val |
| evstr += f"I{intval.int_val}" |
| elif isinstance(evalue.val, Double): |
| doubleval = evalue.val |
| evstr += f"D{doubleval.double_val}" |
| elif isinstance(evalue.val, Bool): |
| boolval = evalue.val |
| evstr += f"B{int(boolval.bool_val)}" # print 0, 1 since it's shorter than false, true |
| elif isinstance(evalue.val, String): |
| stringval = evalue.val |
| evstr += f"S{stringval.string_val}" |
| elif isinstance(evalue.val, Null): |
| evstr += "N" # for null |
| else: |
| raise RuntimeError(f"Unrecognized type of evalue: {evalue}") |
| evstr += "\033[0m" |
| return evstr |
| |
| |
| def print_program( # noqa: C901 |
| program: Program, |
| show_meminfo: bool = True, |
| mark_dynamic_shape_tensor: bool = False, |
| out: Optional[TextIO] = None, |
| ) -> None: |
| """ |
| Dump the instruction list of a program in a more human readable fashion. |
| |
| The dump follows the following BNF syntax (I combime some regex syntax |
| so the grammar becomes shorter. The grammar is not strict but the main |
| purpose is to let people understand the dump): |
| ``` |
| PROGRAM: (INSTRUCTION)+ |
| INSTRUCTION: SEQUENCE_NO ':' (CALL_KERNEL | JUMP_FALSE) |
| JUMP_FALSE: 'JF' '(' EVALUE ')' '->' TARGET_SEQUENCE_NO |
| CALL_KERNEL: OVERLOADDED_OP_NAME ARGS |
| ARGS: EVALUE | ARGS ',' EVALUE |
| EVALUE: EVALUE_IDX ( TENSOR | INT | BOOL | ...) |
| INT: 'I' ACTUAL_INT_VALUE |
| BOOL: 'B' ZERO_OR_ONE |
| CONST_TENSOR_PREFIX: 'CT' |
| TENSOR: ('T' | CONST_TENSOR_PREFIX) (MEM_ALLOCATION_INFO)? TENSOR_SHAPE TENSOR_DTYPE |
| TENSOR_SHAPE: '[' dim0_size, dim1_size, ..., last_dim_size ']' |
| MEM_ALLOCATION_INFO: PLANNED_MEM_INFO | UNPLANNED_MEM_INFO |
| PLANNED_MEM_INFO: 'm' MEM_LAYER_ID '.' MEM_LAYER_OFFSET |
| UNPLANNED_MEM_INFO: 'm.' |
| ``` |
| |
| To make the dump easier to read, it's colored as follows: |
| 1. input/output EValues are marked as red |
| 2. EValue types (or more specifically tensor types with size and dtype) are marked as blue |
| """ |
| execution_plan = program.execution_plan[0] |
| operators = execution_plan.operators |
| delegates = execution_plan.delegates |
| chain = execution_plan.chains[0] |
| instructions = chain.instructions |
| inputs: List[int] = execution_plan.inputs |
| outputs: List[int] = execution_plan.outputs |
| values: List[EValue] = execution_plan.values |
| |
| def _format_arg(evalue_idx: int) -> str: |
| def _get_io_index(iolist: List[int], target_evalue_idx: int) -> int: |
| """ |
| The list is short enough so linear scan is proper. |
| """ |
| for io_idx, evalue_idx in enumerate(iolist): |
| if evalue_idx == target_evalue_idx: |
| return io_idx |
| return -1 |
| |
| argstr = str(evalue_idx) |
| if (input_idx := _get_io_index(inputs, evalue_idx)) >= 0: |
| argstr += f"\033[31mI{input_idx}\033[0m" |
| if (output_idx := _get_io_index(outputs, evalue_idx)) >= 0: |
| argstr += f"\033[31mO{output_idx}\033[0m" |
| |
| # EValue type |
| evalue = values[evalue_idx] |
| return argstr + _format_evalue(evalue, show_meminfo, mark_dynamic_shape_tensor) |
| |
| print( |
| f"The program contains the following {len(instructions)} instructions", file=out |
| ) |
| for idx, instr in enumerate(instructions): |
| print(f"{idx:3}: ", end="", file=out) |
| if isinstance(instr.instr_args, KernelCall): |
| kernel = instr.instr_args |
| op = operators[kernel.op_index] |
| args = kernel.args |
| |
| opname = f"{op.name}.{op.overload}" if op.overload else op.name |
| argstr = ",".join(map(_format_arg, args)) |
| print(f"{opname} {argstr}", file=out) |
| elif isinstance(instr.instr_args, DelegateCall): |
| delegate = instr.instr_args |
| backend = delegates[delegate.delegate_index] |
| args = delegate.args |
| backend_id = f"{backend.id}" |
| argstr = ",".join(map(_format_arg, args)) |
| print(f"{backend_id} {argstr}", file=out) |
| elif isinstance(instr.instr_args, JumpFalseCall): |
| jfcall = instr.instr_args |
| print( |
| f"JF ({_format_arg(jfcall.cond_value_index)}) -> {jfcall.destination_instruction}", |
| file=out, |
| ) |
| elif isinstance(instr.instr_args, MoveCall): |
| move_call = instr.instr_args |
| print( |
| f"MOVE {_format_arg(move_call.move_from)} -> {_format_arg(move_call.move_to)}", |
| file=out, |
| ) |
| elif isinstance(instr.instr_args, FreeCall): |
| print(f"FREE {_format_arg(instr.instr_args.value_index)}", file=out) |
| else: |
| raise InternalError(f"Unsupport instruction type {instr}") |
| |
| |
| # pyre-ignore |
| def pretty_print(obj: Any, indent: int = 0, out: Optional[TextIO] = None) -> None: |
| """ |
| Pretty prints the given object which is of the Program type and any of its |
| attribute’s types. |
| """ |
| if isinstance(obj, torch.fx.GraphModule): |
| raise ExportError( |
| ExportErrorType.INVALID_INPUT_TYPE, |
| "pretty_print() does not accept GraphModule as input.", |
| ) |
| |
| # Instruction types are IntEnum object |
| if isinstance(obj, IntEnum): |
| print(int(obj), end="", file=out) |
| return |
| |
| primitives = (int, str, bool, float, type(None)) |
| if isinstance(obj, primitives): |
| print(obj, end="", file=out) |
| return |
| |
| if isinstance(obj, bytes): |
| r = reprlib.Repr() |
| r.maxother = 1024 |
| print(r.repr(obj), end="", file=out) |
| return |
| |
| if isinstance(obj, list): |
| if len(obj) < 10 and all(isinstance(elem, int) for elem in obj): |
| print(obj, end="", file=out) |
| return |
| print("[", file=out) |
| for index, elem in enumerate(obj): |
| print(" " * (indent + 1), end="", file=out) |
| pretty_print(elem, indent + 1, out=out) |
| print(f"(index={index}),", file=out) |
| print(" " * indent + "]", end="", file=out) |
| return |
| |
| inline = all( |
| isinstance(getattr(obj, field.name), primitives) for field in fields(obj) |
| ) |
| end = "" if inline else "\n" |
| print(f"{type(obj).__name__}(", end=end, file=out) |
| for i, _field in enumerate(fields(obj)): |
| if not inline: |
| print(" " * (indent + 1), end="", file=out) |
| print(_field.name + "=", end="", file=out) |
| pretty_print(getattr(obj, _field.name), indent + 1, out=out) |
| if i < len(fields(obj)) - 1: |
| print(", ", end="", file=out) |
| print("", end=end, file=out) |
| if not inline: |
| print(" " * indent, end="", file=out) |
| print(")", end="" if indent else "\n", file=out) |
| |
| |
| def pretty_print_stacktraces(obj: FrameList) -> str: |
| """ |
| Pretty prints the traceback for one instruction |
| """ |
| pretty = "Traceback (most recent call last): \n" |
| for frame in obj.items: |
| pretty += f' File "{frame.filename}", ' |
| pretty += f"line {str(frame.lineno)}, in {frame.name}\n" |
| pretty += f"{frame.context} \n" |
| pretty += "\n" |
| return pretty |
| |
| |
| def add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str: |
| """ |
| Insert a cursor at the node location in the fx.Graph. |
| e.g: |
| # graph(): |
| # %x : [#users=1] = placeholder[target=x] |
| # %param : [#users=1] = get_attr[target=param] |
| # %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) |
| # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) |
| # %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) |
| # return clamp |
| |
| This is mostly used for error reporting |
| """ |
| |
| new_graph = copy.deepcopy(graph) |
| |
| found_at = -1 |
| for ix, node in enumerate(graph.nodes): |
| if node == finding_node: |
| found_at = ix |
| |
| # This is heavily based on __str__ method of fx.Graph |
| def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str: |
| s = "graph():" |
| for ix, node in enumerate(graph.nodes): |
| node_str = node.format_node() |
| if node_str: |
| if ix != offending_node_idx: |
| s += "\n " + node_str |
| else: |
| s += "\n--> " + node_str |
| return s |
| |
| return _format_graph(new_graph, found_at) |
| |
| |
| def _stacktrace_to_framelist(stacktrace: str) -> FrameList: |
| """Creates a frame list from a stacktrace string.""" |
| pattern = r'File "(.*?)", line (\d+), in (.*?)\n' |
| matches = re.findall(pattern, stacktrace) |
| mapped_frame_list = [ |
| Frame( |
| filename=match[0], |
| lineno=int(match[1]), |
| name=match[2], |
| context=stacktrace.split("\n")[i * 2 + 1].strip(), |
| ) |
| for i, match in enumerate(matches) |
| ] |
| return FrameList(mapped_frame_list) |
| |
| |
| def inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str: |
| """ |
| Inspect a node by highlighting the node in the graph as well as the stacktrace. |
| |
| Args: |
| graph: The graph containing the node |
| node: The node to be inspected |
| |
| Return: A string. An example output is: |
| |
| _param_constant0 error_msg: Here is the failing node in the graph module: |
| graph(): |
| %arg0_1 : [num_users=1] = placeholder[target=arg0_1] |
| --> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] |
| %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] |
| %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) |
| %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2] |
| %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3] |
| %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) |
| %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {}) |
| %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4] |
| %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5] |
| %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) |
| %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {}) |
| return [aten_gelu_default] |
| This node _param_constant0 has metadata of: |
| The node stacktrace: |
| Traceback (most recent call last): |
| File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward |
| return self.test_model(x) |
| File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl |
| return forward_call(*args, **kwargs) |
| File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward |
| a = self.conv1(x) |
| |
| """ |
| graph_str_with_cursor = add_cursor_to_graph(graph, node) |
| error_msg = ( |
| f"Here is the node in the graph module:\n" |
| f"{graph_str_with_cursor}\n" |
| f"This node {node} has metadata of:\n" |
| ) |
| # Node spec error message |
| if hasattr(node.meta, "spec"): |
| error_msg += f"The node spec:\n{node.meta['spec']}\n" |
| |
| # Stacktrace error message |
| if "stack_trace" in node.meta: |
| framelist = _stacktrace_to_framelist(node.meta["stack_trace"]) |
| error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n" |
| return error_msg |