| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import unittest |
| |
| import torch |
| import torchvision |
| from executorch.exir import EdgeCompileConfig, to_edge |
| from executorch.exir.passes.quant_fusion_pass import QuantFusionPass |
| from executorch.exir.passes.spec_prop_pass import SpecPropPass |
| from torch.ao.ns.fx.utils import compute_sqnr |
| from torch.ao.quantization import QConfigMapping # @manual |
| from torch.ao.quantization.backend_config import get_executorch_backend_config |
| from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig |
| from torch.ao.quantization.quantize_fx import prepare_fx |
| from torch.ao.quantization.quantize_pt2e import ( |
| _convert_to_reference_decomposed_fx, |
| convert_pt2e, |
| prepare_pt2e, |
| ) |
| |
| from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
| get_symmetric_quantization_config, |
| XNNPACKQuantizer, |
| ) |
| from torch.export import export |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_quantized import override_quantized_engine |
| |
| # load executorch out variant ops |
| torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") |
| |
| |
| class TestQuantization(unittest.TestCase): |
| """prepare_pt2e and convert_pt2e are OSS APIs, the rest are all meta-only |
| |
| APIs for now, but we plan to open source them in the future |
| """ |
| |
| def test_resnet(self) -> None: |
| import copy |
| |
| with override_quantized_engine("qnnpack"): |
| torch.backends.quantized.engine = "qnnpack" |
| example_inputs = (torch.randn(1, 3, 224, 224),) |
| m = torchvision.models.resnet18().eval() |
| m_copy = copy.deepcopy(m) |
| # program capture |
| m = torch.export.export_for_training( |
| m, copy.deepcopy(example_inputs) |
| ).module() |
| |
| quantizer = XNNPACKQuantizer() |
| operator_config = get_symmetric_quantization_config(is_per_channel=True) |
| quantizer.set_global(operator_config) |
| m = prepare_pt2e(m, quantizer) # pyre-fixme[6] |
| self.assertEqual( |
| id(m.activation_post_process_3), id(m.activation_post_process_2) |
| ) |
| after_prepare_result = m(*example_inputs)[0] |
| m = convert_pt2e(m) |
| |
| # TODO: conv, conv_relu, linear delegation |
| # quantized ops to implement: add_relu |
| compile_config = EdgeCompileConfig( |
| _check_ir_validity=False, |
| ) |
| m = to_edge( |
| export(m, example_inputs), compile_config=compile_config |
| ).transform([QuantFusionPass(), SpecPropPass()]) |
| |
| after_quant_result = m.exported_program().module()(*example_inputs)[0] |
| FileCheck().check( |
| "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor" |
| ).check( |
| "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor" |
| ).run( |
| m.exported_program().graph_module.code |
| ) |
| # after_quant_fusion_result = m(*example_inputs)[0] |
| |
| # TODO: implement torch.ops.quantized_decomposed.add_relu.out |
| # m = m.to_executorch().dump_graph_module() |
| # after_to_executorch = m(*example_inputs)[0] |
| # test the result before and after to_executorch matches |
| # TODO: debug why this is a mismatch |
| # self.assertTrue(torch.equal(after_quant_fusion_result, after_to_executorch)) |
| # self.assertEqual(compute_sqnr(after_quant_fusion_result, after_to_executorch), torch.tensor(float("inf"))) |
| |
| # comparing with existing fx graph mode quantization reference flow |
| qconfig = default_per_channel_symmetric_qnnpack_qconfig |
| qconfig_mapping = QConfigMapping().set_global(qconfig) |
| backend_config = get_executorch_backend_config() |
| m_fx = prepare_fx( |
| m_copy, qconfig_mapping, example_inputs, backend_config=backend_config |
| ) |
| after_prepare_result_fx = m_fx(*example_inputs) |
| m_fx = _convert_to_reference_decomposed_fx( |
| m_fx, backend_config=backend_config |
| ) |
| after_quant_result_fx = m_fx(*example_inputs) |
| |
| # the result matches exactly after prepare |
| self.assertTrue( |
| torch.allclose(after_prepare_result, after_prepare_result_fx, atol=1e-6) |
| ) |
| |
| # there are slight differences after convert due to different implementations |
| # of quant/dequant |
| self.assertTrue( |
| torch.max(after_quant_result - after_quant_result_fx) < 1e-1 |
| ) |
| self.assertTrue( |
| compute_sqnr(after_quant_result, after_quant_result_fx) > 35 |
| ) |