Support fp8 quantization (#123161)
This commit enables float8_e5m2 and float8_e4m3fn dtypes in fx quantization and PT2E.
Motivation for using fp8 quantization instead of int8:
- it works better to run inference with the same datatype the model was trained with,
- fp8 can handle outliers better, which is one of the problems in LLMs activations.
The numerical recipe we want to use it for is fp8 inference:
- bgemms/gemms running in float8_e4m3fn,
- Per-Tensor-Quantization/Scaling,
- amax observer for measurement with input_backoff and weight_backoff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123161
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index e3b7ead..3c759fc 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -49,7 +49,11 @@
skipIfNoQNNPACK,
TestHelperModules,
)
-from torch.testing._internal.common_utils import TemporaryFileName
+from torch.testing._internal.common_utils import (
+ instantiate_parametrized_tests,
+ parametrize,
+ TemporaryFileName,
+)
@skipIfNoQNNPACK
@@ -1175,14 +1179,15 @@
self.assertIsNot(observers[0], observers[2])
self.assertIsNot(observers[1], observers[2])
- def test_int16(self):
- class Int16ActQuantizer(Quantizer):
+ @parametrize("dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
+ def test_quantization_dtype(self, dtype):
+ class DtypeActQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
- # using int32 to simulate int16
- int16_qspec = QuantizationSpec(
- dtype=torch.int16,
- quant_min=-(2**15),
- quant_max=2**15 - 1,
+ info_fun = torch.iinfo if dtype == torch.int16 else torch.finfo
+ activate_qspec = QuantizationSpec(
+ dtype=dtype,
+ quant_min=int(info_fun(dtype).min),
+ quant_max=int(info_fun(dtype).max),
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
@@ -1196,10 +1201,10 @@
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
quantization_config = QuantizationConfig(
- input_activation=int16_qspec,
+ input_activation=activate_qspec,
weight=int8_qspec,
bias=None,
- output_activation=int16_qspec,
+ output_activation=activate_qspec,
)
OP_TO_ANNOTATOR["conv"](model, quantization_config)
@@ -1214,7 +1219,7 @@
def forward(self, x):
return self.conv(x)
- quantizer = Int16ActQuantizer()
+ quantizer = DtypeActQuantizer()
node_occurrence = {
# one for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
@@ -1230,7 +1235,7 @@
self._test_quantizer(
M().eval(),
example_inputs,
- Int16ActQuantizer(),
+ quantizer,
node_occurrence,
node_list,
)
@@ -2248,3 +2253,6 @@
node_occurrence,
node_list,
)
+
+
+instantiate_parametrized_tests(TestQuantizePT2E)
diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py
index c54a304..8feafaf 100644
--- a/torch/ao/quantization/fx/_decomposed.py
+++ b/torch/ao/quantization/fx/_decomposed.py
@@ -10,12 +10,11 @@
# name is not too long
quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
-_DTYPE_TO_QVALUE_BOUNDS = {
- torch.uint8: (0, 255),
- torch.int8: (-128, 127),
- torch.int16: (-(2**15), 2**15 - 1),
- torch.int32: (-(2**31), 2**31 - 1),
-}
+_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.int16, torch.int32]
+_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
+
+_DTYPE_TO_QVALUE_BOUNDS = {k : (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES}
+_DTYPE_TO_QVALUE_BOUNDS.update({k : (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES})
# Helper to check the passed in quant min and max are valid for the dtype
def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py
index 023abff..ef90f8b 100644
--- a/torch/ao/quantization/fx/convert.py
+++ b/torch/ao/quantization/fx/convert.py
@@ -84,6 +84,18 @@
"convert_weighted_module",
]
+SUPPORTED_QDTYPES = [
+ torch.quint8,
+ torch.qint8,
+ torch.qint32,
+ torch.uint8,
+ torch.int8,
+ torch.int16,
+ torch.int32,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
+]
+
_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
@@ -136,8 +148,7 @@
if hasattr(activation_post_process, "is_dynamic"):
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
- if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
- (not is_dynamic):
+ if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
# TODO: probably should cleanup this condition check, it's hard
# to reason about this if and the following elif
@@ -372,7 +383,7 @@
if hasattr(activation_post_process, "is_dynamic"):
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
- if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
+ if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.float8_e5m2, torch.float8_e4m3fn] and \
(not is_dynamic):
# TODO: probably should cleanup this condition check, it's hard
# to reason about this if and the following elif
@@ -477,15 +488,7 @@
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
return (
- (dtype in [
- torch.quint8,
- torch.qint8,
- torch.qint32,
- torch.uint8,
- torch.int8,
- torch.int16,
- torch.int32
- ] and (not is_dynamic)) or # type: ignore[return-value]
+ (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) or # type: ignore[return-value]
is_dynamic or
dtype == torch.float16
)
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 6a4ae0b..9ca91ec 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -138,7 +138,9 @@
torch.uint8,
torch.int8,
torch.int16,
- torch.int32
+ torch.int32,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
]
_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py
index 718dc7d..5f075df 100644
--- a/torch/ao/quantization/observer.py
+++ b/torch/ao/quantization/observer.py
@@ -244,6 +244,8 @@
torch.uint8,
torch.int16,
torch.int32,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
)
assert self.dtype in _ALLOWED_DTYPES, f"Default Observer only works for {_ALLOWED_DTYPES} data type"
diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py
index 2a225d1..70b45b9 100644
--- a/torch/ao/quantization/utils.py
+++ b/torch/ao/quantization/utils.py
@@ -151,6 +151,8 @@
torch.int8: torch.int8,
torch.int16: torch.int16,
torch.int32: torch.int32,
+ torch.float8_e5m2: torch.float8_e5m2,
+ torch.float8_e4m3fn: torch.float8_e4m3fn,
}
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
return DTYPE_MAPPING[qdtype]
@@ -231,7 +233,9 @@
torch.uint8,
torch.int8,
torch.int16,
- torch.int32
+ torch.int32,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
]
and (not activation_is_dynamically_quantized(qconfig))
)
@@ -269,7 +273,9 @@
torch.uint8,
torch.int8,
torch.int16,
- torch.int32
+ torch.int32,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
]
def weight_is_statically_quantized(qconfig):
@@ -305,7 +311,18 @@
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
- static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32]
+ static_dtypes = [
+ torch.quint8,
+ torch.qint8,
+ torch.quint4x2,
+ torch.qint32,
+ torch.uint8,
+ torch.int8,
+ torch.int16,
+ torch.int32,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn
+ ]
if weight.dtype in static_dtypes:
if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
return QuantType.DYNAMIC