blob: c31ebba0e468dab473407b826c01f41119279a99 [file]
#
# Copyright (c) 2023 Apple Inc. All rights reserved.
# Provided subject to the LICENSE file in the top level directory.
#
from typing import cast, Optional, Union
import torch
from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSDataType
from executorch.exir import ExportedProgram
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
def get_input_node(node: torch.fx.Node, input_index: int) -> Union[torch.fx.Node, None]:
return None if node is None else cast(torch.fx.Node, node.args[input_index])
def get_scalar_val(node: torch.fx.Node, input_index: int) -> Union[float, int]:
return node.args[input_index]
def edge_dtype_to_mps_dtype(dtype: torch.dtype):
if not hasattr(edge_dtype_to_mps_dtype, "map"):
edge_dtype_to_mps_dtype.map = {
torch.float16: MPSDataType.mps_data_type_float16,
torch.float32: MPSDataType.mps_data_type_float32,
torch.float64: MPSDataType.mps_data_type_float32,
torch.bfloat16: MPSDataType.mps_data_type_bfloat16,
torch.int8: MPSDataType.mps_data_type_int8,
torch.int16: MPSDataType.mps_data_type_int16,
torch.int32: MPSDataType.mps_data_type_int32,
torch.int64: MPSDataType.mps_data_type_int64,
torch.uint8: MPSDataType.mps_data_type_uint8,
torch.bool: MPSDataType.mps_data_type_bool,
torch.cfloat: MPSDataType.mps_data_type_complex_float32,
torch.chalf: MPSDataType.mps_data_type_complex_float16,
}
try:
return edge_dtype_to_mps_dtype.map[dtype]
except KeyError:
raise RuntimeError(f"Invalid data type: {dtype}")
def get_param_tensor(
exp_prog: ExportedProgram, node: torch.fx.Node
) -> Optional[torch.Tensor]:
if node is None:
return None
elif is_param(exp_prog, node):
return get_param(exp_prog, node)
elif is_buffer(exp_prog, node):
return get_buffer(exp_prog, node)
elif is_get_attr(node):
# Support both lifted and unlifted graph
try:
# Unlifted graph (coming from old exir.capture API)
return getattr(node.graph.owning_module, node.target)
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
raise RuntimeError(f"unsupported param type, {node.op}.")
def is_get_attr(node: torch.fx.Node):
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool:
"""
Check if a node is a lifted parameter (static data like weights and bias are
are supplied as inputs to the graph.
Args:
exp_prog (torch.export.ExportedProgram): _description_
node (torch.fx.Node): _description_
Returns:
bool: _description_
"""
return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node)