blob: 67bd9893cb1e46425ee164df462bd1fef99783ee [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include <cstdint>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
int dst_channels) {
const int src_depth = IntegralDivideRoundUp(src_channels, 4);
std::stringstream code;
code << R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
uint src_depth;
uint dst_channels;
uint out_channels;
uint dummy;
};
$$0
kernel void ComputeFunction(
$$1
uint3 tid[[thread_position_in_threadgroup]],
uint tid_index[[thread_index_in_threadgroup]],
uint3 ugid[[thread_position_in_grid]]) {
)";
if (shared_memory) {
code << R"(
float summa = 0.0f;
threadgroup FLT4 local_vector[32];
for (int j = 0; j < $0; ++j) {
local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ?
FLT4(0.0f) : vector[j * 32 + tid_index];
BARRIER(mem_flags::mem_threadgroup);
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]);
}
BARRIER(mem_flags::mem_none);
}
)";
} else {
code << R"(
float summa = 0.0f;
uint counter = ugid.y * $0;
for (uint i = 0; i < $0; ++i, ++counter) {
)";
if (src_depth % 4 != 0) {
code << " if (counter >= params.src_depth) continue;" << std::endl;
}
code << " summa += dot(vector[counter], matrix[counter * "
"params.dst_channels + ugid.x]);"
<< std::endl;
code << " }" << std::endl;
}
code << R"(
threadgroup float temp[8][4];
temp[tid.x][tid.y] = summa;
BARRIER(mem_flags::mem_threadgroup);
if (tid.y == 0) {
summa += temp[tid.x][1];
summa += temp[tid.x][2];
summa += temp[tid.x][3];
temp[tid.x][0] = summa;
}
BARRIER(mem_flags::mem_threadgroup);
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) {
const int linear_index = ugid.x / 4;
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
biases[linear_index];
uint3 gid = uint3(1u, 1u, uint(linear_index));
$$2
result[linear_index] = value;
}
}
)";
const int src_depth_sub_groups = shared_memory
? IntegralDivideRoundUp(src_depth, 32)
: IntegralDivideRoundUp(src_depth, 4);
return absl::Substitute(code.str(), src_depth_sub_groups);
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> FullyConnected(
int id, ValueId input_id, ValueId output_id,
const FullyConnectedAttributes& attr, const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
int gpu_type = GetAppleSocVersion();
bool shared = gpu_type == 7 || gpu_type == 8;
desc->shader_source =
GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o);
desc->input_buffers = {
{input_id, "device FLT4* const vector"},
};
desc->output_buffer = {
output_id, "device FLT4* result",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) {
return CalculateOutputShape(buffers.find(input_id)->second, attr);
}};
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
const int src_depth_aligned = AlignByN(src_depth, shared ? 32 : 4);
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
int counter = 0;
std::vector<float> filters_reordered(dst_channels_aligned *
src_depth_aligned * 4);
for (int j = 0; j < src_depth_aligned; ++j) {
for (int i = 0; i < dst_channels_aligned; ++i) {
for (int k = 0; k < 4; ++k) {
if (j * 4 + k >= attr.weights.shape.i || i >= attr.weights.shape.o) {
filters_reordered[counter++] = 0.0f;
} else {
const int f_index =
attr.weights.shape.LinearIndex({i, 0, 0, j * 4 + k});
filters_reordered[counter++] = attr.weights.data[f_index];
}
}
}
}
desc->immutable_buffers = {
{"device FLT4* const matrix",
GetByteBufferConverted(filters_reordered, options.storage_precision)},
{"device FLT4* const biases",
GetByteBufferConvertedResized(attr.bias.data, options.storage_precision,
attr.weights.shape.o)},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[attr](const std::map<ValueId, BHWC>& buffers) {
std::vector<uint32_t> uniform_params{
static_cast<uint32_t>(
IntegralDivideRoundUp(attr.weights.shape.i, 4)),
static_cast<uint32_t>(AlignByN(attr.weights.shape.o, 8)),
static_cast<uint32_t>(attr.weights.shape.o),
static_cast<uint32_t>(0),
};
return GetByteBuffer(uniform_params);
}},
};
desc->resize_function = [attr](const std::map<ValueId, BHWC>& buffers) {
const uint3 groups_size{8, 4, 1};
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
int groups_x = IntegralDivideRoundUp(dst_channels_aligned, groups_size.x);
return std::make_pair(groups_size, uint3{groups_x, 1, 1});
};
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite