[Executorch][Quantization][BE] Refactor Choose Qparams (#92592)

Summary: Should hopefully be a little faster. Definitely cleaner to not create an observer inside the op

Test Plan: ci

Differential Revision: D42154677

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92592
Approved by: https://github.com/jerryzh168
diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py
index e932c28..c659123 100644
--- a/torch/ao/quantization/fx/_decomposed.py
+++ b/torch/ao/quantization/fx/_decomposed.py
@@ -1,8 +1,9 @@
 import torch
 from torch.library import Library, impl
-from torch.ao.quantization import MinMaxObserver
+from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
 from typing import Tuple
 
+
 # Note: decomposed means decomposed quantized tensor, using decomposed so that the
 # name is not too long
 quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
@@ -182,8 +183,8 @@
 @impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
 def choose_qparams_tensor(
         input: torch.Tensor,
-        quant_min: int,
-        quant_max: int,
+        qmin: int,
+        qmax: int,
         dtype: torch.dtype
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """ Given an input Tensor, derive the per tensor affine quantization parameter
@@ -200,16 +201,14 @@
        zero_point (int): quantization parameter for the target quantized Tensor
     """
     assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
-    assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}"
+    validate_qmin_qmax(qmin, qmax)
 
-    # Its weird to create an observer manually just to calculate qparams. I tried refactoring this functionality out of observer
-    # into a util and then use that util directly, but I kept running into jit typing errors related to torch.qscheme not
-    # being recognized as a type. TODO: properly refactor this out to avoid observer overhead
-    tensor_dtype_to_observer_dtype = {torch.uint8: torch.quint8, torch.int8: torch.qint8}
-    observer = MinMaxObserver(quant_min=quant_min, quant_max=quant_max, dtype=tensor_dtype_to_observer_dtype[dtype])
-    observer(input)
-    scale, zero_point = observer.calculate_qparams()
-    return (scale, zero_point)
+    min_val, max_val = torch.aminmax(input)
+
+    # Future QSchemes like per_tensor_symmetric will be supported in a different op 'choose_qparams_symmetric.
+    # Customized qrange is unused for non symmetric quant so just ignore and set to false here
+    return determine_qparams(
+        min_val, max_val, qmin, qmax, input.dtype, torch.Tensor([torch.finfo(torch.float32).eps]), False)
 
 @impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
 def choose_qparams_tensor_meta(
diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py
index 2134f41..d3ce875 100644
--- a/torch/ao/quantization/observer.py
+++ b/torch/ao/quantization/observer.py
@@ -13,7 +13,8 @@
 import torch
 import torch.nn as nn
 from torch.ao.quantization.utils import (
-    check_min_max_valid, calculate_qmin_qmax, is_per_tensor, is_per_channel)
+    calculate_qmin_qmax, is_per_tensor, is_per_channel, determine_qparams, QSchemeTSHack, validate_qmin_qmax
+)
 
 __all__ = [
     "default_affine_fixed_qparams_observer",
@@ -236,7 +237,7 @@
         ), "Default Observer only works for qint8, quint8 and quint4x2 data type"
         self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
         if self.has_customized_qrange:
-            self._validate_qmin_qmax(quant_min, quant_max)
+            validate_qmin_qmax(quant_min, quant_max)
         self.quant_min, self.quant_max = \
             calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range)
 
@@ -269,30 +270,6 @@
         )
 
     @torch.jit.export
