blob: 9fa627bcac233f2cd057ec3d9ae9c7ffa838d213 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "tensorflow/lite/delegates/gpu/common/convert.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetKernelDepthWiseConv3x3Stride1x1() {
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int2 padding;
int2 dummy0; // for alignment
int4 dummy1; // for alignment
};
$0
kernel void ComputeFunction(
$1
uint3 ugid[[thread_position_in_grid]])
{
int gid_x = ugid.x * 2;
int gid_y = ugid.y * 2;
int gid_z = ugid.z;
if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) {
return;
}
ACCUM_FLT4 r0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
ACCUM_FLT4 l0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
ACCUM_FLT4 t0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
ACCUM_FLT4 b0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
int x0 = gid_x + params.padding.x;
int x1 = gid_x + params.padding.x + 1;
int x2 = gid_x + params.padding.x + 2;
int x3 = gid_x + params.padding.x + 3;
int y0 = gid_y + params.padding.y;
int y1 = gid_y + params.padding.y + 1;
int y2 = gid_y + params.padding.y + 2;
int y3 = gid_y + params.padding.y + 3;
bool x0_out = x0 < 0 || x0 >= params.src_size.x;
bool x1_out = x1 < 0 || x1 >= params.src_size.x;
bool x2_out = x2 < 0 || x2 >= params.src_size.x;
bool x3_out = x3 < 0 || x3 >= params.src_size.x;
bool y0_out = y0 < 0 || y0 >= params.src_size.y;
bool y1_out = y1 < 0 || y1 >= params.src_size.y;
bool y2_out = y2 < 0 || y2 >= params.src_size.y;
bool y3_out = y3 < 0 || y3 >= params.src_size.y;
x0 = clamp(x0, 0, params.src_size.x - 1);
x1 = clamp(x1, 0, params.src_size.x - 1);
x2 = clamp(x2, 0, params.src_size.x - 1);
x3 = clamp(x3, 0, params.src_size.x - 1);
y0 = clamp(y0, 0, params.src_size.y - 1);
y1 = clamp(y1, 0, params.src_size.y - 1);
y2 = clamp(y2, 0, params.src_size.y - 1);
y3 = clamp(y3, 0, params.src_size.y - 1);
device FLT4* src_loc = src_buffer + gid_z * params.src_size.z;
device FLT4* filters_loc = filters + gid_z * 10;
FLT4 s0 = src_loc[y0 * params.src_size.x + x0] * FLT(!(x0_out || y0_out));
FLT4 s1 = src_loc[y1 * params.src_size.x + x0] * FLT(!(x0_out || y1_out));
FLT4 s2 = src_loc[y2 * params.src_size.x + x0] * FLT(!(x0_out || y2_out));
FLT4 s3 = src_loc[y3 * params.src_size.x + x0] * FLT(!(x0_out || y3_out));
r0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]);
r0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]);
r0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]);
l0 += TO_ACCUM4_TYPE(s1 * filters_loc[0]);
l0 += TO_ACCUM4_TYPE(s2 * filters_loc[1]);
l0 += TO_ACCUM4_TYPE(s3 * filters_loc[2]);
s0 = src_loc[y0 * params.src_size.x + x1] * FLT(!(x1_out || y0_out));
s1 = src_loc[y1 * params.src_size.x + x1] * FLT(!(x1_out || y1_out));
s2 = src_loc[y2 * params.src_size.x + x1] * FLT(!(x1_out || y2_out));
s3 = src_loc[y3 * params.src_size.x + x1] * FLT(!(x1_out || y3_out));
r0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]);
r0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]);
r0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]);
l0 += TO_ACCUM4_TYPE(s1 * filters_loc[3]);
l0 += TO_ACCUM4_TYPE(s2 * filters_loc[4]);
l0 += TO_ACCUM4_TYPE(s3 * filters_loc[5]);
t0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]);
t0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]);
t0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]);
b0 += TO_ACCUM4_TYPE(s1 * filters_loc[0]);
b0 += TO_ACCUM4_TYPE(s2 * filters_loc[1]);
b0 += TO_ACCUM4_TYPE(s3 * filters_loc[2]);
s0 = src_loc[y0 * params.src_size.x + x2] * FLT(!(x2_out || y0_out));
s1 = src_loc[y1 * params.src_size.x + x2] * FLT(!(x2_out || y1_out));
s2 = src_loc[y2 * params.src_size.x + x2] * FLT(!(x2_out || y2_out));
s3 = src_loc[y3 * params.src_size.x + x2] * FLT(!(x2_out || y3_out));
r0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]);
r0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]);
r0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]);
l0 += TO_ACCUM4_TYPE(s1 * filters_loc[6]);
l0 += TO_ACCUM4_TYPE(s2 * filters_loc[7]);
l0 += TO_ACCUM4_TYPE(s3 * filters_loc[8]);
t0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]);
t0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]);
t0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]);
b0 += TO_ACCUM4_TYPE(s1 * filters_loc[3]);
b0 += TO_ACCUM4_TYPE(s2 * filters_loc[4]);
b0 += TO_ACCUM4_TYPE(s3 * filters_loc[5]);
s0 = src_loc[y0 * params.src_size.x + x3] * FLT(!(x3_out || y0_out));
s1 = src_loc[y1 * params.src_size.x + x3] * FLT(!(x3_out || y1_out));
s2 = src_loc[y2 * params.src_size.x + x3] * FLT(!(x3_out || y2_out));
s3 = src_loc[y3 * params.src_size.x + x3] * FLT(!(x3_out || y3_out));
t0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]);
t0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]);
t0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]);
b0 += TO_ACCUM4_TYPE(s1 * filters_loc[6]);
b0 += TO_ACCUM4_TYPE(s2 * filters_loc[7]);
b0 += TO_ACCUM4_TYPE(s3 * filters_loc[8]);
r0 += TO_ACCUM4_TYPE(filters_loc[9]);
l0 += TO_ACCUM4_TYPE(filters_loc[9]);
t0 += TO_ACCUM4_TYPE(filters_loc[9]);
b0 += TO_ACCUM4_TYPE(filters_loc[9]);
const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x;
const int offset_1 = offset_0 + params.dst_size.x;
const int offset_2 = offset_0 + 1;
const int offset_3 = offset_0 + params.dst_size.x + 1;
bool x0_in = gid_x < params.dst_size.x;
bool x1_in = gid_x + 1 < params.dst_size.x;
bool y0_in = gid_y < params.dst_size.y;
bool y1_in = gid_y + 1 < params.dst_size.y;
if (y0_in && x0_in) {
int linear_index = offset_0;
FLT4 value = FLT4(r0);
uint3 gid = uint3(gid_x, gid_y, gid_z);
$2
dst_buffer[linear_index] = value;
}
if (y1_in && x0_in) {
int linear_index = offset_1;
FLT4 value = FLT4(l0);
uint3 gid = uint3(gid_x, gid_y + 1, gid_z);
$2
dst_buffer[linear_index] = value;
}
if (y0_in && x1_in) {
int linear_index = offset_2;
FLT4 value = FLT4(t0);
uint3 gid = uint3(gid_x + 1, gid_y, gid_z);
$2
dst_buffer[linear_index] = value;
}
if (y1_in && x1_in) {
int linear_index = offset_3;
FLT4 value = FLT4(b0);
uint3 gid = uint3(gid_x + 1, gid_y + 1, gid_z);
$2
dst_buffer[linear_index] = value;
}
}
)";
return code;
}
// Reorder weights to make the weights memory access pattern cache friendly for
// DepthWiseConv3x3Stride1x1
std::vector<float> ReorderWeightsDepthWiseConv3x3Stride1x1(
const DepthwiseConvolution2DAttributes& attr) {
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
const int kernel_x = 3;
const int kernel_y = 3;
std::vector<float> weights_reordered((kernel_x * kernel_y + 1) * src_depth *
4);
int counter = 0;
for (int s = 0; s < src_depth; ++s) {
for (int x = 0; x < kernel_x; ++x) {
for (int y = 0; y < kernel_y; ++y) {
for (int i = 0; i < 4; ++i) {
const int s_ch = s * 4 + i;
if (s_ch < attr.weights.shape.i) {
const int f_index = attr.weights.shape.LinearIndex({0, y, x, s_ch});
weights_reordered[counter++] = attr.weights.data[f_index];
} else {
weights_reordered[counter++] = 0.0f;
}
}
}
}
for (int i = 0; i < 4; ++i) {
const int dst_ch = s * 4 + i;
if (dst_ch < attr.bias.shape.v) {
weights_reordered[counter++] = attr.bias.data[dst_ch];
} else {
weights_reordered[counter++] = 0.0f;
}
}
}
return weights_reordered;
}
static std::vector<uint8_t> GetUniformBufferDepthWiseConv3x3Stride1x1(
const BHWC& src_size, const BHWC& dst_size,
const DepthwiseConvolution2DAttributes& params) {
std::vector<int> uniform_params = {
src_size.w,
src_size.h,
src_size.w * src_size.h,
IntegralDivideRoundUp(src_size.c, 4),
dst_size.w,
dst_size.h,
dst_size.w * dst_size.h,
IntegralDivideRoundUp(dst_size.c, 4),
-params.padding.prepended.w,
-params.padding.prepended.h,
0, // dummy, for alignment
0, // dummy, for alignment
0, // dummy, for alignment
0, // dummy, for alignment
0, // dummy, for alignment
0, // dummy, for alignment
};
return GetByteBuffer(uniform_params);
}
std::string GetKernelDepthWiseConv3x3Stride2() {
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int2 padding;
int2 stride;
int2 dilation;
int2 dummy0; // for alignment
};
$0
kernel void ComputeFunction(
$1
uint3 ugid[[thread_position_in_grid]])
{
int gid_x = ugid.x;
int gid_y = ugid.y * 2;
int gid_z = ugid.z;
if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) {
return;
}
device FLT4* src_loc = src_buffer + gid_z * params.src_size.z;
device FLT4* filters_loc = filters + gid_z * 10;
ACCUM_FLT4 r0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
ACCUM_FLT4 l0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
int x0 = gid_x * params.stride.x + params.padding.x;
int x1 = gid_x * params.stride.x + params.padding.x + params.dilation.x;
int x2 = gid_x * params.stride.x + params.padding.x + 2 * params.dilation.x;
int y0 = gid_y * 2 + params.padding.y;
int y1 = gid_y * 2 + params.padding.y + 1;
int y2 = gid_y * 2 + params.padding.y + 2;
int y3 = gid_y * 2 + params.padding.y + 3;
int y4 = gid_y * 2 + params.padding.y + 4;
bool x0_out = x0 < 0 || x0 >= params.src_size.x;
bool x1_out = x1 < 0 || x1 >= params.src_size.x;
bool x2_out = x2 < 0 || x2 >= params.src_size.x;
bool y0_out = y0 < 0 || y0 >= params.src_size.y;
bool y1_out = y1 < 0 || y1 >= params.src_size.y;
bool y2_out = y2 < 0 || y2 >= params.src_size.y;
bool y3_out = y3 < 0 || y3 >= params.src_size.y;
bool y4_out = y4 < 0 || y4 >= params.src_size.y;
x0 = clamp(x0, 0, params.src_size.x - 1);
x1 = clamp(x1, 0, params.src_size.x - 1);
x2 = clamp(x2, 0, params.src_size.x - 1);
y0 = clamp(y0, 0, params.src_size.y - 1);
y1 = clamp(y1, 0, params.src_size.y - 1);
y2 = clamp(y2, 0, params.src_size.y - 1);
y3 = clamp(y3, 0, params.src_size.y - 1);
y4 = clamp(y4, 0, params.src_size.y - 1);
FLT4 s0 = src_loc[y0 * params.src_size.x + x0] * FLT(!(x0_out || y0_out));
FLT4 s1 = src_loc[y0 * params.src_size.x + x1] * FLT(!(x1_out || y0_out));
FLT4 s2 = src_loc[y0 * params.src_size.x + x2] * FLT(!(x2_out || y0_out));
r0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]);
r0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]);
r0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]);
s0 = src_loc[y1 * params.src_size.x + x0] * FLT(!(x0_out || y1_out));
s1 = src_loc[y1 * params.src_size.x + x1] * FLT(!(x1_out || y1_out));
s2 = src_loc[y1 * params.src_size.x + x2] * FLT(!(x2_out || y1_out));
r0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]);
r0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]);
r0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]);
s0 = src_loc[y2 * params.src_size.x + x0] * FLT(!(x0_out || y2_out));
s1 = src_loc[y2 * params.src_size.x + x1] * FLT(!(x1_out || y2_out));
s2 = src_loc[y2 * params.src_size.x + x2] * FLT(!(x2_out || y2_out));
r0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]);
r0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]);
r0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]);
l0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]);
l0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]);
l0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]);
s0 = src_loc[y3 * params.src_size.x + x0] * FLT(!(x0_out || y3_out));
s1 = src_loc[y3 * params.src_size.x + x1] * FLT(!(x1_out || y3_out));
s2 = src_loc[y3 * params.src_size.x + x2] * FLT(!(x2_out || y3_out));
l0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]);
l0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]);
l0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]);
s0 = src_loc[y4 * params.src_size.x + x0] * FLT(!(x0_out || y4_out));
s1 = src_loc[y4 * params.src_size.x + x1] * FLT(!(x1_out || y4_out));
s2 = src_loc[y4 * params.src_size.x + x2] * FLT(!(x2_out || y4_out));
l0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]);
l0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]);
l0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]);
r0 += TO_ACCUM4_TYPE(filters_loc[9]);
l0 += TO_ACCUM4_TYPE(filters_loc[9]);
const int offset_0 = gid_z * params.dst_size.z
+ gid_y * params.dst_size.x + gid_x;
const int offset_1 = offset_0 + params.dst_size.x;
bool y0_in = gid_y < params.dst_size.y;
bool y1_in = gid_y + 1 < params.dst_size.y;
if (y0_in) {
int linear_index = offset_0;
FLT4 value = FLT4(r0);
uint3 gid = uint3(gid_x, gid_y, gid_z);
$2
dst_buffer[linear_index] = value;
}
if (y1_in) {
int linear_index = offset_1;
FLT4 value = FLT4(l0);
uint3 gid = uint3(gid_x, gid_y, gid_z);
$2
dst_buffer[linear_index] = value;
}
}
)";
return code;
}
// Reorder weights to make the weights memory access pattern cache friendly for
// DepthWiseConv3x3Stride2
std::vector<float> ReorderWeightsDepthWiseConv3x3Stride2(
const DepthwiseConvolution2DAttributes& attr) {
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
const int kernel_x = 3;
const int kernel_y = 3;
std::vector<float> weights_reordered((kernel_x * kernel_y + 1) * src_depth *
4);
int counter = 0;
for (int s = 0; s < src_depth; ++s) {
for (int y = 0; y < kernel_y; ++y) {
for (int x = 0; x < kernel_x; ++x) {
for (int i = 0; i < 4; ++i) {
const int s_ch = s * 4 + i;
if (s_ch < attr.weights.shape.i) {
const int f_index = attr.weights.shape.LinearIndex({0, y, x, s_ch});
weights_reordered[counter++] = attr.weights.data[f_index];
} else {
weights_reordered[counter++] = 0.0f;
}
}
}
}
for (int i = 0; i < 4; ++i) {
const int dst_ch = s * 4 + i;
if (dst_ch < attr.bias.shape.v) {
weights_reordered[counter++] = attr.bias.data[dst_ch];
} else {
weights_reordered[counter++] = 0.0f;
}
}
}
return weights_reordered;
}
static std::vector<uint8_t> GetUniformBufferDepthWiseConv3x3Stride2(
const BHWC& src_size, const BHWC& dst_size,
const DepthwiseConvolution2DAttributes& attr) {
std::vector<int> uniform_params = {
src_size.w,
src_size.h,
src_size.w * src_size.h,
IntegralDivideRoundUp(src_size.c, 4),
dst_size.w,
dst_size.h,
dst_size.w * dst_size.h,
IntegralDivideRoundUp(dst_size.c, 4),
-attr.padding.prepended.w,
-attr.padding.prepended.h,
attr.strides.w,
attr.strides.h,
attr.dilations.w,
attr.dilations.h,
0, // dummy, for alignment
0, // dummy, for alignment
};
return GetByteBuffer(uniform_params);
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
int id, ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
const RuntimeOptions& options) {
int channels_multiplier = attr.weights.shape.o;
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
std::string shader_source = R"(
#include <metal_stdlib>
using namespace metal;
constant int kernel_x = $0;
constant int kernel_y = $1;
struct uniforms {
int4 stride;
int4 padding;
int4 dilation;
int4 size;
int4 channel_multiplier;
};
$$0
kernel void ComputeFunction(
$$1
uint tid[[thread_index_in_threadgroup]],
uint3 gid[[thread_position_in_grid]]) {
const bool outside = static_cast<int>(gid.x) >= params.size.z ||
static_cast<int>(gid.y) >= params.size.w;
if (outside) {
return;
}
device FLT4* temp = filters + gid.z * kernel_y * kernel_x;
float4 sum0 = float4(0.0f, 0.0f, 0.0f, 0.0f);
for(int ky = 0; ky < kernel_y; ++ky) {
for(int kx = 0; kx < kernel_x; ++kx) {
int2 coords = int2(gid.xy) * params.stride.xy + int2(kx, ky) * params.dilation.xy -
params.padding.xy;
const bool outside = coords.x < 0 || coords.y < 0 ||
coords.x >= params.size.x || coords.y >= params.size.y;
if (outside) continue;
)";
if (channels_multiplier == 1) {
shader_source += R"(
const int src_layer = gid.z;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src_modified = src_buffer[src_index];
)";
} else if (channels_multiplier == 2) {
shader_source += R"(
const int src_layer = gid.z / 2;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src = src_buffer[src_index];
const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw;
const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
)";
} else if (channels_multiplier == 4) {
shader_source += R"(
const int src_layer = gid.z / 4;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src = src_buffer[src_index];
const FLT t0 = src[gid.z % 4];
const FLT4 src_modified = FLT4(t0, t0, t0, t0);
)";
} else {
shader_source += R"(
const int src_layer = gid.z / params.channel_multiplier.x;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src = src_buffer[src_index];
FLT4 src_modified;
const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4;
src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x];
src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x];
src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x];
src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x];
)";
}
shader_source += R"(
sum0 += float4(src_modified * temp[ky * kernel_x + kx]);
}
}
FLT4 res = FLT4(sum0 + float4(biases[gid.z]));
const int linear_index = (gid.z * params.size.w + int(gid.y)) * params.size.z + int(gid.x);
FLT4 value = res;
$$2
output_buffer[linear_index] = value;
}
)";
desc->shader_source = absl::Substitute(shader_source, attr.weights.shape.w,
attr.weights.shape.h);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* output_buffer",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) {
auto out_shape =
CalculateOutputShape(buffers.find(input_id)->second, attr);
return out_shape;
}};
const int output_channels_count = attr.weights.shape.i * attr.weights.shape.o;
desc->immutable_buffers = {
{"device FLT4* const filters",
GetByteBufferConverted(ConvertToPIOHW4(attr.weights),
options.storage_precision)},
{"device FLT4* const biases",
GetByteBufferConvertedResized(attr.bias.data, options.storage_precision,
output_channels_count)},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(input_id)->second;
const auto& output_dimension = buffers.find(output_id)->second;
std::vector<int> uniform_params{
attr.strides.w,
attr.strides.h,
1,
1,
attr.padding.prepended.w,
attr.padding.prepended.h,
1,
1,
attr.dilations.w,
attr.dilations.h,
1,
1,
dimension.w,
dimension.h,
output_dimension.w,
output_dimension.h,
attr.weights.shape.o,
0,
0,
0,
};
return GetByteBuffer(uniform_params);
}},
};
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(output_id)->second;
uint3 groups_size{8, 4, 1};
uint3 groups_count{IntegralDivideRoundUp(dimension.w, groups_size.x),
IntegralDivideRoundUp(dimension.h, groups_size.y),
IntegralDivideRoundUp(dimension.c, 4)};
return std::make_pair(groups_size, groups_count);
};
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> DepthWiseConv3x3Stride1x1(
int id, ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetKernelDepthWiseConv3x3Stride1x1();
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) {
auto out_shape =
CalculateOutputShape(buffers.find(input_id)->second, attr);
return out_shape;
}};
// For this operation we keep weights and biases in one buffer
auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride1x1(attr);
desc->immutable_buffers = {
{"device FLT4* const filters",
GetByteBufferConverted(weights_reordered, options.storage_precision)},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
return GetUniformBufferDepthWiseConv3x3Stride1x1(
input_dimensions, output_dimensions, attr);
}},
};
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(output_id)->second;
const int grid_x = IntegralDivideRoundUp(dimension.w, 2);
const int grid_y = IntegralDivideRoundUp(dimension.h, 2);
const int grid_z = IntegralDivideRoundUp(dimension.c, 4);
uint3 group_size{8, 4, 1};
if (grid_x <= 4) {
group_size.x = 4;
group_size.z = grid_z % 2 == 0 ? 2 : 1;
}
const int groups_x = IntegralDivideRoundUp(grid_x, group_size.x);
const int groups_y = IntegralDivideRoundUp(grid_y, group_size.y);
const int groups_z = IntegralDivideRoundUp(grid_z, group_size.z);
return std::make_pair(group_size, uint3(groups_x, groups_y, groups_z));
};
return {desc};
}
bool CheckDepthWiseConv3x3Stride1x1Support(
const DepthwiseConvolution2DAttributes& attr) {
return attr.weights.shape.o == 1 && attr.weights.shape.h == 3 &&
attr.weights.shape.w == 3 && attr.strides.h == 1 &&
attr.strides.w == 1 && attr.dilations.h == 1 && attr.dilations.w == 1;
}
std::vector<ComputeTaskDescriptorPtr> DepthWiseConv3x3Stride2(
int id, ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetKernelDepthWiseConv3x3Stride2();
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, attr](const std::map<ValueId, BHWC>& buffers) {
auto out_shape =
CalculateOutputShape(buffers.find(input_id)->second, attr);
return out_shape;
}};
// For this operation we keep weights and biases in one buffer
auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride2(attr);
desc->immutable_buffers = {
{"device FLT4* const filters",
GetByteBufferConverted(weights_reordered, options.storage_precision)},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
return GetUniformBufferDepthWiseConv3x3Stride2(
input_dimensions, output_dimensions, attr);
}},
};
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(output_id)->second;
const int grid_x = dimension.w;
const int grid_y = IntegralDivideRoundUp(dimension.h, 2);
const int grid_z = IntegralDivideRoundUp(dimension.c, 4);
const uint3 group_size{8, 4, 1};
const int groups_x = IntegralDivideRoundUp(grid_x, group_size.x);
const int groups_y = IntegralDivideRoundUp(grid_y, group_size.y);
const int groups_z = IntegralDivideRoundUp(grid_z, group_size.z);
return std::make_pair(group_size, uint3(groups_x, groups_y, groups_z));
};
return {desc};
}
bool CheckDepthWiseConv3x3Stride2Support(
const DepthwiseConvolution2DAttributes& attr) {
return attr.weights.shape.o == 1 && attr.weights.shape.h == 3 &&
attr.weights.shape.w == 3 && attr.strides.h == 2 &&
attr.dilations.h == 1;
}
} // namespace metal
} // namespace gpu
} // namespace tflite