| # Copyright (c) 2024 MediaTek Inc. |
| # |
| # Licensed under the BSD License (the "License"); you may not use this file |
| # except in compliance with the License. See the license file in the root |
| # directory of this source tree for more details. |
| |
| import copy |
| |
| from enum import IntEnum, unique |
| |
| import torch |
| |
| from torch.ao.quantization.fake_quantize import FakeQuantize |
| from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver |
| from torch.ao.quantization.quantizer import QuantizationSpec |
| |
| |
| @unique |
| class Precision(IntEnum): |
| A16W16 = 0 |
| A16W8 = 1 |
| A16W4 = 2 |
| A8W8 = 3 |
| A8W4 = 4 |
| |
| |
| class QuantizationConfig: |
| |
| def __init__( |
| self, activation_spec: QuantizationSpec, weight_spec: QuantizationSpec |
| ): |
| self._activation_spec = activation_spec |
| self._weight_spec = weight_spec |
| |
| @property |
| def activation(self): |
| return copy.deepcopy(self._activation_spec) |
| |
| @property |
| def weight(self): |
| return copy.deepcopy(self._weight_spec) |
| |
| |
| def get_quant_config( |
| precision: Precision, |
| is_per_channel: bool = False, |
| is_qat: bool = False, |
| ) -> QuantizationConfig: |
| |
| precision_mappings = { |
| Precision.A16W16: get_a16w16_quant_config, |
| Precision.A16W8: get_a16w8_quant_config, |
| Precision.A16W4: get_a16w4_quant_config, |
| Precision.A8W8: get_a8w8_quant_config, |
| Precision.A8W4: get_a8w4_quant_config, |
| } |
| if precision not in precision_mappings: |
| raise RuntimeError("Unrecognized precision setting.") |
| |
| qconfig_fn = precision_mappings[precision] |
| return qconfig_fn(is_per_channel, is_qat) |
| |
| |
| def _get_activation_qspec( |
| dtype, |
| is_symmetric, |
| is_qat, |
| observer_cls=MinMaxObserver, |
| quant_min=None, |
| quant_max=None, |
| ): |
| if quant_max is None: |
| quant_max = torch.iinfo(dtype).max |
| if quant_min is None: |
| # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min |
| quant_min = torch.iinfo(dtype).min |
| |
| qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine |
| if is_qat: |
| observer_or_fake_quant = FakeQuantize.with_args(observer=observer_cls, eps=1e-6) |
| else: |
| observer_or_fake_quant = observer_cls.with_args(eps=1e-6) |
| |
| return QuantizationSpec( |
| dtype=dtype, |
| quant_min=quant_min, |
| quant_max=quant_max, |
| qscheme=qscheme, |
| observer_or_fake_quant_ctr=observer_or_fake_quant, |
| ) |
| |
| |
| def _get_weight_qspec( |
| dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None |
| ): |
| if not is_per_channel: |
| return _get_activation_qspec( |
| dtype, is_symmetric, is_qat, observer_cls=MinMaxObserver |
| ) |
| |
| if quant_max is None: |
| quant_max = torch.iinfo(dtype).max |
| if quant_min is None: |
| # quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min |
| quant_min = torch.iinfo(dtype).min |
| |
| qscheme = torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine |
| if is_qat: |
| observer_or_fake_quant = FakeQuantize.with_args( |
| observer=PerChannelMinMaxObserver, eps=1e-6 |
| ) |
| else: |
| observer_or_fake_quant = PerChannelMinMaxObserver.with_args(eps=1e-6) |
| |
| return QuantizationSpec( |
| dtype=dtype, |
| quant_min=quant_min, |
| quant_max=quant_max, |
| qscheme=qscheme, |
| ch_axis=0, |
| observer_or_fake_quant_ctr=observer_or_fake_quant, |
| ) |
| |
| |
| def get_a16w16_quant_config(is_per_channel, is_qat) -> QuantizationConfig: |
| act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) |
| wgt_quantization_spec = _get_weight_qspec(torch.int16, True, is_per_channel, is_qat) |
| quantization_config = QuantizationConfig( |
| act_quantization_spec, wgt_quantization_spec |
| ) |
| return quantization_config |
| |
| |
| def get_a16w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig: |
| act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) |
| wgt_quantization_spec = _get_weight_qspec(torch.int8, True, is_per_channel, is_qat) |
| quantization_config = QuantizationConfig( |
| act_quantization_spec, wgt_quantization_spec |
| ) |
| return quantization_config |
| |
| |
| def get_a16w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig: |
| act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat) |
| wgt_quantization_spec = _get_weight_qspec( |
| torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7 |
| ) |
| quantization_config = QuantizationConfig( |
| act_quantization_spec, wgt_quantization_spec |
| ) |
| return quantization_config |
| |
| |
| def get_a8w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig: |
| act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat) |
| wgt_quantization_spec = _get_weight_qspec(torch.int8, False, is_per_channel, is_qat) |
| quantization_config = QuantizationConfig( |
| act_quantization_spec, wgt_quantization_spec |
| ) |
| return quantization_config |
| |
| |
| def get_a8w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig: |
| act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat) |
| wgt_quantization_spec = _get_weight_qspec( |
| torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7 |
| ) |
| quantization_config = QuantizationConfig( |
| act_quantization_spec, wgt_quantization_spec |
| ) |
| return quantization_config |