blob: 590ede7431981ad7ce094469e4f5975230eca507 [file] [log] [blame] [edit]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
import warnings
from collections import OrderedDict
from typing import Callable, Dict, FrozenSet, List, Tuple
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
import executorch.exir as exir
import torch
from executorch.backends.qualcomm._passes.annotate_and_quant_scalar import (
AnnotateAndQuantScalar,
)
from executorch.backends.qualcomm._passes.annotate_decomposed import AnnotateDecomposed
from executorch.backends.qualcomm._passes.annotate_quant_attrs import AnnotateQuantAttrs
from executorch.backends.qualcomm._passes.convert_binary_op_with_scalar import (
ConvertBinaryOpsWithScalar,
)
from executorch.backends.qualcomm._passes.convert_bmm_to_matmul import (
ConvertBmmToMatmul,
)
from executorch.backends.qualcomm._passes.convert_interpolate_with_upsample2d import (
ConvertInterpolateWithUpsample2D,
)
from executorch.backends.qualcomm._passes.convert_prelu import ConvertPReLU
from executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear
from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import (
ExpandBroadcastTensorShape,
)
from executorch.backends.qualcomm._passes.fold_qdq import FoldQDQ
from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32
from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
RecomposePixelUnshuffle,
)
from executorch.backends.qualcomm._passes.recompose_rms_norm import RecomposeRmsNorm
from executorch.backends.qualcomm._passes.remove_redundancy import RemoveRedundancy
from executorch.backends.qualcomm._passes.replace_index_put_input import (
ReplaceIndexPutInput,
)
from executorch.backends.qualcomm.builders.node_visitor import (
QNN_QUANT_TYPE_MAP,
QNN_TENSOR_TYPE_MAP,
)
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
from executorch.backends.qualcomm.partition.qnn_partitioner import (
generate_qnn_executorch_option,
QnnPartitioner,
)
from executorch.backends.qualcomm.serialization.qc_schema import (
_soc_info_table,
HtpArch,
QcomChipset,
QnnExecuTorchBackendOptions,
QnnExecuTorchBackendType,
QnnExecuTorchHtpBackendOptions,
QnnExecuTorchHtpPerformanceMode,
QnnExecuTorchHtpPrecision,
QnnExecuTorchLogLevel,
QnnExecuTorchOptions,
QnnExecuTorchProfileLevel,
)
from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
flatbuffer_to_option,
option_to_flatbuffer,
)
from executorch.backends.qualcomm.utils.constants import (
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
QCOM_PASS_SKIP_ADVANCED_REQUANT,
QCOM_QNN_COMPILE_SPEC,
QCOM_QUANTIZED_IO,
)
from executorch.exir import (
EdgeCompileConfig,
ExecutorchProgramManager,
ExirExportedProgram,
to_edge,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.capture import ExecutorchBackendConfig
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.program._program import _get_updated_graph_signature
from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions
from torch.export.exported_program import ExportedProgram
from torch.fx import passes
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.library import Library
class _AnnotationSkipper(OperatorSupportBase):
"""
Class used to partition out unwanted graph nodes.
e.g. - nodes are prevented from quantization annotation
- nodes have been grouped together as a submodule
Attributes
----------
fp_node_id_set : set
a set contains nodes' name to be left in fp precision
fp_node_op_set : set
a set contains nodes' target (aten dialect) to be left in fp precision
skip_annotated_submodule : bool
flag to skip annotated submodule or not
Methods
-------
should_delegate(n: torch.fx.Node)
identify the residual nodes haven't be lowered with fixed-precision
should_skip(n: torch.fx.Node)
identify the nodes should be kept out with fixed-precision or not
is_node_supported(_, node: torch.fx.Node)
overridden method for graph partitioning
"""
def __init__(
self,
fp_node_id_set: set = None,
fp_node_op_set: set = None,
skip_annotated_submodule: bool = False,
):
self.fp_node_id_set = fp_node_id_set
self.fp_node_op_set = fp_node_op_set
self.skip_annotated_submodule = skip_annotated_submodule
def should_delegate(self, n: torch.fx.Node):
return n.op == "call_function" and n.target != operator.getitem
def should_skip(self, n: torch.fx.Node):
return n.name in self.fp_node_id_set or n.target in self.fp_node_op_set
def is_node_supported(self, _, node: torch.fx.Node) -> bool:
if self.skip_annotated_submodule:
if node.op == "get_attr":
return all(self.should_delegate(user) for user in node.users)
return self.should_delegate(node)
if any(
[
node.op in ("placeholder", "output"),
self.should_skip(node),
# check if parameters belong to fallbacked operator
(
node.op == "get_attr"
and all(self.should_skip(user) for user in node.users)
),
]
):
print(f"[QNN Quantizer Annotation]: {node.name} | Skipped")
return False
return True
def qnn_capture_config():
return exir.CaptureConfig(enable_aot=True)
def qnn_edge_config() -> exir.EdgeCompileConfig:
return exir.EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
)
def convert_linear_to_conv2d(module: torch.nn.Module):
class Conv2D(torch.nn.Module):
def __init__(self, weight, bias=None):
super().__init__()
use_bias = bias is not None
self.conv = torch.nn.Conv2d(
in_channels=weight.shape[0],
out_channels=weight.shape[1],
kernel_size=1,
padding=0,
bias=use_bias,
)
self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1))
if use_bias:
self.conv.bias = torch.nn.Parameter(bias)
def forward(self, x):
rank = x.dim()
x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1)
x = torch.transpose(x, 1, 2)
res = self.conv(x)
res = torch.transpose(res, 1, 2)
res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3])
return res
def replace_linear(module: torch.nn.Module):
attr_strs = dir(module)
if isinstance(module, torch.nn.ModuleList):
attr_strs += [str(i) for i in range(len(module))]
for attr_str in attr_strs:
target_attr = getattr(module, attr_str)
if isinstance(target_attr, torch.nn.Linear):
setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias))
for _, sub_module in module.named_children():
sub_module = replace_linear(sub_module)
return module
return replace_linear(module)
def update_spill_fill_size(
exported_program: ExportedProgram | List[LoweredBackendModule],
):
# check if user specifies to use multi_contexts
# this is a generic approach in case there exists multiple backends
def get_program_info(program):
def process_exported_program(prog):
max_sf_buf_size, module_map = 0, {}
for _, m in prog.graph_module._modules.items():
# currently only 1 compile spec is expected in each partition
options = flatbuffer_to_option(m.compile_specs[0].value)
if (
options.backend_options.backend_type
== QnnExecuTorchBackendType.kHtpBackend
and options.backend_options.htp_options.use_multi_contexts
):
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
m.compile_specs[0].value, m.processed_bytes
)
assert qnn_mgr.Init().value == 0, "failed to load context binary"
max_sf_buf_size = max(
max_sf_buf_size, qnn_mgr.GetSpillFillBufferSize()
)
module_map[m] = options
qnn_mgr.Destroy()
return max_sf_buf_size, module_map
def process_lowered_module(module):
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
module.compile_specs[0].value, module.processed_bytes
)
assert qnn_mgr.Init().value == 0, "failed to load context binary"
spill_fill_size = qnn_mgr.GetSpillFillBufferSize()
qnn_mgr.Destroy()
return spill_fill_size, {
module: flatbuffer_to_option(module.compile_specs[0].value)
}
dispatch = {
ExportedProgram: process_exported_program,
LoweredBackendModule: process_lowered_module,
}
return dispatch[type(program)](program)
def update_program(max_sf_buf_size, module_map):
def set_spec(module, options):
spec = CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(options))
if isinstance(module, ExportedProgram):
module.compile_specs[0] = spec
else:
module._compile_specs[0] = spec
for module, options in module_map.items():
options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size
set_spec(module, options)
if isinstance(exported_program, list):
max_sf_size, modules_map = 0, {}
for prog in exported_program:
max_sf_buf_size, module_map = get_program_info(prog)
max_sf_size = max(max_sf_size, max_sf_buf_size)
modules_map.update(module_map)
update_program(max_sf_size, modules_map)
else:
update_program(*get_program_info(exported_program))
def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
source_decompositions = torch_core_aten_decompositions()
# The below super ops are supported by QNN
remove_decompositions = [
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.pixel_unshuffle.default,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten._safe_softmax.default,
]
for key in remove_decompositions:
source_decompositions.pop(key)
return source_decompositions
def _transform(
edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset()
) -> ExportedProgram:
# currently ExirExportedProgram.transform does not accept
# changes of input number which was caused by FoldQDQ
# apply passes one by one here to avoid IR capture failure
graph_module = edge_program.graph_module
RemoveRedundancy()(graph_module)
RecomposePixelUnshuffle()(graph_module)
RecomposeRmsNorm()(graph_module)
ConvertToLinear()(graph_module)
ConvertPReLU(edge_program)(graph_module)
ConvertBmmToMatmul()(graph_module)
ConvertInterpolateWithUpsample2D()(graph_module)
I64toI32(edge_program)(graph_module)
AnnotateQuantAttrs(
edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config
)(graph_module)
AnnotateAndQuantScalar(edge_program)(graph_module)
AnnotateDecomposed(edge_program)(graph_module)
FoldQDQ()(graph_module)
# this pass is not necessary for network without layout-sensitive ops
# enable defaultly will introduce overhead from extra view_copy nodes
if QCOM_PASS_EXPAND_BROADCAST_SHAPE in custom_pass_config:
ExpandBroadcastTensorShape()(graph_module)
LayoutTransform(edge_program)(graph_module)
ReplaceIndexPutInput(edge_program)(graph_module)
# Since QDQ nodes are stripped, update graph signature again to validate program
edge_program._graph_signature = _get_updated_graph_signature(
edge_program.graph_signature,
edge_program.graph_module,
)
edge_program._validate()
return edge_program
def capture_program(
module: torch.nn.Module,
inputs: Tuple[torch.Tensor],
custom_pass_config: FrozenSet[str] = frozenset(),
) -> exir.ExirExportedProgram:
ep = torch.export.export(module, inputs)
decomposed_ep = ep.run_decompositions(get_decomp_table())
# We choose call_operator by target in ConvertBinaryOpsWithScalar
# because it is the same source_fn_stack for MultiheadAttention
# TODO: Should modify the scalar op in the op builder instead of
# using transformation
core_ep = ExirExportedProgram(decomposed_ep, False)
core_ep.transform(ConvertBinaryOpsWithScalar())
edge_ep = core_ep.to_edge(qnn_edge_config())
_transform(edge_ep.exported_program, custom_pass_config)
return edge_ep
def _partition_graph_into_submodules(gm, subgm_tag, subgm_cb, ptn):
from torch.fx.passes.utils.fuser_utils import (
erase_nodes,
fuse_as_graphmodule,
insert_subgm,
legalize_graph,
topo_sort,
)
partitions = ptn.propose_partitions()
# insert meta for each partition group
for i, partition in enumerate(partitions):
for node in partition.nodes:
node.meta[subgm_tag] = i
for i in range(len(partitions)):
# find nodes with same group id in current graph
node_list = [
node for node in gm.graph.nodes if node.meta.get(subgm_tag, "") == i
]
# fuse group nodes into submodule
sorted_nodes = topo_sort(node_list)
submodule_name = f"{subgm_tag}_{i}"
subgm, orig_inputs, orig_outputs = fuse_as_graphmodule(
gm, sorted_nodes, submodule_name
)
# insert submodule & trim group nodes
gm = insert_subgm(
gm,
subgm_cb(subgm, submodule_name),
orig_inputs,
orig_outputs,
)
erase_nodes(gm, sorted_nodes)
legalize_graph(gm)
gm.recompile()
return gm
def _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn):
from executorch.exir.backend.backend_api import to_backend
# return lowered program for user to debug
exported_progs = []
# partition each submodule which went through convert_pt2e
for node in gm.graph.nodes:
if node.op == "call_module" and subgm_tag in node.name:
# obtain sample inputs through meta
subgm_input = [
torch.ones(arg.meta["val"].shape, dtype=arg.meta["val"].dtype)
for arg in node.args
]
# program meets QNN backend requirement
sub_prog = capture_program(gm.get_submodule(node.name), tuple(subgm_input))
# start lowering with given partitioner
exported_progs.append(to_backend(sub_prog.exported_program, ptn))
# replace submodule with lowered module
gm.set_submodule(
node.name,
exported_progs[-1].graph_module,
)
# if node has multiple outputs, getitems will be default generated
if all(n.target != operator.getitem for n in node.users):
with gm.graph.inserting_after(node):
getitem_node = gm.graph.call_function(
operator.getitem,
(node, 0),
)
getitem_node.meta = node.meta
node.replace_all_uses_with(
replace_with=getitem_node,
delete_user_cb=lambda user: user.target != operator.getitem,
)
gm.recompile()
return gm, exported_progs
def skip_annotation(
nn_module: torch.nn.Module,
quantizer,
partitioner,
sample_input: Tuple[torch.Tensor, ...],
calibration_cb: Callable[[torch.fx.GraphModule], None],
fp_node_id_set: set = None,
fp_node_op_set: set = None,
fallback_to_cpu: bool = True,
):
r"""
Exclude speific operators from quantizer annotation.
Skipped operators will defaultly stay in CPU, set 'fallback_to_cpu'
to False for trying to delegate them with FP16 precision.
e.g.: consider following graph:
bias_1 weight_1 input_1 bias_2 weight_2 input_2
| (placeholder) | | (placeholder) |
\ | / \ | /
\ | / \ | /
\ | / \ | /
conv2d_1 conv2d_2
(torch.ops.aten.conv2d.default)
\ /
\ /
\_______ _______/
add_1
(torch.ops.aten.add.default)
|
output
If user wants to skip convolution op by names with
'skip_node_id_set' = {"conv2d_1"}
"bias_1 / weight_1 / input_1 / input_2 / conv2d_1"
will be partitioned out and not annotated / lowered with QNN.
[Generated graph]
bias_1 weight_1 input_1 input_2
| (placeholder) | |
\ | / |
\ | / |
\ | / |
conv2d_1 |
\ /
\ /
\ /
lowered_module_1
(QNN fixed precision)
|
output
If user wants to skip convolution op by target with
'skip_node_op_set' = {torch.ops.aten.conv2d.default}
"bias_1 / weight_1 / input_1 / conv2d_1,
bias_2 / weight_2 / input_2 / conv2d_2"
will be partitioned out and not annotated / lowered with QNN.
[Generated graph]
bias_1 weight_1 input_1 bias_2 weight_2 input_2
| (placeholder) | | (placeholder) |
\ | / \ | /
\ | / \ | /
\ | / \ | /
conv2d_1 conv2d_2
(torch.ops.aten.conv2d.default)
\ /
\ /
\__ __/
lowered_module_1
(QNN fixed precision)
|
output
If user wants to delegate the skipped conv2d from above graph
with 'fallback_to_cpu' = False:
[Generated graph]
input_1 input_2
(placeholder) (placeholder)
| |
\ /
lowered_module_2
(QNN fp16 precision)
|
|
lowered_module_1
(QNN fixed precision)
|
output
Args:
nn_module (torch.nn.Module): The module to be lowered.
quantizer (QnnQuantizer): Instance of QnnQuantizer.
partitioner (QnnPartitioner): Instance of QnnPartitioner.
sample_input ((torch.Tensor, ...)): Sample input tensors for graph exporting.
calibration_cb (callable): Callback function for user-defined calibration.
fp_node_id_set ({str, ...}): Set of operator names to be left in fp precision.
fp_node_op_set ({torch.ops.aten.xxx, ...}): Set of operator targets to be left in fp precision.
fallback_to_cpu (bool): Whether to lower skipped nodes to fp16 or not.
Returns:
exported_programs: List of programs lowered to QnnBackend (quantized graphs only).
"""
from executorch.backends.qualcomm.serialization.qc_schema import (
QnnExecuTorchHtpPrecision,
)
from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
flatbuffer_to_option,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
def prepare_subgm(subgm, subgm_name):
# prepare current submodule for quantization annotation
subgm_prepared = prepare_pt2e(subgm, quantizer)
# overwrite this attribute or name will be set to "GraphModule"
# we could not identify each submodule if action is not performed
subgm_prepared.__class__.__name__ = subgm_name
return subgm_prepared
fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set()
fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set()
graph_module = torch.export.export(nn_module, sample_input).module()
# define node support type
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
_AnnotationSkipper(fp_node_id_set, fp_node_op_set),
allows_single_node_partition=True,
)
subgm_tag = "annotated_group"
graph_module = _partition_graph_into_submodules(
gm=graph_module,
subgm_tag=subgm_tag,
subgm_cb=prepare_subgm,
ptn=capability_partitioner,
)
# perform calibration
calibration_cb(graph_module)
# convert sub modules which went through prepare_pt2e
for node in graph_module.graph.nodes:
if node.op == "call_module":
graph_module.set_submodule(
node.name, convert_pt2e(graph_module.get_submodule(node.name))
)
# canonicalize graph for lowering again
graph_module, exported_progs = _canonicalize_graph_with_lowered_module(
gm=graph_module,
subgm_tag=subgm_tag,
ptn=partitioner,
)
if not fallback_to_cpu:
try:
from executorch.exir.backend.partitioner import DelegationSpec
# change HTP compiler spec for hardware to enable fp16
qnn_option = generate_qnn_executorch_option(
partitioner.compiler_specs_snapshot
)
compile_option = flatbuffer_to_option(qnn_option)
htp_options = compile_option.backend_options.htp_options
htp_options.precision = QnnExecuTorchHtpPrecision.kHtpFp16
partitioner.delegation_spec = DelegationSpec(
"QnnBackend",
[
CompileSpec(
QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(compile_option)
)
],
)
except:
print(
"Failed to change HTP compiler spec with 'use_fp16' as True,"
" skipped operators will fallback to cpu,"
)
return graph_module, exported_progs
# try lowering skipped operator into fp16
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
_AnnotationSkipper(skip_annotated_submodule=True),
allows_single_node_partition=True,
)
subgm_tag = "skipped_group"
graph_module = _partition_graph_into_submodules(
gm=graph_module,
subgm_tag=subgm_tag,
subgm_cb=lambda subgm, _: subgm,
ptn=capability_partitioner,
)
graph_module, exported_progs_fp = _canonicalize_graph_with_lowered_module(
gm=graph_module,
subgm_tag=subgm_tag,
ptn=partitioner,
)
exported_progs.extend(exported_progs_fp)
return graph_module, exported_progs
def from_context_binary( # noqa: C901
ctx_path: str | bytes,
op_name: str,
soc_model: QcomChipset = QcomChipset.SM8650,
custom_info: Dict = None,
):
from pathlib import Path
def implement_op(custom_op, op_name, outputs):
@torch.library.impl(
custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd"
)
def op_impl(inputs: List[torch.Tensor]):
return tuple(
torch.zeros(tuple(v.shape), device="meta", dtype=v.dtype)
for v in outputs.values()
)
def build_graph(inputs, outputs):
# custom op declaration
inputs_str = "Tensor[] inputs"
func_proto = f"{op_name}({inputs_str}) -> Any"
custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
custom_op.define(func_proto)
# custom op implementation
implement_op(custom_op, op_name, outputs)
# model architecture mimicking context binary
class Model(torch.nn.Module):
def forward(self, *inputs):
return getattr(
getattr(torch.ops, OpContextLoader.namespace), op_name
).default(inputs)
model = Model()
prog = torch.export.export(model, tuple(inputs.values()))
# bookkeeping for variables' life cycle
return {
"custom_op": custom_op,
"custom_module": model,
"exported_program": prog,
}
def build_tensor(tensors, dtype_map):
ret = OrderedDict()
for t in tensors:
dtype = t.GetDataType()
dtype_torch = dtype_map.get(dtype, None)
assert dtype_torch is not None, f"unknown qnn data type {dtype}"
ret[t.GetName()] = torch.zeros(tuple(t.GetDims()), dtype=dtype_torch)
return ret
def preprocess_binary(ctx_bin, compiler_specs):
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
generate_qnn_executorch_option(compiler_specs),
)
return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin))
# dummy compiler spec would be fine, since we're not compiling
backend_options = generate_htp_compiler_spec(use_fp16=False)
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=soc_model,
backend_options=backend_options,
is_from_context_binary=True,
)
ctx_bin = (
ctx_path
if not isinstance(ctx_path, str)
else preprocess_binary(Path(f"{ctx_path}").read_bytes(), compiler_specs)
)
dtype_map = {}
for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP):
for k, v in type_map.items():
dtype_map.setdefault(v, k)
if custom_info is not None:
# since some context binaries might fail to open on host
# if they are compiled with special flags:
# e.g. weight sharing
# use custom information here instead
inputs = build_tensor(custom_info["graph_inputs"], dtype_map)
outputs = build_tensor(custom_info["graph_outputs"], dtype_map)
graph_name = custom_info["graph_name"]
else:
# get context-binary io tensor info through qnn manager
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
generate_qnn_executorch_option(compiler_specs),
ctx_bin,
)
assert qnn_mgr.Init().value == 0, "failed to load context binary"
# assume we only have one graph in current context
graph_name = qnn_mgr.GetGraphNames()[0]
qnn_mgr.AllocateTensor(graph_name)
inputs = build_tensor(qnn_mgr.GetGraphInputs(graph_name), dtype_map)
outputs = build_tensor(qnn_mgr.GetGraphOutputs(graph_name), dtype_map)
qnn_mgr.Destroy()
# generate graph specific for loading context
bundle_prog = build_graph(inputs, outputs)
bundle_prog.update({"inputs": inputs, "outputs": outputs})
edge_prog_mgr = to_edge(
programs={graph_name: bundle_prog["exported_program"]},
# do not alter name for custom op
compile_config=EdgeCompileConfig(_use_edge_ops=False),
)
# update meta with context binary
for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
if n.op == "call_function" and OpContextLoader.namespace in str(n.target):
n.meta[OpContextLoader.meta_ctx_bin] = ctx_bin
break
bundle_prog["edge_program_manager"] = edge_prog_mgr.to_backend(
QnnPartitioner(compiler_specs)
)
return bundle_prog
def draw_graph(title, path, graph_module: torch.fx.GraphModule):
graph = passes.graph_drawer.FxGraphDrawer(graph_module, title)
with open(f"{path}/{title}.svg", "wb") as f:
f.write(graph.get_dot_graph().create_svg())
def generate_multi_graph_program(
compiler_specs: List[CompileSpec],
processed_bytes: List[bytes],
backend_config: ExecutorchBackendConfig = None,
) -> ExecutorchProgramManager:
# compile multiple graphs in qcir into single context binary
graph_inputs, graph_outputs = {}, {}
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
generate_qnn_executorch_option(compiler_specs), processed_bytes
)
assert qnn_mgr.Init().value == 0, "failed to load processed bytes"
binary_info = bytes(qnn_mgr.Compile())
assert len(binary_info) != 0, "failed to generate QNN context binary"
graph_names = qnn_mgr.GetGraphNames()
for graph_name in graph_names:
graph_inputs[graph_name] = qnn_mgr.GetGraphInputs(graph_name)
graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name)
qnn_mgr.Destroy()
# build custom ops with different graph signatures
compiler_options = flatbuffer_to_option(compiler_specs[0].value)
bundle_progs = [
from_context_binary(
ctx_path=binary_info,
op_name=f"loader_{graph_name}",
soc_model=compiler_options.soc_info.soc_model,
custom_info={
"graph_inputs": graph_inputs[graph_name],
"graph_outputs": graph_outputs[graph_name],
"graph_name": graph_name,
},
)
for graph_name in graph_names
]
# leverage ExecutorchProgramManager for generating pte with multi-methods
edge_prog_mgr = to_edge(
programs={
graph_name: bundle_prog["exported_program"]
for graph_name, bundle_prog in zip(graph_names, bundle_progs)
},
# do not alter name for custom op
compile_config=EdgeCompileConfig(_use_edge_ops=False),
)
# restore meta losed in generating EdgeProgramManager
for graph_name in graph_names:
for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
if graph_name in n.name:
n.meta[OpContextLoader.meta_ctx_bin] = binary_info
break
return edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)).to_executorch(
config=backend_config or ExecutorchBackendConfig()
)
def generate_htp_compiler_spec(
use_fp16: bool,
use_dlbc: bool = False,
use_multi_contexts: bool = False,
) -> QnnExecuTorchBackendOptions:
"""
Helper function generating backend options for QNN HTP
Args:
use_fp16: If true, the model is compiled to QNN HTP fp16 runtime.
Note that not all SoC support QNN HTP fp16. Only premium tier SoC
like Snapdragon 8 Gen 1 or newer can support HTP fp16.
use_dlbc: Deep Learning Bandwidth Compression allows inputs to be
compressed, such that the processing bandwidth can be lowered.
use_multi_contexts: When multiple contexts are generated inside the same
pte, it is possible to reserve a single spill-fill allocation that
could be re-used across all the splits.
Returns:
QnnExecuTorchHtpBackendOptions: backend options for QNN HTP.
"""
htp_options = QnnExecuTorchHtpBackendOptions()
htp_options.precision = (
QnnExecuTorchHtpPrecision.kHtpFp16
if use_fp16
else QnnExecuTorchHtpPrecision.kHtpQuantized
)
# This actually is not an option which can affect the compiled blob.
# But we don't have other place to pass this option at execution stage.
# TODO: enable voting mechanism in runtime and make this as an option
htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst
htp_options.use_multi_contexts = use_multi_contexts
htp_options.use_dlbc = use_dlbc
return QnnExecuTorchBackendOptions(
backend_type=QnnExecuTorchBackendType.kHtpBackend,
htp_options=htp_options,
)
def generate_qnn_executorch_compiler_spec(
soc_model: QcomChipset,
backend_options: QnnExecuTorchBackendOptions,
debug: bool = False,
saver: bool = False,
online_prepare: bool = False,
dump_intermediate_outputs: bool = False,
profile: bool = False,
optrace: bool = False,
shared_buffer: bool = False,
is_from_context_binary: bool = False,
multiple_graphs: bool = False,
graph_name: str = "forward",
) -> List[CompileSpec]:
"""
Helper function generating compiler specs for Qualcomm AI Engine Direct
Args:
soc_model: The SoC you plan to run the compiled model. Please check
QcomChipset for supported SoC.
SM8450 (Snapdragon 8 Gen 1)
SM8475(Snapdragon 8 Gen 1+)
SM8550(Snapdragon 8 Gen 2)
SM8650(Snapdragon 8 Gen 3)
backend_options: Options required by different backends.
debug: Enable verbose logging. Disclaimer: this option must change in
the near future.
online_prepare: Compose QNN graph on device if set to True
saver: Instead of compiling the model, run QNN Saver. Please check
documents of Qualcomm AI Engine Direct SDK. This feature is usually
for debugging purpose.
dump_intermediate_outputs: If tensor dump is enabled, all intermediate tensors output will be dumped.
This option exists for debugging accuracy issues
profile: Enable profile the performance of per operator.
Note that for now only support kProfileDetailed to
profile the performance of each operator with cycle unit.
shared_buffer: Enables usage of shared buffer between application
and backend for graph I/O.
is_from_context_binary: True if current graph comes from pre-built context binary.
multiple_graphs: True if multiple methods are expected to have in single .pte file.
Please see test cases for post-processing example.
graph_name: Assign unique graph name if 'multiple_graphs' is used.
Returns:
List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct.
Raises:
ValueError: The value QcomChipset is currently not supported.
ValueError: Confliction between compiler specs.
"""
_supported_soc_models = {soc_model.value for soc_model in QcomChipset}
if soc_model not in _supported_soc_models:
raise ValueError(f"unknown SoC model for QNN: {soc_model}")
if profile and dump_intermediate_outputs:
warnings.warn(
"It is not recommended to turn on both profiling and dump_intermediate_outputs the same time"
", because dump_intermediate_outputs will cause performance drop.",
stacklevel=1,
)
qnn_executorch_options = QnnExecuTorchOptions(
_soc_info_table[soc_model], backend_options
)
qnn_executorch_options.graph_name = graph_name
qnn_executorch_options.log_level = (
QnnExecuTorchLogLevel.kLogLevelDebug
if debug
else QnnExecuTorchLogLevel.kLogLevelWarn
)
qnn_executorch_options.dump_intermediate_outputs = dump_intermediate_outputs
if saver:
qnn_executorch_options.library_path = "libQnnSaver.so"
if optrace:
qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOptrace
elif profile:
qnn_executorch_options.profile_level = (
QnnExecuTorchProfileLevel.kProfileDetailed
)
else:
qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOff
if (
online_prepare
and backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend
and backend_options.htp_options.use_multi_contexts
):
raise ValueError(
"'use_multi_context' could not function in online prepare mode, "
"please set 'online_prepare' to False"
)
qnn_executorch_options.shared_buffer = shared_buffer
qnn_executorch_options.online_prepare = online_prepare
qnn_executorch_options.is_from_context_binary = is_from_context_binary
qnn_executorch_options.multiple_graphs = multiple_graphs
if multiple_graphs:
# enable weight sharing mechanism if multiple graphs appear
if backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend:
backend_options.htp_options.use_weight_sharing = True
return [
CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(qnn_executorch_options))
]
def get_soc_to_arch_map():
return {
"SSG2115P": HtpArch.V73,
"SM8650": HtpArch.V75,
"SM8550": HtpArch.V73,
"SM8475": HtpArch.V69,
"SM8450": HtpArch.V69,
"SA8295": HtpArch.V68,
}
def get_soc_to_chipset_map():
return {
"SSG2115P": QcomChipset.SSG2115P,
"SM8650": QcomChipset.SM8650,
"SM8550": QcomChipset.SM8550,
"SM8475": QcomChipset.SM8475,
"SM8450": QcomChipset.SM8450,
"SA8295": QcomChipset.SA8295,
}
def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
"""
Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
"""
for node in gm.graph.nodes:
if dtype := get_quant_io_dtype_fn(node):
node.meta[QCOM_QUANTIZED_IO] = dtype