ConcatZ reverted and rewritten to support batch implicitly.

PiperOrigin-RevId: 272790908
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
index 6712379..ca03ecc 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
@@ -36,22 +36,27 @@
   return true;
 }
 
+std::string GetSrcDepthSizeVar(int src_index) {
+  return "src_size_" + std::to_string(src_index) + "_depth";
+}
+
 std::string GetConcatKernelCode(
     const OperationDef& op_def, const std::vector<int>& channels,
     const std::vector<ElementwiseOperation*>& linked_operations) {
   std::vector<TensorCodeGenerator> srcs(channels.size());
   for (int i = 0; i < channels.size(); ++i) {
     const std::string tensor_name = "src_data_" + std::to_string(i);
-    const std::string uniform_name = "src_size_" + std::to_string(i);
-    srcs[i] =
-        TensorCodeGenerator(tensor_name, uniform_name, op_def.src_tensors[i]);
+    srcs[i] = TensorCodeGenerator(
+        tensor_name, {"dst_size.x", "dst_size.y", GetSrcDepthSizeVar(i)},
+        op_def.src_tensors[i]);
   }
-  TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
+  TensorCodeGenerator dst("dst_data",
+                          {"dst_size.x", "dst_size.y", "dst_size.z"},
+                          op_def.dst_tensors[0]);
 
   std::string c = GetCommonDefines(op_def.precision);
   const std::string postfix[] = {".x", ".y", ".z", ".w"};
 
-  const std::string batch_id = op_def.batch_support ? "batch_id" : "";
   c += "__kernel void main_function(\n";
   for (const auto& src : srcs) {
     c += src.GetDeclaration(AccessType::READ) + ",\n";
@@ -59,21 +64,13 @@
   c += dst.GetDeclaration(AccessType::WRITE);
   c += GetArgsDeclaration(linked_operations);
   for (int i = 0; i < channels.size(); ++i) {
-    const std::string uniform_name = "src_size_" + std::to_string(i);
-    c += "    int4 " + uniform_name + ",\n";
-  }
-  if (op_def.batch_support) {
-    c += "    int BATCH_SIZE,  \n";
+    c += "    int " + GetSrcDepthSizeVar(i) + ",\n";
   }
   c += "    int4 dst_size\n";
   c += ") {\n";
   c += "  int X = get_global_id(0);\n";
   c += "  int Y = get_global_id(1);\n";
-  c += "  if (X >= dst_size.x || Y >= dst_size.y) return;\n";
-  if (op_def.batch_support) {
-    c += "  int batch_id = get_global_id(2);\n";
-    c += "  if (batch_id >= BATCH_SIZE) return;\n";
-  }
+  c += "  if (X >= dst_size.x || Y >= dst_size.y) return; \n";
 
   if (IsAllChannelsX4(channels)) {
     // When all channels % 4 == 0 we can read/assign/write FLT4 elements easily.
@@ -81,37 +78,35 @@
     // generation.
     c += "  int Z = 0;\n";
     for (int i = 0; i < channels.size(); ++i) {
-      const std::string uniform_name = "src_size_" + std::to_string(i);
       const int depth = IntegralDivideRoundUp(channels[i], 4);
       if (depth % 2 == 0) {
         // We can read more at once inside of loop in case depth % 2 == 0
         // it should be better for reading latency hiding
-        c += "  for (int i = 0; i < " + uniform_name + ".w; i += 2) {\n";
+        c += "  for (int i = 0; i < " + GetSrcDepthSizeVar(i) + "; i += 2) {\n";
         c += "    FLT4 result0 = " +
-             srcs[i].Read4D("X", "Y", "i", batch_id,
-                            TextureAddressMode::DONT_CARE) +
+             srcs[i].Read3D("X", "Y", "i", TextureAddressMode::DONT_CARE) +
              ";\n";
         c += "    FLT4 result1 = " +
-             srcs[i].Read4D("X", "Y", "i + 1", batch_id,
-                            TextureAddressMode::DONT_CARE) +
+             srcs[i].Read3D("X", "Y", "i + 1", TextureAddressMode::DONT_CARE) +
              ";\n";
+        c += "    " + dst.GetAddress("dst_adr0", "X", "Y", "Z") + "\n";
+        c += "    " + dst.GetAddress("dst_adr1", "X", "Y", "Z + 1") + "\n";
         const LinkingContext context_0{"result0", "X", "Y", "Z"};
         const LinkingContext context_1{"result1", "X", "Y", "Z + 1"};
         c += PostProcess(linked_operations, context_0);
         c += PostProcess(linked_operations, context_1);
-        c += "    " + dst.Write4D("result0", "X", "Y", "Z", batch_id);
-        c += "    " + dst.Write4D("result1", "X", "Y", "Z + 1", batch_id);
+        c += "    " + dst.Write3D("result0", "X", "Y", "Z");
+        c += "    " + dst.Write3D("result1", "X", "Y", "Z + 1");
         c += "    Z += 2;\n";
         c += "  }\n";
       } else {
-        c += "  for (int i = 0; i < " + uniform_name + ".w; ++i) {\n";
+        c += "  for (int i = 0; i < " + GetSrcDepthSizeVar(i) + "; ++i) {\n";
         c += "    FLT4 result = " +
-             srcs[i].Read4D("X", "Y", "i", batch_id,
-                            TextureAddressMode::DONT_CARE) +
+             srcs[i].Read3D("X", "Y", "i", TextureAddressMode::DONT_CARE) +
              ";\n";
         const LinkingContext context{"result", "X", "Y", "Z"};
         c += PostProcess(linked_operations, context);
-        c += "    " + dst.Write4D("result", "X", "Y", "Z", batch_id);
+        c += "    " + dst.Write3D("result", "X", "Y", "Z");
         c += "    Z++;\n";
         c += "  }\n";
       }
@@ -126,8 +121,8 @@
       for (int d = 0; d < depth; ++d) {
         const int channels_in_group = std::min(4, channels[i] - d * 4);
         const std::string temp_name = "t" + std::to_string(read_index);
-        c += "  FLT4 " + temp_name + " = " +
-             srcs[i].Read4D("X", "Y", std::to_string(d), batch_id,
+        c += "  FLT4 " + temp_name + " = ";
+        c += srcs[i].Read3D("X", "Y", std::to_string(d),
                             TextureAddressMode::DONT_CARE) +
              ";\n";
         for (int ch = 0; ch < channels_in_group; ++ch) {
@@ -139,8 +134,7 @@
             c += "  {\n";
             const LinkingContext context{"result", "X", "Y", std::to_string(z)};
             c += PostProcess(linked_operations, context);
-            c += "  " +
-                 dst.Write4D("result", "X", "Y", std::to_string(z), batch_id);
+            c += "  " + dst.Write3D("result", "X", "Y", std::to_string(z));
             c += "  }\n";
             z++;
           }
@@ -152,7 +146,7 @@
       c += "  {\n";
       const LinkingContext context{"result", "X", "Y", std::to_string(z)};
       c += PostProcess(linked_operations, context);
-      c += "  " + dst.Write4D("result", "X", "Y", "Z", std::to_string(z));
+      c += "  " + dst.Write3D("result", "X", "Y", std::to_string(z));
       c += "  }\n";
     }
   }
@@ -199,21 +193,16 @@
   RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
   RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
   for (int i = 0; i < channels_.size(); ++i) {
-    int4 size(src_[i]->Width(), src_[i]->Height(), channels_[i],
-              IntegralDivideRoundUp(channels_[i], 4));
-    RETURN_IF_ERROR(kernel_.SetBytesAuto(size));
+    RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->Depth()));
   }
-  if (definition_.batch_support) {
-    RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
-  }
-  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
+  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
   return OkStatus();
 }
 
 int3 ConcatZ::GetGridSize() const {
-  const int grid_x = dst_[0]->Width();
+  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
   const int grid_y = dst_[0]->Height();
-  const int grid_z = dst_[0]->Batch();
+  const int grid_z = 1;
   return int3(grid_x, grid_y, grid_z);
 }