[ao] fuser_method_mappings.py fixing public v private (#87516)

Summary: made _get_valid_patterns, _DEFAULT_PATTERN_TO_FUSER_METHOD,
_reverse3, _reverse2, _reverse_sequential_wrapper2,
_DEFAULT_OP_LIST_TO_FUSER_METHOD, _sequential_wrapper2 private

Test Plan: python test/test_public_bindings.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D40709281](https://our.internmc.facebook.com/intern/diff/D40709281)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87516
Approved by: https://github.com/jcaip
diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py
index 52b8f63..2617e7a 100644
--- a/test/quantization/ao_migration/test_quantization.py
+++ b/test/quantization/ao_migration/test_quantization.py
@@ -225,7 +225,7 @@
             "get_fuser_method",
         ]
         dict_list = [
-            "DEFAULT_OP_LIST_TO_FUSER_METHOD"
+            "_DEFAULT_OP_LIST_TO_FUSER_METHOD"
         ]
         self._test_function_import('fuser_method_mappings', function_list)
         self._test_dict_import('fuser_method_mappings', dict_list)
diff --git a/test/quantization/core/test_backend_config.py b/test/quantization/core/test_backend_config.py
index e1e7067..aa9de64 100644
--- a/test/quantization/core/test_backend_config.py
+++ b/test/quantization/core/test_backend_config.py
@@ -14,7 +14,7 @@
     ObservationType,
 )
 from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
-from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
+from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2
 from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter
 from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer
 
@@ -106,7 +106,7 @@
     #  BackendPatternConfig
     # ======================
 
-    _fuser_method = reverse_sequential_wrapper2(nni.LinearReLU)
+    _fuser_method = _reverse_sequential_wrapper2(nni.LinearReLU)
 
     _num_tensor_args_to_observation_type = {
         0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
index abc0bd2..2e8390c 100644
--- a/torch/ao/quantization/__init__.py
+++ b/torch/ao/quantization/__init__.py
@@ -114,7 +114,6 @@
     "get_quantized_operator",
     "get_static_quant_module_class",
     "get_unique_devices_",
-    "get_valid_patterns",
     "is_activation_post_process",
     "load_observer_state_dict",
     "no_observer_set",
@@ -132,12 +131,8 @@
     "quantize_jit",
     "quantize_qat",
     "register_activation_post_process_hook",
-    "reverse2",
-    "reverse3",
-    "reverse_sequential_wrapper2",
     "script_qconfig",
     "script_qconfig_dict",
-    "sequential_wrapper2",
     "swap_module",
     "weight_observer_range_neg_127_to_127",
 ]
diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
index bc6f678..c2f0f72 100644
--- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py
+++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -15,9 +15,9 @@
 )
 from ..fake_quantize import FixedQParamsFakeQuantize
 from ..fuser_method_mappings import (
-    reverse_sequential_wrapper2,
-    reverse2,
-    reverse3,
+    _reverse_sequential_wrapper2,
+    _reverse2,
+    _reverse3,
     fuse_conv_bn,
     fuse_conv_bn_relu,
     fuse_linear_bn,
@@ -115,13 +115,13 @@
     linear_configs.append(
         BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
             .set_dtype_configs(dtype_configs)  # noqa: E131
-            .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU))
+            .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
             .set_fused_module(nni.LinearReLU))
     # linear relu, linear module + functional relu
     linear_configs.append(
         BackendPatternConfig((torch.nn.functional.relu, torch.nn.Linear))
             .set_dtype_configs(dtype_configs)  # noqa: E131
-            .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU))
+            .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
             .set_fused_module(nni.LinearReLU))
 
     # 2.2 linear module + relu, fused module configs
@@ -158,7 +158,7 @@
     linear_configs.append(
         BackendPatternConfig((nn.BatchNorm1d, nn.Linear))
             .set_dtype_configs(dtype_configs)  # noqa: E131
-            .set_fuser_method(reverse2(fuse_linear_bn))
+            .set_fuser_method(_reverse2(fuse_linear_bn))
             .set_fused_module(nni.LinearBn1d))
 
     # 3.2 linear bn fused
@@ -218,13 +218,13 @@
         conv_configs.append(
             BackendPatternConfig((torch.nn.ReLU, convs.root))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+                .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
                 .set_fused_module(convs.fused_conv_relu))
         # conv relu fusion, conv module + functional relu
         conv_configs.append(
             BackendPatternConfig((F.relu, convs.root))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+                .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
                 .set_fused_module(convs.fused_conv_relu))
         # 2.2 conv module + relu fused module configs
         # conv relu, fused module
@@ -273,20 +273,20 @@
         conv_configs.append(
             BackendPatternConfig((convs.bn, convs.root))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse2(fuse_conv_bn))
+                .set_fuser_method(_reverse2(fuse_conv_bn))
                 .set_fused_module(convs.fused_conv_bn))
         # conv + bn + relu module fusion
         conv_configs.append(
             BackendPatternConfig((nn.ReLU, (convs.bn, convs.root)))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse3(fuse_conv_bn_relu))
