blob: 945c8f0ab029a55f0a90eb8df28e04c54b07c2ec [file] [log] [blame]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple
import executorch.exir as exir
import torch
from executorch.backends.qualcomm.passes.annotate_and_quant_scalar import (
AnnotateAndQuantScalar,
)
from executorch.backends.qualcomm.passes.annotate_decomposed import AnnotateDecomposed
from executorch.backends.qualcomm.passes.annotate_quant_attrs import AnnotateQuantAttrs
from executorch.backends.qualcomm.passes.convert_binary_op_with_scalar import (
ConvertBinaryOpsWithScalar,
)
from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul
from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
from executorch.backends.qualcomm.passes.convert_hardswish import ConvertHardswish
from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import (
ConvertInterpolateWithUpsample2D,
)
from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear
from executorch.backends.qualcomm.passes.fold_qdq import FoldQDQ
from executorch.backends.qualcomm.passes.i64_to_i32 import I64toI32
from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform
from executorch.backends.qualcomm.passes.remove_clone import RemoveClone
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
_soc_info_table,
QcomChipset,
QnnExecuTorchBackendType,
QnnExecuTorchHtpPdSession,
QnnExecuTorchHtpPerformanceMode,
QnnExecuTorchHtpPrecision,
QnnExecuTorchLogLevel,
QnnExecuTorchOptions,
)
from executorch.backends.qualcomm.serialization.qnn_compile_spec_serialize import (
convert_to_flatbuffer,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.fx import passes
QNN_COMPILE_SPEC = "qnn_compile_spec"
def qnn_capture_config():
return exir.CaptureConfig(enable_aot=True)
def qnn_edge_config() -> exir.EdgeCompileConfig:
return exir.EdgeCompileConfig(_check_ir_validity=False)
def capture_program(
module: torch.nn.Module,
inputs: Tuple[torch.Tensor],
) -> exir.ExirExportedProgram:
exir_exported_program = exir.capture(
module,
inputs,
qnn_capture_config(),
)
# We choose call_operator by target in ConvertBinaryOpsWithScalar
# because it is the same source_fn_stack for MultiheadAttention
exir_exported_program.transform(ConvertBinaryOpsWithScalar())
ex_prog = exir_exported_program.to_edge(qnn_edge_config())
# currently ExirExportedProgram.transform does not accept
# changes of input number which was caused by FoldQDQ
# apply passes one by one here to avoid IR capture failure
edge_program = ex_prog.exported_program
graph_module = edge_program.graph_module
RemoveClone()(graph_module)
ConvertToLinear()(graph_module)
ConvertHardsigmoid()(graph_module)
ConvertHardswish()(graph_module)
ConvertBmmToMatmul()(graph_module)
ConvertInterpolateWithUpsample2D()(graph_module)
I64toI32(edge_program)(graph_module)
AnnotateQuantAttrs(edge_program)(graph_module)
AnnotateAndQuantScalar(edge_program)(graph_module)
AnnotateDecomposed(edge_program)(graph_module)
FoldQDQ()(graph_module)
LayoutTransform(edge_program)(graph_module)
return ex_prog
def draw_graph(title, path, graph_module: torch.fx.GraphModule):
graph = passes.graph_drawer.FxGraphDrawer(graph_module, title)
with open(f"{path}/{title}.svg", "wb") as f:
f.write(graph.get_dot_graph().create_svg())
def generate_qnn_executorch_option(
compiler_specs: List[CompileSpec],
) -> bytes:
for compiler_spec in compiler_specs:
if compiler_spec.key == QNN_COMPILE_SPEC:
qnn_compile_spec_buffer = compiler_spec.value
else:
raise ValueError(f"unknown compiler spec key value: {compiler_spec.key}")
return qnn_compile_spec_buffer
# TODO: refactor this for supporting other backends
def generate_qnn_executorch_compiler_spec(
is_fp16: bool,
soc_model: QcomChipset,
debug: bool = False,
saver: bool = False,
online_prepare: bool = False,
) -> List[CompileSpec]:
"""
Helper function generating compiler specs for Qualcomm AI Engine Direct
Args:
is_fp16: If true, the model is compiled to QNN HTP fp16 runtime.
Note that not all SoC support QNN HTP fp16. Only premium tier SoC
like Snapdragon 8 Gen 1 or newer can support HTP fp16.
soc_model: The SoC you plan to run the compiled model. Please check
QcomChipset for supported SoC.
SM8450 (Snapdragon 8 Gen 1)
SM8475(Snapdragon 8 Gen 1+)
SM8550(Snapdragon 8 Gen 2)
SM8650(Snapdragon 8 Gen 3)
online_prepare: Compose QNN graph on device if set to True
debug: Enable verbose logging. Disclaimer: this option must change in
the near future.
saver: Instead of compiling the model, run QNN Saver. Please check
documents of Qualcomm AI Engine Direct SDK. This feature is usually
for debugging purpose.
Returns:
List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct.
Raises:
ValueError: The value QcomChipset is currently not supported.
"""
qnn_executorch_options = QnnExecuTorchOptions()
qnn_executorch_options.backend_type = QnnExecuTorchBackendType.kHtpBackend
qnn_executorch_options.graph_name = "executorch"
qnn_executorch_options.htp_options.pd_session = (
QnnExecuTorchHtpPdSession.kHtpUnsignedPd
)
qnn_executorch_options.htp_options.use_conv_hmx = True
qnn_executorch_options.htp_options.use_fold_relu = True
if is_fp16:
qnn_executorch_options.htp_options.precision = (
QnnExecuTorchHtpPrecision.kHtpFp16
)
else:
qnn_executorch_options.htp_options.precision = (
QnnExecuTorchHtpPrecision.kHtpQuantized
)
if debug:
qnn_executorch_options.log_level = QnnExecuTorchLogLevel.kLogLevelDebug
else:
qnn_executorch_options.log_level = QnnExecuTorchLogLevel.kLogLevelWarn
# This actually is not an option which can affect the compiled blob.
# But we don't have other place to pass this option at execution stage.
qnn_executorch_options.htp_options.performance_mode = (
QnnExecuTorchHtpPerformanceMode.kHtpBurst
)
_supported_soc_models = {soc_model.value for soc_model in QcomChipset}
if soc_model not in _supported_soc_models:
raise ValueError(f"unknown SoC model for QNN: {soc_model}")
else:
qnn_executorch_options.soc_info = _soc_info_table[soc_model]
if saver:
qnn_executorch_options.library_path = "libQnnSaver.so"
if online_prepare:
qnn_executorch_options.online_prepare = True
return [
CompileSpec(QNN_COMPILE_SPEC, convert_to_flatbuffer(qnn_executorch_options))
]