Overloaded Read4D/Write4D to make transition to batch support and general code generation easier to read.

PiperOrigin-RevId: 272345859
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
index cc49298..d940134 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.cc
@@ -39,17 +39,9 @@
   }
   TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
 
-  auto read_src = [&](const TensorCodeGenerator& tensor, const std::string& x,
-                      const std::string& y, const std::string& z) {
-    if (op_def.batch_support) {
-      return tensor.Read4D(x, y, z, "B");
-    } else {
-      return tensor.Read3D(x, y, z, TextureAddressMode::DONT_CARE);
-    }
-  };
-
   std::string c = GetCommonDefines(op_def.precision);
 
+  const std::string batch_id = op_def.batch_support ? "batch_id" : "";
   c += "__kernel void main_function(\n";
   for (const auto& src : srcs) {
     c += src.GetDeclaration(AccessType::READ) + ",\n";
@@ -72,9 +64,9 @@
   c += "  int X = get_global_id(0);\n";
   c += "  int Y = get_global_id(1);\n";
   if (op_def.batch_support) {
-    c += "  int B = get_global_id(2) / dst_size.w;\n";
-    c += "  int Z = get_global_id(2) - B * dst_size.w;\n";
-    c += "  if (Z >= dst_size.w || B >= BATCH_SIZE) return;\n";
+    c += "  int batch_id = get_global_id(2) / dst_size.w;\n";
+    c += "  int Z = get_global_id(2) - batch_id * dst_size.w;\n";
+    c += "  if (Z >= dst_size.w || batch_id >= BATCH_SIZE) return;\n";
   } else {
     c += "  int Z = get_global_id(2);\n";
     c += "  if (Z >= dst_size.w) return;\n";
@@ -83,16 +75,15 @@
     const std::string offset_name = "dst_offset_" + std::to_string(i);
     const std::string size_name = "src_size_" + std::to_string(i);
     c += "  if (X < " + size_name + ".x && Y < " + size_name + ".y) { \n";
-    c += "    FLT4 result = " + read_src(srcs[i], "X", "Y", "Z") + ";\n";
+    c +=
+        "    FLT4 result = " +
+        srcs[i].Read4D("X", "Y", "Z", batch_id, TextureAddressMode::DONT_CARE) +
+        ";\n";
     c += "    int dst_x = X + " + offset_name + ".x;\n";
     c += "    int dst_y = Y + " + offset_name + ".y;\n";
     const LinkingContext context{"result", "dst_x", "dst_y", "Z"};
     c += PostProcess(linked_operations, context);
-    if (op_def.batch_support) {
-      c += "    " + dst.Write4D("result", "dst_x", "dst_y", "Z", "B");
-    } else {
-      c += "    " + dst.Write3D("result", "dst_x", "dst_y", "Z");
-    }
+    c += "    " + dst.Write4D("result", "dst_x", "dst_y", "Z", batch_id);
     c += "  } \n";
   }
   c += "}\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
index 97e414d..6712379 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/concat_z.cc
@@ -48,27 +48,10 @@
   }
   TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
 
-  auto read_src = [&](const TensorCodeGenerator& tensor, const std::string& x,
-                      const std::string& y, const std::string& z) {
-    if (op_def.batch_support) {
-      return tensor.Read4D(x, y, z, "B");
-    } else {
-      return tensor.Read3D(x, y, z, TextureAddressMode::DONT_CARE);
-    }
-  };
-
-  auto write_dst = [&](const std::string& var_name, const std::string& x,
-                       const std::string& y, const std::string& z) {
-    if (op_def.batch_support) {
-      return dst.Write4D(var_name, x, y, z, "B");
-    } else {
-      return dst.Write3D(var_name, x, y, z);
-    }
-  };
-
   std::string c = GetCommonDefines(op_def.precision);
   const std::string postfix[] = {".x", ".y", ".z", ".w"};
 
