blob: cf2daa2c2d3df2a7d841bf5972d1c8dddce95303 [file] [log] [blame] [edit]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import copy
import re
import reprlib
from dataclasses import fields
from enum import IntEnum
from typing import Any, List, Optional, TextIO
import torch
from executorch.exir.error import ExportError, ExportErrorType, InternalError
from executorch.exir.schema import (
Bool,
BoolList,
DelegateCall,
Double,
DoubleList,
EValue,
Frame,
FrameList,
FreeCall,
Int,
IntList,
JumpFalseCall,
KernelCall,
MoveCall,
Null,
OptionalTensorList,
Program,
ScalarType,
String,
Tensor,
TensorList,
TensorShapeDynamism,
)
def _scalar_type_str(scalar_type: ScalarType) -> str:
type2str = {
ScalarType.BYTE: "bt",
ScalarType.CHAR: "c",
ScalarType.SHORT: "s",
ScalarType.INT: "i",
ScalarType.LONG: "l",
ScalarType.HALF: "h",
ScalarType.FLOAT: "f",
ScalarType.DOUBLE: "d",
ScalarType.COMPLEX32: "c32",
ScalarType.COMPLEX64: "c64",
ScalarType.COMPLEX128: "c128",
ScalarType.BOOL: "b",
ScalarType.QINT8: "qi8",
ScalarType.QUINT8: "qui8",
ScalarType.QINT32: "qi32",
ScalarType.BFLOAT16: "bf16",
ScalarType.QUINT4x2: "qui4x2",
ScalarType.QUINT2x4: "qui2x4",
}
if not (ret := type2str.get(scalar_type, None)):
raise RuntimeError(f"Unrecognized scalar_type: {scalar_type}")
else:
return ret
def _is_dynamic_shape_tensor(tensor: Tensor) -> bool:
return tensor.shape_dynamism != TensorShapeDynamism.STATIC
def _format_evalue( # noqa: C901
evalue: EValue, show_meminfo: bool, mark_dynamic_shape_tensor: bool
) -> str:
evstr = "\033[34m"
if isinstance(evalue.val, Tensor):
tensor = evalue.val
if tensor.data_buffer_idx > 0:
assert not _is_dynamic_shape_tensor(
tensor
), "A constant tensor can not be dynamic shape"
evstr += "CT" # constant tensor
assert tensor.allocation_info is None
else:
if mark_dynamic_shape_tensor:
if tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND:
evstr += "UB" # upper bound tensor will be shown as 'UBT'
elif tensor.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND:
evstr += "DU" # dynamic unbound tensor will be shown as 'DUT'
evstr += "T"
if show_meminfo:
if tensor.allocation_info:
evstr += f"m{tensor.allocation_info.memory_id}.{tensor.allocation_info.memory_offset}"
else:
evstr += "m."
evstr += f"{tensor.sizes}{_scalar_type_str(tensor.scalar_type)}"
elif isinstance(evalue.val, TensorList):
evstr += "TL"
tensorlist = evalue.val
# pyre-ignore
evstr += str(tensorlist.items)
elif isinstance(evalue.val, OptionalTensorList):
evstr += "OTL"
optionaltensorlist = evalue.val
# pyre-ignore
evstr += str(optionaltensorlist.items)
elif isinstance(evalue.val, IntList):
evstr += "IL"
intlist = evalue.val
# pyre-ignore
evstr += str(intlist.items)
elif isinstance(evalue.val, DoubleList):
evstr += "DL"
doublelist = evalue.val
# pyre-ignore
evstr += str(doublelist.items)
elif isinstance(evalue.val, BoolList):
evstr += "BL"
boollist = evalue.val
# pyre-ignore
evstr += str(boollist.items)
elif isinstance(evalue.val, Int):
intval = evalue.val
evstr += f"I{intval.int_val}"
elif isinstance(evalue.val, Double):
doubleval = evalue.val
evstr += f"D{doubleval.double_val}"
elif isinstance(evalue.val, Bool):
boolval = evalue.val
evstr += f"B{int(boolval.bool_val)}" # print 0, 1 since it's shorter than false, true
elif isinstance(evalue.val, String):
stringval = evalue.val
evstr += f"S{stringval.string_val}"
elif isinstance(evalue.val, Null):
evstr += "N" # for null
else:
raise RuntimeError(f"Unrecognized type of evalue: {evalue}")
evstr += "\033[0m"
return evstr
def print_program( # noqa: C901
program: Program,
show_meminfo: bool = True,
mark_dynamic_shape_tensor: bool = False,
out: Optional[TextIO] = None,
) -> None:
"""
Dump the instruction list of a program in a more human readable fashion.
The dump follows the following BNF syntax (I combime some regex syntax
so the grammar becomes shorter. The grammar is not strict but the main
purpose is to let people understand the dump):
```
PROGRAM: (INSTRUCTION)+
INSTRUCTION: SEQUENCE_NO ':' (CALL_KERNEL | JUMP_FALSE)
JUMP_FALSE: 'JF' '(' EVALUE ')' '->' TARGET_SEQUENCE_NO
CALL_KERNEL: OVERLOADDED_OP_NAME ARGS
ARGS: EVALUE | ARGS ',' EVALUE
EVALUE: EVALUE_IDX ( TENSOR | INT | BOOL | ...)
INT: 'I' ACTUAL_INT_VALUE
BOOL: 'B' ZERO_OR_ONE
CONST_TENSOR_PREFIX: 'CT'
TENSOR: ('T' | CONST_TENSOR_PREFIX) (MEM_ALLOCATION_INFO)? TENSOR_SHAPE TENSOR_DTYPE
TENSOR_SHAPE: '[' dim0_size, dim1_size, ..., last_dim_size ']'
MEM_ALLOCATION_INFO: PLANNED_MEM_INFO | UNPLANNED_MEM_INFO
PLANNED_MEM_INFO: 'm' MEM_LAYER_ID '.' MEM_LAYER_OFFSET
UNPLANNED_MEM_INFO: 'm.'
```
To make the dump easier to read, it's colored as follows:
1. input/output EValues are marked as red
2. EValue types (or more specifically tensor types with size and dtype) are marked as blue
"""
execution_plan = program.execution_plan[0]
operators = execution_plan.operators
delegates = execution_plan.delegates
chain = execution_plan.chains[0]
instructions = chain.instructions
inputs: List[int] = execution_plan.inputs
outputs: List[int] = execution_plan.outputs
values: List[EValue] = execution_plan.values
def _format_arg(evalue_idx: int) -> str:
def _get_io_index(iolist: List[int], target_evalue_idx: int) -> int:
"""
The list is short enough so linear scan is proper.
"""
for io_idx, evalue_idx in enumerate(iolist):
if evalue_idx == target_evalue_idx:
return io_idx
return -1
argstr = str(evalue_idx)
if (input_idx := _get_io_index(inputs, evalue_idx)) >= 0:
argstr += f"\033[31mI{input_idx}\033[0m"
if (output_idx := _get_io_index(outputs, evalue_idx)) >= 0:
argstr += f"\033[31mO{output_idx}\033[0m"
# EValue type
evalue = values[evalue_idx]
return argstr + _format_evalue(evalue, show_meminfo, mark_dynamic_shape_tensor)
print(
f"The program contains the following {len(instructions)} instructions", file=out
)
for idx, instr in enumerate(instructions):
print(f"{idx:3}: ", end="", file=out)
if isinstance(instr.instr_args, KernelCall):
kernel = instr.instr_args
op = operators[kernel.op_index]
args = kernel.args
opname = f"{op.name}.{op.overload}" if op.overload else op.name
argstr = ",".join(map(_format_arg, args))
print(f"{opname} {argstr}", file=out)
elif isinstance(instr.instr_args, DelegateCall):
delegate = instr.instr_args
backend = delegates[delegate.delegate_index]
args = delegate.args
backend_id = f"{backend.id}"
argstr = ",".join(map(_format_arg, args))
print(f"{backend_id} {argstr}", file=out)
elif isinstance(instr.instr_args, JumpFalseCall):
jfcall = instr.instr_args
print(
f"JF ({_format_arg(jfcall.cond_value_index)}) -> {jfcall.destination_instruction}",
file=out,
)
elif isinstance(instr.instr_args, MoveCall):
move_call = instr.instr_args
print(
f"MOVE {_format_arg(move_call.move_from)} -> {_format_arg(move_call.move_to)}",
file=out,
)
elif isinstance(instr.instr_args, FreeCall):
print(f"FREE {_format_arg(instr.instr_args.value_index)}", file=out)
else:
raise InternalError(f"Unsupport instruction type {instr}")
# pyre-ignore
def pretty_print(obj: Any, indent: int = 0, out: Optional[TextIO] = None) -> None:
"""
Pretty prints the given object which is of the Program type and any of its
attribute’s types.
"""
if isinstance(obj, torch.fx.GraphModule):
raise ExportError(
ExportErrorType.INVALID_INPUT_TYPE,
"pretty_print() does not accept GraphModule as input.",
)
# Instruction types are IntEnum object
if isinstance(obj, IntEnum):
print(int(obj), end="", file=out)
return
primitives = (int, str, bool, float, type(None))
if isinstance(obj, primitives):
print(obj, end="", file=out)
return
if isinstance(obj, bytes):
r = reprlib.Repr()
r.maxother = 1024
print(r.repr(obj), end="", file=out)
return
if isinstance(obj, list):
if len(obj) < 10 and all(isinstance(elem, int) for elem in obj):
print(obj, end="", file=out)
return
print("[", file=out)
for index, elem in enumerate(obj):
print(" " * (indent + 1), end="", file=out)
pretty_print(elem, indent + 1, out=out)
print(f"(index={index}),", file=out)
print(" " * indent + "]", end="", file=out)
return
inline = all(
isinstance(getattr(obj, field.name), primitives) for field in fields(obj)
)
end = "" if inline else "\n"
print(f"{type(obj).__name__}(", end=end, file=out)
for i, _field in enumerate(fields(obj)):
if not inline:
print(" " * (indent + 1), end="", file=out)
print(_field.name + "=", end="", file=out)
pretty_print(getattr(obj, _field.name), indent + 1, out=out)
if i < len(fields(obj)) - 1:
print(", ", end="", file=out)
print("", end=end, file=out)
if not inline:
print(" " * indent, end="", file=out)
print(")", end="" if indent else "\n", file=out)
def pretty_print_stacktraces(obj: FrameList) -> str:
"""
Pretty prints the traceback for one instruction
"""
pretty = "Traceback (most recent call last): \n"
for frame in obj.items:
pretty += f' File "{frame.filename}", '
pretty += f"line {str(frame.lineno)}, in {frame.name}\n"
pretty += f"{frame.context} \n"
pretty += "\n"
return pretty
def add_cursor_to_graph(graph: torch.fx.Graph, finding_node: torch.fx.Node) -> str:
"""
Insert a cursor at the node location in the fx.Graph.
e.g:
# graph():
# %x : [#users=1] = placeholder[target=x]
# %param : [#users=1] = get_attr[target=param]
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
# --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
# %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
# return clamp
This is mostly used for error reporting
"""
new_graph = copy.deepcopy(graph)
found_at = -1
for ix, node in enumerate(graph.nodes):
if node == finding_node:
found_at = ix
# This is heavily based on __str__ method of fx.Graph
def _format_graph(graph: torch.fx.Graph, offending_node_idx: int) -> str:
s = "graph():"
for ix, node in enumerate(graph.nodes):
node_str = node.format_node()
if node_str:
if ix != offending_node_idx:
s += "\n " + node_str
else:
s += "\n--> " + node_str
return s
return _format_graph(new_graph, found_at)
def _stacktrace_to_framelist(stacktrace: str) -> FrameList:
"""Creates a frame list from a stacktrace string."""
pattern = r'File "(.*?)", line (\d+), in (.*?)\n'
matches = re.findall(pattern, stacktrace)
mapped_frame_list = [
Frame(
filename=match[0],
lineno=int(match[1]),
name=match[2],
context=stacktrace.split("\n")[i * 2 + 1].strip(),
)
for i, match in enumerate(matches)
]
return FrameList(mapped_frame_list)
def inspect_node(graph: torch.fx.Graph, node: torch.fx.Node) -> str:
"""
Inspect a node by highlighting the node in the graph as well as the stacktrace.
Args:
graph: The graph containing the node
node: The node to be inspected
Return: A string. An example output is:
_param_constant0 error_msg: Here is the failing node in the graph module:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
--> %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
%_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
%aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {})
%_param_constant4 : [num_users=1] = get_attr[target=_param_constant4]
%_param_constant5 : [num_users=1] = get_attr[target=_param_constant5]
%aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {})
return [aten_gelu_default]
This node _param_constant0 has metadata of:
The node stacktrace:
Traceback (most recent call last):
File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward
return self.test_model(x)
File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward
a = self.conv1(x)
"""
graph_str_with_cursor = add_cursor_to_graph(graph, node)
error_msg = (
f"Here is the node in the graph module:\n"
f"{graph_str_with_cursor}\n"
f"This node {node} has metadata of:\n"
)
# Node spec error message
if hasattr(node.meta, "spec"):
error_msg += f"The node spec:\n{node.meta['spec']}\n"
# Stacktrace error message
if "stack_trace" in node.meta:
framelist = _stacktrace_to_framelist(node.meta["stack_trace"])
error_msg += f"The node stacktrace:\n{pretty_print_stacktraces(framelist)}\n"
return error_msg