| // |
| // Copyright (c) 2023 Apple Inc. All rights reserved. |
| // Provided subject to the LICENSE file in the top level directory. |
| // |
| |
| #pragma once |
| |
| // Obj-C headers |
| #include <Foundation/Foundation.h> |
| #include <Metal/Metal.h> |
| #include <MetalPerformanceShaders/MetalPerformanceShaders.h> |
| #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> |
| |
| // Runtime headers |
| #include <executorch/runtime/backend/interface.h> |
| #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
| |
| // MPS headers |
| #include <executorch/backends/apple/mps/runtime/operations/MPSGraphSequoiaOps.h> |
| #include <executorch/backends/apple/mps/runtime/operations/MPSGraphVenturaOps.h> |
| #include <executorch/backends/apple/mps/runtime/operations/OperationUtils.h> |
| #include <executorch/backends/apple/mps/schema_generated.h> |
| |
| #include <unordered_map> |
| #include <vector> |
| |
| namespace executorch { |
| namespace backends { |
| namespace mps { |
| namespace delegate { |
| |
| using Error = executorch::runtime::Error; |
| using DataType = mpsgraph::MPSDataType; |
| using TensorPtr = const mpsgraph::MPSTensor *; |
| using NodePtr = const mpsgraph::MPSNode *; |
| |
| #define _DEFINE_MPS_OP(name) Error mps##name##Op(NodePtr nodePtr); |
| |
| /** |
| * Helper class to construct a MPSGraph object from a serialized MPS FlatBuffer model. |
| * It records all the input placeholders, lifted weights/biases and output feeds. |
| */ |
| class MPSGraphBuilder { |
| public: |
| MPSGraphBuilder(const void *buffer_pointer, size_t num_bytes, |
| std::unordered_map<MPSGraphTensor *, int32_t> &mpsGraphTensorToId); |
| ~MPSGraphBuilder() = default; |
| |
| Error compileModel(); |
| MPSGraph *getMPSGraph(); |
| MPSGraphExecutable *getMPSGraphExecutable(); |
| |
| private: |
| // Input feeds & constant ops |
| Error mpsGraphRankedPlaceholder(int32_t id); |
| Error mpsConstantOp(int32_t id); |
| // Activation ops |
| _DEFINE_MPS_OP(HardTanh); |
| _DEFINE_MPS_OP(ReLU); |
| _DEFINE_MPS_OP(GELU); |
| _DEFINE_MPS_OP(LeakyReLU); |
| _DEFINE_MPS_OP(Softmax); |
| _DEFINE_MPS_OP(LogSoftmax); |
| // Arithmetic Binary Ops |
| _DEFINE_MPS_OP(Add); |
| _DEFINE_MPS_OP(Sub); |
| _DEFINE_MPS_OP(Mul); |
| _DEFINE_MPS_OP(Div); |
| _DEFINE_MPS_OP(Pow); |
| _DEFINE_MPS_OP(Fmod); |
| _DEFINE_MPS_OP(Remainder); |
| _DEFINE_MPS_OP(BitwiseAnd); |
| _DEFINE_MPS_OP(BitwiseOr); |
| _DEFINE_MPS_OP(BitwiseXor); |
| _DEFINE_MPS_OP(Minimum); |
| // Comparison ops |
| _DEFINE_MPS_OP(Eq); |
| _DEFINE_MPS_OP(Ne); |
| _DEFINE_MPS_OP(Ge); |
| _DEFINE_MPS_OP(Gt); |
| _DEFINE_MPS_OP(Le); |
| _DEFINE_MPS_OP(Lt); |
| // Unary ops |
| _DEFINE_MPS_OP(Exp); |
| _DEFINE_MPS_OP(Exp2); |
| _DEFINE_MPS_OP(Reciprocal); |
| _DEFINE_MPS_OP(Sqrt); |
| _DEFINE_MPS_OP(Neg); |
| _DEFINE_MPS_OP(Log); |
| _DEFINE_MPS_OP(Log10); |
| _DEFINE_MPS_OP(Log2); |
| _DEFINE_MPS_OP(Erf); |
| _DEFINE_MPS_OP(Floor); |
| _DEFINE_MPS_OP(Ceil); |
| _DEFINE_MPS_OP(Rsqrt); |
| _DEFINE_MPS_OP(Sigmoid); |
| _DEFINE_MPS_OP(Sin); |
| _DEFINE_MPS_OP(Sign); |
| _DEFINE_MPS_OP(Cos); |
| _DEFINE_MPS_OP(Tan); |
| _DEFINE_MPS_OP(Abs); |
| _DEFINE_MPS_OP(Asin); |
| _DEFINE_MPS_OP(Acos); |
| _DEFINE_MPS_OP(Atan); |
| _DEFINE_MPS_OP(Sinh); |
| _DEFINE_MPS_OP(Cosh); |
| _DEFINE_MPS_OP(Tanh); |
| _DEFINE_MPS_OP(Asinh); |
| _DEFINE_MPS_OP(Acosh); |
| _DEFINE_MPS_OP(Atanh); |
| _DEFINE_MPS_OP(BitwiseNot); |
| _DEFINE_MPS_OP(Isnan); |
| _DEFINE_MPS_OP(Isinf); |
| _DEFINE_MPS_OP(Round); |
| _DEFINE_MPS_OP(LogicalNot); |
| _DEFINE_MPS_OP(NormCdf); |
| // Clamp ops |
| _DEFINE_MPS_OP(Clamp); |
| _DEFINE_MPS_OP(Where); |
| // BitWise ops |
| // Convolution ops |
| _DEFINE_MPS_OP(Conv2D); |
| _DEFINE_MPS_OP(DepthwiseConv2D); |
| // Indexing ops |
| _DEFINE_MPS_OP(IndexSelect); |
| _DEFINE_MPS_OP(Embedding); |
| _DEFINE_MPS_OP(IndexTensor); |
| _DEFINE_MPS_OP(IndexPut); |
| _DEFINE_MPS_OP(Scatter); |
| // Linear algebra ops |
| _DEFINE_MPS_OP(MatMul); |
| _DEFINE_MPS_OP(Addmm); |
| // Constant ops |
| _DEFINE_MPS_OP(Full); |
| _DEFINE_MPS_OP(FullLike); |
| // Normalization ops |
| _DEFINE_MPS_OP(BatchNorm); |
| _DEFINE_MPS_OP(LayerNorm); |
| // Reduce ops |
| _DEFINE_MPS_OP(Mean); |
| // Shape ops |
| _DEFINE_MPS_OP(Permute); |
| _DEFINE_MPS_OP(View); |
| _DEFINE_MPS_OP(Expand); |
| _DEFINE_MPS_OP(Cat); |
| _DEFINE_MPS_OP(Squeeze); |
| _DEFINE_MPS_OP(Unsqueeze); |
| _DEFINE_MPS_OP(Select); |
| _DEFINE_MPS_OP(Slice); |
| _DEFINE_MPS_OP(PixelShuffle); |
| _DEFINE_MPS_OP(SplitWithSizes); |
| _DEFINE_MPS_OP(Cast); |
| // Pooling ops |
| _DEFINE_MPS_OP(MaxPool2DWithIndices); |
| _DEFINE_MPS_OP(AvgPool2D); |
| // Pad ops |
| _DEFINE_MPS_OP(ConstantPadND); |
| // Range ops |
| _DEFINE_MPS_OP(Arange); |
| // Quant-Dequant ops |
| _DEFINE_MPS_OP(DequantizePerChannelGroup); |
| |
| // Helper functions |
| Error addNodeToMPSGraph(NodePtr nodePtr); |
| Error compileMetalKernel(NodePtr nodePtr); |
| MPSShape *getMPSShape(int32_t id); |
| MPSShape *getMPSShape(const flatbuffers::Vector<int32_t> *shape); |
| int64_t numel(const flatbuffers::Vector<int32_t> *shape); |
| MPSDataType getMPSDataType(int32_t id); |
| MPSDataType getMPSDataType(DataType serializedDataType); |
| MPSGraphTensor *getMPSGraphTensor(int32_t id); |
| NSData *getConstantData(int32_t id); |
| std::pair<float, float> getMinMaxValues(NodePtr nodePtr); |
| Error compileMPSGraph(); |
| Error compileMetalKernel(); |
| |
| // Each MPSGraph op result in at least MPSGraphTensor being |
| // produced, which will be stored in this structure. Other ops |
| // can reference the saved tensor by the AOT id (1:1 mapping). |
| std::vector<MPSGraphTensor *> _idToMPSGraphTensor; |
| std::unordered_map<MPSGraphTensor *, int32_t> &_mpsGraphTensorToId; |
| // FlatBuffer serialized graph containing the nodes from the original model. |
| const mpsgraph::MPSGraph *_flatBufferGraph; |
| // FlatBuffer raw bytes of the serialized MPS model. |
| const void *_buffer_pointer; |
| size_t _num_bytes; |
| |
| bool _metal_kernel; |
| MPSGraph *_mpsGraph; |
| MPSGraphExecutable *_mpsGraphExecutable; |
| NSMutableDictionary<MPSGraphTensor *, MPSGraphShapedType *> *_feeds; |
| NSMutableArray<MPSGraphTensor *> *_targetTensors; |
| |
| const uint8_t *_constant_data_ptr; |
| }; |
| |
| #undef _DEFINE_MPS_OP |
| |
| } // namespace delegate |
| } // namespace mps |
| } // namespace backends |
| } // namespace executorch |