| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ |
| #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ |
| |
| #import <Metal/Metal.h> |
| |
| #include <map> |
| #include <set> |
| #include <string> |
| #include <vector> |
| |
| #include "tensorflow/lite/delegates/gpu/common/model.h" |
| #include "tensorflow/lite/delegates/gpu/common/precision.h" |
| #include "tensorflow/lite/delegates/gpu/common/shape.h" |
| #include "tensorflow/lite/delegates/gpu/common/status.h" |
| #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" |
| |
| @interface TFLComputeTask : NSObject |
| |
| /// Returns empty string or error if shader can't be compiled. |
| - (absl::Status)compileWithDevice:(id<MTLDevice>)device |
| taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc |
| precision:(tflite::gpu::CalculationsPrecision)precision; |
| |
| /// Updates parameters for inputs/outputs/intermediate tensors |
| - (absl::Status)updateParamsWithDevice:(id<MTLDevice>)device |
| tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&) |
| tensorShapes; |
| |
| /// Updates buffers for intermediate tensors only. Returns error if out of memory or a buffer is |
| /// larger than MTLDevice can support. |
| /// @param buffers is a map from intermediate tensors' ValueId to metal handles with corresponding |
| /// buffers. |
| /// @param outputIDs must match the output of added operations. |
| /// @param usageRecordIds is a map from intermediate tensors' ValueId to corresponding tensor usage |
| /// records ids. |
| /// @param sharedBufferIds contain shared buffer id for each tensor usage record id. |
| /// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id. |
| /// TODO(ypisarchyk): probably we can decrease the number of parameters here |
| - (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers |
| outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds |
| usageRecordIds:(const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds |
| sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds |
| sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers; |
| |
| - (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids; |
| |
| - (void)updateBuffers:(const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers; |
| |
| - (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder; |
| |
| - (std::vector<tflite::gpu::ValueId>)getOutputIds; |
| - (std::vector<tflite::gpu::ValueId>)getInputIds; |
| |
| - (void)setDescription:(const std::string&)description; |
| |
| @end |
| |
| #endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ |