| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import logging |
| |
| from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
| from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
| get_symmetric_quantization_config, |
| XNNPACKQuantizer, |
| ) |
| |
| |
| def quantize(model, example_inputs): |
| """This is the official recommended flow for quantization in pytorch 2.0 export""" |
| logging.info(f"Original model: {model}") |
| quantizer = XNNPACKQuantizer() |
| # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel |
| operator_config = get_symmetric_quantization_config(is_per_channel=False) |
| quantizer.set_global(operator_config) |
| m = prepare_pt2e(model, quantizer) |
| # calibration |
| m(*example_inputs) |
| m = convert_pt2e(m) |
| logging.info(f"Quantized model: {m}") |
| # make sure we can export to flat buffer |
| return m |