[Quant] Remove weight from DTypeConfig for non-weighted ops (#86335)
Summary: Weight dtypes should be specified only for weighted
ops like conv and linear. This commit removes weight dtypes
from the DTypeConfigs used in binary ops and fixed qparams ops.
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
Reviewers: jerryzh168, vkuzo
Subscribers: jerryzh168, vkuzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86335
Approved by: https://github.com/vkuzo
diff --git a/torch/ao/quantization/backend_config/fbgemm.py b/torch/ao/quantization/backend_config/fbgemm.py
index 6e9b525..de38272 100644
--- a/torch/ao/quantization/backend_config/fbgemm.py
+++ b/torch/ao/quantization/backend_config/fbgemm.py
@@ -23,7 +23,7 @@
# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized
# values to within [0, 127].
-fbgemm_weighted_op_int8_dtype_config = DTypeConfig(
+fbgemm_weighted_op_quint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
@@ -79,15 +79,15 @@
"""
Return the `BackendConfig` for PyTorch's native FBGEMM backend.
"""
- conv_dtype_configs = [fbgemm_weighted_op_int8_dtype_config]
+ conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config]
linear_dtype_configs = [
- fbgemm_weighted_op_int8_dtype_config,
+ fbgemm_weighted_op_quint8_dtype_config,
fbgemm_default_dynamic_int8_dtype_config,
fbgemm_default_dynamic_float16_dtype_config,
]
- binary_op_dtype_configs = [fbgemm_weighted_op_int8_dtype_config]
+ binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
- fixed_qparams_op_dtype_configs = [fbgemm_weighted_op_int8_dtype_config]
+ fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
rnn_op_dtype_configs = [
fbgemm_default_dynamic_int8_dtype_config,
diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py
index 3da807f..ad3671f 100644
--- a/torch/ao/quantization/backend_config/native.py
+++ b/torch/ao/quantization/backend_config/native.py
@@ -21,7 +21,7 @@
# weighted op int8 dtype config
# this is config for ops that has quantized weights, like linear, conv
-weighted_op_int8_dtype_config = DTypeConfig(
+weighted_op_quint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
@@ -91,20 +91,20 @@
"""
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
"""
- conv_dtype_configs = [weighted_op_int8_dtype_config]
+ conv_dtype_configs = [weighted_op_quint8_dtype_config]
linear_dtype_configs = [
- weighted_op_int8_dtype_config,
+ weighted_op_quint8_dtype_config,
default_dynamic_int8_dtype_config,
default_dynamic_float16_dtype_config,
default_op_fp16_dtype_config,
]
binary_op_dtype_configs = [
- weighted_op_int8_dtype_config,
+ default_op_quint8_dtype_config,
default_op_fp16_dtype_config,
]
default_op_dtype_configs = [default_op_quint8_dtype_config]
fixed_qparams_op_dtype_configs = [
- weighted_op_int8_dtype_config,
+ default_op_quint8_dtype_config,
default_op_fp16_dtype_config,
]
share_qparams_op_dtype_configs = [
@@ -138,15 +138,15 @@
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
"""
# TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
- conv_dtype_configs = [weighted_op_int8_dtype_config]
+ conv_dtype_configs = [weighted_op_quint8_dtype_config]
linear_dtype_configs = [
- weighted_op_int8_dtype_config,
+ weighted_op_quint8_dtype_config,
default_dynamic_int8_dtype_config,
default_dynamic_float16_dtype_config,
]
- binary_op_dtype_configs = [weighted_op_int8_dtype_config]
+ binary_op_dtype_configs = [default_op_quint8_dtype_config]
default_op_dtype_configs = [default_op_quint8_dtype_config]
- fixed_qparams_op_dtype_configs = [weighted_op_int8_dtype_config]
+ fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
rnn_op_dtype_configs = [
default_dynamic_int8_dtype_config,
diff --git a/torch/ao/quantization/backend_config/qnnpack.py b/torch/ao/quantization/backend_config/qnnpack.py
index e944670..391acf5 100644
--- a/torch/ao/quantization/backend_config/qnnpack.py
+++ b/torch/ao/quantization/backend_config/qnnpack.py
@@ -121,16 +121,16 @@
qnnpack_default_dynamic_float16_dtype_config,
]
binary_op_dtype_configs = [
- qnnpack_weighted_op_qint8_symmetric_dtype_config,
- qnnpack_weighted_op_quint8_dtype_config,
+ qnnpack_default_op_qint8_symmetric_dtype_config,
+ qnnpack_default_op_quint8_dtype_config,
]
default_op_dtype_configs = [
qnnpack_default_op_qint8_symmetric_dtype_config,
qnnpack_default_op_quint8_dtype_config,
]
fixed_qparams_op_dtype_configs = [
- qnnpack_weighted_op_qint8_symmetric_dtype_config,
- qnnpack_weighted_op_quint8_dtype_config,
+ qnnpack_default_op_qint8_symmetric_dtype_config,
+ qnnpack_default_op_quint8_dtype_config,
]
share_qparams_op_dtype_configs = [
qnnpack_default_op_qint8_symmetric_dtype_config,