| # Copyright 2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| from typing import List |
| |
| import serializer.tosa_serializer as ts |
| from executorch.backends.arm.operators.node_visitor import ( |
| NodeVisitor, |
| register_node_visitor, |
| ) |
| |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| |
| from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape |
| from serializer.tosa_serializer import TosaOp |
| from torch.fx import Node |
| |
| |
| @register_node_visitor |
| class SelectVisitor(NodeVisitor): |
| target = "aten.select_copy.int" |
| |
| def __init__(self, *args): |
| super().__init__(*args) |
| |
| def define_node( |
| self, |
| node: Node, |
| tosa_graph: ts.TosaSerializer, |
| inputs: List[TosaArg], |
| output: TosaArg, |
| is_quant_node: bool, |
| ) -> None: |
| |
| assert len(inputs) == 3 |
| input_node, dim, index = inputs |
| shape = input_node.shape |
| rank = len(shape) |
| |
| dim = dim.number % rank if dim.number < 0 else dim.number |
| index = index.number % rank if index.number < 0 else index.number |
| |
| # For aten.select_copy, the output will be rank[input_shape - 1] |
| # For TOSA rank(in) == rank(out). |
| # Add an intermediate with the same rank |
| expanded_shape = tuple(1 if i == dim else shape[i] for i in range(rank)) |
| expanded_shape = tosa_shape(expanded_shape, input_node.dim_order) |
| |
| output_reshaped = tosa_graph.addIntermediate( |
| expanded_shape, ts.DType.INT8 if is_quant_node else output.dtype |
| ) |
| |
| attr_slice = ts.TosaSerializerAttribute() |
| |
| start_attr = [index if i == dim else 0 for i in input_node.dim_order] |
| size_attr = [ |
| 1 if i == dim else input_node.shape[i] for i in input_node.dim_order |
| ] |
| |
| attr_slice.SliceAttribute(start_attr, size_attr) |
| |
| tosa_graph.addOperator( |
| TosaOp.Op().SLICE, [input_node.name], [output_reshaped.name], attr_slice |
| ) |
| |
| # Reshape back to original rank of output. |
| build_reshape(tosa_graph, output_reshaped.name, output.shape, output.name) |