| # Copyright 2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| import serializer.tosa_serializer as ts |
| import torch |
| from executorch.backends.arm.operators.node_visitor import ( |
| NodeVisitor, |
| register_node_visitor, |
| ) |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| from executorch.backends.arm.tosa_utils import tosa_shape |
| from serializer.tosa_serializer import TosaOp |
| |
| |
| @register_node_visitor |
| class RepeatVisitor(NodeVisitor): |
| target = "aten.repeat.default" |
| |
| def __init__(self, *args): |
| super().__init__(*args) |
| |
| def define_node( |
| self, |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| inputs: list[TosaArg], |
| output: TosaArg, |
| is_quant_node: bool, |
| ) -> None: |
| |
| item_name = inputs[0].name |
| shape = inputs[0].shape |
| rank = len(shape) |
| multiples = inputs[1].special |
| new_rank = len(multiples) |
| |
| assert new_rank >= rank |
| |
| # TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first. |
| if new_rank > rank: |
| # Add length 1 dimensions to shape to match multiples |
| num_new_dims = new_rank - rank |
| expanded_shape = tuple( |
| 1 if i < num_new_dims else shape[i - num_new_dims] |
| for i in range(new_rank) |
| ) |
| expanded_shape = tosa_shape(expanded_shape, output.dim_order) |
| dtype = ( |
| ts.dtype_str_to_val("INT8") |
| if is_quant_node |
| else ts.dtype_str_to_val("FP32") |
| ) |
| |
| rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype) |
| rescale_attr = ts.TosaSerializerAttribute() |
| rescale_attr.ReshapeAttribute(expanded_shape) |
| tosa_graph.addOperator( |
| TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr |
| ) |
| item_name = rescale_out.name |
| |
| attr = ts.TosaSerializerAttribute() |
| attr.TileAttribute(tosa_shape(multiples, output.dim_order)) |
| tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr) |