[ao] qconfig_mapping_utils.py fixing public v private (#87517)

Summary: made _get_object_type_qconfig, _get_module_name_regex_qconfig,
_get_module_name_qconfig, _maybe_adjust_qconfig_for_module_type_or_name,
_get_flattened_qconfig_dict _update_qconfig_for_qat private

Test Plan: python test/test_public_bindings.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D40709279](https://our.internmc.facebook.com/intern/diff/D40709279)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87517
Approved by: https://github.com/jcaip
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index 236a558..8c75658 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -99,9 +99,9 @@
 )
 
 from torch.ao.quantization.qconfig_mapping_utils import (
-    get_object_type_qconfig,
-    get_module_name_qconfig,
-    get_module_name_regex_qconfig,
+    _get_object_type_qconfig,
+    _get_module_name_qconfig,
+    _get_module_name_regex_qconfig,
 )
 
 from torch.ao.quantization.fx.pattern_utils import (
@@ -1876,9 +1876,9 @@
         qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3)
         self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3)
         self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
-        self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3)
-        self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2)
-        self.assertEqual(get_object_type_qconfig(qconfig_mapping, "nomatch", None), None)
+        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3)
+        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2)
+        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None)
 
     def test_qconfig_mapping_set_module_name_regex(self):
         qconfig1 = get_default_qconfig()
@@ -1898,11 +1898,11 @@
         qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3)
         self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3)
         self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
-        self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3)
-        self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3)
-        self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2)
-        self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2)
-        self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None)
+        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3)
+        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3)
+        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2)
+        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2)
+        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None)
 
     def test_qconfig_mapping_set_module_name(self):
         qconfig1 = get_default_qconfig()
@@ -1922,9 +1922,9 @@
         qconfig_mapping.set_module_name("mod1", qconfig3)
         self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3)
         self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
-        self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3)
-        self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2)
-        self.assertEqual(get_module_name_qconfig(qconfig_mapping, "nomatch", None), None)
+        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3)
+        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2)
+        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None)
 
     def test_qconfig_mapping_set_module_name_object_type_order(self):
         qconfig1 = get_default_qconfig()
diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py
index 74eb8f1..b5e9cf3 100644
--- a/torch/ao/quantization/fx/convert.py
+++ b/torch/ao/quantization/fx/convert.py
@@ -24,7 +24,7 @@
 )
 from ..qconfig_mapping import QConfigMapping
 from ..qconfig_mapping_utils import (
-    update_qconfig_for_qat,
+    _update_qconfig_for_qat,
 )
 from .qconfig_mapping_utils import (
     generate_node_name_to_qconfig,
@@ -563,7 +563,7 @@
         modules_copy = copy.deepcopy(modules)
 
         if model._is_qat:
-            update_qconfig_for_qat(qconfig_mapping, {})
+            _update_qconfig_for_qat(qconfig_mapping, {})
         update_qconfig_for_fusion(model, qconfig_mapping)
 
         compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping)  # type: ignore[arg-type]
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 160b80a..281bd96 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -28,8 +28,8 @@
     QConfigMapping,
 )
 from ..qconfig_mapping_utils import (
-    get_flattened_qconfig_dict,
-    update_qconfig_for_qat,
+    _get_flattened_qconfig_dict,
+    _update_qconfig_for_qat,
 )
 from .qconfig_mapping_utils import (
     generate_node_name_to_qconfig,
@@ -1587,14 +1587,14 @@
 
     update_qconfig_for_fusion(model, qconfig_mapping)
     update_qconfig_for_fusion(model, _equalization_config)
-    flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_mapping)
+    flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
     # TODO: support regex as well
     propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
 
     if is_qat:
         module_to_qat_module = get_module_to_qat_module(backend_config)
         qat_swap_modules(model, module_to_qat_module)
-        update_qconfig_for_qat(qconfig_mapping, {})
+        _update_qconfig_for_qat(qconfig_mapping, {})
 
     # mapping from fully qualified module name to module instance
     # for example,
diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py
index 2abfaf8..66dffd5 100644
--- a/torch/ao/quantization/fx/qconfig_mapping_utils.py
+++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py
@@ -29,8 +29,8 @@
     QConfigMapping,
 )
 from ..qconfig_mapping_utils import (
-    get_object_type_qconfig,
-    maybe_adjust_qconfig_for_module_type_or_name,
+    _get_object_type_qconfig,
+    _maybe_adjust_qconfig_for_module_type_or_name,
 )
 
 
