Add per channel observer (#25887)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25887

ghstack-source-id: 90383258

Add per channel observer to compute the qparams for each channel.

Test Plan:
buck test mode/dev caffe2/test:quantization -- 'test_per_channel_minmax_observer'

buck test mode/dev caffe2/test:quantization -- 'test_per_channel_minmax_observer_scriptable'

Differential Revision: D17137226

fbshipit-source-id: 0b1c93e3cbcda86f5c4e30f7cd94c670f2665063
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 9c3712a..9f5bdcf 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -1,4 +1,5 @@
 import unittest
+import math
 import torch
 import torch.nn as nn
 import torch.nn.quantized as nnq
@@ -9,7 +10,7 @@
     QConfig_dynamic, default_weight_observer, dump_tensor,\
     quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
     quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
-    default_dynamic_qconfig, HistogramObserver, MinMaxObserver, TensorObserver, QuantWrapper
+    default_dynamic_qconfig, HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, TensorObserver, QuantWrapper
 
 from common_utils import run_tests
 from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
@@ -776,6 +777,65 @@
         self.assertEqual(qparams[1].item(), ref_zero_point)
         self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
 
+    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
+           qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)),
+           ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
+    def test_per_channel_minmax_observer(self, qdtype, qscheme, ch_axis, reduce_range):
+        # reduce_range cannot be true for symmetric quantization with uint8
+        if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
+            reduce_range = False
+        myobs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme)
+        x = torch.tensor(
+            [
+                [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
+                [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
+            ]
+        )
+        result = myobs(x)
+        self.assertEqual(result, x)
+        qparams = myobs.calculate_qparams()
+        ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]]
+        ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
+        per_channel_symmetric_ref_scales = [
+            [0.04705882, 0.06274509],
+            [0.03921569, 0.0627451],
+            [0.04705882, 0.0627451],
+            [0.05490196, 0.0627451],
+        ]
+        per_channel_affine_ref_scales = [
+            [0.02352941, 0.04705882],
+            [0.03529412, 0.03137255],
+            [0.03921569, 0.03137255],
+            [0.04313726, 0.04313726],
+        ]
+        per_channel_affine_qint8_zp = [
+            [-128, -43],
+            [-15, -128],
+            [-26, -128],
+            [-35, -58],
+        ]
+        per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]
+
+        self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
+        self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
+        if qscheme == torch.per_channel_symmetric:
+            ref_scales = per_channel_symmetric_ref_scales[ch_axis]
+            ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
+        else:
+            ref_scales = per_channel_affine_ref_scales[ch_axis]
+            ref_zero_points = (
+                per_channel_affine_qint8_zp[ch_axis]
+                if qdtype is torch.qint8
+                else per_channel_affine_quint8_zp[ch_axis]
+            )
+
+        if reduce_range:
+            ref_scales = [s * 255 / 127 for s in ref_scales]
+            ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
+
+        self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype)))
+        self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))
+
     def test_observer_scriptable(self):
         obs = torch.quantization.default_observer()()
         scripted = torch.jit.script(obs)
diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py
index 8d8872e..6faf479 100644
--- a/torch/quantization/observer.py
+++ b/torch/quantization/observer.py
@@ -35,8 +35,11 @@
         assert self.qscheme in (
             torch.per_tensor_affine,
             torch.per_tensor_symmetric,
-        ), "Default Observer only works for per_tensor_affine and \
-                per_tensor_symmetric quantization scheme"
+            torch.per_channel_affine,
+            torch.per_channel_symmetric,
+        ), "Default Observer only works for per_tensor_affine, \
+                per_tensor_symmetric, per_channel_affine and \
+                per_channel_symmetric quantization scheme"
         assert self.dtype in (
             torch.qint8,
             torch.quint8,
@@ -50,6 +53,35 @@
     def calculate_qparams(self, **kwargs):
         pass
 
+    def _calculate_per_channel_qparams(self, min_vals, max_vals):
+        # type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
+        """
+        Given min and max value tensors, this function calculates per channel
+        quantization parameters
+        """
+        if min_vals is None or max_vals is None:
+            warnings.warn(
+                "must run observer before calling calculate_qparams.\
+                                    Returning default scale and zero point "
+            )
+            return torch.tensor([1.0]), torch.tensor([0])
+
+        for i in range(len(min_vals)):
+            assert (
+                min_vals[i] <= max_vals[i]
+            ), "min {} should be less than max {}".format(min_vals[i], max_vals[i])
+
+        scales = torch.ones(min_vals.size())
+        zero_points = torch.ones(min_vals.size())
+        for i in range(len(scales)):
+            qparam = self._calculate_qparams(
+                min_vals[i], max_vals[i]
+            )
+            scales[i] = float(qparam[0])
+            zero_points[i] = int(qparam[1])
+
+        return scales, zero_points
+
     def _calculate_qparams(self, min_val, max_val):
         # type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
         """
@@ -85,7 +117,7 @@
             scale = 1.0
             zero_point = 0
         else:
-            if self.qscheme == torch.per_tensor_symmetric:
+            if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
                 max_val = max(-min_val, max_val)
                 scale = max_val / ((qmax - qmin) / 2)
                 scale = max(scale, self.eps)
@@ -156,6 +188,56 @@
         return "min_val={}, max_val={}".format(self.min_val, self.max_val)
 
 
+class PerChannelMinMaxObserver(ObserverBase):
+    r"""Per Channel Observer Module
+    The module will record the running average of max and min value for each
+    channel of the observed Tensor and calculate_qparams will calculate
+    scales and zero_points for each channel
+    """
+
+    def __init__(self, ch_axis=0, **kwargs):
+        super(PerChannelMinMaxObserver, self).__init__(**kwargs)
+        self.ch_axis = ch_axis
+        self.min_vals = None
+        self.max_vals = None
+        if (
+            self.qscheme == torch.per_channel_symmetric
+            and self.reduce_range
+            and self.dtype == torch.quint8
+        ):
+            raise NotImplementedError(
+                "Cannot reduce range for symmetric quantization for quint8"
+            )
+
+    def forward(self, x):
+        with torch.no_grad():
+            min_vals = self.min_vals
+            max_vals = self.max_vals
+            x_dim = x.size()
+
+            new_axis_list = list(range(len(x_dim)))
+            new_axis_list[self.ch_axis] = 0
+            new_axis_list[0] = self.ch_axis
+            y = x.permute(tuple(new_axis_list))
+            y = torch.flatten(y, start_dim=1)
+            if min_vals is None or max_vals is None:
+                min_vals = torch.min(y, 1)[0]
+                max_vals = torch.max(y, 1)[0]
+            else:
+                min_vals = torch.min(torch.min(y, 1)[0], min_vals)
+                max_vals = torch.max(torch.max(y, 1)[0], max_vals)
+            self.min_vals = min_vals
+            self.max_vals = max_vals
+            return x
+
+    def calculate_qparams(self):
+        return self._calculate_per_channel_qparams(self.min_vals, self.max_vals)
+
+    def extra_repr(self):
+        return "min_val={}, max_val={}".format(self.min_vals, self.max_vals)
+
+
+
 class HistogramObserver(ObserverBase):
     r"""
     The module records the running histogram of tensor values along with