Support fp8 quantization (#123161)

This commit enables float8_e5m2 and float8_e4m3fn dtypes in fx quantization and PT2E.

Motivation for using fp8 quantization instead of int8:
- it works better to run inference with the same datatype the model was trained with,
- fp8 can handle outliers better, which is one of the problems in LLMs activations.

The numerical recipe we want to use it for is fp8 inference:
- bgemms/gemms running in float8_e4m3fn,
- Per-Tensor-Quantization/Scaling,
- amax observer for measurement with input_backoff and weight_backoff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123161
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index e3b7ead..3c759fc 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -49,7 +49,11 @@
     skipIfNoQNNPACK,
     TestHelperModules,
 )
-from torch.testing._internal.common_utils import TemporaryFileName
+from torch.testing._internal.common_utils import (
+    instantiate_parametrized_tests,
+    parametrize,
+    TemporaryFileName,
+)
 
 
 @skipIfNoQNNPACK
@@ -1175,14 +1179,15 @@
         self.assertIsNot(observers[0], observers[2])
         self.assertIsNot(observers[1], observers[2])
 
-    def test_int16(self):
-        class Int16ActQuantizer(Quantizer):
+    @parametrize("dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
+    def test_quantization_dtype(self, dtype):
+        class DtypeActQuantizer(Quantizer):
             def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
-                # using int32 to simulate int16
-                int16_qspec = QuantizationSpec(
-                    dtype=torch.int16,
-                    quant_min=-(2**15),
-                    quant_max=2**15 - 1,
+                info_fun = torch.iinfo if dtype == torch.int16 else torch.finfo
+                activate_qspec = QuantizationSpec(
+                    dtype=dtype,
+                    quant_min=int(info_fun(dtype).min),
+                    quant_max=int(info_fun(dtype).max),
                     qscheme=torch.per_tensor_affine,
                     is_dynamic=False,
                     observer_or_fake_quant_ctr=observer.default_observer,
@@ -1196,10 +1201,10 @@
                     observer_or_fake_quant_ctr=observer.default_weight_observer,
                 )
                 quantization_config = QuantizationConfig(
-                    input_activation=int16_qspec,
+                    input_activation=activate_qspec,
                     weight=int8_qspec,
                     bias=None,
-                    output_activation=int16_qspec,
+                    output_activation=activate_qspec,
                 )
                 OP_TO_ANNOTATOR["conv"](model, quantization_config)
 
@@ -1214,7 +1219,7 @@
             def forward(self, x):
                 return self.conv(x)
 
-        quantizer = Int16ActQuantizer()
+        quantizer = DtypeActQuantizer()
         node_occurrence = {
             # one for input of the first conv, one for output for the first conv
             torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
@@ -1230,7 +1235,7 @@
         self._test_quantizer(
             M().eval(),
             example_inputs,
-            Int16ActQuantizer(),
+            quantizer,
             node_occurrence,
             node_list,
         )
@@ -2248,3 +2253,6 @@
             node_occurrence,
             node_list,
         )
+
+
+instantiate_parametrized_tests(TestQuantizePT2E)
diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py
index c54a304..8feafaf 100644
--- a/torch/ao/quantization/fx/_decomposed.py
+++ b/torch/ao/quantization/fx/_decomposed.py
@@ -10,12 +10,11 @@
 # name is not too long
 quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
 
-_DTYPE_TO_QVALUE_BOUNDS = {
-    torch.uint8: (0, 255),
-    torch.int8: (-128, 127),
-    torch.int16: (-(2**15), 2**15 - 1),
-    torch.int32: (-(2**31), 2**31 - 1),
-}
+_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.int16, torch.int32]
+_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
+
+_DTYPE_TO_QVALUE_BOUNDS = {k : (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES}
+_DTYPE_TO_QVALUE_BOUNDS.update({k : (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES})
 
 # Helper to check the passed in quant min and max are valid for the dtype
 def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py
index 023abff..ef90f8b 100644
--- a/torch/ao/quantization/fx/convert.py
+++ b/torch/ao/quantization/fx/convert.py
@@ -84,6 +84,18 @@
     "convert_weighted_module",
 ]
 
+SUPPORTED_QDTYPES = [
+    torch.quint8,
+    torch.qint8,
+    torch.qint32,
+    torch.uint8,
+    torch.int8,
+    torch.int16,
+    torch.int32,
+    torch.float8_e5m2,
+    torch.float8_e4m3fn,
+]
+
 _QSCHEME_TO_CHOOSE_QPARAMS_OP = {
     torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
     torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
@@ -136,8 +148,7 @@
     if hasattr(activation_post_process, "is_dynamic"):
         is_dynamic = activation_post_process.is_dynamic  # type: ignore[assignment]
 
-    if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
-            (not is_dynamic):
+    if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
         # TODO: probably should cleanup this condition check, it's hard
         # to reason about this if and the following elif
 
@@ -372,7 +383,7 @@
     if hasattr(activation_post_process, "is_dynamic"):
         is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
 
-    if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
+    if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.float8_e5m2, torch.float8_e4m3fn] and \
             (not is_dynamic):
         # TODO: probably should cleanup this condition check, it's hard
         # to reason about this if and the following elif
@@ -477,15 +488,7 @@
         is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
 
     return (
-        (dtype in [
-            torch.quint8,
-            torch.qint8,
-            torch.qint32,
-            torch.uint8,
-            torch.int8,
-            torch.int16,
-            torch.int32
-        ] and (not is_dynamic)) or  # type: ignore[return-value]
+        (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) or  # type: ignore[return-value]
         is_dynamic or
         dtype == torch.float16
     )
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 6a4ae0b..9ca91ec 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -138,7 +138,9 @@
     torch.uint8,
     torch.int8,
     torch.int16,
-    torch.int32
+    torch.int32,
+    torch.float8_e5m2,
+    torch.float8_e4m3fn,
 ]
 
 _DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py
index 718dc7d..5f075df 100644
--- a/torch/ao/quantization/observer.py
+++ b/torch/ao/quantization/observer.py
@@ -244,6 +244,8 @@
             torch.uint8,
             torch.int16,
             torch.int32,
+            torch.float8_e5m2,
+            torch.float8_e4m3fn,
         )
 
         assert self.dtype in _ALLOWED_DTYPES, f"Default Observer only works for {_ALLOWED_DTYPES} data type"
diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py
index 2a225d1..70b45b9 100644
--- a/torch/ao/quantization/utils.py
+++ b/torch/ao/quantization/utils.py
@@ -151,6 +151,8 @@
         torch.int8: torch.int8,
         torch.int16: torch.int16,
         torch.int32: torch.int32,
+        torch.float8_e5m2: torch.float8_e5m2,
+        torch.float8_e4m3fn: torch.float8_e4m3fn,
     }
     assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
     return DTYPE_MAPPING[qdtype]
@@ -231,7 +233,9 @@
             torch.uint8,
             torch.int8,
             torch.int16,
-            torch.int32
+            torch.int32,
+            torch.float8_e5m2,
+            torch.float8_e4m3fn,
         ]
         and (not activation_is_dynamically_quantized(qconfig))
     )
@@ -269,7 +273,9 @@
         torch.uint8,
         torch.int8,
         torch.int16,
-        torch.int32
+        torch.int32,
+        torch.float8_e5m2,
+        torch.float8_e4m3fn,
     ]
 
 def weight_is_statically_quantized(qconfig):
@@ -305,7 +311,18 @@
     assert qconfig is not None
     activation = qconfig.activation()
     weight = qconfig.weight()
-    static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32]
+    static_dtypes = [
+        torch.quint8,
+        torch.qint8,
+        torch.quint4x2,
+        torch.qint32,
+        torch.uint8,
+        torch.int8,
+        torch.int16,
+        torch.int32,
+        torch.float8_e5m2,
+        torch.float8_e4m3fn
+    ]
     if weight.dtype in static_dtypes:
         if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
             return QuantType.DYNAMIC