| # Copyright 2023-2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| from typing import List |
| |
| import executorch.backends.arm.tosa_quant_utils as tqutils |
| import executorch.backends.arm.tosa_utils as tutils |
| |
| import serializer.tosa_serializer as ts |
| import torch |
| from executorch.backends.arm.operators.node_visitor import ( |
| NodeVisitor, |
| register_node_visitor, |
| ) |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| from executorch.backends.arm.tosa_specification import TosaSpecification |
| from serializer.tosa_serializer import TosaOp |
| from torch.fx import Node |
| |
| |
| @register_node_visitor |
| class AddVisitor_080_BI(NodeVisitor): |
| target = "aten.add.Tensor" |
| |
| tosa_specs = [ |
| TosaSpecification.create_from_string("TOSA-0.80.0+BI"), |
| ] |
| |
| def __init__(self, *args): |
| super().__init__(*args) |
| |
| def define_node( |
| self, |
| node: Node, |
| tosa_graph: ts.TosaSerializer, |
| inputs: List[TosaArg], |
| output: TosaArg, |
| is_quant_node: bool, |
| ) -> None: |
| input_nodes = tutils.get_two_inputs(node) |
| |
| if not is_quant_node and not all( |
| tensor.meta["val"].dtype in (torch.int8, torch.int32) |
| for tensor in input_nodes |
| ): |
| raise RuntimeError( |
| f"Unexpected non quantized {AddVisitor_080_BI.target} node." |
| ) |
| |
| needs_rescale = not ( |
| all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) |
| and node.meta["val"].dtype == torch.int32 |
| ) |
| |
| if needs_rescale: |
| # Rescale inputs to 32 bit |
| rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( |
| input_nodes, tosa_graph |
| ) |
| |
| # Prepare add output tensor |
| broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) |
| add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) |
| else: |
| add_output = output |
| rescaled_inputs = inputs |
| |
| # Do the INT32 Add |
| tosa_graph.addOperator( |
| TosaOp.Op().ADD, |
| [ |
| rescaled_inputs[0].name, |
| rescaled_inputs[1].name, |
| ], |
| [add_output.name], |
| None, |
| ) |
| |
| if needs_rescale: |
| # Scale output back to 8 bit |
| # pyre-ignore |
| tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) |
| |
| |
| @register_node_visitor |
| class AddVisitor_080_MI(AddVisitor_080_BI): |
| # inheriting 'target' from BI class |
| |
| tosa_specs = [ |
| TosaSpecification.create_from_string("TOSA-0.80.0+MI"), |
| ] |
| |
| def __init__(self, *args): |
| super().__init__(*args) |
| |
| def define_node( |
| self, |
| node: Node, |
| tosa_graph: ts.TosaSerializer, |
| inputs: List[TosaArg], |
| output: TosaArg, |
| is_quant_node: bool, |
| ) -> None: |
| if is_quant_node: |
| # Call the inherited define_node for handling integers |
| super().define_node(node, tosa_graph, inputs, output, is_quant_node) |
| else: |
| # FP32 Add lowering |
| tosa_graph.addOperator( |
| TosaOp.Op().ADD, |
| [inputs[0].name, inputs[1].name], |
| [output.name], |
| None, |
| ) |