Using Tensor objects in Softmax kernels.

PiperOrigin-RevId: 350094543
Change-Id: I2e9ad146265af116692cd750e00d29014472be00
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index c66c280..45cd523 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -677,6 +677,7 @@
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/task:util",
         "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
index 9a3a8ea..8d2ba67 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/task/util.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@@ -40,7 +41,6 @@
 using namespace metal;
 
 struct uniforms {
-  int4 size;
   float4 mask;
 };
 
@@ -51,11 +51,11 @@
                             uint3 ugid[[thread_position_in_grid]])
 {
 
-  float4 maxx4 = float4(src_tensor[0].x);
-  for (int s = int(tid); s < params.size.x; s += 32) {
-    float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
+  float4 maxx4 = float4(args.src_tensor.Read(0, 0, 0).x);
+  for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
+    float4 mask_a = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
     float4 mask_b = float4(1.0f) - mask_a;
-    float4 src = float4(src_tensor[s]);
+    float4 src = float4(args.src_tensor.Read(0, 0, s));
     src = src * mask_a + mask_b * src.x;
     maxx4 = max(maxx4, src);
   }
@@ -89,9 +89,9 @@
   maximum = tmpx1[0];
 
   float sum = 0.0f;
-  for (int s = int(tid); s < params.size.x; s += 32) {
-    float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
-    float4 src = float4(src_tensor[s]) - float4(maximum);
+  for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
+    float4 mask_temp = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
+    float4 src = float4(args.src_tensor.Read(0, 0, s)) - float4(maximum);
     sum += dot(mask_temp, exp(src));
   }
 
@@ -120,13 +120,13 @@
   sum = tmpx1[0];
 
   int dst_s = int(ugid.x);
-  if (dst_s < params.size.x) {
-    int linear_index = dst_s;
-    float4 src = float4(src_tensor[linear_index]) - float4(maximum);
+  if (dst_s < args.src_tensor.Slices()) {
+    float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum);
     FLT4 value = FLT4(exp(src) * sum);
-    uint3 gid = uint3(0, 0, linear_index);
+    uint3 gid = uint3(0, 0, dst_s);
+    args.dst_tensor.GetAddress(linear_index, 0, 0, dst_s);
     $2
-    dst_tensor[linear_index] = value;
+    args.dst_tensor.Write(value, 0, 0, dst_s);
   }
 })";
   return code;
@@ -135,28 +135,27 @@
 
 ComputeTaskDescriptor Softmax(const OperationDef& definition) {
   ComputeTaskDescriptor desc(definition);
+  desc.tensors_as_args = true;
   desc.shader_source = R"(
 #include <metal_stdlib>
 using namespace metal;
 
 struct uniforms {
-  int4 size;
   float4 mask;
 };
 $0
 kernel void ComputeFunction(
                             $1
                             uint3 gid[[thread_position_in_grid]]) {
-  if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
+  if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
     return;
   }
 
-  float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
-  for (int d = 0; d < params.size.z; ++d) {
-    int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
-    float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
+  float maximum = args.src_tensor.Read(gid.x, gid.y, 0).x;
+  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
+    float4 mask_a = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
     float4 mask_b = float4(1.0f) - mask_a;
-    float4 src = float4(src_tensor[buffer_index]);
+    float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d));
     src = src * mask_a + mask_b * src.x;
     maximum = max(maximum, src.x);
     maximum = max(maximum, src.y);
@@ -165,19 +164,18 @@
   }
 
   float sum = 0.0f;
-  for (int d = 0; d < params.size.z; ++d) {
-    int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
-    float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
-    float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
+  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
+    float4 mask_temp = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
+    float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
     sum += dot(mask_temp, exp(src));
   }
 
-  for (int d = 0; d < params.size.z; ++d) {
-    const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
-    float4 src = float4(src_tensor[linear_index]) - float4(maximum);
+  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
+    float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
     FLT4 value = FLT4(exp(src) / sum);
+    args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, d);
     $2
-    dst_tensor[linear_index] = value;
+    args.dst_tensor.Write(value, gid.x, gid.y, d);
   }
 }
   )";
@@ -189,20 +187,9 @@
       {"constant uniforms& params",
        [](const std::vector<BHWC>& src_shapes,
           const std::vector<BHWC>& dst_shapes) {
-         const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
-         struct uniforms {
-           int4 size;
-           float4 mask;
-         };
-         uniforms params;
-         params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
-         params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
-         int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
-         for (int i = 0; i < reminder; ++i) {
-           params.mask[i] = 1.0f;
-         }
-         const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
-         return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
+         float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
+         const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
+         return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
        }},
   };
 
@@ -220,6 +207,7 @@
 ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
                                  const GpuInfo& gpu_info) {
   ComputeTaskDescriptor desc(definition);
+  desc.tensors_as_args = true;
   desc.shader_source = GetSoftmax1x1Code(gpu_info);
 
   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
@@ -229,20 +217,9 @@
       {"constant uniforms& params",
        [](const std::vector<BHWC>& src_shapes,
           const std::vector<BHWC>& dst_shapes) {
-         const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
-         struct uniforms {
-           int4 size;
-           float4 mask;
-         };
-         uniforms params;
-         params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
-         params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
-         int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
-         for (int i = 0; i < reminder; ++i) {
-           params.mask[i] = 1.0f;
-         }
-         const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
-         return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
+         float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
+         const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
+         return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
        }},
   };