Add remaining CheckGpuDelegateCompatibility() logic

Now the function has all the gpu delegate compatibility logic of model_builder.cc

PiperOrigin-RevId: 381770247
Change-Id: Ie560c934a402eb4fb190e40a0af120658a31fb30
diff --git a/tensorflow/lite/tools/versioning/gpu_compatibility.cc b/tensorflow/lite/tools/versioning/gpu_compatibility.cc
index badca8f..a67f68f 100644
--- a/tensorflow/lite/tools/versioning/gpu_compatibility.cc
+++ b/tensorflow/lite/tools/versioning/gpu_compatibility.cc
@@ -14,6 +14,8 @@
 ==============================================================================*/
 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
 
+#include <string>
+
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/core/platform/logging.h"
@@ -25,6 +27,13 @@
 
 namespace {
 
+const std::string GetOpName(const OpSignature& op_sig) {
+  if (op_sig.op == tflite::BuiltinOperator_CUSTOM) {
+    return op_sig.custom_name;
+  }
+  return tflite::EnumNamesBuiltinOperator()[op_sig.op];
+}
+
 // Helper functions from
 // tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
 
@@ -44,6 +53,16 @@
   return absl::OkStatus();
 }
 
+template <typename ParamsT>
+absl::Status RetrieveCustomInitialData(const OpSignature& op_sig,
+                                       const ParamsT** tf_options) {
+  *tf_options = static_cast<const ParamsT*>(op_sig.custom_initial_data);
+  if (!*tf_options) {
+    return absl::InternalError("Unable to retrieve custom_initial_data.");
+  }
+  return absl::OkStatus();
+}
+
 absl::Status IsActivationSupported(TfLiteFusedActivation fused_activation) {
   switch (fused_activation) {
     case kTfLiteActNone:
@@ -129,6 +148,28 @@
   return absl::OkStatus();
 }
 
+// Checks if the given OpSignature has required number of inputs and outputs for
+// convolution operators. The number of input should be either 2 runtime inputs
+// or 1 runtime and 1 constant input. The number of output should be one.
+absl::Status CheckConvoultionInputOutput(const OpSignature& op_sig) {
+  const int runtime_inputs = GetNumberOfRuntimeInputs(op_sig);
+  if (runtime_inputs > 2) {
+    return absl::InternalError(
+        absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
+                     runtime_inputs, " runtime inputs."));
+  }
+  const int runtime_outputs = op_sig.outputs.size();
+  if (runtime_outputs != 1) {
+    return absl::InternalError(
+        absl::StrCat("Expected 1 output tensor(s), but node has ",
+                     runtime_outputs, " runtime outputs."));
+  }
+  if (runtime_inputs == 1) {
+    RETURN_IF_ERROR(CheckTensorIsAvailable(op_sig, 1));
+  }
+  return absl::OkStatus();
+}
+
 absl::Status CheckStrides(int strides_h, int strides_w) {
   if (strides_h <= 0 || strides_w <= 0) {
     return absl::InvalidArgumentError(
@@ -170,11 +211,26 @@
   return absl::OkStatus();
 }
 
+// Checks if the axes tensor at the given index is a integer32 constant tensor.
+absl::Status CheckAxesAreInt32Const(const OpSignature& op_sig, int idx) {
+  auto axes = op_sig.inputs.at(idx);
+  if (!axes.is_const) {
+    return absl::UnimplementedError(GetOpName(op_sig) +
+                                    " is only supported with constant axes.");
+  }
+  if (axes.type != kTfLiteInt32) {
+    return absl::UnimplementedError(absl::StrCat(
+        GetOpName(op_sig) + " supports int32 tensor for axes. But node has ",
+        TfLiteTypeGetName(axes.type)));
+  }
+  return absl::OkStatus();
+}
+
 absl::Status CheckPooling2DGpuDelegateCompatibility(const OpSignature& op_sig) {
   const TfLitePoolParams* tf_options;
   if (op_sig.custom_initial_data) {  // custom case with indices as a second
                                      // output
-    tf_options = static_cast<TfLitePoolParams*>(op_sig.custom_initial_data);
+    RETURN_IF_ERROR(RetrieveCustomInitialData(op_sig, &tf_options));
     RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
                                        /*required_runtime_inputs=*/1,
                                        /*required_outputs=*/2));
@@ -190,6 +246,82 @@
   return IsActivationSupported(tf_options->activation);
 }
 
+absl::Status CheckDepthwiseConvGpuDelegateCompatibility(
+    const OpSignature& op_sig) {
+  RETURN_IF_ERROR(CheckConvoultionInputOutput(op_sig));
+  const TfLiteDepthwiseConvParams* tf_options;
+  RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
+  RETURN_IF_ERROR(CheckStridesAndDilation(
+      tf_options->stride_height, tf_options->stride_width,
+      tf_options->dilation_height_factor, tf_options->dilation_width_factor));
+  RETURN_IF_ERROR(IsActivationSupported(tf_options->activation));
+
+  const int depth_multiplier = tf_options->depth_multiplier;
+  const auto* input = &op_sig.inputs[0];
+  const auto* filter = &op_sig.inputs[1];
+  const auto* bias = op_sig.inputs.size() > 2 ? &op_sig.inputs[2] : nullptr;
+  const auto* output = &op_sig.outputs[0];
+  if (input->dims.size() != 4) {
+    return absl::InvalidArgumentError("input.dims.size != 4");
+  }
+  if (filter->dims.size() != 4) {
+    return absl::InvalidArgumentError("filter.dims.size != 4");
+  }
+  if (output->dims.size() != 4) {
+    return absl::InvalidArgumentError("output.dims.size != 4");
+  }
+  if (input->dims[0] != output->dims[0]) {
+    return absl::InvalidArgumentError("input.b != output.b");
+  }
+  const int input_depth = input->dims[3];
+  const int output_depth = output->dims[3];
+  if (filter->dims[3] != output_depth) {
+    return absl::InvalidArgumentError("filter.i != output.c");
+  }
+  if (output_depth != input_depth * depth_multiplier) {
+    return absl::InvalidArgumentError("output.c != input.c * depth_multiplier");
+  }
+  if (bias && bias->dims.size() != output_depth) {
+    return absl::InvalidArgumentError("bias.size != output.c");
+  }
+  if (depth_multiplier != 1 && input_depth != 1) {
+    return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1");
+  }
+  return absl::OkStatus();
+}
+
+absl::Status CheckCustomOpsGpuDelegateCompatibility(const OpSignature& op_sig) {
+  if (op_sig.custom_name == "Convolution2DTransposeBias") {
+    RETURN_IF_ERROR(CheckTensorIsAvailable(op_sig, 1));
+    const TfLiteTransposeConvParams* tf_options;
+    RETURN_IF_ERROR(RetrieveCustomInitialData(op_sig, &tf_options));
+    RETURN_IF_ERROR(
+        CheckStrides(tf_options->stride_height, tf_options->stride_width));
+    return absl::OkStatus();
+  }
+  if (op_sig.custom_name == "MaxPoolingWithArgmax2D") {
+    return CheckPooling2DGpuDelegateCompatibility(op_sig);
+  }
+  if (op_sig.custom_name == "MaxUnpooling2D") {
+    RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                       /*required_runtime_inputs=*/2,
+                                       /*required_outputs=*/1));
+    const TfLitePoolParams* tf_options;
+    RETURN_IF_ERROR(RetrieveCustomInitialData(op_sig, &tf_options));
+    RETURN_IF_ERROR(CheckKernelsAndStrides(
+        tf_options->filter_height, tf_options->filter_width,
+        tf_options->stride_height, tf_options->stride_width));
+    return absl::OkStatus();
+  }
+  if (op_sig.custom_name == "Resampler") {
+    return CheckInputsOutputs(op_sig,
+                              /*required_runtime_inputs=*/2,
+                              /*required_outputs=*/1);
+  }
+  return absl::InvalidArgumentError(
+      absl::StrCat("Not supported custom op ", op_sig.custom_name));
+}
+
 }  // namespace
 
 // TODO(b/189917229): Logics are copied from TFLiteOperationParser:IsSupported()
