blob: da7b0174c029d9e8a9cde3fca1fde532a5ede9f4 [file] [log] [blame] [edit]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import IntEnum, unique
from typing import Callable, Optional, Sequence, Set
import torch
from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum
from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu
from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
RecomposePixelUnshuffle,
)
from executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange
from executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from torch._ops import OpOverload
from torch.ao.quantization.quantizer import Quantizer
from torch.fx import GraphModule
from .annotators import OP_ANNOTATOR
from .qconfig import (
get_16a16w_qnn_ptq_config,
get_16a4w_qnn_ptq_config,
get_16a4w_qnn_qat_config,
get_16a8w_qnn_ptq_config,
get_8a8w_qnn_ptq_config,
get_8a8w_qnn_qat_config,
get_ptq_per_channel_quant_config,
get_qat_per_channel_quant_config,
QuantizationConfig,
)
# To bypass the meta internal test error
get_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config
__all__ = [
"QnnQuantizer",
"QuantDtype",
"get_16a4w_qnn_ptq_config",
"get_16a8w_qnn_ptq_config",
"get_16a16w_qnn_ptq_config",
"get_8a8w_qnn_ptq_config",
"get_8a8w_qnn_qat_config",
"get_16a4w_qnn_qat_config",
]
@unique
class QuantDtype(IntEnum):
"""
bits of activation and bits of weight
"""
use_16a16w = 0
use_16a8w = 1
use_16a4w = 2
use_8a8w = 3
quant_config_dict = {
# PTQ
(QuantDtype.use_16a16w, False): (
get_16a16w_qnn_ptq_config,
get_ptq_per_channel_quant_config(torch.uint16, torch.int16),
),
(QuantDtype.use_16a8w, False): (
get_16a8w_qnn_ptq_config,
get_ptq_per_channel_quant_config(torch.uint16, torch.int8),
),
(QuantDtype.use_16a4w, False): (
get_16a4w_qnn_ptq_config,
get_ptq_per_channel_quant_config(torch.uint16, "int4"),
),
(QuantDtype.use_8a8w, False): (
get_8a8w_qnn_ptq_config,
get_ptq_per_channel_quant_config(),
),
# QAT,
(QuantDtype.use_16a4w, True): (
get_16a4w_qnn_qat_config,
get_qat_per_channel_quant_config(torch.uint16, "int4"),
),
(QuantDtype.use_8a8w, True): (
get_8a8w_qnn_qat_config,
get_qat_per_channel_quant_config(),
),
}
class QnnQuantizer(Quantizer):
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
def __init__(self):
super().__init__()
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
self.is_qat = False
self.quant_dtype = QuantDtype.use_8a8w
self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
self.per_channel_quant_config = get_ptq_per_channel_quant_config()
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
self.custom_quant_annotations: Sequence[Callable] = []
self.discard_nodes: Set[str] = set()
def _annotate(self, gm: GraphModule) -> None:
for node in gm.graph.nodes:
if node.name in self.discard_nodes:
continue
quant_config = self._get_quant_config(node.target)
if quant_config:
OP_ANNOTATOR[node.target](node, quant_config)
def _annotate_custom_annotation(self, gm: GraphModule) -> None:
for annotation_func in self.custom_quant_annotations:
annotation_func(gm)
def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
"""
Priority:
1. is one of use_per_channel_weight_quant_ops
2. quant config
"""
if isinstance(op, str):
return
if op in self.use_per_channel_weight_quant_ops:
return self.per_channel_quant_config
if op in self.quant_ops:
return self.quant_config
print(f"No quant config is implemented for op, {op}")
def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
if enable:
self.use_per_channel_weight_quant_ops.update(ops)
else:
self.use_per_channel_weight_quant_ops.difference_update(ops)
def add_custom_quant_annotations(
self, custom_quant_annotations: Sequence[Callable]
) -> None:
self.custom_quant_annotations = custom_quant_annotations
def add_discard_nodes(self, nodes: Sequence[str]) -> None:
self.discard_nodes = set(nodes)
def add_discard_ops(self, ops: Sequence[OpOverload]) -> None:
for op in ops:
self.quant_ops.remove(op)
def annotate(self, model: GraphModule) -> GraphModule:
self._annotate(model)
self._annotate_custom_annotation(model)
return model
def get_supported_ops(self) -> Set[OpOverload]:
return self.SUPPORTED_OPS
def set_quant_config(
self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
) -> None:
self.quant_dtype = quant_dtype
self.is_qat = is_qat
if (quant_dtype, is_qat) not in quant_config_dict:
raise RuntimeError(
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
)
quant_config_fuc, self.per_channel_quant_config = quant_config_dict[
(quant_dtype, is_qat)
]
self.quant_config = (
quant_config_fuc(act_observer) if act_observer else quant_config_fuc()
)
def set_per_channel_conv_quant(self, enable: bool) -> None:
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
self._update_per_channel_weight_quant_ops(conv_ops, enable)
def set_per_channel_linear_quant(self, enable: bool) -> None:
linear_ops = {
torch.ops.aten.linear.default,
}
self._update_per_channel_weight_quant_ops(linear_ops, enable)
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
model = ReduceDynamicRange()(model).graph_module
model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module
model = DecomposeScaledDotProductAttention()(model).graph_module
model = DecomposeSilu()(model).graph_module
model = DecomposeEinsum()(model).graph_module
model = ReplaceInfBuffer()(model).graph_module
return model
def validate(self, model: GraphModule) -> None:
pass