| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import copy |
| from typing import Any, Dict, Tuple |
| |
| import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper |
| |
| import numpy as np |
| import torch |
| from executorch.backends.qualcomm.utils.constants import ( |
| QCOM_AXIS, |
| QCOM_AXIS_ORDER, |
| QCOM_BITWIDTH, |
| QCOM_DTYPE, |
| QCOM_ENCODING, |
| QCOM_OFFSET, |
| QCOM_QUANT_ATTRS, |
| QCOM_QUANT_MAX, |
| QCOM_QUANT_MIN, |
| QCOM_REQUANTIZE, |
| QCOM_SCALE, |
| QCOM_SCALE_OFFSET, |
| QCOM_SCALES, |
| QCOM_ZERO_POINT, |
| QCOM_ZERO_POINTS, |
| ) |
| |
| from executorch.exir.dialects._ops import ops as exir_ops |
| |
| from .utils import ( |
| deduce_dtype, |
| get_parameter, |
| is_graph_input, |
| is_graph_output, |
| is_parameter, |
| ) |
| |
| |
| QNN_QUANT_TYPE_MAP = { |
| torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, |
| torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, |
| torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32, |
| # Note that there is no int64 tensor data type in Qnn. |
| torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED, |
| torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, |
| torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, |
| } |
| QNN_TENSOR_TYPE_MAP = { |
| torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, |
| torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, |
| torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, |
| torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, |
| torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, |
| torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64, |
| torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, |
| torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, |
| float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, |
| } |
| |
| PER_CHANNEL_ENCODING = { |
| exir_ops.edge.quantized_decomposed.quantize_per_channel.default, |
| exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, |
| } |
| |
| PER_TENSOR_ENCODING = { |
| exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, |
| exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, |
| } |
| |
| |
| class NodeVisitor: |
| """ |
| Node visitor pattern for visiting nodes in an edge IR graph |
| """ |
| |
| def __init__( |
| self, |
| external_ids, |
| edge_program: torch.export.ExportedProgram, |
| enable_tensor_dump, |
| ) -> None: |
| self.external_ids = external_ids or {} |
| self.edge_program = edge_program |
| self.enable_tensor_dump = enable_tensor_dump |
| |
| def get_tensor(self, input_node, op_node, idx=None): |
| """ |
| Get tensor value/shape with axis_order |
| """ |
| |
| def _get_tensor(node, index): |
| if index is not None: |
| assert isinstance(index, int) |
| if is_parameter(node, self.edge_program): |
| return get_parameter(node, self.edge_program)[index] |
| return node.meta["val"][index] |
| |
| if is_parameter(node, self.edge_program): |
| return get_parameter(node, self.edge_program) |
| return node.meta["val"] |
| |
| tensor = _get_tensor(input_node, idx) |
| if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta: |
| tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous() |
| return tensor |
| |
| def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): |
| quant_config = copy.deepcopy(quant_attrs) |
| |
| scales = quant_attrs[QCOM_SCALES] |
| zero_points = quant_attrs[QCOM_ZERO_POINTS] |
| assert len(scales) == len( |
| zero_points |
| ), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}" |
| |
| scale_offset = [] |
| for i in range(len(scales)): |
| # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h |
| scale_offset.append( |
| PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) |
| ) |
| |
| user_0 = list(node.users)[0] |
| # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO |
| if ( |
| "convolution" in user_0.target.__name__ |
| and list(node.users)[0].args[1] == node |
| ): |
| quant_config[QCOM_AXIS] = 3 |
| |
| else: |
| quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS] |
| |
| quant_config[QCOM_SCALE_OFFSET] = scale_offset |
| # special case for 4 bits |
| if ( |
| quant_config[QCOM_DTYPE] == torch.int8 |
| and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15 |
| ): |
| quant_config[QCOM_BITWIDTH] = 4 |
| return ( |
| PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, |
| quant_config, |
| ) |
| return ( |
| PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, |
| quant_config, |
| ) |
| |
| def make_qnn_per_tensor_config(self, quant_attrs: Dict): |
| quant_config = copy.deepcopy(quant_attrs) |
| # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h |
| quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT] |
| # special case for 4 bits |
| if ( |
| quant_config[QCOM_DTYPE] == torch.int8 |
| and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15 |
| ): |
| quant_config[QCOM_BITWIDTH] = 4 |
| return ( |
| PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET, |
| quant_config, |
| ) |
| return ( |
| PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, |
| quant_config, |
| ) |
| |
| def get_quant_encoding_conf( |
| self, node: torch.fx.Node, is_input_tensor: bool = False |
| ) -> Tuple[Any, Dict]: |
| if not node.meta.get(QCOM_QUANT_ATTRS, None): |
| return ( |
| PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, |
| {}, |
| ) |
| quant_attrs = ( |
| node.meta[QCOM_REQUANTIZE] |
| if QCOM_REQUANTIZE in node.meta and is_input_tensor |
| else node.meta[QCOM_QUANT_ATTRS] |
| ) |
| if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING: |
| return self.make_qnn_per_channel_config(node, quant_attrs) |
| |
| return self.make_qnn_per_tensor_config(quant_attrs) |
| |
| def get_quant_tensor_value( |
| self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict |
| ) -> torch.Tensor: |
| if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING: |
| scale = quant_attrs[QCOM_SCALE] |
| zero_point = quant_attrs[QCOM_ZERO_POINT] |
| else: # per channel case |
| scale = quant_attrs[QCOM_SCALES] |
| zero_point = quant_attrs[QCOM_ZERO_POINTS] |
| |
| dtype = quant_configs[QCOM_DTYPE] |
| |
| tensor = tensor.div(scale).add(zero_point).round().to(dtype) |
| # Make the backends access data correctly |
| if quant_configs.get(QCOM_BITWIDTH) == 4: |
| mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) |
| tensor = torch.bitwise_and(mask, tensor) |
| return tensor |
| |
| def get_tensor_type( |
| self, |
| node: torch.fx.Node, |
| tensor_type: PyQnnWrapper.Qnn_TensorType_t, |
| ) -> PyQnnWrapper.Qnn_TensorType_t: |
| is_input = is_graph_input(node, self.edge_program) |
| is_output = is_graph_output(node) |
| # handle logic for input/output tensors |
| if is_input or is_output: |
| assert ( |
| node in self.external_ids |
| ), f"Node {node}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}" |
| if is_input: |
| return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE |
| if is_output: |
| return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ |
| |
| if is_parameter(node, self.edge_program): |
| return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC |
| # dump all tensor, set to app read, and we only dump native tensors |
| if ( |
| self.enable_tensor_dump |
| and tensor_type == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE |
| ): |
| return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ |
| return tensor_type |
| |
| def get_data_type( |
| self, |
| tensor: torch.Tensor, |
| quant_config: Dict, |
| ) -> PyQnnWrapper.Qnn_TensorType_t: |
| if quant_config: |
| quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config) |
| return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]] |
| |
| return QNN_TENSOR_TYPE_MAP[tensor.dtype] |
| |
| def define_custom_tensor_wrapper( |
| self, |
| node_name: str, |
| tensor_type: PyQnnWrapper.Qnn_TensorType_t, |
| dtype: PyQnnWrapper.Qnn_DataType_t, |
| quant_encoding: PyQnnWrapper.Qnn_QuantizationEncoding_t, |
| quant_configs: dict, |
| dims: torch.Size, |
| tensor: torch.Tensor, |
| is_fake_tensor: bool, |
| nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], |
| wrapper_idx: int = 0, |
| ) -> PyQnnWrapper.TensorWrapper: |
| if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): |
| return cached |
| if is_fake_tensor: |
| tensor_wrapper = PyQnnWrapper.TensorWrapper( |
| node_name, |
| tensor_type, |
| dtype, |
| quant_encoding, |
| quant_configs, |
| len(dims), |
| dims, |
| np.array([]), |
| False, |
| ) |
| else: |
| # Can implement non-fake tensor when there is a need |
| return None |
| nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper |
| return tensor_wrapper |
| |
| def define_tensor( |
| self, |
| node: torch.fx.Node, |
| tensor: torch.Tensor, |
| tensor_type: PyQnnWrapper.Qnn_TensorType_t, |
| nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], |
| is_input_tensor: bool, |
| node_name: str = None, |
| wrapper_idx: int = 0, |
| ) -> PyQnnWrapper.TensorWrapper: |
| """ |
| Covert torch.Tensor to TensorWrapper |
| |
| Args: |
| node: EdgeIR Node |
| tensor: EdgeIR Tensor |
| tensor_type: QNN tensor type |
| nodes_to_wrappers: Set contains edge_graph values(node targets) |
| is_input_tensor: Whether tensor is a fake input tensor relatively to |
| the op builder that is calling this function |
| """ |
| if node_name is None: |
| node_name = node.name |
| |
| if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): |
| return cached |
| |
| tensor_name = f"{node.name}_{wrapper_idx}" |
| if is_graph_input(node, self.edge_program): |
| tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name |
| if is_graph_output(node): |
| tensor_name = "output_" + tensor_name |
| dims = [1] if len(tensor.size()) == 0 else tensor.size() |
| tensor_type = self.get_tensor_type(node, tensor_type) |
| quant_encoding, quant_configs = self.get_quant_encoding_conf( |
| node, is_input_tensor |
| ) |
| dtype = self.get_data_type(tensor, quant_configs) |
| if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): |
| tensor_wrapper = PyQnnWrapper.TensorWrapper( |
| tensor_name, |
| tensor_type, |
| dtype, |
| quant_encoding, |
| quant_configs, |
| len(dims), |
| dims, |
| np.array([]), |
| False, |
| ) |
| else: |
| if quant_configs: |
| tensor = self.get_quant_tensor_value( |
| tensor, |
| node.meta[QCOM_QUANT_ATTRS], |
| quant_configs, |
| ) |
| tensor_wrapper = PyQnnWrapper.TensorWrapper( |
| tensor_name, |
| tensor_type, |
| dtype, |
| quant_encoding, |
| quant_configs, |
| len(dims), |
| dims, |
| tensor.detach().numpy(), |
| True, |
| ) |
| nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper |
| return tensor_wrapper |
| |
| def define_node( |
| self, |
| node: torch.fx.Node, |
| nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], |
| ) -> PyQnnWrapper.PyQnnOpWrapper: |
| """Convert torch.fx.Node to OpWrapper""" |
| raise NotImplementedError("NodeVisitor must be extended!") |
| |
| |
| # This will hold mapping of all node names to the visitor class |
| _node_visitor_dict = {} |
| |
| |
| def register_node_visitor(visitor): |
| """Register node visitor into _node_visitor_dict""" |
| assert ( |
| isinstance(visitor, type) |
| and issubclass(visitor, NodeVisitor) |
| and hasattr(visitor, "target") |
| ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}" |
| for target in visitor.target: |
| _node_visitor_dict[target] = visitor |
| |
| |
| def generate_node_to_external_map( |
| edge_program: torch.export.ExportedProgram, |
| ) -> Dict[torch.fx.Node, int]: |
| node_to_external_map = {} |
| for node in edge_program.graph_module.graph.nodes: |
| # The order in which we visit the placeholder node is same as the *args |
| # order for the forward(*args) signature for this gm. Using the order of |
| # the nodes as external_id to extract the right arg from *args at runtime |
| if is_graph_input(node, edge_program): |
| node_to_external_map[node] = len(node_to_external_map) |
| for node in edge_program.graph_module.graph.nodes: |
| if is_graph_output(node): |
| node_to_external_map[node] = len(node_to_external_map) |
| return node_to_external_map |
| |
| |
| def get_node_visitors( |
| edge_program: torch.export.ExportedProgram, |
| enable_tensor_dump=False, |
| ) -> Dict[str, NodeVisitor]: |
| """Create a new class instance at runtime, and put them in a dict""" |
| node_to_external_map = generate_node_to_external_map(edge_program) |
| node_visitors = {} |
| for target, visitor in _node_visitor_dict.items(): |
| assert callable( |
| visitor |
| ), f"Expeting a callable class, but got {visitor} of type {type(visitor)}" |
| node_visitors[target] = visitor( |
| node_to_external_map, edge_program, enable_tensor_dump |
| ) |
| return node_visitors |