[Quant][fx] Add get_default_qconfig_mapping
Summary: This follows https://github.com/pytorch/pytorch/pull/78452,
which replaced the qconfig_dict with QConfigMapping. This PR
additionally replaces get_default_*qconfig_dict with
get_default_*qconfig_mapping. For backward compatibility, we
deprecate the old functions instead of removing them.
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
Reviewers: jerryzh168, vkuzo
Subscribers: jerryzh168, vkuzo, supriyar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79618
Approved by: https://github.com/jerryzh168
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index ad3b614..f2d09e7 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -50,8 +50,8 @@
float_qparams_weight_only_qconfig_4bit,
get_default_qconfig,
get_default_qat_qconfig,
- get_default_qconfig_dict,
- get_default_qat_qconfig_dict,
+ get_default_qconfig_mapping,
+ get_default_qat_qconfig_mapping,
fuse_modules,
fuse_modules_qat,
prepare,
@@ -1854,7 +1854,7 @@
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
m = model(relu).eval()
- qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm")
+ qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
@@ -4388,7 +4388,7 @@
for M, is_qat in options:
m = M1().eval()
example_inputs = (torch.randn(1, 3, 3, 3),)
- m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
+ m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
m = convert_fx(m)
node_list = [
ns.call_function(torch.quantize_per_tensor),
@@ -4401,7 +4401,7 @@
expected_node_list=node_list)
m = M2().eval()
- m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
+ m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
m = convert_fx(m)
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 0,
@@ -4426,7 +4426,7 @@
return x
m = M().eval()
- mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),))
+ mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
found_stack_trace = False
for n in mp.graph.nodes:
@@ -4541,7 +4541,7 @@
return x
backends = ["qnnpack", "fbgemm"]
- for func in [get_default_qconfig_dict, get_default_qat_qconfig_dict]:
+ for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
for backend in backends:
m = M().eval()
qconfig_dict = func(backend)
@@ -4581,8 +4581,8 @@
prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
# Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
- _test(prepare_fx, get_default_qconfig_dict())
- _test(prepare_qat_fx, get_default_qat_qconfig_dict())
+ _test(prepare_fx, get_default_qconfig_mapping())
+ _test(prepare_qat_fx, get_default_qat_qconfig_mapping())
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py
index c093d71..363b01b 100644
--- a/torch/ao/quantization/qconfig.py
+++ b/torch/ao/quantization/qconfig.py
@@ -335,52 +335,17 @@
eps=2 ** -12),
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
-def _get_default_qconfig_dict_helper(qconfig, qconfig_transpose):
- return {
- "": qconfig,
- "object_type": [("reshape", default_reuse_input_qconfig),
- (torch.nn.Conv1d, qconfig),
- (torch.nn.Conv2d, qconfig),
- (torch.nn.Conv3d, qconfig),
- (torch.nn.ConvTranspose1d, qconfig_transpose),
- (torch.nn.ConvTranspose2d, qconfig_transpose),
- (torch.nn.ConvTranspose3d, qconfig_transpose),
- (torch.nn.Linear, qconfig),
- (torch.nn.functional.conv1d, qconfig),
- (torch.nn.functional.conv2d, qconfig),
- (torch.nn.functional.conv3d, qconfig),
- (torch.nn.functional.conv_transpose1d, qconfig_transpose),
- (torch.nn.functional.conv_transpose2d, qconfig_transpose),
- (torch.nn.functional.conv_transpose3d, qconfig_transpose),
- (torch.nn.functional.linear, qconfig),
- (torch.nn.ReLU, qconfig),
- (torch.nn.functional.relu, qconfig),
- (torch.relu, qconfig),
- (torch.nn.BatchNorm1d, qconfig),
- (torch.nn.BatchNorm2d, qconfig),
- (torch.nn.BatchNorm3d, qconfig)]}
-
def get_default_qconfig_dict(backend='fbgemm', version=0):
- qconfig = get_default_qconfig(backend, version)
- qconfig_transpose = qconfig
- # default_per_channel_weight_observer is not currently compatible with fbgemm backend
- # so we have to modify the weight observer to default_weight_observer or another
- # per tensor supported observer.
- # see https://github.com/pytorch/pytorch/issues/47535
- if backend == "fbgemm":
- qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_observer)
- return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
+ warnings.warn(
+ "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
+ "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
+ return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
- qconfig = get_default_qat_qconfig(backend, version)
- qconfig_transpose = qconfig
- # default_per_channel_weight_observer is not currently compatible with fbgemm backend
- # so we have to modify the weight observer to default_weight_observer or another
- # per tensor supported observer
- # see https://github.com/pytorch/pytorch/issues/47535
- if backend == "fbgemm":
- qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_fake_quant)
- return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
+ warnings.warn(
+ "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
+ "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
+ return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
def assert_valid_qconfig(qconfig: Optional[QConfig],
mod: torch.nn.Module) -> None:
diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py
index 3d9351c..7caa6f9 100644
--- a/torch/ao/quantization/qconfig_mapping.py
+++ b/torch/ao/quantization/qconfig_mapping.py
@@ -2,10 +2,22 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, Tuple, Union
-from .qconfig import QConfigAny
+import torch
+
+from .fake_quantize import default_weight_fake_quant
+from .observer import default_weight_observer
+from .qconfig import (
+ default_reuse_input_qconfig,
+ get_default_qconfig,
+ get_default_qat_qconfig,
+ QConfig,
+ QConfigAny
+)
__all__ = [
+ "get_default_qconfig_mapping",
+ "get_default_qat_qconfig_mapping",
"QConfigMapping",
]
@@ -18,6 +30,62 @@
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
+def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int):
+ """
+ Return the default QConfigMapping for the given quantization type and backend.
+ """
+ if is_qat:
+ qconfig = get_default_qat_qconfig(backend, version)
+ else:
+ qconfig = get_default_qconfig(backend, version)
+
+ # default_per_channel_weight_observer is not currently compatible with fbgemm backend
+ # so we have to modify the weight observer to default_weight_observer or another
+ # per tensor supported observer.
+ # see https://github.com/pytorch/pytorch/issues/47535
+ if backend == "fbgemm":
+ default_weight = default_weight_fake_quant if is_qat else default_weight_observer
+ qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
+ else:
+ qconfig_transpose = qconfig
+
+ return QConfigMapping() \
+ .set_global(qconfig) \
+ .set_object_type("reshape", default_reuse_input_qconfig) \
+ .set_object_type(torch.nn.Conv1d, qconfig) \
+ .set_object_type(torch.nn.Conv2d, qconfig) \
+ .set_object_type(torch.nn.Conv3d, qconfig) \
+ .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
+ .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
+ .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
+ .set_object_type(torch.nn.Linear, qconfig) \
+ .set_object_type(torch.nn.functional.conv1d, qconfig) \
+ .set_object_type(torch.nn.functional.conv2d, qconfig) \
+ .set_object_type(torch.nn.functional.conv3d, qconfig) \
+ .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
+ .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
+ .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
+ .set_object_type(torch.nn.functional.linear, qconfig) \
+ .set_object_type(torch.nn.ReLU, qconfig) \
+ .set_object_type(torch.nn.functional.relu, qconfig) \
+ .set_object_type(torch.relu, qconfig) \
+ .set_object_type(torch.nn.BatchNorm1d, qconfig) \
+ .set_object_type(torch.nn.BatchNorm2d, qconfig) \
+ .set_object_type(torch.nn.BatchNorm3d, qconfig)
+
+def get_default_qconfig_mapping(backend="fbgemm", version=0):
+ """
+ Return the default QConfigMapping for post training quantization.
+ """
+ return _get_default_qconfig_mapping(False, backend, version)
+
+def get_default_qat_qconfig_mapping(backend="fbgemm", version=1):
+ """
+ Return the default QConfigMapping for quantization aware training.
+ """
+ return _get_default_qconfig_mapping(True, backend, version)
+
+
class QConfigMapping:
"""
Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.