Move some custom ops back to model builder.

PiperOrigin-RevId: 365116594
Change-Id: Ieea3ad01cc6518c5ccd1cb971ef941864fc9db8b
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index af152a5..a9f9bbf 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -82,6 +82,16 @@
   return absl::OkStatus();
 }
 
+template <typename ParamsT>
+absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
+                                       const ParamsT** tf_options) {
+  *tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data);
+  if (!*tf_options) {
+    return absl::InternalError("Unable to retrieve custom_initial_data.");
+  }
+  return absl::OkStatus();
+}
+
 absl::Status CheckDilation(int dilation_h, int dilation_w) {
   if (dilation_h <= 0 || dilation_w <= 0) {
     return absl::InvalidArgumentError(absl::StrCat(
@@ -1360,10 +1370,17 @@
                            const TfLiteRegistration* registration) final {
     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
     const TfLitePoolParams* tf_options;
-    RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
-    RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
-                                       /*runtime_inputs=*/1,
-                                       /*outputs=*/1));
+    auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
+    if (status.ok()) {  // custom case with indices as a second output
+      RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
+                                         /*runtime_inputs=*/1,
+                                         /*outputs=*/2));
+    } else {  // common pooling with 1 output
+      RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+      RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
+                                         /*runtime_inputs=*/1,
+                                         /*outputs=*/1));
+    }
     RETURN_IF_ERROR(CheckKernelsAndStrides(
         tf_options->filter_height, tf_options->filter_width,
         tf_options->stride_height, tf_options->stride_width));
@@ -1386,12 +1403,28 @@
 
     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
 
+    // Check whether there are custom options encoded. It happens if operation
+    // is MaxPoolingWithArgmax2D. There is no way to read
+    // tflite_node->builtin_code, so, simply check whether custom data is
+    // available.
     const TfLitePoolParams* tf_options;
-    RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+    if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
+      RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
+    }
 
     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
+    // Second output is optional. It is not required, it but must be added after
+    // MaybeAddFusedActivation function is called
+    reader->AddOutput(node, 1).IgnoreError();
 
-    attr.output_indices = false;
+    // First output is the result of pooling operation, while second output is
+    // indices used for pooling.
+    auto outputs = graph->FindOutputs(node->id);
+    attr.output_indices = outputs.size() == 2;
+    if (attr.output_indices) {
+      // Fix data type for output indices. In the model it is set as float32.
+      outputs[1]->tensor.type = DataType::INT32;
+    }
     RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
     node->operation.attributes = attr;
     return absl::OkStatus();
@@ -2172,6 +2205,45 @@
   }
 };
 
+// Custom op version of TRANSPOSE_CONV.
+class TransposeConvCustomOperationParser : public TFLiteOperationParser {
+ public:
+  absl::Status IsSupported(const TfLiteContext* context,
+                           const TfLiteNode* tflite_node,
+                           const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
+    const TfLiteTransposeConvParams* tf_options;
+    RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
+    RETURN_IF_ERROR(
+        CheckStrides(tf_options->stride_height, tf_options->stride_width));
+    return absl::OkStatus();
+  }
+
+  absl::Status Parse(const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration,
+                     GraphFloat32* graph, ObjectReader* reader) final {
+    auto* node = graph->NewNode();
+    node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
+    RETURN_IF_ERROR(reader->AddInput(node, 0));
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+
+    const TfLiteTransposeConvParams* tf_options;
+    auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
+
+    ConvolutionTransposedAttributes attr;
+    attr.stride = status.ok()
+                      ? HW(tf_options->stride_height, tf_options->stride_width)
+                      : HW(1, 1);
+    RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
+    reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
+
+    UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
+                  graph->FindInputs(node->id)[0]->tensor.shape, &attr);
+    node->operation.attributes = std::move(attr);
+    return absl::OkStatus();
+  }
+};
+
 class TransposeOperationParser : public TFLiteOperationParser {
  public:
   absl::Status IsSupported(const TfLiteContext* context,
@@ -2223,6 +2295,47 @@
   }
 };
 
+class Unpooling2DOperationParser : public TFLiteOperationParser {
+ public:
+  absl::Status IsSupported(const TfLiteContext* context,
+                           const TfLiteNode* tflite_node,
+                           const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
+                                       /*runtime_inputs=*/2, /*outputs=*/1));
+    const TfLitePoolParams* tf_options;
+    RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
+    RETURN_IF_ERROR(CheckKernelsAndStrides(
+        tf_options->filter_height, tf_options->filter_width,
+        tf_options->stride_height, tf_options->stride_width));
+    return absl::OkStatus();
+  }
+
+  absl::Status Parse(const TfLiteNode* tflite_node,
+                     const TfLiteRegistration* registration,
+                     GraphFloat32* graph, ObjectReader* reader) final {
+    Node* node = graph->NewNode();
+    node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
+    RETURN_IF_ERROR(reader->AddInput(node, 0));
+    RETURN_IF_ERROR(reader->AddInput(node, 1));
+    RETURN_IF_ERROR(reader->AddOutputs(node));
+    auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
+    MaxUnpooling2DAttributes attr;
+
+    const TfLitePoolParams* tf_options;
+    RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
+
+    attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
+    attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
+    UpdatePadding(tf_options->padding, input_shape, &attr);
+
+    node->operation.attributes = attr;
+
+    auto output_value = graph->FindOutputs(node->id)[0];
+    output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
+    return absl::OkStatus();
+  }
+};
+
 // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
 class BatchToSpaceOperationParser : public TFLiteOperationParser {
  public:
@@ -2496,8 +2609,19 @@
       return std::make_unique<TransposeOperationParser>();
     case kTfLiteBuiltinTransposeConv:
       return std::make_unique<TransposeConvBuiltinOperationParser>();
-    case kTfLiteBuiltinCustom:
+    case kTfLiteBuiltinCustom: {
+      const absl::string_view custom_name = registration->custom_name;
+      if (custom_name == "Convolution2DTransposeBias") {
+        return std::make_unique<TransposeConvCustomOperationParser>();
+      }
+      if (custom_name == "MaxPoolingWithArgmax2D") {
+        return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
+      }
+      if (custom_name == "MaxUnpooling2D") {
+        return std::make_unique<Unpooling2DOperationParser>();
+      }
       return NewCustomOperationParser(registration->custom_name);
+    }
   }
   return std::make_unique<UnsupportedOperationParser>();
 }