| # Copyright 2023-2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| from typing import List |
| |
| import serializer.tosa_serializer as ts |
| import torch |
| from executorch.backends.arm.operators.node_visitor import ( |
| NodeVisitor, |
| register_node_visitor, |
| ) |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| from executorch.backends.arm.tosa_quant_utils import ( |
| build_rescale_conv_output, |
| get_quant_arg_downstream, |
| get_quant_arg_upstream, |
| ) |
| from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape |
| |
| from serializer.tosa_serializer import TosaOp |
| |
| |
| @register_node_visitor |
| class Conv2dVisitor(NodeVisitor): |
| target = "aten.convolution.default" |
| |
| def __init__(self, *args): |
| super().__init__(*args) |
| |
| # torch.nn.Conv2d does not require the result of |
| # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` |
| # must be an integer, but tosa currently strictly require this property. |
| # This function adjusts the pad value to meet the requirement. |
| def adjust_pad_if_needed(self, input, weight, stride, pad, dilation): |
| mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride |
| |
| # No need to adjust |
| if mod_remainder == 0: |
| return pad |
| |
| if mod_remainder > pad: |
| raise RuntimeError( |
| "This case should be handled by the SizeAdjustConv2d pass, is it enabled?" |
| ) |
| return pad - mod_remainder |
| |
| def define_node( |
| self, |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| inputs: List[TosaArg], |
| output: TosaArg, |
| is_quant_node: bool, |
| ) -> None: |
| input, weight, bias, stride, pad, dilation, _, _, group = inputs |
| |
| # Currently only int8 is supported in quantized types. |
| actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype |
| |
| # Get the attributes of convolution. |
| attr = ts.TosaSerializerAttribute() |
| pad_attr = [val for val in pad.special for _ in (0, 1)] |
| stride_attr = stride.special |
| dilation_attr = dilation.special |
| |
| # Adjust the pad value if needed to meet the strict convolution output shape calculation. |
| pad_attr[1] = self.adjust_pad_if_needed( |
| input.shape[2], |
| weight.shape[2], |
| stride_attr[0], |
| pad_attr[1], |
| dilation_attr[0], |
| ) |
| pad_attr[3] = self.adjust_pad_if_needed( |
| input.shape[3], |
| weight.shape[3], |
| stride_attr[1], |
| pad_attr[3], |
| dilation_attr[1], |
| ) |
| |
| input_zp = ( |
| get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 |
| ) |
| |
| attr.ConvAttribute( |
| pad=pad_attr, |
| stride=stride_attr, |
| dilation=dilation_attr, |
| input_zp=input_zp, |
| weight_zp=0, |
| local_bound=False, |
| ) |
| |
| # Non-bias case. |
| if len(node.all_input_nodes) == 2: |
| # Create a zero bias tensor if not presented |
| out_channels = weight.shape[0] |
| bias_name = "bias" + node.name.split("default", 1)[1] |
| bias = tosa_graph.addConst( |
| [out_channels], |
| ts.DType.INT32 if is_quant_node else output.dtype, |
| [0] * out_channels, |
| name=bias_name, |
| ) |
| |
| # The output type is int32 when input type is int8. |
| conv2d_output_name = output.name |
| if is_quant_node: |
| conv2d_res = tosa_graph.addIntermediate( |
| tosa_shape(output.shape, output.dim_order), ts.DType.INT32 |
| ) |
| conv2d_output_name = conv2d_res.name |
| |
| # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) |
| in_channels = input.shape[1] |
| out_channels = weight.shape[0] |
| if (in_channels == group.number) and (out_channels % in_channels) == 0: |
| """Depthwise convolution case""" |
| # Reshape torch shape format of weight tensor to tosa required format. |
| # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d |
| m_length = int(out_channels / in_channels) |
| weight_post_shape = ( |
| weight.shape[2], |
| weight.shape[3], |
| in_channels, |
| m_length, |
| ) |
| |
| weight_reshaped = tosa_graph.addIntermediate( |
| weight_post_shape, |
| ts.DType.INT8 if is_quant_node else weight.dtype, |
| ) |
| build_reshape( |
| tosa_graph, weight.name, weight_post_shape, weight_reshaped.name |
| ) |
| tosa_op = TosaOp.Op().DEPTHWISE_CONV2D |
| weight_name = weight_reshaped.name |
| else: |
| """Regular convolution case""" |
| tosa_op = TosaOp.Op().CONV2D |
| weight_name = weight.name |
| |
| tosa_graph.addOperator( |
| tosa_op, |
| [ |
| input.name, |
| weight_name, |
| bias.name, |
| ], |
| [conv2d_output_name], |
| attr, |
| ) |
| |
| # For quantized convolution, rescale the output value back to the same |
| # integer value domain of the next op. Otherwise return float32 output. |
| if is_quant_node: |
| # Get scale_factor from input, weight, and output. |
| input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale |
| weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale |
| output_qargs = get_quant_arg_downstream(list(node.users)[0]) |
| |
| build_rescale_conv_output( |
| tosa_graph, |
| # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. |
| conv2d_res, |
| output.name, |
| actual_out_type, |
| input_scale, |
| weight_scale, |
| output_qargs.scale, |
| output_qargs.zp, |
| ) |