| from __future__ import absolute_import, division, print_function, unicode_literals |
| import torch |
| from torch.nn import Module |
| from .observer import MinMaxObserver, _with_args |
| |
| class FakeQuantize(Module): |
| ''' Simulate the quantize and dequantize operations in training time. |
| The output of this module is given by |
| |
| x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale |
| |
| |
| |
| * :attr:`scale` defines the scale factor used for quantization. |
| |
| * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to |
| |
| * :attr:`quant_min` specifies the minimum allowable quantized value. |
| |
| * :attr:`quant_max` specifies the maximum allowable quantized value. |
| |
| * :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that |
| statistics can still be updated. |
| |
| * :attr:`observer_enable` controls statistics collection on tensors |
| |
| * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, |
| allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should |
| be chosen to be consistent with the dtype |
| |
| |
| Args: |
| observer (module): Module for observing statistics on input tensors and calculating |
| scale and zero-point. |
| quant_min (int): The minimum allowable quantized value. |
| quant_max (int): The maximum allowable quantized value. |
| observer_kwargs (optional): Arguments for the observer module |
| |
| Attributes: |
| observer (Module): User provided module that collects statistics on the input tensor and |
| provides a method to calculate scale and zero-point. |
| |
| """ |
| Args: |
| `observer`: Observer module that records stats of input tensor |
| `quant_min`: Tensors are fake-quantized corresponding to the |
| `quant_max`: A function that calculates quantization parameters |
| given the stats |
| `observer_kwargs` |
| ''' |
| def __init__(self, observer=MinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): |
| super(FakeQuantize, self).__init__() |
| assert quant_min <= quant_max, \ |
| 'quant_min must be less than or equal to quant_max' |
| self.quant_min = quant_min |
| self.quant_max = quant_max |
| self.fake_quant_enabled = True |
| self.observer_enabled = True |
| self.observer = observer(**observer_kwargs) |
| assert torch.iinfo(self.observer.dtype).min <= quant_min, 'quant_min out of bound' |
| assert quant_max <= torch.iinfo(self.observer.dtype).max, 'quant_max out of bound' |
| self.scale = None |
| self.zero_point = None |
| self.dtype = self.observer.dtype |
| self.qscheme = self.observer.qscheme |
| |
| def enable_fake_quant(self, enabled=True): |
| self.fake_quant_enabled = enabled |
| return self |
| |
| def disable_fake_quant(self): |
| return self.enable_fake_quant(False) |
| |
| def calculate_qparams(self): |
| return self.observer.calculate_qparams() |
| |
| def forward(self, X): |
| if self.observer_enabled: |
| self.observer(X) |
| scale, zero_point = self.calculate_qparams() |
| self.scale, self.zero_point = float(scale), int(zero_point) |
| if self.fake_quant_enabled: |
| X = torch.fake_quantize_per_tensor_affine(X, self.scale, self.zero_point, self.quant_min, self.quant_max) |
| return X |
| |
| with_args = classmethod(_with_args) |
| |
| def extra_repr(self): |
| return 'fake_quant_enabled={}, observer_enabled={},\ |
| scale={}, zero_point={}'.format( |
| self.fake_quant_enabled, self.observer_enabled, |
| self.scale, self.zero_point) |
| |
| def _save_to_state_dict(self, destination, prefix, keep_vars): |
| # We cannot currently register scalar values as buffers, so need to manually |
| # specify serialization here. |
| super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) |
| destination[prefix + 'scale'] = self.scale |
| destination[prefix + 'zero_point'] = self.zero_point |
| |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| |
| self.scale = state_dict.pop(prefix + 'scale') |
| self.zero_point = state_dict.pop(prefix + 'zero_point') |
| super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, |
| missing_keys, unexpected_keys, error_msgs) |
| |
| default_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver, quant_min=0, quant_max=255, |
| dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) |
| default_weight_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver, quant_min=-128, quant_max=127, |
| dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False) |
| |
| def disable_fake_quant(mod): |
| if type(mod) == FakeQuantize: |
| mod.disable_fake_quant() |
| |
| def enable_fake_quant(mod): |
| if type(mod) == FakeQuantize: |
| mod.enable_fake_quant() |
| |
| def disable_observer(mod): |
| if type(mod) == FakeQuantize: |
| mod.disable_observer() |
| |
| def enable_observer(mod): |
| if type(mod) == FakeQuantize: |
| mod.disable_observer() |