Proper support of tensors without automatic zero clamp in MaxUnpooling.

PiperOrigin-RevId: 459746642
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
index be78e16..23adc6f 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc
@@ -612,7 +612,7 @@
     case OperationType::MAX_UNPOOLING_2D: {
       auto attr =
           absl::any_cast<MaxUnpooling2DAttributes>(node.operation.attributes);
-      *gpu_op = SelectMaxUnpooling(attr, op_def);
+      *gpu_op = SelectMaxUnpooling(attr, gpu_info, op_def);
       return absl::OkStatus();
     }
     case OperationType::MEAN: {
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
index d594c3c..f95b372 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc
@@ -77,8 +77,10 @@
 }
 
 std::unique_ptr<GPUOperation> SelectMaxUnpooling(
-    const MaxUnpooling2DAttributes& attr, const OperationDef& op_def) {
-  return std::make_unique<GPUOperation>(CreateMaxUnpooling(op_def, attr));
+    const MaxUnpooling2DAttributes& attr, const GpuInfo& gpu_info,
+    const OperationDef& op_def) {
+  return std::make_unique<GPUOperation>(
+      CreateMaxUnpooling(gpu_info, op_def, attr));
 }
 
 void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
index ae2900c..e4d7c36 100644
--- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h
@@ -42,7 +42,8 @@
                                             const OperationDef& op_def);
 
 std::unique_ptr<GPUOperation> SelectMaxUnpooling(
-    const MaxUnpooling2DAttributes& attr, const OperationDef& op_def);
+    const MaxUnpooling2DAttributes& attr, const GpuInfo& gpu_info,
+    const OperationDef& op_def);
 
 void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
                int dst_channels, std::unique_ptr<GPUOperation>* ptr);
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/BUILD
index 52e5f8f..d53da31 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/tasks/BUILD
@@ -638,7 +638,6 @@
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
-        "//tensorflow/lite/delegates/gpu/common/task:work_group_picking",
     ],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc
index 0c9564a..ddf74a2 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc
@@ -17,38 +17,40 @@
 
 #include <string>
 
