| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| from enum import IntEnum, unique |
| from typing import Callable, Optional, Sequence, Set |
| |
| import torch |
| from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum |
| from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu |
| from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import ( |
| RecomposePixelUnshuffle, |
| ) |
| from executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange |
| from executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer |
| from executorch.backends.transforms.decompose_sdpa import ( |
| DecomposeScaledDotProductAttention, |
| ) |
| |
| from torch._ops import OpOverload |
| from torch.ao.quantization.quantizer import Quantizer |
| from torch.fx import GraphModule |
| |
| from .annotators import OP_ANNOTATOR |
| |
| from .qconfig import ( |
| get_16a16w_qnn_ptq_config, |
| get_16a4w_qnn_ptq_config, |
| get_16a4w_qnn_qat_config, |
| get_16a8w_qnn_ptq_config, |
| get_8a8w_qnn_ptq_config, |
| get_8a8w_qnn_qat_config, |
| get_ptq_per_channel_quant_config, |
| get_qat_per_channel_quant_config, |
| QuantizationConfig, |
| ) |
| |
| # To bypass the meta internal test error |
| get_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config |
| |
| __all__ = [ |
| "QnnQuantizer", |
| "QuantDtype", |
| "get_16a4w_qnn_ptq_config", |
| "get_16a8w_qnn_ptq_config", |
| "get_16a16w_qnn_ptq_config", |
| "get_8a8w_qnn_ptq_config", |
| "get_8a8w_qnn_qat_config", |
| "get_16a4w_qnn_qat_config", |
| ] |
| |
| |
| @unique |
| class QuantDtype(IntEnum): |
| """ |
| bits of activation and bits of weight |
| """ |
| |
| use_16a16w = 0 |
| use_16a8w = 1 |
| use_16a4w = 2 |
| use_8a8w = 3 |
| |
| |
| quant_config_dict = { |
| # PTQ |
| (QuantDtype.use_16a16w, False): ( |
| get_16a16w_qnn_ptq_config, |
| get_ptq_per_channel_quant_config(torch.uint16, torch.int16), |
| ), |
| (QuantDtype.use_16a8w, False): ( |
| get_16a8w_qnn_ptq_config, |
| get_ptq_per_channel_quant_config(torch.uint16, torch.int8), |
| ), |
| (QuantDtype.use_16a4w, False): ( |
| get_16a4w_qnn_ptq_config, |
| get_ptq_per_channel_quant_config(torch.uint16, "int4"), |
| ), |
| (QuantDtype.use_8a8w, False): ( |
| get_8a8w_qnn_ptq_config, |
| get_ptq_per_channel_quant_config(), |
| ), |
| # QAT, |
| (QuantDtype.use_16a4w, True): ( |
| get_16a4w_qnn_qat_config, |
| get_qat_per_channel_quant_config(torch.uint16, "int4"), |
| ), |
| (QuantDtype.use_8a8w, True): ( |
| get_8a8w_qnn_qat_config, |
| get_qat_per_channel_quant_config(), |
| ), |
| } |
| |
| |
| class QnnQuantizer(Quantizer): |
| SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) |
| |
| def __init__(self): |
| super().__init__() |
| self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() |
| |
| self.is_qat = False |
| self.quant_dtype = QuantDtype.use_8a8w |
| self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() |
| self.per_channel_quant_config = get_ptq_per_channel_quant_config() |
| self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() |
| |
| self.custom_quant_annotations: Sequence[Callable] = [] |
| self.discard_nodes: Set[str] = set() |
| |
| def _annotate(self, gm: GraphModule) -> None: |
| for node in gm.graph.nodes: |
| if node.name in self.discard_nodes: |
| continue |
| |
| quant_config = self._get_quant_config(node.target) |
| if quant_config: |
| OP_ANNOTATOR[node.target](node, quant_config) |
| |
| def _annotate_custom_annotation(self, gm: GraphModule) -> None: |
| for annotation_func in self.custom_quant_annotations: |
| annotation_func(gm) |
| |
| def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]: |
| """ |
| Priority: |
| 1. is one of use_per_channel_weight_quant_ops |
| 2. quant config |
| """ |
| if isinstance(op, str): |
| return |
| |
| if op in self.use_per_channel_weight_quant_ops: |
| return self.per_channel_quant_config |
| |
| if op in self.quant_ops: |
| return self.quant_config |
| |
| print(f"No quant config is implemented for op, {op}") |
| |
| def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): |
| if enable: |
| self.use_per_channel_weight_quant_ops.update(ops) |
| else: |
| self.use_per_channel_weight_quant_ops.difference_update(ops) |
| |
| def add_custom_quant_annotations( |
| self, custom_quant_annotations: Sequence[Callable] |
| ) -> None: |
| self.custom_quant_annotations = custom_quant_annotations |
| |
| def add_discard_nodes(self, nodes: Sequence[str]) -> None: |
| self.discard_nodes = set(nodes) |
| |
| def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: |
| for op in ops: |
| self.quant_ops.remove(op) |
| |
| def annotate(self, model: GraphModule) -> GraphModule: |
| self._annotate(model) |
| self._annotate_custom_annotation(model) |
| |
| return model |
| |
| def get_supported_ops(self) -> Set[OpOverload]: |
| return self.SUPPORTED_OPS |
| |
| def set_quant_config( |
| self, quant_dtype: QuantDtype, is_qat=False, act_observer=None |
| ) -> None: |
| self.quant_dtype = quant_dtype |
| self.is_qat = is_qat |
| if (quant_dtype, is_qat) not in quant_config_dict: |
| raise RuntimeError( |
| f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" |
| ) |
| |
| quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ |
| (quant_dtype, is_qat) |
| ] |
| self.quant_config = ( |
| quant_config_fuc(act_observer) if act_observer else quant_config_fuc() |
| ) |
| |
| def set_per_channel_conv_quant(self, enable: bool) -> None: |
| conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} |
| self._update_per_channel_weight_quant_ops(conv_ops, enable) |
| |
| def set_per_channel_linear_quant(self, enable: bool) -> None: |
| linear_ops = { |
| torch.ops.aten.linear.default, |
| } |
| self._update_per_channel_weight_quant_ops(linear_ops, enable) |
| |
| def transform_for_annotation(self, model: GraphModule) -> GraphModule: |
| model = ReduceDynamicRange()(model).graph_module |
| model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module |
| model = DecomposeScaledDotProductAttention()(model).graph_module |
| model = DecomposeSilu()(model).graph_module |
| model = DecomposeEinsum()(model).graph_module |
| model = ReplaceInfBuffer()(model).graph_module |
| return model |
| |
| def validate(self, model: GraphModule) -> None: |
| pass |