@@ -224,26 +356,14 @@
             TfLiteTypeGetName(op_sig.outputs.at(0).type)));
       }
 
-    case kTfLiteBuiltinConcatenation:
-      // TODO(b/189917229): Implement logic.
+    case kTfLiteBuiltinConcatenation: {
+      const TfLiteConcatenationParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
       return absl::OkStatus();
+    }
 
     case kTfLiteBuiltinConv2d: {
-      const int runtime_inputs = GetNumberOfRuntimeInputs(op_sig);
-      if (runtime_inputs > 2) {
-        return absl::InternalError(
-            absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
-                         runtime_inputs, " runtime inputs."));
-      }
-      const int runtime_outputs = op_sig.outputs.size();
-      if (runtime_outputs != 1) {
-        return absl::InternalError(
-            absl::StrCat("Expected 1 output tensor(s), but node has ",
-                         runtime_outputs, " runtime outputs."));
-      }
-      if (runtime_inputs == 1) {
-        RETURN_IF_ERROR(CheckTensorIsAvailable(op_sig, 1));
-      }
+      RETURN_IF_ERROR(CheckConvoultionInputOutput(op_sig));
       const TfLiteConvParams* tf_options;
       RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
       RETURN_IF_ERROR(CheckStridesAndDilation(
@@ -257,65 +377,8 @@
       return CheckInputsOutputs(op_sig, /*required_runtime_inputs=*/0,
                                 /*required_outputs=*/1);
 
-    case kTfLiteBuiltinDepthwiseConv2d: {
-      const int runtime_inputs = GetNumberOfRuntimeInputs(op_sig);
-      if (runtime_inputs > 2) {
-        return absl::InternalError(
-            absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
-                         runtime_inputs, " runtime inputs."));
-      }
-      const int runtime_outputs = op_sig.outputs.size();
-      if (runtime_outputs != 1) {
-        return absl::InternalError(
-            absl::StrCat("Expected 1 output tensor(s), but node has ",
-                         runtime_outputs, " runtime outputs."));
-      }
-      if (runtime_inputs == 1) {
-        RETURN_IF_ERROR(CheckTensorIsAvailable(op_sig, 1));
-      }
-      const TfLiteDepthwiseConvParams* tf_options;
-      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
-      RETURN_IF_ERROR(CheckStridesAndDilation(
-          tf_options->stride_height, tf_options->stride_width,
-          tf_options->dilation_height_factor,
-          tf_options->dilation_width_factor));
-      RETURN_IF_ERROR(IsActivationSupported(tf_options->activation));
-
-      const int depth_multiplier = tf_options->depth_multiplier;
-      const auto* input = &op_sig.inputs[0];
-      const auto* filter = &op_sig.inputs[1];
-      const auto* bias = op_sig.inputs.size() > 2 ? &op_sig.inputs[2] : nullptr;
-      const auto* output = &op_sig.outputs[0];
-      if (input->dims.size() != 4) {
-        return absl::InvalidArgumentError("input.dims.size != 4");
-      }
-      if (filter->dims.size() != 4) {
-        return absl::InvalidArgumentError("filter.dims.size != 4");
-      }
-      if (output->dims.size() != 4) {
-        return absl::InvalidArgumentError("output.dims.size != 4");
-      }
-      if (input->dims[0] != output->dims[0]) {
-        return absl::InvalidArgumentError("input.b != output.b");
-      }
-      const int input_depth = input->dims[3];
-      const int output_depth = output->dims[3];
-      if (filter->dims[3] != output_depth) {
-        return absl::InvalidArgumentError("filter.i != output.c");
-      }
-      if (output_depth != input_depth * depth_multiplier) {
-        return absl::InvalidArgumentError(
-            "output.c != input.c * depth_multiplier");
-      }
-      if (bias && bias->dims.size() != output_depth) {
-        return absl::InvalidArgumentError("bias.size != output.c");
-      }
-      if (depth_multiplier != 1 && input_depth != 1) {
-        return absl::UnimplementedError(
-            "depth_multiplier != 1 && input.c != 1");
-      }
-      return absl::OkStatus();
-    }
+    case kTfLiteBuiltinDepthwiseConv2d:
+      return CheckDepthwiseConvGpuDelegateCompatibility(op_sig);
 
     case kTfLiteBuiltinDepthToSpace: {
       RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
@@ -426,9 +489,13 @@
     case kTfLiteBuiltinMaxPool2d:
       return CheckPooling2DGpuDelegateCompatibility(op_sig);
 
-    case kTfLiteBuiltinMean:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
+    case kTfLiteBuiltinMean: {
+      RETURN_IF_ERROR(CheckInputsConstsOutputs(op_sig,
+                                               /*required_runtime_inputs=*/1,
+                                               /*required_const_inputs=*/1,
+                                               /*required_outputs=*/1));
+      return CheckAxesAreInt32Const(op_sig, 1);
+    }
 
     case kTfLiteBuiltinMul: {
       if (op_sig.inputs.size() != 2) {
@@ -466,92 +533,155 @@
     case kTfLiteBuiltinPack:
       return absl::OkStatus();
 
-    case kTfLiteBuiltinReduceMax:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
-
-    case kTfLiteBuiltinReduceMin:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
-
-    case kTfLiteBuiltinReduceProd:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
-
     case kTfLiteBuiltinQuantize:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
-
-    case kTfLiteBuiltinRelu:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
-
-    case kTfLiteBuiltinRelu6:
-      // TODO(b/189917229): Implement logic.
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
       return absl::OkStatus();
 
     case kTfLiteBuiltinReluN1To1:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
-
-    case kTfLiteBuiltinLeakyRelu:
-      // TODO(b/189917229): Implement logic.
       return absl::OkStatus();
 
     case kTfLiteBuiltinPrelu:
       return absl::OkStatus();
 
     case kTfLiteBuiltinReshape:
-      // TODO(b/189917229): Implement logic.
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
       return absl::OkStatus();
 
-    case kTfLiteBuiltinResizeBilinear:
-      // TODO(b/189917229): Implement logic.
+    case kTfLiteBuiltinSlice: {
+      if (op_sig.inputs.size() < 3) {
+        return absl::UnimplementedError(
+            absl::StrCat("SLICE requires 3 inputs, but node has ",
+                         op_sig.inputs.size(), " inputs."));
+      }
+      const auto& input = op_sig.inputs.at(0);
+      if (input.dims.size() != 3 && input.dims.size() != 4) {
+        return absl::UnimplementedError(absl::StrCat(
+            "SLICE supports for 3 or 4 dimensional tensors only, but node has ",
+            input.dims.size(), " dimensional tensors."));
+      }
       return absl::OkStatus();
+    }
 
-    case kTfLiteBuiltinResizeNearestNeighbor:
-      // TODO(b/189917229): Implement logic.
+    case kTfLiteBuiltinSoftmax: {
+      const TfLiteSoftmaxParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
+      if (tf_options->beta != 1) {
+        return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
+      }
       return absl::OkStatus();
+    }
 
-    case kTfLiteBuiltinSlice:
-      // TODO(b/189917229): Implement logic.
+    case kTfLiteBuiltinSpaceToDepth: {
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
+      const TfLiteSpaceToDepthParams* s2d_params;
+      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &s2d_params));
+      if (s2d_params->block_size == 1) {
+        return absl::InvalidArgumentError(
+            "SPACE_TO_DEPTH block_size = 1 is a no-op.");
+      }
+      if (s2d_params->block_size < 1) {
+        return absl::InvalidArgumentError(
+            "SPACE_TO_DEPTH block_size must be > 1.");
+      }
       return absl::OkStatus();
-
-    case kTfLiteBuiltinSoftmax:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
-
-    case kTfLiteBuiltinSpaceToDepth:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
+    }
 
     case kTfLiteBuiltinSplit:
-      // TODO(b/189917229): Implement logic.
       return absl::OkStatus();
 
     case kTfLiteBuiltinSplitV:
-      // TODO(b/189917229): Implement logic.
       return absl::OkStatus();
 
-    case kTfLiteBuiltinStridedSlice:
-      // TODO(b/189917229): Implement logic.
-      return absl::OkStatus();
+    case kTfLiteBuiltinStridedSlice: {
+      const TfLiteStridedSliceParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
+      if (tf_options->ellipsis_mask) {
+        return absl::UnimplementedError(
+            "Slice does not support ellipsis_mask.");
+      }
+      if (tf_options->new_axis_mask) {
+        return absl::UnimplementedError(
+            "Slice does not support new_axis_mask.");
+      }
+      if (tf_options->shrink_axis_mask) {
+        return absl::UnimplementedError(
+            "Slice does not support shrink_axis_mask parameter. ");
+      }
 
-    case kTfLiteBuiltinSum:
-      // TODO(b/189917229): Implement logic.
+      if (op_sig.inputs.size() < 4) {
+        return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs.");
+      }
+      const auto& input = op_sig.inputs.at(0);
+      if (input.dims.size() != 3 && input.dims.size() != 4) {
+        return absl::UnimplementedError(
+            "STRIDED_SLICE supports for 3 or 4 dimensional tensors only.");
+      }
       return absl::OkStatus();
+    }
 
     case kTfLiteBuiltinTile:
-      // TODO(b/189917229): Implement logic.
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
       return absl::OkStatus();
 
     case kTfLiteBuiltinTranspose:
-      // TODO(b/189917229): Implement logic.
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
       return absl::OkStatus();
 
-    case kTfLiteBuiltinTransposeConv:
-      // TODO(b/189917229): Implement logic.
+    case kTfLiteBuiltinTransposeConv: {
+      RETURN_IF_ERROR(CheckConvoultionInputOutput(op_sig));
+      const TfLiteTransposeConvParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
+      RETURN_IF_ERROR(
+          CheckStrides(tf_options->stride_height, tf_options->stride_width));
       return absl::OkStatus();
+    }
+
+    case kTfLiteBuiltinResizeBilinear: {
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
+      const TfLiteResizeBilinearParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
+      if (tf_options->align_corners && tf_options->half_pixel_centers) {
+        return absl::InternalError(
+            "If half_pixel_centers is True, align_corners must be False.");
+      }
+      return absl::OkStatus();
+    }
+
+    case kTfLiteBuiltinResizeNearestNeighbor: {
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
+      const TfLiteResizeNearestNeighborParams* tf_options;
+      RETURN_IF_ERROR(RetrieveBuiltinData(op_sig, &tf_options));
+      return absl::OkStatus();
+    }
+
+    case kTfLiteBuiltinRelu:
+    case kTfLiteBuiltinRelu6:
+    case kTfLiteBuiltinLeakyRelu:
+      return absl::OkStatus();
+
+    case kTfLiteBuiltinReduceMax:
+    case kTfLiteBuiltinReduceMin:
+    case kTfLiteBuiltinReduceProd:
+    case kTfLiteBuiltinSum: {
+      RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
+                                         /*required_runtime_inputs=*/1,
+                                         /*required_outputs=*/1));
+      return CheckAxesAreInt32Const(op_sig, 1);
+    }
 
     case kTfLiteBuiltinPad:
     case kTfLiteBuiltinMirrorPad: {
@@ -642,25 +772,8 @@
       return IsActivationSupported(activation);
     }
 
