[quant][pt2e] Fix a bug in reference quantized module (decomposed mode) (#98903)

Summary:
Fixed quant_min/quant_max for per channel quantized weight for reference quantized module in decomposed mode,
this bug is triggered while onboard an internal model

Test Plan:
python test/test_quantization.py TestQuantizeFx.test__convert_to_reference_decomposed_fx_per_channel_quant_module

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98903
Approved by: https://github.com/andrewor14
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index cc89792..af4b1ea 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -48,6 +48,7 @@
     default_reuse_input_qconfig,
     default_symmetric_qnnpack_qconfig,
     default_symmetric_qnnpack_qat_qconfig,
+    default_per_channel_symmetric_qnnpack_qconfig,
     per_channel_dynamic_qconfig,
     float16_dynamic_qconfig,
     float16_static_qconfig,
@@ -188,7 +189,11 @@
     override_quantized_engine,
 )
 
-from torch.testing._internal.common_utils import TemporaryFileName, IS_ARM64
+from torch.testing._internal.common_utils import (
+    TemporaryFileName,
+    IS_ARM64,
+    IS_WINDOWS,
+)
 
 from torch.testing._internal.common_quantization import NodeSpec as ns
 
@@ -6158,6 +6163,57 @@
         res = m(*example_inputs)
         self.assertEqual(res, res_ref)
 
+    @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
+    def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
+        """ Test the result for per channel weight quant for reference modules
+        """
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.conv = torch.nn.Conv2d(3, 3, 3)
+
+            def forward(self, x):
+                return self.conv(x)
+
+        m = M().eval()
+        qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
+        example_inputs = (torch.randn(1, 3, 10, 10),)
+        m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
+        m(*example_inputs)
+        m_ref = copy.deepcopy(m)
+        m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
+        m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
+        expected_occurrence = {
+            # for input and output activations
+            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
+            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
+            # weight is per channel quantized
+            ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
+            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
+        }
+        import torch._dynamo as torchdynamo
+        m, guards = torchdynamo.export(
+            m,
+            *copy.deepcopy(example_inputs),
+            aten_graph=True,
+            tracing_mode="real",
+        )
+        self.checkGraphModuleNodes(
+            m,
+            expected_node_occurrence=expected_occurrence)
+        # make sure it runs
+        res_ref = m_ref(*example_inputs)
+        res = m(*example_inputs)
+        self.assertEqual(res, res_ref)
+        # check the qmin/qmax for per channel quant
+        for n in m.graph.nodes:
+            if n.op == "call_function" and \
+               n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
+                _QUANT_MIN_INDEX = 4
+                _QUANT_MAX_INDEX = 5
+                self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
+                self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
+
     def test_change_backend_config_for_fixed_qparam_ops(self):
         """ Making sure we can skip validation of qconfigs for fixedqparam ops based
         on BackendConfig
diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py
index 422e710..2c1f52c 100644
--- a/torch/ao/nn/quantized/reference/modules/utils.py
+++ b/torch/ao/nn/quantized/reference/modules/utils.py
@@ -165,7 +165,8 @@
         # TODO: torch.quint4x2 is not supported
         if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
             weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
-            weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
+            if weight_quant_min is None or weight_quant_max is None:
+                weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
             weight = torch.ops.quantized_decomposed.quantize_per_channel(
                 weight,
                 weight_scale,