Add per channel observer (#25887)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25887
ghstack-source-id: 90383258
Add per channel observer to compute the qparams for each channel.
Test Plan:
buck test mode/dev caffe2/test:quantization -- 'test_per_channel_minmax_observer'
buck test mode/dev caffe2/test:quantization -- 'test_per_channel_minmax_observer_scriptable'
Differential Revision: D17137226
fbshipit-source-id: 0b1c93e3cbcda86f5c4e30f7cd94c670f2665063
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 9c3712a..9f5bdcf 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -1,4 +1,5 @@
import unittest
+import math
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
@@ -9,7 +10,7 @@
QConfig_dynamic, default_weight_observer, dump_tensor,\
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
- default_dynamic_qconfig, HistogramObserver, MinMaxObserver, TensorObserver, QuantWrapper
+ default_dynamic_qconfig, HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, TensorObserver, QuantWrapper
from common_utils import run_tests
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
@@ -776,6 +777,65 @@
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
+ @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
+ qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)),
+ ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
+ def test_per_channel_minmax_observer(self, qdtype, qscheme, ch_axis, reduce_range):
+ # reduce_range cannot be true for symmetric quantization with uint8
+ if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
+ reduce_range = False
+ myobs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme)
+ x = torch.tensor(
+ [
+ [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
+ [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
+ ]
+ )
+ result = myobs(x)
+ self.assertEqual(result, x)
+ qparams = myobs.calculate_qparams()
+ ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]]
+ ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
+ per_channel_symmetric_ref_scales = [
+ [0.04705882, 0.06274509],
+ [0.03921569, 0.0627451],
+ [0.04705882, 0.0627451],
+ [0.05490196, 0.0627451],
+ ]
+ per_channel_affine_ref_scales = [
+ [0.02352941, 0.04705882],
+ [0.03529412, 0.03137255],
+ [0.03921569, 0.03137255],
+ [0.04313726, 0.04313726],
+ ]
+ per_channel_affine_qint8_zp = [
+ [-128, -43],
+ [-15, -128],
+ [-26, -128],
+ [-35, -58],
+ ]
+ per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]
+
+ self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
+ self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
+ if qscheme == torch.per_channel_symmetric:
+ ref_scales = per_channel_symmetric_ref_scales[ch_axis]
+ ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
+ else:
+ ref_scales = per_channel_affine_ref_scales[ch_axis]
+ ref_zero_points = (
+ per_channel_affine_qint8_zp[ch_axis]
+ if qdtype is torch.qint8
+ else per_channel_affine_quint8_zp[ch_axis]
+ )
+
+ if reduce_range:
+ ref_scales = [s * 255 / 127 for s in ref_scales]
+ ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
+
+ self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype)))
+ self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))
+
def test_observer_scriptable(self):
obs = torch.quantization.default_observer()()
scripted = torch.jit.script(obs)
diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py
index 8d8872e..6faf479 100644
--- a/torch/quantization/observer.py
+++ b/torch/quantization/observer.py
@@ -35,8 +35,11 @@
assert self.qscheme in (
torch.per_tensor_affine,
torch.per_tensor_symmetric,
- ), "Default Observer only works for per_tensor_affine and \
- per_tensor_symmetric quantization scheme"
+ torch.per_channel_affine,
+ torch.per_channel_symmetric,
+ ), "Default Observer only works for per_tensor_affine, \
+ per_tensor_symmetric, per_channel_affine and \
+ per_channel_symmetric quantization scheme"
assert self.dtype in (
torch.qint8,
torch.quint8,
@@ -50,6 +53,35 @@
def calculate_qparams(self, **kwargs):
pass
+ def _calculate_per_channel_qparams(self, min_vals, max_vals):
+ # type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
+ """
+ Given min and max value tensors, this function calculates per channel
+ quantization parameters
+ """
+ if min_vals is None or max_vals is None:
+ warnings.warn(
+ "must run observer before calling calculate_qparams.\
+ Returning default scale and zero point "
+ )
+ return torch.tensor([1.0]), torch.tensor([0])
+
+ for i in range(len(min_vals)):
+ assert (
+ min_vals[i] <= max_vals[i]
+ ), "min {} should be less than max {}".format(min_vals[i], max_vals[i])
+
+ scales = torch.ones(min_vals.size())
+ zero_points = torch.ones(min_vals.size())
+ for i in range(len(scales)):
+ qparam = self._calculate_qparams(
+ min_vals[i], max_vals[i]
+ )
+ scales[i] = float(qparam[0])
+ zero_points[i] = int(qparam[1])
+
+ return scales, zero_points
+
def _calculate_qparams(self, min_val, max_val):
# type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
"""
@@ -85,7 +117,7 @@
scale = 1.0
zero_point = 0
else:
- if self.qscheme == torch.per_tensor_symmetric:
+ if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
max_val = max(-min_val, max_val)
scale = max_val / ((qmax - qmin) / 2)
scale = max(scale, self.eps)
@@ -156,6 +188,56 @@
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
+class PerChannelMinMaxObserver(ObserverBase):
+ r"""Per Channel Observer Module
+ The module will record the running average of max and min value for each
+ channel of the observed Tensor and calculate_qparams will calculate
+ scales and zero_points for each channel
+ """
+
+ def __init__(self, ch_axis=0, **kwargs):
+ super(PerChannelMinMaxObserver, self).__init__(**kwargs)
+ self.ch_axis = ch_axis
+ self.min_vals = None
+ self.max_vals = None
+ if (
+ self.qscheme == torch.per_channel_symmetric
+ and self.reduce_range
+ and self.dtype == torch.quint8
+ ):
+ raise NotImplementedError(
+ "Cannot reduce range for symmetric quantization for quint8"
+ )
+
+ def forward(self, x):
+ with torch.no_grad():
+ min_vals = self.min_vals
+ max_vals = self.max_vals
+ x_dim = x.size()
+
+ new_axis_list = list(range(len(x_dim)))
+ new_axis_list[self.ch_axis] = 0
+ new_axis_list[0] = self.ch_axis
+ y = x.permute(tuple(new_axis_list))
+ y = torch.flatten(y, start_dim=1)
+ if min_vals is None or max_vals is None:
+ min_vals = torch.min(y, 1)[0]
+ max_vals = torch.max(y, 1)[0]
+ else:
+ min_vals = torch.min(torch.min(y, 1)[0], min_vals)
+ max_vals = torch.max(torch.max(y, 1)[0], max_vals)
+ self.min_vals = min_vals
+ self.max_vals = max_vals
+ return x
+
+ def calculate_qparams(self):
+ return self._calculate_per_channel_qparams(self.min_vals, self.max_vals)
+
+ def extra_repr(self):
+ return "min_val={}, max_val={}".format(self.min_vals, self.max_vals)
+
+
+
class HistogramObserver(ObserverBase):
r"""
The module records the running histogram of tensor values along with