Added batch support for elementwise operations.

PiperOrigin-RevId: 272974762
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
index b2ca756..7a1a070 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.cc
@@ -26,13 +26,16 @@
 namespace cl {
 
 std::string Add::GetElementWiseCode(
-    const TensorDescriptor& src_descriptor,
-    const TensorDescriptor& dst_descriptor, CalculationsPrecision precision,
+    const OperationDef& op_def,
     const std::vector<ElementwiseOperation*>& linked_operations) {
-  TensorCodeGenerator src_tensor("src_data", "src_size", src_descriptor);
-  TensorCodeGenerator dst_tensor("dst_data", "dst_size", dst_descriptor);
+  TensorCodeGenerator src_tensor("src_data",
+                                 {"src_size.x", "src_size.y", "src_size.z"},
+                                 op_def.src_tensors[0]);
+  TensorCodeGenerator dst_tensor("dst_data",
+                                 {"dst_size.x", "dst_size.y", "dst_size.z"},
+                                 op_def.dst_tensors[0]);
 
-  std::string c = GetCommonDefines(precision);
+  std::string c = GetCommonDefines(op_def.precision);
 
   c += "__kernel void main_function(\n";
   c += src_tensor.GetDeclaration(AccessType::READ);
@@ -45,7 +48,7 @@
   c += "  int X = get_global_id(0);\n";
   c += "  int Y = get_global_id(1);\n";
   c += "  int Z = get_global_id(2);\n";
-  c += "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) { \n";
+  c += "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) { \n";
   c += "    return; \n";
   c += "  } \n";
   c += "  FLT4 src = (FLT4)(0.0);\n";
@@ -106,8 +109,9 @@
         absl::StrCat("src_data_", link_index_, "_", i);
     const std::string size_name =
         "src_size_" + std::to_string(link_index_) + "_" + std::to_string(i);
-    TensorCodeGenerator src_tensor(tensor_name, size_name,
-                                   definition_.src_tensors[i]);
+    TensorCodeGenerator src_tensor(
+        tensor_name, {size_name + ".x", size_name + ".y", size_name + ".z"},
+        definition_.src_tensors[i]);
     if (src_depthes_[i] != dst_depth_) {
       absl::StrAppend(&result, "  if (", context.z_coord, " < ",
                       src_depthes_[i], ") {\n");
@@ -149,15 +153,13 @@
     RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[i]->GetMemoryPtr()));
   }
   for (int i = 1; i < src_depthes_.size(); ++i) {
-    RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetSizeWithDepth()));
+    RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetWBatchedHDB()));
   }
   return OkStatus();
 }
 
 Status Add::Compile(const CreationContext& creation_context) {
-  const auto code =
-      GetElementWiseCode(definition_.src_tensors[0], definition_.dst_tensors[0],
-                         definition_.precision, linked_operations_);
+  const auto code = GetElementWiseCode(definition_, linked_operations_);
   return creation_context.cache->GetOrCreateCLKernel(
       code, "main_function", *creation_context.context,
       *creation_context.device, &kernel_);
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/add.h b/tensorflow/lite/delegates/gpu/cl/kernels/add.h
index cad591b..ac6243c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/add.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/add.h
@@ -51,8 +51,7 @@
 
  private:
   std::string GetElementWiseCode(
-      const TensorDescriptor& src_descriptor,
-      const TensorDescriptor& dst_descriptor, CalculationsPrecision precision,
+      const OperationDef& op_def,
       const std::vector<ElementwiseOperation*>& linked_operations);
 
   int link_index_;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
index a3305f8..085c4e9 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
@@ -25,14 +25,16 @@
 namespace {
 
 std::string GetElementWiseCode(
-    const TensorDescriptor& src_descriptor,
-    const TensorDescriptor& dst_descriptor, CalculationsPrecision precision,
-    const ElementwiseOperation& op,
+    const OperationDef& op_def, const ElementwiseOperation& op,
     const std::vector<ElementwiseOperation*>& linked_operations) {
-  TensorCodeGenerator src_tensor("src_data", "src_size", src_descriptor);
-  TensorCodeGenerator dst_tensor("dst_data", "dst_size", dst_descriptor);
+  TensorCodeGenerator src_tensor("src_data",
+                                 {"src_size.x", "src_size.y", "src_size.z"},
+                                 op_def.src_tensors[0]);
+  TensorCodeGenerator dst_tensor("dst_data",
+                                 {"dst_size.x", "dst_size.y", "dst_size.z"},
+                                 op_def.dst_tensors[0]);
 
-  std::string c = GetCommonDefines(precision);
+  std::string c = GetCommonDefines(op_def.precision);
 
   c += "__kernel void main_function(\n";
   c += src_tensor.GetDeclaration(AccessType::READ);
@@ -45,7 +47,7 @@
   c += "  int X = get_global_id(0);\n";
   c += "  int Y = get_global_id(1);\n";
   c += "  int Z = get_global_id(2);\n";
-  c += "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) { \n";
+  c += "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) { \n";
   c += "    return; \n";
   c += "  } \n";
   c += "  FLT4 src = " +
@@ -144,22 +146,20 @@
   RETURN_IF_ERROR(BindArguments(&kernel_));
   RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
   RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
-  RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
-  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
+  RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB()));
+  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
   return OkStatus();
 }
 
 int3 ElementwiseOperation::GetGridSize() const {
-  const int grid_x = dst_[0]->Width();
+  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
   const int grid_y = dst_[0]->Height();
   const int grid_z = dst_[0]->Depth();
   return int3(grid_x, grid_y, grid_z);
 }
 
 Status ElementwiseOperation::Compile(const CreationContext& creation_context) {
-  const auto code =
-      GetElementWiseCode(definition_.src_tensors[0], definition_.dst_tensors[0],
-                         definition_.precision, *this, linked_operations_);
+  const auto code = GetElementWiseCode(definition_, *this, linked_operations_);
   return creation_context.cache->GetOrCreateCLKernel(
       code, "main_function", *creation_context.context,
       *creation_context.device, &kernel_);