blob: 269a9ee11bc18dab68be5e99559d2f116dc1e7b8 [file] [log] [blame] [edit]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import unittest
import torch
import torchvision
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization import QConfigMapping # @manual
from torch.ao.quantization.backend_config import get_executorch_backend_config
from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx
from torch.ao.quantization.quantize_pt2e import (
_convert_to_reference_decomposed_fx,
convert_pt2e,
prepare_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export
from torch.testing import FileCheck
from torch.testing._internal.common_quantized import override_quantized_engine
# load executorch out variant ops
torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib")
class TestQuantization(unittest.TestCase):
"""prepare_pt2e and convert_pt2e are OSS APIs, the rest are all meta-only
APIs for now, but we plan to open source them in the future
"""
def test_resnet(self) -> None:
import copy
with override_quantized_engine("qnnpack"):
torch.backends.quantized.engine = "qnnpack"
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m = torch.export.export_for_training(
m, copy.deepcopy(example_inputs)
).module()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
after_prepare_result = m(*example_inputs)[0]
m = convert_pt2e(m)
# TODO: conv, conv_relu, linear delegation
# quantized ops to implement: add_relu
compile_config = EdgeCompileConfig(
_check_ir_validity=False,
)
m = to_edge(
export(m, example_inputs), compile_config=compile_config
).transform([QuantFusionPass(), SpecPropPass()])
after_quant_result = m.exported_program().module()(*example_inputs)[0]
FileCheck().check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor"
).check(
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor"
).run(
m.exported_program().graph_module.code
)
# after_quant_fusion_result = m(*example_inputs)[0]
# TODO: implement torch.ops.quantized_decomposed.add_relu.out
# m = m.to_executorch().dump_graph_module()
# after_to_executorch = m(*example_inputs)[0]
# test the result before and after to_executorch matches
# TODO: debug why this is a mismatch
# self.assertTrue(torch.equal(after_quant_fusion_result, after_to_executorch))
# self.assertEqual(compute_sqnr(after_quant_fusion_result, after_to_executorch), torch.tensor(float("inf")))
# comparing with existing fx graph mode quantization reference flow
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_executorch_backend_config()
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
after_prepare_result_fx = m_fx(*example_inputs)
m_fx = _convert_to_reference_decomposed_fx(
m_fx, backend_config=backend_config
)
after_quant_result_fx = m_fx(*example_inputs)
# the result matches exactly after prepare
self.assertTrue(
torch.allclose(after_prepare_result, after_prepare_result_fx, atol=1e-6)
)
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
)
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)