[quant][graph] Add a new observer type for dynamic quantization (#35455)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35455

In graph mode we need to observer the activation tensor for dynamic quantization. This observer should behave the same way as the quantization functions called in the dynamic operator.
Currently for qlinear_dynamic we call quant_utils::ChooseQuantizationParams which has its own logic for calculating scale and zero_point.
We mimic those calculations in the new observer.

Test Plan:
python test/test_quantization.py ObserverTest

Imported from OSS

Differential Revision: D20664586

fbshipit-source-id: e987ea71fff777c21e00c498504e6586e92568a2
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 6481469..fe63b69 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -14,7 +14,7 @@
     default_dynamic_qconfig, per_channel_dynamic_qconfig, HistogramObserver, MinMaxObserver, \
     PerChannelMinMaxObserver, RecordingObserver, MovingAverageMinMaxObserver, \
     MovingAveragePerChannelMinMaxObserver, QuantWrapper, default_eval_fn, \
-    float16_dynamic_qconfig
+    float16_dynamic_qconfig, MinMaxDynamicQuantObserver
 
 from torch.quantization import QConfig
 from torch.quantization import default_histogram_observer
@@ -1353,6 +1353,26 @@
             self.assertEqual(myobs.max_val, loaded_obs.max_val)
             self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
 
+
+    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4,
+                                              min_side=1, max_side=10),
+                       qparams=hu.qparams()),
+           reduce_range=st.booleans())
+    def test_per_tensor_dynamic_quant_observers(self, X, reduce_range):
+
+        X, (scale, zero_point, torch_type) = X
+        x = torch.from_numpy(X)
+
+        obs = MinMaxDynamicQuantObserver(dtype=torch.quint8, reduce_range=reduce_range)
+
+        result = obs(x)
+        qparams = obs.calculate_qparams()
+        ref = torch._choose_qparams_per_tensor(x, reduce_range)
+
+        self.assertEqual(ref[0], qparams[0])
+        self.assertEqual(ref[1], qparams[1])
+
+
     @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
            qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)),
            ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
@@ -1448,7 +1468,7 @@
             self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
 
     def test_observer_scriptable(self):
-        obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver()]
+        obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()]
         for obs in obs_list:
             scripted = torch.jit.script(obs)
 
diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py
index ddd20af..26ec4e0 100644
--- a/torch/quantization/observer.py
+++ b/torch/quantization/observer.py
@@ -439,6 +439,83 @@
         return x_orig
 
 
+class MinMaxDynamicQuantObserver(MinMaxObserver):
+    r"""Observer module for computing the quantization parameters based on the
+    tensor min and max values in dynamic quantization.
+
+    This observer will mimic the quantization steps followed in the operator
+    to compute the activation tensor quantization parameters at run-time.
+
+    Args:
+        dtype: Quantized data type
+        qscheme: Quantization scheme to be used
+        reduce_range: Reduces the range of the quantized data type by 1 bit
+
+    .. warning:: Only works with ``torch.per_tensor_symmetric`` quantization scheme
+
+    .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
+
+    .. note:: If the running minimum equals to the running maximum, the scale
+              and zero_point are set to 0.1 and 0.
+    """
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        r"""Calculates the quantization parameters."""
+
+        if self.max_val.numel() == 0 or self.min_val.numel() == 0:
+            warnings.warn("Must run observer before calling calculate_qparams.\
+                           Returning default scale and zero point.")
+            return torch.tensor([1.0]), torch.tensor([0])
+
+        assert self.min_val <= self.max_val, "min {} should be less than max {}".format(
+            self.min_val, self.max_val
+        )
+
+        if self.dtype == torch.qint8:
+            if self.reduce_range:
+                qmin, qmax = -64, 63
+            else:
+                qmin, qmax = -128, 127
+        else:  # dtype == torch.quint8
+            if self.reduce_range:
+                qmin, qmax = 0, 127
+            else:
+                qmin, qmax = 0, 255
+
+        max_val, min_val = self.max_val.to(dtype=torch.float), self.min_val.to(dtype=torch.float)
+
+        # Extend the min_val and max_val to ensure that it contains 0.
+        min_val = torch.min(min_val, torch.tensor(0.).to(dtype=torch.float))
+        max_val = torch.max(max_val, torch.tensor(0.).to(dtype=torch.float))
+
+        scale = (max_val.to(dtype=torch.double) - min_val) / float(qmax - qmin)
+
+        if scale == 0.0 or torch.isinf(1.0 / scale):
+            scale = torch.tensor(0.1).to(dtype=torch.float)
+            zero_point = 0
+
+        zero_point_from_min = qmin - min_val / scale.to(dtype=torch.double)
+        zero_point_from_max = qmax - max_val / scale.to(dtype=torch.double)
+        zero_point_from_min_error = abs(qmin) - abs(min_val / scale.to(dtype=torch.double))
+        zero_point_from_max_error = abs(qmax) - abs(max_val / scale.to(dtype=torch.double))
+
+        if zero_point_from_min_error < zero_point_from_max_error:
+            initial_zero_point = zero_point_from_min
+        else:
+            initial_zero_point = zero_point_from_max
+
+        nudged_zero_point = 0
+
+        if initial_zero_point < qmin:
+            nudged_zero_point = qmin
+        elif initial_zero_point > qmax:
+            nudged_zero_point = qmax
+        else:
+            nudged_zero_point = int(initial_zero_point.round())
+
+        return scale.to(dtype=torch.float), torch.tensor([nudged_zero_point])
+
 class PerChannelMinMaxObserver(_ObserverBase):
     r"""Observer module for computing the quantization parameters based on the
     running per channel min and max values.
@@ -933,3 +1010,4 @@
 default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
 default_histogram_observer = HistogramObserver.with_args(reduce_range=True)
 default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
+default_dynamic_quant_observer = MinMaxDynamicQuantObserver
diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py
index 0034a03..dcc1300 100644
--- a/torch/quantization/qconfig.py
+++ b/torch/quantization/qconfig.py
@@ -38,7 +38,7 @@
 default_per_channel_qconfig = QConfig(activation=default_observer,
                                       weight=default_per_channel_weight_observer)
 
-class QConfigDynamic(namedtuple('QConfigDynamic', ['weight'])):
+class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])):
     """
     Describes how to dynamically quantize a layer or a part of the network by providing
     settings (observer classe) for weights.
@@ -54,15 +54,17 @@
 
       my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
     """
-    def __new__(cls, weight):
+    def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
         # catch common mistakes
         if isinstance(weight, nn.Module):
             raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
                              "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
-        return super(QConfigDynamic, cls).__new__(cls, weight)
+        return super(QConfigDynamic, cls).__new__(cls, activation, weight)
 
-default_dynamic_qconfig = QConfigDynamic(weight=default_weight_observer)
-float16_dynamic_qconfig = QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))
+default_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
+                                         weight=default_weight_observer)
+float16_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
+                                         weight=NoopObserver.with_args(dtype=torch.float16))
 per_channel_dynamic_qconfig = QConfigDynamic(weight=default_per_channel_weight_observer)
 
 default_qat_qconfig = QConfig(activation=default_fake_quant,