blob: 3a57e0a8759e9c4a7ce37941c54e730396baabfa [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
import torch
from functools import partial
class Observer(nn.Module):
r"""Default Observer Module
A default implementation of the observer module, only works for
`per_tensor_affine` quantization scheme.
The module will record the running average of max and min value of the
observed Tensor and calulate_qparams will calculate the scale and zero_point
Other types of Observers should follow the same API, it can take arbitrary
number of keyward arguments. In forward, it will update the statistics of
the observed Tensor. And it should provide a `calculate_qparam` function
that computes the quantization parameters given the collected statistics.
TODO: Maybe add an abstract Observer class that enforces these rules?
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine):
super(Observer, self).__init__()
self.dtype = dtype
self.qscheme = qscheme
assert self.qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric), \
'Default Observer only works for per_tensor_affine and \
per_tensor_symmetric quantization scheme'
assert self.dtype in (torch.qint8, torch.quint8), \
'Default Observer only works for qint8 and quint data type'
self.min_val = None
self.max_val = None
def forward(self, x):
if self.min_val is None or self.max_val is None:
self.min_val = torch.min(x)
self.max_val = torch.max(x)
else:
self.min_val = torch.min(torch.min(x), self.min_val)
self.max_val = torch.max(torch.max(x), self.max_val)
def calculate_qparams(self):
if self.dtype == torch.qint8:
qmin, qmax = -128, 127
else:
qmin, qmax = 0, 255
n_levels = 255.0
if self.max_val is None or self.min_val is None:
raise Exception('must run observer before calling calculate_qparams!')
max_val, min_val = self.max_val.item(), self.min_val.item()
if max_val == min_val:
scale = 1.0
zero_point = 0
else:
if self.qscheme == torch.per_tensor_symmetric:
max_val = max(-min_val, max_val)
scale = max_val / 127.0
scale = max(scale, torch.finfo(torch.float32).eps)
zero_point = 0 if self.dtype == torch.qint8 else 128
else:
scale = (max_val - min_val) / n_levels
scale = max(scale, torch.finfo(torch.float32).eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
return torch.tensor([scale, zero_point])
def observer(observer_cls, **kwargs):
return partial(observer_cls, **kwargs)
def default_observer(**kwargs):
return observer(Observer, **kwargs)
def default_weight_observer(**kwargs):
kwargs.setdefault('dtype', torch.qint8)
kwargs.setdefault('qscheme', torch.per_tensor_symmetric)
return observer(Observer, **kwargs)