| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import json |
| import os |
| from multiprocessing.connection import Client |
| |
| import numpy as np |
| import torch |
| from executorch.backends.qualcomm.quantizer.annotators import ( |
| QuantizationConfig, |
| QuantizationSpec, |
| ) |
| from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( |
| PerChannelParamObserver, |
| ) |
| from executorch.backends.qualcomm.quantizer.qconfig import ( |
| _derived_bias_quant_spec, |
| MovingAverageMinMaxObserver, |
| ) |
| |
| from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype |
| from executorch.backends.qualcomm.utils.constants import ( |
| QCOM_PASS_EXPAND_BROADCAST_SHAPE, |
| ) |
| from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d |
| from executorch.examples.qualcomm.utils import ( |
| build_executorch_binary, |
| get_imagenet_dataset, |
| make_output_dir, |
| make_quantizer, |
| parse_skip_delegation_node, |
| setup_common_args_and_variables, |
| SimpleADB, |
| topk_accuracy, |
| ) |
| |
| |
| def get_instance(repo_path: str, checkpoint_path: str): |
| import sys |
| |
| sys.path.insert(0, repo_path) |
| |
| from models.modules.mobileone import reparameterize_model |
| from timm.models import create_model |
| |
| checkpoint = torch.load(checkpoint_path, weights_only=True) |
| model = create_model("fastvit_s12") |
| model = reparameterize_model(model).eval() |
| model.load_state_dict(checkpoint["state_dict"]) |
| return model |
| |
| |
| def main(args): |
| skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) |
| |
| # ensure the working directory exist. |
| os.makedirs(args.artifact, exist_ok=True) |
| |
| if not args.compile_only and args.device is None: |
| raise RuntimeError( |
| "device serial is required if not compile only. " |
| "Please specify a device serial by -s/--device argument." |
| ) |
| |
| data_num = 100 |
| inputs, targets, input_list = get_imagenet_dataset( |
| dataset_path=f"{args.dataset}", |
| data_size=data_num, |
| image_shape=(256, 256), |
| ) |
| |
| pte_filename = "fastvit_qnn" |
| quantizer = make_quantizer(quant_dtype=QuantDtype.use_8a8w) |
| |
| # there are lots of outliers appearing in fastvit parameters |
| # we need to apply special configuration to saturate their impact |
| act_qspec = QuantizationSpec( |
| dtype=torch.uint8, |
| qscheme=torch.per_tensor_affine, |
| observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args( |
| **{"averaging_constant": 0.02} |
| ), |
| ) |
| weight_qspec = QuantizationSpec( |
| dtype=torch.int8, |
| quant_min=torch.iinfo(torch.int8).min + 1, |
| quant_max=torch.iinfo(torch.int8).max, |
| qscheme=torch.per_channel_symmetric, |
| ch_axis=0, |
| observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( |
| **{"steps": 200, "use_mse": True} |
| ), |
| ) |
| # rewrite default per-channel ptq config |
| quantizer.per_channel_quant_config = QuantizationConfig( |
| input_activation=act_qspec, |
| output_activation=act_qspec, |
| weight=weight_qspec, |
| bias=_derived_bias_quant_spec, |
| ) |
| # rewrite default ptq config |
| q_config = quantizer.bit8_quant_config |
| quantizer.bit8_quant_config = QuantizationConfig( |
| input_activation=act_qspec, |
| output_activation=act_qspec, |
| weight=q_config.weight, |
| bias=q_config.bias, |
| ) |
| # lower to QNN |
| build_executorch_binary( |
| convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)), |
| inputs[0], |
| args.model, |
| f"{args.artifact}/{pte_filename}", |
| dataset=inputs, |
| skip_node_id_set=skip_node_id_set, |
| skip_node_op_set=skip_node_op_set, |
| quant_dtype=QuantDtype.use_8a8w, |
| custom_quantizer=quantizer, |
| custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE}, |
| shared_buffer=args.shared_buffer, |
| ) |
| |
| if args.compile_only: |
| return |
| |
| adb = SimpleADB( |
| qnn_sdk=os.getenv("QNN_SDK_ROOT"), |
| build_path=f"{args.build_folder}", |
| pte_path=f"{args.artifact}/{pte_filename}.pte", |
| workspace=f"/data/local/tmp/executorch/{pte_filename}", |
| device_id=args.device, |
| host_id=args.host, |
| soc_model=args.model, |
| ) |
| adb.push(inputs=inputs, input_list=input_list) |
| adb.execute() |
| |
| # collect output data |
| output_data_folder = f"{args.artifact}/outputs" |
| make_output_dir(output_data_folder) |
| |
| adb.pull(output_path=args.artifact) |
| |
| # top-k analysis |
| predictions = [] |
| for i in range(data_num): |
| predictions.append( |
| np.fromfile( |
| os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 |
| ) |
| ) |
| |
| k_val = [1, 5] |
| topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] |
| if args.ip and args.port != -1: |
| with Client((args.ip, args.port)) as conn: |
| conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) |
| else: |
| for i, k in enumerate(k_val): |
| print(f"top_{k}->{topk[i]}%") |
| |
| |
| if __name__ == "__main__": |
| parser = setup_common_args_and_variables() |
| |
| parser.add_argument( |
| "-a", |
| "--artifact", |
| help="path for storing generated artifacts by this example. Default ./fastvit", |
| default="./fastvit", |
| type=str, |
| ) |
| |
| parser.add_argument( |
| "-d", |
| "--dataset", |
| help=( |
| "path to the validation folder of ImageNet dataset. " |
| "e.g. --dataset imagenet-mini/val " |
| "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" |
| ), |
| type=str, |
| required=True, |
| ) |
| |
| parser.add_argument( |
| "--oss_repo", |
| help="Path to cloned https://github.com/apple/ml-fastvit", |
| type=str, |
| required=True, |
| ) |
| |
| parser.add_argument( |
| "-p", |
| "--pretrained_weight", |
| help=( |
| "Location of model pretrained weight." |
| "e.g., -p ./fastvit_s12_reparam.pth.tar" |
| "Pretrained model can be found in " |
| "https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12_reparam.pth.tar" |
| ), |
| type=str, |
| required=True, |
| ) |
| |
| args = parser.parse_args() |
| try: |
| main(args) |
| except Exception as e: |
| if args.ip and args.port != -1: |
| with Client((args.ip, args.port)) as conn: |
| conn.send(json.dumps({"Error": str(e)})) |
| else: |
| raise Exception(e) |