blob: 95b147aba518d7516903178ab15644181f09544a [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Dict, Tuple
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS,
QCOM_AXIS_ORDER,
QCOM_BITWIDTH,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_OFFSET,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
QCOM_SCALE,
QCOM_SCALE_OFFSET,
QCOM_SCALES,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)
from executorch.exir.dialects._ops import ops as exir_ops
from .utils import (
deduce_dtype,
get_parameter,
is_graph_input,
is_graph_output,
is_parameter,
)
QNN_QUANT_TYPE_MAP = {
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8,
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16,
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32,
# Note that there is no int64 tensor data type in Qnn.
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
}
QNN_TENSOR_TYPE_MAP = {
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,
torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
}
PER_CHANNEL_ENCODING = {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
}
PER_TENSOR_ENCODING = {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}
class NodeVisitor:
"""
Node visitor pattern for visiting nodes in an edge IR graph
"""
def __init__(
self,
external_ids,
edge_program: torch.export.ExportedProgram,
enable_tensor_dump,
) -> None:
self.external_ids = external_ids or {}
self.edge_program = edge_program
self.enable_tensor_dump = enable_tensor_dump
def get_tensor(self, input_node, op_node, idx=None):
"""
Get tensor value/shape with axis_order
"""
def _get_tensor(node, index):
if index is not None:
assert isinstance(index, int)
if is_parameter(node, self.edge_program):
return get_parameter(node, self.edge_program)[index]
return node.meta["val"][index]
if is_parameter(node, self.edge_program):
return get_parameter(node, self.edge_program)
return node.meta["val"]
tensor = _get_tensor(input_node, idx)
if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta:
tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous()
return tensor
def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)
scales = quant_attrs[QCOM_SCALES]
zero_points = quant_attrs[QCOM_ZERO_POINTS]
assert len(scales) == len(
zero_points
), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"
scale_offset = []
for i in range(len(scales)):
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
scale_offset.append(
PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
)
user_0 = list(node.users)[0]
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
if (
"convolution" in user_0.target.__name__
and list(node.users)[0].args[1] == node
):
quant_config[QCOM_AXIS] = 3
else:
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
quant_config[QCOM_SCALE_OFFSET] = scale_offset
# special case for 4 bits
if (
quant_config[QCOM_DTYPE] == torch.int8
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
):
quant_config[QCOM_BITWIDTH] = 4
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
quant_config,
)
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
quant_config,
)
def make_qnn_per_tensor_config(self, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
# special case for 4 bits
if (
quant_config[QCOM_DTYPE] == torch.int8
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
):
quant_config[QCOM_BITWIDTH] = 4
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
quant_config,
)
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
quant_config,
)
def get_quant_encoding_conf(
self, node: torch.fx.Node, is_input_tensor: bool = False
) -> Tuple[Any, Dict]:
if not node.meta.get(QCOM_QUANT_ATTRS, None):
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
{},
)
quant_attrs = (
node.meta[QCOM_REQUANTIZE]
if QCOM_REQUANTIZE in node.meta and is_input_tensor
else node.meta[QCOM_QUANT_ATTRS]
)
if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
return self.make_qnn_per_channel_config(node, quant_attrs)
return self.make_qnn_per_tensor_config(quant_attrs)
def get_quant_tensor_value(
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
) -> torch.Tensor:
if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
scale = quant_attrs[QCOM_SCALE]
zero_point = quant_attrs[QCOM_ZERO_POINT]
else: # per channel case
scale = quant_attrs[QCOM_SCALES]
zero_point = quant_attrs[QCOM_ZERO_POINTS]
dtype = quant_configs[QCOM_DTYPE]
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
# Make the backends access data correctly
if quant_configs.get(QCOM_BITWIDTH) == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
tensor = torch.bitwise_and(mask, tensor)
return tensor
def get_tensor_type(
self,
node: torch.fx.Node,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
) -> PyQnnWrapper.Qnn_TensorType_t:
is_input = is_graph_input(node, self.edge_program)
is_output = is_graph_output(node)
# handle logic for input/output tensors
if is_input or is_output:
assert (
node in self.external_ids
), f"Node {node}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}"
if is_input:
return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE
if is_output:
return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
if is_parameter(node, self.edge_program):
return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
# dump all tensor, set to app read, and we only dump native tensors
if (
self.enable_tensor_dump
and tensor_type == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
):
return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
return tensor_type
def get_data_type(
self,
tensor: torch.Tensor,
quant_config: Dict,
) -> PyQnnWrapper.Qnn_TensorType_t:
if quant_config:
quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config)
return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]]
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
def define_custom_tensor_wrapper(
self,
node_name: str,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
dtype: PyQnnWrapper.Qnn_DataType_t,
quant_encoding: PyQnnWrapper.Qnn_QuantizationEncoding_t,
quant_configs: dict,
dims: torch.Size,
tensor: torch.Tensor,
is_fake_tensor: bool,
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached
if is_fake_tensor:
tensor_wrapper = PyQnnWrapper.TensorWrapper(
node_name,
tensor_type,
dtype,
quant_encoding,
quant_configs,
len(dims),
dims,
np.array([]),
False,
)
else:
# Can implement non-fake tensor when there is a need
return None
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
return tensor_wrapper
def define_tensor(
self,
node: torch.fx.Node,
tensor: torch.Tensor,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
is_input_tensor: bool,
node_name: str = None,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Covert torch.Tensor to TensorWrapper
Args:
node: EdgeIR Node
tensor: EdgeIR Tensor
tensor_type: QNN tensor type
nodes_to_wrappers: Set contains edge_graph values(node targets)
is_input_tensor: Whether tensor is a fake input tensor relatively to
the op builder that is calling this function
"""
if node_name is None:
node_name = node.name
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached
tensor_name = f"{node.name}_{wrapper_idx}"
if is_graph_input(node, self.edge_program):
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
tensor_type = self.get_tensor_type(node, tensor_type)
quant_encoding, quant_configs = self.get_quant_encoding_conf(
node, is_input_tensor
)
dtype = self.get_data_type(tensor, quant_configs)
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
tensor_type,
dtype,
quant_encoding,
quant_configs,
len(dims),
dims,
np.array([]),
False,
)
else:
if quant_configs:
tensor = self.get_quant_tensor_value(
tensor,
node.meta[QCOM_QUANT_ATTRS],
quant_configs,
)
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
tensor_type,
dtype,
quant_encoding,
quant_configs,
len(dims),
dims,
tensor.detach().numpy(),
True,
)
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
return tensor_wrapper
def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
) -> PyQnnWrapper.PyQnnOpWrapper:
"""Convert torch.fx.Node to OpWrapper"""
raise NotImplementedError("NodeVisitor must be extended!")
# This will hold mapping of all node names to the visitor class
_node_visitor_dict = {}
def register_node_visitor(visitor):
"""Register node visitor into _node_visitor_dict"""
assert (
isinstance(visitor, type)
and issubclass(visitor, NodeVisitor)
and hasattr(visitor, "target")
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
for target in visitor.target:
_node_visitor_dict[target] = visitor
def generate_node_to_external_map(
edge_program: torch.export.ExportedProgram,
) -> Dict[torch.fx.Node, int]:
node_to_external_map = {}
for node in edge_program.graph_module.graph.nodes:
# The order in which we visit the placeholder node is same as the *args
# order for the forward(*args) signature for this gm. Using the order of
# the nodes as external_id to extract the right arg from *args at runtime
if is_graph_input(node, edge_program):
node_to_external_map[node] = len(node_to_external_map)
for node in edge_program.graph_module.graph.nodes:
if is_graph_output(node):
node_to_external_map[node] = len(node_to_external_map)
return node_to_external_map
def get_node_visitors(
edge_program: torch.export.ExportedProgram,
enable_tensor_dump=False,
) -> Dict[str, NodeVisitor]:
"""Create a new class instance at runtime, and put them in a dict"""
node_to_external_map = generate_node_to_external_map(edge_program)
node_visitors = {}
for target, visitor in _node_visitor_dict.items():
assert callable(
visitor
), f"Expeting a callable class, but got {visitor} of type {type(visitor)}"
node_visitors[target] = visitor(
node_to_external_map, edge_program, enable_tensor_dump
)
return node_visitors