blob: 0e2c695ab34d7052e447b80e83278f6dce07f49f [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from multiprocessing.connection import Client
import numpy as np
import torch
from executorch.backends.qualcomm.quantizer.annotators import (
QuantizationConfig,
QuantizationSpec,
)
from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
PerChannelParamObserver,
)
from executorch.backends.qualcomm.quantizer.qconfig import (
_derived_bias_quant_spec,
MovingAverageMinMaxObserver,
)
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.backends.qualcomm.utils.constants import (
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
)
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
from executorch.examples.qualcomm.utils import (
build_executorch_binary,
get_imagenet_dataset,
make_output_dir,
make_quantizer,
parse_skip_delegation_node,
setup_common_args_and_variables,
SimpleADB,
topk_accuracy,
)
def get_instance(repo_path: str, checkpoint_path: str):
import sys
sys.path.insert(0, repo_path)
from models.modules.mobileone import reparameterize_model
from timm.models import create_model
checkpoint = torch.load(checkpoint_path, weights_only=True)
model = create_model("fastvit_s12")
model = reparameterize_model(model).eval()
model.load_state_dict(checkpoint["state_dict"])
return model
def main(args):
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
# ensure the working directory exist.
os.makedirs(args.artifact, exist_ok=True)
if not args.compile_only and args.device is None:
raise RuntimeError(
"device serial is required if not compile only. "
"Please specify a device serial by -s/--device argument."
)
data_num = 100
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
data_size=data_num,
image_shape=(256, 256),
)
pte_filename = "fastvit_qnn"
quantizer = make_quantizer(quant_dtype=QuantDtype.use_8a8w)
# there are lots of outliers appearing in fastvit parameters
# we need to apply special configuration to saturate their impact
act_qspec = QuantizationSpec(
dtype=torch.uint8,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(
**{"averaging_constant": 0.02}
),
)
weight_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=torch.iinfo(torch.int8).min + 1,
quant_max=torch.iinfo(torch.int8).max,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
**{"steps": 200, "use_mse": True}
),
)
# rewrite default per-channel ptq config
quantizer.per_channel_quant_config = QuantizationConfig(
input_activation=act_qspec,
output_activation=act_qspec,
weight=weight_qspec,
bias=_derived_bias_quant_spec,
)
# rewrite default ptq config
q_config = quantizer.bit8_quant_config
quantizer.bit8_quant_config = QuantizationConfig(
input_activation=act_qspec,
output_activation=act_qspec,
weight=q_config.weight,
bias=q_config.bias,
)
# lower to QNN
build_executorch_binary(
convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)),
inputs[0],
args.model,
f"{args.artifact}/{pte_filename}",
dataset=inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_8a8w,
custom_quantizer=quantizer,
custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE},
shared_buffer=args.shared_buffer,
)
if args.compile_only:
return
adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
device_id=args.device,
host_id=args.host,
soc_model=args.model,
)
adb.push(inputs=inputs, input_list=input_list)
adb.execute()
# collect output data
output_data_folder = f"{args.artifact}/outputs"
make_output_dir(output_data_folder)
adb.pull(output_path=args.artifact)
# top-k analysis
predictions = []
for i in range(data_num):
predictions.append(
np.fromfile(
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
)
)
k_val = [1, 5]
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
else:
for i, k in enumerate(k_val):
print(f"top_{k}->{topk[i]}%")
if __name__ == "__main__":
parser = setup_common_args_and_variables()
parser.add_argument(
"-a",
"--artifact",
help="path for storing generated artifacts by this example. Default ./fastvit",
default="./fastvit",
type=str,
)
parser.add_argument(
"-d",
"--dataset",
help=(
"path to the validation folder of ImageNet dataset. "
"e.g. --dataset imagenet-mini/val "
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
),
type=str,
required=True,
)
parser.add_argument(
"--oss_repo",
help="Path to cloned https://github.com/apple/ml-fastvit",
type=str,
required=True,
)
parser.add_argument(
"-p",
"--pretrained_weight",
help=(
"Location of model pretrained weight."
"e.g., -p ./fastvit_s12_reparam.pth.tar"
"Pretrained model can be found in "
"https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12_reparam.pth.tar"
),
type=str,
required=True,
)
args = parser.parse_args()
try:
main(args)
except Exception as e:
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({"Error": str(e)}))
else:
raise Exception(e)