blob: 26d29f5179b239787d717dd8a4c427eeb7083e33 [file]
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import cast, List
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils
import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node
@register_node_visitor
class AddVisitor(NodeVisitor):
target = "aten.sum.dim_IntList"
def __init__(self, *args):
super().__init__(*args)
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
input_node = inputs[0]
input_shape = list(input_node.shape)
dim_list = cast(list[int], inputs[1].special)
dim_list = [dim % len(input_node.shape) for dim in dim_list]
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
if is_quant_node:
# Rescale input to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
[node.all_input_nodes[0]], tosa_graph
)
prev_node = rescaled_inputs[0]
reduced_shape = input_shape
# Reduce all dims in dim_list one-by-one.
for dim in dim_list:
# When reduced, the size of the dim becomes 1.
reduced_shape[dim] = 1
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(input_node.dim_order.index(dim))
next_node = tosa_graph.addIntermediate(
tutils.tosa_shape(reduced_shape, input_node.dim_order),
dtype=ts.DType.INT32,
)
tosa_graph.addOperator(
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
)
prev_node = next_node
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
else:
input_name = input_node.name
reduced_shape = input_shape
# Reduce all dims in dim_list one-by-one.
for dim in dim_list:
# When reduced, the size of the dim becomes 1
reduced_shape[dim] = 1
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(input_node.dim_order.index(dim))
if dim == dim_list[-1]:
output_name = output.name
else:
output_name = tosa_graph.addIntermediate(
tutils.tosa_shape(reduced_shape, input_node.dim_order),
dtype=ts.DType.FP32,
).name
tosa_graph.addOperator(
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
)
input_name = output_name