[quant][fx] Move all binary op configs to backend_config_dict (#75241)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75241
We have a previous PR that enabled operator.add in backend_config_dict, this
PR moved the rest binary ops to backend_config_dict.
There are some ops left, which are not needed (previously fp16 ops), we
will move them in the following PR
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D35403589
fbshipit-source-id: 663703b310944a6b7c5ade6d07a4d938a6ca082b
(cherry picked from commit 5a76ce031872c4fed5fcab5bb3c84a9394b01118)
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index aed8822..4164e58 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -5000,7 +5000,7 @@
m = M()
expected_node_occurrence = {
- ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 6,
+ ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
}
self._test_quantized_add_mul_qat(m, expected_node_occurrence)
@@ -5016,14 +5016,13 @@
x = torch.mul(x, 1.0)
x = self.conv1(x)
x = torch.mul(x, 1.0)
- # TODO: add support for add + torch.relu?
x = torch.relu(x)
x = self.conv2(x)
return x
m = M()
expected_node_occurrence = {
- ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 6,
+ ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
}
self._test_quantized_add_mul_qat(m, expected_node_occurrence)
diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py
index c796a82..5c3574c 100644
--- a/torch/ao/ns/fx/mappings.py
+++ b/torch/ao/ns/fx/mappings.py
@@ -451,9 +451,9 @@
F.silu,
F.mish,
operator.add,
- # TODO(future PR): implement shadowing for binary ops and
- # uncomment below
- # operator.mul,
+ torch.add,
+ operator.mul,
+ torch.mul,
torch.sum,
])
diff --git a/torch/ao/quantization/fx/backend_config/native.py b/torch/ao/quantization/fx/backend_config/native.py
index 4e8e8c0..b897d46 100644
--- a/torch/ao/quantization/fx/backend_config/native.py
+++ b/torch/ao/quantization/fx/backend_config/native.py
@@ -1,4 +1,5 @@
from collections import namedtuple
+from typing import List, Dict, Any
import operator
import torch
from .observation_type import ObservationType
@@ -289,20 +290,43 @@
})
return conv_configs
-_ADD_CONFIG = {
- "pattern": operator.add,
- "num_tensor_args_to_observation_type": {
+def _get_binary_op_configs():
+ binary_op_configs: List[Dict[str, Any]] = []
+ num_tensor_args_to_observation_type_mapping = {
# TODO: this is not used right now since we have extra check in prepare
# will need to change this to NO_OBSERVER later after we implemented
# Tensor dtype inference properly
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
- },
- "dtype_configs": [
+ }
+ dtype_configs = [
weighted_op_int8_dtype_config,
- ],
-}
+ ]
+ for op_with_quantized_bop_scalar_variant in [
+ operator.add, torch.add, operator.mul, torch.mul]:
+ binary_op_configs.append({
+ "pattern": (torch.nn.ReLU, op_with_quantized_bop_scalar_variant),
+ "num_tensor_args_to_observation_type": num_tensor_args_to_observation_type_mapping,
+ "dtype_configs": dtype_configs,
+ })
+ binary_op_configs.append({
+ "pattern": (torch.nn.functional.relu, op_with_quantized_bop_scalar_variant),
+ "num_tensor_args_to_observation_type": num_tensor_args_to_observation_type_mapping,
+ "dtype_configs": dtype_configs,
+ })
+ binary_op_configs.append({
+ "pattern": (torch.relu, op_with_quantized_bop_scalar_variant),
+ "num_tensor_args_to_observation_type": num_tensor_args_to_observation_type_mapping,
+ "dtype_configs": dtype_configs,
+ })
+ binary_op_configs.append({
+ "pattern": op_with_quantized_bop_scalar_variant,
+ "num_tensor_args_to_observation_type": num_tensor_args_to_observation_type_mapping,
+ "dtype_configs": dtype_configs,
+ })
+ return binary_op_configs
+
_HARDSIGMOID_MODULE_CONFIG = {
"pattern": torch.nn.Hardsigmoid,
@@ -329,7 +353,7 @@
*_DEFAULT_OP_INT8_CONFIGS,
*_get_linear_configs(),
*_get_conv_configs(),
- _ADD_CONFIG,
+ *_get_binary_op_configs(),
_HARDSIGMOID_MODULE_CONFIG,
],
}
diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py
index 9114664..79cac99 100644
--- a/torch/ao/quantization/fx/quantization_patterns.py
+++ b/torch/ao/quantization/fx/quantization_patterns.py
@@ -117,23 +117,10 @@
return self.is_standalone_module_
@register_quant_pattern(operator.sub)
-@register_quant_pattern(operator.mul)
@register_quant_pattern(operator.truediv)
-@register_quant_pattern(torch.add)
@register_quant_pattern(torch.sub)
-@register_quant_pattern(torch.mul)
@register_quant_pattern(torch.div)
@register_quant_pattern(torch.bmm)
-@register_quant_pattern((torch.nn.ReLU, operator.add))
-@register_quant_pattern((torch.nn.ReLU, operator.mul))
-@register_quant_pattern((torch.nn.ReLU, torch.add))
-@register_quant_pattern((torch.nn.ReLU, torch.mul))
-@register_quant_pattern((torch.nn.functional.relu, operator.add))
-@register_quant_pattern((torch.nn.functional.relu, operator.mul))
-@register_quant_pattern((torch.nn.functional.relu, torch.add))
-@register_quant_pattern((torch.nn.functional.relu, torch.mul))
-@register_quant_pattern((torch.relu, operator.add))
-@register_quant_pattern((torch.relu, operator.mul))
@register_quant_pattern(torch.matmul)
class BinaryOpQuantizeHandler(QuantizeHandler):
def __init__(