+                .set_fuser_method(_reverse3(fuse_conv_bn_relu))
                 .set_fused_module(convs.fused_conv_bn_relu))
         # conv + bn + relu functional fusion
         conv_configs.append(
             BackendPatternConfig((F.relu, (convs.bn, convs.root)))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
                 .set_root_module(convs.root)
-                .set_fuser_method(reverse3(fuse_conv_bn_relu))
+                .set_fuser_method(_reverse3(fuse_conv_bn_relu))
                 .set_fused_module(convs.fused_conv_bn_relu))
         # TODO: we can add fusion for torch.relu as well
 
@@ -330,7 +330,7 @@
         conv_configs.append(
             BackendPatternConfig((convs.bn, convs.transpose))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse2(fuse_convtranspose_bn))
+                .set_fuser_method(_reverse2(fuse_convtranspose_bn))
                 .set_root_module(convs.transpose)
                 .set_reference_quantized_module(convs.transpose_reference))
 
@@ -497,13 +497,13 @@
         bn_configs.append(
             BackendPatternConfig((torch.nn.ReLU, bn))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse_sequential_wrapper2(fused_bn))
+                .set_fuser_method(_reverse_sequential_wrapper2(fused_bn))
                 .set_fused_module(fused_bn))
         # bn module + F.relu fusion config
         bn_configs.append(
             BackendPatternConfig((torch.nn.functional.relu, bn))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse_sequential_wrapper2(bn_to_fused_bn[bn]))
+                .set_fuser_method(_reverse_sequential_wrapper2(bn_to_fused_bn[bn]))
                 .set_fused_module(fused_bn))
         bn_configs.append(
             BackendPatternConfig(bn)
diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py
index 2f491b1..1305c32 100644
--- a/torch/ao/quantization/backend_config/backend_config.py
+++ b/torch/ao/quantization/backend_config/backend_config.py
@@ -229,7 +229,7 @@
 
         import torch
         from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
-        from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
+        from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2
 
         weighted_int8_dtype_config = DTypeConfig(
             input_dtype=torch.quint8,
@@ -248,7 +248,7 @@
             .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
             .add_dtype_config(weighted_int8_dtype_config) \
             .set_fused_module(torch.nn.intrinsic.ConvReLU2d) \
-            .set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d))
+            .set_fuser_method(_reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d))
 
         backend_config = BackendConfig("my_backend") \
             .set_backend_pattern_config(linear_config) \
diff --git a/torch/ao/quantization/backend_config/executorch.py b/torch/ao/quantization/backend_config/executorch.py
index 4c0f2a4..3c72932 100644
--- a/torch/ao/quantization/backend_config/executorch.py
+++ b/torch/ao/quantization/backend_config/executorch.py
@@ -7,7 +7,7 @@
 import torch.nn.quantized._reference as nnqr
 from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
 from ._common_operator_config_utils import _Conv2dMetadata
