Using Tensor objects in Softmax kernels.
PiperOrigin-RevId: 350094543
Change-Id: I2e9ad146265af116692cd750e00d29014472be00
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index c66c280..45cd523 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -677,6 +677,7 @@
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
+ "//tensorflow/lite/delegates/gpu/common/task:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
],
)
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
index 9a3a8ea..8d2ba67 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
@@ -23,6 +23,7 @@
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@@ -40,7 +41,6 @@
using namespace metal;
struct uniforms {
- int4 size;
float4 mask;
};
@@ -51,11 +51,11 @@
uint3 ugid[[thread_position_in_grid]])
{
- float4 maxx4 = float4(src_tensor[0].x);
- for (int s = int(tid); s < params.size.x; s += 32) {
- float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
+ float4 maxx4 = float4(args.src_tensor.Read(0, 0, 0).x);
+ for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
+ float4 mask_a = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
- float4 src = float4(src_tensor[s]);
+ float4 src = float4(args.src_tensor.Read(0, 0, s));
src = src * mask_a + mask_b * src.x;
maxx4 = max(maxx4, src);
}
@@ -89,9 +89,9 @@
maximum = tmpx1[0];
float sum = 0.0f;
- for (int s = int(tid); s < params.size.x; s += 32) {
- float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
- float4 src = float4(src_tensor[s]) - float4(maximum);
+ for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
+ float4 mask_temp = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
+ float4 src = float4(args.src_tensor.Read(0, 0, s)) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
@@ -120,13 +120,13 @@
sum = tmpx1[0];
int dst_s = int(ugid.x);
- if (dst_s < params.size.x) {
- int linear_index = dst_s;
- float4 src = float4(src_tensor[linear_index]) - float4(maximum);
+ if (dst_s < args.src_tensor.Slices()) {
+ float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum);
FLT4 value = FLT4(exp(src) * sum);
- uint3 gid = uint3(0, 0, linear_index);
+ uint3 gid = uint3(0, 0, dst_s);
+ args.dst_tensor.GetAddress(linear_index, 0, 0, dst_s);
$2
- dst_tensor[linear_index] = value;
+ args.dst_tensor.Write(value, 0, 0, dst_s);
}
})";
return code;
@@ -135,28 +135,27 @@
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
ComputeTaskDescriptor desc(definition);
+ desc.tensors_as_args = true;
desc.shader_source = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
- int4 size;
float4 mask;
};
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
- if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
+ if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return;
}
- float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
- for (int d = 0; d < params.size.z; ++d) {
- int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
- float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
+ float maximum = args.src_tensor.Read(gid.x, gid.y, 0).x;
+ for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
+ float4 mask_a = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
- float4 src = float4(src_tensor[buffer_index]);
+ float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d));
src = src * mask_a + mask_b * src.x;
maximum = max(maximum, src.x);
maximum = max(maximum, src.y);
@@ -165,19 +164,18 @@
}
float sum = 0.0f;
- for (int d = 0; d < params.size.z; ++d) {
- int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
- float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
- float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
+ for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
+ float4 mask_temp = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
+ float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
- for (int d = 0; d < params.size.z; ++d) {
- const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
- float4 src = float4(src_tensor[linear_index]) - float4(maximum);
+ for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
+ float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
FLT4 value = FLT4(exp(src) / sum);
+ args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, d);
$2
- dst_tensor[linear_index] = value;
+ args.dst_tensor.Write(value, gid.x, gid.y, d);
}
}
)";
@@ -189,20 +187,9 @@
{"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
- const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
- struct uniforms {
- int4 size;
- float4 mask;
- };
- uniforms params;
- params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
- params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
- int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
- for (int i = 0; i < reminder; ++i) {
- params.mask[i] = 1.0f;
- }
- const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
- return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
+ float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
+ const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
+ return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
}},
};
@@ -220,6 +207,7 @@
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
const GpuInfo& gpu_info) {
ComputeTaskDescriptor desc(definition);
+ desc.tensors_as_args = true;
desc.shader_source = GetSoftmax1x1Code(gpu_info);
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
@@ -229,20 +217,9 @@
{"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
- const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
- struct uniforms {
- int4 size;
- float4 mask;
- };
- uniforms params;
- params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
- params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
- int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
- for (int i = 0; i < reminder; ++i) {
- params.mask[i] = 1.0f;
- }
- const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
- return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
+ float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
+ const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
+ return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
}},
};