[Quant][fx][bc-breaking] Add simpler BackendConfig pattern format (#90698)
Summary: The existing BackendConfig fusion pattern
uses a "reversed nested tuple" format that is highly
unintuitive. For example,
```
linear-relu -> (nn.ReLU, nn.Linear)
conv-bn-relu -> (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
```
This pattern format also complicates the signatures
of the user specified "fuser methods", which needed
to accept arguments in reverse nested order to match
the patterns:
```
def fuse_linear_relu(is_qat, relu, linear):
...
def fuse_conv_bn_relu(is_qat, relu, bn_conv):
(bn, conv) = bn_conv
...
```
Instead, this commit introduces a new pattern format that
simply specifies the ops in forward order with no nesting:
```
linear-relu -> (nn.Linear, nn.ReLU)
conv-bn-relu -> (nn.Conv2d, nn.BatchNorm2d, nn.ReLU)
def fuse_linear_relu(is_qat, linear, relu):
...
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
...
```
Note that the legacy "reversed nested tuple" is still
used internally since it is more general. In the
future, we should replace it with the format used in
the subgraph rewriter in `torch.fx`, and simplify the
existing pattern matching code to handle the new
format added in this commit.
BC-breaking Notes:
Before:
```
import torch as nn
import torch.ao.nn.intrinsic as nni
from torch.ao.quantization.backend_config import BackendPatternConfig
def fuse_linear_relu(is_qat, relu, bn_conv):
(bn, conv) = bn_conv
return nni.ConvBnReLU2d(conv, bn, relu)
config = BackendPatternConfig((nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))) \
.set_dtype_configs(...) \
.set_fuser_method(fuse_conv_bn_relu) \
.set_fused_module(nni.ConvBnReLU2d)
```
After:
```
def fuse_linear_relu(is_qat, conv, bn, relu):
return nni.ConvBnReLU2d(conv, bn, relu)
config = BackendPatternConfig((nn.Conv2d, nn.BatchNorm2d, nn.ReLU)) \
.set_dtype_configs(...) \
.set_fuser_method(fuse_conv_bn_relu) \
.set_fused_module(nni.ConvBnReLU2d)
```
OR (for backward-compatibility)
```
def fuse_linear_relu(is_qat, relu, bn_conv):
(bn, conv) = bn_conv
return nni.ConvBnReLU2d(conv, bn, relu)
config = BackendPatternConfig() \
._set_pattern_complex_format((nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))) \
.set_dtype_configs(...) \
.set_fuser_method(fuse_conv_bn_relu) \
.set_fused_module(nni.ConvBnReLU2d) \
._set_use_legacy_pattern_format(True)
```
Before:
```
backend_config.configs # returns Dict[Pattern, BackendPatternConfig]
```
After:
```
backend_config.configs # returns List[BackendPatternConfig]
```
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestBackendConfig
Reviewers: jerryzh168, vkuzo
Subscribers: jerryzh168, vkuzo
Differential Revision: [D41954553](https://our.internmc.facebook.com/intern/diff/D41954553)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90698
Approved by: https://github.com/vkuzo, https://github.com/jerryzh168
diff --git a/test/quantization/core/test_backend_config.py b/test/quantization/core/test_backend_config.py
index 6cf8f3d..7f44809 100644
--- a/test/quantization/core/test_backend_config.py
+++ b/test/quantization/core/test_backend_config.py
@@ -13,7 +13,7 @@
DTypeWithConstraints,
ObservationType,
)
-from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2
+from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2
from torch.ao.quantization.fx.quantize_handler import _default_root_node_getter
@@ -104,7 +104,7 @@
# BackendPatternConfig
# ======================
- _fuser_method = _reverse_sequential_wrapper2(nni.LinearReLU)
+ _fuser_method = _sequential_wrapper2(nni.LinearReLU)
_num_tensor_args_to_observation_type = {
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
@@ -121,7 +121,7 @@
return (torch.rand(3, 3),)
def _get_backend_op_config1(self):
- return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \
+ return BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(self.dtype_config1) \
.add_dtype_config(self.dtype_config2) \
@@ -142,7 +142,7 @@
def _get_backend_pattern_config_dict1(self):
return {
- "pattern": (torch.nn.ReLU, torch.nn.Linear),
+ "pattern": (torch.nn.Linear, torch.nn.ReLU),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2],
"root_module": torch.nn.Linear,
@@ -198,19 +198,19 @@
self.assertEqual(conf.reference_quantized_module, nnqr.Linear)
def test_backend_op_config_set_fused_module(self):
- conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
+ conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
self.assertTrue(conf.fused_module is None)
conf.set_fused_module(nni.LinearReLU)
self.assertEqual(conf.fused_module, nni.LinearReLU)
def test_backend_op_config_set_fuser_method(self):
- conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
+ conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
self.assertTrue(conf.fuser_method is None)
conf.set_fuser_method(self._fuser_method)
self.assertEqual(conf.fuser_method, self._fuser_method)
def test_backend_op_config_set_root_node_getter(self):
- conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
+ conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
self.assertTrue(conf._root_node_getter is None)
conf._set_root_node_getter(_default_root_node_getter)
self.assertEqual(conf._root_node_getter, _default_root_node_getter)
@@ -242,7 +242,7 @@
def test_backend_op_config_from_dict(self):
conf_dict1 = self._get_backend_pattern_config_dict1()
conf1 = BackendPatternConfig.from_dict(conf_dict1)
- self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear))
+ self.assertEqual(conf1.pattern, (torch.nn.Linear, torch.nn.ReLU))
self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
self.assertEqual(conf1.root_module, torch.nn.Linear)
self.assertEqual(conf1.qat_module, nnqat.Linear)
@@ -294,11 +294,11 @@
backend_op_config1 = self._get_backend_op_config1()
backend_op_config2 = self._get_backend_op_config2()
conf.set_backend_pattern_config(backend_op_config1)
- self.assertEqual(conf.configs, {
+ self.assertEqual(conf._pattern_complex_format_to_config, {
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
})
conf.set_backend_pattern_config(backend_op_config2)
- self.assertEqual(conf.configs, {
+ self.assertEqual(conf._pattern_complex_format_to_config, {
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
torch.add: backend_op_config2
})
@@ -317,10 +317,10 @@
self.assertEqual(len(conf.configs), 2)
key1 = (torch.nn.ReLU, torch.nn.Linear)
key2 = torch.add
- self.assertTrue(key1 in conf.configs)
- self.assertTrue(key2 in conf.configs)
- self.assertEqual(conf.configs[key1].to_dict(), op_dict1)
- self.assertEqual(conf.configs[key2].to_dict(), op_dict2)
+ self.assertTrue(key1 in conf._pattern_complex_format_to_config)
+ self.assertTrue(key2 in conf._pattern_complex_format_to_config)
+ self.assertEqual(conf._pattern_complex_format_to_config[key1].to_dict(), op_dict1)
+ self.assertEqual(conf._pattern_complex_format_to_config[key2].to_dict(), op_dict2)
def test_backend_config_to_dict(self):
op1 = self._get_backend_op_config1()
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index f00412a..df57225 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -546,9 +546,11 @@
bn, conv = bn_pattern
return conv
- conv_bn_res_relu_config1 = BackendPatternConfig((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
+ conv_bn_res_relu_config1 = BackendPatternConfig() \
+ ._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
.set_fuser_method(fuse_conv_bn_relu)
- conv_bn_res_relu_config2 = BackendPatternConfig((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
+ conv_bn_res_relu_config2 = BackendPatternConfig() \
+ ._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
.set_fuser_method(fuse_conv_bn_relu)
backend_config = BackendConfig() \
.set_backend_pattern_config(conv_bn_res_relu_config1) \
@@ -606,7 +608,8 @@
bn, conv = bn_pattern
return [extra_input]
- conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
+ conv_bn_res_relu_config = BackendPatternConfig() \
+ ._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
.set_fuser_method(fuse_conv_bn_relu) \
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
@@ -654,7 +657,7 @@
m = M().eval()
- def fuse_conv_relu(is_qat, relu, conv):
+ def fuse_conv_relu(is_qat, conv, relu):
return conv
def fuse_conv_res_relu(is_qat, relu, add_pattern):
@@ -669,9 +672,10 @@
relu, (_, _, extra_input) = pattern
return [extra_input]
- conv_relu_config = BackendPatternConfig((nn.ReLU, nn.Conv2d)) \
+ conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \
.set_fuser_method(fuse_conv_relu)
- conv_res_relu_config = BackendPatternConfig((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \
+ conv_res_relu_config = BackendPatternConfig() \
+ ._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \
.set_fuser_method(fuse_conv_res_relu) \
._set_root_node_getter(conv_res_relu_root_node_getter) \
._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter)
@@ -5545,10 +5549,12 @@
return transpose
backend_pattern_configs.append(
- BackendPatternConfig((torch.reshape, torch.transpose, MatchAllNode))
- .set_observation_type(observation_type) # noqa: E131
+ BackendPatternConfig()
+ ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) # noqa: E131
+ .set_observation_type(observation_type)
.set_dtype_configs(dtype_configs)
- ._set_root_node_getter(root_node_getter))
+ ._set_root_node_getter(root_node_getter)
+ )
return backend_pattern_configs
backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs())
diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py
index 321ec09..381bce2 100644
--- a/torch/ao/ns/fx/mappings.py
+++ b/torch/ao/ns/fx/mappings.py
@@ -13,18 +13,18 @@
import torch.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
-from torch.ao.quantization.backend_config import get_native_backend_config_dict
+from torch.ao.quantization.backend_config import get_native_backend_config
import torch.ao.quantization.fx._lower_to_native_backend as \
_lower_to_native_backend
import torch.ao.quantization.quantization_mappings as quantization_mappings
from .ns_types import NSNodeTargetType
-from typing import Set, Dict, List, Optional
+from typing import Callable, Dict, List, Optional, Set, Tuple
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
- # note: this set is modified below by items from backend_config_dict
+ # note: this set is modified below by items from backend_config
sets_of_related_ops: List[Set[NSNodeTargetType]] = [
# conv modules
set([
@@ -327,42 +327,36 @@
]
# for each floating point op, add versions of the op added by
- # backend_config_dict
- backend_config_dict = get_native_backend_config_dict()
+ # backend_config
+ backend_config = get_native_backend_config()
- new_connections = [
+ new_connections: List[Tuple[Callable, Callable]] = [
# technical debt edge case
(nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
]
- for config in backend_config_dict['configs']:
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
- if 'pattern' not in config:
- continue
-
- # format: (c, (b, a))
- pattern = config['pattern']
+ # pattern format: (c, (b, a))
first_element = pattern
# look from the end, because pattern is in reverse order
while isinstance(first_element, (list, tuple)):
first_element = first_element[-1]
- if 'fused_module' in config:
+ if config.fused_module is not None:
# case 1: pattern fuses a pattern of ops into an op
# example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
- new_connections.append((first_element, config['fused_module']))
+ new_connections.append((first_element, config.fused_module))
- if 'qat_module' in config:
+ if config.qat_module is not None:
# case 2: pattern swaps a module into a QAT module
# example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
- new_connections.append((first_element, config['qat_module']))
+ new_connections.append((first_element, config.qat_module))
- if 'reference_quantized_module_for_root' in config:
+ if config.reference_quantized_module is not None:
# case 3: reference version of floating point module, such as
# nn.Conv2d and nnqr.Conv2d
- new_connections.append(
- (first_element, config['reference_quantized_module_for_root'])
- )
+ new_connections.append((first_element, config.reference_quantized_module))
#
# Add reference module swaps from default lowering path
@@ -413,7 +407,7 @@
new_connections.append((source, target))
- # add the new connections from backend_config_dict
+ # add the new connections from backend_config
for item1, item2 in new_connections:
for set_of_related_ops in sets_of_related_ops:
if item1 in set_of_related_ops or item2 in set_of_related_ops:
diff --git a/torch/ao/quantization/backend_config/README.md b/torch/ao/quantization/backend_config/README.md
index 5d37fce..5e63af1 100644
--- a/torch/ao/quantization/backend_config/README.md
+++ b/torch/ao/quantization/backend_config/README.md
@@ -22,7 +22,19 @@
## Pattern Specification
-The operator patterns used in BackendConfig are float modules, functional operators and pytorch operators specified in reverse order:
+The operator patterns used in BackendConfig are float modules, functional operators, pytorch operators, or a tuple combination of the above. For example:
+* torch.nn.Linear
+* torch.nn.functional.linear
+* torch.add
+* operator.add
+* (torch.nn.functional.linear, torch.nn.functional.relu)
+* (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU)
+
+Tuple patterns are treated as sequential patterns, and currently only tuples of 2 or 3 elements are supported.
+
+### Advanced Pattern Specification
+
+The above format should satisfy the vast majority of use cases. However, it does not handle more complex scenarios such as graph patterns. For these use cases, the BackendConfig API offers an alternative "reverse nested tuple" pattern format, enabled through `BackendPatternConfig()._set_pattern_complex_format(...)`. Note that this format is deprecated and will be replaced in a future version of PyTorch.
```
operator = module_type | functional | torch op | native op | MatchAllNode
Pattern = (operator, Pattern, Pattern, ...) | operator
@@ -62,7 +74,7 @@
weight_dtype=torch.qint8,
bias_dtype=torch.float)
-def fuse_conv2d_relu(is_qat, relu, conv):
+def fuse_conv2d_relu(is_qat, conv, relu):
"""Return a fused ConvReLU2d from individual conv and relu modules."""
return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
@@ -75,7 +87,7 @@
.set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
# For fusing Conv2d + ReLU into ConvReLU2d
-conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \
+conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_int8_dtype_config) \
.set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
@@ -118,7 +130,7 @@
* `_set_root_node_getter`
* `_set_extra_inputs_getter`
-As an optimization, operator patterns such as (`torch.nn.ReLU`, `torch.nn.Linear`) may be fused into `nni.LinearReLU`. This is performed during the prepare phase according to the function specified in `set_fuser_method`, which replaces the pattern with the fused module. During the convert phase, these fused modules (identified by `set_fused_module`) will then be converted to the reference quantized versions of the modules.
+As an optimization, operator patterns such as (`torch.nn.Linear`, `torch.nn.ReLU`) may be fused into `nni.LinearReLU`. This is performed during the prepare phase according to the function specified in `set_fuser_method`, which replaces the pattern with the fused module. During the convert phase, these fused modules (identified by `set_fused_module`) will then be converted to the reference quantized versions of the modules.
In FX graph mode quantization, we replace the corresponding nodes in the graph using two helper functions set by the user: `root_node_getter`, which returns the root node (typically the weighted module in the pattern like `torch.nn.Linear`) to replace the matched pattern in the graph, and `extra_inputs_getter`, which returns a list of extra input arguments that will be appended to the existing arguments of the fused module (copied over from the root node). See [this snippet](https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6) for an example usage.
diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
index 3d95b8b..fa06c5a 100644
--- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py
+++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -16,9 +16,7 @@
ObservationType,
)
from ..fuser_method_mappings import (
- _reverse_sequential_wrapper2,
- _reverse2,
- _reverse3,
+ _sequential_wrapper2,
fuse_conv_bn,
fuse_conv_bn_relu,
fuse_linear_bn,
@@ -94,9 +92,9 @@
}
for op_with_quantized_bop_scalar_variant in [operator.add, torch.add, operator.mul, torch.mul]:
bop_patterns = [
- (torch.nn.ReLU, op_with_quantized_bop_scalar_variant),
- (torch.nn.functional.relu, op_with_quantized_bop_scalar_variant),
- (torch.relu, op_with_quantized_bop_scalar_variant),
+ (op_with_quantized_bop_scalar_variant, nn.ReLU),
+ (op_with_quantized_bop_scalar_variant, F.relu),
+ (op_with_quantized_bop_scalar_variant, torch.relu),
op_with_quantized_bop_scalar_variant
]
for bop_pattern in bop_patterns:
@@ -147,15 +145,15 @@
# 2.1 linear module + relu fusion config
# linear relu, linear module + relu module
linear_configs.append(
- BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
+ BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
+ .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
.set_fused_module(nni.LinearReLU))
# linear relu, linear module + functional relu
linear_configs.append(
- BackendPatternConfig((torch.nn.functional.relu, torch.nn.Linear))
+ BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
+ .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
.set_fused_module(nni.LinearReLU))
# 2.2 linear module + relu, fused module configs
@@ -177,12 +175,12 @@
# 2.3 functional linear + relu configs
# linear relu, functional linear + relu module
linear_configs.append(
- BackendPatternConfig((torch.nn.ReLU, F.linear))
+ BackendPatternConfig((F.linear, torch.nn.ReLU))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs))
# linear relu, functional linear + functional relu
linear_configs.append(
- BackendPatternConfig((F.relu, F.linear))
+ BackendPatternConfig((F.linear, F.relu))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs))
@@ -190,9 +188,9 @@
# ------------------------
# 3.1 linear bn fusion
linear_configs.append(
- BackendPatternConfig((nn.BatchNorm1d, nn.Linear))
+ BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse2(fuse_linear_bn))
+ .set_fuser_method(fuse_linear_bn)
.set_fused_module(nni.LinearBn1d))
# 3.2 linear bn fused
@@ -250,15 +248,15 @@
# 2.1 conv module + relu fusion configs
# conv relu fusion, conv module + relu module
conv_configs.append(
- BackendPatternConfig((torch.nn.ReLU, convs.root))
+ BackendPatternConfig((convs.root, torch.nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv relu fusion, conv module + functional relu
conv_configs.append(
- BackendPatternConfig((F.relu, convs.root))
+ BackendPatternConfig((convs.root, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# 2.2 conv module + relu fused module configs
# conv relu, fused module
@@ -279,12 +277,12 @@
# 2.3 functional conv + relu configs
# conv relu, functional conv + relu module
conv_configs.append(
- BackendPatternConfig((torch.nn.ReLU, convs.func))
+ BackendPatternConfig((convs.func, torch.nn.ReLU))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs))
# conv relu, functional conv + functional relu
conv_configs.append(
- BackendPatternConfig((F.relu, convs.func))
+ BackendPatternConfig((convs.func, F.relu))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs))
@@ -305,22 +303,22 @@
# 3.1 conv bn fusion configs
# conv + bn fusion
conv_configs.append(
- BackendPatternConfig((convs.bn, convs.root))
+ BackendPatternConfig((convs.root, convs.bn))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse2(fuse_conv_bn))
+ .set_fuser_method(fuse_conv_bn)
.set_fused_module(convs.fused_conv_bn))
# conv + bn + relu module fusion
conv_configs.append(
- BackendPatternConfig((nn.ReLU, (convs.bn, convs.root)))
+ BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse3(fuse_conv_bn_relu))
+ .set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# conv + bn + relu functional fusion
conv_configs.append(
- BackendPatternConfig((F.relu, (convs.bn, convs.root)))
+ BackendPatternConfig((convs.root, convs.bn, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
.set_root_module(convs.root)
- .set_fuser_method(_reverse3(fuse_conv_bn_relu))
+ .set_fuser_method(fuse_conv_bn_relu)
.set_fused_module(convs.fused_conv_bn_relu))
# TODO: we can add fusion for torch.relu as well
@@ -362,9 +360,9 @@
# 4.2 conv transpose + bn fusion
conv_configs.append(
- BackendPatternConfig((convs.bn, convs.transpose))
+ BackendPatternConfig((convs.transpose, convs.bn))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse2(fuse_convtranspose_bn))
+ .set_fuser_method(fuse_convtranspose_bn)
.set_root_module(convs.transpose)
.set_reference_quantized_module(convs.transpose_reference))
@@ -553,15 +551,15 @@
fused_bn = bn_to_fused_bn[bn]
# bn module + relu module fusion config
bn_configs.append(
- BackendPatternConfig((torch.nn.ReLU, bn))
+ BackendPatternConfig((bn, nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(fused_bn))
+ .set_fuser_method(_sequential_wrapper2(fused_bn))
.set_fused_module(fused_bn))
# bn module + F.relu fusion config
bn_configs.append(
- BackendPatternConfig((torch.nn.functional.relu, bn))
+ BackendPatternConfig((bn, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(bn_to_fused_bn[bn]))
+ .set_fuser_method(_sequential_wrapper2(fused_bn))
.set_fused_module(fused_bn))
bn_configs.append(
BackendPatternConfig(bn)
diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py
index 4b3d4d3..3ec05fe 100644
--- a/torch/ao/quantization/backend_config/backend_config.py
+++ b/torch/ao/quantization/backend_config/backend_config.py
@@ -29,6 +29,7 @@
# BackendPatternConfig dict keys
PATTERN_DICT_KEY = "pattern"
+PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format"
OBSERVATION_TYPE_DICT_KEY = "observation_type"
DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
ROOT_MODULE_DICT_KEY = "root_module"
@@ -241,7 +242,7 @@
weight_dtype=torch.qint8,
bias_dtype=torch.float)
- def fuse_conv2d_relu(is_qat, relu, conv):
+ def fuse_conv2d_relu(is_qat, conv, relu):
return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
# For quantizing Linear
@@ -253,7 +254,7 @@
.set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
# For fusing Conv2d + ReLU into ConvReLU2d
- conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \
+ conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_int8_dtype_config) \
.set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
@@ -275,7 +276,11 @@
"""
def __init__(self, name: str = ""):
self.name = name
- self.configs: Dict[Pattern, BackendPatternConfig] = {}
+ # Store all BackendPatternConfigs in a map to handle duplicates
+ # Note: the key in this map uses the complex reversed tuple format.
+ # This is intended only for internal use; users who wish to access
+ # the original patterns should go through `self.configs` instead.
+ self._pattern_complex_format_to_config: Dict[Pattern, BackendPatternConfig] = {}
def set_name(self, name: str) -> BackendConfig:
"""
@@ -289,7 +294,10 @@
Set the config for an pattern that can be run on the target backend.
This overrides any existing config for the given pattern.
"""
- self.configs[config.pattern] = config
+ # Avoid circular dependencies
+ pattern_complex_format = torch.ao.quantization.backend_config.utils \
+ ._get_pattern_in_reversed_nested_tuple_format(config) # type: ignore[attr-defined]
+ self._pattern_complex_format_to_config[pattern_complex_format] = config
return self
def set_backend_pattern_configs(self, configs: List[BackendPatternConfig]) -> BackendConfig:
@@ -301,6 +309,13 @@
self.set_backend_pattern_config(conf)
return self
+ @property
+ def configs(self) -> List[BackendPatternConfig]:
+ """
+ Return a copy of the list of configs set in this `BackendConfig`.
+ """
+ return list(self._pattern_complex_format_to_config.values())
+
@classmethod
def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
"""
@@ -328,7 +343,7 @@
"""
return {
NAME_DICT_KEY: self.name,
- CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs.values()],
+ CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs],
}
@@ -338,8 +353,8 @@
For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
"""
- def __init__(self, pattern: Pattern):
- self.pattern = pattern
+ def __init__(self, pattern: Optional[Pattern] = None):
+ self.pattern: Optional[Pattern] = pattern
self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
self.dtype_configs: List[DTypeConfig] = []
self.root_module: Optional[Type[torch.nn.Module]] = None
@@ -354,6 +369,20 @@
self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {}
self._input_type_to_index: Dict[str, int] = {}
self._input_output_observed: Optional[bool] = None
+ self._pattern_complex_format: Optional[Pattern] = None
+
+ def set_pattern(self, pattern: Pattern) -> BackendPatternConfig:
+ """
+ Set the pattern to configure.
+
+ The pattern can be a float module, functional operator, pytorch operator, or a tuple
+ combination of the above. Tuple patterns are treated as sequential patterns, and
+ currently only tuples of 2 or 3 elements are supported.
+ """
+ if self._pattern_complex_format is not None:
+ raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set")
+ self.pattern = pattern
+ return self
def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
"""
@@ -421,9 +450,13 @@
Set the function that specifies how to fuse the pattern for this pattern.
The first argument of this function should be `is_qat`, and the rest of the arguments
- should be the items in the tuple pattern, e.g. (`torch.nn.ReLU`, `torch.nn.Linear`)
- will have a function with three arguments, `is_qat`, `relu`, and `linear`.
- The return value of this function should be the resulting fused module.
+ should be the items in the tuple pattern. The return value of this function should be
+ the resulting fused module.
+
+ For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be:
+
+ def fuse_linear_relu(is_qat, linear, relu):
+ return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
"""
self.fuser_method = fuser_method
return self
@@ -449,6 +482,18 @@
self._input_output_observed = input_output_observed
return self
+ def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig:
+ """
+ Set the pattern to configure, using the reversed nested tuple format.
+
+ See the BackendConfig README for more detail:
+ https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification
+ """
+ if self.pattern is not None:
+ raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set")
+ self._pattern_complex_format = pattern
+ return self
+
@classmethod
def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig:
"""
@@ -464,6 +509,7 @@
implementation for this pattern's root module.
"fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
"fuser_method": a function that specifies how to fuse the pattern for this pattern
+ "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated)
"""
def _get_dtype_config(obj: Any) -> DTypeConfig:
@@ -477,9 +523,9 @@
raise ValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'" %
(DTYPE_CONFIGS_DICT_KEY, type(obj)))
- if PATTERN_DICT_KEY not in backend_pattern_config_dict:
- raise ValueError("backend_pattern_config_dict must contain '%s'" % PATTERN_DICT_KEY)
- conf = cls(backend_pattern_config_dict[PATTERN_DICT_KEY])
+ conf = cls()
+ if PATTERN_DICT_KEY in backend_pattern_config_dict:
+ conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY])
if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY])
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
@@ -495,6 +541,8 @@
backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}))
conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}))
conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None))
+ if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict:
+ conf._set_pattern_complex_format(backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY])
return conf
def to_dict(self) -> Dict[str, Any]:
@@ -503,10 +551,11 @@
:func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
"""
backend_pattern_config_dict: Dict[str, Any] = {
- PATTERN_DICT_KEY: self.pattern,
OBSERVATION_TYPE_DICT_KEY: self.observation_type,
DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
}
+ if self.pattern is not None:
+ backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern
if self.root_module is not None:
backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
if self.qat_module is not None:
@@ -527,4 +576,6 @@
backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index
if self._input_output_observed is not None:
backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed
+ if self._pattern_complex_format is not None:
+ backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = self._pattern_complex_format
return backend_pattern_config_dict
diff --git a/torch/ao/quantization/backend_config/executorch.py b/torch/ao/quantization/backend_config/executorch.py
index fcccec6..f497e38 100644
--- a/torch/ao/quantization/backend_config/executorch.py
+++ b/torch/ao/quantization/backend_config/executorch.py
@@ -16,7 +16,7 @@
qnnpack_default_op_qint8_symmetric_dtype_config
)
from ._common_operator_config_utils import _Conv2dMetadata
-from ..fuser_method_mappings import _reverse_sequential_wrapper2
+from ..fuser_method_mappings import _sequential_wrapper2
__all__ = [
@@ -115,15 +115,15 @@
._set_input_type_to_index({"weight": 1, "bias": 2}))
# conv module + relu module
conv_configs.append(
- BackendPatternConfig((torch.nn.ReLU, convs.root))
+ BackendPatternConfig((convs.root, nn.ReLU))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# conv module + functional relu
conv_configs.append(
- BackendPatternConfig((F.relu, convs.root))
+ BackendPatternConfig((convs.root, F.relu))
.set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu))
+ .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
.set_fused_module(convs.fused_conv_relu))
# fused conv relu module
conv_configs.append(
@@ -135,12 +135,12 @@
.set_qat_module(convs.relu_qat))
# functional conv + relu module
conv_configs.append(
- BackendPatternConfig((torch.nn.ReLU, convs.func))
+ BackendPatternConfig((convs.func, nn.ReLU))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs))
# functional conv + functional relu
conv_configs.append(
- BackendPatternConfig((F.relu, convs.func))
+ BackendPatternConfig((convs.func, F.relu))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs))
return conv_configs
diff --git a/torch/ao/quantization/backend_config/utils.py b/torch/ao/quantization/backend_config/utils.py
index fc7e9ac..2e73822 100644
--- a/torch/ao/quantization/backend_config/utils.py
+++ b/torch/ao/quantization/backend_config/utils.py
@@ -3,8 +3,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from .backend_config import BackendConfig, DTypeConfig
+from .backend_config import (
+ BackendConfig,
+ BackendPatternConfig,
+ DTypeConfig,
+)
from ..utils import Pattern
+from ..fuser_method_mappings import (
+ _reverse2,
+ _reverse3,
+)
__all__ = [
"get_pattern_to_dtype_configs",
@@ -23,48 +31,52 @@
def get_pattern_to_dtype_configs(backend_config: BackendConfig) -> Dict[Pattern, List[DTypeConfig]]:
pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
pattern_to_dtype_configs[pattern] = config.dtype_configs
return pattern_to_dtype_configs
def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
qat_module_classes = []
- for config in backend_config.configs.values():
+ for config in backend_config.configs:
if config.qat_module is not None:
qat_module_classes.append(config.qat_module)
return tuple(set(qat_module_classes))
def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
fused_module_classes = []
- for config in backend_config.configs.values():
+ for config in backend_config.configs:
if config.fused_module is not None:
fused_module_classes.append(config.fused_module)
return tuple(set(fused_module_classes))
def get_pattern_to_input_type_to_index(backend_config: BackendConfig) -> Dict[Pattern, Dict[str, int]]:
pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
pattern_to_input_type_to_index[pattern] = config._input_type_to_index
return pattern_to_input_type_to_index
def get_root_module_to_quantized_reference_module(
backend_config: BackendConfig) -> Dict[Type[torch.nn.Module], Type[torch.nn.Module]]:
mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {}
- for config in backend_config.configs.values():
+ for config in backend_config.configs:
if config.root_module is not None and config.reference_quantized_module is not None:
mapping[config.root_module] = config.reference_quantized_module
return mapping
def get_fuser_method_mapping(backend_config: BackendConfig) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
if config.fuser_method is not None:
- fuser_method_mapping[pattern] = config.fuser_method
+ # Note: both the fuser method and the pattern are specified in forward order in the
+ # BackendConfig, but the internal pattern matching code uses the reversed nested tuple
+ # format, so we need to convert both to the internal format
+ fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config)
+ fuser_method_mapping[pattern] = fuser_method
return fuser_method_mapping
def get_module_to_qat_module(backend_config: BackendConfig) -> Dict[Pattern, Type[torch.nn.Module]]:
module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
if config.qat_module is not None:
module_to_qat_module[pattern] = config.qat_module
return module_to_qat_module
@@ -80,7 +92,7 @@
e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
"""
root_node_getter_mapping: Dict[Pattern, Callable] = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
if config._root_node_getter is not None:
root_node_getter_mapping[pattern] = config._root_node_getter
return root_node_getter_mapping
@@ -100,7 +112,7 @@
return [extra_input]
"""
extra_inputs_getter_mapping: Dict[Pattern, Callable] = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
if config._extra_inputs_getter is not None:
extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter
return extra_inputs_getter_mapping
@@ -188,3 +200,80 @@
s += "}"
return s
+
+def _get_pattern_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Pattern:
+ """
+ Return the pattern specified in the given config in the reversed nested tuple format
+ used internally in the quantization pattern matching code.
+
+ If the pattern is not a tuple, or the pattern is already specified in the reversed
+ nested tuple format, return the pattern as is. Otherwise:
+
+ For 2-tuples (a, b), return (b, a).
+ For 3-tuples (a, b, c), return (c, (b, a)).
+
+ For example:
+ * Given nn.Linear, return nn.Linear
+ * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear)
+ * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return
+ (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
+
+ For context, the reason why this is needed is the user-facing BackendConfig
+ API accepts the flat 2-or-3-tuple format in forward order. While this simple
+ format handles the vast majority of use cases, it does not handle the more
+ complex ones, and so the internal pattern matching code for quantization uses
+ the following, more general reversed nested tuple format instead:
+
+ operator = module_type | functional | torch op | native op | MatchAllNode
+ Pattern = (operator, Pattern, Pattern, ...) | operator
+
+ In the future, we expect to replace the above complex format with the one used
+ by the subgraph rewriter in torch.fx, so we don't have to maintain our own
+ complex pattern matching code. Then we won't need this helper function anymore.
+ """
+ if config._pattern_complex_format is not None:
+ return config._pattern_complex_format
+ if config.pattern is None:
+ raise ValueError("Either 'pattern' or 'pattern_complex_format' must be specified")
+ if not isinstance(config.pattern, tuple):
+ return config.pattern
+
+ # Pattern is specified in the simple tuple format, need to convert
+ if len(config.pattern) == 2:
+ (a, b) = config.pattern
+ return (b, a)
+ elif len(config.pattern) == 3:
+ (a, b, c) = config.pattern
+ return (c, (b, a))
+ else:
+ raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
+
+def _get_fuser_method_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Callable:
+ """
+ Return the fuser method specified in the given config in the reversed nested
+ tuple format used internally in the quantization pattern matching code.
+
+ If pattern is specified in the reversed nested tuple format, we assume the
+ fuser method is also specified in this format and simply return it as is.
+ Otherwise, we convert the fuser method as follows:
+
+ * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv)
+ * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv),
+ where bn_conv is a 2-tuple (bn, conv)
+
+ The first argument of a fuser method is always `is_qat` and is not affected
+ in the conversion. We currently only support functions with 3 or 4 arguments.
+ """
+ assert config.fuser_method is not None
+ if config._pattern_complex_format is not None:
+ return config.fuser_method
+ if not isinstance(config.pattern, tuple):
+ raise ValueError("Expected pattern to be a tuple, got: ", config.pattern)
+
+ # Pattern is specified in the simple tuple format, need to convert
+ if len(config.pattern) == 2:
+ return _reverse2(config.fuser_method)
+ elif len(config.pattern) == 3:
+ return _reverse3(config.fuser_method)
+ else:
+ raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py
index db4cc9a..9d6455d 100644
--- a/torch/ao/quantization/fuser_method_mappings.py
+++ b/torch/ao/quantization/fuser_method_mappings.py
@@ -190,16 +190,6 @@
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
return fuser_method
-def _reverse_sequential_wrapper2(sequential):
- """ Given a sequential class for two modules, return a function that takes
- is_qat, and then two modules as argument, that ignores the is_qat flag
- and always returns the sequential that combines the two input modules, with
- the order of two inputs reversed
- """
- def fuser_method(is_qat, m1, m2):
- return sequential(m2, m1)
- return fuser_method
-
def _reverse2(f):
def reversed(is_qat, x, y):
return f(is_qat, y, x)
@@ -211,25 +201,6 @@
return f(is_qat, z, y, x)
return reversed
-_DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
- (nn.BatchNorm1d, nn.Conv1d): _reverse2(fuse_conv_bn),
- (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): _reverse3(fuse_conv_bn_relu),
- (nn.BatchNorm2d, nn.Conv2d): _reverse2(fuse_conv_bn),
- (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): _reverse3(fuse_conv_bn_relu),
- (nn.BatchNorm3d, nn.Conv3d): _reverse2(fuse_conv_bn),
- (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): _reverse3(fuse_conv_bn_relu),
- (nn.ReLU, nn.Conv1d): _reverse_sequential_wrapper2(nni.ConvReLU1d),
- (nn.ReLU, nn.Conv2d): _reverse_sequential_wrapper2(nni.ConvReLU2d),
- (nn.ReLU, nn.Conv3d): _reverse_sequential_wrapper2(nni.ConvReLU3d),
- (nn.BatchNorm1d, nn.Linear): _reverse2(fuse_linear_bn),
- (nn.ReLU, nn.Linear): _reverse_sequential_wrapper2(nni.LinearReLU),
- (nn.ReLU, nn.BatchNorm2d): _reverse_sequential_wrapper2(nni.BNReLU2d),
- (nn.ReLU, nn.BatchNorm3d): _reverse_sequential_wrapper2(nni.BNReLU3d),
- (nn.BatchNorm1d, nn.ConvTranspose1d): _reverse2(fuse_convtranspose_bn),
- (nn.BatchNorm2d, nn.ConvTranspose2d): _reverse2(fuse_convtranspose_bn),
- (nn.BatchNorm3d, nn.ConvTranspose3d): _reverse2(fuse_convtranspose_bn),
-}
-
def _get_valid_patterns(op_pattern):
"""
Returns a list of valid patterns generated from the op_pattern,
@@ -263,13 +234,10 @@
def get_fuser_method_new(
op_pattern: Pattern,
- fuser_method_mapping: Optional[Dict[Pattern, Union[nn.Sequential, Callable]]] = None):
- """ This will be made defult after we deparate the get_fuser_method
+ fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]]):
+ """ This will be made defult after we deprecate the get_fuser_method
Would like to implement this first and have a separate PR for deprecation
"""
- if fuser_method_mapping is None:
- fuser_method_mapping = _DEFAULT_PATTERN_TO_FUSER_METHOD
-
op_patterns = _get_valid_patterns(op_pattern)
fuser_method = None
for op_pattern in op_patterns:
diff --git a/torch/ao/quantization/fx/README.md b/torch/ao/quantization/fx/README.md
index 7816247..08f69ed 100644
--- a/torch/ao/quantization/fx/README.md
+++ b/torch/ao/quantization/fx/README.md
@@ -80,8 +80,11 @@
`backend_config` configurations relevant to this step are:
```
-BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
- .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU))
+def fuse_linear_relu(is_qat, linear, relu):
+ return nni.LinearReLU(linear, relu)
+
+BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
+ .set_fuser_method(fuse_linear_relu)
._set_root_node_getter(my_root_node_getter)
._set_extra_inputs_getter(my_extra_inputs_getter)
```
diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py
index 6eaaff5..241803f 100644
--- a/torch/ao/quantization/fx/fuse.py
+++ b/torch/ao/quantization/fx/fuse.py
@@ -123,8 +123,9 @@
return model
def _find_matches(
- root: GraphModule, graph: Graph,
- patterns: Dict[Pattern, Callable]
+ root: GraphModule,
+ graph: Graph,
+ pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
modules = dict(root.named_modules())
# node name -> (root_node, match_value)
@@ -155,10 +156,10 @@
for node in reversed(graph.nodes):
if node.name not in match_map:
- for pattern, value in patterns.items():
+ for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
matched_node_pattern: List[Node] = []
if _is_match(modules, node, pattern):
- apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
+ apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern)
break
return match_map
diff --git a/torch/ao/quantization/fx/fuse_handler.py b/torch/ao/quantization/fx/fuse_handler.py
index 2106dc4..2706f96 100644
--- a/torch/ao/quantization/fx/fuse_handler.py
+++ b/torch/ao/quantization/fx/fuse_handler.py
@@ -4,7 +4,7 @@
from ..utils import _parent_name, NodePattern, Pattern
from ..fuser_method_mappings import get_fuser_method_new
from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, Optional, Union, List
+from typing import Any, Callable, Dict, List, Union
from .custom_config import FuseCustomConfig
from .match_utils import MatchAllNode
from torch.nn.utils.parametrize import type_before_parametrizations
@@ -35,7 +35,7 @@
extra_inputs: List[Any],
matched_node_pattern: NodePattern,
fuse_custom_config: FuseCustomConfig,
- fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
+ fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool) -> Node:
pass
@@ -53,7 +53,7 @@
extra_inputs: List[Any],
matched_node_pattern: NodePattern,
fuse_custom_config: FuseCustomConfig,
- fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
+ fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool) -> Node:
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
root_module = named_modules[str(root_node.target)]
@@ -112,7 +112,7 @@
def _get_fusion_pattern_to_fuse_handler_cls(
backend_config: BackendConfig) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
if config.fuser_method is not None:
# TODO: is this logic right?
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
diff --git a/torch/ao/quantization/fx/quantize_handler.py b/torch/ao/quantization/fx/quantize_handler.py
index 8670eee..473cc0d9 100644
--- a/torch/ao/quantization/fx/quantize_handler.py
+++ b/torch/ao/quantization/fx/quantize_handler.py
@@ -152,7 +152,7 @@
new path, this is not exposed to backend developers
"""
pattern_to_quantize_handlers = {}
- for pattern, config in backend_config.configs.items():
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
observation_type = config.observation_type
dtype_configs = config.dtype_configs
num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py
index 242e189..c215649 100644
--- a/torch/ao/quantization/fx/utils.py
+++ b/torch/ao/quantization/fx/utils.py
@@ -70,8 +70,9 @@
def node_arg_is_weight(node: Node, arg: Any, backend_config: BackendConfig) -> bool:
"""Returns if node arg is weight"""
- if isinstance(node, Node) and node.op == "call_function" and node.target in backend_config.configs:
- weight_index = backend_config.configs[node.target]._input_type_to_index.get("weight")
+ if isinstance(node, Node) and node.op == "call_function" and \
+ node.target in backend_config._pattern_complex_format_to_config:
+ weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight")
if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg:
return True
return node.kwargs.get("weight") is arg
@@ -79,8 +80,9 @@
def node_arg_is_bias(node: Node, arg: Any, backend_config: BackendConfig) -> bool:
"""Returns if node arg is bias"""
- if isinstance(node, Node) and node.op == "call_function" and node.target in backend_config.configs:
- bias_index = backend_config.configs[node.target]._input_type_to_index.get("bias")
+ if isinstance(node, Node) and node.op == "call_function" and \
+ node.target in backend_config._pattern_complex_format_to_config:
+ bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias")
if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg:
return True
return node.kwargs.get("bias") is arg