| # Copyright 2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| from typing import List |
| |
| import serializer.tosa_serializer as ts |
| from executorch.backends.arm.operators.node_visitor import ( |
| NodeVisitor, |
| register_node_visitor, |
| ) |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| from serializer.tosa_serializer import TosaOp |
| from torch.fx import Node |
| |
| |
| @register_node_visitor |
| class SliceVisitor(NodeVisitor): |
| target = "aten.slice_copy.Tensor" |
| |
| def __init__(self, *args): |
| super().__init__(*args) |
| |
| def define_node( |
| self, |
| node: Node, |
| tosa_graph: ts.TosaSerializer, |
| inputs: List[TosaArg], |
| output: TosaArg, |
| is_quant_node: bool, |
| ) -> None: |
| |
| # aten.slice_copy supports slicing in 1d at a time. |
| # The arguments are dimension of slicing, start index and end index. |
| assert len(inputs) == 4 |
| input_node, dim, start, end = inputs |
| |
| # Translate and check parameters in Pytorch dim order. |
| shape = input_node.shape |
| dim = dim.number |
| if end.number < 0: |
| end = end.number % shape[dim] |
| else: |
| end = min(end.number, shape[dim]) |
| size = end - start.number |
| assert size > 0 |
| assert size <= shape[dim] |
| |
| # Convert aten args to Tosa's start and size attributes and in TOSA dim order. |
| attr = ts.TosaSerializerAttribute() |
| start_attr = [start.number if i == dim else 0 for i in input_node.dim_order] |
| size_attr = [size if i == dim else shape[i] for i in input_node.dim_order] |
| attr.SliceAttribute(start_attr, size_attr) |
| |
| tosa_graph.addOperator( |
| TosaOp.Op().SLICE, [input_node.name], [output.name], attr |
| ) |