| // |
| // Copyright (c) 2023 Apple Inc. All rights reserved. |
| // Provided subject to the LICENSE file in the top level directory. |
| // |
| |
| namespace mpsgraph; |
| |
| // Update after any BC breaking changes |
| file_identifier "MP00"; |
| |
| // datatype for mps-values |
| enum MPSDataType : short { |
| mps_data_type_invalid = 0, |
| mps_data_type_float16 = 1, |
| mps_data_type_float32 = 2, |
| mps_data_type_float64 = 3, |
| mps_data_type_bfloat16 = 4, |
| |
| // Signed integers. |
| mps_data_type_int4 = 5, |
| mps_data_type_int8 = 6, |
| mps_data_type_int16 = 7, |
| mps_data_type_int32 = 8, |
| mps_data_type_int64 = 9, |
| |
| |
| // Unsigned integers. range: [0, UTYPE_MAX] |
| mps_data_type_uint4 = 10, |
| mps_data_type_uint8 = 11, |
| mps_data_type_uint16 = 12, |
| mps_data_type_uint32 = 13, |
| mps_data_type_uint64 = 14, |
| |
| mps_data_type_bool = 15, |
| |
| mps_data_type_complex_float16 = 16, |
| mps_data_type_complex_float32 = 17, |
| } |
| |
| // ops like index.Tensor and index.put are currentely implemented as |
| // Metal kernels for unsupported MPSGraph cases. |
| enum OpType : short { |
| mps_graph, |
| metal_kernel |
| } |
| |
| // Helper classes to define the number of input and output tensors for a node. |
| // Not meant to be used directly. |
| |
| // A node with one input and one output. |
| table _MPSNode1x1 { |
| input1_id:int; |
| output_id:int; |
| } |
| |
| // A node with two inputs and one output. |
| table _MPSNode2x1 { |
| input1_id:int; |
| input2_id:int; |
| output_id:int; |
| } |
| |
| table _MPSDivNode2x1 { |
| input1_id:int; |
| input2_id:int; |
| output_id:int; |
| rounding_mode:string; |
| } |
| |
| table _MPSNodeWithAlpha2x1 { |
| input1_id:int; |
| input2_id:int; |
| output_id:int; |
| alpha:float; |
| } |
| |
| // A node with three inputs and one output. |
| table _MPSNode3x1 { |
| input1_id:int; |
| input2_id:int; |
| input3_id:int; |
| output_id:int; |
| } |
| |
| table MPSMinMax { |
| min_value:float; |
| max_value:float; |
| } |
| |
| table MPSPooling2D { |
| input1_id:int; |
| kernel_height:int; |
| kernel_width:int; |
| stride_height:int; |
| stride_width:int; |
| padding_left:int; |
| padding_right:int; |
| padding_top:int; |
| padding_bottom:int; |
| dilation_height:int; |
| dilation_width:int; |
| ceil_mode:bool; |
| count_include_pad:bool; |
| divisor_override:int; |
| output1_id:int; |
| output2_id:int; |
| } |
| |
| // Activation ops. |
| table MPSHardTanh { |
| input1_id:int; |
| output_id:int; |
| min_value:float; |
| max_value:float; |
| } |
| |
| table MPSGELU { |
| input1_id:int; |
| output_id:int; |
| approximate:string; |
| } |
| |
| table MPSLeakyReLU { |
| input1_id:int; |
| output_id:int; |
| negative_slope:float; |
| } |
| |
| table MPSSoftmax { |
| input1_id:int; |
| output_id:int; |
| dim:int; |
| half_to_float:bool; |
| } |
| |
| // Clamp ops |
| table MPSClamp { |
| input1_id:int; |
| output_id:int; |
| } |
| |
| // Reduce ops |
| table MPSMean { |
| input1_id:int; |
| output_id:int; |
| num_dims:int; |
| dims:[int]; |
| keep_dims:bool; |
| } |
| |
| // Indexing ops |
| table MPSIndexSelect { |
| input1_id:int; |
| output_id:int; |
| dim:int; |
| index_id:int; |
| } |
| |
| table MPSEmbedding { |
| input1_id:int; |
| input2_id:int; |
| output_id:int; |
| padding_idx:int; |
| scale_grad_by_freq:bool; |
| sparse:bool; |
| } |
| |
| table MPSIndexTensor { |
| input1_id:int; |
| indices_id:[int]; |
| output_id:int; |
| } |
| |
| table MPSIndexPut { |
| input1_id:int; |
| indices_id:[int]; |
| values_shape:[int]; |
| values_id:int; |
| output_id:int; |
| } |
| |
| table MPSScatter { |
| input1_id:int; |
| output_id:int; |
| dim:long; |
| idx_id:int; |
| src_id:int; |
| } |
| |
| // Shape ops. |
| table MPSPermute { |
| input1_id:int; |
| output_id:int; |
| num_dims:int; |
| perm:[int]; |
| } |
| |
| table MPSView { |
| input1_id:int; |
| output_id:int; |
| num_dims:int; |
| shape:[int]; |
| } |
| |
| table MPSCat { |
| input_ids:[int]; |
| output_id:int; |
| dim:int; |
| } |
| |
| table MPSSqueeze { |
| input1_id:int; |
| output_id:int; |
| dims:[int]; |
| } |
| |
| table MPSUnsqueeze { |
| input1_id:int; |
| output_id:int; |
| dim:int; |
| } |
| |
| table MPSSelect { |
| input1_id:int; |
| output_id:int; |
| dim:int; |
| index:int; |
| } |
| |
| table MPSSlice { |
| input1_id:int; |
| output_id:int; |
| dim:long; |
| start:long; |
| end:long; |
| step:long; |
| } |
| |
| table MPSPixelShuffle { |
| input1_id:int; |
| output_id:int; |
| upscale_factor:int; |
| } |
| |
| table MPSSplitWithSizes { |
| input1_id:int; |
| output_ids:[int]; |
| split_sizes:[int]; |
| dim:int; |
| } |
| |
| table MPSCast { |
| input1_id:int; |
| output_id:int; |
| dtype:MPSDataType; |
| } |
| |
| // Linear algebra ops. |
| table MPSAddmm { |
| input1_id:int; |
| input2_id:int; |
| input3_id:int; |
| output_id:int; |
| beta:float; |
| alpha:float; |
| } |
| |
| // Constant ops |
| table _MPSFull { |
| input1_id:int; |
| output_id:int; |
| shape:[int]; |
| fill_value: float; |
| dtype:MPSDataType; |
| } |
| |
| // Convolution ops. |
| table MPSConv { |
| input1_id:int; |
| input2_id:int; |
| input3_id:int; |
| output_id:int; |
| stride_x:int; |
| stride_y:int; |
| dilation_x:int; |
| dilation_y:int; |
| groups:int; |
| padding_left:int; |
| padding_right:int; |
| padding_top:int; |
| padding_bottom:int; |
| } |
| |
| // Normalization ops. |
| table MPSBatchNorm { |
| input_id:int; |
| mean_id:int; |
| var_id:int; |
| weight_id:int; |
| bias_id:int; |
| momentum:float; |
| epsilon:float; |
| output2_id:int; |
| output1_id:int; |
| output3_id:int; |
| } |
| |
| table MPSLayerNorm { |
| input1_id:int; |
| normalized_shape:[int]; |
| weight_id:int; |
| bias_id:int; |
| eps:float; |
| output2_id:int; |
| output1_id:int; |
| output3_id:int; |
| } |
| |
| // Pooling ops |
| |
| // Pad ops |
| table MPSConstantPadND { |
| input1_id:int; |
| output_id:int; |
| pad:[int]; |
| value:float; |
| } |
| |
| // Range ops |
| table MPSArange { |
| output_id:int; |
| start:float; |
| end:float; |
| step:float; |
| dtype:MPSDataType; |
| } |
| |
| // Quant - Dequant ops |
| table MPSDequantizePerChannelGroup { |
| input1_id:int; |
| output_id:int; |
| scales_id:int; |
| zero_points_id:int; |
| quant_min:int; |
| quant_max:int; |
| dtype:MPSDataType; |
| group_size:int; |
| output_dtype:MPSDataType; |
| } |
| |
| union MPSNodeUnion { |
| // Activation ops |
| MPSHardTanh, |
| MPSReLU: _MPSNode2x1, |
| MPSGELU, |
| MPSLeakyReLU, |
| MPSSoftmax, |
| MPSLogSoftmax: MPSSoftmax, |
| |
| // Binary ops |
| MPSAdd: _MPSNodeWithAlpha2x1, |
| MPSSub: _MPSNodeWithAlpha2x1, |
| MPSMul: _MPSNode2x1, |
| MPSDiv: _MPSDivNode2x1, |
| MPSFmod: _MPSDivNode2x1, |
| MPSRemainder: _MPSDivNode2x1, |
| MPSMin: _MPSNode2x1, |
| MPSMax: _MPSNode2x1, |
| MPSPow: _MPSNode2x1, |
| MPSAtan2: _MPSNode2x1, |
| MPSBitwiseAnd: _MPSNode2x1, |
| MPSBitwiseOr: _MPSNode2x1, |
| MPSBitwiseXor: _MPSNode2x1, |
| MPSMinimum: _MPSNode2x1, |
| |
| // Unary ops |
| MPSExp: _MPSNode1x1, |
| MPSExp2: _MPSNode1x1, |
| MPSReciprocal: _MPSNode1x1, |
| MPSSqrt: _MPSNode1x1, |
| MPSNeg: _MPSNode1x1, |
| MPSLog: _MPSNode1x1, |
| MPSLog10: _MPSNode1x1, |
| MPSLog2: _MPSNode1x1, |
| MPSErf: _MPSNode1x1, |
| MPSFloor: _MPSNode1x1, |
| MPSCeil: _MPSNode1x1, |
| MPSRsqrt: _MPSNode1x1, |
| MPSSigmoid: _MPSNode1x1, |
| MPSSin: _MPSNode1x1, |
| MPSSign: _MPSNode1x1, |
| MPSCos: _MPSNode1x1, |
| MPSTan: _MPSNode1x1, |
| MPSAbs: _MPSNode1x1, |
| MPSAsin: _MPSNode1x1, |
| MPSAcos: _MPSNode1x1, |
| MPSAtan: _MPSNode1x1, |
| MPSSinh: _MPSNode1x1, |
| MPSCosh: _MPSNode1x1, |
| MPSTanh: _MPSNode1x1, |
| MPSAsinh: _MPSNode1x1, |
| MPSAcosh: _MPSNode1x1, |
| MPSAtanh: _MPSNode1x1, |
| MPSBitwiseNot: _MPSNode1x1, |
| MPSIsnan: _MPSNode1x1, |
| MPSIsinf: _MPSNode1x1, |
| MPSRound: _MPSNode1x1, |
| MPSLogicalNot: _MPSNode1x1, |
| |
| // Linear algebra ops |
| MPSMatMul: _MPSNode2x1, |
| MPSAddmm, |
| |
| // Constant ops |
| MPSFull: _MPSFull, |
| MPSFullLike: _MPSFull, |
| |
| // Clamp ops, |
| MPSClamp, |
| MPSWhere: _MPSNode3x1, |
| |
| // Indexing ops |
| MPSIndexSelect, |
| MPSEmbedding, |
| MPSIndexTensor, |
| MPSIndexPut, |
| MPSScatter, |
| |
| // Reduce ops |
| MPSMean, |
| |
| // Shape ops |
| MPSPermute, |
| MPSView, |
| MPSExpand: MPSView, |
| MPSCat, |
| MPSSqueeze, |
| MPSUnsqueeze, |
| MPSSelect, |
| MPSSlice, |
| MPSPixelShuffle, |
| MPSSplitWithSizes, |
| MPSCast, |
| |
| // Convolution ops |
| MPSConv2D: MPSConv, |
| MPSDepthwiseConv2D: MPSConv, |
| |
| // Comparasion ops |
| MPSEq: _MPSNode2x1, |
| MPSNe: _MPSNode2x1, |
| MPSGe: _MPSNode2x1, |
| MPSGt: _MPSNode2x1, |
| MPSLe: _MPSNode2x1, |
| MPSLt: _MPSNode2x1, |
| |
| // Normalization ops |
| MPSBatchNorm, |
| MPSLayerNorm, |
| |
| // Pooling ops |
| MPSMaxPool2DWithIndices: MPSPooling2D, |
| MPSAvgPool2D: MPSPooling2D, |
| |
| // Pad ops |
| MPSConstantPadND, |
| |
| // Range ops |
| MPSArange, |
| |
| // Quant-Dequant ops |
| MPSDequantizePerChannelGroup, |
| } |
| |
| table MPSNode { |
| mpsnode_union:MPSNodeUnion; |
| min_max:MPSMinMax; |
| } |
| |
| // taken from executorch |
| // Data buffer abstraction. |
| // Deprecated |
| table Buffer { |
| storage:[ubyte] (force_align: 16); |
| } |
| |
| table MPSTensor { |
| datatype:MPSDataType; |
| num_dims:int; |
| dims:[int]; |
| constant_buffer_size:uint64; |
| constant_buffer:Buffer; // deprecated |
| segment_offset:uint64; |
| } |
| |
| table DataSegment { |
| // Segment offsets are relative to the segment base offset provided in |
| // the extended file header. Segments will typically be aligned in a |
| // way to make it possible to use mmap() to load them. |
| offset: uint64; |
| |
| // The size in bytes of valid data starting at the offset. The segment |
| // data may be followed by padding before the segment that follows it, |
| // to make it easier to use mmap(). |
| size: uint64; |
| } |
| |
| table MPSGraph { |
| // Schema version. |
| version:string; |
| mps_nodes:[MPSNode]; |
| mps_values:[MPSTensor]; |
| |
| input_ids:[int]; |
| output_ids:[int]; |
| constant_ids:[int]; |
| |
| graph_type:OpType; |
| |
| constant_segment:DataSegment; |
| } |
| |
| root_type MPSGraph; |