[quant][graph] Add _choose_qparams function for graph mode (#35235)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35235

For dynamic quantization in graph mode, we need an operator that returns the qparams of the tensor
similar to the linear_dynamic quantized op

Test Plan:
python test/test_quantized_tensor.py TestQuantizedTensor.test_choose_qparams

Imported from OSS

Differential Revision: D20608793

fbshipit-source-id: b923b2620421b32d05f4097db0d6153d53198221
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 39a9189..a4a9abe 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3717,6 +3717,10 @@
   use_c10_dispatcher: full
   variants: function
 
+- func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)
+  use_c10_dispatcher: full
+  variants: function
+
 # to(Device) must not exist because all constructors of Device also works for
 # TensorOptions. Otherwise, an ambiguity error is thrown.
 # See NOTE [ TensorOptions Constructors ].
diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp
index 870f6b3..45a8a95 100644
--- a/aten/src/ATen/native/quantized/QTensor.cpp
+++ b/aten/src/ATen/native/quantized/QTensor.cpp
@@ -4,6 +4,7 @@
 #include <ATen/native/cpu/Loops.h>
 #include <ATen/quantized/QTensorImpl.h>
 #include <ATen/quantized/Quantizer.h>
+#include <ATen/native/quantized/cpu/quant_utils.h>
 
 namespace at {
 namespace native {
@@ -218,5 +219,24 @@
   return 0 == memcmp(self_data, other_data, self.numel() * self.element_size());
 }
 
+/* Calculate the quantization params for the activation tensor */
+std::tuple<double, int64_t> _choose_qparams_per_tensor(const Tensor& self, bool reduce_range) {
+  at::Tensor a;
+  auto input_contig = self.contiguous();
+  float x_min = input_contig.min().item<float>();
+  float x_max = input_contig.max().item<float>();
+
+  auto q_params = quant_utils::ChooseQuantizationParams(
+        /*min=*/x_min,
+        /*max=*/x_max,
+        /*qmin=*/0,
+        /*qmax=*/255,
+        /*preserve_sparsity=*/false,
+        /*force_scale_power_of_two=*/false,
+        /*reduce_range=*/reduce_range);
+
+  return std::make_tuple(q_params.scale, q_params.zero_point);
+}
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/quant_utils.h b/aten/src/ATen/native/quantized/cpu/quant_utils.h
index 408af1c..ffc5f83 100644
--- a/aten/src/ATen/native/quantized/cpu/quant_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/quant_utils.h
@@ -15,7 +15,7 @@
 // to the real value 0, and 'scale' is the difference of real values
 // corresponding to consecutive quantized values.
 struct TensorQuantizationParams {
-  float scale;
+  double scale;
   std::int32_t zero_point;
   int precision;
 };
@@ -57,15 +57,14 @@
 
   // Use double precision for intermediate computation but use single precision
   // in final number to reflect the actual number used during quantization.
-  float scale = (static_cast<double>(max) - min) / (qmax - qmin);
+  double scale = (static_cast<double>(max) - min) / (qmax - qmin);
   // If scale is 0 or too small so its reciprocal is infinity, we arbitrary
   // adjust the scale to 0.1 . We want to avoid scale's reciprocal being
   // infinity because some of fbgemm code pre-computes scale's reciprocal to do
   // multiplication instead of division in the time critical part of code.
-  if (scale == 0.0f || isinf(1.0f / scale)) {
+  if (scale == 0.0f || std::isinf(1.0f / scale)) {
     scale = 0.1;
   }
-
   TORCH_CHECK(scale > 0, "quantization scale should be > 0");
 
   if (force_scale_power_of_two) {
diff --git a/test/test_quantized_tensor.py b/test/test_quantized_tensor.py
index 0be19af..dc39706 100644
--- a/test/test_quantized_tensor.py
+++ b/test/test_quantized_tensor.py
@@ -1,10 +1,16 @@
 import numpy as np
-
+import math
 import torch
 import io
 from copy import deepcopy
+from hypothesis import given
+from hypothesis import strategies as st
 
 from torch.testing._internal.common_utils import TestCase, run_tests
+import torch.testing._internal.hypothesis_utils as hu
+
+hu.assert_deadline_disabled()
+
 import tempfile
 
 class Foo(torch.nn.Module):
@@ -12,6 +18,49 @@
         super(Foo, self).__init__()
         self.qscheme = torch.per_tensor_symmetric
 
+def _calculate_dynamic_qparams(X, dtype, reduce_range=False):
+    """Calculate the dynamic quantization parameters (scale, zero_point)
+    according to the min and max element of the tensor"""
+    if isinstance(X, torch.Tensor):
+        X = X.numpy()
+    if dtype == torch.qint8:
+        if reduce_range:
+            qmin, qmax = -64, 63
+        else:
+            qmin, qmax = -128, 127
+    else:  # dtype == torch.quint8
+        if reduce_range:
+            qmin, qmax = 0, 127
+        else:
+            qmin, qmax = 0, 255
+
+    min_val = X.min().astype(dtype=np.float32)
+    max_val = X.max().astype(dtype=np.float32)
+    min_val = min(0.0, min_val)
+    max_val = max(0.0, max_val)
+    scale = (np.float64(max_val) - min_val) / (qmax - qmin)
+    if scale == 0.0 or math.isinf(1.0 / scale):
+        scale = np.float64(0.1)
+        zero_point = 0
+
+    zero_point_from_min = qmin - min_val / float(scale)
+    zero_point_from_max = qmax - max_val / float(scale)
+    zero_point_from_min_error = abs(qmin) - abs(min_val / float(scale))
+    zero_point_from_max_error = abs(qmax) - abs(max_val / float(scale))
+    if zero_point_from_min_error < zero_point_from_max_error:
+        initial_zero_point = zero_point_from_min
+    else:
+        initial_zero_point = zero_point_from_max
+    nudged_zero_point = 0
+
+    if initial_zero_point < qmin:
+        nudged_zero_point = qmin
+    elif initial_zero_point > qmax:
+        nudged_zero_point = qmax
+    else:
+        nudged_zero_point = int(round(initial_zero_point))
+
+    return [scale.astype(np.float32), int(nudged_zero_point)]
 
 class TestQuantizedTensor(TestCase):
     def test_qtensor(self):
@@ -363,6 +412,18 @@
 
         self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)
 
+    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4,
+                                              min_side=1, max_side=10),
+                       qparams=hu.qparams()),
+           reduce_range=st.booleans()
+           )
+    def test_choose_qparams(self, X, reduce_range):
+        X, (scale, zero_point, torch_type) = X
+        X = torch.from_numpy(X)
+        X_scale, X_zp = _calculate_dynamic_qparams(X, torch.quint8, reduce_range=reduce_range)
+        qparams = torch._choose_qparams_per_tensor(X, reduce_range)
+        np.testing.assert_array_almost_equal(X_scale, qparams[0], decimal=3)
+        self.assertEqual(X_zp, qparams[1])
 
 if __name__ == "__main__":
     run_tests()
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index adf1744..11c5722 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -349,6 +349,7 @@
     'std::tuple<Tensor,Tensor,double,int64_t>',
     'std::tuple<Tensor,Tensor,Tensor,Tensor,int64_t>',
     'std::tuple<Tensor,Tensor,double,Tensor,int64_t>',
+    'std::tuple<double,int64_t>',
     'std::vector<Tensor>',
     'Scalar', 'bool', 'int64_t', 'void*', 'void',
     'QScheme', 'double',
diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h
index c3c536b..9d183fe 100644
--- a/torch/csrc/autograd/utils/wrap_outputs.h
+++ b/torch/csrc/autograd/utils/wrap_outputs.h
@@ -185,4 +185,13 @@
   }
   return r.release();
 }
+
+inline PyObject* wrap(std::tuple<float, int64_t> tensors) {
+  auto r = THPObjectPtr{PyTuple_New(2)};
+  if (!r) throw python_error();
+  PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors))));
+  PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors))));
+  return r.release();
+}
+
 }}} // namespace torch::autograd::utils