blob: f1c9d75e62a803a1c61f69a7fc9e1b7dd0310d20 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_format.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetKernelWinograd4x4To36() {
std::string c;
c += R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int2 padding;
int2 tiles_count;
};
)";
auto bt_mat = BtMatrixForWinograd4x4To6x6();
c += "constant FLT Bt[36] = {\n";
for (int y = 0; y < 6; ++y) {
c += "\t";
for (int x = 0; x < 6; ++x) {
c += absl::StrFormat("%.10f", bt_mat[y * 6 + x]) + "f, ";
}
c += "\n";
}
c += "};\n";
c += R"(
$0
kernel void ComputeFunction($1
uint3 ugid[[thread_position_in_grid]])
{
int3 gid = int3(ugid.x * 4, ugid.y * 4, ugid.z);
if (ugid.x >= U.tiles_count.x || ugid.y >= U.tiles_count.y) return;
FLT4 I[6][6];
for (int y = 0; y < 6; ++y) {
for (int x = 0; x < 6; ++x) {
I[y][x] = FLT4(0.0f);
}
}
const int src_base = gid.z * U.src_size.y * U.src_size.x;
)";
for (int y = 0; y < 6; ++y) {
const std::string s_y = std::to_string(y);
c += " {\n";
c += " int coord_y = gid.y + " + s_y + " + U.padding.y;\n";
c += " bool in_y = FLT(coord_y >= 0 && coord_y < U.src_size.y);\n";
c += " coord_y = clamp(coord_y, 0, U.src_size.y - 1);\n";
c += " const int src_adress_y = src_base + coord_y * U.src_size.x;\n";
for (int x = 0; x < 6; ++x) {
const std::string s_x = std::to_string(x);
c += " {\n";
c += " int coord_x = gid.x + " + s_x + " + U.padding.x;\n";
c += " bool in_x = FLT(coord_x >= 0 && coord_x < U.src_size.x);\n";
c += " FLT mult = FLT(in_y && in_x);\n";
c += " coord_x = clamp(coord_x, 0, U.src_size.x - 1);\n";
c += " FLT4 src = src_buffer[src_adress_y + coord_x] * mult;\n";
c += " I[0][" + s_x + "] += Bt[" + std::to_string(y) + "] * src;\n";
c += " I[1][" + s_x + "] += Bt[" + std::to_string(y + 6) +
"] * src;\n";
c += " I[2][" + s_x + "] += Bt[" + std::to_string(y + 12) +
"] * src;\n";
c += " I[3][" + s_x + "] += Bt[" + std::to_string(y + 18) +
"] * src;\n";
c += " I[4][" + s_x + "] += Bt[" + std::to_string(y + 24) +
"] * src;\n";
c += " I[5][" + s_x + "] += Bt[" + std::to_string(y + 30) +
"] * src;\n";
c += " }\n";
}
c += " }\n";
}
c += R"(
int dst_x = ugid.y * U.tiles_count.x + ugid.x;
int dst_adress = gid.z * U.dst_size.y * U.dst_size.x + dst_x;
for (int y = 0; y < 6; ++y) {
dst_buffer[dst_adress] = I[y][0] + Bt[2] * I[y][2] + Bt[4] * I[y][4];
dst_adress += U.dst_size.x;
dst_buffer[dst_adress] = Bt[7] * I[y][1] + Bt[8] * I[y][2] + Bt[9] * I[y][3] + Bt[10] * I[y][4];
dst_adress += U.dst_size.x;
dst_buffer[dst_adress] = Bt[13] * I[y][1] + Bt[14] * I[y][2] + Bt[15] * I[y][3] + Bt[16] * I[y][4];
dst_adress += U.dst_size.x;
dst_buffer[dst_adress] = Bt[19] * I[y][1] + Bt[20] * I[y][2] + Bt[21] * I[y][3] + Bt[22] * I[y][4];
dst_adress += U.dst_size.x;
dst_buffer[dst_adress] = Bt[25] * I[y][1] + Bt[26] * I[y][2] + Bt[27] * I[y][3] + Bt[28] * I[y][4];
dst_adress += U.dst_size.x;
dst_buffer[dst_adress] = Bt[31] * I[y][1] + Bt[33] * I[y][3] + I[y][5];
dst_adress += U.dst_size.x;
}
}
)";
return c;
}
std::string GetKernelWinograd36To4x4() {
std::string c;
c += R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
};
)";
auto at_mat = AtMatrixForWinograd4x4To6x6();
c += "constant FLT At[24] = {\n";
for (int y = 0; y < 4; ++y) {
c += "\t";
for (int x = 0; x < 6; ++x) {
c += absl::StrFormat("%.10f", at_mat[y * 6 + x]) + "f, ";
}
c += "\n";
}
c += "};\n";
c += R"(
$0
kernel void ComputeFunction($1
uint3 global_ids[[thread_position_in_grid]])
{
int tile_id = global_ids.x;
int Z = static_cast<int>(global_ids.z);
int tiles_count_x = (U.dst_size.x + 3) / 4;
int tile_x = (tile_id % tiles_count_x) * 4;
int tile_y = (tile_id / tiles_count_x) * 4;
if (tile_x >= U.dst_size.x || tile_y >= U.dst_size.y) return;
int src_adress = Z * U.src_size.y * U.src_size.x + tile_id;
FLT4 I[4][6];
for (int y = 0; y < 4; ++y) {
for (int x = 0; x < 6; ++x) {
I[y][x] = 0.0f;
}
}
for (int y = 0; y < 6; ++y) {
for (int x = 0; x < 6; ++x, src_adress += U.src_size.x) {
FLT4 src = src_buffer[src_adress];
I[0][x] += src * At[y];
I[1][x] += src * At[y + 6];
I[2][x] += src * At[y + 12];
I[3][x] += src * At[y + 18];
}
}
FLT4 bias_val = biases[Z];
int dst_adress = (Z * U.dst_size.y + tile_y) * U.dst_size.x + tile_x;
for (int y = 0; y < 4 && tile_y + y < U.dst_size.y; ++y) {
FLT4 t0 = I[y][1] + I[y][2];
FLT4 t1 = I[y][3] + I[y][4];
if (tile_x < U.dst_size.x) {
FLT4 value = I[y][0] + t0 + t1 + bias_val;
int linear_index = dst_adress;
uint3 gid = uint3(tile_x, tile_y + y, global_ids.z);
$2
dst_buffer[linear_index] = value;
}
FLT4 t2 = I[y][1] - I[y][2];
FLT4 t3 = I[y][3] - I[y][4];
if (tile_x + 1 < U.dst_size.x) {
FLT4 value = t2 * At[7] + t3 * At[9] + bias_val;
int linear_index = dst_adress + 1;
uint3 gid = uint3(tile_x + 1, tile_y + y, global_ids.z);
$2
dst_buffer[linear_index] = value;
}
if (tile_x + 2 < U.dst_size.x) {
FLT4 value = t0 * At[13] + t1 * At[15] + bias_val;
int linear_index = dst_adress + 2;
uint3 gid = uint3(tile_x + 2, tile_y + y, global_ids.z);
$2
dst_buffer[linear_index] = value;
}
if (tile_x + 3 < U.dst_size.x) {
FLT4 value = t2 * At[19] + t3 * At[21] + I[y][5] + bias_val;
uint3 gid = uint3(tile_x + 3, tile_y + y, global_ids.z);
int linear_index = dst_adress + 3;
$2
dst_buffer[linear_index] = value;
}
dst_adress += U.dst_size.x;
}
}
)";
return c;
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> Winograd4x4To36(
int id, ValueId input_id, ValueId output_id,
const Winograd4x4To36Attributes& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetKernelWinograd4x4To36();
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) {
const auto src_shape = buffers.find(input_id)->second;
int new_width = src_shape.w + attr.padding.prepended.w +
attr.padding.appended.w - 2;
int new_height = src_shape.h + attr.padding.prepended.h +
attr.padding.appended.h - 2;
BHWC dst_shape;
dst_shape.b = src_shape.b;
dst_shape.h = 36;
dst_shape.w = IntegralDivideRoundUp(new_width, 4) *
IntegralDivideRoundUp(new_height, 4);
dst_shape.c = src_shape.c;
return dst_shape;
}};
desc->uniform_buffers = {
{"constant uniforms& U",
[input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
const auto& src_shape = buffers.find(input_id)->second;
const auto& dst_shape = buffers.find(output_id)->second;
int new_width = src_shape.w + attr.padding.prepended.w +
attr.padding.appended.w - 2;
int new_height = src_shape.h + attr.padding.prepended.h +
attr.padding.appended.h - 2;
int tiles_x = IntegralDivideRoundUp(new_width, 4);
int tiles_y = IntegralDivideRoundUp(new_height, 4);
std::vector<int> sizes = {
src_shape.w,
src_shape.h,
IntegralDivideRoundUp(src_shape.c, 4),
0,
dst_shape.w,
dst_shape.h,
IntegralDivideRoundUp(dst_shape.c, 4),
0,
-attr.padding.prepended.w,
-attr.padding.prepended.h,
tiles_x,
tiles_y,
};
return GetByteBuffer(sizes);
}},
};
desc->resize_function = [input_id,
attr](const std::map<ValueId, BHWC>& buffers) {
const uint3 groups_size{8, 4, 1};
const auto& src_shape = buffers.find(input_id)->second;
int new_width =
src_shape.w + attr.padding.prepended.w + attr.padding.appended.w - 2;
int new_height =
src_shape.h + attr.padding.prepended.h + attr.padding.appended.h - 2;
int grid_x = IntegralDivideRoundUp(new_width, 4);
int grid_y = IntegralDivideRoundUp(new_height, 4);
int grid_z = IntegralDivideRoundUp(src_shape.c, 4);
int groups_x = IntegralDivideRoundUp(grid_x, groups_size.x);
int groups_y = IntegralDivideRoundUp(grid_y, groups_size.y);
int groups_z = IntegralDivideRoundUp(grid_z, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> Winograd36To4x4(
int id, ValueId input_id, ValueId output_id, const RuntimeOptions& options,
const Winograd36To4x4Attributes& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetKernelWinograd36To4x4();
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) {
const auto src_shape = buffers.find(input_id)->second;
BHWC dst_shape;
dst_shape.b = src_shape.b;
dst_shape.h = attr.output_shape.h;
dst_shape.w = attr.output_shape.w;
dst_shape.c = src_shape.c;
return dst_shape;
}};
desc->immutable_buffers = {
{"device FLT4* const biases",
GetByteBufferConvertedResized(attr.biases.data,
options.storage_precision,
AlignByN(attr.output_shape.c, 4))},
};
desc->uniform_buffers = {
{"constant uniforms& U",
[input_id, output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& src_shape = buffers.find(input_id)->second;
const auto& dst_shape = buffers.find(output_id)->second;
std::vector<int> sizes = {
src_shape.w, src_shape.h, IntegralDivideRoundUp(src_shape.c, 4), 0,
dst_shape.w, dst_shape.h, IntegralDivideRoundUp(dst_shape.c, 4), 0,
};
return GetByteBuffer(sizes);
}},
};
desc->resize_function = [input_id](const std::map<ValueId, BHWC>& buffers) {
const uint3 groups_size{32, 1, 1};
const auto& src_shape = buffers.find(input_id)->second;
int grid_x = src_shape.w;
int grid_y = 1;
int grid_z = IntegralDivideRoundUp(src_shape.c, 4);
int groups_x = IntegralDivideRoundUp(grid_x, groups_size.x);
int groups_y = IntegralDivideRoundUp(grid_y, groups_size.y);
int groups_z = IntegralDivideRoundUp(grid_z, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite