| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # Copyright 2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| from dataclasses import dataclass |
| |
| import torch |
| |
| from torch.ao.quantization.quantizer import ( |
| FixedQParamsQuantizationSpec, |
| QuantizationSpec, |
| ) |
| |
| |
| @dataclass(eq=True, frozen=True) |
| class QuantizationConfig: |
| input_activation: QuantizationSpec | None |
| output_activation: QuantizationSpec | None |
| weight: QuantizationSpec | None |
| bias: QuantizationSpec | None |
| |
| def get_input_act_qspec(self) -> QuantizationSpec | None: |
| """Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid.""" |
| if self.input_activation is None: |
| return None |
| assert self.input_activation.qscheme in [ |
| torch.per_tensor_affine, |
| torch.per_tensor_symmetric, |
| ], f"Unsupported quantization_spec {self.input_activation} for input_activation." |
| return self.input_activation |
| |
| def get_output_act_qspec(self) -> QuantizationSpec | None: |
| """Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid.""" |
| if self.output_activation is None: |
| return None |
| assert self.output_activation.qscheme in [ |
| torch.per_tensor_affine, |
| torch.per_tensor_symmetric, |
| ], f"Unsupported quantization_spec {self.output_activation} for output_activation." |
| return self.output_activation |
| |
| def get_weight_qspec(self) -> QuantizationSpec | None: |
| """Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid.""" |
| if self.weight is None: |
| return None |
| assert self.weight.qscheme in [ |
| torch.per_tensor_symmetric, |
| torch.per_channel_symmetric, |
| ], f"Unsupported quantization_spec {self.weight} for weight" |
| return self.weight |
| |
| def get_bias_qspec(self) -> QuantizationSpec | None: |
| """Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float.""" |
| if self.bias is None: |
| return None |
| assert ( |
| self.bias.dtype == torch.float |
| ), "Only float dtype for bias is supported for bias right now" |
| return self.bias |
| |
| def get_fixed_qspec( |
| self, |
| scale: float, |
| zp: int, |
| dtype: torch.dtype = torch.int8, |
| quant_min: int = -128, |
| quant_max: int = 127, |
| ) -> FixedQParamsQuantizationSpec: |
| """Returns a new FixedQParamsQuantizationSpec with the given parameters.""" |
| return FixedQParamsQuantizationSpec( |
| dtype=dtype, |
| qscheme=torch.per_tensor_affine, |
| scale=scale, |
| zero_point=zp, |
| quant_min=quant_min, |
| quant_max=quant_max, |
| ) |