[quant][be] Easier way to override default in QConfigMapping (#99888)
Summary: This commit adds a private helper function to override
the default QConfig in the default QConfigMapping. Previously we
needed to override all the object_types manually while skipping
the fixed qparams ops. This led to duplicate code every time
someone wanted a new default QConfig. After this commit, we can
just call the same helper function instead.
Test Plan:
python test/test_quantization.py TestQuantizeFx
Reviewers: jerryzh168, vkuzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99888
Approved by: https://github.com/vkuzo, https://github.com/jerryzh168
diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py
index 76bdd2c..0e50ed9 100644
--- a/torch/ao/quantization/qconfig_mapping.py
+++ b/torch/ao/quantization/qconfig_mapping.py
@@ -135,31 +135,40 @@
"""
return _get_default_qconfig_mapping(True, backend, version)
-def _get_symmetric_qnnpack_qconfig_mapping():
+def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping:
"""
Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig`
as the default QConfig.
"""
- qconfig_mapping = get_default_qconfig_mapping("qnnpack") \
- .set_global(default_symmetric_qnnpack_qconfig)
- for pattern in qconfig_mapping.object_type_qconfigs.keys():
- if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER:
- qconfig_mapping.set_object_type(pattern, default_symmetric_qnnpack_qconfig)
- return qconfig_mapping
+ default_qconfig = default_symmetric_qnnpack_qconfig
+ return _get_default_qconfig_mapping_with_default_qconfig(False, "qnnpack", default_qconfig)
-def _get_symmetric_qnnpack_qat_qconfig_mapping():
+def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping:
"""
Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig`
as the default QConfig.
"""
- qconfig_mapping = get_default_qconfig_mapping("qnnpack") \
- .set_global(default_symmetric_qnnpack_qat_qconfig)
+ default_qconfig = default_symmetric_qnnpack_qat_qconfig
+ return _get_default_qconfig_mapping_with_default_qconfig(True, "qnnpack", default_qconfig)
+
+def _get_default_qconfig_mapping_with_default_qconfig(
+ is_qat: bool,
+ backend: str,
+ default_qconfig: QConfig,
+) -> QConfigMapping:
+ """
+ Return a QConfigMapping that uses the provided qconfig as the default QConfig.
+ """
+ if is_qat:
+ qconfig_mapping = get_default_qat_qconfig_mapping(backend)
+ else:
+ qconfig_mapping = get_default_qconfig_mapping(backend)
+ qconfig_mapping.set_global(default_qconfig)
for pattern in qconfig_mapping.object_type_qconfigs.keys():
if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER:
- qconfig_mapping.set_object_type(pattern, default_symmetric_qnnpack_qat_qconfig)
+ qconfig_mapping.set_object_type(pattern, default_qconfig)
return qconfig_mapping
-
_QCONFIG_STYLE_ORDER: List[str] = [
"global_qconfig",
"object_type_qconfigs",