blob: 2d3a0c2786c93ebc11b489c0256d4ac5742bc6ed [file] [log] [blame] [edit]
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-unsafe
from typing import cast, Dict
import numpy as np
import serializer.tosa_serializer as ts
import torch
import torch.fx
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_upstream,
get_quantized_node_output_dtype,
is_node_quantized,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
getNodeArgs,
is_bias_node_for_quantized_conv,
tosa_shape,
)
from torch.export.exported_program import ExportedProgram
def process_call_function(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
node_visitors: Dict[str, NodeVisitor],
tosa_spec: TosaSpecification,
):
# Unpack arguments and convert
inputs = getNodeArgs(node)
# Convert output (this node itself)
output = TosaArg(node)
is_quant_node = is_node_quantized(node)
if is_quant_node:
output_dtype = map_dtype(get_quantized_node_output_dtype(node))
else:
output_dtype = output.dtype
tosa_graph.currRegion.currBasicBlock.addTensor(
output.name,
tosa_shape(output.shape, output.dim_order),
output_dtype,
)
# Visiting each Node
# pyre-ignore[16]: Undefined attribute.
if node.target.__name__ in node_visitors:
# pyre-ignore[16]: Undefined attribute.
node_visitors[node.target.__name__].define_node(
node,
tosa_graph,
inputs,
output,
is_quant_node,
)
else:
raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")
def process_inputs(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_spec: TosaSpecification,
):
"""Serialize an input node"""
# inputs need to be in default dim_order (contiguous memory format)
meta = node.meta["val"]
if meta.dim_order() != tuple(range(meta.dim())):
raise RuntimeError(
f"Arm backend only supports contiguous memory format for inputs. "
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
)
inputs = [TosaArg(node)]
input_shape = inputs[0].shape
input_dim_order = inputs[0].dim_order
tensor = ts.TosaSerializerTensor(
inputs[0].name,
tosa_shape(input_shape, input_dim_order),
(
map_dtype(get_quantized_node_output_dtype(node))
if is_node_quantized(node)
else inputs[0].dtype
),
data=None,
placeholderFilename=inputs[0].name + ".npy",
)
tosa_graph.addInputTensor(tensor)
def process_quantized_bias(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
parameter_values,
):
"""
Serialize bias node that needs to be quantized.
"""
consumer_node = list(node.users)[0]
(
input_node,
weight_node,
_,
) = consumer_node.all_input_nodes
input_node_scale = get_quant_arg_upstream(input_node).scale
weight_node_scale = get_quant_arg_upstream(weight_node).scale
bias_values_quantized = (
(parameter_values / (input_node_scale * weight_node_scale))
.round()
.astype(np.int32)
)
tosa_graph.addConst(
bias_values_quantized.shape,
ts.DType.INT32,
bias_values_quantized,
name=node.name,
)
def process_inputs_to_parameters(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
"""Serialize bias and non-quantized weights"""
inputs = [TosaArg(node)]
parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
parameter_data = edge_program.state_dict[parameter_name]
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
parameter_values = parameter_data.detach().numpy()
if is_bias_node_for_quantized_conv(node):
# BI bias
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
process_quantized_bias(node, tosa_graph, parameter_values)
else:
# MI weights or bias
if inputs[0].dtype == torch.float32:
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
tosa_graph.addConst(
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
)
def process_inputs_to_buffers(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
):
"""Serialize quantized weights"""
inputs = [TosaArg(node)]
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
buffer_data = edge_program.state_dict[buffer_name]
assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
buffer_values = buffer_data.detach().numpy()
# TODO: fragile code for temporary fix
# the mean and var tensors are also stored here but they have shape (1, )
# we only transpose weights here
buffer_values = np.transpose(buffer_values, inputs[0].dim_order)
tosa_graph.addConst(
buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name
)
def process_inputs_to_lifted_tensor_constants(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
):
arg = TosaArg(node)
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
arg.name
]
tensor = edge_program.tensor_constants[tensor_name]
tensor_data = tensor.detach().numpy()
tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name)
def process_placeholder(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
"""Wrapper for processing and serializing all types of placeholders"""
assert node.name == node.target, "Expect placeholder name and target to match"
assert 0 == len(node.args), "Can't handle default input values"
if node.name in edge_program.graph_signature.user_inputs:
process_inputs(node, tosa_graph, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_parameters:
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_buffers:
process_inputs_to_buffers(node, tosa_graph, edge_program)
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
raise NotImplementedError(
"Placeholder is of type 'lifted custom object' which is not supported."
)
else:
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
def process_output(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
):
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
tosa_graph.addOutputTensor(
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
)