+  const std::string batch_id = op_def.batch_support ? "batch_id" : "";
   c += "__kernel void main_function(\n";
   for (const auto& src : srcs) {
     c += src.GetDeclaration(AccessType::READ) + ",\n";
@@ -88,8 +71,8 @@
   c += "  int Y = get_global_id(1);\n";
   c += "  if (X >= dst_size.x || Y >= dst_size.y) return;\n";
   if (op_def.batch_support) {
-    c += "  int B = get_global_id(2);\n";
-    c += "  if (B >= BATCH_SIZE) return;\n";
+    c += "  int batch_id = get_global_id(2);\n";
+    c += "  if (batch_id >= BATCH_SIZE) return;\n";
   }
 
   if (IsAllChannelsX4(channels)) {
@@ -104,22 +87,31 @@
         // We can read more at once inside of loop in case depth % 2 == 0
         // it should be better for reading latency hiding
         c += "  for (int i = 0; i < " + uniform_name + ".w; i += 2) {\n";
-        c += "    FLT4 result0 = " + read_src(srcs[i], "X", "Y", "i") + ";\n";
-        c += "    FLT4 result1 = " + read_src(srcs[i], "X", "Y", "i+1") + ";\n";
+        c += "    FLT4 result0 = " +
+             srcs[i].Read4D("X", "Y", "i", batch_id,
+                            TextureAddressMode::DONT_CARE) +
+             ";\n";
+        c += "    FLT4 result1 = " +
+             srcs[i].Read4D("X", "Y", "i + 1", batch_id,
+                            TextureAddressMode::DONT_CARE) +
+             ";\n";
         const LinkingContext context_0{"result0", "X", "Y", "Z"};
         const LinkingContext context_1{"result1", "X", "Y", "Z + 1"};
         c += PostProcess(linked_operations, context_0);
         c += PostProcess(linked_operations, context_1);
-        c += "    " + write_dst("result0", "X", "Y", "Z");
-        c += "    " + write_dst("result1", "X", "Y", "Z + 1");
+        c += "    " + dst.Write4D("result0", "X", "Y", "Z", batch_id);
+        c += "    " + dst.Write4D("result1", "X", "Y", "Z + 1", batch_id);
         c += "    Z += 2;\n";
         c += "  }\n";
       } else {
         c += "  for (int i = 0; i < " + uniform_name + ".w; ++i) {\n";
-        c += "    FLT4 result = " + read_src(srcs[i], "X", "Y", "i") + ";\n";
+        c += "    FLT4 result = " +
+             srcs[i].Read4D("X", "Y", "i", batch_id,
+                            TextureAddressMode::DONT_CARE) +
+             ";\n";
         const LinkingContext context{"result", "X", "Y", "Z"};
         c += PostProcess(linked_operations, context);
-        c += "    " + write_dst("result", "X", "Y", "Z");
+        c += "    " + dst.Write4D("result", "X", "Y", "Z", batch_id);
         c += "    Z++;\n";
         c += "  }\n";
       }
@@ -135,7 +127,9 @@
         const int channels_in_group = std::min(4, channels[i] - d * 4);
         const std::string temp_name = "t" + std::to_string(read_index);
         c += "  FLT4 " + temp_name + " = " +
-             read_src(srcs[i], "X", "Y", std::to_string(d)) + ";\n";
+             srcs[i].Read4D("X", "Y", std::to_string(d), batch_id,
+                            TextureAddressMode::DONT_CARE) +
+             ";\n";
         for (int ch = 0; ch < channels_in_group; ++ch) {
           c += "  result" + postfix[out_channel] + " = ";
           c += temp_name + postfix[ch] + ";\n";
@@ -145,7 +139,8 @@
             c += "  {\n";
             const LinkingContext context{"result", "X", "Y", std::to_string(z)};
             c += PostProcess(linked_operations, context);
-            c += "  " + write_dst("result", "X", "Y", std::to_string(z));
+            c += "  " +
+                 dst.Write4D("result", "X", "Y", std::to_string(z), batch_id);
             c += "  }\n";
             z++;
           }
@@ -157,7 +152,7 @@
       c += "  {\n";
       const LinkingContext context{"result", "X", "Y", std::to_string(z)};
       c += PostProcess(linked_operations, context);
-      c += "  " + write_dst("result", "X", "Y", std::to_string(z));
+      c += "  " + dst.Write4D("result", "X", "Y", "Z", std::to_string(z));
       c += "  }\n";
     }
   }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_batched.cc b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_batched.cc
index 44abb02..d451479 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_batched.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_batched.cc
@@ -48,7 +48,9 @@
   c += "  if (Z >= dst_size.w || B >= BATCH_SIZE) return;\n";
   c += "  ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);\n";
   c += "  for (int i = 0; i < src_size.w; ++i) {\n";
