blob: 682a905335a4cd2d4ae8dd76a8ad88c00d632154 [file]
//
// Copyright (c) 2023 Apple Inc. All rights reserved.
// Provided subject to the LICENSE file in the top level directory.
//
#pragma once
// Obj-C headers
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
// Runtime headers
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
// MPS headers
#include <executorch/backends/apple/mps/runtime/operations/MPSGraphSequoiaOps.h>
#include <executorch/backends/apple/mps/runtime/operations/MPSGraphVenturaOps.h>
#include <executorch/backends/apple/mps/runtime/operations/OperationUtils.h>
#include <executorch/backends/apple/mps/schema_generated.h>
#include <unordered_map>
#include <vector>
namespace executorch {
namespace backends {
namespace mps {
namespace delegate {
using Error = executorch::runtime::Error;
using DataType = mpsgraph::MPSDataType;
using TensorPtr = const mpsgraph::MPSTensor *;
using NodePtr = const mpsgraph::MPSNode *;
#define _DEFINE_MPS_OP(name) Error mps##name##Op(NodePtr nodePtr);
/**
* Helper class to construct a MPSGraph object from a serialized MPS FlatBuffer model.
* It records all the input placeholders, lifted weights/biases and output feeds.
*/
class MPSGraphBuilder {
public:
MPSGraphBuilder(const void *buffer_pointer, size_t num_bytes,
std::unordered_map<MPSGraphTensor *, int32_t> &mpsGraphTensorToId);
~MPSGraphBuilder() = default;
Error compileModel();
MPSGraph *getMPSGraph();
MPSGraphExecutable *getMPSGraphExecutable();
private:
// Input feeds & constant ops
Error mpsGraphRankedPlaceholder(int32_t id);
Error mpsConstantOp(int32_t id);
// Activation ops
_DEFINE_MPS_OP(HardTanh);
_DEFINE_MPS_OP(ReLU);
_DEFINE_MPS_OP(GELU);
_DEFINE_MPS_OP(LeakyReLU);
_DEFINE_MPS_OP(Softmax);
_DEFINE_MPS_OP(LogSoftmax);
// Arithmetic Binary Ops
_DEFINE_MPS_OP(Add);
_DEFINE_MPS_OP(Sub);
_DEFINE_MPS_OP(Mul);
_DEFINE_MPS_OP(Div);
_DEFINE_MPS_OP(Pow);
_DEFINE_MPS_OP(Fmod);
_DEFINE_MPS_OP(Remainder);
_DEFINE_MPS_OP(BitwiseAnd);
_DEFINE_MPS_OP(BitwiseOr);
_DEFINE_MPS_OP(BitwiseXor);
_DEFINE_MPS_OP(Minimum);
// Comparison ops
_DEFINE_MPS_OP(Eq);
_DEFINE_MPS_OP(Ne);
_DEFINE_MPS_OP(Ge);
_DEFINE_MPS_OP(Gt);
_DEFINE_MPS_OP(Le);
_DEFINE_MPS_OP(Lt);
// Unary ops
_DEFINE_MPS_OP(Exp);
_DEFINE_MPS_OP(Exp2);
_DEFINE_MPS_OP(Reciprocal);
_DEFINE_MPS_OP(Sqrt);
_DEFINE_MPS_OP(Neg);
_DEFINE_MPS_OP(Log);
_DEFINE_MPS_OP(Log10);
_DEFINE_MPS_OP(Log2);
_DEFINE_MPS_OP(Erf);
_DEFINE_MPS_OP(Floor);
_DEFINE_MPS_OP(Ceil);
_DEFINE_MPS_OP(Rsqrt);
_DEFINE_MPS_OP(Sigmoid);
_DEFINE_MPS_OP(Sin);
_DEFINE_MPS_OP(Sign);
_DEFINE_MPS_OP(Cos);
_DEFINE_MPS_OP(Tan);
_DEFINE_MPS_OP(Abs);
_DEFINE_MPS_OP(Asin);
_DEFINE_MPS_OP(Acos);
_DEFINE_MPS_OP(Atan);
_DEFINE_MPS_OP(Sinh);
_DEFINE_MPS_OP(Cosh);
_DEFINE_MPS_OP(Tanh);
_DEFINE_MPS_OP(Asinh);
_DEFINE_MPS_OP(Acosh);
_DEFINE_MPS_OP(Atanh);
_DEFINE_MPS_OP(BitwiseNot);
_DEFINE_MPS_OP(Isnan);
_DEFINE_MPS_OP(Isinf);
_DEFINE_MPS_OP(Round);
_DEFINE_MPS_OP(LogicalNot);
_DEFINE_MPS_OP(NormCdf);
// Clamp ops
_DEFINE_MPS_OP(Clamp);
_DEFINE_MPS_OP(Where);
// BitWise ops
// Convolution ops
_DEFINE_MPS_OP(Conv2D);
_DEFINE_MPS_OP(DepthwiseConv2D);
// Indexing ops
_DEFINE_MPS_OP(IndexSelect);
_DEFINE_MPS_OP(Embedding);
_DEFINE_MPS_OP(IndexTensor);
_DEFINE_MPS_OP(IndexPut);
_DEFINE_MPS_OP(Scatter);
// Linear algebra ops
_DEFINE_MPS_OP(MatMul);
_DEFINE_MPS_OP(Addmm);
// Constant ops
_DEFINE_MPS_OP(Full);
_DEFINE_MPS_OP(FullLike);
// Normalization ops
_DEFINE_MPS_OP(BatchNorm);
_DEFINE_MPS_OP(LayerNorm);
// Reduce ops
_DEFINE_MPS_OP(Mean);
// Shape ops
_DEFINE_MPS_OP(Permute);
_DEFINE_MPS_OP(View);
_DEFINE_MPS_OP(Expand);
_DEFINE_MPS_OP(Cat);
_DEFINE_MPS_OP(Squeeze);
_DEFINE_MPS_OP(Unsqueeze);
_DEFINE_MPS_OP(Select);
_DEFINE_MPS_OP(Slice);
_DEFINE_MPS_OP(PixelShuffle);
_DEFINE_MPS_OP(SplitWithSizes);
_DEFINE_MPS_OP(Cast);
// Pooling ops
_DEFINE_MPS_OP(MaxPool2DWithIndices);
_DEFINE_MPS_OP(AvgPool2D);
// Pad ops
_DEFINE_MPS_OP(ConstantPadND);
// Range ops
_DEFINE_MPS_OP(Arange);
// Quant-Dequant ops
_DEFINE_MPS_OP(DequantizePerChannelGroup);
// Helper functions
Error addNodeToMPSGraph(NodePtr nodePtr);
Error compileMetalKernel(NodePtr nodePtr);
MPSShape *getMPSShape(int32_t id);
MPSShape *getMPSShape(const flatbuffers::Vector<int32_t> *shape);
int64_t numel(const flatbuffers::Vector<int32_t> *shape);
MPSDataType getMPSDataType(int32_t id);
MPSDataType getMPSDataType(DataType serializedDataType);
MPSGraphTensor *getMPSGraphTensor(int32_t id);
NSData *getConstantData(int32_t id);
std::pair<float, float> getMinMaxValues(NodePtr nodePtr);
Error compileMPSGraph();
Error compileMetalKernel();
// Each MPSGraph op result in at least MPSGraphTensor being
// produced, which will be stored in this structure. Other ops
// can reference the saved tensor by the AOT id (1:1 mapping).
std::vector<MPSGraphTensor *> _idToMPSGraphTensor;
std::unordered_map<MPSGraphTensor *, int32_t> &_mpsGraphTensorToId;
// FlatBuffer serialized graph containing the nodes from the original model.
const mpsgraph::MPSGraph *_flatBufferGraph;
// FlatBuffer raw bytes of the serialized MPS model.
const void *_buffer_pointer;
size_t _num_bytes;
bool _metal_kernel;
MPSGraph *_mpsGraph;
MPSGraphExecutable *_mpsGraphExecutable;
NSMutableDictionary<MPSGraphTensor *, MPSGraphShapedType *> *_feeds;
NSMutableArray<MPSGraphTensor *> *_targetTensors;
const uint8_t *_constant_data_ptr;
};
#undef _DEFINE_MPS_OP
} // namespace delegate
} // namespace mps
} // namespace backends
} // namespace executorch