blob: 2342d129b9cd79b59ccabdb0352abbafda98a4df [file] [log] [blame]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import collections
import os
import tempfile
import unittest
from typing import Callable, List, Literal, Optional, Tuple
import numpy as np
import torch
from executorch import exir
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
from executorch.backends.qualcomm.quantizer.quantizer import (
get_default_16bit_qnn_ptq_config,
QnnQuantizer,
)
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
QcomChipset,
)
from executorch.backends.qualcomm.utils.utils import capture_program
from executorch.examples.qualcomm.scripts.utils import SimpleADB
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
class TestQNN(unittest.TestCase):
rtol: float = 0
atol: float = 0
host: Literal = ""
device: Literal = ""
build_folder: Literal = ""
model: QcomChipset = None
compiler_specs: List[CompileSpec] = None
arch_table = {
"SM8650": QcomChipset.SM8650,
"SM8550": QcomChipset.SM8550,
"SM8475": QcomChipset.SM8475,
"SM8450": QcomChipset.SM8450,
}
error_only = False
ip = "localhost"
port = 8080
executorch_root: Literal = ""
artifact_dir: Literal = ""
image_dataset: Literal = ""
pretrained_weight: Literal = ""
online_prepare: bool = False
def _assert_outputs_equal(self, model_output, ref_output):
self.assertTrue(len(ref_output) == len(model_output))
for i in range(len(ref_output)):
self.assertTrue(
torch.allclose(
model_output[i], ref_output[i], atol=self.atol, rtol=self.rtol
),
msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}",
)
def _save_model_and_expected_output(
self,
module: torch.nn.Module,
buffer: exir.ExirExportedProgram,
inputs: Tuple[torch.Tensor],
dir_name: Literal,
) -> None:
# Save the input data list to be executed
input_list = ""
for idx, _ in enumerate(inputs):
input_name = f"input_0_{idx}.raw"
input_list += input_name + " "
input_list = input_list.strip() + "\n"
ref_output = module(*inputs)
# Save the expected output data to be verified
ref_outputs = []
if isinstance(ref_output, collections.OrderedDict):
ref_outputs.append(ref_output["out"].detach())
elif isinstance(ref_output, tuple):
for output in ref_output:
ref_outputs.append(output.detach())
else:
ref_outputs.append(ref_output.detach())
pte_fname = f"{dir_name}/qnn_executorch_test.pte"
with open(pte_fname, "wb") as file:
file.write(buffer)
return input_list, ref_outputs, pte_fname
def lower_module_and_test_output(
self,
module: torch.nn.Module,
sample_inputs: Tuple[torch.Tensor],
expected_partitions: int = 1,
assert_output_equal: bool = True,
skip_node_id_set: set = None,
skip_node_op_set: set = None,
):
qnn_partitioner = QnnPartitioner(
self.compiler_specs, skip_node_id_set, skip_node_op_set
)
delegated_program = capture_program(module, sample_inputs)
delegated_program.exported_program = to_backend(
delegated_program.exported_program, qnn_partitioner
)
exec_prog = delegated_program.to_executorch()
# Assert the backend name is qnn
self.assertEqual(
len(exec_prog.program.execution_plan[0].delegates), expected_partitions
)
for i in range(expected_partitions):
self.assertEqual(
exec_prog.program.execution_plan[0].delegates[i].id,
QnnBackend.__name__,
)
# Check numerics
if assert_output_equal:
with tempfile.TemporaryDirectory() as tmp_dir:
(
input_list,
ref_outputs,
pte_fname,
) = self._save_model_and_expected_output(
module,
exec_prog.buffer,
sample_inputs,
tmp_dir,
)
device_output_dir = f"{tmp_dir}/outputs"
device_outputs = []
def post_process():
for i, f in enumerate(os.listdir(device_output_dir)):
filename = os.path.join(device_output_dir, f)
output = np.fromfile(
filename, dtype=ref_outputs[i].numpy().dtype
)
output = torch.from_numpy(output).reshape(ref_outputs[i].shape)
device_outputs.append(output)
adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
artifact_path=self.build_folder,
pte_path=pte_fname,
workspace="/data/local/tmp/qnn_executorch_test",
device_id=self.device,
host_id=self.host,
soc_model=self.model,
error_only=self.error_only,
)
adb.push(inputs=[sample_inputs], input_list=input_list)
adb.execute()
adb.pull(output_path=tmp_dir, callback=post_process)
self._assert_outputs_equal(device_outputs, ref_outputs)
def get_qdq_module(
self,
module: torch.nn.Module,
inputs: Tuple[torch.Tensor],
is_conv_per_channel: Optional[bool] = True,
custom_quant_annotations: Tuple[Callable] = (),
use_16bit_quant: Optional[bool] = False,
) -> torch.fx.GraphModule:
m = torch._export.capture_pre_autograd_graph(module, inputs)
quantizer = QnnQuantizer()
quantizer.add_custom_quant_annotations(custom_quant_annotations)
quantizer.set_per_channel_quant(is_conv_per_channel)
if use_16bit_quant:
quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS)
quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
prepared = prepare_pt2e(m, quantizer)
prepared(*inputs)
quantized_module = convert_pt2e(prepared)
nodes = {node.target for node in quantized_module.graph.nodes}
q_and_dq = {
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
}
self.assertTrue(nodes.intersection(q_and_dq))
return quantized_module