[quant][pt2e] Fix a bug in reference quantized module (decomposed mode) (#98903)
Summary:
Fixed quant_min/quant_max for per channel quantized weight for reference quantized module in decomposed mode,
this bug is triggered while onboard an internal model
Test Plan:
python test/test_quantization.py TestQuantizeFx.test__convert_to_reference_decomposed_fx_per_channel_quant_module
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98903
Approved by: https://github.com/andrewor14
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index cc89792..af4b1ea 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -48,6 +48,7 @@
default_reuse_input_qconfig,
default_symmetric_qnnpack_qconfig,
default_symmetric_qnnpack_qat_qconfig,
+ default_per_channel_symmetric_qnnpack_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float16_static_qconfig,
@@ -188,7 +189,11 @@
override_quantized_engine,
)
-from torch.testing._internal.common_utils import TemporaryFileName, IS_ARM64
+from torch.testing._internal.common_utils import (
+ TemporaryFileName,
+ IS_ARM64,
+ IS_WINDOWS,
+)
from torch.testing._internal.common_quantization import NodeSpec as ns
@@ -6158,6 +6163,57 @@
res = m(*example_inputs)
self.assertEqual(res, res_ref)
+ @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
+ def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
+ """ Test the result for per channel weight quant for reference modules
+ """
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 3, 3)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ m = M().eval()
+ qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
+ example_inputs = (torch.randn(1, 3, 10, 10),)
+ m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
+ m(*example_inputs)
+ m_ref = copy.deepcopy(m)
+ m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
+ m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
+ expected_occurrence = {
+ # for input and output activations
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
+ # weight is per channel quantized
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
+ }
+ import torch._dynamo as torchdynamo
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+ self.checkGraphModuleNodes(
+ m,
+ expected_node_occurrence=expected_occurrence)
+ # make sure it runs
+ res_ref = m_ref(*example_inputs)
+ res = m(*example_inputs)
+ self.assertEqual(res, res_ref)
+ # check the qmin/qmax for per channel quant
+ for n in m.graph.nodes:
+ if n.op == "call_function" and \
+ n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
+ _QUANT_MIN_INDEX = 4
+ _QUANT_MAX_INDEX = 5
+ self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
+ self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
+
def test_change_backend_config_for_fixed_qparam_ops(self):
""" Making sure we can skip validation of qconfigs for fixedqparam ops based
on BackendConfig
diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py
index 422e710..2c1f52c 100644
--- a/torch/ao/nn/quantized/reference/modules/utils.py
+++ b/torch/ao/nn/quantized/reference/modules/utils.py
@@ -165,7 +165,8 @@
# TODO: torch.quint4x2 is not supported
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
- weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
+ if weight_quant_min is None or weight_quant_max is None:
+ weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
weight = torch.ops.quantized_decomposed.quantize_per_channel(
weight,
weight_scale,