|  | """ | 
|  | This module implements observers which are used to collect statistics about | 
|  | the values observed during calibration (PTQ) or training (QAT). | 
|  | """ | 
|  |  | 
|  | import re | 
|  | import warnings | 
|  | from abc import ABCMeta, abstractmethod | 
|  | from collections import OrderedDict | 
|  | from functools import partial | 
|  | from typing import Any, List, Tuple, Optional, Dict, Union | 
|  |  | 
|  | import torch | 
|  | import torch.nn as nn | 
|  | from torch.ao.quantization.utils import check_min_max_valid, calculate_qmin_qmax | 
|  |  | 
|  |  | 
|  | class _PartialWrapper(object): | 
|  | def __init__(self, p): | 
|  | self.p = p | 
|  | self.callable_args = {} | 
|  |  | 
|  | def __call__(self, *args, **keywords): | 
|  | # call each arg in callable_args and add them partial, then run with keywords | 
|  | # skip if arg_name in keywords so its possible to overwrite | 
|  | for arg_name in self.callable_args: | 
|  | if arg_name not in keywords: | 
|  | keywords = {**keywords, **{arg_name: self.callable_args[arg_name]()}} | 
|  | return self.p(*args, **keywords) | 
|  |  | 
|  | def __repr__(self): | 
|  | return self.p.__repr__() + self.callable_args.__repr__() | 
|  |  | 
|  | def with_args(self, **kwargs): | 
|  | return _with_args(self, **kwargs) | 
|  |  | 
|  | def with_callable_args(self, **kwargs): | 
|  | result = _PartialWrapper(p=self.p) | 
|  | result.callable_args = {**self.callable_args, **kwargs} | 
|  | return result | 
|  |  | 
|  |  | 
|  | def _with_args(cls_or_self, **kwargs): | 
|  | r"""Wrapper that allows creation of class factories. | 
|  |  | 
|  | This can be useful when there is a need to create classes with the same | 
|  | constructor arguments, but different instances. Can be used in conjunction with | 
|  | _callable_args | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> Foo.with_args = classmethod(_with_args) | 
|  | >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) | 
|  | >>> foo_instance1 = foo_builder() | 
|  | >>> foo_instance2 = foo_builder() | 
|  | >>> id(foo_instance1) == id(foo_instance2) | 
|  | False | 
|  | """ | 
|  | r = _PartialWrapper(partial(cls_or_self, **kwargs)) | 
|  | return r | 
|  |  | 
|  | def _with_callable_args(cls_or_self, **kwargs): | 
|  | r"""Wrapper that allows creation of class factories args that need to be | 
|  | called at construction time. | 
|  |  | 
|  | This can be useful when there is a need to create classes with the same | 
|  | constructor arguments, but different instances and those arguments should only | 
|  | be calculated at construction time. Can be used in conjunction with _with_args | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> Foo.with_callable_args = classmethod(_with_callable_args) | 
|  | >>> Foo.with_args = classmethod(_with_args) | 
|  | >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") | 
|  | >>> foo_instance1 = foo_builder() | 
|  | >>> wait 50 | 
|  | >>> foo_instance2 = foo_builder() | 
|  | >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) | 
|  | False | 
|  | """ | 
|  | r = _PartialWrapper(partial(cls_or_self)) | 
|  | return r.with_callable_args(**kwargs) | 
|  |  | 
|  |  | 
|  | ABC: Any = ABCMeta(str("ABC"), (object,), {})  # compatible with Python 2 *and* 3: | 
|  |  | 
|  |  | 
|  | class ObserverBase(ABC, nn.Module): | 
|  | r"""Base observer Module. | 
|  | Any observer implementation should derive from this class. | 
|  |  | 
|  | Concrete observers should follow the same API. In forward, they will update | 
|  | the statistics of the observed Tensor. And they should provide a | 
|  | `calculate_qparams` function that computes the quantization parameters given | 
|  | the collected statistics. | 
|  |  | 
|  | Args: | 
|  | dtype: Quantized data type | 
|  | """ | 
|  |  | 
|  | def __init__(self, dtype): | 
|  | super(ObserverBase, self).__init__() | 
|  | self.dtype = dtype | 
|  |  | 
|  | @abstractmethod | 
|  | def forward(self, x): | 
|  | pass | 
|  |  | 
|  | @abstractmethod | 
|  | def calculate_qparams(self, **kwargs): | 
|  | pass | 
|  |  | 
|  | with_args = classmethod(_with_args) | 
|  | with_callable_args = classmethod(_with_callable_args) | 
|  |  | 
|  |  | 
|  | class _ObserverBase(ObserverBase): | 
|  | r"""Internal common base for all qint/quint8 observers. | 
|  |  | 
|  | This base is for commonly used parameters used internally. | 
|  | Users should use `~torch.ao.quantization.observer.ObserverBase` as a base class | 
|  | for custom observers. | 
|  |  | 
|  | Args: | 
|  | dtype: Quantized data type. | 
|  | qscheme: Quantization scheme to be used. | 
|  | reduce_range: Reduces the range of the quantized data type by 1 bit. | 
|  | This is sometimes required to avoid instruction overflow. | 
|  | quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  | quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  |  | 
|  | .. warning:: | 
|  |  | 
|  | :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. | 
|  |  | 
|  | .. warning:: | 
|  |  | 
|  | :attr:`qscheme` can only take one of the following options: | 
|  |  | 
|  | - ``torch.per_tensor_affine`` | 
|  | - ``torch.per_tensor_symmetric`` | 
|  | - ``torch.per_channel_affine`` | 
|  | - ``torch.per_channel_symmetric`` | 
|  | """ | 
|  |  | 
|  | # Note: the version is shared by all observer types | 
|  | # | 
|  | # Version 1/None | 
|  | #   self | 
|  | # | 
|  | # Version 2 (base class only, does not include child class buffers) | 
|  | #   self | 
|  | #   |--- eps : Tensor | 
|  | # | 
|  | # Version 3 | 
|  | #   for HistogramObserver only, changed the shape of uninitialized | 
|  | #   min_val and max_val buffers from torch.Size([0]) to torch.Size([]) | 
|  | #   for PerChannelObservers, changed the name of the buffers from min_vals | 
|  | #   to min_val and from max_vals to max_val. | 
|  | _version = 3 | 
|  |  | 
|  | eps: torch.Tensor | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | dtype=torch.quint8, | 
|  | qscheme=torch.per_tensor_affine, | 
|  | reduce_range=False, | 
|  | quant_min=None, | 
|  | quant_max=None, | 
|  | factory_kwargs=None, | 
|  | ) -> None: | 
|  | factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | 
|  | super(_ObserverBase, self).__init__(dtype=dtype) | 
|  | self.qscheme = qscheme | 
|  | if reduce_range: | 
|  | warnings.warn( | 
|  | "Please use quant_min and quant_max to specify the range for observers. \ | 
|  | reduce_range will be deprecated in a future release of PyTorch." | 
|  | ) | 
|  | self.reduce_range = reduce_range | 
|  | self.register_buffer( | 
|  | "eps", torch.tensor([torch.finfo(torch.float32).eps], **factory_kwargs) | 
|  | ) | 
|  | assert self.qscheme in ( | 
|  | torch.per_tensor_affine, | 
|  | torch.per_tensor_symmetric, | 
|  | torch.per_channel_affine, | 
|  | torch.per_channel_symmetric, | 
|  | torch.per_channel_affine_float_qparams, | 
|  | ), "Default Observer only works for per_tensor_affine, \ | 
|  | per_tensor_symmetric, per_channel_affine, \ | 
|  | per_channel_symmetric and per_channel_float_qparams quantization scheme" | 
|  | assert self.dtype in ( | 
|  | torch.qint8, | 
|  | torch.quint8, | 
|  | torch.quint4x2, | 
|  | ), "Default Observer only works for qint8, quint8 and quint4x2 data type" | 
|  | self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) | 
|  | if self.has_customized_qrange: | 
|  | self._validate_qmin_qmax(quant_min, quant_max) | 
|  | self.quant_min, self.quant_max = \ | 
|  | calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range) | 
|  |  | 
|  | def _load_from_state_dict( | 
|  | self, | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | strict, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ): | 
|  |  | 
|  | version = local_metadata.get("version", None) | 
|  |  | 
|  | if version is None or version == 1: | 
|  | # eps was moved to a buffer in version 2 | 
|  | eps = torch.tensor([torch.finfo(torch.float32).eps]) | 
|  | state_dict[prefix + "eps"] = eps | 
|  |  | 
|  | super(ObserverBase, self)._load_from_state_dict( | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | strict, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ) | 
|  |  | 
|  | @torch.jit.export | 
|  | def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: | 
|  | r"""Validates that the user-specified quantization range is properly initialized | 
|  | and within the given bound supported by the observer dtype. | 
|  |  | 
|  | To accommodate lower-bit quantization with respect to the existing torch.qint8 and | 
|  | torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing | 
|  | in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax | 
|  | values are used to calculate static estimates of the scale and zero point for aggressive lower-bit | 
|  | fake quantization. These estimates are compared against parameters learned through backpropagation. | 
|  | The related literatures for scale and zero point via backpropagation are as follows: | 
|  |  | 
|  | Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS | 
|  | Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf | 
|  | """ | 
|  | # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted | 
|  | # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. | 
|  | assert ( | 
|  | quant_min <= 0 <= quant_max | 
|  | ), "Used-specified quantization range must include 0." | 
|  | assert ( | 
|  | quant_min < quant_max | 
|  | ), "qmin must be strictly less than qmax for user-specified quantization range." | 
|  |  | 
|  | @torch.jit.export | 
|  | def _calculate_qparams( | 
|  | self, min_val: torch.Tensor, max_val: torch.Tensor | 
|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | r"""Calculates the quantization parameters, given min and max | 
|  | value tensors. Works for both per tensor and per channel cases | 
|  |  | 
|  | Args: | 
|  | min_val: Minimum values per channel | 
|  | max_val: Maximum values per channel | 
|  |  | 
|  | Returns: | 
|  | scales: Scales tensor of shape (#channels,) | 
|  | zero_points: Zero points tensor of shape (#channels,) | 
|  | """ | 
|  | if not check_min_max_valid(min_val, max_val): | 
|  | return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type) | 
|  |  | 
|  | quant_min, quant_max = self.quant_min, self.quant_max | 
|  | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) | 
|  | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) | 
|  |  | 
|  | device = min_val_neg.device | 
|  | scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) | 
|  | zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) | 
|  |  | 
|  | if ( | 
|  | self.qscheme == torch.per_tensor_symmetric | 
|  | or self.qscheme == torch.per_channel_symmetric | 
|  | ): | 
|  | max_val_pos = torch.max(-min_val_neg, max_val_pos) | 
|  | scale = max_val_pos / (float(quant_max - quant_min) / 2) | 
|  | scale = torch.max(scale, self.eps) | 
|  | if self.dtype == torch.quint8: | 
|  | if self.has_customized_qrange: | 
|  | # When customized quantization range is used, down-rounded midpoint of the range is chosen. | 
|  | zero_point = zero_point.new_full( | 
|  | zero_point.size(), (quant_min + quant_max) // 2 | 
|  | ) | 
|  | else: | 
|  | zero_point = zero_point.new_full(zero_point.size(), 128) | 
|  | elif self.qscheme == torch.per_channel_affine_float_qparams: | 
|  | scale = (max_val - min_val) / float(quant_max - quant_min) | 
|  | scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) | 
|  | # We use the quantize function | 
|  | # xq = Round(Xf * inv_scale + zero_point), | 
|  | # setting zero_point to (-1 * min *inv_scale) we get | 
|  | # Xq = Round((Xf - min) * inv_scale) | 
|  | zero_point = -1 * min_val / scale | 
|  | else: | 
|  | scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) | 
|  | scale = torch.max(scale, self.eps) | 
|  | zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) | 
|  | zero_point = torch.clamp(zero_point, quant_min, quant_max) | 
|  |  | 
|  | # For scalar values, cast them to Tensors of size 1 to keep the shape | 
|  | # consistent with default values in FakeQuantize. | 
|  | if len(scale.shape) == 0: | 
|  | # TODO: switch to scale.item() after adding JIT support | 
|  | scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) | 
|  | if len(zero_point.shape) == 0: | 
|  | # TODO: switch to zero_point.item() after adding JIT support | 
|  | zero_point = torch.tensor( | 
|  | [int(zero_point)], dtype=zero_point.dtype, device=device | 
|  | ) | 
|  | if self.qscheme == torch.per_channel_affine_float_qparams: | 
|  | zero_point = torch.tensor( | 
|  | [float(zero_point)], dtype=zero_point.dtype, device=device | 
|  | ) | 
|  |  | 
|  | return scale, zero_point | 
|  |  | 
|  | @torch.jit.export | 
|  | def reset_min_max_vals(self): | 
|  | raise NotImplementedError("Cannot reset min/max values in the given observer.") | 
|  |  | 
|  |  | 
|  | class MinMaxObserver(_ObserverBase): | 
|  | r"""Observer module for computing the quantization parameters based on the | 
|  | running min and max values. | 
|  |  | 
|  | This observer uses the tensor min/max statistics to compute the quantization | 
|  | parameters. The module records the running minimum and maximum of incoming | 
|  | tensors, and uses this statistic to compute the quantization parameters. | 
|  |  | 
|  | Args: | 
|  | dtype: Quantized data type | 
|  | qscheme: Quantization scheme to be used | 
|  | reduce_range: Reduces the range of the quantized data type by 1 bit | 
|  | quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  | quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  | memoryless: Boolean that controls whether observer removes old data when a new input is seen. | 
|  | This is most useful for simulating dynamic quantization, especially during QAT. | 
|  |  | 
|  | Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, | 
|  | scale :math:`s` and zero point :math:`z` are computed as: | 
|  |  | 
|  | The running minimum/maximum :math:`x_\text{min/max}` is computed as: | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \begin{array}{ll} | 
|  | x_\text{min} &= \begin{cases} | 
|  | \min(X) & \text{if~}x_\text{min} = \text{None} \\ | 
|  | \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} | 
|  | \end{cases}\\ | 
|  | x_\text{max} &= \begin{cases} | 
|  | \max(X) & \text{if~}x_\text{max} = \text{None} \\ | 
|  | \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} | 
|  | \end{cases}\\ | 
|  | \end{array} | 
|  |  | 
|  | where :math:`X` is the observed tensor. | 
|  |  | 
|  | The scale :math:`s` and zero point :math:`z` are then computed as: | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \begin{aligned} | 
|  | \text{if Symmetric:}&\\ | 
|  | &s = 2 \max(|x_\text{min}|, x_\text{max}) / | 
|  | \left( Q_\text{max} - Q_\text{min} \right) \\ | 
|  | &z = \begin{cases} | 
|  | 0 & \text{if dtype is qint8} \\ | 
|  | 128 & \text{otherwise} | 
|  | \end{cases}\\ | 
|  | \text{Otherwise:}&\\ | 
|  | &s = \left( x_\text{max} - x_\text{min}  \right ) / | 
|  | \left( Q_\text{max} - Q_\text{min} \right ) \\ | 
|  | &z = Q_\text{min} - \text{round}(x_\text{min} / s) | 
|  | \end{aligned} | 
|  |  | 
|  | where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and | 
|  | maximum of the quantized data type. | 
|  |  | 
|  | .. warning:: Only works with ``torch.per_tensor_symmetric`` quantization scheme | 
|  |  | 
|  | .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. | 
|  |  | 
|  | .. note:: If the running minimum equals to the running maximum, the scale | 
|  | and zero_point are set to 1.0 and 0. | 
|  | """ | 
|  | min_val: torch.Tensor | 
|  | max_val: torch.Tensor | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | dtype=torch.quint8, | 
|  | qscheme=torch.per_tensor_affine, | 
|  | reduce_range=False, | 
|  | quant_min=None, | 
|  | quant_max=None, | 
|  | factory_kwargs=None, | 
|  | memoryless=False, | 
|  | ) -> None: | 
|  |  | 
|  | # For x86 quantized kernels, we need to ensure that the vpmaddubsw | 
|  | # instruction does not overflow. We allow for a reduce_range argument to | 
|  | # observers that reduces the quantized range to (0,127) or (-64, 63). | 
|  | # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp | 
|  | # This is not an optimal choice for non x86 backends as it loses a bit | 
|  | # of precision for activations. | 
|  | super(MinMaxObserver, self).__init__( | 
|  | dtype=dtype, | 
|  | qscheme=qscheme, | 
|  | reduce_range=reduce_range, | 
|  | quant_min=quant_min, | 
|  | quant_max=quant_max, | 
|  | factory_kwargs=factory_kwargs, | 
|  | ) | 
|  | self.memoryless = memoryless | 
|  | factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | 
|  | self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) | 
|  | self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) | 
|  | if ( | 
|  | self.qscheme == torch.per_tensor_symmetric | 
|  | and self.reduce_range | 
|  | and self.dtype == torch.quint8 | 
|  | ): | 
|  | raise NotImplementedError( | 
|  | "Cannot reduce range for symmetric \ | 
|  | quantization for quint8" | 
|  | ) | 
|  |  | 
|  | def forward(self, x_orig): | 
|  | r"""Records the running minimum and maximum of ``x``.""" | 
|  | if x_orig.numel() == 0: | 
|  | return x_orig | 
|  | elif self.memoryless: | 
|  | self.reset_min_max_vals() | 
|  | x = x_orig.detach()  # avoid keeping autograd tape | 
|  | x = x.to(self.min_val.dtype) | 
|  | min_val_cur, max_val_cur = torch._aminmax(x) | 
|  | min_val = torch.min(min_val_cur, self.min_val) | 
|  | max_val = torch.max(max_val_cur, self.max_val) | 
|  | self.min_val.copy_(min_val) | 
|  | self.max_val.copy_(max_val) | 
|  | return x_orig | 
|  |  | 
|  | @torch.jit.export | 
|  | def calculate_qparams(self): | 
|  | r"""Calculates the quantization parameters.""" | 
|  | return self._calculate_qparams(self.min_val, self.max_val) | 
|  |  | 
|  | @torch.jit.export | 
|  | def extra_repr(self): | 
|  | return "min_val={}, max_val={}".format(self.min_val, self.max_val) | 
|  |  | 
|  | @torch.jit.export | 
|  | def reset_min_max_vals(self): | 
|  | """Resets the min/max values.""" | 
|  | self.min_val.copy_(torch.tensor(float("inf"))) | 
|  | self.max_val.copy_(torch.tensor(float("-inf"))) | 
|  |  | 
|  | class MovingAverageMinMaxObserver(MinMaxObserver): | 
|  | r"""Observer module for computing the quantization parameters based on the | 
|  | moving average of the min and max values. | 
|  |  | 
|  | This observer computes the quantization parameters based on the moving | 
|  | averages of minimums and maximums of the incoming tensors. The module | 
|  | records the average minimum and maximum of incoming tensors, and uses this | 
|  | statistic to compute the quantization parameters. | 
|  |  | 
|  | Args: | 
|  | averaging_constant: Averaging constant for min/max. | 
|  | dtype: Quantized data type | 
|  | qscheme: Quantization scheme to be used | 
|  | reduce_range: Reduces the range of the quantized data type by 1 bit | 
|  | quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  | quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  |  | 
|  | The moving average min/max is computed as follows | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \begin{array}{ll} | 
|  | x_\text{min} = \begin{cases} | 
|  | \min(X) & \text{if~}x_\text{min} = \text{None} \\ | 
|  | (1 - c) x_\text{min} + c \min(X) & \text{otherwise} | 
|  | \end{cases}\\ | 
|  | x_\text{max} = \begin{cases} | 
|  | \max(X) & \text{if~}x_\text{max} = \text{None} \\ | 
|  | (1 - c) x_\text{max} + c \max(X) & \text{otherwise} | 
|  | \end{cases}\\ | 
|  | \end{array} | 
|  |  | 
|  | where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is | 
|  | is the incoming tensor, and :math:`c` is the ``averaging_constant``. | 
|  |  | 
|  | The scale and zero point are then computed as in | 
|  | :class:`~torch.ao.quantization.observer.MinMaxObserver`. | 
|  |  | 
|  | .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. | 
|  |  | 
|  | .. note:: If the running minimum equals to the running maximum, the scale | 
|  | and zero_point are set to 1.0 and 0. | 
|  | """ | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | averaging_constant=0.01, | 
|  | dtype=torch.quint8, | 
|  | qscheme=torch.per_tensor_affine, | 
|  | reduce_range=False, | 
|  | quant_min=None, | 
|  | quant_max=None, | 
|  | **kwargs | 
|  | ) -> None: | 
|  | self.averaging_constant = averaging_constant | 
|  | super(MovingAverageMinMaxObserver, self).__init__( | 
|  | dtype=dtype, | 
|  | qscheme=qscheme, | 
|  | reduce_range=reduce_range, | 
|  | quant_min=quant_min, | 
|  | quant_max=quant_max, | 
|  | **kwargs | 
|  | ) | 
|  |  | 
|  | def forward(self, x_orig): | 
|  | if x_orig.numel() == 0: | 
|  | return x_orig | 
|  | x = x_orig.detach()  # avoid keeping autograd tape | 
|  | x = x.to(self.min_val.dtype) | 
|  | min_val = self.min_val | 
|  | max_val = self.max_val | 
|  | if min_val == float("inf") and max_val == float("-inf"): | 
|  | min_val, max_val = torch._aminmax(x) | 
|  | else: | 
|  | min_val_cur, max_val_cur = torch._aminmax(x) | 
|  | min_val = min_val + self.averaging_constant * (min_val_cur - min_val) | 
|  | max_val = max_val + self.averaging_constant * (max_val_cur - max_val) | 
|  | self.min_val.copy_(min_val) | 
|  | self.max_val.copy_(max_val) | 
|  | return x_orig | 
|  |  | 
|  |  | 
|  | class PerChannelMinMaxObserver(_ObserverBase): | 
|  | r"""Observer module for computing the quantization parameters based on the | 
|  | running per channel min and max values. | 
|  |  | 
|  | This observer uses the tensor min/max statistics to compute the per channel | 
|  | quantization parameters. The module records the running minimum and maximum | 
|  | of incoming tensors, and uses this statistic to compute the quantization | 
|  | parameters. | 
|  |  | 
|  | Args: | 
|  | ch_axis: Channel axis | 
|  | dtype: Quantized data type | 
|  | qscheme: Quantization scheme to be used | 
|  | reduce_range: Reduces the range of the quantized data type by 1 bit | 
|  | quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  | quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  | memoryless: Boolean that controls whether observer removes old data when a new input is seen. | 
|  | This is most useful for simulating dynamic quantization, especially during QAT. | 
|  |  | 
|  | The quantization parameters are computed the same way as in | 
|  | :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference | 
|  | that the running min/max values are stored per channel. | 
|  | Scales and zero points are thus computed per channel as well. | 
|  |  | 
|  | .. note:: If the running minimum equals to the running maximum, the scales | 
|  | and zero_points are set to 1.0 and 0. | 
|  | """ | 
|  | min_val: torch.Tensor | 
|  | max_val: torch.Tensor | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | ch_axis=0, | 
|  | dtype=torch.quint8, | 
|  | qscheme=torch.per_channel_affine, | 
|  | reduce_range=False, | 
|  | quant_min=None, | 
|  | quant_max=None, | 
|  | factory_kwargs=None, | 
|  | memoryless=False, | 
|  | ) -> None: | 
|  | super(PerChannelMinMaxObserver, self).__init__( | 
|  | dtype=dtype, | 
|  | qscheme=qscheme, | 
|  | reduce_range=reduce_range, | 
|  | quant_min=quant_min, | 
|  | quant_max=quant_max, | 
|  | factory_kwargs=factory_kwargs, | 
|  | ) | 
|  | self.memoryless = memoryless | 
|  | factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | 
|  | self.ch_axis = ch_axis | 
|  | self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) | 
|  | self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) | 
|  | if ( | 
|  | self.qscheme == torch.per_channel_symmetric | 
|  | and self.reduce_range | 
|  | and self.dtype == torch.quint8 | 
|  | ): | 
|  | raise NotImplementedError( | 
|  | "Cannot reduce range for symmetric quantization for quint8" | 
|  | ) | 
|  |  | 
|  | def forward(self, x_orig): | 
|  | return self._forward(x_orig) | 
|  |  | 
|  | def _forward(self, x_orig): | 
|  | if x_orig.numel() == 0: | 
|  | return x_orig | 
|  | x = x_orig.detach()  # avoid keeping autograd tape | 
|  | min_val = self.min_val | 
|  | max_val = self.max_val | 
|  | x_dim = x.size() | 
|  |  | 
|  | new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416 | 
|  | new_axis_list[self.ch_axis] = 0 | 
|  | new_axis_list[0] = self.ch_axis | 
|  | y = x.permute(new_axis_list) | 
|  | # Need to match dtype of min/max because the updates to buffers | 
|  | # are done in place and types need to match for comparisons | 
|  | y = y.to(self.min_val.dtype) | 
|  | y = torch.flatten(y, start_dim=1) | 
|  | if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless: | 
|  | min_val, max_val = torch._aminmax(y, 1) | 
|  | else: | 
|  | min_val_cur, max_val_cur = torch._aminmax(y, 1) | 
|  | min_val = torch.min(min_val_cur, min_val) | 
|  | max_val = torch.max(max_val_cur, max_val) | 
|  | self.min_val.resize_(min_val.shape) | 
|  | self.max_val.resize_(max_val.shape) | 
|  | self.min_val.copy_(min_val) | 
|  | self.max_val.copy_(max_val) | 
|  | return x_orig | 
|  |  | 
|  | @torch.jit.export | 
|  | def calculate_qparams(self): | 
|  | return self._calculate_qparams(self.min_val, self.max_val) | 
|  |  | 
|  | def extra_repr(self): | 
|  | return "min_val={}, max_val={}".format(self.min_val, self.max_val) | 
|  |  | 
|  | def _load_from_state_dict( | 
|  | self, | 
|  | state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], | 
|  | prefix: str, | 
|  | local_metadata: Dict[str, torch.Tensor], | 
|  | strict: bool, | 
|  | missing_keys: List[str], | 
|  | unexpected_keys: List[str], | 
|  | error_msgs: List[str], | 
|  | ): | 
|  | version = local_metadata.get("version", None) | 
|  | if version is None or version < 3: | 
|  | local_state = ["min_vals", "max_vals"] | 
|  | expected_min_name = "min_vals" | 
|  | expected_max_name = "max_vals" | 
|  | else: | 
|  | local_state = ["min_val", "max_val"] | 
|  | expected_min_name = "min_val" | 
|  | expected_max_name = "max_val" | 
|  | for name in local_state: | 
|  | key = prefix + name | 
|  | if key in state_dict: | 
|  | val = state_dict[key] | 
|  | # Custom handling to allow loading min_val or max_val | 
|  | # of size N into uninitialized buffers of size 0. The | 
|  | # buffers are resized here, and the values are copied in | 
|  | # the default state_dict loading code of the parent. | 
|  | if name == expected_min_name: | 
|  | self.min_val.resize_(val.shape) | 
|  | elif name == expected_max_name: | 
|  | self.max_val.resize_(val.shape) | 
|  | else: | 
|  | warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name)) | 
|  | # For torchscript module we need to update the attributes here since we do not | 
|  | # call the `_load_from_state_dict` function defined module.py | 
|  | if torch.jit.is_scripting(): | 
|  | if name == expected_min_name: | 
|  | self.min_val.copy_(val) | 
|  | elif name == expected_max_name: | 
|  | self.max_val.copy_(val) | 
|  | else: | 
|  | warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name)) | 
|  | elif strict: | 
|  | missing_keys.append(key) | 
|  |  | 
|  | if not torch.jit.is_scripting(): | 
|  | super(PerChannelMinMaxObserver, self)._load_from_state_dict( | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | False, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ) | 
|  |  | 
|  | def _load_from_state_dict_script( | 
|  | self, | 
|  | state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], | 
|  | prefix: str, | 
|  | local_metadata: Dict[str, torch.Tensor], | 
|  | strict: bool, | 
|  | missing_keys: List[str], | 
|  | unexpected_keys: List[str], | 
|  | error_msgs: List[str], | 
|  | ): | 
|  |  | 
|  | self._load_from_state_dict( | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | strict, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ) | 
|  |  | 
|  | @torch.jit.export | 
|  | def reset_min_max_vals(self): | 
|  | """Resets the min/max values.""" | 
|  | self.min_val = torch.tensor([]) | 
|  | self.max_val = torch.tensor([]) | 
|  |  | 
|  |  | 
|  | class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): | 
|  | r"""Observer module for computing the quantization parameters based on the | 
|  | running per channel min and max values. | 
|  |  | 
|  | This observer uses the tensor min/max statistics to compute the per channel | 
|  | quantization parameters. The module records the running minimum and maximum | 
|  | of incoming tensors, and uses this statistic to compute the quantization | 
|  | parameters. | 
|  |  | 
|  | Args: | 
|  | averaging_constant: Averaging constant for min/max. | 
|  | ch_axis: Channel axis | 
|  | dtype: Quantized data type | 
|  | qscheme: Quantization scheme to be used | 
|  | reduce_range: Reduces the range of the quantized data type by 1 bit | 
|  | quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  | quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. | 
|  |  | 
|  | The quantization parameters are computed the same way as in | 
|  | :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the | 
|  | difference that the running min/max values are stored per channel. | 
|  | Scales and zero points are thus computed per channel as well. | 
|  |  | 
|  | .. note:: If the running minimum equals to the running maximum, the scales | 
|  | and zero_points are set to 1.0 and 0. | 
|  | """ | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | averaging_constant=0.01, | 
|  | ch_axis=0, | 
|  | dtype=torch.quint8, | 
|  | qscheme=torch.per_channel_affine, | 
|  | reduce_range=False, | 
|  | quant_min=None, | 
|  | quant_max=None, | 
|  | **kwargs | 
|  | ) -> None: | 
|  | super(MovingAveragePerChannelMinMaxObserver, self).__init__( | 
|  | ch_axis=ch_axis, | 
|  | dtype=dtype, | 
|  | qscheme=qscheme, | 
|  | reduce_range=reduce_range, | 
|  | quant_min=quant_min, | 
|  | quant_max=quant_max, | 
|  | **kwargs | 
|  | ) | 
|  | self.averaging_constant = averaging_constant | 
|  |  | 
|  | def forward(self, x_orig): | 
|  | if x_orig.numel() == 0: | 
|  | return x_orig | 
|  | x = x_orig.detach()  # avoid keeping autograd tape | 
|  | x = x.to(self.min_val.dtype) | 
|  | min_val = self.min_val | 
|  | max_val = self.max_val | 
|  | x_dim = x.size() | 
|  |  | 
|  | new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416 | 
|  | new_axis_list[self.ch_axis] = 0 | 
|  | new_axis_list[0] = self.ch_axis | 
|  | y = x.permute(new_axis_list) | 
|  | y = torch.flatten(y, start_dim=1) | 
|  | if min_val.numel() == 0 or max_val.numel() == 0: | 
|  | min_val, max_val = torch._aminmax(y, 1) | 
|  | else: | 
|  | min_val_cur, max_val_cur = torch._aminmax(y, 1) | 
|  | min_val = min_val + self.averaging_constant * (min_val_cur - min_val) | 
|  | max_val = max_val + self.averaging_constant * (max_val_cur - max_val) | 
|  | self.min_val.resize_(min_val.shape) | 
|  | self.max_val.resize_(max_val.shape) | 
|  | self.min_val.copy_(min_val) | 
|  | self.max_val.copy_(max_val) | 
|  | return x_orig | 
|  |  | 
|  |  | 
|  | class HistogramObserver(_ObserverBase): | 
|  | r""" | 
|  | The module records the running histogram of tensor values along with | 
|  | min/max values. ``calculate_qparams`` will calculate scale and zero_point. | 
|  |  | 
|  | Args: | 
|  | bins: Number of bins to use for the histogram | 
|  | upsample_rate: Factor by which the histograms are upsampled, this is | 
|  | used to interpolate histograms with varying ranges across observations | 
|  | dtype: Quantized data type | 
|  | qscheme: Quantization scheme to be used | 
|  | reduce_range: Reduces the range of the quantized data type by 1 bit | 
|  |  | 
|  | The scale and zero point are computed as follows: | 
|  |  | 
|  | 1. Create the histogram of the incoming inputs. | 
|  | The histogram is computed continuously, and the ranges per bin change | 
|  | with every new tensor observed. | 
|  | 2. Search the distribution in the histogram for optimal min/max values. | 
|  | The search for the min/max values ensures the minimization of the | 
|  | quantization error with respect to the floating point model. | 
|  | 3. Compute the scale and zero point the same way as in the | 
|  | :class:`~torch.ao.quantization.MinMaxObserver` | 
|  | """ | 
|  | histogram: torch.Tensor | 
|  | min_val: torch.Tensor | 
|  | max_val: torch.Tensor | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | bins: int = 2048, | 
|  | upsample_rate: int = 128, | 
|  | dtype: torch.dtype = torch.quint8, | 
|  | qscheme=torch.per_tensor_affine, | 
|  | reduce_range=False, | 
|  | factory_kwargs=None, | 
|  | ) -> None: | 
|  | # bins: The number of bins used for histogram calculation. | 
|  | super(HistogramObserver, self).__init__( | 
|  | dtype=dtype, | 
|  | qscheme=qscheme, | 
|  | reduce_range=reduce_range, | 
|  | factory_kwargs=factory_kwargs, | 
|  | ) | 
|  | factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) | 
|  | self.bins = bins | 
|  | self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) | 
|  | self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) | 
|  | self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) | 
|  | self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits | 
|  | self.upsample_rate = upsample_rate | 
|  |  | 
|  | def _get_norm( | 
|  | self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor | 
|  | ) -> torch.Tensor: | 
|  | r""" | 
|  | Compute the norm of the values uniformaly distributed between | 
|  | delta_begin and delta_end. | 
|  | Currently only L2 norm is supported. | 
|  |  | 
|  | norm = density * (integral_{begin, end} x^2) | 
|  | = density * (end^3 - begin^3) / 3 | 
|  | """ | 
|  | norm = ( | 
|  | delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin | 
|  | ) / 3 | 
|  | return density * norm | 
|  |  | 
|  | def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): | 
|  | r""" | 
|  | Compute the quantization error if we use start_bin to end_bin as the | 
|  | min and max to do the quantization. | 
|  | """ | 
|  | bin_width = (self.max_val.item() - self.min_val.item()) / self.bins | 
|  |  | 
|  | dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins | 
|  | if dst_bin_width == 0.0: | 
|  | return 0.0 | 
|  |  | 
|  | src_bin = torch.arange(self.bins, device=self.histogram.device) | 
|  | # distances from the beginning of first dst_bin to the beginning and | 
|  | # end of src_bin | 
|  | src_bin_begin = (src_bin - next_start_bin) * bin_width | 
|  | src_bin_end = src_bin_begin + bin_width | 
|  |  | 
|  | # which dst_bins the beginning and end of src_bin belong to? | 
|  | dst_bin_of_begin = torch.clamp( | 
|  | src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1 | 
|  | ) | 
|  | dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width | 
|  |  | 
|  | dst_bin_of_end = torch.clamp( | 
|  | src_bin_end // dst_bin_width, 0, self.dst_nbins - 1 | 
|  | ) | 
|  | dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width | 
|  |  | 
|  | density = self.histogram / bin_width | 
|  |  | 
|  | norm = torch.zeros(self.bins, device=self.histogram.device) | 
|  |  | 
|  | delta_begin = src_bin_begin - dst_bin_of_begin_center | 
|  | delta_end = dst_bin_width / 2 | 
|  | norm += self._get_norm(delta_begin, | 
|  | torch.ones(self.bins, device=self.histogram.device) * delta_end, | 
|  | density) | 
|  |  | 
|  | norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( | 
|  | torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density | 
|  | ) | 
|  |  | 
|  | dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 | 
|  |  | 
|  | delta_begin = -dst_bin_width / 2 | 
|  | delta_end = src_bin_end - dst_bin_of_end_center | 
|  | norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) | 
|  |  | 
|  | return norm.sum().item() | 
|  |  | 
|  | def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | r"""Non-linear parameter search. | 
|  |  | 
|  | An approximation for L2 error minimization for selecting min/max. | 
|  | By selecting new min/max, we filter out outliers in input distribution. | 
|  | This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in | 
|  | caffe2/quantization/server/norm_minimization.cc | 
|  | """ | 
|  | assert self.histogram.size()[0] == self.bins, "bins mistmatch" | 
|  | bin_width = (self.max_val - self.min_val) / self.bins | 
|  |  | 
|  | # cumulative sum | 
|  | total = torch.sum(self.histogram).item() | 
|  | cSum = torch.cumsum(self.histogram, dim=0) | 
|  |  | 
|  | stepsize = 1e-5  # granularity | 
|  | alpha = 0.0  # lower bound | 
|  | beta = 1.0  # upper bound | 
|  | start_bin = 0 | 
|  | end_bin = self.bins - 1 | 
|  | norm_min = float("inf") | 
|  |  | 
|  | while alpha < beta: | 
|  | # Find the next step | 
|  | next_alpha = alpha + stepsize | 
|  | next_beta = beta - stepsize | 
|  |  | 
|  | # find the left and right bins between the quantile bounds | 
|  | l = start_bin | 
|  | r = end_bin | 
|  | while l < end_bin and cSum[l] < next_alpha * total: | 
|  | l = l + 1 | 
|  | while r > start_bin and cSum[r] > next_beta * total: | 
|  | r = r - 1 | 
|  |  | 
|  | # decide the next move | 
|  | next_start_bin = start_bin | 
|  | next_end_bin = end_bin | 
|  | if (l - start_bin) > (end_bin - r): | 
|  | # move the start bin | 
|  | next_start_bin = l | 
|  | alpha = next_alpha | 
|  | else: | 
|  | # move the end bin | 
|  | next_end_bin = r | 
|  | beta = next_beta | 
|  |  | 
|  | if next_start_bin == start_bin and next_end_bin == end_bin: | 
|  | continue | 
|  |  | 
|  | # calculate the quantization error using next_start_bin and next_end_bin | 
|  | norm = self._compute_quantization_error(next_start_bin, next_end_bin) | 
|  |  | 
|  | if norm > norm_min: | 
|  | break | 
|  | norm_min = norm | 
|  | start_bin = next_start_bin | 
|  | end_bin = next_end_bin | 
|  |  | 
|  | new_min = self.min_val + bin_width * start_bin | 
|  | new_max = self.min_val + bin_width * (end_bin + 1) | 
|  | return new_min, new_max | 
|  |  | 
|  | def _adjust_min_max( | 
|  | self, combined_min: torch.Tensor, combined_max: torch.Tensor, upsample_rate: int | 
|  | ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: | 
|  | # We ensure that: | 
|  | # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins) | 
|  | # This allows us to have a common grid of resolution s, where we can align | 
|  | # the input histogram | 
|  | # start_idx maps min_val to the histogram bin index. | 
|  |  | 
|  | hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate) | 
|  | downsample_rate = int( | 
|  | torch.ceil( | 
|  | (combined_max - combined_min) / (self.bins * hist_bin_width) | 
|  | ).item() | 
|  | ) | 
|  | e = downsample_rate * (self.bins * hist_bin_width) - ( | 
|  | combined_max - combined_min | 
|  | ) | 
|  | # Relax only the max, not the min, so that for one sided distributions, min stays at zero | 
|  | combined_max = combined_max + e | 
|  | combined_min = combined_min | 
|  | start_idx = int( | 
|  | torch.round((self.min_val - combined_min) / hist_bin_width).item() | 
|  | ) | 
|  | return combined_min, combined_max, downsample_rate, start_idx | 
|  |  | 
|  | def _combine_histograms( | 
|  | self, | 
|  | orig_hist: torch.Tensor, | 
|  | new_hist: torch.Tensor, | 
|  | upsample_rate: int, | 
|  | downsample_rate: int, | 
|  | start_idx: int, | 
|  | Nbins: int, | 
|  | ) -> torch.Tensor: | 
|  | # First up-sample the histogram with new data by a factor of L | 
|  | # This creates an approximate probability density thats piecwise constant | 
|  | upsampled_histogram = new_hist.repeat_interleave(upsample_rate) | 
|  | # Now insert the upsampled histogram into the output | 
|  | # histogram, which is initialized with zeros. | 
|  | # The offset at which the histogram is introduced is determined | 
|  | # by the start index as the output histogram can cover a wider range | 
|  | histogram_with_output_range = torch.zeros( | 
|  | (Nbins * downsample_rate), device=orig_hist.device | 
|  | ) | 
|  | histogram_with_output_range[ | 
|  | start_idx : Nbins * upsample_rate + start_idx | 
|  | ] = upsampled_histogram | 
|  | # Compute integral histogram, double precision is needed to ensure | 
|  | # that there are no overflows | 
|  | integral_histogram = torch.cumsum( | 
|  | histogram_with_output_range, 0, dtype=torch.double | 
|  | )[downsample_rate - 1 :: downsample_rate] | 
|  | # Finally perform interpolation | 
|  | shifted_integral_histogram = torch.zeros((Nbins), device=orig_hist.device) | 
|  | shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1] | 
|  | interpolated_histogram = ( | 
|  | integral_histogram - shifted_integral_histogram | 
|  | ) / upsample_rate | 
|  | orig_hist = orig_hist + interpolated_histogram.to(torch.float) | 
|  | return orig_hist | 
|  |  | 
|  | def forward(self, x_orig: torch.Tensor) -> torch.Tensor: | 
|  | if x_orig.numel() == 0: | 
|  | return x_orig | 
|  | x = x_orig.detach() | 
|  | min_val = self.min_val | 
|  | max_val = self.max_val | 
|  | same_values = min_val.item() == max_val.item() | 
|  | is_uninitialized = min_val == float("inf") and max_val == float("-inf") | 
|  | if is_uninitialized or same_values: | 
|  | min_val, max_val = torch._aminmax(x) | 
|  | self.min_val.resize_(min_val.shape) | 
|  | self.min_val.copy_(min_val) | 
|  | self.max_val.resize_(max_val.shape) | 
|  | self.max_val.copy_(max_val) | 
|  | assert ( | 
|  | min_val.numel() == 1 and max_val.numel() == 1 | 
|  | ), "histogram min/max values must be scalar." | 
|  | torch.histc( | 
|  | x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram | 
|  | ) | 
|  | else: | 
|  | new_min, new_max = torch._aminmax(x) | 
|  | combined_min = torch.min(new_min, min_val) | 
|  | combined_max = torch.max(new_max, max_val) | 
|  | # combine the existing histogram and new histogram into 1 histogram | 
|  | # We do this by first upsampling the histogram to a dense grid | 
|  | # and then downsampling the histogram efficiently | 
|  | ( | 
|  | combined_min, | 
|  | combined_max, | 
|  | downsample_rate, | 
|  | start_idx, | 
|  | ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate) | 
|  | assert ( | 
|  | combined_min.numel() == 1 and combined_max.numel() == 1 | 
|  | ), "histogram min/max values must be scalar." | 
|  | combined_histogram = torch.histc( | 
|  | x, self.bins, min=int(combined_min), max=int(combined_max) | 
|  | ) | 
|  | if combined_min == min_val and combined_max == max_val: | 
|  | combined_histogram += self.histogram | 
|  | else: | 
|  | combined_histogram = self._combine_histograms( | 
|  | combined_histogram, | 
|  | self.histogram, | 
|  | self.upsample_rate, | 
|  | downsample_rate, | 
|  | start_idx, | 
|  | self.bins, | 
|  | ) | 
|  |  | 
|  | self.histogram.detach_().resize_(combined_histogram.shape) | 
|  | self.histogram.copy_(combined_histogram) | 
|  | self.min_val.detach_().resize_(combined_min.shape) | 
|  | self.min_val.copy_(combined_min) | 
|  | self.max_val.detach_().resize_(combined_max.shape) | 
|  | self.max_val.copy_(combined_max) | 
|  | return x_orig | 
|  |  | 
|  | @torch.jit.export | 
|  | def calculate_qparams(self): | 
|  | is_uninitialized = self.min_val == float("inf") and self.max_val == float( | 
|  | "-inf" | 
|  | ) | 
|  | if is_uninitialized: | 
|  | warnings.warn( | 
|  | "must run observer before calling calculate_qparams.\ | 
|  | Returning default scale and zero point " | 
|  | ) | 
|  | return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor([0], device=self.min_val.device.type) | 
|  | assert self.bins == len(self.histogram), ( | 
|  | "The number of bins in histogram should be equal to the number of bins " | 
|  | "supplied while making this observer" | 
|  | ) | 
|  |  | 
|  | new_min, new_max = self._non_linear_param_search() | 
|  |  | 
|  | return self._calculate_qparams(new_min, new_max) | 
|  |  | 
|  | def _save_to_state_dict(self, destination, prefix, keep_vars): | 
|  | super(HistogramObserver, self)._save_to_state_dict( | 
|  | destination, prefix, keep_vars | 
|  | ) | 
|  | destination[prefix + "min_val"] = self.min_val | 
|  | destination[prefix + "max_val"] = self.max_val | 
|  |  | 
|  | def _load_from_state_dict( | 
|  | self, | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | strict, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ): | 
|  | version = local_metadata.get("version", None) | 
|  |  | 
|  | if version is None or version < 3: | 
|  | # if min_val and max_val are not initialized, update their shape | 
|  | # to account for the differences between v2 and v3 | 
|  | min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" | 
|  | if min_val_name in state_dict: | 
|  | if state_dict[min_val_name].shape == torch.Size([0]): | 
|  | state_dict[min_val_name] = torch.tensor(float("inf")) | 
|  | if max_val_name in state_dict: | 
|  | if state_dict[max_val_name].shape == torch.Size([0]): | 
|  | state_dict[max_val_name] = torch.tensor(float("-inf")) | 
|  |  | 
|  | local_state = ["min_val", "max_val"] | 
|  | for name in local_state: | 
|  | key = prefix + name | 
|  | if key in state_dict: | 
|  | val = state_dict[key] | 
|  | setattr(self, name, val) | 
|  | elif strict: | 
|  | missing_keys.append(key) | 
|  | super(HistogramObserver, self)._load_from_state_dict( | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | strict, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ) | 
|  |  | 
|  |  | 
|  | class PlaceholderObserver(ObserverBase): | 
|  | r""" | 
|  | Observer that doesn't do anything and just passes its configuration to the | 
|  | quantized module's ``.from_float()``. | 
|  |  | 
|  | Can be used for quantization to float16 which doesn't require determining | 
|  | ranges. | 
|  |  | 
|  | Args: | 
|  | dtype: Quantized data type | 
|  | custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation | 
|  | (Can be used in Graph Mode Passes for special case ops). | 
|  | """ | 
|  |  | 
|  | def __init__( | 
|  | self, dtype=torch.float32, custom_op_name="", compute_dtype=None | 
|  | ) -> None: | 
|  | super(PlaceholderObserver, self).__init__(dtype=dtype) | 
|  | # dtype of input of the target operator, e.g. for dynamic quantization | 
|  | # ops, the dtype will be float32 | 
|  | self.dtype = dtype | 
|  | self.custom_op = custom_op_name | 
|  | # used for configuration of computation type for dynamic quantization | 
|  | if compute_dtype: | 
|  | self.compute_dtype = compute_dtype | 
|  |  | 
|  | def forward(self, x): | 
|  | return x | 
|  |  | 
|  | @torch.jit.export | 
|  | def calculate_qparams(self): | 
|  | raise Exception( | 
|  | "calculate_qparams should not be called for PlaceholderObserver" | 
|  | ) | 
|  |  | 
|  |  | 
|  | class RecordingObserver(_ObserverBase): | 
|  | r""" | 
|  | The module is mainly for debug and records the tensor values during runtime. | 
|  |  | 
|  | Args: | 
|  | dtype: Quantized data type | 
|  | qscheme: Quantization scheme to be used | 
|  | reduce_range: Reduces the range of the quantized data type by 1 bit | 
|  | """ | 
|  | __annotations__ = {"tensor_val": List[Optional[torch.Tensor]]} | 
|  |  | 
|  | def __init__(self, **kwargs): | 
|  | super(RecordingObserver, self).__init__(**kwargs) | 
|  | self.tensor_val = [] | 
|  |  | 
|  | def forward(self, x): | 
|  | self.tensor_val.append(x.clone()) | 
|  | return x | 
|  |  | 
|  | @torch.jit.export | 
|  | def calculate_qparams(self): | 
|  | raise Exception("calculate_qparams should not be called for RecordingObserver") | 
|  |  | 
|  | @torch.jit.export | 
|  | def get_tensor_value(self): | 
|  | return self.tensor_val | 
|  |  | 
|  |  | 
|  | class NoopObserver(ObserverBase): | 
|  | r""" | 
|  | Observer that doesn't do anything and just passes its configuration to the | 
|  | quantized module's ``.from_float()``. | 
|  |  | 
|  | Primarily used for quantization to float16 which doesn't require determining | 
|  | ranges. | 
|  |  | 
|  | Args: | 
|  | dtype: Quantized data type | 
|  | custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation | 
|  | (Can be used in Graph Mode Passes for special case ops). | 
|  | """ | 
|  |  | 
|  | def __init__(self, dtype=torch.float16, custom_op_name="") -> None: | 
|  | super(NoopObserver, self).__init__(dtype=dtype) | 
|  | self.dtype = dtype | 
|  | self.custom_op = custom_op_name | 
|  |  | 
|  | def forward(self, x): | 
|  | return x | 
|  |  | 
|  | @torch.jit.export | 
|  | def calculate_qparams(self): | 
|  | raise Exception("calculate_qparams should not be called for NoopObserver") | 
|  |  | 
|  |  | 
|  | def _is_observer_script_module(mod, obs_type_name): | 
|  | """Returns true if given mod is an instance of Observer script module.""" | 
|  | if isinstance(mod, torch.jit.RecursiveScriptModule): | 
|  | # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver' | 
|  | suffix = mod._c.qualified_name.split(".", 1)[1] | 
|  | name = re.sub(r"\.___torch_mangle_\d+", "", suffix) | 
|  | return obs_type_name in name | 
|  | return False | 
|  |  | 
|  |  | 
|  | def _is_activation_post_process(module): | 
|  | return ( | 
|  | isinstance(module, torch.ao.quantization.ObserverBase) | 
|  | or isinstance(module, torch.ao.quantization.FakeQuantize) | 
|  | or _is_observer_script_module(module, "quantization.observer") | 
|  | ) | 
|  |  | 
|  |  | 
|  | def _is_per_channel_script_obs_instance(module): | 
|  | if isinstance(module, torch.jit.RecursiveScriptModule): | 
|  | return _is_observer_script_module( | 
|  | module, "quantization.observer.PerChannelMinMaxObserver" | 
|  | ) or _is_observer_script_module( | 
|  | module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" | 
|  | ) | 
|  | return False | 
|  |  | 
|  |  | 
|  | def get_observer_state_dict(mod): | 
|  | r""" | 
|  | Returns the state dict corresponding to the observer stats. | 
|  | Traverse the model state_dict and extract out the stats. | 
|  | """ | 
|  | od = OrderedDict() | 
|  | if isinstance(mod, torch.jit.RecursiveScriptModule): | 
|  | for k, v in mod.state_dict().items(): | 
|  | if "observer" in k: | 
|  | od[k] = v | 
|  | else: | 
|  | # path for GraphModule and nn.Module (eager mode) | 
|  | for k, v in mod.state_dict().items(): | 
|  | if "activation_post_process" in k: | 
|  | od[k] = v | 
|  | od._metadata = mod.state_dict()._metadata  # type: ignore[attr-defined] | 
|  | return od | 
|  |  | 
|  |  | 
|  | def load_observer_state_dict(mod, obs_dict): | 
|  | r""" | 
|  | Given input model and a state_dict containing model observer stats, | 
|  | load the stats back into the model. The observer state_dict can be saved | 
|  | using torch.ao.quantization.get_observer_state_dict | 
|  | """ | 
|  | missing_keys: List[str] = [] | 
|  | unexpected_keys: List[str] = [] | 
|  | for name, module in mod.named_modules(): | 
|  | prefix = name + "." | 
|  | if _is_activation_post_process(module): | 
|  | if _is_per_channel_script_obs_instance(module): | 
|  | # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. | 
|  | # However this is not called when the module is scripted and we end up calling the default one in module.py | 
|  | module._load_from_state_dict_script( | 
|  | obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] | 
|  | ) | 
|  | else: | 
|  | module._load_from_state_dict( | 
|  | obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] | 
|  | ) | 
|  | for k in missing_keys: | 
|  | if "observer" in k or "activation_post_process" in k: | 
|  | raise Exception("Missing keys for observer {} in state_dict".format(k)) | 
|  | for k in unexpected_keys: | 
|  | if "observer" in k or "activation_post_process" in k: | 
|  | raise Exception("Unexpected keys for observer {} in state_dict".format(k)) | 
|  |  | 
|  |  | 
|  | # Restrict activations to be in the range (0,127) | 
|  | default_observer = MinMaxObserver.with_args(reduce_range=True) | 
|  | """ | 
|  | Default observer for static quantization, usually used for debugging. | 
|  | """ | 
|  |  | 
|  | default_placeholder_observer = PlaceholderObserver | 
|  | """ | 
|  | Default placeholder observer, usually used for quantization to torch.float16. | 
|  | """ | 
|  |  | 
|  | default_debug_observer = RecordingObserver | 
|  | """ | 
|  | Default debug-only observer. | 
|  | """ | 
|  |  | 
|  | default_weight_observer = MinMaxObserver.with_args( | 
|  | dtype=torch.qint8, qscheme=torch.per_tensor_symmetric | 
|  | ) | 
|  | """ | 
|  | Default weight observer. | 
|  | """ | 
|  |  | 
|  | default_histogram_observer = HistogramObserver.with_args(reduce_range=True) | 
|  | """ | 
|  | Default histogram observer, usually used for PTQ. | 
|  | """ | 
|  |  | 
|  | default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( | 
|  | dtype=torch.qint8, qscheme=torch.per_channel_symmetric | 
|  | ) | 
|  | """ | 
|  | Default per-channel weight observer, usually used on backends where per-channel | 
|  | weight quantization is supported, such as `fbgemm`. | 
|  | """ | 
|  |  | 
|  | default_dynamic_quant_observer = PlaceholderObserver.with_args( | 
|  | dtype=torch.float, compute_dtype=torch.quint8 | 
|  | ) | 
|  | """ | 
|  | Default observer for dynamic quantization. | 
|  | """ | 
|  |  | 
|  | default_float_qparams_observer = PerChannelMinMaxObserver.with_args( | 
|  | dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 | 
|  | ) | 
|  | """ | 
|  | Default observer for a floating point zero-point. | 
|  | """ | 
|  |  | 
|  | default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( | 
|  | dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 | 
|  | ) | 
|  | """ | 
|  | Default observer for a floating point zero-point and 4 bit activations. | 
|  | """ |