blob: b7ee440c289844b53732a1581c880dc99b860924 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Dict, final, List
import torch
from executorch.backends.xnnpack.operators.node_visitor import get_node_visitors
from executorch.backends.xnnpack.passes import XNNPACKPassManager
from executorch.backends.xnnpack.passes.convert_to_linear import ConvertToLinearPass
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
ConstantDataOffset,
XNNGraph,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
serialize_xnnpack_binary,
)
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.backends.xnnpack.utils.xnnpack_constants import (
XNN_VALUE_FLAG_EXTERNAL_INPUT,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
)
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
from torch.export.exported_program import ExportedProgram
DEFAULT_DEBUG_HANDLE = 65535
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
@dataclass
class ExternalMeta:
external_id: int
io_type: int
def generate_node_to_external_map(
exported_program: ExportedProgram,
edge_graph_module: torch.fx.GraphModule,
) -> Dict[torch.fx.Node, ExternalMeta]:
node_to_external_map = {}
for node in edge_graph_module.graph.nodes:
# The order in which we visit the placeholder node is same as the *args
# order for the forward(*args) signature for this gm. Using the order of
# the nodes as external_id to extract the right arg from *args at runtime
#
# Removing parameters/buffers since they will disappear from the signature
# at runtime
if node.op == "placeholder" and not is_param_node(exported_program, node):
node_to_external_map[node] = ExternalMeta(
external_id=len(node_to_external_map),
io_type=XNN_VALUE_FLAG_EXTERNAL_INPUT,
)
for node in edge_graph_module.graph.nodes:
if node.op == "output":
for output_nodes in node.args:
for output_node in output_nodes:
node_to_external_map[output_node] = ExternalMeta(
external_id=len(node_to_external_map),
io_type=XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
)
return node_to_external_map
def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
for node in edge_graph_module.graph.nodes:
if node.op != "placeholder":
continue
# We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params
t = node.meta.get("val", None)
if t is not None and getattr(t, "dim_order", None) is not None:
default_dim_order = tuple(range(t.dim()))
if t.dim_order() != default_dim_order:
raise RuntimeError(
f"XNNPACK backend only supports contiguous memory format for inputs."
f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}."
)
@final
class XnnpackBackend(BackendDetails):
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
xnnpack_edge_compile_config = get_xnnpack_edge_compile_config()
# Need to wrap EP here because xnnpack does addmm to linear
# transforms. This makes resulting graph not aten compliant
# as aten.linear is not a core aten op.
# Ideal fix would be to have XNNPACK verifier that bypass
# most checks but the base Verifier itself has some strict changes
# and to bypass those, we would basically copy what EdgeDialectVerifier
# does. So for now instead of copy pasting that, just instantiate
# EdgeDialectVerifier, but disable it.
# TODO (task link) to implement NullVerifier or something similar
ep = ExportedProgram(
root=edge_program.graph_module,
graph=edge_program.graph,
graph_signature=edge_program.graph_signature,
state_dict=edge_program.state_dict,
range_constraints=edge_program.range_constraints,
module_call_graph=edge_program.module_call_graph,
example_inputs=edge_program.example_inputs,
constants=edge_program.constants,
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=xnnpack_edge_compile_config, class_only=True
)
],
)
passes = []
for spec in compile_specs:
if spec.key == "dqlinear_partitioner":
passes.append(ConvertToLinearPass)
passes.append(TagImplicitQDqPass)
passes = passes if len(passes) > 0 else None
# XNNPACK Delegate Specific Passes
ep = XNNPACKPassManager(ep, passes=passes).transform()
graph_module = ep.graph_module
node_to_external_map = generate_node_to_external_map(ep, graph_module)
# Make sure all inputs are contiguous_format or NCHW or default dim order
assert_default_dim_order(graph_module)
# TODO retrace the graph module to lift the new params may have
# been added to the graph in passes
vals_to_ids = {}
xnnpack_graph = XNNGraph(
version="0",
xnodes=[],
xvalues=[],
num_externs=len(node_to_external_map),
input_ids=[],
output_ids=[],
constant_data=[ConstantDataOffset(0, 0)],
)
constant_data_bytes = bytearray()
node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes)
for node in graph_module.graph.nodes:
if node.op == "call_function":
logger.info(f"Visiting: {node}, {node.target.__name__}")
if node.target.__name__ in node_visitors:
node_visitors[node.target.__name__].define_node(
node,
xnnpack_graph,
vals_to_ids,
node.meta.get("debug_handle", DEFAULT_DEBUG_HANDLE),
)
else:
raise RuntimeError(
f"For {node}, {node.op}:{node.target.__name__} is not supported in XNNPACK Delegate"
)
elif node.op in [
"get_attr",
"placeholder",
"output",
]:
continue
else:
raise RuntimeError(f"{node.op} is not supported in XNNPACK")
return PreprocessResult(
processed_bytes=serialize_xnnpack_binary(
xnnpack_graph, constant_data_bytes
),
debug_handle_map={},
)