blob: e16f5e936cbe1ebc92ad541bd24861c7d01dd106 [file] [log] [blame]
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.
import copy
from enum import IntEnum, unique
import torch
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.ao.quantization.quantizer import QuantizationSpec
@unique
class Precision(IntEnum):
A16W16 = 0
A16W8 = 1
A16W4 = 2
A8W8 = 3
A8W4 = 4
class QuantizationConfig:
def __init__(
self, activation_spec: QuantizationSpec, weight_spec: QuantizationSpec
):
self._activation_spec = activation_spec
self._weight_spec = weight_spec
@property
def activation(self):
return copy.deepcopy(self._activation_spec)
@property
def weight(self):
return copy.deepcopy(self._weight_spec)
def get_quant_config(
precision: Precision,
is_per_channel: bool = False,
is_qat: bool = False,
) -> QuantizationConfig:
precision_mappings = {
Precision.A16W16: get_a16w16_quant_config,
Precision.A16W8: get_a16w8_quant_config,
Precision.A16W4: get_a16w4_quant_config,
Precision.A8W8: get_a8w8_quant_config,
Precision.A8W4: get_a8w4_quant_config,
}
if precision not in precision_mappings:
raise RuntimeError("Unrecognized precision setting.")
qconfig_fn = precision_mappings[precision]
return qconfig_fn(is_per_channel, is_qat)
def _get_activation_qspec(
dtype,
is_symmetric,
is_qat,
observer_cls=MinMaxObserver,
quant_min=None,
quant_max=None,
):
if quant_max is None:
quant_max = torch.iinfo(dtype).max
if quant_min is None:
# quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
quant_min = torch.iinfo(dtype).min
qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine
if is_qat:
observer_or_fake_quant = FakeQuantize.with_args(observer=observer_cls, eps=1e-6)
else:
observer_or_fake_quant = observer_cls.with_args(eps=1e-6)
return QuantizationSpec(
dtype=dtype,
quant_min=quant_min,
quant_max=quant_max,
qscheme=qscheme,
observer_or_fake_quant_ctr=observer_or_fake_quant,
)
def _get_weight_qspec(
dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None
):
if not is_per_channel:
return _get_activation_qspec(
dtype, is_symmetric, is_qat, observer_cls=MinMaxObserver
)
if quant_max is None:
quant_max = torch.iinfo(dtype).max
if quant_min is None:
# quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
quant_min = torch.iinfo(dtype).min
qscheme = torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine
if is_qat:
observer_or_fake_quant = FakeQuantize.with_args(
observer=PerChannelMinMaxObserver, eps=1e-6
)
else:
observer_or_fake_quant = PerChannelMinMaxObserver.with_args(eps=1e-6)
return QuantizationSpec(
dtype=dtype,
quant_min=quant_min,
quant_max=quant_max,
qscheme=qscheme,
ch_axis=0,
observer_or_fake_quant_ctr=observer_or_fake_quant,
)
def get_a16w16_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
wgt_quantization_spec = _get_weight_qspec(torch.int16, True, is_per_channel, is_qat)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a16w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
wgt_quantization_spec = _get_weight_qspec(torch.int8, True, is_per_channel, is_qat)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a16w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
wgt_quantization_spec = _get_weight_qspec(
torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a8w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
wgt_quantization_spec = _get_weight_qspec(torch.int8, False, is_per_channel, is_qat)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a8w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
wgt_quantization_spec = _get_weight_qspec(
torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config