[reland][quant][fix] Compare resnet with quantizer api with the prepare_fx and decomposed convert flow (#99355)
Summary:
Using a decomposed convert to make sure we get exact match, this means the nodes in resnet are
annotated correctly, reland for https://github.com/pytorch/pytorch/pull/98905
Test Plan:
python test/test_quantization.py TestQuantizePT2EModels.test_resnet18_with_quantizer_api
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D45071168](https://our.internmc.facebook.com/intern/diff/D45071168)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99355
Approved by: https://github.com/kimishpatel
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index af4b1ea..370f32d 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -48,7 +48,6 @@
default_reuse_input_qconfig,
default_symmetric_qnnpack_qconfig,
default_symmetric_qnnpack_qat_qconfig,
- default_per_channel_symmetric_qnnpack_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float16_static_qconfig,
@@ -192,7 +191,6 @@
from torch.testing._internal.common_utils import (
TemporaryFileName,
IS_ARM64,
- IS_WINDOWS,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
@@ -6163,57 +6161,6 @@
res = m(*example_inputs)
self.assertEqual(res, res_ref)
- @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
- def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
- """ Test the result for per channel weight quant for reference modules
- """
- class M(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv = torch.nn.Conv2d(3, 3, 3)
-
- def forward(self, x):
- return self.conv(x)
-
- m = M().eval()
- qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
- example_inputs = (torch.randn(1, 3, 10, 10),)
- m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
- m(*example_inputs)
- m_ref = copy.deepcopy(m)
- m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
- m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
- expected_occurrence = {
- # for input and output activations
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
- # weight is per channel quantized
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
- }
- import torch._dynamo as torchdynamo
- m, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
- tracing_mode="real",
- )
- self.checkGraphModuleNodes(
- m,
- expected_node_occurrence=expected_occurrence)
- # make sure it runs
- res_ref = m_ref(*example_inputs)
- res = m(*example_inputs)
- self.assertEqual(res, res_ref)
- # check the qmin/qmax for per channel quant
- for n in m.graph.nodes:
- if n.op == "call_function" and \
- n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
- _QUANT_MIN_INDEX = 4
- _QUANT_MAX_INDEX = 5
- self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
- self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
-
def test_change_backend_config_for_fixed_qparam_ops(self):
""" Making sure we can skip validation of qconfigs for fixedqparam ops based
on BackendConfig
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index 5470ec4..b9d4a4b 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -265,11 +265,5 @@
compute_sqnr(after_prepare_result, after_prepare_result_fx),
torch.tensor(float("inf")),
)
- # there are slight differences after convert due to different implementations
- # of quant/dequant
- self.assertTrue(
- torch.max(after_quant_result - after_quant_result_fx) < 1e-1
- )
- self.assertTrue(
- compute_sqnr(after_quant_result, after_quant_result_fx) > 35
- )
+ self.assertEqual(after_quant_result, after_quant_result_fx)
+ self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) == torch.tensor(float("inf")))
diff --git a/test/quantization/pt2e/test_quantize_pt2e_fx.py b/test/quantization/pt2e/test_quantize_pt2e_fx.py
index 8b93190..1af5100 100644
--- a/test/quantization/pt2e/test_quantize_pt2e_fx.py
+++ b/test/quantization/pt2e/test_quantize_pt2e_fx.py
@@ -7,7 +7,12 @@
import torch.nn as nn
from torch._inductor.compile_fx import compile_fx
from torch.ao.ns.fx.utils import compute_sqnr
-from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
+from torch.ao.quantization import (
+ get_default_qconfig,
+ observer,
+ QConfigMapping,
+ default_per_channel_symmetric_qnnpack_qconfig,
+)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
@@ -23,8 +28,13 @@
from torch.ao.quantization.quantize_fx import (
convert_fx,
convert_to_reference_fx,
+ _convert_to_reference_decomposed_fx,
prepare_fx,
)
+
+from torch.testing._internal.common_utils import (
+ IS_WINDOWS,
+)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
@@ -33,8 +43,10 @@
skipIfNoX86,
)
from torch.testing._internal.common_quantized import override_quantized_engine
+import unittest
+# TODO: remove after quantizer API is more mature
@skipIfNoQNNPACK
class TestQuantizePT2EFX(QuantizationTestCase):
def test_qconfig_none(self):
@@ -255,6 +267,57 @@
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
+ # TODO(jerryzh168): move all _convert_to_reference_decomposed_fx tests here
+ @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
+ def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
+ """ Test the result for per channel weight quant for reference modules
+ """
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 3, 3)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ m = M().eval()
+ qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
+ example_inputs = (torch.randn(1, 3, 10, 10),)
+ m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
+ m(*example_inputs)
+ m_ref = copy.deepcopy(m)
+ m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
+ m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
+ expected_occurrence = {
+ # for input and output activations
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
+ # weight is per channel quantized
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
+ }
+ import torch._dynamo as torchdynamo
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+ self.checkGraphModuleNodes(
+ m,
+ expected_node_occurrence=expected_occurrence)
+ # make sure it runs
+ res_ref = m_ref(*example_inputs)
+ res = m(*example_inputs)
+ self.assertEqual(res, res_ref)
+ # check the qmin/qmax for per channel quant
+ for n in m.graph.nodes:
+ if n.op == "call_function" and \
+ n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
+ _QUANT_MIN_INDEX = 4
+ _QUANT_MAX_INDEX = 5
+ self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
+ self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
@skipIfNoQNNPACK
class TestQuantizePT2EFXX86Inductor(QuantizationTestCase):
@@ -425,7 +488,7 @@
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
after_prepare_result_fx = m_fx(*example_inputs)
- m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
+ m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
after_quant_result_fx = m_fx(*example_inputs)
@@ -437,9 +500,5 @@
)
# there are slight differences after convert due to different implementations
# of quant/dequant
- self.assertTrue(
- torch.max(after_quant_result - after_quant_result_fx) < 1e-1
- )
- self.assertTrue(
- compute_sqnr(after_quant_result, after_quant_result_fx) > 35
- )
+ self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
+ self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)