blob: d38a67a656eb21f2856775ae59ca413747b5833c [file]
//
// Copyright (c) 2023 Apple Inc. All rights reserved.
// Provided subject to the LICENSE file in the top level directory.
//
namespace mpsgraph;
// Update after any BC breaking changes
file_identifier "MP00";
// datatype for mps-values
enum MPSDataType : short {
mps_data_type_invalid = 0,
mps_data_type_float16 = 1,
mps_data_type_float32 = 2,
mps_data_type_float64 = 3,
mps_data_type_bfloat16 = 4,
// Signed integers.
mps_data_type_int4 = 5,
mps_data_type_int8 = 6,
mps_data_type_int16 = 7,
mps_data_type_int32 = 8,
mps_data_type_int64 = 9,
// Unsigned integers. range: [0, UTYPE_MAX]
mps_data_type_uint4 = 10,
mps_data_type_uint8 = 11,
mps_data_type_uint16 = 12,
mps_data_type_uint32 = 13,
mps_data_type_uint64 = 14,
mps_data_type_bool = 15,
mps_data_type_complex_float16 = 16,
mps_data_type_complex_float32 = 17,
}
// ops like index.Tensor and index.put are currentely implemented as
// Metal kernels for unsupported MPSGraph cases.
enum OpType : short {
mps_graph,
metal_kernel
}
// Helper classes to define the number of input and output tensors for a node.
// Not meant to be used directly.
// A node with one input and one output.
table _MPSNode1x1 {
input1_id:int;
output_id:int;
}
// A node with two inputs and one output.
table _MPSNode2x1 {
input1_id:int;
input2_id:int;
output_id:int;
}
table _MPSDivNode2x1 {
input1_id:int;
input2_id:int;
output_id:int;
rounding_mode:string;
}
table _MPSNodeWithAlpha2x1 {
input1_id:int;
input2_id:int;
output_id:int;
alpha:float;
}
// A node with three inputs and one output.
table _MPSNode3x1 {
input1_id:int;
input2_id:int;
input3_id:int;
output_id:int;
}
table MPSMinMax {
min_value:float;
max_value:float;
}
table MPSPooling2D {
input1_id:int;
kernel_height:int;
kernel_width:int;
stride_height:int;
stride_width:int;
padding_left:int;
padding_right:int;
padding_top:int;
padding_bottom:int;
dilation_height:int;
dilation_width:int;
ceil_mode:bool;
count_include_pad:bool;
divisor_override:int;
output1_id:int;
output2_id:int;
}
// Activation ops.
table MPSHardTanh {
input1_id:int;
output_id:int;
min_value:float;
max_value:float;
}
table MPSGELU {
input1_id:int;
output_id:int;
approximate:string;
}
table MPSLeakyReLU {
input1_id:int;
output_id:int;
negative_slope:float;
}
table MPSSoftmax {
input1_id:int;
output_id:int;
dim:int;
half_to_float:bool;
}
// Clamp ops
table MPSClamp {
input1_id:int;
output_id:int;
}
// Reduce ops
table MPSMean {
input1_id:int;
output_id:int;
num_dims:int;
dims:[int];
keep_dims:bool;
}
// Indexing ops
table MPSIndexSelect {
input1_id:int;
output_id:int;
dim:int;
index_id:int;
}
table MPSEmbedding {
input1_id:int;
input2_id:int;
output_id:int;
padding_idx:int;
scale_grad_by_freq:bool;
sparse:bool;
}
table MPSIndexTensor {
input1_id:int;
indices_id:[int];
output_id:int;
}
table MPSIndexPut {
input1_id:int;
indices_id:[int];
values_shape:[int];
values_id:int;
output_id:int;
}
table MPSScatter {
input1_id:int;
output_id:int;
dim:long;
idx_id:int;
src_id:int;
}
// Shape ops.
table MPSPermute {
input1_id:int;
output_id:int;
num_dims:int;
perm:[int];
}
table MPSView {
input1_id:int;
output_id:int;
num_dims:int;
shape:[int];
}
table MPSCat {
input_ids:[int];
output_id:int;
dim:int;
}
table MPSSqueeze {
input1_id:int;
output_id:int;
dims:[int];
}
table MPSUnsqueeze {
input1_id:int;
output_id:int;
dim:int;
}
table MPSSelect {
input1_id:int;
output_id:int;
dim:int;
index:int;
}
table MPSSlice {
input1_id:int;
output_id:int;
dim:long;
start:long;
end:long;
step:long;
}
table MPSPixelShuffle {
input1_id:int;
output_id:int;
upscale_factor:int;
}
table MPSSplitWithSizes {
input1_id:int;
output_ids:[int];
split_sizes:[int];
dim:int;
}
table MPSCast {
input1_id:int;
output_id:int;
dtype:MPSDataType;
}
// Linear algebra ops.
table MPSAddmm {
input1_id:int;
input2_id:int;
input3_id:int;
output_id:int;
beta:float;
alpha:float;
}
// Constant ops
table _MPSFull {
input1_id:int;
output_id:int;
shape:[int];
fill_value: float;
dtype:MPSDataType;
}
// Convolution ops.
table MPSConv {
input1_id:int;
input2_id:int;
input3_id:int;
output_id:int;
stride_x:int;
stride_y:int;
dilation_x:int;
dilation_y:int;
groups:int;
padding_left:int;
padding_right:int;
padding_top:int;
padding_bottom:int;
}
// Normalization ops.
table MPSBatchNorm {
input_id:int;
mean_id:int;
var_id:int;
weight_id:int;
bias_id:int;
momentum:float;
epsilon:float;
output2_id:int;
output1_id:int;
output3_id:int;
}
table MPSLayerNorm {
input1_id:int;
normalized_shape:[int];
weight_id:int;
bias_id:int;
eps:float;
output2_id:int;
output1_id:int;
output3_id:int;
}
// Pooling ops
// Pad ops
table MPSConstantPadND {
input1_id:int;
output_id:int;
pad:[int];
value:float;
}
// Range ops
table MPSArange {
output_id:int;
start:float;
end:float;
step:float;
dtype:MPSDataType;
}
// Quant - Dequant ops
table MPSDequantizePerChannelGroup {
input1_id:int;
output_id:int;
scales_id:int;
zero_points_id:int;
quant_min:int;
quant_max:int;
dtype:MPSDataType;
group_size:int;
output_dtype:MPSDataType;
}
union MPSNodeUnion {
// Activation ops
MPSHardTanh,
MPSReLU: _MPSNode2x1,
MPSGELU,
MPSLeakyReLU,
MPSSoftmax,
MPSLogSoftmax: MPSSoftmax,
// Binary ops
MPSAdd: _MPSNodeWithAlpha2x1,
MPSSub: _MPSNodeWithAlpha2x1,
MPSMul: _MPSNode2x1,
MPSDiv: _MPSDivNode2x1,
MPSFmod: _MPSDivNode2x1,
MPSRemainder: _MPSDivNode2x1,
MPSMin: _MPSNode2x1,
MPSMax: _MPSNode2x1,
MPSPow: _MPSNode2x1,
MPSAtan2: _MPSNode2x1,
MPSBitwiseAnd: _MPSNode2x1,
MPSBitwiseOr: _MPSNode2x1,
MPSBitwiseXor: _MPSNode2x1,
MPSMinimum: _MPSNode2x1,
// Unary ops
MPSExp: _MPSNode1x1,
MPSExp2: _MPSNode1x1,
MPSReciprocal: _MPSNode1x1,
MPSSqrt: _MPSNode1x1,
MPSNeg: _MPSNode1x1,
MPSLog: _MPSNode1x1,
MPSLog10: _MPSNode1x1,
MPSLog2: _MPSNode1x1,
MPSErf: _MPSNode1x1,
MPSFloor: _MPSNode1x1,
MPSCeil: _MPSNode1x1,
MPSRsqrt: _MPSNode1x1,
MPSSigmoid: _MPSNode1x1,
MPSSin: _MPSNode1x1,
MPSSign: _MPSNode1x1,
MPSCos: _MPSNode1x1,
MPSTan: _MPSNode1x1,
MPSAbs: _MPSNode1x1,
MPSAsin: _MPSNode1x1,
MPSAcos: _MPSNode1x1,
MPSAtan: _MPSNode1x1,
MPSSinh: _MPSNode1x1,
MPSCosh: _MPSNode1x1,
MPSTanh: _MPSNode1x1,
MPSAsinh: _MPSNode1x1,
MPSAcosh: _MPSNode1x1,
MPSAtanh: _MPSNode1x1,
MPSBitwiseNot: _MPSNode1x1,
MPSIsnan: _MPSNode1x1,
MPSIsinf: _MPSNode1x1,
MPSRound: _MPSNode1x1,
MPSLogicalNot: _MPSNode1x1,
// Linear algebra ops
MPSMatMul: _MPSNode2x1,
MPSAddmm,
// Constant ops
MPSFull: _MPSFull,
MPSFullLike: _MPSFull,
// Clamp ops,
MPSClamp,
MPSWhere: _MPSNode3x1,
// Indexing ops
MPSIndexSelect,
MPSEmbedding,
MPSIndexTensor,
MPSIndexPut,
MPSScatter,
// Reduce ops
MPSMean,
// Shape ops
MPSPermute,
MPSView,
MPSExpand: MPSView,
MPSCat,
MPSSqueeze,
MPSUnsqueeze,
MPSSelect,
MPSSlice,
MPSPixelShuffle,
MPSSplitWithSizes,
MPSCast,
// Convolution ops
MPSConv2D: MPSConv,
MPSDepthwiseConv2D: MPSConv,
// Comparasion ops
MPSEq: _MPSNode2x1,
MPSNe: _MPSNode2x1,
MPSGe: _MPSNode2x1,
MPSGt: _MPSNode2x1,
MPSLe: _MPSNode2x1,
MPSLt: _MPSNode2x1,
// Normalization ops
MPSBatchNorm,
MPSLayerNorm,
// Pooling ops
MPSMaxPool2DWithIndices: MPSPooling2D,
MPSAvgPool2D: MPSPooling2D,
// Pad ops
MPSConstantPadND,
// Range ops
MPSArange,
// Quant-Dequant ops
MPSDequantizePerChannelGroup,
}
table MPSNode {
mpsnode_union:MPSNodeUnion;
min_max:MPSMinMax;
}
// taken from executorch
// Data buffer abstraction.
// Deprecated
table Buffer {
storage:[ubyte] (force_align: 16);
}
table MPSTensor {
datatype:MPSDataType;
num_dims:int;
dims:[int];
constant_buffer_size:uint64;
constant_buffer:Buffer; // deprecated
segment_offset:uint64;
}
table DataSegment {
// Segment offsets are relative to the segment base offset provided in
// the extended file header. Segments will typically be aligned in a
// way to make it possible to use mmap() to load them.
offset: uint64;
// The size in bytes of valid data starting at the offset. The segment
// data may be followed by padding before the segment that follows it,
// to make it easier to use mmap().
size: uint64;
}
table MPSGraph {
// Schema version.
version:string;
mps_nodes:[MPSNode];
mps_values:[MPSTensor];
input_ids:[int];
output_ids:[int];
constant_ids:[int];
graph_type:OpType;
constant_segment:DataSegment;
}
root_type MPSGraph;