[ao] fuser_method_mappings.py fixing public v private (#87516)
Summary: made _get_valid_patterns, _DEFAULT_PATTERN_TO_FUSER_METHOD,
_reverse3, _reverse2, _reverse_sequential_wrapper2,
_DEFAULT_OP_LIST_TO_FUSER_METHOD, _sequential_wrapper2 private
Test Plan: python test/test_public_bindings.py
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D40709281](https://our.internmc.facebook.com/intern/diff/D40709281)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87516
Approved by: https://github.com/jcaip
diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py
index 52b8f63..2617e7a 100644
--- a/test/quantization/ao_migration/test_quantization.py
+++ b/test/quantization/ao_migration/test_quantization.py
@@ -225,7 +225,7 @@
"get_fuser_method",
]
dict_list = [
- "DEFAULT_OP_LIST_TO_FUSER_METHOD"
+ "_DEFAULT_OP_LIST_TO_FUSER_METHOD"
]
self._test_function_import('fuser_method_mappings', function_list)
self._test_dict_import('fuser_method_mappings', dict_list)
diff --git a/test/quantization/core/test_backend_config.py b/test/quantization/core/test_backend_config.py
index e1e7067..aa9de64 100644
--- a/test/quantization/core/test_backend_config.py
+++ b/test/quantization/core/test_backend_config.py
@@ -14,7 +14,7 @@
ObservationType,
)
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
-from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
+from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2
from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter
from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer
@@ -106,7 +106,7 @@
# BackendPatternConfig
# ======================
- _fuser_method = reverse_sequential_wrapper2(nni.LinearReLU)
+ _fuser_method = _reverse_sequential_wrapper2(nni.LinearReLU)
_num_tensor_args_to_observation_type = {
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
index abc0bd2..2e8390c 100644
--- a/torch/ao/quantization/__init__.py
+++ b/torch/ao/quantization/__init__.py
@@ -114,7 +114,6 @@
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
- "get_valid_patterns",
"is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
@@ -132,12 +131,8 @@
"quantize_jit",
"quantize_qat",
"register_activation_post_process_hook",
- "reverse2",
- "reverse3",
- "reverse_sequential_wrapper2",
"script_qconfig",
"script_qconfig_dict",
- "sequential_wrapper2",
"swap_module",
"weight_observer_range_neg_127_to_127",
]
diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
index bc6f678..c2f0f72 100644
--- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py
+++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -15,9 +15,9 @@
)
from ..fake_quantize import FixedQParamsFakeQuantize
from ..fuser_method_mappings import (
- reverse_sequential_wrapper2,
- reverse2,
- reverse3,
+ _reverse_sequential_wrapper2,
+ _reverse2,
+ _reverse3,
fuse_conv_bn,
fuse_conv_bn_relu,
fuse_linear_bn,
@@ -115,13 +115,13 @@
linear_configs.append(
BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU))
+ .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
.set_fused_module(nni.LinearReLU))
# linear relu, linear module + functional relu
linear_configs.append(
BackendPatternConfig((torch.nn.functional.relu, torch.nn.Linear))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU))
+ .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
.set_fused_module(nni.LinearReLU))
# 2.2 linear module + relu, fused module configs
@@ -158,7 +158,7 @@
linear_configs.append(
BackendPatternConfig((nn.BatchNorm1d, nn.Linear))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse2(fuse_linear_bn))
+ .set_fuser_method(_reverse2(fuse_linear_bn))
.set_fused_module(nni.LinearBn1d))
# 3.2 linear bn fused
@@ -218,13 +218,13 @@
conv_configs.append(
BackendPatternConfig((torch.nn.ReLU, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv relu fusion, conv module + functional relu
conv_configs.append(
BackendPatternConfig((F.relu, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
@@ -273,20 +273,20 @@
conv_configs.append(
BackendPatternConfig((convs.bn, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse2(fuse_conv_bn))
+ .set_fuser_method(_reverse2(fuse_conv_bn))
.set_fused_module(convs.fused_conv_bn))
# conv + bn + relu module fusion
conv_configs.append(
BackendPatternConfig((nn.ReLU, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse3(fuse_conv_bn_relu))
+ .set_fuser_method(_reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
BackendPatternConfig((F.relu, (convs.bn, convs.root)))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
- .set_fuser_method(reverse3(fuse_conv_bn_relu))
+ .set_fuser_method(_reverse3(fuse_conv_bn_relu))
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well
@@ -330,7 +330,7 @@
conv_configs.append(
BackendPatternConfig((convs.bn, convs.transpose))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse2(fuse_convtranspose_bn))
+ .set_fuser_method(_reverse2(fuse_convtranspose_bn))
.set_root_module(convs.transpose)
.set_reference_quantized_module(convs.transpose_reference))
@@ -497,13 +497,13 @@
bn_configs.append(
BackendPatternConfig((torch.nn.ReLU, bn))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(fused_bn))
+ .set_fuser_method(_reverse_sequential_wrapper2(fused_bn))
.set_fused_module(fused_bn))
# bn module + F.relu fusion config
bn_configs.append(
BackendPatternConfig((torch.nn.functional.relu, bn))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(bn_to_fused_bn[bn]))
+ .set_fuser_method(_reverse_sequential_wrapper2(bn_to_fused_bn[bn]))
.set_fused_module(fused_bn))
bn_configs.append(
BackendPatternConfig(bn)
diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py
index 2f491b1..1305c32 100644
--- a/torch/ao/quantization/backend_config/backend_config.py
+++ b/torch/ao/quantization/backend_config/backend_config.py
@@ -229,7 +229,7 @@
import torch
from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
- from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
+ from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2
weighted_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
@@ -248,7 +248,7 @@
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_int8_dtype_config) \
.set_fused_module(torch.nn.intrinsic.ConvReLU2d) \
- .set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d))
+ .set_fuser_method(_reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d))
backend_config = BackendConfig("my_backend") \
.set_backend_pattern_config(linear_config) \
diff --git a/torch/ao/quantization/backend_config/executorch.py b/torch/ao/quantization/backend_config/executorch.py
index 4c0f2a4..3c72932 100644
--- a/torch/ao/quantization/backend_config/executorch.py
+++ b/torch/ao/quantization/backend_config/executorch.py
@@ -7,7 +7,7 @@
import torch.nn.quantized._reference as nnqr
from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
from ._common_operator_config_utils import _Conv2dMetadata
-from ..fuser_method_mappings import reverse_sequential_wrapper2
+from ..fuser_method_mappings import _reverse_sequential_wrapper2
__all__ = [
@@ -105,13 +105,13 @@
conv_configs.append(
BackendPatternConfig((torch.nn.ReLU, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv module + functional relu
conv_configs.append(
BackendPatternConfig((F.relu, convs.root))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# fused conv relu module
conv_configs.append(
diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py
index 2e39f87..db4cc9a 100644
--- a/torch/ao/quantization/fuser_method_mappings.py
+++ b/torch/ao/quantization/fuser_method_mappings.py
@@ -10,13 +10,7 @@
"fuse_conv_bn_relu",
"fuse_linear_bn",
"fuse_convtranspose_bn",
- "sequential_wrapper2",
"get_fuser_method",
- "reverse_sequential_wrapper2",
- "reverse2",
- "reverse3",
- "DEFAULT_PATTERN_TO_FUSER_METHOD",
- "get_valid_patterns",
"get_fuser_method_new",
]
@@ -156,7 +150,7 @@
else:
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
-def sequential_wrapper2(sequential):
+def _sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
@@ -165,20 +159,20 @@
return sequential(m1, m2)
return fuser_method
-DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
+_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
- (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
- (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
- (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
+ (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d),
+ (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d),
+ (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d),
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
- (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
- (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
- (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
+ (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU),
+ (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d),
+ (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d),
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
@@ -190,13 +184,13 @@
'''
if additional_fuser_method_mapping is None:
additional_fuser_method_mapping = {}
- all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD,
+ all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD,
additional_fuser_method_mapping)
fuser_method = all_mappings.get(op_list, None)
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
return fuser_method
-def reverse_sequential_wrapper2(sequential):
+def _reverse_sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules, with
@@ -206,37 +200,37 @@
return sequential(m2, m1)
return fuser_method
-def reverse2(f):
+def _reverse2(f):
def reversed(is_qat, x, y):
return f(is_qat, y, x)
return reversed
-def reverse3(f):
+def _reverse3(f):
def reversed(is_qat, x, w):
y, z = w
return f(is_qat, z, y, x)
return reversed
-DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
- (nn.BatchNorm1d, nn.Conv1d): reverse2(fuse_conv_bn),
- (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu),
- (nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn),
- (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
- (nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
- (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
- (nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
- (nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
- (nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
- (nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
- (nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
- (nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
- (nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
- (nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
- (nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
- (nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
+_DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
+ (nn.BatchNorm1d, nn.Conv1d): _reverse2(fuse_conv_bn),
+ (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): _reverse3(fuse_conv_bn_relu),
+ (nn.BatchNorm2d, nn.Conv2d): _reverse2(fuse_conv_bn),
+ (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): _reverse3(fuse_conv_bn_relu),
+ (nn.BatchNorm3d, nn.Conv3d): _reverse2(fuse_conv_bn),
+ (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): _reverse3(fuse_conv_bn_relu),
+ (nn.ReLU, nn.Conv1d): _reverse_sequential_wrapper2(nni.ConvReLU1d),
+ (nn.ReLU, nn.Conv2d): _reverse_sequential_wrapper2(nni.ConvReLU2d),
+ (nn.ReLU, nn.Conv3d): _reverse_sequential_wrapper2(nni.ConvReLU3d),
+ (nn.BatchNorm1d, nn.Linear): _reverse2(fuse_linear_bn),
+ (nn.ReLU, nn.Linear): _reverse_sequential_wrapper2(nni.LinearReLU),
+ (nn.ReLU, nn.BatchNorm2d): _reverse_sequential_wrapper2(nni.BNReLU2d),
+ (nn.ReLU, nn.BatchNorm3d): _reverse_sequential_wrapper2(nni.BNReLU3d),
+ (nn.BatchNorm1d, nn.ConvTranspose1d): _reverse2(fuse_convtranspose_bn),
+ (nn.BatchNorm2d, nn.ConvTranspose2d): _reverse2(fuse_convtranspose_bn),
+ (nn.BatchNorm3d, nn.ConvTranspose3d): _reverse2(fuse_convtranspose_bn),
}
-def get_valid_patterns(op_pattern):
+def _get_valid_patterns(op_pattern):
"""
Returns a list of valid patterns generated from the op_pattern,
since MatchAllNode can match all types of nodes,
@@ -261,7 +255,7 @@
if isinstance(op_pattern, (tuple, list)):
sub_combs = []
for sub_pattern in op_pattern:
- sub_combs.append(get_valid_patterns(sub_pattern))
+ sub_combs.append(_get_valid_patterns(sub_pattern))
result = list(itertools.product(*sub_combs))
else:
result = [op_pattern, MatchAllNode]
@@ -274,9 +268,9 @@
Would like to implement this first and have a separate PR for deprecation
"""
if fuser_method_mapping is None:
- fuser_method_mapping = DEFAULT_PATTERN_TO_FUSER_METHOD
+ fuser_method_mapping = _DEFAULT_PATTERN_TO_FUSER_METHOD
- op_patterns = get_valid_patterns(op_pattern)
+ op_patterns = _get_valid_patterns(op_pattern)
fuser_method = None
for op_pattern in op_patterns:
fuser_method = fuser_method_mapping.get(op_pattern, None)
diff --git a/torch/ao/quantization/fx/README.md b/torch/ao/quantization/fx/README.md
index cba11e9..622acd3 100644
--- a/torch/ao/quantization/fx/README.md
+++ b/torch/ao/quantization/fx/README.md
@@ -81,7 +81,7 @@
```
BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
- .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU))
+ .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
._set_root_node_getter(my_root_node_getter)
._set_extra_inputs_getter(my_extra_inputs_getter)
```
diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py
index 50520b3..22f4e63 100644
--- a/torch/quantization/fuser_method_mappings.py
+++ b/torch/quantization/fuser_method_mappings.py
@@ -10,6 +10,6 @@
fuse_conv_bn,
fuse_conv_bn_relu,
fuse_linear_bn,
- DEFAULT_OP_LIST_TO_FUSER_METHOD,
+ _DEFAULT_OP_LIST_TO_FUSER_METHOD,
get_fuser_method,
)