-#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
-
 namespace tflite {
 namespace gpu {
-
 namespace {
-std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def,
+void AppendConditionally(const std::string& value, const std::string& delimeter,
+                         std::string* result) {
+  if (!result->empty()) {
+    *result += delimeter;
+  }
+  *result += value;
+}
+
+std::string GetMaxUnpoolingKernelCode(const GpuInfo& gpu_info,
+                                      const OperationDef& op_def,
                                       GPUOperation* op) {
-  auto src_desc = op_def.src_tensors[0];
-  if (op_def.IsBatchSupported()) {
-    src_desc.SetStateVar("BatchedWidth", "true");
-  }
-  op->AddSrcTensor("src_tensor", src_desc);
-  auto src_ind_desc = op_def.src_tensors[1];
-  if (op_def.IsBatchSupported()) {
-    src_ind_desc.SetStateVar("BatchedWidth", "true");
-  }
-  op->AddSrcTensor("src_indices", src_ind_desc);
-  auto dst_desc = op_def.dst_tensors[0];
-  if (op_def.IsBatchSupported()) {
-    dst_desc.SetStateVar("BatchedWidth", "true");
-  }
-  op->AddDstTensor("dst_tensor", dst_desc);
+  op->AddSrcTensor("src_tensor", op_def.src_tensors[0]);
+  op->AddSrcTensor("src_indices", op_def.src_tensors[1]);
+  op->AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
 
   std::string c;
   c += "MAIN_FUNCTION($0) {\n";
-  c += "  int X = GLOBAL_ID_0;\n";
+  if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
+    c += "  int linear_id = GLOBAL_ID_0;\n";
+    c += "  int X = linear_id / args.dst_tensor.Batch();\n";
+    c += "  int B = linear_id % args.dst_tensor.Batch();\n";
+    c += "  args.src_tensor.SetBatchRef(B);\n";
+    c += "  args.src_indices.SetBatchRef(B);\n";
+    c += "  args.dst_tensor.SetBatchRef(B);\n";
+  } else {
+    c += "  int X = GLOBAL_ID_0;\n";
+  }
   if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
     c += "  int linear_id_1 = GLOBAL_ID_1;\n";
     c += "  int Y = linear_id_1 / args.dst_tensor.Depth();\n";
     c += "  int Z = linear_id_1 % args.dst_tensor.Depth();\n";
-    c += "  int src_z = (Z + args.padding_z) / args.stride_z;\n";
   } else {
     c += "  int Y = GLOBAL_ID_1;\n";
   }
@@ -57,72 +59,66 @@
        "S >= args.dst_tensor.Slices()) { \n";
   c += "    return; \n";
   c += "  } \n";
-  if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
-    c += "  int linear_id_0 = GLOBAL_ID_0;\n";
-    c += "  int X0 = linear_id_0 / args.dst_tensor.Batch();\n";
-    c += "  int B = linear_id_0 % args.dst_tensor.Batch();\n";
-    c += "  int src_x0 = (X0 + args.padding_x * args.dst_tensor.Batch()) / "
-         "args.stride_x;\n";
-    c += "  int src_x = src_x0 * args.dst_tensor.Batch() + B;\n";
-  } else {
-    c += "  int src_x = (X + args.padding_x) / args.stride_x;\n";
-  }
+  c += "  int src_x = (X + args.padding_x) / args.stride_x;\n";
+  c += "  int t_x = X - (src_x * args.stride_x - args.padding_x);\n";
   c += "  int src_y = (Y + args.padding_y) / args.stride_y;\n";
-  std::string src_args = op_def.dst_tensors[0].HasAxis(Axis::DEPTH)
-                             ? "src_x, src_y, src_z, S"
-                             : "src_x, src_y, S";
-  if (op_def.src_tensors[0].GetStorageType() == TensorStorageType::BUFFER) {
-    if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-      c += "  bool outside = src_x < 0 || src_y < 0 || src_z < 0 || src_x >= "
-           "args.src_tensor.Width() || src_y >= args.src_tensor.Height() || "
-           "src_z >= args.src_tensor.Depth();\n";
-    } else {
-      c += "  bool outside = src_x < 0 || src_y < 0 || src_x >= "
-           "args.src_tensor.Width() || src_y >= args.src_tensor.Height();\n";
-    }
-    c += "  FLT4 src = INIT_FLT4(0.0f);\n";
-    c += "  int4 ind = INIT_INT4v4(0, 0, 0, 0);\n";
-    c += "  if (!outside) {\n";
-    c += "    src = args.src_tensor.Read(" + src_args + ");\n";
-    c += "    ind = args.src_indices.Read<int>(" + src_args + ");\n";
-    c += "  }\n";
-  } else {
-    c += "  FLT4 src = args.src_tensor.Read(" + src_args + ");\n";
-    c += "  int4 ind = args.src_indices.Read<int>(" + src_args + ");\n";
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
-    c += "  int t_x = X0 - (src_x0 * args.stride_x - args.padding_x * "
-         "args.dst_tensor.Batch());\n";
-  } else {
-    c += "  int t_x = X - (src_x * args.stride_x - args.padding_x);\n";
-  }
   c += "  int t_y = Y - (src_y * args.stride_y - args.padding_y);\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
+    c += "  int src_z = (Z + args.padding_z) / args.stride_z;\n";
     c += "  int t_z = Z - (src_z * args.stride_z - args.padding_z);\n";
     c += "  int t_index = (t_y * args.kernel_size_x + t_x) * "
          "args.kernel_size_z + t_z;\n";
   } else {
     c += "  int t_index = t_y * args.kernel_size_x + t_x;\n";
   }
-  c += "  FLT4 result;\n";
-  const std::string channels[] = {".x", ".y", ".z", ".w"};
-  for (int i = 0; i < 4; ++i) {
-    const auto& s = channels[i];
-    c += "  result" + s + "= t_index == ind" + s + "? src" + s +
-         ": INIT_FLT(0.0f);\n";
+  std::string inbounds_check;
+  if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::WIDTH, gpu_info) ||
+      !op_def.src_tensors[1].SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
+    c += "  bool inside_x = src_x >= 0 && src_x < args.src_tensor.Width();\n";
+    c += "  src_x = clamp(src_x, 0, args.src_tensor.Width() - 1);\n";
+    AppendConditionally("inside_x", " && ", &inbounds_check);
   }
