| # Copyright 2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| # |
| |
| # pyre-unsafe |
| from typing import cast, Dict |
| |
| import numpy as np |
| import serializer.tosa_serializer as ts |
| import torch |
| import torch.fx |
| from executorch.backends.arm.operators.node_visitor import NodeVisitor |
| from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg |
| from executorch.backends.arm.tosa_quant_utils import ( |
| get_quant_arg_upstream, |
| get_quantized_node_output_dtype, |
| is_node_quantized, |
| ) |
| from executorch.backends.arm.tosa_specification import TosaSpecification |
| from executorch.backends.arm.tosa_utils import ( |
| getNodeArgs, |
| is_bias_node_for_quantized_conv, |
| tosa_shape, |
| ) |
| from torch.export.exported_program import ExportedProgram |
| |
| |
| def process_call_function( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| node_visitors: Dict[str, NodeVisitor], |
| tosa_spec: TosaSpecification, |
| ): |
| # Unpack arguments and convert |
| inputs = getNodeArgs(node) |
| |
| # Convert output (this node itself) |
| output = TosaArg(node) |
| |
| is_quant_node = is_node_quantized(node) |
| if is_quant_node: |
| output_dtype = map_dtype(get_quantized_node_output_dtype(node)) |
| else: |
| output_dtype = output.dtype |
| tosa_graph.currRegion.currBasicBlock.addTensor( |
| output.name, |
| tosa_shape(output.shape, output.dim_order), |
| output_dtype, |
| ) |
| |
| # Visiting each Node |
| # pyre-ignore[16]: Undefined attribute. |
| if node.target.__name__ in node_visitors: |
| # pyre-ignore[16]: Undefined attribute. |
| node_visitors[node.target.__name__].define_node( |
| node, |
| tosa_graph, |
| inputs, |
| output, |
| is_quant_node, |
| ) |
| else: |
| raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") |
| |
| |
| def process_inputs( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| tosa_spec: TosaSpecification, |
| ): |
| """Serialize an input node""" |
| # inputs need to be in default dim_order (contiguous memory format) |
| meta = node.meta["val"] |
| if meta.dim_order() != tuple(range(meta.dim())): |
| raise RuntimeError( |
| f"Arm backend only supports contiguous memory format for inputs. " |
| f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" |
| ) |
| inputs = [TosaArg(node)] |
| input_shape = inputs[0].shape |
| input_dim_order = inputs[0].dim_order |
| tensor = ts.TosaSerializerTensor( |
| inputs[0].name, |
| tosa_shape(input_shape, input_dim_order), |
| ( |
| map_dtype(get_quantized_node_output_dtype(node)) |
| if is_node_quantized(node) |
| else inputs[0].dtype |
| ), |
| data=None, |
| placeholderFilename=inputs[0].name + ".npy", |
| ) |
| tosa_graph.addInputTensor(tensor) |
| |
| |
| def process_quantized_bias( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| parameter_values, |
| ): |
| """ |
| Serialize bias node that needs to be quantized. |
| """ |
| consumer_node = list(node.users)[0] |
| ( |
| input_node, |
| weight_node, |
| _, |
| ) = consumer_node.all_input_nodes |
| |
| input_node_scale = get_quant_arg_upstream(input_node).scale |
| weight_node_scale = get_quant_arg_upstream(weight_node).scale |
| bias_values_quantized = ( |
| (parameter_values / (input_node_scale * weight_node_scale)) |
| .round() |
| .astype(np.int32) |
| ) |
| |
| tosa_graph.addConst( |
| bias_values_quantized.shape, |
| ts.DType.INT32, |
| bias_values_quantized, |
| name=node.name, |
| ) |
| |
| |
| def process_inputs_to_parameters( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| edge_program: ExportedProgram, |
| tosa_spec: TosaSpecification, |
| ): |
| """Serialize bias and non-quantized weights""" |
| inputs = [TosaArg(node)] |
| parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name] |
| parameter_data = edge_program.state_dict[parameter_name] |
| |
| assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" |
| parameter_values = parameter_data.detach().numpy() |
| |
| if is_bias_node_for_quantized_conv(node): |
| # BI bias |
| assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer" |
| process_quantized_bias(node, tosa_graph, parameter_values) |
| else: |
| # MI weights or bias |
| if inputs[0].dtype == torch.float32: |
| assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" |
| |
| parameter_values = np.transpose(parameter_values, inputs[0].dim_order) |
| |
| tosa_graph.addConst( |
| parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name |
| ) |
| |
| |
| def process_inputs_to_buffers( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| edge_program: ExportedProgram, |
| ): |
| """Serialize quantized weights""" |
| inputs = [TosaArg(node)] |
| buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] |
| buffer_data = edge_program.state_dict[buffer_name] |
| |
| assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor" |
| buffer_values = buffer_data.detach().numpy() |
| |
| # TODO: fragile code for temporary fix |
| # the mean and var tensors are also stored here but they have shape (1, ) |
| # we only transpose weights here |
| buffer_values = np.transpose(buffer_values, inputs[0].dim_order) |
| |
| tosa_graph.addConst( |
| buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name |
| ) |
| |
| |
| def process_inputs_to_lifted_tensor_constants( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| edge_program: ExportedProgram, |
| ): |
| arg = TosaArg(node) |
| tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[ |
| arg.name |
| ] |
| tensor = edge_program.tensor_constants[tensor_name] |
| tensor_data = tensor.detach().numpy() |
| |
| tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name) |
| |
| |
| def process_placeholder( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| edge_program: ExportedProgram, |
| tosa_spec: TosaSpecification, |
| ): |
| """Wrapper for processing and serializing all types of placeholders""" |
| assert node.name == node.target, "Expect placeholder name and target to match" |
| assert 0 == len(node.args), "Can't handle default input values" |
| |
| if node.name in edge_program.graph_signature.user_inputs: |
| process_inputs(node, tosa_graph, tosa_spec) |
| elif node.name in edge_program.graph_signature.inputs_to_parameters: |
| process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) |
| elif node.name in edge_program.graph_signature.inputs_to_buffers: |
| process_inputs_to_buffers(node, tosa_graph, edge_program) |
| elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants: |
| process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program) |
| elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs: |
| raise NotImplementedError( |
| "Placeholder is of type 'lifted custom object' which is not supported." |
| ) |
| else: |
| raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") |
| |
| |
| def process_output( |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| ): |
| for output in cast(tuple[torch.fx.Node, ...], node.args[0]): |
| tosa_graph.addOutputTensor( |
| tosa_graph.currRegion.currBasicBlock.tensors[output.name] |
| ) |