[Quant][fx] Add get_default_qconfig_mapping

Summary: This follows https://github.com/pytorch/pytorch/pull/78452,
which replaced the qconfig_dict with QConfigMapping. This PR
additionally replaces get_default_*qconfig_dict with
get_default_*qconfig_mapping. For backward compatibility, we
deprecate the old functions instead of removing them.

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo, supriyar

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79618

Approved by: https://github.com/jerryzh168
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index ad3b614..f2d09e7 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -50,8 +50,8 @@
     float_qparams_weight_only_qconfig_4bit,
     get_default_qconfig,
     get_default_qat_qconfig,
-    get_default_qconfig_dict,
-    get_default_qat_qconfig_dict,
+    get_default_qconfig_mapping,
+    get_default_qat_qconfig_mapping,
     fuse_modules,
     fuse_modules_qat,
     prepare,
@@ -1854,7 +1854,7 @@
         for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
             for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
                 m = model(relu).eval()
-                qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm")
+                qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
                 # should not crash as in https://github.com/pytorch/pytorch/issues/75825
                 prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
 
@@ -4388,7 +4388,7 @@
         for M, is_qat in options:
             m = M1().eval()
             example_inputs = (torch.randn(1, 3, 3, 3),)
-            m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
+            m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
             m = convert_fx(m)
             node_list = [
                 ns.call_function(torch.quantize_per_tensor),
@@ -4401,7 +4401,7 @@
                 expected_node_list=node_list)
 
             m = M2().eval()
-            m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
+            m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
             m = convert_fx(m)
             node_occurrence = {
                 ns.call_function(torch.quantize_per_tensor): 0,
@@ -4426,7 +4426,7 @@
                 return x
 
         m = M().eval()
-        mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),))
+        mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
 
         found_stack_trace = False
         for n in mp.graph.nodes:
@@ -4541,7 +4541,7 @@
                 return x
 
         backends = ["qnnpack", "fbgemm"]
-        for func in [get_default_qconfig_dict, get_default_qat_qconfig_dict]:
+        for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
             for backend in backends:
                 m = M().eval()
                 qconfig_dict = func(backend)
@@ -4581,8 +4581,8 @@
             prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
 
         # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
-        _test(prepare_fx, get_default_qconfig_dict())
-        _test(prepare_qat_fx, get_default_qat_qconfig_dict())
+        _test(prepare_fx, get_default_qconfig_mapping())
+        _test(prepare_qat_fx, get_default_qat_qconfig_mapping())
 
 @skipIfNoFBGEMM
 class TestQuantizeFxOps(QuantizationTestCase):
diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py
index c093d71..363b01b 100644
--- a/torch/ao/quantization/qconfig.py
+++ b/torch/ao/quantization/qconfig.py
@@ -335,52 +335,17 @@
                                                        eps=2 ** -12),
     weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
 
-def _get_default_qconfig_dict_helper(qconfig, qconfig_transpose):
-    return {
-        "": qconfig,
-        "object_type": [("reshape", default_reuse_input_qconfig),
-                        (torch.nn.Conv1d, qconfig),
-                        (torch.nn.Conv2d, qconfig),
-                        (torch.nn.Conv3d, qconfig),
-                        (torch.nn.ConvTranspose1d, qconfig_transpose),
-                        (torch.nn.ConvTranspose2d, qconfig_transpose),
-                        (torch.nn.ConvTranspose3d, qconfig_transpose),
-                        (torch.nn.Linear, qconfig),
-                        (torch.nn.functional.conv1d, qconfig),
-                        (torch.nn.functional.conv2d, qconfig),
-                        (torch.nn.functional.conv3d, qconfig),
-                        (torch.nn.functional.conv_transpose1d, qconfig_transpose),
-                        (torch.nn.functional.conv_transpose2d, qconfig_transpose),
-                        (torch.nn.functional.conv_transpose3d, qconfig_transpose),
-                        (torch.nn.functional.linear, qconfig),
-                        (torch.nn.ReLU, qconfig),
-                        (torch.nn.functional.relu, qconfig),
-                        (torch.relu, qconfig),
-                        (torch.nn.BatchNorm1d, qconfig),
-                        (torch.nn.BatchNorm2d, qconfig),
-                        (torch.nn.BatchNorm3d, qconfig)]}
-
 def get_default_qconfig_dict(backend='fbgemm', version=0):
