| # |
| # Copyright (c) 2023 Apple Inc. All rights reserved. |
| # Provided subject to the LICENSE file in the top level directory. |
| # |
| |
| from typing import cast, Optional, Union |
| |
| import torch |
| from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSDataType |
| from executorch.exir import ExportedProgram |
| from torch._export.utils import get_buffer, get_param, is_buffer, is_param |
| |
| |
| def get_input_node(node: torch.fx.Node, input_index: int) -> Union[torch.fx.Node, None]: |
| return None if node is None else cast(torch.fx.Node, node.args[input_index]) |
| |
| |
| def get_scalar_val(node: torch.fx.Node, input_index: int) -> Union[float, int]: |
| return node.args[input_index] |
| |
| |
| def edge_dtype_to_mps_dtype(dtype: torch.dtype): |
| if not hasattr(edge_dtype_to_mps_dtype, "map"): |
| edge_dtype_to_mps_dtype.map = { |
| torch.float16: MPSDataType.mps_data_type_float16, |
| torch.float32: MPSDataType.mps_data_type_float32, |
| torch.float64: MPSDataType.mps_data_type_float32, |
| torch.bfloat16: MPSDataType.mps_data_type_bfloat16, |
| torch.int8: MPSDataType.mps_data_type_int8, |
| torch.int16: MPSDataType.mps_data_type_int16, |
| torch.int32: MPSDataType.mps_data_type_int32, |
| torch.int64: MPSDataType.mps_data_type_int64, |
| torch.uint8: MPSDataType.mps_data_type_uint8, |
| torch.bool: MPSDataType.mps_data_type_bool, |
| torch.cfloat: MPSDataType.mps_data_type_complex_float32, |
| torch.chalf: MPSDataType.mps_data_type_complex_float16, |
| } |
| try: |
| return edge_dtype_to_mps_dtype.map[dtype] |
| except KeyError: |
| raise RuntimeError(f"Invalid data type: {dtype}") |
| |
| |
| def get_param_tensor( |
| exp_prog: ExportedProgram, node: torch.fx.Node |
| ) -> Optional[torch.Tensor]: |
| if node is None: |
| return None |
| elif is_param(exp_prog, node): |
| return get_param(exp_prog, node) |
| elif is_buffer(exp_prog, node): |
| return get_buffer(exp_prog, node) |
| elif is_get_attr(node): |
| # Support both lifted and unlifted graph |
| try: |
| # Unlifted graph (coming from old exir.capture API) |
| return getattr(node.graph.owning_module, node.target) |
| except AttributeError: |
| return getattr(exp_prog.graph_module, node.target) |
| raise RuntimeError(f"unsupported param type, {node.op}.") |
| |
| |
| def is_get_attr(node: torch.fx.Node): |
| """ |
| Returns true if the given node is a get attr node for a tensor of the model |
| """ |
| return isinstance(node, torch.fx.Node) and node.op == "get_attr" |
| |
| |
| def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool: |
| """ |
| Check if a node is a lifted parameter (static data like weights and bias are |
| are supplied as inputs to the graph. |
| |
| Args: |
| exp_prog (torch.export.ExportedProgram): _description_ |
| node (torch.fx.Node): _description_ |
| |
| Returns: |
| bool: _description_ |
| """ |
| return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node) |