blob: 73e9d81c76fec35f820b3d717f2e457270cb6d47 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_
#import <Metal/Metal.h>
#include <map>
#include <set>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/precision.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@interface TFLComputeTask : NSObject
/// Returns empty string or error if shader can't be compiled.
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
precision:(tflite::gpu::CalculationsPrecision)precision;
/// Updates parameters for inputs/outputs/intermediate tensors
- (absl::Status)updateParamsWithDevice:(id<MTLDevice>)device
tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)
tensorShapes;
/// Updates buffers for intermediate tensors only. Returns error if out of memory or a buffer is
/// larger than MTLDevice can support.
/// @param buffers is a map from intermediate tensors' ValueId to metal handles with corresponding
/// buffers.
/// @param outputIDs must match the output of added operations.
/// @param usageRecordIds is a map from intermediate tensors' ValueId to corresponding tensor usage
/// records ids.
/// @param sharedBufferIds contain shared buffer id for each tensor usage record id.
/// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id.
/// TODO(ypisarchyk): probably we can decrease the number of parameters here
- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers
outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds
usageRecordIds:(const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds
sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds
sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers;
- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids;
- (void)updateBuffers:(const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers;
- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder;
- (std::vector<tflite::gpu::ValueId>)getOutputIds;
- (std::vector<tflite::gpu::ValueId>)getInputIds;
- (void)setDescription:(const std::string&)description;
@end
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_