[Executorch][Quantization][BE] Refactor Choose Qparams (#92592)
Summary: Should hopefully be a little faster. Definitely cleaner to not create an observer inside the op
Test Plan: ci
Differential Revision: D42154677
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92592
Approved by: https://github.com/jerryzh168
diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py
index e932c28..c659123 100644
--- a/torch/ao/quantization/fx/_decomposed.py
+++ b/torch/ao/quantization/fx/_decomposed.py
@@ -1,8 +1,9 @@
import torch
from torch.library import Library, impl
-from torch.ao.quantization import MinMaxObserver
+from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
from typing import Tuple
+
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
# name is not too long
quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
@@ -182,8 +183,8 @@
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
def choose_qparams_tensor(
input: torch.Tensor,
- quant_min: int,
- quant_max: int,
+ qmin: int,
+ qmax: int,
dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Given an input Tensor, derive the per tensor affine quantization parameter
@@ -200,16 +201,14 @@
zero_point (int): quantization parameter for the target quantized Tensor
"""
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
- assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}"
+ validate_qmin_qmax(qmin, qmax)
- # Its weird to create an observer manually just to calculate qparams. I tried refactoring this functionality out of observer
- # into a util and then use that util directly, but I kept running into jit typing errors related to torch.qscheme not
- # being recognized as a type. TODO: properly refactor this out to avoid observer overhead
- tensor_dtype_to_observer_dtype = {torch.uint8: torch.quint8, torch.int8: torch.qint8}
- observer = MinMaxObserver(quant_min=quant_min, quant_max=quant_max, dtype=tensor_dtype_to_observer_dtype[dtype])
- observer(input)
- scale, zero_point = observer.calculate_qparams()
- return (scale, zero_point)
+ min_val, max_val = torch.aminmax(input)
+
+ # Future QSchemes like per_tensor_symmetric will be supported in a different op 'choose_qparams_symmetric.
+ # Customized qrange is unused for non symmetric quant so just ignore and set to false here
+ return determine_qparams(
+ min_val, max_val, qmin, qmax, input.dtype, torch.Tensor([torch.finfo(torch.float32).eps]), False)
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
def choose_qparams_tensor_meta(
diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py
index 2134f41..d3ce875 100644
--- a/torch/ao/quantization/observer.py
+++ b/torch/ao/quantization/observer.py
@@ -13,7 +13,8 @@
import torch
import torch.nn as nn
from torch.ao.quantization.utils import (
- check_min_max_valid, calculate_qmin_qmax, is_per_tensor, is_per_channel)
+ calculate_qmin_qmax, is_per_tensor, is_per_channel, determine_qparams, QSchemeTSHack, validate_qmin_qmax
+)
__all__ = [
"default_affine_fixed_qparams_observer",
@@ -236,7 +237,7 @@
), "Default Observer only works for qint8, quint8 and quint4x2 data type"
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange:
- self._validate_qmin_qmax(quant_min, quant_max)
+ validate_qmin_qmax(quant_min, quant_max)
self.quant_min, self.quant_max = \
calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range)
@@ -269,30 +270,6 @@
)
@torch.jit.export
- def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
- r"""Validates that the user-specified quantization range is properly initialized
- and within the given bound supported by the observer dtype.
-
- To accommodate lower-bit quantization with respect to the existing torch.qint8 and
- torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
- in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
- values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
- fake quantization. These estimates are compared against parameters learned through backpropagation.
- The related literatures for scale and zero point via backpropagation are as follows:
-
- Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
- Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
- """
- # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
- # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
- assert (
- quant_min <= 0 <= quant_max
- ), "Used-specified quantization range must include 0."
- assert (
- quant_min < quant_max
- ), "qmin must be strictly less than qmax for user-specified quantization range."
-
- @torch.jit.export
def _calculate_qparams(
self, min_val: torch.Tensor, max_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -307,62 +284,27 @@
scales: Scales tensor of shape (#channels,)
zero_points: Zero points tensor of shape (#channels,)
"""
- if not check_min_max_valid(min_val, max_val):
- return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
+ # See comment in ./utils.py on QSchemeTSHack and why this gross code exists
+ if self.qscheme == torch.per_tensor_affine:
+ qscheme = QSchemeTSHack.per_tensor_affine
- quant_min, quant_max = self.quant_min, self.quant_max
- min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
- max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+ elif self.qscheme == torch.per_channel_affine:
+ qscheme = QSchemeTSHack.per_channel_affine
- device = min_val_neg.device
- scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
- zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+ elif self.qscheme == torch.per_tensor_symmetric:
+ qscheme = QSchemeTSHack.per_tensor_symmetric
- if (
- self.qscheme == torch.per_tensor_symmetric
- or self.qscheme == torch.per_channel_symmetric
- ):
- max_val_pos = torch.max(-min_val_neg, max_val_pos)
- scale = max_val_pos / (float(quant_max - quant_min) / 2)
- scale = torch.max(scale, self.eps)
- if self.dtype == torch.quint8:
- if self.has_customized_qrange:
- # When customized quantization range is used, down-rounded midpoint of the range is chosen.
- zero_point = zero_point.new_full(
- zero_point.size(), (quant_min + quant_max) // 2
- )
- else:
- zero_point = zero_point.new_full(zero_point.size(), 128)
+ elif self.qscheme == torch.per_channel_symmetric:
+ qscheme = QSchemeTSHack.per_channel_symmetric
+
elif self.qscheme == torch.per_channel_affine_float_qparams:
- scale = (max_val - min_val) / float(quant_max - quant_min)
- scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
- # We use the quantize function
- # xq = Round(Xf * inv_scale + zero_point),
- # setting zero_point to (-1 * min *inv_scale) we get
- # Xq = Round((Xf - min) * inv_scale)
- zero_point = -1 * min_val / scale
+ qscheme = QSchemeTSHack.per_channel_affine_float_qparams
else:
- scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
- scale = torch.max(scale, self.eps)
- zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
- zero_point = torch.clamp(zero_point, quant_min, quant_max)
+ raise Exception(f"Unsupported Qscheme {self.qscheme}. Update QSchemeTSHack to support this new qscheme")
- # For scalar values, cast them to Tensors of size 1 to keep the shape
- # consistent with default values in FakeQuantize.
- if len(scale.shape) == 0:
- # TODO: switch to scale.item() after adding JIT support
- scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
- if len(zero_point.shape) == 0:
- # TODO: switch to zero_point.item() after adding JIT support
- zero_point = torch.tensor(
- [int(zero_point)], dtype=zero_point.dtype, device=device
- )
- if self.qscheme == torch.per_channel_affine_float_qparams:
- zero_point = torch.tensor(
- [float(zero_point)], dtype=zero_point.dtype, device=device
- )
-
- return scale, zero_point
+ return determine_qparams(
+ min_val, max_val, self.quant_min, self.quant_max, self.dtype, self.eps,
+ self.has_customized_qrange, qscheme)
@torch.jit.export
def reset_min_max_vals(self):
diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py
index a40935b..3e13d82 100644
--- a/torch/ao/quantization/utils.py
+++ b/torch/ao/quantization/utils.py
@@ -1,16 +1,17 @@
"""
Utils shared by different modes of quantization (eager/graph)
"""
-import warnings
import functools
-import torch
-from torch.fx import Node
-from torch.ao.quantization.quant_type import QuantType
-from typing import Tuple, Any, Union, Callable, Dict, Optional
-from torch.nn.utils.parametrize import is_parametrized
+import warnings
from collections import OrderedDict
-from inspect import signature
-from inspect import getfullargspec
+from enum import Enum
+from inspect import getfullargspec, signature
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch.ao.quantization.quant_type import QuantType
+from torch.fx import Node
+from torch.nn.utils.parametrize import is_parametrized
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
NodePattern.__module__ = "torch.ao.quantization.utils"
@@ -476,6 +477,116 @@
normalized_kwargs[attr] = val
return normalized_kwargs
+def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
+ r"""Validates that the user-specified quantization range is properly initialized
+ and within the given bound supported by the observer dtype.
+
+ To accommodate lower-bit quantization with respect to the existing torch.qint8 and
+ torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
+ in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
+ values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
+ fake quantization. These estimates are compared against parameters learned through backpropagation.
+ The related literatures for scale and zero point via backpropagation are as follows:
+
+ Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
+ Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
+ """
+ # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
+ # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
+ assert (
+ quant_min <= 0 <= quant_max
+ ), "Used-specified quantization range must include 0."
+ assert (
+ quant_min < quant_max
+ ), "qmin must be strictly less than qmax for user-specified quantization range."
+
+# As far as I can tell regular torch.qscheme is not accepted by torchscript.
+# this is a problem as we want to have a single source of truth for choose_qparams
+# in different stacks. As a gross hack this lets us convert torch.qscheme into a normal
+# enum and then pass that as an arg.
+class QSchemeTSHack(Enum):
+ """ Class to allow the passing of QSchemes around on methods that must be torchscriptable,
+ ideally regular torch.qscheme would be torchscriptable and then this could be deleted.
+
+ Must match core/QScheme.h and the generated torch/_C/__init__.pyi python types
+ """
+ per_tensor_affine: int = 0
+ per_channel_affine: int = 1
+ per_tensor_symmetric: int = 2
+ per_channel_symmetric: int = 3
+ per_channel_affine_float_qparams: int = 4
+
+def determine_qparams(
+ min_val: torch.Tensor, max_val: torch.Tensor, quant_min: int, quant_max: int,
+ dtype: torch.dtype, eps: torch.Tensor, has_customized_qrange: bool,
+ qscheme: QSchemeTSHack = QSchemeTSHack.per_tensor_affine) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Calculates the quantization parameters, given min and max
+ value tensors. Works for both per tensor and per channel cases
+
+ Args:
+ min_val: Minimum values per channel
+ max_val: Maximum values per channel
+
+ Returns:
+ scales: Scales tensor of shape (#channels,)
+ zero_points: Zero points tensor of shape (#channels,)
+ """
+ if not check_min_max_valid(min_val, max_val):
+ return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
+
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+
+ device = min_val_neg.device
+ scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
+ zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+ if (
+ qscheme == QSchemeTSHack.per_tensor_symmetric
+ or qscheme == QSchemeTSHack.per_channel_symmetric
+ ):
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
+ scale = max_val_pos / (float(quant_max - quant_min) / 2)
+ scale = torch.max(scale, eps)
+ if dtype == torch.quint8:
+ if has_customized_qrange:
+ # When customized quantization range is used, down-rounded midpoint of the range is chosen.
+ zero_point = zero_point.new_full(
+ zero_point.size(), (quant_min + quant_max) // 2
+ )
+ else:
+ zero_point = zero_point.new_full(zero_point.size(), 128)
+ elif qscheme == QSchemeTSHack.per_channel_affine_float_qparams:
+ scale = (max_val - min_val) / float(quant_max - quant_min)
+ scale = torch.where(scale > eps, scale, torch.ones_like(scale))
+ # We use the quantize function
+ # xq = Round(Xf * inv_scale + zero_point),
+ # setting zero_point to (-1 * min *inv_scale) we get
+ # Xq = Round((Xf - min) * inv_scale)
+ zero_point = -1 * min_val / scale
+ else:
+ scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
+ scale = torch.max(scale, eps)
+ zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
+ zero_point = torch.clamp(zero_point, quant_min, quant_max)
+
+ # For scalar values, cast them to Tensors of size 1 to keep the shape
+ # consistent with default values in FakeQuantize.
+ if len(scale.shape) == 0:
+ # TODO: switch to scale.item() after adding JIT support
+ scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
+ if len(zero_point.shape) == 0:
+ # TODO: switch to zero_point.item() after adding JIT support
+ zero_point = torch.tensor(
+ [int(zero_point)], dtype=zero_point.dtype, device=device
+ )
+ if qscheme == QSchemeTSHack.per_channel_affine_float_qparams:
+ zero_point = torch.tensor(
+ [float(zero_point)], dtype=zero_point.dtype, device=device
+ )
+
+ return scale, zero_point
+
def _get_num_pos_args(f: Callable) -> int:
""" Get number of positional args for a function
@@ -662,4 +773,7 @@
"has_no_children_ignoring_parametrizations",
"get_fqn_to_example_inputs",
"to_underlying_dtype",
+ "determine_qparams",
+ "QSchemeTSHack",
+ "validate_qmin_qmax",
]