-from ..fuser_method_mappings import reverse_sequential_wrapper2
+from ..fuser_method_mappings import _reverse_sequential_wrapper2
 
 
 __all__ = [
@@ -105,13 +105,13 @@
         conv_configs.append(
             BackendPatternConfig((torch.nn.ReLU, convs.root))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+                .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
                 .set_fused_module(convs.fused_conv_relu))
         # conv module + functional relu
         conv_configs.append(
             BackendPatternConfig((F.relu, convs.root))
                 .set_dtype_configs(dtype_configs)  # noqa: E131
-                .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+                .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
                 .set_fused_module(convs.fused_conv_relu))
         # fused conv relu module
         conv_configs.append(
diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py
index 2e39f87..db4cc9a 100644
--- a/torch/ao/quantization/fuser_method_mappings.py
+++ b/torch/ao/quantization/fuser_method_mappings.py
@@ -10,13 +10,7 @@
     "fuse_conv_bn_relu",
     "fuse_linear_bn",
     "fuse_convtranspose_bn",
-    "sequential_wrapper2",
     "get_fuser_method",
-    "reverse_sequential_wrapper2",
-    "reverse2",
-    "reverse3",
-    "DEFAULT_PATTERN_TO_FUSER_METHOD",
-    "get_valid_patterns",
     "get_fuser_method_new",
 ]
 
@@ -156,7 +150,7 @@
     else:
         return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
 
-def sequential_wrapper2(sequential):
+def _sequential_wrapper2(sequential):
     """ Given a sequential class for two modules, return a function that takes
     is_qat, and then two modules as argument, that ignores the is_qat flag
     and always returns the sequential that combines the two input modules
@@ -165,20 +159,20 @@
         return sequential(m1, m2)
     return fuser_method
 
-DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
+_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
     (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
     (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
     (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
     (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
     (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
     (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
-    (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
-    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
-    (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
+    (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d),
+    (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d),
+    (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d),
     (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
-    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
-    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
-    (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
+    (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU),
+    (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d),
+    (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d),
     (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
     (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
     (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
@@ -190,13 +184,13 @@
     '''
     if additional_fuser_method_mapping is None:
         additional_fuser_method_mapping = {}
-    all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD,
+    all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD,
                                      additional_fuser_method_mapping)
     fuser_method = all_mappings.get(op_list, None)
     assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
     return fuser_method
 
-def reverse_sequential_wrapper2(sequential):
+def _reverse_sequential_wrapper2(sequential):
     """ Given a sequential class for two modules, return a function that takes
     is_qat, and then two modules as argument, that ignores the is_qat flag
     and always returns the sequential that combines the two input modules, with
@@ -206,37 +200,37 @@
         return sequential(m2, m1)
     return fuser_method
 
-def reverse2(f):
+def _reverse2(f):
     def reversed(is_qat, x, y):
         return f(is_qat, y, x)
     return reversed
 
-def reverse3(f):
+def _reverse3(f):
     def reversed(is_qat, x, w):
         y, z = w
         return f(is_qat, z, y, x)
     return reversed
 
-DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
-    (nn.BatchNorm1d, nn.Conv1d): reverse2(fuse_conv_bn),
-    (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu),
-    (nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn),
-    (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
-    (nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
-    (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
-    (nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
-    (nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
-    (nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
-    (nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
-    (nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
-    (nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
-    (nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
-    (nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
-    (nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
-    (nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
+_DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
+    (nn.BatchNorm1d, nn.Conv1d): _reverse2(fuse_conv_bn),
+    (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): _reverse3(fuse_conv_bn_relu),
+    (nn.BatchNorm2d, nn.Conv2d): _reverse2(fuse_conv_bn),
+    (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): _reverse3(fuse_conv_bn_relu),
+    (nn.BatchNorm3d, nn.Conv3d): _reverse2(fuse_conv_bn),
+    (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): _reverse3(fuse_conv_bn_relu),
+    (nn.ReLU, nn.Conv1d): _reverse_sequential_wrapper2(nni.ConvReLU1d),
+    (nn.ReLU, nn.Conv2d): _reverse_sequential_wrapper2(nni.ConvReLU2d),
+    (nn.ReLU, nn.Conv3d): _reverse_sequential_wrapper2(nni.ConvReLU3d),
+    (nn.BatchNorm1d, nn.Linear): _reverse2(fuse_linear_bn),
+    (nn.ReLU, nn.Linear): _reverse_sequential_wrapper2(nni.LinearReLU),
+    (nn.ReLU, nn.BatchNorm2d): _reverse_sequential_wrapper2(nni.BNReLU2d),
+    (nn.ReLU, nn.BatchNorm3d): _reverse_sequential_wrapper2(nni.BNReLU3d),
+    (nn.BatchNorm1d, nn.ConvTranspose1d): _reverse2(fuse_convtranspose_bn),
+    (nn.BatchNorm2d, nn.ConvTranspose2d): _reverse2(fuse_convtranspose_bn),
+    (nn.BatchNorm3d, nn.ConvTranspose3d): _reverse2(fuse_convtranspose_bn),
 }
 
-def get_valid_patterns(op_pattern):
+def _get_valid_patterns(op_pattern):
     """
     Returns a list of valid patterns generated from the op_pattern,
     since MatchAllNode can match all types of nodes,
@@ -261,7 +255,7 @@
     if isinstance(op_pattern, (tuple, list)):
         sub_combs = []
         for sub_pattern in op_pattern:
-            sub_combs.append(get_valid_patterns(sub_pattern))
+            sub_combs.append(_get_valid_patterns(sub_pattern))
         result = list(itertools.product(*sub_combs))
     else:
         result = [op_pattern, MatchAllNode]
@@ -274,9 +268,9 @@
     Would like to implement this first and have a separate PR for deprecation
     """
     if fuser_method_mapping is None:
-        fuser_method_mapping = DEFAULT_PATTERN_TO_FUSER_METHOD
+        fuser_method_mapping = _DEFAULT_PATTERN_TO_FUSER_METHOD
 
-    op_patterns = get_valid_patterns(op_pattern)
+    op_patterns = _get_valid_patterns(op_pattern)
     fuser_method = None
     for op_pattern in op_patterns:
         fuser_method = fuser_method_mapping.get(op_pattern, None)
diff --git a/torch/ao/quantization/fx/README.md b/torch/ao/quantization/fx/README.md
index cba11e9..622acd3 100644
--- a/torch/ao/quantization/fx/README.md
+++ b/torch/ao/quantization/fx/README.md
@@ -81,7 +81,7 @@
 
 ```
 BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
-    .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU))
+    .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
     ._set_root_node_getter(my_root_node_getter)
     ._set_extra_inputs_getter(my_extra_inputs_getter)
 ```
diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py
index 50520b3..22f4e63 100644
--- a/torch/quantization/fuser_method_mappings.py
+++ b/torch/quantization/fuser_method_mappings.py
@@ -10,6 +10,6 @@
     fuse_conv_bn,
     fuse_conv_bn_relu,
     fuse_linear_bn,
-    DEFAULT_OP_LIST_TO_FUSER_METHOD,
+    _DEFAULT_OP_LIST_TO_FUSER_METHOD,
     get_fuser_method,
 )