Removed unused ConvolutionTransposed3x3.
PiperOrigin-RevId: 303397293
Change-Id: I62b6dc3add95edf19e1cbf41c7282b23669d7b49
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc
index 9c3f91d..fd3f32f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc
@@ -275,618 +275,6 @@
src_local_size_x, src_local_size_y, workgroup_x, workgroup_y);
}
-struct GridParams {
- uint rect_offsets[4];
- uint widths[4];
- short2 origins[4];
- uint elements_count;
-};
-
-struct Params3x3 {
- short2 inner_size;
- short2 src_offset;
- short2 dst_offset;
-};
-
-void Init3x3(const ConvolutionTransposedAttributes& attr, const int2& src_size,
- const int2& dst_size, GridParams* grid_params,
- Params3x3* params3x3) {
- short2 src_size_scaled;
- src_size_scaled.x = (src_size.x - 1) * 2;
- src_size_scaled.y = (src_size.y - 1) * 2;
- short2 top_left_src, bottom_right_src;
- top_left_src.x = 1 - attr.padding.prepended.w;
- top_left_src.y = 1 - attr.padding.prepended.h;
- bottom_right_src.x = top_left_src.x + src_size_scaled.x;
- bottom_right_src.y = top_left_src.y + src_size_scaled.y;
- short2 top_left_inner, bottom_right_inner;
- if (top_left_src.x >= 0) {
- top_left_inner.x = top_left_src.x;
- } else {
- top_left_inner.x = std::abs(top_left_src.x % 2);
- }
- if (top_left_src.y >= 0) {
- top_left_inner.y = top_left_src.y;
- } else {
- top_left_inner.y = std::abs(top_left_src.y % 2);
- }
-
- if (bottom_right_src.x <= dst_size.x) {
- bottom_right_inner.x = bottom_right_src.x;
- } else {
- bottom_right_inner.x = dst_size.x;
- }
- if (top_left_src.x % 2 == 0) {
- bottom_right_inner.x -= bottom_right_inner.x % 2;
- } else {
- if (bottom_right_inner.x % 2 == 0) {
- bottom_right_inner.x -= 1;
- }
- }
- bottom_right_inner.x -= 1;
-
- if (bottom_right_src.y <= dst_size.y) {
- bottom_right_inner.y = bottom_right_src.y;
- } else {
- bottom_right_inner.y = dst_size.y;
- }
- if (top_left_src.y % 2 == 0) {
- bottom_right_inner.y -= bottom_right_inner.y % 2;
- } else {
- if (bottom_right_inner.y % 2 == 0) {
- bottom_right_inner.y -= 1;
- }
- }
- bottom_right_inner.y -= 1;
-
- params3x3->dst_offset = top_left_inner;
- params3x3->src_offset.x = (top_left_inner.x - top_left_src.x) / 2;
- params3x3->src_offset.y = (top_left_inner.y - top_left_src.y) / 2;
- params3x3->inner_size.x =
- std::max(0, bottom_right_inner.x - top_left_inner.x + 1) / 2;
- params3x3->inner_size.y =
- std::max(0, bottom_right_inner.y - top_left_inner.y + 1) / 2;
-
- short2 top_rect, bottom_rect, left_rect, right_rect;
-
- top_rect.x = dst_size.x;
- top_rect.y = top_left_inner.y;
-
- bottom_rect.x = dst_size.x;
- bottom_rect.y = dst_size.y - bottom_right_inner.y - 1;
-
- left_rect.x = top_left_inner.x;
- left_rect.y = dst_size.y - top_rect.y - bottom_rect.y;
-
- right_rect.x = dst_size.x - bottom_right_inner.x - 1;
- right_rect.y = left_rect.y;
-
- grid_params->widths[0] = top_rect.x;
- grid_params->widths[1] = left_rect.x;
- grid_params->widths[2] = right_rect.x;
- grid_params->widths[3] = bottom_rect.x;
-
- grid_params->rect_offsets[0] = 0;
- grid_params->rect_offsets[1] =
- grid_params->rect_offsets[0] + top_rect.x * top_rect.y;
- grid_params->rect_offsets[2] =
- grid_params->rect_offsets[1] + left_rect.x * left_rect.y;
- grid_params->rect_offsets[3] =
- grid_params->rect_offsets[2] + right_rect.x * right_rect.y;
- grid_params->elements_count =
- grid_params->rect_offsets[3] + bottom_rect.x * bottom_rect.y;
-
- grid_params->origins[0] = short2(0, 0);
- grid_params->origins[1] = short2(int16_t(0), int16_t(top_rect.y));
- grid_params->origins[2] =
- short2(int16_t(dst_size.x - right_rect.x), int16_t(top_rect.y));
- grid_params->origins[3] = short2(0, dst_size.y - bottom_rect.y);
-}
-
-std::string GetDeconvolutionBorder(
- const ConvolutionTransposedAttributes& attr) {
- std::string constant_args = R"(
- constant short2 padding = {$0, $1};
- constant short2 stride = {$2, $3};
- constant short2 kernel_size = {$4, $5};
- constant short2 inner_size = {$6, $7};
- constant short2 kernel_offset = {$8, $9};
- )";
- std::string shader_source = R"(
- #include <metal_stdlib>
- using namespace metal;
-
- struct FilterStripe {
- FLT4 vals[$0];
- };
-
- constant int src_depth = $1;
- constant int dst_depth = $2;
- constant int dst_channels = $3;
- constant int dst_channels_aligned = $4;
-
- $5
-
- struct uniforms {
- int2 src_size;
- int2 dst_size;
- uint rect_offsets[4];
- uint widths[4];
- short2 origins[4];
- uint elements_count;
- };
-
- short2 GetGridIdByLinearId(uint linear_id, constant uniforms& params);
-
- short2 GetGridIdByLinearId(uint linear_id, constant uniforms& params) {
- int index = 0;
- index = linear_id >= params.rect_offsets[0] ? 0 : index;
- index = linear_id >= params.rect_offsets[1] ? 1 : index;
- index = linear_id >= params.rect_offsets[2] ? 2 : index;
- index = linear_id >= params.rect_offsets[3] ? 3 : index;
-
- const uint rect_index = linear_id - params.rect_offsets[index];
-
- const uint rect_width = params.widths[index];
- const short2 offset = short2(rect_index % rect_width, rect_index / rect_width);
- return params.origins[index] + offset;
- }
-
- $$0
- kernel void ComputeFunction(
- $$1
- uint linear_id[[thread_position_in_grid]]) {
- if (linear_id >= params.elements_count) {
- return;
- }
- short2 gid_sh = GetGridIdByLinearId(linear_id, params);
-
- float out[$4];
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- short2 offset = gid_sh + padding - kernel_offset;
- offset.x = offset.x % stride.x;
- offset.y = offset.y % stride.y;
- offset += stride;
- offset.x = offset.x % stride.x;
- offset.y = offset.y % stride.y;
- short2 f_offset;
- f_offset.x = offset.x == 0 ? 0 : stride.x - offset.x;
- f_offset.y = offset.y == 0 ? 0 : stride.y - offset.y;
- for (int ky = 0; ky < inner_size.y; ++ky) {
- for (int kx = 0; kx < inner_size.x; ++kx) {
- short2 index = short2(kx, ky) * stride + f_offset;
- bool inside_kernel = index.x < kernel_size.x && index.y < kernel_size.y;
- const short2 src_coord = (gid_sh + index + padding - kernel_offset) / stride;
- index = kernel_size - short2(1, 1) - index;
- bool outside = src_coord.x < 0 || src_coord.y < 0 ||
- src_coord.x >= params.src_size.x || src_coord.y >= params.src_size.y;
- const int kernel_index = index.y * kernel_size.x + index.x;
- bool belong = inside_kernel && !outside;
- if (belong) {
- for (int l = 0; l < src_depth; ++l) {
- const int src_index = (l * params.src_size.y + src_coord.y) *
- params.src_size.x + src_coord.x;
- FLT4 srcColor = src_buffer[src_index];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor, filters[kernel_index].vals[l * dst_channels_aligned + k]);
- }
- }
- }
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(gid_sh.y)) *
- params.dst_size.x + int(gid_sh.x);
- uint3 gid = uint3(uint(gid_sh.x), uint(gid_sh.y), uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
- }
- )";
- const int kernel_x = attr.weights.shape.w;
- const int kernel_y = attr.weights.shape.h;
- const int inner_size_x = (kernel_x - 1) / attr.stride.w + 1;
- const int inner_size_y = (kernel_y - 1) / attr.stride.h + 1;
- std::string constant_args_inplaced = absl::Substitute(
- constant_args, attr.padding.prepended.w, attr.padding.prepended.h,
- attr.stride.w, attr.stride.h, kernel_x, kernel_y, inner_size_x,
- inner_size_y, kernel_x - 1, kernel_y - 1);
- const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
- const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
- const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4);
- return absl::Substitute(shader_source, src_depth * dst_channels_aligned,
- src_depth, dst_depth, attr.weights.shape.o,
- dst_channels_aligned, constant_args_inplaced);
-}
-
-std::string GetDeconvolution3x3(const ConvolutionTransposedAttributes& attr) {
- std::string shader_source = R"(
- #include <metal_stdlib>
- using namespace metal;
-
- struct FilterStripe {
- FLT4 vals[$0];
- };
-
- constant int src_depth = $1;
- constant int dst_depth = $2;
- constant int dst_channels = $3;
- constant int dst_channels_aligned = $4;
-
- struct uniforms {
- int2 src_size;
- int2 dst_size;
- short2 inner_size;
- short2 src_offset;
- short2 dst_offset;
- };
-
- $$0
- kernel void ComputeFunction(
- $$1
- uint tid[[thread_index_in_threadgroup]],
- uint2 ugid[[thread_position_in_grid]]) {
- if (static_cast<int>(ugid.x) >= params.inner_size.x ||
- static_cast<int>(ugid.y) >= params.inner_size.y) {
- return;
- }
-
- float out[$4];
- short2 src_coord_0 = short2(ugid) + params.src_offset;
- short2 dst_coord = short2(ugid) * 2 + params.dst_offset;
-
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, filters[4].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
-
- short2 src_coord_1 = src_coord_0 + short2(1, 0);
- dst_coord += short2(1, 0);
-
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- const int src_index_1 = (l * params.src_size.y + src_coord_1.y) *
- params.src_size.x + src_coord_1.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- FLT4 srcColor_1 = src_buffer[src_index_1];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, filters[5].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_1, filters[3].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
-
- short2 src_coord_2 = src_coord_0 + short2(0, 1);
- dst_coord += short2(-1, 1);
-
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- const int src_index_2 = (l * params.src_size.y + src_coord_2.y) *
- params.src_size.x + src_coord_2.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- FLT4 srcColor_2 = src_buffer[src_index_2];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, filters[7].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_2, filters[1].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
-
- short2 src_coord_3 = src_coord_0 + short2(1, 1);
- dst_coord += short2(1, 0);
-
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- const int src_index_1 = (l * params.src_size.y + src_coord_1.y) *
- params.src_size.x + src_coord_1.x;
- const int src_index_2 = (l * params.src_size.y + src_coord_2.y) *
- params.src_size.x + src_coord_2.x;
- const int src_index_3 = (l * params.src_size.y + src_coord_3.y) *
- params.src_size.x + src_coord_3.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- FLT4 srcColor_1 = src_buffer[src_index_1];
- FLT4 srcColor_2 = src_buffer[src_index_2];
- FLT4 srcColor_3 = src_buffer[src_index_3];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, filters[8].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_1, filters[6].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_2, filters[2].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_3, filters[0].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
- }
- )";
-
- const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
- const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
- const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4);
- return absl::Substitute(shader_source, src_depth * dst_channels_aligned,
- src_depth, dst_depth, attr.weights.shape.o,
- dst_channels_aligned);
-}
-
-std::string GetDeconvolutionShared3x3(
- const ConvolutionTransposedAttributes& attr) {
- std::string shader_source = R"(
- #include <metal_stdlib>
- using namespace metal;
-
- struct FilterStripe {
- FLT4 vals[$0];
- };
-
- constant int src_depth = $1;
- constant int dst_depth = $2;
- constant int dst_channels = $3;
- constant int dst_channels_aligned = $4;
-
- struct uniforms {
- int2 src_size;
- int2 dst_size;
- short2 inner_size;
- short2 src_offset;
- short2 dst_offset;
- };
-
- $$0
- kernel void ComputeFunction(
- $$1
- uint tid[[thread_index_in_threadgroup]],
- uint2 ugid[[thread_position_in_grid]]) {
-
- float out[$4];
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- threadgroup FilterStripe stripes[4];
- threadgroup_barrier(mem_flags::mem_none);
- if (tid < dst_channels) {
- for (int l = 0; l < src_depth; ++l) {
- stripes[0].vals[l * dst_channels_aligned + tid]
- = filters[4].vals[l * dst_channels_aligned + tid];
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- bool inside_grid = (static_cast<int>(ugid.x) < params.inner_size.x)
- && (static_cast<int>(ugid.y) < params.inner_size.y);
-
- short2 src_coord_0 = short2(ugid) + params.src_offset;
- short2 dst_coord = short2(ugid) * 2 + params.dst_offset;
-
- if (inside_grid) {
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(ugid.x, ugid.y, uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
- }
-
- short2 src_coord_1 = src_coord_0 + short2(1, 0);
- dst_coord += short2(1, 0);
-
- threadgroup_barrier(mem_flags::mem_none);
- if (tid < dst_channels) {
- for (int l = 0; l < src_depth; ++l) {
- stripes[0].vals[l * dst_channels_aligned + tid]
- = filters[5].vals[l * dst_channels_aligned + tid];
- stripes[1].vals[l * dst_channels_aligned + tid]
- = filters[3].vals[l * dst_channels_aligned + tid];
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (inside_grid) {
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- const int src_index_1 = (l * params.src_size.y + src_coord_1.y) *
- params.src_size.x + src_coord_1.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- FLT4 srcColor_1 = src_buffer[src_index_1];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_1, stripes[1].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(ugid.x, ugid.y, uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
- }
-
- short2 src_coord_2 = src_coord_0 + short2(0, 1);
- dst_coord += short2(-1, 1);
-
- threadgroup_barrier(mem_flags::mem_none);
- if (tid < dst_channels) {
- for (int l = 0; l < src_depth; ++l) {
- stripes[0].vals[l * dst_channels_aligned + tid]
- = filters[7].vals[l * dst_channels_aligned + tid];
- stripes[1].vals[l * dst_channels_aligned + tid]
- = filters[1].vals[l * dst_channels_aligned + tid];
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (inside_grid) {
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- const int src_index_2 = (l * params.src_size.y + src_coord_2.y) *
- params.src_size.x + src_coord_2.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- FLT4 srcColor_2 = src_buffer[src_index_2];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_2, stripes[1].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(ugid.x, ugid.y, uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
- }
-
- short2 src_coord_3 = src_coord_0 + short2(1, 1);
- dst_coord += short2(1, 0);
-
- threadgroup_barrier(mem_flags::mem_none);
- if (tid < dst_channels) {
- for (int l = 0; l < src_depth; ++l) {
- stripes[0].vals[l * dst_channels_aligned + tid]
- = filters[8].vals[l * dst_channels_aligned + tid];
- stripes[1].vals[l * dst_channels_aligned + tid]
- = filters[6].vals[l * dst_channels_aligned + tid];
- stripes[2].vals[l * dst_channels_aligned + tid]
- = filters[2].vals[l * dst_channels_aligned + tid];
- stripes[3].vals[l * dst_channels_aligned + tid]
- = filters[0].vals[l * dst_channels_aligned + tid];
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (inside_grid) {
- for (short l = 0; l < dst_depth * 4; ++l) {
- out[l] = float(0.0f);
- }
-
- for (int l = 0; l < src_depth; ++l) {
- const int src_index_0 = (l * params.src_size.y + src_coord_0.y) *
- params.src_size.x + src_coord_0.x;
- const int src_index_1 = (l * params.src_size.y + src_coord_1.y) *
- params.src_size.x + src_coord_1.x;
- const int src_index_2 = (l * params.src_size.y + src_coord_2.y) *
- params.src_size.x + src_coord_2.x;
- const int src_index_3 = (l * params.src_size.y + src_coord_3.y) *
- params.src_size.x + src_coord_3.x;
- FLT4 srcColor_0 = src_buffer[src_index_0];
- FLT4 srcColor_1 = src_buffer[src_index_1];
- FLT4 srcColor_2 = src_buffer[src_index_2];
- FLT4 srcColor_3 = src_buffer[src_index_3];
- for (int k = 0; k < dst_channels; ++k) {
- out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_1, stripes[1].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_2, stripes[2].vals[l * dst_channels_aligned + k]);
- out[k] += dot(srcColor_3, stripes[3].vals[l * dst_channels_aligned + k]);
- }
- }
-
- for (short l = 0; l < dst_depth; ++l) {
- FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l];
- const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) *
- params.dst_size.x + int(dst_coord.x);
- uint3 gid = uint3(ugid.x, ugid.y, uint(l));
- $$2
- dst_buffer[linear_index] = value;
- }
- }
- }
- )";
- const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
- const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
- const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4);
- return absl::Substitute(shader_source, src_depth * dst_channels_aligned,
- src_depth, dst_depth, attr.weights.shape.o,
- dst_channels_aligned);
-}
-
std::string GetDeconvolution4x4(const int2& block_size, bool use_local_mem) {
std::string c = R"(
#include <metal_stdlib>
@@ -1152,200 +540,6 @@
return {desc};
}
-std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed3x3(
- int id, ValueId input_id, ValueId output_id,
- const ConvolutionTransposedAttributes& params,
- const RuntimeOptions& options) {
- const int kThreadGroupWidth = 16;
- const int kThreadGroupHeight = 4;
-
- auto border_desc = std::make_shared<ComputeTaskDescriptor>();
- border_desc->id = id;
- border_desc->is_linkable = false;
-
- border_desc->shader_source = GetDeconvolutionBorder(params);
-
- border_desc->input_buffers = {
- {input_id, "device FLT4* const src_buffer"},
- };
-
- border_desc->output_buffer = {
- output_id, "device FLT4* dst_buffer",
- [input_id, params](const std::map<ValueId, BHWC>& buffers) {
- const auto& src_shape = buffers.find(input_id)->second;
- BHWC dst_shape = CalculateOutputShape(src_shape, params);
- return BHWC{src_shape.b, dst_shape.h, dst_shape.w, dst_shape.c};
- }};
-
- const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4);
- const int src_ch_aligned = AlignByN(params.weights.shape.i, 4);
- const int dst_ch_aligned = AlignByN(params.weights.shape.o, 4);
- const int kernel_x = params.weights.shape.w;
- const int kernel_y = params.weights.shape.h;
- const int filters_aligned_size =
- src_ch_aligned * dst_ch_aligned * kernel_x * kernel_y;
- std::vector<float> filters_reordered(filters_aligned_size);
-
- int counter = 0;
- for (int y = 0; y < kernel_y; ++y) {
- for (int x = 0; x < kernel_x; ++x) {
- for (int ch = 0; ch < src_depth; ++ch) {
- for (int f = 0; f < dst_ch_aligned; ++f) {
- for (int i = 0; i < 4; ++i) {
- if (ch * 4 + i >= params.weights.shape.i ||
- f >= params.weights.shape.o) {
- filters_reordered[counter++] = 0.0f;
- } else {
- const int f_index =
- params.weights.shape.LinearIndex({f, y, x, ch * 4 + i});
- filters_reordered[counter++] = params.weights.data[f_index];
- }
- }
- }
- }
- }
- }
-
- auto filters =
- GetByteBufferConverted(filters_reordered, options.storage_precision);
- auto biases = GetByteBufferConvertedResized(
- params.bias.data, options.storage_precision, params.weights.shape.o);
- border_desc->immutable_buffers = {
- {"device FilterStripe* const filters", filters},
- {"constant FLT4* const biases", biases},
- };
-
- border_desc->uniform_buffers = {
- {"constant uniforms& params",
- [input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
- const auto& src_dim = buffers.find(input_id)->second;
- const auto& dst_dim = buffers.find(output_id)->second;
- GridParams grid_params;
- Params3x3 params3x3;
- Init3x3(params, int2(src_dim.w, src_dim.h), int2(dst_dim.w, dst_dim.h),
- &grid_params, ¶ms3x3);
- int* ptr = reinterpret_cast<int*>(&grid_params);
- std::vector<int> uniform_params{
- src_dim.w,
- src_dim.h,
- dst_dim.w,
- dst_dim.h,
- /*uint GridParams.rect_offsets[4]*/
- ptr[0],
- ptr[1],
- ptr[2],
- ptr[3],
- /*uint GridParams.widths[4]*/
- ptr[4],
- ptr[5],
- ptr[6],
- ptr[7],
- /*short2 GridParams.origins[4]*/
- ptr[8],
- ptr[9],
- ptr[10],
- ptr[11],
- /*uint GridParams.elements_count*/
- ptr[12],
- };
- return GetByteBuffer(uniform_params);
- }},
- };
-
- border_desc->resize_function =
- [input_id, params](const std::map<ValueId, BHWC>& buffers) {
- const uint3 groups_size{kThreadGroupWidth * kThreadGroupHeight, 1, 1};
- const auto& src_shape = buffers.find(input_id)->second;
- BHWC dst_shape = CalculateOutputShape(src_shape, params);
- GridParams grid_params;
- Params3x3 params3x3;
- Init3x3(params, int2(src_shape.w, src_shape.h),
- int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3);
- if (grid_params.elements_count == 0) {
- return std::make_pair(groups_size, uint3{0, 0, 0});
- }
- int groups_x =
- IntegralDivideRoundUp(grid_params.elements_count, groups_size.x);
- int groups_y = 1;
- int groups_z = 1;
- return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
- };
-
- auto desc = std::make_shared<ComputeTaskDescriptor>();
- desc->id = id;
- desc->is_linkable = false;
-
- const int shared_size = sizeof(float) * 4 * src_depth * dst_ch_aligned * 4;
- auto gpu_type = GetGpuType();
- if (shared_size < (1024 * 16 - 32) &&
- (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) &&
- dst_ch_aligned <= kThreadGroupWidth * kThreadGroupHeight) {
- desc->shader_source = GetDeconvolutionShared3x3(params);
- } else {
- desc->shader_source = GetDeconvolution3x3(params);
- }
-
- desc->input_buffers = {
- {input_id, "device FLT4* const src_buffer"},
- };
-
- desc->output_buffer = {
- output_id, "device FLT4* dst_buffer",
- [input_id, params](const std::map<ValueId, BHWC>& buffers) {
- const auto& src_shape = buffers.find(input_id)->second;
- BHWC dst_shape = CalculateOutputShape(src_shape, params);
- return BHWC{src_shape.b, dst_shape.h, dst_shape.w, dst_shape.c};
- }};
-
- desc->immutable_buffers = {
- {"device FilterStripe* const filters", filters},
- {"constant FLT4* const biases", biases},
- };
-
- desc->uniform_buffers = {
- {"constant uniforms& params",
- [input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
- const auto& src_shape = buffers.find(input_id)->second;
- const auto& dst_shape = buffers.find(output_id)->second;
- GridParams grid_params;
- Params3x3 params3x3;
- Init3x3(params, int2(src_shape.w, src_shape.h),
- int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3);
- int* ptr = reinterpret_cast<int*>(¶ms3x3);
- std::vector<int> uniform_params{
- src_shape.w,
- src_shape.h,
- dst_shape.w,
- dst_shape.h,
- /*short2 Params3x3.inner_size*/ ptr[0],
- /*short2 Params3x3.src_offset*/ ptr[1],
- /*short2 Params3x3.dst_offset*/ ptr[2],
- };
- return GetByteBuffer(uniform_params);
- }},
- };
-
- desc->resize_function = [input_id,
- params](const std::map<ValueId, BHWC>& buffers) {
- const uint3 groups_size{kThreadGroupWidth, kThreadGroupHeight, 1};
- const auto& src_shape = buffers.find(input_id)->second;
- BHWC dst_shape = CalculateOutputShape(src_shape, params);
- GridParams grid_params;
- Params3x3 params3x3;
- Init3x3(params, int2(src_shape.w, src_shape.h),
- int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3);
- if (params3x3.inner_size.x * params3x3.inner_size.y == 0) {
- return std::make_pair(groups_size, uint3{0, 0, 0});
- }
- int groups_x = IntegralDivideRoundUp(params3x3.inner_size.x, groups_size.x);
- int groups_y = IntegralDivideRoundUp(params3x3.inner_size.y, groups_size.y);
- int groups_z = 1;
- return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
- };
-
- return {border_desc, desc};
-}
-
std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed4x4(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& params,
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h
index cffab3c..0fa19db 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h
@@ -32,11 +32,6 @@
const ConvolutionTransposedAttributes& params,
const RuntimeOptions& options);
-std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed3x3(
- int id, ValueId input_id, ValueId output_id,
- const ConvolutionTransposedAttributes& params,
- const RuntimeOptions& options);
-
std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed4x4(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& params,