blob: 20de9e0846a55f0c801cd5115da052ef04840d88 [file]
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp
@register_node_visitor
class RepeatVisitor(NodeVisitor):
target = "aten.repeat.default"
def __init__(self, *args):
super().__init__(*args)
def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: list[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
item_name = inputs[0].name
shape = inputs[0].shape
rank = len(shape)
multiples = inputs[1].special
new_rank = len(multiples)
assert new_rank >= rank
# TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first.
if new_rank > rank:
# Add length 1 dimensions to shape to match multiples
num_new_dims = new_rank - rank
expanded_shape = tuple(
1 if i < num_new_dims else shape[i - num_new_dims]
for i in range(new_rank)
)
expanded_shape = tosa_shape(expanded_shape, output.dim_order)
dtype = (
ts.dtype_str_to_val("INT8")
if is_quant_node
else ts.dtype_str_to_val("FP32")
)
rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype)
rescale_attr = ts.TosaSerializerAttribute()
rescale_attr.ReshapeAttribute(expanded_shape)
tosa_graph.addOperator(
TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr
)
item_name = rescale_out.name
attr = ts.TosaSerializerAttribute()
attr.TileAttribute(tosa_shape(multiples, output.dim_order))
tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr)