MaxUnpooling converted to generic GPUOperation.

PiperOrigin-RevId: 328193073
Change-Id: Idf41aafe4b095dd2a6ea08fced44e04b3a67566c
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
index 97ee487..0bea5e4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
@@ -23,76 +23,26 @@
 namespace tflite {
 namespace gpu {
 namespace cl {
-
-MaxUnpooling::MaxUnpooling(const OperationDef& definition,
-                           const MaxUnpooling2DAttributes& attr)
-    : GPUOperation(definition),
-      stride_(attr.strides.w, attr.strides.h, 0, 0),
-      padding_(attr.padding.appended.w, attr.padding.appended.h, 0, 0),
-      kernel_size_(attr.kernel.w, attr.kernel.h, 0, 0) {
-  code_ = GetMaxUnpoolingKernelCode(definition_);
-}
-
-MaxUnpooling::MaxUnpooling(const OperationDef& definition,
-                           const MaxUnpooling3DAttributes& attr)
-    : GPUOperation(definition),
-      stride_(attr.strides.w, attr.strides.h, attr.strides.d, 0),
-      padding_(attr.padding.appended.w, attr.padding.appended.h,
-               attr.padding.appended.d, 0),
-      kernel_size_(attr.kernel.w, attr.kernel.h, attr.kernel.d, 0) {
-  code_ = GetMaxUnpoolingKernelCode(definition_);
-}
-
-MaxUnpooling::MaxUnpooling(MaxUnpooling&& kernel)
-    : GPUOperation(std::move(kernel)),
-      stride_(kernel.stride_),
-      padding_(kernel.padding_),
-      kernel_size_(kernel.kernel_size_) {}
-
-MaxUnpooling& MaxUnpooling::operator=(MaxUnpooling&& kernel) {
-  if (this != &kernel) {
-    std::swap(stride_, kernel.stride_);
-    std::swap(padding_, kernel.padding_);
-    std::swap(kernel_size_, kernel.kernel_size_);
-    GPUOperation::operator=(std::move(kernel));
-  }
-  return *this;
-}
-
-std::string MaxUnpooling::GetMaxUnpoolingKernelCode(
-    const OperationDef& op_def) {
+namespace {
+std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def,
+                                      GPUOperation* op) {
   auto src_desc = op_def.src_tensors[0];
   src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
   if (op_def.IsBatchSupported()) {
     src_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_tensor", src_desc);
+  op->AddSrcTensor("src_tensor", src_desc);
   auto src_ind_desc = op_def.src_tensors[1];
   src_ind_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
   if (op_def.IsBatchSupported()) {
     src_ind_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_indices", src_ind_desc);
+  op->AddSrcTensor("src_indices", src_ind_desc);
   auto dst_desc = op_def.dst_tensors[0];
   if (op_def.IsBatchSupported()) {
     dst_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddDstTensor("dst_tensor", dst_desc);
-  if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
-    args_.AddInt("kernel_size_x");
-    args_.AddInt("padding_x");
-    args_.AddInt("stride_x");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
-    args_.AddInt("kernel_size_y");
-    args_.AddInt("padding_y");
-    args_.AddInt("stride_y");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    args_.AddInt("kernel_size_z");
-    args_.AddInt("padding_z");
-    args_.AddInt("stride_z");
-  }
+  op->AddDstTensor("dst_tensor", dst_desc);
 
   std::string c = GetCommonDefines(op_def.precision);
   c += "__kernel void main_function(\n";
@@ -115,7 +65,8 @@
     c += "  int linear_id_0 = get_global_id(0);\n";
     c += "  int X0 = linear_id_0 / args.dst_tensor.Batch();\n";
     c += "  int B = linear_id_0 % args.dst_tensor.Batch();\n";
-    c += "  int src_x0 = (X0 + args.padding_x) / args.stride_x;\n";
+    c += "  int src_x0 = (X0 + args.padding_x * args.dst_tensor.Batch()) / "
+         "args.stride_x;\n";
     c += "  int src_x = src_x0 * args.dst_tensor.Batch() + B;\n";
   } else {
     c += "  int src_x = (X + args.padding_x) / args.stride_x;\n";
@@ -145,7 +96,8 @@
         "  int4 ind = convert_int4(args.src_indices.Read(" + src_args + "));\n";
   }
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
-    c += "  int t_x = X0 - (src_x0 * args.stride_x - args.padding_x);\n";
+    c += "  int t_x = X0 - (src_x0 * args.stride_x - args.padding_x * "
+         "args.dst_tensor.Batch());\n";
   } else {
     c += "  int t_x = X - (src_x * args.stride_x - args.padding_x);\n";
   }
@@ -172,41 +124,37 @@
 
   return c;
 }
+}  // namespace
 
-absl::Status MaxUnpooling::BindArguments() {
-  if (definition_.dst_tensors[0].HasAxis(Axis::WIDTH)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
-    RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
-  }
-  if (definition_.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
-    RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
-  }
-  if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z));
-    RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z));
-  }
-  return absl::OkStatus();
-}
-
-int3 MaxUnpooling::GetGridSize() const {
-  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
-  const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
-  const int grid_z = dst_[0]->Slices();
-  return int3(grid_x, grid_y, grid_z);
-}
-
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling2DAttributes& attr) {
-  return MaxUnpooling(definition, attr);
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.kernel.w);
+  op.args_.AddInt("padding_x", attr.padding.appended.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("kernel_size_y", attr.kernel.h);
+  op.args_.AddInt("padding_y", attr.padding.appended.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.code_ = GetMaxUnpoolingKernelCode(definition, &op);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  return op;
 }
 
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling3DAttributes& attr) {
-  return MaxUnpooling(definition, attr);
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.kernel.w);
+  op.args_.AddInt("padding_x", attr.padding.appended.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("kernel_size_y", attr.kernel.h);
+  op.args_.AddInt("padding_y", attr.padding.appended.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.args_.AddInt("kernel_size_z", attr.kernel.d);
+  op.args_.AddInt("padding_z", attr.padding.appended.d);
+  op.args_.AddInt("stride_z", attr.strides.d);
+  op.code_ = GetMaxUnpoolingKernelCode(definition, &op);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  return op;
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
index 0b1420a..c1b6cbf 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
@@ -25,34 +25,10 @@
 namespace gpu {
 namespace cl {
 
-class MaxUnpooling : public GPUOperation {
- public:
-  MaxUnpooling(const OperationDef& definition,
-               const MaxUnpooling2DAttributes& attr);
-  MaxUnpooling(const OperationDef& definition,
-               const MaxUnpooling3DAttributes& attr);
-
-  absl::Status BindArguments() override;
-  int3 GetGridSize() const override;
-
-  // Move only
-  MaxUnpooling(MaxUnpooling&& kernel);
-  MaxUnpooling& operator=(MaxUnpooling&& kernel);
-  MaxUnpooling(const MaxUnpooling&) = delete;
-  MaxUnpooling& operator=(const MaxUnpooling&) = delete;
-
- private:
-  std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def);
-
-  int4 stride_;
-  int4 padding_;
-  int4 kernel_size_;
-};
-
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling2DAttributes& attr);
 
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling3DAttributes& attr);
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc
index c03cb4f..654b389 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc
@@ -55,7 +55,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      MaxUnpooling operation = CreateMaxUnpooling(op_def, attr);
+      GPUOperation operation = CreateMaxUnpooling(op_def, attr);
       ASSERT_OK(ExecuteGPUOperation({src_tensor, src_ind_tensor},
                                     creation_context_, &operation,
                                     BHWC(1, 4, 4, 1), &dst_tensor));
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
index daa052e..497bb85 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
@@ -252,7 +252,7 @@
     case OperationType::MAX_UNPOOLING_2D: {
       auto attr =
           absl::any_cast<MaxUnpooling2DAttributes>(node.operation.attributes);
-      SelectMaxUnpooling(attr, op_def, gpu_op);
+      *gpu_op = SelectMaxUnpooling(attr, op_def);
       return absl::OkStatus();
     }
     case OperationType::MEAN: {
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
index db76a0c..4baf8e7 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
@@ -68,11 +68,9 @@
   *ptr = absl::make_unique<Pooling>(std::move(pooling));
 }
 
-void SelectMaxUnpooling(const MaxUnpooling2DAttributes& attr,
-                        const OperationDef& op_def,
-                        std::unique_ptr<GPUOperation>* ptr) {
-  MaxUnpooling operation = CreateMaxUnpooling(op_def, attr);
-  *ptr = absl::make_unique<MaxUnpooling>(std::move(operation));
+std::unique_ptr<GPUOperation> SelectMaxUnpooling(
+    const MaxUnpooling2DAttributes& attr, const OperationDef& op_def) {
+  return absl::make_unique<GPUOperation>(CreateMaxUnpooling(op_def, attr));
 }
 
 void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
index 6e91a4e..efbc305 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
@@ -41,9 +41,8 @@
 void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def,
                    std::unique_ptr<GPUOperation>* ptr);
 
-void SelectMaxUnpooling(const MaxUnpooling2DAttributes& attr,
-                        const OperationDef& op_def,
-                        std::unique_ptr<GPUOperation>* ptr);
+std::unique_ptr<GPUOperation> SelectMaxUnpooling(
+    const MaxUnpooling2DAttributes& attr, const OperationDef& op_def);
 
 void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
                int dst_channels, std::unique_ptr<GPUOperation>* ptr);