| # Copyright 2023-2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| # Utiliy functions for TOSA quantized lowerings |
| |
| import math |
| from typing import Callable, cast, NamedTuple, Sequence |
| |
| import numpy as np |
| |
| import serializer.tosa_serializer as ts |
| import torch.fx |
| import tosa.Op as TosaOp |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| from executorch.exir.dialects._ops import ops as exir_ops |
| from serializer.tosa_serializer import TosaSerializerTensor |
| from torch.fx import Node |
| |
| |
| q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default |
| dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default |
| dq_q_ops = (q_op, dq_op) |
| passable_ops = [ |
| exir_ops.edge.aten.view_copy.default, |
| exir_ops.edge.aten.permute_copy.default, |
| exir_ops.edge.aten.squeeze_copy.dims, |
| exir_ops.edge.aten.unsqueeze_copy.default, |
| exir_ops.edge.aten.split_with_sizes_copy.default, |
| exir_ops.edge.aten.repeat.default, |
| exir_ops.edge.aten.clone.default, |
| exir_ops.edge.aten.slice_copy.Tensor, |
| exir_ops.edge.aten.cat.default, |
| ] |
| |
| |
| def register_passable_op(op): |
| """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" |
| passable_ops.append(op) |
| |
| |
| class QuantArgs(NamedTuple): |
| scale: float |
| zp: int |
| qmin: int |
| qmax: int |
| dtype: torch.dtype |
| |
| def quantize_value(self, x): |
| if not isinstance(x, torch.Tensor): |
| x = torch.Tensor([x]) |
| return torch.clip( |
| torch.round(x / self.scale) + self.zp, |
| self.qmin, |
| self.qmax, |
| ).to(self.dtype) |
| |
| def dequantize_value(self, qx: int) -> float: |
| return (qx - self.zp) * self.scale |
| |
| |
| def quantize_value(x, qargs: QuantArgs, dtype=np.int8): |
| return np.clip( |
| np.round(x / qargs.scale) + qargs.zp, |
| qargs.qmin, |
| qargs.qmax, |
| ).astype(dtype) |
| |
| |
| def dequantize_value(qx, qargs: QuantArgs): |
| return (qx - qargs.zp) * qargs.scale |
| |
| |
| def qargs_from_qnode(node: torch.fx.Node): |
| assert node.target in dq_q_ops, f"Op {node} is not a quant node." |
| |
| return QuantArgs( |
| scale=cast(float, node.args[1]), |
| zp=cast(int, node.args[2]), |
| qmin=cast(int, node.args[3]), |
| qmax=cast(int, node.args[4]), |
| dtype=cast(torch.dtype, node.args[5]), |
| ) |
| |
| |
| def get_neighbour_quant_args( |
| node: torch.fx.Node, |
| ) -> tuple[list[QuantArgs], list[QuantArgs]]: |
| user_q_args = [] |
| |
| for user in node.users: |
| q_args = search_quant_arg_downstream(user) |
| if q_args: |
| user_q_args.append(q_args) |
| |
| input_q_nodes = [] |
| for input_node in node.all_input_nodes: |
| q_args = search_quant_arg_upstream(input_node) |
| if q_args: |
| input_q_nodes.append(q_args) |
| return user_q_args, input_q_nodes |
| |
| |
| def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: |
| first_q_arg = q_arg_list[0] |
| for q_arg in q_arg_list: |
| if q_arg != first_q_arg: |
| return False |
| return True |
| |
| |
| def is_node_quantized(node: torch.fx.Node) -> bool: |
| if node.target in dq_q_ops: |
| return True |
| |
| user_q_args, input_q_args = get_neighbour_quant_args(node) |
| |
| # If we did not find any neighbouring quant nodes, we are not quantized. |
| if len(input_q_args) == 0 and len(user_q_args) == 0: |
| return False |
| |
| if node.target in passable_ops: |
| assert all_q_args_equal( |
| user_q_args + input_q_args |
| ), f"Node {node} needs same quantization parameters on all inputs and outputs." |
| |
| return True |
| |
| |
| def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: |
| """ |
| Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, |
| starting with 'node'. |
| If a passable node with multiple consumers is encountered, |
| find QuantArgs for all consumers and assert that they are equal. |
| If a node not in passable_ops is encountered, return None. |
| If a node without consumers is encountered, return None. |
| """ |
| if node.target in dq_q_ops: |
| return qargs_from_qnode(node) |
| if node.target not in passable_ops: |
| return None |
| consumer_nodes = list(node.users) |
| if len(consumer_nodes) == 0: |
| return None |
| elif len(consumer_nodes) == 1: |
| return search_quant_arg_downstream(consumer_nodes[0]) |
| else: |
| consumer_qargs: list[QuantArgs] = [] |
| for input in consumer_nodes: |
| quant_args = search_quant_arg_downstream(input) |
| if quant_args: |
| consumer_qargs.append(quant_args) |
| if len(consumer_qargs) == 0: |
| return None |
| assert all_q_args_equal( |
| consumer_qargs |
| ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." |
| return consumer_qargs[0] |
| |
| |
| def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: |
| """Calls search_quant_arg_downstream and asserts that QuantArgs are found, |
| meaning return value can't be None. |
| """ |
| qargs = search_quant_arg_downstream(node) |
| assert qargs, f"Did not find QuantArgs downstream for node {node}" |
| return qargs |
| |
| |
| def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: |
| """ |
| Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, |
| starting with 'node'. |
| If a passable node with multiple inputs is encountered, |
| find QuantArgs for all inputs and assert that they are equal. |
| If a node not in passable_ops is encountered, return None. |
| If a node without inputs is encountered, return None. |
| """ |
| |
| if node.target in dq_q_ops: |
| return qargs_from_qnode(node) |
| if node.target not in passable_ops: |
| return None |
| input_nodes = list(node.all_input_nodes) |
| if len(input_nodes) == 0: |
| return None |
| elif len(input_nodes) == 1: |
| return search_quant_arg_upstream(input_nodes[0]) |
| else: |
| input_qargs: list[QuantArgs] = [] |
| for input in input_nodes: |
| quant_args = search_quant_arg_upstream(input) |
| if quant_args: |
| input_qargs.append(quant_args) |
| if len(input_qargs) == 0: |
| return None |
| assert all_q_args_equal( |
| input_qargs |
| ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." |
| return input_qargs[0] |
| |
| |
| def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: |
| """Calls search_quant_arg_upstream and asserts that QuantArgs are found, |
| meaning return value can't be None. |
| """ |
| qargs = search_quant_arg_upstream(node) |
| assert qargs, f"Did not find QuantArgs upstream for node {node}" |
| return qargs |
| |
| |
| def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: |
| if isinstance(node.target, Callable) and "tosa" in node.target.__name__: |
| return node.meta["val"].dtype |
| if node.target in dq_q_ops: |
| return cast(torch.dtype, node.args[5]) |
| |
| # if not a tosa node, nor a q/dq op, walk the graph until we find a q op |
| user_q_args, input_q_args = get_neighbour_quant_args(node) |
| if len(user_q_args) > 0: |
| return user_q_args[0].dtype |
| elif node.target in passable_ops and len(input_q_args) > 0: |
| return input_q_args[0].dtype |
| else: |
| raise RuntimeError("No quantized node found in graph") |
| |
| |
| # Check if scale32 mode is used for given output element type |
| def is_scale32(type): |
| return type == ts.DType.INT8 |
| |
| |
| # TOSA uses the RESCALE operation to scale between values with differing precision. |
| # The RESCALE operator is defined using an integer multiply, add, and shift. |
| # This utility function is for calculating the multier and shift given a scale. |
| # Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling |
| def compute_multiplier_and_shift(scale, scaleWidth=32): |
| if scaleWidth == 16: |
| offset = 15 |
| elif scaleWidth == 32: |
| offset = 31 |
| else: |
| raise AssertionError("unsupported scale width") |
| |
| assert isinstance(scale, float) |
| |
| mantissa, exponent = math.frexp(scale) |
| shift = exponent |
| |
| const_2_power_15_or_31 = 1 << offset |
| shifted_mantissa = round(mantissa * const_2_power_15_or_31) |
| |
| assert shifted_mantissa <= const_2_power_15_or_31 |
| |
| if shifted_mantissa == const_2_power_15_or_31: |
| shifted_mantissa = shifted_mantissa / 2 |
| shift += 1 |
| |
| # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. |
| shift = offset - shift |
| |
| # INT32_MAX, 2^31 - 1 |
| assert shifted_mantissa <= (const_2_power_15_or_31 - 1) |
| |
| multiplier = shifted_mantissa |
| |
| if shift > 62: |
| multiplier = multiplier >> min(31, shift - 62) |
| shift = 62 |
| return multiplier, shift |
| |
| |
| def build_rescale( |
| tosa_fb, |
| scale, |
| input_node, |
| output_name, |
| output_type, |
| output_shape, |
| input_zp, |
| output_zp, |
| is_double_round=False, |
| ): |
| scale_width = 32 if is_scale32(output_type) else 16 |
| multiplier, shift = compute_multiplier_and_shift(scale, scale_width) |
| |
| attr_rescale = ts.TosaSerializerAttribute() |
| attr_rescale.RescaleAttribute( |
| input_zp=input_zp, |
| output_zp=output_zp, |
| multiplier=[multiplier], |
| shift=[shift], |
| scale32=is_scale32(output_type), |
| double_round=is_double_round, |
| per_channel=False, |
| input_unsigned=False, |
| output_unsigned=False, |
| ) |
| |
| tosa_fb.addOperator( |
| TosaOp.Op().RESCALE, [input_node.name], [output_name], attr_rescale |
| ) |
| |
| return |
| |
| |
| def build_rescale_to_int32( |
| tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False |
| ) -> TosaSerializerTensor: |
| multiplier, shift = compute_multiplier_and_shift(rescale_scale) |
| attr_rescale = ts.TosaSerializerAttribute() |
| attr_rescale.RescaleAttribute( |
| input_zp=input_zp, |
| output_zp=0, |
| multiplier=[multiplier], |
| shift=[shift], |
| scale32=is_scale32, |
| double_round=is_double_round, |
| per_channel=False, |
| input_unsigned=False, |
| output_unsigned=False, |
| ) |
| input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32) |
| tosa_fb.addOperator( |
| TosaOp.Op().RESCALE, |
| [input.name], |
| [input_A_rescaled_to_int32.name], |
| attr_rescale, |
| ) |
| |
| return input_A_rescaled_to_int32 |
| |
| |
| def build_rescale_from_int32( |
| tosa_fb, |
| input_name, |
| output_name, |
| output_zp, |
| rescale_scale, |
| is_scale32=True, |
| is_double_round=False, |
| ) -> None: |
| multiplier, shift = compute_multiplier_and_shift(rescale_scale) |
| attr_rescale_output = ts.TosaSerializerAttribute() |
| attr_rescale_output.RescaleAttribute( |
| input_zp=0, |
| output_zp=output_zp, |
| multiplier=[multiplier], |
| shift=[shift], |
| scale32=is_scale32, |
| double_round=is_double_round, |
| per_channel=False, |
| input_unsigned=False, |
| output_unsigned=False, |
| ) |
| |
| tosa_fb.addOperator( |
| TosaOp.Op().RESCALE, [input_name], [output_name], attr_rescale_output |
| ) |
| |
| return |
| |
| |
| def rescale_nodes_to_int32( |
| nodes: Sequence[Node], tosa_graph: ts.TosaSerializer |
| ) -> tuple[list[TosaSerializerTensor], float]: |
| """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. |
| The scales are adjusted using the smallest scale of all 'nodes'. |
| |
| Returns a list of the rescaled nodes and the scale factor used, |
| needed by rescale_node_back_to_int8. |
| """ |
| |
| tensors = [TosaArg(node) for node in nodes] |
| |
| # Reshape tensor according to tosa dim order |
| for tensor in tensors: |
| dim_order = tensor.dim_order |
| tensor.shape = [tensor.shape[i] for i in dim_order] |
| |
| qargs = [get_quant_arg_upstream(node) for node in nodes] |
| |
| # Scale the int8 quantized input to a common scale in the integer |
| # domain |
| min_scale = min([qarg.scale for qarg in qargs]) |
| scales = [qarg.scale / min_scale for qarg in qargs] |
| |
| rescaled_nodes: list[TosaSerializerTensor] = [] |
| for tensor, qarg, scale in zip(tensors, qargs, scales): |
| rescaled_nodes.append( |
| build_rescale_to_int32( |
| tosa_graph, |
| tensor, |
| qarg.zp, |
| scale, |
| ) |
| ) |
| return rescaled_nodes, min_scale |
| |
| |
| def rescale_node_back_to_int8( |
| node: Node, |
| last_tensor: TosaSerializerTensor, |
| scale: float, |
| tosa_graph: ts.TosaSerializer, |
| ): |
| """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. |
| Parameters: |
| node: The original node that is being handled by the rescales. |
| last_tensor:the tosa tensor to rescale back. |
| scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' |
| tosa_graph: the tosa_graph to manipulate. |
| """ |
| qargs_out = get_quant_arg_downstream(list(node.users)[0]) |
| output_rescale_scale = scale / qargs_out.scale |
| |
| # Rescale Back to INT8 |
| build_rescale_from_int32( |
| tosa_graph, |
| last_tensor.name, |
| node.name, |
| qargs_out.zp, |
| output_rescale_scale, |
| ) |
| |
| |
| """ Creates a TOSA rescale op based on conv2d parameters. """ |
| |
| |
| def build_rescale_conv_output( |
| tosa_fb, |
| op, |
| output_name, |
| output_type, |
| input_scale, |
| weight_scale, |
| output_scale, |
| output_zp, |
| ): |
| # TODO add check to verify if this is a Per-channel quantization. |
| post_conv2d_scale = (input_scale * weight_scale) / output_scale |
| |
| # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. |
| build_rescale( |
| tosa_fb, |
| post_conv2d_scale, |
| op, |
| output_name, |
| output_type, |
| op.shape, |
| 0, |
| output_zp, |
| ) |
| return |