-    qconfig = get_default_qconfig(backend, version)
-    qconfig_transpose = qconfig
-    # default_per_channel_weight_observer is not currently compatible with fbgemm backend
-    # so we have to modify the weight observer to default_weight_observer or another
-    # per tensor supported observer.
-    # see https://github.com/pytorch/pytorch/issues/47535
-    if backend == "fbgemm":
-        qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_observer)
-    return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
+    warnings.warn(
+        "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
+        "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
+    return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
 
 def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
-    qconfig = get_default_qat_qconfig(backend, version)
-    qconfig_transpose = qconfig
-    # default_per_channel_weight_observer is not currently compatible with fbgemm backend
-    # so we have to modify the weight observer to default_weight_observer or another
-    # per tensor supported observer
-    # see https://github.com/pytorch/pytorch/issues/47535
-    if backend == "fbgemm":
-        qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_fake_quant)
-    return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
+    warnings.warn(
+        "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
+        "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
+    return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
 
 def assert_valid_qconfig(qconfig: Optional[QConfig],
                          mod: torch.nn.Module) -> None:
diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py
index 3d9351c..7caa6f9 100644
--- a/torch/ao/quantization/qconfig_mapping.py
+++ b/torch/ao/quantization/qconfig_mapping.py
@@ -2,10 +2,22 @@
 from collections import OrderedDict
 from typing import Any, Callable, Dict, Tuple, Union
 
-from .qconfig import QConfigAny
+import torch
+
+from .fake_quantize import default_weight_fake_quant
+from .observer import default_weight_observer
+from .qconfig import (
+    default_reuse_input_qconfig,
+    get_default_qconfig,
+    get_default_qat_qconfig,
+    QConfig,
+    QConfigAny
+)
 
 
 __all__ = [
+    "get_default_qconfig_mapping",
+    "get_default_qat_qconfig_mapping",
     "QConfigMapping",
 ]
 
@@ -18,6 +30,62 @@
 MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
 
 
+def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int):
+    """
+    Return the default QConfigMapping for the given quantization type and backend.
+    """
+    if is_qat:
+        qconfig = get_default_qat_qconfig(backend, version)
+    else:
+        qconfig = get_default_qconfig(backend, version)
+
+    # default_per_channel_weight_observer is not currently compatible with fbgemm backend
+    # so we have to modify the weight observer to default_weight_observer or another
+    # per tensor supported observer.
+    # see https://github.com/pytorch/pytorch/issues/47535
+    if backend == "fbgemm":
+        default_weight = default_weight_fake_quant if is_qat else default_weight_observer
+        qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
+    else:
+        qconfig_transpose = qconfig
+
+    return QConfigMapping() \
+        .set_global(qconfig) \
+        .set_object_type("reshape", default_reuse_input_qconfig) \
+        .set_object_type(torch.nn.Conv1d, qconfig) \
+        .set_object_type(torch.nn.Conv2d, qconfig) \
+        .set_object_type(torch.nn.Conv3d, qconfig) \
+        .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
+        .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
+        .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
+        .set_object_type(torch.nn.Linear, qconfig) \
+        .set_object_type(torch.nn.functional.conv1d, qconfig) \
+        .set_object_type(torch.nn.functional.conv2d, qconfig) \
+        .set_object_type(torch.nn.functional.conv3d, qconfig) \
+        .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
+        .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
+        .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
+        .set_object_type(torch.nn.functional.linear, qconfig) \
+        .set_object_type(torch.nn.ReLU, qconfig) \
+        .set_object_type(torch.nn.functional.relu, qconfig) \
+        .set_object_type(torch.relu, qconfig) \
+        .set_object_type(torch.nn.BatchNorm1d, qconfig) \
+        .set_object_type(torch.nn.BatchNorm2d, qconfig) \
+        .set_object_type(torch.nn.BatchNorm3d, qconfig)
+
+def get_default_qconfig_mapping(backend="fbgemm", version=0):
+    """
+    Return the default QConfigMapping for post training quantization.
+    """
+    return _get_default_qconfig_mapping(False, backend, version)
+
+def get_default_qat_qconfig_mapping(backend="fbgemm", version=1):
+    """
+    Return the default QConfigMapping for quantization aware training.
+    """
+    return _get_default_qconfig_mapping(True, backend, version)
+
+
 class QConfigMapping:
     """
     Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.