-  c += "    FLT4 v = " + src_tensor.Read4D("0", "0", "i", "B") + ";\n";
+  c += "    FLT4 v = " +
+       src_tensor.Read4D("0", "0", "i", "B", TextureAddressMode::DONT_CARE) +
+       ";\n";
   c += "    FLT4 m0 = READ_IMAGE(filters, smp_none, (int2)(Z * 4 + 0, i));\n";
   c += "    FLT4 m1 = READ_IMAGE(filters, smp_none, (int2)(Z * 4 + 1, i));\n";
   c += "    FLT4 m2 = READ_IMAGE(filters, smp_none, (int2)(Z * 4 + 2, i));\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc
index 69c18d4..737e203 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax.cc
@@ -32,15 +32,7 @@
   TensorCodeGenerator src_tensor("src_data", "size", op_def.src_tensors[0]);
   TensorCodeGenerator dst_tensor("dst_data", "size", op_def.dst_tensors[0]);
 
-  auto read_src = [&](const std::string& x, const std::string& y,
-                      const std::string& z) {
-    if (op_def.batch_support) {
-      return src_tensor.ReadAsFloat4D(x, y, z, "B");
-    } else {
-      return src_tensor.ReadAsFloat3D(x, y, z, TextureAddressMode::DONT_CARE);
-    }
-  };
-
+  const std::string batch_id = op_def.batch_support ? "batch_id" : "";
   std::string code = GetCommonDefines(op_def.precision);
   code += "__kernel void main_function(\n";
   code += src_tensor.GetDeclaration(AccessType::READ);
@@ -56,29 +48,34 @@
   code += "  int Y = get_global_id(1);\n";
   code += "  if (X >= size.x || Y >= size.y) return; \n";
   if (op_def.batch_support) {
-    code += "  int B = get_global_id(2);\n";
-    code += "  if (B >= BATCH_SIZE) return;\n";
+    code += "  int batch_id = get_global_id(2);\n";
+    code += "  if (batch_id >= BATCH_SIZE) return;\n";
   }
   code += "  float sum = 0.0f;\n";
   code += "  for (int d = 0; d < size.w - 1; ++d) {\n";
-  code += "    float4 t = " + read_src("X", "Y", "d") + ";\n";
+  code += "    float4 t = " +
+          src_tensor.ReadAsFloat4D("X", "Y", "d", batch_id,
+                                   TextureAddressMode::DONT_CARE) +
+          ";\n";
   code += "    sum += dot((float4)(1.0f), exp(t));\n";
   code += "  }\n";
   code += "  {\n";
-  code += "    float4 t = " + read_src("X", "Y", "size.w - 1") + ";\n";
+  code += "    float4 t = " +
+          src_tensor.ReadAsFloat4D("X", "Y", "size.w - 1", batch_id,
+                                   TextureAddressMode::DONT_CARE) +
+          ";\n";
   code += "    sum += dot(mask, exp(t));\n";
   code += "  }\n";
   code += "  for (int d = 0; d < size.w; ++d) {\n";
-  code += "    float4 t = " + read_src("X", "Y", "d") + ";\n";
+  code += "    float4 t = " +
+          src_tensor.ReadAsFloat4D("X", "Y", "d", batch_id,
+                                   TextureAddressMode::DONT_CARE) +
+          ";\n";
   code += "    t = exp(t) / sum;\n";
   code += "    FLT4 result = TO_FLT4(t);\n";
   const LinkingContext context{"result", "X", "Y", "d"};
   code += PostProcess(linked_operations, context);
-  if (op_def.batch_support) {
-    code += "    " + dst_tensor.Write4D("result", "X", "Y", "d", "B");
-  } else {
-    code += "    " + dst_tensor.Write3D("result", "X", "Y", "d");
-  }
+  code += "    " + dst_tensor.Write4D("result", "X", "Y", "d", batch_id);
   code += "  }\n";
   code += "}\n";
   return code;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
index 7666bb7..5070459 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
@@ -33,15 +33,7 @@
   TensorCodeGenerator dst_tensor("dst_data", "tensor_size",
                                  op_def.dst_tensors[0]);
 
