| """ |
| This module implements modules which are used to perform fake quantization |
| during QAT. |
| """ |
| |
| import torch |
| from torch.nn import Module |
| from torch.ao.quantization.observer import ( |
| MovingAverageMinMaxObserver, |
| HistogramObserver, |
| MovingAveragePerChannelMinMaxObserver, |
| PerChannelMinMaxObserver, |
| _with_args, |
| ) |
| import re |
| from abc import ABC, abstractmethod |
| from typing import Any, Tuple |
| |
| def _is_per_channel(qscheme: 'torch.qscheme') -> bool: |
| return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams] |
| |
| def _is_per_tensor(qscheme: 'torch.qscheme') -> bool: |
| return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] |
| |
| def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool: |
| return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] |
| |
| class FakeQuantizeBase(ABC, Module): |
| r""" Base fake quantize module |
| Any fake quantize implementation should derive from this class. |
| |
| Concrete fake quantize module should follow the same API. In forward, they will update |
| the statistics of the observed Tensor and fake quantize the input. They should also provide a |
| `calculate_qparams` function that computes the quantization parameters given |
| the collected statistics. |
| |
| """ |
| |
| fake_quant_enabled: torch.Tensor |
| observer_enabled: torch.Tensor |
| |
| def __init__(self): |
| super().__init__() |
| # fake_quant_enabled and observer_enabled are buffers to support their |
| # replication in DDP. Data type is uint8 because NCCL does not support |
| # bool tensors. |
| self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) |
| self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8)) |
| |
| @abstractmethod |
| def forward(self, x): |
| pass |
| |
| @abstractmethod |
| def calculate_qparams(self, **kwargs): |
| pass |
| |
| @torch.jit.export |
| def enable_fake_quant(self, enabled: bool = True) -> None: |
| self.fake_quant_enabled[0] = 1 if enabled else 0 |
| |
| @torch.jit.export |
| def disable_fake_quant(self): |
| self.enable_fake_quant(False) |
| |
| @torch.jit.export |
| def enable_observer(self, enabled: bool = True) -> None: |
| self.observer_enabled[0] = 1 if enabled else 0 |
| |
| @torch.jit.export |
| def disable_observer(self): |
| self.enable_observer(False) |
| |
| with_args = classmethod(_with_args) |
| |
| class FakeQuantize(FakeQuantizeBase): |
| r""" Simulate the quantize and dequantize operations in training time. |
| The output of this module is given by:: |
| |
| x_out = ( |
| clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point |
| ) * scale |
| |
| * :attr:`scale` defines the scale factor used for quantization. |
| |
| * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to |
| |
| * :attr:`quant_min` specifies the minimum allowable quantized value. |
| |
| * :attr:`quant_max` specifies the maximum allowable quantized value. |
| |
| * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that |
| statistics can still be updated. |
| |
| * :attr:`observer_enabled` controls statistics collection on tensors |
| |
| * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, |
| allowable values are torch.qint8 and torch.quint8. The values of quant_min and |
| quant_max should be chosen to be consistent with the dtype |
| |
| Args: |
| |
| observer (module): Module for observing statistics on input tensors and calculating scale |
| and zero-point. |
| quant_min (int): The minimum allowable quantized value. |
| quant_max (int): The maximum allowable quantized value. |
| observer_kwargs (optional): Arguments for the observer module |
| |
| Attributes: |
| |
| observer (Module): User provided module that collects statistics on the input tensor and |
| provides a method to calculate scale and zero-point. |
| |
| """ |
| |
| scale: torch.Tensor |
| zero_point: torch.Tensor |
| |
| def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): |
| super().__init__() |
| assert quant_min <= quant_max, \ |
| 'quant_min must be less than or equal to quant_max' |
| self.quant_min = quant_min |
| self.quant_max = quant_max |
| self.activation_post_process = observer(**observer_kwargs) |
| assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound' |
| assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound' |
| self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) |
| self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int)) |
| self.dtype = self.activation_post_process.dtype |
| self.qscheme = self.activation_post_process.qscheme |
| self.ch_axis = self.activation_post_process.ch_axis \ |
| if hasattr(self.activation_post_process, 'ch_axis') else -1 |
| assert _is_per_channel(self.qscheme) or \ |
| _is_per_tensor(self.qscheme), \ |
| 'Only per channel and per tensor quantization are supported in fake quantize' + \ |
| ' got qscheme: ' + str(self.qscheme) |
| self.is_per_channel = _is_per_channel(self.qscheme) |
| |
| @torch.jit.export |
| def calculate_qparams(self): |
| return self.activation_post_process.calculate_qparams() |
| |
| def forward(self, X): |
| if self.observer_enabled[0] == 1: |
| self.activation_post_process(X.detach()) |
| _scale, _zero_point = self.calculate_qparams() |
| _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device) |
| if self.scale.shape != _scale.shape: |
| self.scale.resize_(_scale.shape) |
| self.zero_point.resize_(_zero_point.shape) |
| self.scale.copy_(_scale) |
| self.zero_point.copy_(_zero_point) |
| |
| if self.fake_quant_enabled[0] == 1: |
| if self.is_per_channel: |
| X = torch.fake_quantize_per_channel_affine( |
| X, self.scale, self.zero_point, |
| self.ch_axis, self.quant_min, self.quant_max) |
| else: |
| X = torch.fake_quantize_per_tensor_affine( |
| X, self.scale, self.zero_point, |
| self.quant_min, self.quant_max) |
| return X |
| |
| @torch.jit.export |
| def extra_repr(self): |
| return 'fake_quant_enabled={}, observer_enabled={}, ' \ |
| 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ |
| 'scale={}, zero_point={}'.format( |
| self.fake_quant_enabled, self.observer_enabled, |
| self.quant_min, self.quant_max, |
| self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) |
| |
| def _save_to_state_dict(self, destination, prefix, keep_vars): |
| # We cannot currently register scalar values as buffers, so need to manually |
| # specify serialization here. |
| super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) |
| destination[prefix + 'scale'] = self.scale |
| destination[prefix + 'zero_point'] = self.zero_point |
| |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| # Removing this function throws an error that the the size of the loaded tensor does not match the original size |
| # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. |
| local_state = ['scale', 'zero_point'] |
| for name in local_state: |
| key = prefix + name |
| if key in state_dict: |
| val = state_dict[key] |
| # Custom handling to allow loading scale and zero_point |
| # of size N into uninitialized buffers of size 0. The |
| # buffers are resized here, and the values are copied in |
| # the default state_dict loading code of the parent. |
| if name == 'scale': |
| self.scale.resize_(val.shape) |
| else: |
| assert name == 'zero_point' |
| self.zero_point.resize_(val.shape) |
| # For torchscript module we need to update the attributes here since we do not |
| # call the `_load_from_state_dict` function defined module.py |
| if torch.jit.is_scripting(): |
| if name == 'scale': |
| self.scale.copy_(val) |
| else: |
| assert name == 'zero_point' |
| self.zero_point.copy_(val) |
| elif strict: |
| missing_keys.append(key) |
| super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs) |
| |
| class FixedQParamsFakeQuantize(FakeQuantizeBase): |
| """ Simulate quantize and dequantize with fixed quantization |
| parameters in training time. Only per tensor quantization |
| is supported. |
| |
| Args: |
| |
| `scale` (float): fixed scale for the fake quantize module |
| `zero_point` (int): fixed zero point for the fake quantize module |
| `dtype`, `qscheme`, `quant_min`, `quant_max` |
| """ |
| |
| scale: torch.Tensor |
| zero_point: torch.Tensor |
| |
| def __init__(self, |
| scale, |
| zero_point, |
| dtype=torch.quint8, |
| qscheme=torch.per_tensor_affine, |
| quant_min=0, |
| quant_max=255): |
| super().__init__() |
| assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max' |
| self.quant_min = quant_min |
| self.quant_max = quant_max |
| self.register_buffer('scale', torch.tensor([scale], dtype=torch.float)) |
| self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int)) |
| self.dtype = dtype |
| self.qscheme = qscheme |
| assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \ |
| ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme) |
| |
| def forward(self, X): |
| if self.fake_quant_enabled[0] == 1: |
| X = torch.fake_quantize_per_tensor_affine(X, self.scale, |
| self.zero_point, self.quant_min, |
| self.quant_max) |
| return X |
| |
| @torch.jit.export |
| def calculate_qparams(self): |
| return self.scale, self.zero_point |
| |
| @torch.jit.export |
| def extra_repr(self): |
| return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \ |
| 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format( |
| self.fake_quant_enabled, self.observer_enabled, |
| self.scale, self.zero_point, self.dtype, |
| self.quant_min, self.quant_max, self.qscheme) |
| |
| class FusedMovingAvgObsFakeQuantize(FakeQuantize): |
| r"""Fused module that is used to observe the input tensor (compute min/max), compute |
| scale/zero_point and fake_quantize the tensor. |
| This module uses calculation similar MovingAverageMinMaxObserver for the inputs, |
| to compute the min/max values in order to compute the scale/zero_point. |
| The qscheme input in the observer is used to differentiate between symmetric/affine |
| quantization scheme. |
| |
| The output of this module is given by |
| x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale |
| |
| Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the |
| base class. |
| |
| """ |
| |
| def __init__( |
| self, |
| observer: Any = MovingAverageMinMaxObserver, |
| quant_min: int = 0, |
| quant_max: int = 255, |
| **observer_kwargs: Any |
| ) -> None: |
| super().__init__(observer, quant_min, quant_max, **observer_kwargs) |
| assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\ |
| "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" |
| self.quant_min: int = quant_min |
| self.quant_max: int = quant_max |
| self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) |
| self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) |
| self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme) |
| |
| self.quant_min, self.quant_max = self.activation_post_process.quant_min, self.activation_post_process.quant_max |
| |
| @torch.jit.export |
| def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| return self.activation_post_process.calculate_qparams() |
| |
| @torch.jit.export |
| def extra_repr(self) -> str: |
| return ( |
| "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, " |
| "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format( |
| self.fake_quant_enabled, |
| self.observer_enabled, |
| self.scale, |
| self.zero_point, |
| self.dtype, |
| self.quant_min, |
| self.quant_max, |
| self.qscheme, |
| self.activation_post_process.reduce_range, |
| ) |
| ) |
| |
| def forward(self, X: torch.Tensor) -> torch.Tensor: |
| return torch.fused_moving_avg_obs_fake_quant( |
| X, |
| self.observer_enabled, |
| self.fake_quant_enabled, |
| self.activation_post_process.min_val, |
| self.activation_post_process.max_val, |
| self.scale, |
| self.zero_point, |
| self.activation_post_process.averaging_constant, |
| self.quant_min, |
| self.quant_max, |
| self.ch_axis, |
| self.is_per_channel, |
| self.is_symmetric_quant, |
| ) |
| |
| default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, |
| dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) |
| """ |
| Default fake_quant for activations. |
| """ |
| |
| default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, |
| dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False) |
| """ |
| Default fake_quant for weights. |
| """ |
| |
| # TODO(future PR): remove these defaults and enforce activation functions |
| # to explicitly specify their output range |
| default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args( |
| scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255) |
| default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args( |
| scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255) |
| |
| default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, |
| quant_min=-128, |
| quant_max=127, |
| dtype=torch.qint8, |
| qscheme=torch.per_channel_symmetric, |
| reduce_range=False, |
| ch_axis=0) |
| """ |
| Default fake_quant for per-channel weights. |
| """ |
| |
| default_embedding_fake_quant = FakeQuantize.with_args(observer=PerChannelMinMaxObserver, |
| qscheme=torch.per_channel_affine_float_qparams, |
| dtype=torch.quint8, |
| quant_min=0, |
| quant_max=255, |
| ch_axis=0, |
| memoryless=True) |
| """ |
| Default fake_quant for embeddings. |
| """ |
| |
| default_embedding_fake_quant_4bit = FakeQuantize.with_args(observer=PerChannelMinMaxObserver, |
| qscheme=torch.per_channel_affine_float_qparams, |
| ch_axis=0, |
| dtype=torch.quint4x2, |
| memoryless=True) |
| |
| default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver, |
| quant_min=0, |
| quant_max=255, |
| dtype=torch.quint8, |
| qscheme=torch.per_tensor_affine, |
| reduce_range=True) |
| """ |
| Fake_quant for activations using a histogram.. |
| """ |
| |
| |
| default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, |
| quant_min=0, |
| quant_max=255, |
| dtype=torch.quint8,) |
| """ |
| Fused version of `default_fake_quant`, with improved performance. |
| """ |
| |
| |
| default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, |
| quant_min=-128, |
| quant_max=127, |
| dtype=torch.qint8, |
| qscheme=torch.per_tensor_symmetric) |
| """ |
| Fused version of `default_weight_fake_quant`, with improved performance. |
| """ |
| |
| default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, |
| quant_min=-128, |
| quant_max=127, |
| dtype=torch.qint8, |
| qscheme=torch.per_channel_symmetric) |
| """ |
| Fused version of `default_per_channel_weight_fake_quant`, with improved performance. |
| """ |
| |
| def _is_fake_quant_script_module(mod): |
| ''' Returns true if given mod is an instance of FakeQuantize script module. |
| ''' |
| if isinstance(mod, torch.jit.RecursiveScriptModule): |
| # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize' |
| suffix = mod._c.qualified_name.split('.', 1)[1] |
| name = re.sub(r'\.___torch_mangle_\d+', '', suffix) |
| return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \ |
| name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize' |
| return False |
| |
| def disable_fake_quant(mod): |
| """ |
| Disable fake quantization for this module, if applicable. Example usage:: |
| |
| # model is any PyTorch model |
| model.apply(torch.ao.quantization.disable_fake_quant) |
| |
| """ |
| if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): |
| mod.disable_fake_quant() |
| |
| def enable_fake_quant(mod): |
| """ |
| Enable fake quantization for this module, if applicable. Example usage:: |
| |
| # model is any PyTorch model |
| model.apply(torch.ao.quantization.enable_fake_quant) |
| |
| """ |
| if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): |
| mod.enable_fake_quant() |
| |
| def disable_observer(mod): |
| """ |
| Disable observation for this module, if applicable. Example usage:: |
| |
| # model is any PyTorch model |
| model.apply(torch.ao.quantization.disable_observer) |
| |
| """ |
| if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): |
| mod.disable_observer() |
| |
| def enable_observer(mod): |
| """ |
| Enable observation for this module, if applicable. Example usage:: |
| |
| # model is any PyTorch model |
| model.apply(torch.ao.quantization.enable_observer) |
| |
| """ |
| if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): |
| mod.enable_observer() |