[ao] qconfig_mapping_utils.py fixing public v private (#87517)
Summary: made _get_object_type_qconfig, _get_module_name_regex_qconfig,
_get_module_name_qconfig, _maybe_adjust_qconfig_for_module_type_or_name,
_get_flattened_qconfig_dict _update_qconfig_for_qat private
Test Plan: python test/test_public_bindings.py
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D40709279](https://our.internmc.facebook.com/intern/diff/D40709279)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87517
Approved by: https://github.com/jcaip
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index 236a558..8c75658 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -99,9 +99,9 @@
)
from torch.ao.quantization.qconfig_mapping_utils import (
- get_object_type_qconfig,
- get_module_name_qconfig,
- get_module_name_regex_qconfig,
+ _get_object_type_qconfig,
+ _get_module_name_qconfig,
+ _get_module_name_regex_qconfig,
)
from torch.ao.quantization.fx.pattern_utils import (
@@ -1876,9 +1876,9 @@
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3)
self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3)
self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
- self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3)
- self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2)
- self.assertEqual(get_object_type_qconfig(qconfig_mapping, "nomatch", None), None)
+ self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3)
+ self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2)
+ self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None)
def test_qconfig_mapping_set_module_name_regex(self):
qconfig1 = get_default_qconfig()
@@ -1898,11 +1898,11 @@
qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3)
self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3)
self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
- self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3)
- self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3)
- self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2)
- self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2)
- self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None)
+ self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3)
+ self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3)
+ self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2)
+ self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2)
+ self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None)
def test_qconfig_mapping_set_module_name(self):
qconfig1 = get_default_qconfig()
@@ -1922,9 +1922,9 @@
qconfig_mapping.set_module_name("mod1", qconfig3)
self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3)
self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
- self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3)
- self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2)
- self.assertEqual(get_module_name_qconfig(qconfig_mapping, "nomatch", None), None)
+ self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3)
+ self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2)
+ self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None)
def test_qconfig_mapping_set_module_name_object_type_order(self):
qconfig1 = get_default_qconfig()
diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py
index 74eb8f1..b5e9cf3 100644
--- a/torch/ao/quantization/fx/convert.py
+++ b/torch/ao/quantization/fx/convert.py
@@ -24,7 +24,7 @@
)
from ..qconfig_mapping import QConfigMapping
from ..qconfig_mapping_utils import (
- update_qconfig_for_qat,
+ _update_qconfig_for_qat,
)
from .qconfig_mapping_utils import (
generate_node_name_to_qconfig,
@@ -563,7 +563,7 @@
modules_copy = copy.deepcopy(modules)
if model._is_qat:
- update_qconfig_for_qat(qconfig_mapping, {})
+ _update_qconfig_for_qat(qconfig_mapping, {})
update_qconfig_for_fusion(model, qconfig_mapping)
compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 160b80a..281bd96 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -28,8 +28,8 @@
QConfigMapping,
)
from ..qconfig_mapping_utils import (
- get_flattened_qconfig_dict,
- update_qconfig_for_qat,
+ _get_flattened_qconfig_dict,
+ _update_qconfig_for_qat,
)
from .qconfig_mapping_utils import (
generate_node_name_to_qconfig,
@@ -1587,14 +1587,14 @@
update_qconfig_for_fusion(model, qconfig_mapping)
update_qconfig_for_fusion(model, _equalization_config)
- flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_mapping)
+ flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
if is_qat:
module_to_qat_module = get_module_to_qat_module(backend_config)
qat_swap_modules(model, module_to_qat_module)
- update_qconfig_for_qat(qconfig_mapping, {})
+ _update_qconfig_for_qat(qconfig_mapping, {})
# mapping from fully qualified module name to module instance
# for example,
diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py
index 2abfaf8..66dffd5 100644
--- a/torch/ao/quantization/fx/qconfig_mapping_utils.py
+++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py
@@ -29,8 +29,8 @@
QConfigMapping,
)
from ..qconfig_mapping_utils import (
- get_object_type_qconfig,
- maybe_adjust_qconfig_for_module_type_or_name,
+ _get_object_type_qconfig,
+ _maybe_adjust_qconfig_for_module_type_or_name,
)
@@ -121,17 +121,17 @@
qconfig = None
if node.op == "get_attr":
module_name, _ = _parent_name(node.target)
- qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+ qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, type(modules[module_name]), module_name, global_qconfig)
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
elif node.op == "call_function":
# precedence: module_name_qconfig
# > function_qconfig > global_qconfig
# module_name takes precedence over function qconfig
- function_qconfig = get_object_type_qconfig(
+ function_qconfig = _get_object_type_qconfig(
qconfig_mapping, node.target, global_qconfig)
module_path, module_type = node_name_to_scope[node.name]
- qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+ qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, module_type, module_path, function_qconfig)
cur_object_type_idx = \
@@ -146,11 +146,11 @@
# first use node.target (string) to get the qconfig
# this is to support configs like
# "object_type": [("reshpe", qconfig)]
- qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+ qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, node.target, module_path, global_qconfig)
# if there is no special config for the method, we'll fall back to the
# config for the module that contains the call_method node
- qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+ qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, module_type, module_path, qconfig)
# currently call_method does not support modifying qconfig
# by order, we can add this later if it is needed.
@@ -160,7 +160,7 @@
# if the node is an observer, just continue - don't add it to the qconfig_map
if is_activation_post_process(modules[node.target]):
continue
- qconfig = maybe_adjust_qconfig_for_module_type_or_name(
+ qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)
module_path, module_type = node_name_to_scope[node.name]
diff --git a/torch/ao/quantization/qconfig_mapping_utils.py b/torch/ao/quantization/qconfig_mapping_utils.py
index 09bce4f..0109729 100644
--- a/torch/ao/quantization/qconfig_mapping_utils.py
+++ b/torch/ao/quantization/qconfig_mapping_utils.py
@@ -1,5 +1,5 @@
import re
-from typing import Dict, Callable, Union
+from typing import Dict, Callable, Union, List
from .utils import (
get_combined_dict,
@@ -12,25 +12,18 @@
from .qconfig_mapping import QConfigMapping
-# TODO: revisit this list. Many helper methods shouldn't be public
-__all__ = [
- "get_flattened_qconfig_dict",
- "get_object_type_qconfig",
- "get_module_name_qconfig",
- "get_module_name_regex_qconfig",
- "maybe_adjust_qconfig_for_module_type_or_name",
- "update_qconfig_for_qat",
+__all__: List[str] = [
]
-def get_object_type_qconfig(
+def _get_object_type_qconfig(
qconfig_mapping: QConfigMapping,
object_type: Union[Callable, str],
fallback_qconfig: QConfigAny) -> QConfigAny:
return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
-def get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
if re.match(regex_pattern, module_name):
# first match wins
@@ -38,7 +31,7 @@
return fallback_qconfig
-def get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
if module_name == '':
# module name qconfig not found
return fallback_qconfig
@@ -46,23 +39,23 @@
return qconfig_mapping.module_name_qconfigs[module_name]
else:
parent, _ = _parent_name(module_name)
- return get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
+ return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
-def maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
+def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
- module_type_qconfig = get_object_type_qconfig(
+ module_type_qconfig = _get_object_type_qconfig(
qconfig_mapping, module_type, global_qconfig)
- module_name_regex_qconfig = get_module_name_regex_qconfig(
+ module_name_regex_qconfig = _get_module_name_regex_qconfig(
qconfig_mapping, module_name, module_type_qconfig)
- module_name_qconfig = get_module_name_qconfig(
+ module_name_qconfig = _get_module_name_qconfig(
qconfig_mapping, module_name, module_name_regex_qconfig)
return module_name_qconfig
-def get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
+def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
""" flatten the global, object_type and module_name qconfig
to the same qconfig_dict so that it can be used by
propagate_qconfig_ function.
@@ -94,7 +87,7 @@
return flattened
-def update_qconfig_for_qat(
+def _update_qconfig_for_qat(
qconfig_mapping: QConfigMapping,
additional_qat_module_mapping: Dict[Callable, Callable]):
"""