-  auto read_src = [&](const std::string& x, const std::string& y,
-                      const std::string& z) {
-    if (op_def.batch_support) {
-      return src_tensor.ReadAsFloat4D(x, y, z, "B");
-    } else {
-      return src_tensor.ReadAsFloat3D(x, y, z, TextureAddressMode::DONT_CARE);
-    }
-  };
-
+  const std::string batch_id = op_def.batch_support ? "batch_id" : "";
   std::string c = GetCommonDefines(op_def.precision);
   c += "__kernel void main_function(\n";
   c += src_tensor.GetDeclaration(AccessType::READ);
@@ -55,8 +47,8 @@
   c += "    float4 mask\n";
   c += ") {\n";
   if (op_def.batch_support) {
-    c += "  int B = get_global_id(1);\n";
-    c += "  if (B >= BATCH_SIZE) return;\n";
+    c += "  int batch_id = get_global_id(1);\n";
+    c += "  if (batch_id >= BATCH_SIZE) return;\n";
   }
   c += "  int offset = 0;\n";
   c += "  float sum = 0.0f;\n";
@@ -66,7 +58,10 @@
   c += "    int z = offset + tid;\n";
   c += "    if (z < size.x) {\n";
   c += "      float4 mask_temp = z == size.x - 1 ? mask : (float4)(1.0f);\n";
-  c += "      float4 src = " + read_src("0", "0", "z") + ";\n";
+  c += "      float4 src = " +
+       src_tensor.ReadAsFloat4D("0", "0", "z", batch_id,
+                                TextureAddressMode::DONT_CARE) +
+       ";\n";
   c += "      sum += dot(mask_temp, exp(src));\n";
   c += "      offset += 32;\n";
   c += "    }\n";
@@ -96,14 +91,13 @@
   c += "  do {\n";
   c += "    int z = offset + tid;\n";
   c += "    if (z < size.x) {\n";
-  c += "      FLT4 res = TO_FLT4(exp(" + read_src("0", "0", "z") + ")*sum);\n";
+  c += "      FLT4 res = TO_FLT4(exp(" +
+       src_tensor.ReadAsFloat4D("0", "0", "z", batch_id,
+                                TextureAddressMode::DONT_CARE) +
+       ")*sum);\n";
   const LinkingContext context{"res", "0", "0", "z"};
   c += PostProcess(linked_operations, context);
-  if (op_def.batch_support) {
-    c += "    " + dst_tensor.Write4D("res", "0", "0", "z", "B");
-  } else {
-    c += "    " + dst_tensor.Write3D("res", "0", "0", "z");
-  }
+  c += "    " + dst_tensor.Write4D("res", "0", "0", "z", batch_id);
   c += "      offset += 32;\n";
   c += "    }\n";
   c += "    s++;\n";
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
index 47f40fa..6c0c549 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
@@ -165,9 +165,9 @@
 std::string TensorCodeGenerator::Read4D(const std::string& x,
                                         const std::string& y,
                                         const std::string& z,
-                                        const std::string& b) const {
-  return Read(GetGlobalAddressNoDeclaration(x, y, z, b),
-              TextureAddressMode::DONT_CARE);
+                                        const std::string& b,
+                                        TextureAddressMode address_mode) const {
+  return Read(GetGlobalAddressNoDeclaration(x, y, z, b), address_mode);
 }
 
 std::string TensorCodeGenerator::ReadAsFloat3D(
@@ -219,6 +219,9 @@
 std::string TensorCodeGenerator::GetGlobalAddressNoDeclaration(
     const std::string& x, const std::string& y, const std::string& z,
     const std::string& b) const {
+  if (b.empty()) {
+    return GetGlobalAddressNoDeclaration(x, y, z);
+  }
   switch (descriptor_.storage_type) {
     case TensorStorageType::BUFFER:
     case TensorStorageType::IMAGE_BUFFER:
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.h b/tensorflow/lite/delegates/gpu/cl/kernels/util.h
index 85d7bad..e5ed20f 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/util.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.h
@@ -65,8 +65,10 @@
       TextureAddressMode address_mode = TextureAddressMode::ZERO) const;
 
   // Read4D supports BUFFER and IMAGE_BUFFER storage types.
-  std::string Read4D(const std::string& x, const std::string& y,
-                     const std::string& z, const std::string& b) const;
+  std::string Read4D(
+      const std::string& x, const std::string& y, const std::string& z,
+      const std::string& b,
+      TextureAddressMode address_mode = TextureAddressMode::ZERO) const;
 
   // Optimization for textures, so as in opencl we can use read_imagef for any
   // texture type.