| # Copyright 2023-2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| import logging |
| import os |
| from typing import Any, cast |
| |
| import numpy as np |
| import serializer.tosa_serializer as ts |
| import torch |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| |
| from executorch.backends.arm.tosa_quant_utils import ( |
| get_quant_arg_downstream, |
| get_quant_arg_upstream, |
| q_op, |
| ) |
| from executorch.exir.dialects._ops import ops as exir_ops |
| from serializer.tosa_serializer import TosaOp |
| from torch.fx import Node |
| |
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.WARNING) |
| TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" |
| if TOSA_DBG_VERBOSE: |
| logging.basicConfig(level=logging.INFO) |
| logger.setLevel(logging.INFO) |
| |
| |
| def dbg_node(node): |
| # Debug output of node information |
| logger.info("OP") |
| logger.info(f" op is {node.op}") |
| logger.info(f" name is {node.name}") |
| logger.info(f" node target is {node.target}") |
| logger.info(f" node args is {node.args}") |
| logger.info(f" node kwargs is {node.kwargs}") |
| logger.info(" node.meta = ") |
| for k, v in node.meta.items(): |
| logger.info(f" '{k}' = {v}") |
| if isinstance(v, list): |
| for i in v: |
| logger.info(f" {i} ") |
| |
| |
| # Output TOSA flatbuffer and test harness file |
| def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): |
| filename = f"output{suffix}.tosa" |
| |
| logger.info(f"Emitting debug output to: {path=}, {suffix=}") |
| |
| os.makedirs(path, exist_ok=True) |
| |
| fb = tosa_graph.serialize() |
| js = tosa_graph.writeJson(filename) |
| |
| filepath_tosa_fb = os.path.join(path, filename) |
| with open(filepath_tosa_fb, "wb") as f: |
| f.write(fb) |
| assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer" |
| |
| filepath_desc_json = os.path.join(path, f"desc{suffix}.json") |
| with open(filepath_desc_json, "w") as f: |
| f.write(js) |
| assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" |
| |
| |
| def dbg_fail(node, tosa_graph, path): |
| dbg_tosa_dump(tosa_graph, path) |
| logger.warn("Internal error due to poorly handled node:") |
| dbg_node(node) |
| logger.warn(f"Debug output captured in '{path}'.") |
| raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") |
| |
| |
| # Helper function to match TOSA's broadcasting rank requirement |
| # Ref: TOSA 0.80.0 specification - 1.9.3. Data Layouts from |
| # https://www.mlplatform.org/tosa/tosa_spec.html |
| def promote_shape(tosa_fb, arg, promoted_shape, out_dtype): |
| assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape" |
| reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype) |
| attr = ts.TosaSerializerAttribute() |
| attr.ReshapeAttribute(promoted_shape) |
| tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr) |
| return reshape_res |
| |
| |
| # Helper transpose function to match TOSA's shape requirements |
| # E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes: |
| # https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d |
| def transpose_helper(tosa_fb, input, new_order, out_dtype): |
| # Check new_order's length is equal to input rank |
| assert len(input.shape) == len(new_order), "Wrong shape order length" |
| |
| # Check no duplications |
| assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers" |
| |
| # Check all dims are valid |
| for idx in new_order: |
| if idx < 0: |
| assert True, "Negative dim number" |
| elif idx >= len(input.shape): |
| assert True, "Dim is greater than input rank" |
| |
| input_shape_transpoed = [input.shape[i] for i in new_order] |
| attr = ts.TosaSerializerAttribute() |
| attr.TransposeAttribute(new_order) |
| input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype) |
| tosa_fb.addOperator( |
| TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr |
| ) |
| return input_transposed |
| |
| |
| def getNodeArgs(node: Node) -> list[TosaArg]: |
| return [TosaArg(arg) for arg in node.args] |
| |
| |
| def get_input_tensor(node: Node) -> TosaArg: |
| return TosaArg(node.args[0]) |
| |
| |
| def get_output_node(node: Node) -> Node: |
| return list(node.users)[0] |
| |
| |
| """ TOSA reshape returns a tensor with the same type/values as the input. |
| No data conversion happens during a reshape operation. """ |
| |
| |
| def build_reshape(tosa_fb, input_name, new_shape, output_name): |
| attr = ts.TosaSerializerAttribute() |
| attr.ReshapeAttribute(new_shape) |
| tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) |
| |
| |
| def is_bias_node_for_quantized_conv(node): |
| consumer_node = list(node.users)[0] |
| return ( |
| consumer_node.target == exir_ops.edge.aten.convolution.default |
| and list(consumer_node.users)[0].target == q_op |
| ) |
| |
| |
| def is_consumer_node_depthwise_conv2d(node): |
| consumer_node = list(node.users)[0] |
| if consumer_node.target == exir_ops.edge.aten.convolution.default: |
| inputs = getNodeArgs(consumer_node) |
| group = inputs[-1] |
| in_channels = inputs[0].shape[1] |
| out_channels = inputs[1].shape[0] |
| if (in_channels == group.number) and (out_channels % in_channels) == 0: |
| return True |
| |
| return False |
| |
| |
| def build_avg_pool_2d_common( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| input_tensor: TosaArg, |
| kernel_size: list, |
| stride: list, |
| padding: list, |
| is_quant_node: bool, |
| output: TosaArg, |
| ): |
| accumulator_type = input_tensor.dtype |
| |
| if is_quant_node: |
| # Accumulator type always is int32 when input tensor is an integer type. |
| accumulator_type = ts.DType.INT32 |
| |
| # Initilize zero point to zero. |
| input_zp = 0 |
| output_zp = 0 |
| |
| if is_quant_node: |
| input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp |
| output_zp = get_quant_arg_downstream(list(node.users)[0]).zp |
| |
| attr = ts.TosaSerializerAttribute() |
| attr.PoolAttribute( |
| kernel=kernel_size, |
| stride=stride, |
| pad=padding, |
| input_zp=input_zp, |
| output_zp=output_zp, |
| accum_dtype=accumulator_type, |
| ) |
| |
| tosa_graph.addOperator( |
| TosaOp.Op().AVG_POOL2D, |
| [input_tensor.name], |
| [output.name], |
| attr, |
| ) |
| |
| |
| def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: |
| """Returns two input nodes to 'node' in order. If 'node' only has one input, |
| it is returned twice. |
| |
| Fails if there are no input nodes. |
| Fails if there are >2 input nodes and 'check' is True, |
| """ |
| |
| num_inputs = len(node.all_input_nodes) |
| assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." |
| |
| input1 = node.all_input_nodes[0] |
| if num_inputs == 1: |
| input2 = node.all_input_nodes[0] |
| else: |
| input2 = node.all_input_nodes[1] |
| if check: |
| assert ( |
| num_inputs <= 2 |
| ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." |
| |
| return input1, input2 |
| |
| |
| def tosa_shape(shape, dim_order): |
| return tuple([shape[dim] for dim in dim_order]) |
| |
| |
| def expand_dims( |
| tosa_graph: ts.TosaSerializer, |
| input_node: TosaArg, |
| dtype: int, |
| dim: int, |
| ) -> Any: |
| """Inserts TOSA operators into the tosa_graph, that perform the equivalent |
| of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the |
| dim location. |
| |
| Args: |
| tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate. |
| input_node (TosaArg): The parent node of the expand dim operations. |
| dtype (ts.DType): The data type expand dims operations. |
| dim (int): The dimension to expand. |
| |
| Returns: |
| Any: The output tensor of the inserted operation in the TOSA graph. |
| """ |
| new_shape = list(input_node.shape) |
| new_shape.insert(dim, 1) |
| |
| intermediate = tosa_graph.addIntermediate(new_shape, dtype) |
| |
| build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name) |
| |
| return intermediate |
| |
| |
| def get_resize_parameters( |
| input_size: torch.Tensor, |
| output_size: torch.Tensor, |
| resize_mode: int, |
| align_corners: bool, |
| ): |
| """Get the tosa.resize parameters based on the input and output size. |
| |
| Args: |
| input_size (torch.Tensor): Size of the input |
| output_size (torch.Tensor): Size of the output |
| resize_mode (tosa.ResizeMode): The TOSA resize mode |
| align_corners (bool): Align the corners pixels of the input and output |
| |
| Returns: |
| scale_n (torch.Tensor), scale_d (torch.Tensor), |
| offset (torch.Tensor), border (torch.Tensor) |
| """ |
| assert torch.all(input_size > 0) |
| assert torch.all(output_size > 0) |
| |
| scale_n = torch.tensor( |
| [ |
| so - 1 if align_corners and si > 1 and so > 1 else so |
| for si, so in zip(input_size, output_size) |
| ] |
| ) |
| scale_d = torch.tensor( |
| [ |
| si - 1 if align_corners and si > 1 and so > 1 else si |
| for si, so in zip(input_size, output_size) |
| ] |
| ) |
| |
| gcd = torch.gcd(scale_n, scale_d) |
| scale_n = scale_n // gcd |
| scale_d = scale_d // gcd |
| |
| # No half-pixel centre support in PyTorch, no offset needed |
| offset = torch.zeros_like(input_size) |
| border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset |
| |
| return scale_n, scale_d, offset, border |