Using Tensor objects in Mean kernel.
PiperOrigin-RevId: 350397926
Change-Id: I808cb6b5c7f150015126a3d02b70aef9c4f48efa
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 45cd523..459e35e 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -872,6 +872,7 @@
"elementwise_test.mm",
"fully_connected_test.mm",
"max_unpooling_test.mm",
+ "mean_test.mm",
"padding_test.mm",
"pooling_test.mm",
"prelu_test.mm",
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc
index 9757e32..bda0967 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/mean.cc
@@ -43,7 +43,6 @@
#include <metal_stdlib>
using namespace metal;
struct uniforms {
- int4 src_size;
float4 inv_multipliers;
};
@@ -57,18 +56,16 @@
int local_y = static_cast<int>(tid3d.y);
int local_id = static_cast<int>(tid);
int S = static_cast<int>(gid.z);
- if (S >= params.src_size.z) return;
+ if (S >= args.dst_tensor.Slices()) return;
)";
c += " threadgroup float4 accum[" +
std::to_string(work_group_size.x * work_group_size.y) + "];\n";
c += " accum[local_id] = float4(0.0f);\n";
- c += " int src_offset = S * params.src_size.x * params.src_size.y;\n";
- c += " for (int s_y = local_y; s_y < params.src_size.y; s_y += " + wg_y +
- ") {\n";
- c += " for (int s_x = local_x; s_x < params.src_size.x; s_x += " + wg_x +
- ") {\n";
- c += " int src_index = src_offset + s_y * params.src_size.x + s_x;\n";
- c += " accum[local_id] += float4(src_tensor[src_index]);\n";
+ c += " for (int s_y = local_y; s_y < args.src_tensor.Height(); s_y += " +
+ wg_y + ") {\n";
+ c += " for (int s_x = local_x; s_x < args.src_tensor.Width(); s_x += " +
+ wg_x + ") {\n";
+ c += " accum[local_id] += float4(args.src_tensor.Read(s_x, s_y, S));\n";
c += " }\n";
c += " }\n";
c += " accum[local_id] *= params.inv_multipliers.x;\n";
@@ -95,7 +92,7 @@
c += R"(
const int linear_index = static_cast<int>(gid.z);
$2
- dst_tensor[linear_index] = value;
+ args.dst_tensor.Write(value, 0, 0, gid.z);
}
)";
return c;
@@ -111,6 +108,7 @@
const int3 work_group_size = int3(16, 16, 1);
ComputeTaskDescriptor desc(definition);
+ desc.tensors_as_args = true;
std::string code = GetMeanCode(work_group_size);
desc.shader_source = code;
@@ -121,21 +119,15 @@
{"constant uniforms& params",
[work_group_size](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
- const auto& src_shape = src_shapes[0];
- const int src_slices = DivideRoundUp(src_shape.c, 4);
- struct uniforms {
- int4 src_size;
- float4 inv_multipliers;
- };
- uniforms params;
- params.src_size = {src_shape.w, src_shape.h, src_slices, 0};
- const double total_size = src_shape.w * src_shape.h;
+ float4 inv_multipliers;
+ const double total_size = src_shapes[0].w * src_shapes[0].h;
const double size_0 = work_group_size.x * work_group_size.y;
const double size_1 = total_size / size_0;
- params.inv_multipliers.x = 1.0 / size_1;
- params.inv_multipliers.y = 1.0 / size_0;
- const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
- return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
+ inv_multipliers.x = 1.0 / size_1;
+ inv_multipliers.y = 1.0 / size_0;
+ const uint8_t* ptr =
+ reinterpret_cast<const uint8_t*>(&inv_multipliers);
+ return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
}},
};