+  if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::HEIGHT, gpu_info) ||
+      !op_def.src_tensors[1].SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
+    c += "  bool inside_y = src_y >= 0 && src_y < args.src_tensor.Height();\n";
+    c += "  src_y = clamp(src_y, 0, args.src_tensor.Height() - 1);\n";
+    AppendConditionally("inside_y", " && ", &inbounds_check);
+  }
+  if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
+    if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::DEPTH, gpu_info) ||
+        !op_def.src_tensors[1].SupportsZeroClamp(Axis::DEPTH, gpu_info)) {
+      c += "  bool inside_z = src_z >= 0 && src_z < args.src_tensor.Depth();\n";
+      c += "  src_z = clamp(src_z, 0, args.src_tensor.Depth() - 1);\n";
+      AppendConditionally("inside_z", " && ", &inbounds_check);
+    }
+  }
+  std::string src_args = op_def.dst_tensors[0].HasAxis(Axis::DEPTH)
+                             ? "src_x, src_y, src_z, S"
+                             : "src_x, src_y, S";
+  c +=
+      "  args.src_tensor::type src = args.src_tensor.Read(" + src_args + ");\n";
+  c += "  int4 ind = args.src_indices.Read<int>(" + src_args + ");\n";
+  if (!inbounds_check.empty()) {
+    c += "  src *= INIT_FLT(" + inbounds_check + ");\n";
+    c += "  ind *= INIT_INT(" + inbounds_check + ");\n";
+  }
+  c += "  args.src_tensor::type result;\n";
+  c += "  result.x = t_index == ind.x ? src.x : INIT_FLT(0.0f);\n";
+  c += "  result.y = t_index == ind.y ? src.y : INIT_FLT(0.0f);\n";
+  c += "  result.z = t_index == ind.z ? src.z : INIT_FLT(0.0f);\n";
+  c += "  result.w = t_index == ind.w ? src.w : INIT_FLT(0.0f);\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
     c += "  args.dst_tensor.Write(result, X, Y, Z, S);\n";
   } else {
     c += "  args.dst_tensor.Write(result, X, Y, S);\n";
   }
   c += "}\n";
-
   return c;
 }
 }  // namespace
 
-GPUOperation CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info,
+                                const OperationDef& definition,
                                 const MaxUnpooling2DAttributes& attr) {
   GPUOperation op(definition);
   op.args_.AddInt("kernel_size_x", attr.kernel.w);
@@ -131,12 +127,13 @@
   op.args_.AddInt("kernel_size_y", attr.kernel.h);
   op.args_.AddInt("padding_y", attr.padding.appended.h);
   op.args_.AddInt("stride_y", attr.strides.h);
-  op.code_ = GetMaxUnpoolingKernelCode(definition, &op);
+  op.code_ = GetMaxUnpoolingKernelCode(gpu_info, definition, &op);
   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
   return op;
 }
 
-GPUOperation CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info,
+                                const OperationDef& definition,
                                 const MaxUnpooling3DAttributes& attr) {
   GPUOperation op(definition);
   op.args_.AddInt("kernel_size_x", attr.kernel.w);
@@ -148,7 +145,7 @@
   op.args_.AddInt("kernel_size_z", attr.kernel.d);
   op.args_.AddInt("padding_z", attr.padding.appended.d);
   op.args_.AddInt("stride_z", attr.strides.d);
-  op.code_ = GetMaxUnpoolingKernelCode(definition, &op);
+  op.code_ = GetMaxUnpoolingKernelCode(gpu_info, definition, &op);
   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
   return op;
 }
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h
index 6e90d37..d6b0bd6 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h
+++ b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h
@@ -24,10 +24,12 @@
 namespace tflite {
 namespace gpu {
 
-GPUOperation CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info,
+                                const OperationDef& definition,
                                 const MaxUnpooling2DAttributes& attr);
 
-GPUOperation CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info,
+                                const OperationDef& definition,
                                 const MaxUnpooling3DAttributes& attr);
 
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc
index d70ef1a..43d2fdd 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc
@@ -50,7 +50,8 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      GPUOperation operation = CreateMaxUnpooling(op_def, attr);
+      GPUOperation operation =
+          CreateMaxUnpooling(env->GetGpuInfo(), op_def, attr);
       RETURN_IF_ERROR(env->ExecuteGPUOperation(
           {src_tensor, src_ind_tensor},
           std::make_unique<GPUOperation>(std::move(operation)),