@@ -121,17 +121,17 @@
         qconfig = None
         if node.op == "get_attr":
             module_name, _ = _parent_name(node.target)
-            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
                 qconfig_mapping, type(modules[module_name]), module_name, global_qconfig)
             qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
         elif node.op == "call_function":
             # precedence: module_name_qconfig
             # > function_qconfig > global_qconfig
             # module_name takes precedence over function qconfig
-            function_qconfig = get_object_type_qconfig(
+            function_qconfig = _get_object_type_qconfig(
                 qconfig_mapping, node.target, global_qconfig)
             module_path, module_type = node_name_to_scope[node.name]
-            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
                 qconfig_mapping, module_type, module_path, function_qconfig)
 
             cur_object_type_idx = \
@@ -146,11 +146,11 @@
             # first use node.target (string) to get the qconfig
             # this is to support configs like
             # "object_type": [("reshpe", qconfig)]
-            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
                 qconfig_mapping, node.target, module_path, global_qconfig)
             # if there is no special config for the method, we'll fall back to the
             # config for the module that contains the call_method node
-            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
                 qconfig_mapping, module_type, module_path, qconfig)
             # currently call_method does not support modifying qconfig
             # by order, we can add this later if it is needed.
@@ -160,7 +160,7 @@
             # if the node is an observer, just continue - don't add it to the qconfig_map
             if is_activation_post_process(modules[node.target]):
                 continue
-            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
                 qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)
 
             module_path, module_type = node_name_to_scope[node.name]
diff --git a/torch/ao/quantization/qconfig_mapping_utils.py b/torch/ao/quantization/qconfig_mapping_utils.py
index 09bce4f..0109729 100644
--- a/torch/ao/quantization/qconfig_mapping_utils.py
+++ b/torch/ao/quantization/qconfig_mapping_utils.py
@@ -1,5 +1,5 @@
 import re
-from typing import Dict, Callable, Union
+from typing import Dict, Callable, Union, List
 
 from .utils import (
     get_combined_dict,
@@ -12,25 +12,18 @@
 from .qconfig_mapping import QConfigMapping
 
 
-# TODO: revisit this list. Many helper methods shouldn't be public
-__all__ = [
-    "get_flattened_qconfig_dict",
-    "get_object_type_qconfig",
-    "get_module_name_qconfig",
-    "get_module_name_regex_qconfig",
-    "maybe_adjust_qconfig_for_module_type_or_name",
-    "update_qconfig_for_qat",
+__all__: List[str] = [
 ]
 
 
-def get_object_type_qconfig(
+def _get_object_type_qconfig(
         qconfig_mapping: QConfigMapping,
         object_type: Union[Callable, str],
         fallback_qconfig: QConfigAny) -> QConfigAny:
     return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
 
 
-def get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
     for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
         if re.match(regex_pattern, module_name):
             # first match wins
@@ -38,7 +31,7 @@
     return fallback_qconfig
 
 
-def get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
     if module_name == '':
         # module name qconfig not found
         return fallback_qconfig
@@ -46,23 +39,23 @@
         return qconfig_mapping.module_name_qconfigs[module_name]
     else:
         parent, _ = _parent_name(module_name)
-        return get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
+        return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
 
 
-def maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
+def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
     # get qconfig for module_name,
     # fallback to module_name_regex_qconfig, module_type_qconfig,
     # global_qconfig if necessary
-    module_type_qconfig = get_object_type_qconfig(
+    module_type_qconfig = _get_object_type_qconfig(
         qconfig_mapping, module_type, global_qconfig)
-    module_name_regex_qconfig = get_module_name_regex_qconfig(
+    module_name_regex_qconfig = _get_module_name_regex_qconfig(
         qconfig_mapping, module_name, module_type_qconfig)
-    module_name_qconfig = get_module_name_qconfig(
+    module_name_qconfig = _get_module_name_qconfig(
         qconfig_mapping, module_name, module_name_regex_qconfig)
     return module_name_qconfig
 
 
-def get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
+def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
     """ flatten the global, object_type and module_name qconfig
     to the same qconfig_dict so that it can be used by
     propagate_qconfig_ function.
@@ -94,7 +87,7 @@
     return flattened
 
 
-def update_qconfig_for_qat(
+def _update_qconfig_for_qat(
         qconfig_mapping: QConfigMapping,
         additional_qat_module_mapping: Dict[Callable, Callable]):
     """