[reland][quant][fix] Compare resnet with quantizer api with the prepare_fx and decomposed convert flow (#99355)

Summary:
Using a decomposed convert to make sure we get exact match, this means the nodes in resnet are
annotated correctly, reland for https://github.com/pytorch/pytorch/pull/98905

Test Plan:
python test/test_quantization.py TestQuantizePT2EModels.test_resnet18_with_quantizer_api

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D45071168](https://our.internmc.facebook.com/intern/diff/D45071168)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99355
Approved by: https://github.com/kimishpatel
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index af4b1ea..370f32d 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -48,7 +48,6 @@
     default_reuse_input_qconfig,
     default_symmetric_qnnpack_qconfig,
     default_symmetric_qnnpack_qat_qconfig,
-    default_per_channel_symmetric_qnnpack_qconfig,
     per_channel_dynamic_qconfig,
     float16_dynamic_qconfig,
     float16_static_qconfig,
@@ -192,7 +191,6 @@
 from torch.testing._internal.common_utils import (
     TemporaryFileName,
     IS_ARM64,
-    IS_WINDOWS,
 )
 
 from torch.testing._internal.common_quantization import NodeSpec as ns
@@ -6163,57 +6161,6 @@
         res = m(*example_inputs)
         self.assertEqual(res, res_ref)
 
-    @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
-    def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
-        """ Test the result for per channel weight quant for reference modules
-        """
-        class M(torch.nn.Module):
-            def __init__(self):
-                super().__init__()
-                self.conv = torch.nn.Conv2d(3, 3, 3)
-
-            def forward(self, x):
-                return self.conv(x)
-
-        m = M().eval()
-        qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
-        example_inputs = (torch.randn(1, 3, 10, 10),)
-        m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
-        m(*example_inputs)
-        m_ref = copy.deepcopy(m)
-        m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
-        m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
-        expected_occurrence = {
-            # for input and output activations
-            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
-            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
-            # weight is per channel quantized
-            ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
-            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
-        }
-        import torch._dynamo as torchdynamo
-        m, guards = torchdynamo.export(
-            m,
-            *copy.deepcopy(example_inputs),
-            aten_graph=True,
-            tracing_mode="real",
-        )
-        self.checkGraphModuleNodes(
-            m,
-            expected_node_occurrence=expected_occurrence)
-        # make sure it runs
-        res_ref = m_ref(*example_inputs)
-        res = m(*example_inputs)
-        self.assertEqual(res, res_ref)
-        # check the qmin/qmax for per channel quant
-        for n in m.graph.nodes:
-            if n.op == "call_function" and \
-               n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
-                _QUANT_MIN_INDEX = 4
-                _QUANT_MAX_INDEX = 5
-                self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
-                self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
-
     def test_change_backend_config_for_fixed_qparam_ops(self):
         """ Making sure we can skip validation of qconfigs for fixedqparam ops based
         on BackendConfig
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index 5470ec4..b9d4a4b 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -265,11 +265,5 @@
                 compute_sqnr(after_prepare_result, after_prepare_result_fx),
                 torch.tensor(float("inf")),
             )
-            # there are slight differences after convert due to different implementations
-            # of quant/dequant
-            self.assertTrue(
-                torch.max(after_quant_result - after_quant_result_fx) < 1e-1
-            )
-            self.assertTrue(
-                compute_sqnr(after_quant_result, after_quant_result_fx) > 35
-            )
+            self.assertEqual(after_quant_result, after_quant_result_fx)
+            self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) == torch.tensor(float("inf")))
diff --git a/test/quantization/pt2e/test_quantize_pt2e_fx.py b/test/quantization/pt2e/test_quantize_pt2e_fx.py
index 8b93190..1af5100 100644
--- a/test/quantization/pt2e/test_quantize_pt2e_fx.py
+++ b/test/quantization/pt2e/test_quantize_pt2e_fx.py
@@ -7,7 +7,12 @@
 import torch.nn as nn
 from torch._inductor.compile_fx import compile_fx
 from torch.ao.ns.fx.utils import compute_sqnr
-from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
+from torch.ao.quantization import (
+    get_default_qconfig,
+    observer,
+    QConfigMapping,
+    default_per_channel_symmetric_qnnpack_qconfig,
+)
 from torch.ao.quantization._quantize_pt2e import (
     convert_pt2e,
     prepare_pt2e,
@@ -23,8 +28,13 @@
 from torch.ao.quantization.quantize_fx import (
     convert_fx,
     convert_to_reference_fx,
+    _convert_to_reference_decomposed_fx,
     prepare_fx,
 )
+
+from torch.testing._internal.common_utils import (
+    IS_WINDOWS,
+)
 from torch.testing._internal.common_quantization import (
     NodeSpec as ns,
     QuantizationTestCase,
@@ -33,8 +43,10 @@
     skipIfNoX86,
 )
 from torch.testing._internal.common_quantized import override_quantized_engine
+import unittest
 
 
+# TODO: remove after quantizer API is more mature
 @skipIfNoQNNPACK
 class TestQuantizePT2EFX(QuantizationTestCase):
     def test_qconfig_none(self):
@@ -255,6 +267,57 @@
             }
             self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
 
+    # TODO(jerryzh168): move all _convert_to_reference_decomposed_fx tests here
+    @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
+    def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
+        """ Test the result for per channel weight quant for reference modules
+        """
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.conv = torch.nn.Conv2d(3, 3, 3)
+
+            def forward(self, x):
+                return self.conv(x)
+
+        m = M().eval()
+        qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
+        example_inputs = (torch.randn(1, 3, 10, 10),)
+        m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
+        m(*example_inputs)
+        m_ref = copy.deepcopy(m)
+        m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
+        m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
+        expected_occurrence = {
+            # for input and output activations
+            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
+            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
+            # weight is per channel quantized
+            ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
+            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
+        }
+        import torch._dynamo as torchdynamo
+        m, guards = torchdynamo.export(
+            m,
+            *copy.deepcopy(example_inputs),
+            aten_graph=True,
+            tracing_mode="real",
+        )
+        self.checkGraphModuleNodes(
+            m,
+            expected_node_occurrence=expected_occurrence)
+        # make sure it runs
+        res_ref = m_ref(*example_inputs)
+        res = m(*example_inputs)
+        self.assertEqual(res, res_ref)
+        # check the qmin/qmax for per channel quant
+        for n in m.graph.nodes:
+            if n.op == "call_function" and \
+               n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
+                _QUANT_MIN_INDEX = 4
+                _QUANT_MAX_INDEX = 5
+                self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
+                self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
 
 @skipIfNoQNNPACK
 class TestQuantizePT2EFXX86Inductor(QuantizationTestCase):
@@ -425,7 +488,7 @@
                 m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
             )
             after_prepare_result_fx = m_fx(*example_inputs)
-            m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
+            m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
 
             after_quant_result_fx = m_fx(*example_inputs)
 
@@ -437,9 +500,5 @@
             )
             # there are slight differences after convert due to different implementations
             # of quant/dequant
-            self.assertTrue(
-                torch.max(after_quant_result - after_quant_result_fx) < 1e-1
-            )
-            self.assertTrue(
-                compute_sqnr(after_quant_result, after_quant_result_fx) > 35
-            )
+            self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
+            self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)