blob: 8362774fa978e689d46c96fc16e892909ecc6f6b [file]
#
# Copyright (c) 2023 Apple Inc. All rights reserved.
# Provided subject to the LICENSE file in the top level directory.
#
import logging
from typing import ClassVar, Dict, final, List, Tuple
import torch
from executorch.backends.apple.mps.operators.node_visitor import (
get_node_visitors,
NodeVisitor,
process_output_node,
process_placeholder_nodes,
)
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
Buffer,
DataSegment,
MPSGraph,
MPSTensor,
OpType,
)
from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
convert_to_flatbuffer,
)
from executorch.exir._serialize._program import Cord
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from torch.export.exported_program import ExportedProgram
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
@final
class MPSBackend(BackendDetails):
@staticmethod
def slice_len_max(s):
assert s.start is not None
assert s.stop is not None
step = 1
if s.step is not None:
step = s.step
return max((s.stop - s.start) // step, 1)
MAGIC_IX: ClassVar[slice] = slice(4, 8)
DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16)
DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24)
# magic bytes that should be at the beginning of the header
EXPECTED_MAGIC: ClassVar[bytes] = b"MP00"
# The length of the header in bytes
EXPECTED_LENGTH: ClassVar[int] = (
4
+ slice_len_max(MAGIC_IX)
+ slice_len_max(DATA_SEGMENT_OFFSET_IX)
+ slice_len_max(DATA_SEGMENT_SIZE_IX)
)
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
# The EdgeIR nodes are processed in the following order:
# 1. Process first the input feeds to the graph (in the same
# order as args from forward(*args)), and generate a unique
# id for each input placeholder. Each input id is appended to
# `input_ids` array from the FlatBuffer schema.
# 2. Process the nodes the graph (e.g `call_function`). For each
# EdgeIR node, create an equivalent MPS node in the FlatBuffer,
# based on which the MPSGraph is constructed at runtime. During
# this process, any visited constant in the EdgeIR is added to the
# final MPS FlatBuffer schema. Each constant id is appended to the
# `constant_ids` FlatBuffer schema.
# 3. After all the inputs, nodes and constants are added to the
# FlatBuffer graph, process the `output` nodes and add their id to
# the `output_ids` array in the schema.
mps_graph = MPSGraph(
version="0",
mps_nodes=[],
mps_values=[],
input_ids=[],
output_ids=[],
constant_ids=[],
graph_type=OpType.mps_graph,
constant_segment=DataSegment(0, 0),
)
convert_model_to_fp16 = True
for spec in compile_specs:
if spec.key == "use_fp16":
convert_model_to_fp16 = bool(list(bytes(spec.value))[0])
logging.debug(f"Convert model to FP16: {convert_model_to_fp16}")
node_visitors = get_node_visitors(edge_program, convert_model_to_fp16)
if logging.DEBUG >= logging.root.level:
edge_program.graph.print_tabular()
process_placeholder_nodes(
edge_program,
edge_program.graph_module,
mps_graph,
node_visitors["placeholder"],
)
op_handler = {
"call_function": MPSBackend.handle_call_function,
"placeholder": MPSBackend.handle_placeholder,
"output": MPSBackend.handle_output,
"get_attr": MPSBackend.handle_get_attr,
}
for node in edge_program.graph_module.graph.nodes:
if node.op not in op_handler:
raise RuntimeError(f"{node.op} is not supported in MPS")
else:
op_handler[node.op](edge_program, node_visitors, node, mps_graph)
segment_data, mps_graph = _extract_constant_segment(mps_graph)
if logging.DEBUG >= logging.root.level:
pretty_print(mps_graph)
# Add to aggregate segments cord with padding.
padding_length = _padding_required(len(segment_data), 16)
if padding_length > 0:
segment_data.append(b"\x00" * padding_length)
# Combine mps_graph with segment data
combined = Cord()
graph_bytes = convert_to_flatbuffer(mps_graph)
data_segment_offset: int = MPSBackend.EXPECTED_LENGTH
data_segment_offset = data_segment_offset + len(graph_bytes)
graph_padding_length = _padding_required(data_segment_offset, 16)
data_segment_offset = data_segment_offset + graph_padding_length
data_segment_size = len(segment_data)
data: bytes = (
b"\x00\x00\x00\x00"
+ MPSBackend.EXPECTED_MAGIC
+ data_segment_offset.to_bytes(8, byteorder="little")
+ data_segment_size.to_bytes(8, byteorder="little")
)
assert len(data) == MPSBackend.EXPECTED_LENGTH
combined.append(data)
combined.append(graph_bytes)
if graph_padding_length > 0:
combined.append(b"\x00" * graph_padding_length)
# Append the segment data to the end of the mps graph
combined.append(segment_data)
return PreprocessResult(processed_bytes=bytes(combined))
@staticmethod
def handle_call_function(
_: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
logging.info(f"Visiting: {node}, {node.target.__name__}")
if (
"delegation_tag" in node.meta
and "metal_kernel" in node.meta["delegation_tag"]
):
logging.info(
f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!"
)
mps_graph.graph_type = OpType.metal_kernel
if node.target.__name__ in node_visitors:
node_visitors[node.target.__name__].define_node(node, mps_graph)
else:
pretty_print(mps_graph)
raise RuntimeError(
f"For {node}, {node.op}:{node.target.__name__} is not supported in MPS delegate"
)
@staticmethod
def handle_placeholder(
edge_program: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
# Constants are handled directly when visiting the nodes.
pass
@staticmethod
def handle_output(
edge_program: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
for output_nodes in node.args:
for output_node in output_nodes:
process_output_node(output_node, mps_graph, node_visitors[node.op])
@staticmethod
def handle_get_attr(
edge_program: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
pass
def _padding_required(offset: int, alignment: int) -> int:
"""Returns the padding required to align `offset` to `alignment`."""
remainder: int = offset % alignment
if remainder != 0:
return alignment - remainder
return 0
def _extract_constant_segment(mps_graph: MPSGraph) -> Tuple[Cord, MPSGraph]:
"""Extracts the constant segment from the MPSGraph and returns the updated MPSGraph along with the segment data."""
# Note that the beginning of the segment data is not aligned. Need to handle out of this call.
segment_data = Cord()
offset = 0
for i in range(len(mps_graph.mps_values)):
tensor = mps_graph.mps_values[i]
if tensor.constant_buffer_size > 0:
# Notice that buffer is already force aligned so we don't need to pad it
segment_data.append(tensor.constant_buffer.storage)
# Reset buffer to empty
tensor.constant_buffer = Buffer(storage=b"")
# Update segment offset
tensor.segment_offset = offset
offset += tensor.constant_buffer_size
return segment_data, mps_graph
def tensor_to_str(mps_tensor: MPSTensor):
tensor_str = "MPSTensor("
tensor_str += "datatype=" + str(mps_tensor.datatype) + ", "
tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", "
tensor_str += "dims=" + str(mps_tensor.dims) + ", "
tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", "
tensor_str += "segment_offset=" + str(mps_tensor.segment_offset)
tensor_str += ")"
return tensor_str
def pretty_print(mps_graph: MPSGraph):
logging.info("Serialized MPSGraph:")
logging.info(f" Version: {mps_graph.version}")
logging.info(" MPS nodes: ")
for i in range(len(mps_graph.mps_nodes)):
logging.info(f" [{i}]: {mps_graph.mps_nodes[i]}")
logging.info(" MPS values: ")
for i in range(len(mps_graph.mps_values)):
logging.info(f" [{i}]: {tensor_to_str(mps_graph.mps_values[i])}")
logging.info(" Input ids:")
for in_id in mps_graph.input_ids:
logging.info(f" {in_id}")
logging.info(" Constant ids:")
for constant_id in mps_graph.constant_ids:
logging.info(f" {constant_id}")
logging.info(" Output ids:")
for out_id in mps_graph.output_ids:
logging.info(f" {out_id}")
logging.info(f" Constant segment: {mps_graph.constant_segment}")