blob: 6f8aa3913f3906d492bcad0c15a1fa14ef1b3457 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
def quantize(model, example_inputs):
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
logging.info(f"Original model: {model}")
quantizer = XNNPACKQuantizer()
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
m = prepare_pt2e(model, quantizer)
# calibration
m(*example_inputs)
m = convert_pt2e(m)
logging.info(f"Quantized model: {m}")
# make sure we can export to flat buffer
return m