blob: 7391205d7ecd46b68fb0a71cbc44a0b78da45005 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import ctypes
from typing import final, List
import torch
from executorch.backends.backend_details import BackendDetails, CompileSpec
from executorch.backends.qnnpack.serialization.qnnpack_graph_schema import (
ConstTensor,
QNNDynamicLinear,
)
from executorch.backends.qnnpack.serialization.qnnpack_graph_serialize import (
convert_to_flatbuffer,
)
from executorch.backends.transforms import get_shape
from executorch.exir.dialects._ops import ops as exir_ops
from torch._export.exported_program import ExportedProgram
T_Mm = exir_ops.edge.aten.mm.default
T_Addmm = exir_ops.edge.aten.addmm.default
T_Linear = exir_ops.edge.aten.linear.default
def _copy_buffer(storage: torch.UntypedStorage) -> bytes:
array_type = ctypes.c_char * storage.nbytes()
array = ctypes.cast(
storage.data_ptr(),
ctypes.POINTER(array_type),
).contents
return bytes(array)
@final
class QnnpackBackend(BackendDetails):
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> bytes:
for node in edge_program.graph.nodes:
# TODO(maxren): Follow this up by removing addm and mm nodes
if node.op == "call_function":
# Finding the linear node
if node.target in (T_Mm, T_Addmm, T_Linear):
# Padding with 16 bytes of 0
padding = b"\0" * 16
bias_tensor = None
if node.target == T_Addmm:
op_input = node.args[1]
weight = node.args[2]
# For linear node, bias is known
bias_tensor = getattr(
edge_program.graph_module, node.args[0].target
).contiguous()
# t_defualt node -> dequant node
weight_dequant = weight.args[0]
elif node.target == T_Mm:
op_input = node.args[0]
weight = node.args[1]
# t_defualt node -> dequant node
weight_dequant = weight.args[0]
elif node.target == T_Linear:
op_input = node.args[0]
weight_dequant = node.args[1]
if len(node.args) > 2:
bias_tensor = getattr(
edge_program.graph_module, node.args[2].target
).contiguous()
else:
raise RuntimeError(
"Node %s not supported", node.target.__name__
)
weight_shape = get_shape(weight_dequant)
output_channels = weight_shape[0]
if bias_tensor is None:
bias_tensor = torch.zeros(output_channels)
# input
input_shape = get_shape(op_input)
# bias
op_bias = ConstTensor(
shape=list(bias_tensor.shape), # should be 1d
buffer=_copy_buffer(bias_tensor.untyped_storage()) + padding,
)
# deqaunt node -> quant node
weight_quant = weight_dequant.args[0]
# quant node -> tensor_constant
weight_const = getattr(
edge_program.graph_module, weight_quant.args[0].target
)
if (
weight_quant.target.__name__
== "quantized_decomposed.quantize_per_channel.default"
):
# scale and zero_point are tensors
weight_scale = weight_quant.args[1]
scale_tensor = getattr(
edge_program.graph_module, weight_scale.target
)
weight_zeropoint = weight_quant.args[2]
zp_tensor = (
getattr(edge_program.graph_module, weight_zeropoint.target)
+ 128
)
axis = weight_quant.args[3]
# requantize weight to uint8
requantized_weight_tensor = weight_quant.target(
weight_const,
scale_tensor,
zp_tensor,
axis,
0,
255,
torch.uint8,
)
elif (
weight_quant.target.__name__
== "quantized_decomposed.quantize_per_tensor.default"
):
scale = weight_quant.args[1]
zeropoint = weight_quant.args[2] + 128
scale_tensor = torch.FloatTensor([scale] * output_channels)
zp_tensor = torch.IntTensor([zeropoint] * output_channels)
requantized_weight_tensor = weight_quant.target(
weight_const,
scale,
zeropoint,
0,
255,
torch.uint8,
)
else:
raise RuntimeError("Not Supported Quantization")
# Prep Tensors for Serializing Data
zp_tensor = zp_tensor.type(torch.uint8).contiguous()
scale_tensor = scale_tensor.contiguous()
requantized_weight_tensor = requantized_weight_tensor.contiguous()
# Weights as Tensor
op_weight = ConstTensor(
# Right now we are serializing shape which is not in
# congruence with weight layout. This is just wrong.
# However, not changing it here since this is a BC breaking
# change. Lets follow this up so as to make the following
# line look like
# shape=weight_val.shape,
# TODO(maxren)
shape=[weight_shape[1], weight_shape[0]],
buffer=_copy_buffer(requantized_weight_tensor.untyped_storage())
+ padding,
)
# Weight's Scales as Tensor
weight_scale = ConstTensor(
shape=[output_channels],
buffer=_copy_buffer(scale_tensor.untyped_storage()) + padding,
)
# Weight's Zeropoints as Tensor
weight_zp = ConstTensor(
shape=[output_channels],
buffer=_copy_buffer(zp_tensor.untyped_storage()) + padding,
)
dynamic_linear = QNNDynamicLinear(
input_shape=input_shape,
bias=op_bias,
weights=op_weight,
weights_zero_point=weight_zp,
weights_scale=weight_scale,
)
return convert_to_flatbuffer(dynamic_linear)
raise RuntimeError("QNNPACK preprocess failed to lower the partitioned graph")
return b""