-    def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
-        r"""Validates that the user-specified quantization range is properly initialized
-        and within the given bound supported by the observer dtype.
-
-        To accommodate lower-bit quantization with respect to the existing torch.qint8 and
-        torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
-        in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
-        values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
-        fake quantization. These estimates are compared against parameters learned through backpropagation.
-        The related literatures for scale and zero point via backpropagation are as follows:
-
-        Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
-        Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
-        """
-        # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
-        # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
-        assert (
-            quant_min <= 0 <= quant_max
-        ), "Used-specified quantization range must include 0."
-        assert (
-            quant_min < quant_max
-        ), "qmin must be strictly less than qmax for user-specified quantization range."
-
-    @torch.jit.export
     def _calculate_qparams(
         self, min_val: torch.Tensor, max_val: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -307,62 +284,27 @@
             scales: Scales tensor of shape (#channels,)
             zero_points: Zero points tensor of shape (#channels,)
         """
-        if not check_min_max_valid(min_val, max_val):
-            return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
+        # See comment in ./utils.py on QSchemeTSHack and why this gross code exists
+        if self.qscheme == torch.per_tensor_affine:
+            qscheme = QSchemeTSHack.per_tensor_affine
 
-        quant_min, quant_max = self.quant_min, self.quant_max
-        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
-        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+        elif self.qscheme == torch.per_channel_affine:
+            qscheme = QSchemeTSHack.per_channel_affine
 
-        device = min_val_neg.device
-        scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
-        zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+        elif self.qscheme == torch.per_tensor_symmetric:
+            qscheme = QSchemeTSHack.per_tensor_symmetric
 
-        if (
-            self.qscheme == torch.per_tensor_symmetric
-            or self.qscheme == torch.per_channel_symmetric
-        ):
-            max_val_pos = torch.max(-min_val_neg, max_val_pos)
-            scale = max_val_pos / (float(quant_max - quant_min) / 2)
-            scale = torch.max(scale, self.eps)
-            if self.dtype == torch.quint8:
-                if self.has_customized_qrange:
-                    # When customized quantization range is used, down-rounded midpoint of the range is chosen.
-                    zero_point = zero_point.new_full(
-                        zero_point.size(), (quant_min + quant_max) // 2
-                    )
-                else:
-                    zero_point = zero_point.new_full(zero_point.size(), 128)
+        elif self.qscheme == torch.per_channel_symmetric:
+            qscheme = QSchemeTSHack.per_channel_symmetric
+
         elif self.qscheme == torch.per_channel_affine_float_qparams:
-            scale = (max_val - min_val) / float(quant_max - quant_min)
-            scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
-            # We use the quantize function
-            # xq = Round(Xf * inv_scale + zero_point),
-            # setting zero_point to (-1 * min *inv_scale) we get
-            # Xq = Round((Xf - min) * inv_scale)
-            zero_point = -1 * min_val / scale
+            qscheme = QSchemeTSHack.per_channel_affine_float_qparams
         else:
-            scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
-            scale = torch.max(scale, self.eps)
-            zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
-            zero_point = torch.clamp(zero_point, quant_min, quant_max)
+            raise Exception(f"Unsupported Qscheme {self.qscheme}. Update QSchemeTSHack to support this new qscheme")
 
-        # For scalar values, cast them to Tensors of size 1 to keep the shape
-        # consistent with default values in FakeQuantize.
-        if len(scale.shape) == 0:
-            # TODO: switch to scale.item() after adding JIT support
-            scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
-        if len(zero_point.shape) == 0:
-            # TODO: switch to zero_point.item() after adding JIT support
-            zero_point = torch.tensor(
-                [int(zero_point)], dtype=zero_point.dtype, device=device
-            )
-            if self.qscheme == torch.per_channel_affine_float_qparams:
-                zero_point = torch.tensor(
-                    [float(zero_point)], dtype=zero_point.dtype, device=device
-                )
-
-        return scale, zero_point
+        return determine_qparams(
+            min_val, max_val, self.quant_min, self.quant_max, self.dtype, self.eps,
+            self.has_customized_qrange, qscheme)
 
     @torch.jit.export
     def reset_min_max_vals(self):
diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py
index a40935b..3e13d82 100644
--- a/torch/ao/quantization/utils.py
+++ b/torch/ao/quantization/utils.py
@@ -1,16 +1,17 @@
 """
 Utils shared by different modes of quantization (eager/graph)
 """
-import warnings
 import functools
-import torch
-from torch.fx import Node
-from torch.ao.quantization.quant_type import QuantType
-from typing import Tuple, Any, Union, Callable, Dict, Optional
-from torch.nn.utils.parametrize import is_parametrized
+import warnings
 from collections import OrderedDict
-from inspect import signature
-from inspect import getfullargspec
+from enum import Enum
+from inspect import getfullargspec, signature
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch.ao.quantization.quant_type import QuantType
+from torch.fx import Node
+from torch.nn.utils.parametrize import is_parametrized
 
 NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
 NodePattern.__module__ = "torch.ao.quantization.utils"
@@ -476,6 +477,116 @@
             normalized_kwargs[attr] = val
     return normalized_kwargs
 
+def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
+    r"""Validates that the user-specified quantization range is properly initialized
+    and within the given bound supported by the observer dtype.
+
+    To accommodate lower-bit quantization with respect to the existing torch.qint8 and
+    torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
+    in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
+    values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
+    fake quantization. These estimates are compared against parameters learned through backpropagation.
+    The related literatures for scale and zero point via backpropagation are as follows:
+
+    Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
+    Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
+    """
+    # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
+    # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
+    assert (
+        quant_min <= 0 <= quant_max
+    ), "Used-specified quantization range must include 0."
+    assert (
+        quant_min < quant_max
+    ), "qmin must be strictly less than qmax for user-specified quantization range."
+
+# As far as I can tell regular torch.qscheme is not accepted by torchscript.
+# this is a problem as we want to have a single source of truth for choose_qparams
+# in different stacks. As a gross hack this lets us convert torch.qscheme into a normal
+# enum and then pass that as an arg.
+class QSchemeTSHack(Enum):
+    """ Class to allow the passing of QSchemes around on methods that must be torchscriptable,
+    ideally regular torch.qscheme would be torchscriptable and then this could be deleted.
+
+    Must match core/QScheme.h and the generated torch/_C/__init__.pyi python types
+    """
+    per_tensor_affine: int = 0
+    per_channel_affine: int = 1
+    per_tensor_symmetric: int = 2
+    per_channel_symmetric: int = 3
+    per_channel_affine_float_qparams: int = 4
+
+def determine_qparams(
+        min_val: torch.Tensor, max_val: torch.Tensor, quant_min: int, quant_max: int,
+        dtype: torch.dtype, eps: torch.Tensor, has_customized_qrange: bool,
+        qscheme: QSchemeTSHack = QSchemeTSHack.per_tensor_affine) -> Tuple[torch.Tensor, torch.Tensor]:
+    r"""Calculates the quantization parameters, given min and max
+    value tensors. Works for both per tensor and per channel cases
+
+    Args:
+        min_val: Minimum values per channel
+        max_val: Maximum values per channel
+
+    Returns:
+        scales: Scales tensor of shape (#channels,)
+        zero_points: Zero points tensor of shape (#channels,)
+    """
+    if not check_min_max_valid(min_val, max_val):
+        return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
+
+    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+
+    device = min_val_neg.device
+    scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
+    zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+    if (
+        qscheme == QSchemeTSHack.per_tensor_symmetric
+        or qscheme == QSchemeTSHack.per_channel_symmetric
+    ):
+        max_val_pos = torch.max(-min_val_neg, max_val_pos)
+        scale = max_val_pos / (float(quant_max - quant_min) / 2)
+        scale = torch.max(scale, eps)
+        if dtype == torch.quint8:
+            if has_customized_qrange:
+                # When customized quantization range is used, down-rounded midpoint of the range is chosen.
+                zero_point = zero_point.new_full(
+                    zero_point.size(), (quant_min + quant_max) // 2
+                )
+            else:
+                zero_point = zero_point.new_full(zero_point.size(), 128)
+    elif qscheme == QSchemeTSHack.per_channel_affine_float_qparams:
+        scale = (max_val - min_val) / float(quant_max - quant_min)
+        scale = torch.where(scale > eps, scale, torch.ones_like(scale))
+        # We use the quantize function
+        # xq = Round(Xf * inv_scale + zero_point),
+        # setting zero_point to (-1 * min *inv_scale) we get
+        # Xq = Round((Xf - min) * inv_scale)
+        zero_point = -1 * min_val / scale
+    else:
+        scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
+        scale = torch.max(scale, eps)
+        zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
+        zero_point = torch.clamp(zero_point, quant_min, quant_max)
+
+    # For scalar values, cast them to Tensors of size 1 to keep the shape
+    # consistent with default values in FakeQuantize.
+    if len(scale.shape) == 0:
+        # TODO: switch to scale.item() after adding JIT support
+        scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
+    if len(zero_point.shape) == 0:
+        # TODO: switch to zero_point.item() after adding JIT support
+        zero_point = torch.tensor(
+            [int(zero_point)], dtype=zero_point.dtype, device=device
+        )
+        if qscheme == QSchemeTSHack.per_channel_affine_float_qparams:
+            zero_point = torch.tensor(
+                [float(zero_point)], dtype=zero_point.dtype, device=device
+            )
+
+    return scale, zero_point
+
 def _get_num_pos_args(f: Callable) -> int:
     """ Get number of positional args for a function
 
@@ -662,4 +773,7 @@
     "has_no_children_ignoring_parametrizations",
     "get_fqn_to_example_inputs",
     "to_underlying_dtype",
+    "determine_qparams",
+    "QSchemeTSHack",
+    "validate_qmin_qmax",
 ]