-    case kTfLiteBuiltinCustom: {
-      if (op_sig.custom_name == "Convolution2DTransposeBias") {
-        // TODO(b/189917229): Implement logic.
-        return absl::OkStatus();
-      }
-      if (op_sig.custom_name == "MaxPoolingWithArgmax2D") {
-        return CheckPooling2DGpuDelegateCompatibility(op_sig);
-      }
-      if (op_sig.custom_name == "MaxUnpooling2D") {
-        // TODO(b/189917229): Implement logic.
-        return absl::OkStatus();
-      }
-      if (op_sig.custom_name == "Resampler") {
-        // TODO(b/189917229): Implement logic.
-        return absl::OkStatus();
-      }
-      return absl::InvalidArgumentError(
-          absl::StrCat("Not supported custom op ", op_sig.custom_name));
-    }
+    case kTfLiteBuiltinCustom:
+      return CheckCustomOpsGpuDelegateCompatibility(op_sig);
 
     default:
       break;
diff --git a/tensorflow/lite/tools/versioning/op_signature.cc b/tensorflow/lite/tools/versioning/op_signature.cc
index bd96ff7..59080d9 100644
--- a/tensorflow/lite/tools/versioning/op_signature.cc
+++ b/tensorflow/lite/tools/versioning/op_signature.cc
@@ -210,6 +210,7 @@
   OpSignature op_sig = {
       static_cast<BuiltinOperator>(registration->builtin_code)};
   op_sig.builtin_data = tflite_node->builtin_data;
+  op_sig.custom_initial_data = tflite_node->custom_initial_data;
   std::memset(&op_sig.ext_options, 0, sizeof(op_sig.ext_options));
 
   op_sig.inputs =
diff --git a/tensorflow/lite/tools/versioning/op_signature.h b/tensorflow/lite/tools/versioning/op_signature.h
index ad4f804..f9ed28d 100644
--- a/tensorflow/lite/tools/versioning/op_signature.h
+++ b/tensorflow/lite/tools/versioning/op_signature.h
@@ -35,7 +35,7 @@
   std::vector<OpSignatureTensorSpec> inputs;
   std::vector<OpSignatureTensorSpec> outputs;
   void* builtin_data;
-  void* custom_initial_data;
+  const void* custom_initial_data;
   std::string custom_name;
   union {
     struct {