ConcatXY reverted and rewritten to support batch implicitly.

PiperOrigin-RevId: 272784208
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
index d940134..789ba83 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
@@ -33,15 +33,17 @@
   std::vector<TensorCodeGenerator> srcs(tensors_count);
   for (int i = 0; i < tensors_count; ++i) {
     const std::string tensor_name = "src_data_" + std::to_string(i);
-    const std::string uniform_name = "src_size_" + std::to_string(i);
-    srcs[i] =
-        TensorCodeGenerator(tensor_name, uniform_name, op_def.src_tensors[i]);
+    const std::string width = "src_size_" + std::to_string(i) + ".x";
+    const std::string height = "src_size_" + std::to_string(i) + ".y";
+    srcs[i] = TensorCodeGenerator(tensor_name, {width, height, "dst_size.z"},
+                                  op_def.src_tensors[i]);
   }
-  TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
+  TensorCodeGenerator dst("dst_data",
+                          {"dst_size.x", "dst_size.y", "dst_size.z"},
+                          op_def.dst_tensors[0]);
 
   std::string c = GetCommonDefines(op_def.precision);
 
-  const std::string batch_id = op_def.batch_support ? "batch_id" : "";
   c += "__kernel void main_function(\n";
   for (const auto& src : srcs) {
     c += src.GetDeclaration(AccessType::READ) + ",\n";
@@ -52,38 +54,22 @@
     const std::string uniform_name = "src_size_" + std::to_string(i);
     c += "    int4 " + uniform_name + ",\n";
   }
-  for (int i = 0; i < tensors_count; ++i) {
-    const std::string uniform_name = "dst_offset_" + std::to_string(i);
-    c += "    int2 " + uniform_name + ",\n";
-  }
-  if (op_def.batch_support) {
-    c += "    int BATCH_SIZE,  \n";
-  }
   c += "    int4 dst_size  \n";
   c += ") {\n";
   c += "  int X = get_global_id(0);\n";
   c += "  int Y = get_global_id(1);\n";
-  if (op_def.batch_support) {
-    c += "  int batch_id = get_global_id(2) / dst_size.w;\n";
-    c += "  int Z = get_global_id(2) - batch_id * dst_size.w;\n";
-    c += "  if (Z >= dst_size.w || batch_id >= BATCH_SIZE) return;\n";
-  } else {
-    c += "  int Z = get_global_id(2);\n";
-    c += "  if (Z >= dst_size.w) return;\n";
-  }
+  c += "  int Z = get_global_id(2);\n";
+  c += "  if (Z >= dst_size.z) return;\n";
   for (int i = 0; i < tensors_count; ++i) {
-    const std::string offset_name = "dst_offset_" + std::to_string(i);
     const std::string size_name = "src_size_" + std::to_string(i);
     c += "  if (X < " + size_name + ".x && Y < " + size_name + ".y) { \n";
-    c +=
-        "    FLT4 result = " +
-        srcs[i].Read4D("X", "Y", "Z", batch_id, TextureAddressMode::DONT_CARE) +
-        ";\n";
-    c += "    int dst_x = X + " + offset_name + ".x;\n";
-    c += "    int dst_y = Y + " + offset_name + ".y;\n";
+    c += "    FLT4 result = " +
+         srcs[i].Read3D("X", "Y", "Z", TextureAddressMode::DONT_CARE) + ";\n";
+    c += "    int dst_x = X + " + size_name + ".z;\n";
+    c += "    int dst_y = Y + " + size_name + ".w;\n";
     const LinkingContext context{"result", "dst_x", "dst_y", "Z"};
     c += PostProcess(linked_operations, context);
-    c += "    " + dst.Write4D("result", "dst_x", "dst_y", "Z", batch_id);
+    c += "    " + dst.Write3D("result", "dst_x", "dst_y", "Z");
     c += "  } \n";
   }
   c += "}\n";
@@ -125,24 +111,17 @@
   }
   RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
   RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
-  int max_src_width = 0;
-  int max_src_height = 0;
-  for (int i = 0; i < tensors_count_; ++i) {
-    RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->GetSizeWithDepth()));
-    max_src_width = std::max(max_src_width, src_[i]->Width());
-    max_src_height = std::max(max_src_height, src_[i]->Height());
-  }
   int x_offset = 0;
   int y_offset = 0;
   for (int i = 0; i < tensors_count_; ++i) {
-    RETURN_IF_ERROR(kernel_.SetBytesAuto(int2(x_offset, y_offset)));
-    x_offset += attr_.axis == Axis::WIDTH ? src_[i]->Width() : 0;
-    y_offset += attr_.axis == Axis::HEIGHT ? src_[i]->Height() : 0;
+    const int width = src_[i]->Width() * src_[i]->Batch();
+    const int height = src_[i]->Height();
+    RETURN_IF_ERROR(
+        kernel_.SetBytesAuto(int4(width, height, x_offset, y_offset)));
+    x_offset += attr_.axis == Axis::WIDTH ? width : 0;
+    y_offset += attr_.axis == Axis::HEIGHT ? height : 0;
   }
-  if (definition_.batch_support) {
-    RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
-  }
-  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
+  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
   return OkStatus();
 }
 
@@ -154,9 +133,9 @@
     max_src_height = std::max(max_src_height, src_[i]->Height());
   }
 
-  const int grid_x = max_src_width;
+  const int grid_x = max_src_width * dst_[0]->Batch();
   const int grid_y = max_src_height;
-  const int grid_z = dst_[0]->Depth() * dst_[0]->Batch();
+  const int grid_z = dst_[0]->Depth();
 
   return int3(grid_x, grid_y, grid_z);
 }