blob: 24d7bcf13bc92b8ff787010cd1750873e6005d0e [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/resize.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
namespace tflite {
namespace gpu {
namespace metal {
std::string GetResizeBilinearCode(bool half_pixel_centers) {
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
return;
})";
if (half_pixel_centers) {
code += "const float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
} else {
code += "const float2 tex_coord = float2(gid.xy) * scale;";
}
code += R"(
const float2 tex_coord_floor = floor(tex_coord);
const int2 itex_coord_floor = int2(tex_coord_floor);
const int2 borders = size.xy - int2(1, 1);
int4 st;
st.xy = max(itex_coord_floor, int2(0, 0));
st.zw = min(itex_coord_floor + int2(1, 1), borders);
const float2 t = tex_coord - tex_coord_floor; // interpolating factors
const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x;
const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z;
const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x;
const int src_index3 = (gid.z * size.y + st.w) * size.x + st.z;
FLT4 tex11 = src_buffer[src_index0];
FLT4 tex21 = src_buffer[src_index1];
FLT4 tex12 = src_buffer[src_index2];
FLT4 tex22 = src_buffer[src_index3];
// bilinear interpolation
FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)),
mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y));
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
$2
output_buffer[linear_index] = value;
}
)";
return code;
}
std::string GetResizeNearestCode() {
return R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
return;
}
const int2 coord = int2(float2(gid.xy) * scale);
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x;
FLT4 value = src_buffer[src_index];
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
$2
output_buffer[linear_index] = value;
}
)";
}
std::vector<ComputeTaskDescriptorPtr> Resize(int id, ValueId input_id,
ValueId output_id,
const Resize2DAttributes& attr) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
switch (attr.type) {
case SamplingType::BILINEAR:
desc->shader_source = GetResizeBilinearCode(attr.half_pixel_centers);
break;
case SamplingType::NEAREST:
desc->shader_source = GetResizeNearestCode();
break;
default:
// Unknown sampling type
return {};
}
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* output_buffer",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) {
return CalculateOutputShape(buffers.find(input_id)->second, attr);
}};
desc->uniform_buffers = {
{"constant int4& size",
[input_id, output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(input_id)->second;
const auto& output_dimension = buffers.find(output_id)->second;
std::vector<int> sizes = {
dimension.w,
dimension.h,
output_dimension.w,
output_dimension.h,
};
return GetByteBuffer(sizes);
}},
{"constant float2& scale",
[input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
std::vector<float> sizes = {
CalculateResizeScale(input_dimensions.w, output_dimensions.w,
attr),
CalculateResizeScale(input_dimensions.h, output_dimensions.h,
attr),
};
return GetByteBuffer(sizes);
}},
};
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const uint3 groups_size{16, 16, 1};
const auto& dst_dim = buffers.find(output_id)->second;
int groups_x = IntegralDivideRoundUp(dst_dim.w, groups_size.x);
int groups_y = IntegralDivideRoundUp(dst_dim.h, groups_size.y);
const int dst_layers = IntegralDivideRoundUp(dst_dim.c, 4);
int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite