blob: 388ca951d4d5c1076b5600b784b236008fd83e21 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
#include <Availability.h>
#include <string>
#include <tuple>
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/common.h"
using ::tflite::gpu::AlignByN;
using ::tflite::gpu::BHWC;
using ::tflite::gpu::HalfBits;
using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
using ::tflite::gpu::metal::CreateComputeProgram;
using ::tflite::gpu::metal::DispatchParamsFunction;
using ::tflite::gpu::CalculationsPrecision;
using ::tflite::gpu::metal::UniformsFunction;
using ::tflite::gpu::uint3;
using ::tflite::gpu::ValueId;
namespace {
struct InputBuffer {
ValueId uid;
id<MTLBuffer> metalHandle;
};
struct OutputBuffer {
ValueId uid;
id<MTLBuffer> metalHandle;
};
struct UniformBuffer {
std::vector<uint8_t> data;
UniformsFunction dataFunction;
};
} // namespace
@implementation TFLComputeTask {
id<MTLComputePipelineState> _program;
std::vector<InputBuffer> _inputBuffers;
std::vector<OutputBuffer> _outputBuffers;
std::vector<id<MTLBuffer>> _immutableBuffers;
std::vector<UniformBuffer> _uniformBuffers;
uint3 _groupsSize;
uint3 _groupsCount;
DispatchParamsFunction _resizeFunction;
std::string _description;
tflite::gpu::metal::MetalArguments _metal_args;
}
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
taskDescriptor:(const tflite::gpu::metal::NodeDescriptor&)desc
precision:(CalculationsPrecision)precision; {
size_t offset = desc.task->src_tensors_names.size() + desc.task->uniform_buffers.size()
+ desc.task->immutable_buffers.size() + 1;
RETURN_IF_ERROR(_metal_args.Init(device, offset, &desc.task->args, &desc.task->shader_source));
NSString* barrier;
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
barrier = @"simdgroup_barrier";
} else {
barrier = @"threadgroup_barrier";
}
NSString* storageType;
NSString* accumulatorType;
NSString* toAccumulatorType = @"";
NSString* toAccumulatorType2 = @"";
NSString* toAccumulatorType3 = @"";
NSString* toAccumulatorType4 = @"";
if (precision == CalculationsPrecision::F32) {
storageType = @"float";
accumulatorType = @"float";
} else {
// FP16
storageType = @"half";
if (precision == CalculationsPrecision::F32_F16) {
accumulatorType = @"float";
toAccumulatorType = @"float";
toAccumulatorType2 = @"float2";
toAccumulatorType3 = @"float3";
toAccumulatorType4 = @"float4";
} else {
accumulatorType = @"half";
}
}
NSDictionary<NSString*, NSString*>* macros = @{
@"FLT" : storageType,
@"FLT2" : [NSString stringWithFormat:@"%@2", storageType],
@"FLT3" : [NSString stringWithFormat:@"%@3", storageType],
@"FLT4" : [NSString stringWithFormat:@"%@4", storageType],
@"ACCUM_FLT" : accumulatorType,
@"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType],
@"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType],
@"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType],
@"TO_ACCUM_TYPE" : toAccumulatorType,
@"TO_ACCUM2_TYPE" : toAccumulatorType2,
@"TO_ACCUM3_TYPE" : toAccumulatorType3,
@"TO_ACCUM4_TYPE" : toAccumulatorType4,
@"SIMDGROUP_BARRIER" : barrier,
};
NSString* code = [NSString stringWithCString:desc.task->shader_source.c_str()
encoding:[NSString defaultCStringEncoding]];
id<MTLComputePipelineState> program;
RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program));
if (!program) {
return absl::InternalError("Unknown shader compilation error");
}
for (auto& id : desc.src_tensors_ids) {
_inputBuffers.emplace_back(InputBuffer{id, nil});
}
for (auto& uniform : desc.task->uniform_buffers) {
_uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function});
}
_outputBuffers.emplace_back(OutputBuffer{desc.dst_tensors_ids[0], nil});
const bool f32_storage = precision == CalculationsPrecision::F32;
for (auto& immutable : desc.task->immutable_buffers) {
int padding = 4 * (f32_storage ? sizeof(float) : sizeof(HalfBits));
int paddedSize = AlignByN(immutable.data.size(), padding);
immutable.data.resize(paddedSize);
id<MTLBuffer> metalBuffer = [device newBufferWithBytes:immutable.data.data()
length:immutable.data.size()
options:MTLResourceStorageModeShared];
_immutableBuffers.emplace_back(metalBuffer);
}
_resizeFunction = desc.task->resize_function;
_program = program;
return absl::OkStatus();
}
- (absl::Status)
updateParamsWithDevice:(id<MTLDevice>)device
tensorShapes:(const std::map<tflite::gpu::ValueId, tflite::gpu::BHWC>&)tensorShapes {
std::vector<BHWC> src_shapes;
std::vector<BHWC> dst_shapes;
for (const auto& in_buf : _inputBuffers) {
auto it = tensorShapes.find(in_buf.uid);
if (it == tensorShapes.end()) {
return absl::InvalidArgumentError("Missing tensor shape");
}
src_shapes.push_back(it->second);
}
for (const auto& out_buf : _outputBuffers) {
auto it = tensorShapes.find(out_buf.uid);
if (it == tensorShapes.end()) {
return absl::InvalidArgumentError("Missing tensor shape");
}
dst_shapes.push_back(it->second);
}
for (auto& uniform : _uniformBuffers) {
uniform.data = uniform.dataFunction(src_shapes, dst_shapes);
}
// Dispatch parameters re-calculation
auto workGroups = _resizeFunction(src_shapes, dst_shapes);
_groupsSize = workGroups.first;
MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup];
if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height ||
_groupsSize.z > threadsPerGroup.depth) {
std::string error("Threads per working group: ");
error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " +
std::to_string(_groupsSize.z);
error += "is larger than the MTLDevice can support: ";
error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) +
", " + std::to_string(threadsPerGroup.depth);
return absl::InvalidArgumentError(error);
}
_groupsCount = workGroups.second;
return absl::OkStatus();
}
- (absl::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers
outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds
usageRecordIds:(const std::map<ValueId, size_t>&)usageRecordIds
sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds
sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers {
for (auto& buffer : _outputBuffers) {
// If the buffer is intermediate: set its metalHandle from sharedBuffers
if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) {
auto usageRecordIt = usageRecordIds.find(buffer.uid);
if (usageRecordIt == usageRecordIds.end()) {
return absl::InternalError("TensorUsageRecord for intermediate tensor is not found.");
}
buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second));
(*buffers)[buffer.uid] = buffer.metalHandle;
}
}
// Re-assign input buffers
for (auto& buffer : _inputBuffers) {
buffer.metalHandle = (*buffers)[buffer.uid];
}
return absl::OkStatus();
}
- (bool)hasInOutIds:(const std::set<::tflite::gpu::ValueId>&)ids {
for (auto& buffer : _inputBuffers) {
if (ids.count(buffer.uid)) {
return true;
}
}
for (auto& buffer : _outputBuffers) {
if (ids.count(buffer.uid)) {
return true;
}
}
return false;
}
- (void)updateBuffers:(const std::map<::tflite::gpu::ValueId, id<MTLBuffer>>&)inputOutputBuffers {
for (auto& buffer : _inputBuffers) {
const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
if (externalBuffer != inputOutputBuffers.end()) {
buffer.metalHandle = externalBuffer->second;
}
}
for (auto& buffer : _outputBuffers) {
const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
if (externalBuffer != inputOutputBuffers.end()) {
buffer.metalHandle = externalBuffer->second;
}
}
}
- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder {
// The dispatch call is intended to be skipped.
if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) {
return;
}
[encoder setComputePipelineState:_program];
int bindIndex = 0;
for (const auto& buffer : _outputBuffers) {
[encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
bindIndex++;
}
for (const auto& buffer : _inputBuffers) {
[encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
bindIndex++;
}
for (auto& immutable : _immutableBuffers) {
[encoder setBuffer:immutable offset:0 atIndex:bindIndex];
bindIndex++;
}
for (auto& uniform : _uniformBuffers) {
[encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
bindIndex++;
}
_metal_args.Encode(encoder, bindIndex);
MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
[encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
}
- (std::vector<tflite::gpu::ValueId>)getOutputIds {
std::vector<tflite::gpu::ValueId> result;
for (auto& buffer : _outputBuffers) {
result.push_back(buffer.uid);
}
return result;
}
- (std::vector<tflite::gpu::ValueId>)getInputIds {
std::vector<tflite::gpu::ValueId> result;
for (auto& buffer : _inputBuffers) {
result.push_back(buffer.uid);
}
return result;
}
- (void)setDescription:(const std::string&)description {
_description = description;
}
@end