blob: 0575137cbc34b9814269350d46196783005b6229 [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import defaultdict
from typing import final, List
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
import torch # noqa: F401
from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import (
FuseConsecutiveTranspose,
)
from executorch.backends.qualcomm._passes.insert_io_qdq import InsertIOQDQ
from executorch.backends.qualcomm._passes.insert_requantize import InsertRequantize
from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from executorch.exir.passes import PassManager
from torch.export.exported_program import ExportedProgram
DEFAULT_DEBUG_HANDLE = 65535
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@final
class QnnBackend(BackendDetails):
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
option = generate_qnn_executorch_option(compile_specs)
qnn_manager = PyQnnManager.QnnManager(option)
qnn_manager.Init()
# QNN Delegate Specific Passes
qnn_compiler_passes = PassManager(
passes=[
InsertRequantize(edge_program),
InsertIOQDQ(edge_program),
LayoutTransform(edge_program, insert_permute=True),
FuseConsecutiveTranspose(),
]
)
pass_result = qnn_compiler_passes(edge_program.graph_module)
assert pass_result is not None
enable_tensor_dump = qnn_manager.IsTensorDump()
nodes_to_wrappers = defaultdict(dict)
node_visitors = get_node_visitors(
edge_program, enable_tensor_dump=enable_tensor_dump
)
py_op_wrapper_list = []
for node in pass_result.graph_module.graph.nodes:
if node.op == "call_function":
logger.info(f"Visiting: {node}, {node.target.__name__}")
if node.target.__name__ in node_visitors:
py_op_wrapper = node_visitors[node.target.__name__].define_node(
node, nodes_to_wrappers
)
if py_op_wrapper is not None:
if isinstance(py_op_wrapper, List):
py_op_wrapper_list.extend(py_op_wrapper)
else:
py_op_wrapper_list.append(py_op_wrapper)
else:
err_msg = (
f"For {node}, {node.op}:{node.target.__name__} "
"is not supported in Qnn Delegate"
)
try:
context_loader_target = eval(
f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}",
globals().update(torch.__dict__),
)
assert node.target == context_loader_target, err_msg
# if graph has context binary loader node, return directly
return PreprocessResult(
processed_bytes=node.meta[OpContextLoader.meta_ctx_bin],
debug_handle_map={},
)
except:
raise RuntimeError(err_msg)
elif node.op in [
"get_attr",
"placeholder",
"output",
]:
continue
else:
raise RuntimeError(f"{node.op} is not supported in Qnn")
qnn_context_binary = qnn_manager.Compile(
qnn_manager.GetGraphNames()[0],
[py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list],
)
assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary."
qnn_manager.Destroy()
# For now, debug_handle_map is not used by QNN ExecuTorch
return PreprocessResult(
processed_bytes=bytes(qnn_context_binary),
debug_handle_map={},
)