Rollback:
[lite] Add support for 5D select

PiperOrigin-RevId: 450095846
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index d17c882..e640ec6 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3052,7 +3052,7 @@
     NoSideEffect,
     QuantizableResult,
     SameOperandsAndResultsScale,
-    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 5>,
+    TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
     PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
     PredOpTrait<"operands and result have same element type",
       TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index df50164..c1048e3 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -2907,7 +2907,7 @@
 func.func @select_v2_with_dynamic_shape_not_from_broadcast_args(%arg0: tensor<8x7x6x5x?x3x2x1xi1>, %arg1: tensor<8x7x6x5x?x3x2x1xf32>, %arg2: tensor<?x3x2x1xf32>, %arg3: tensor<8xi64>) -> tensor<8x7x6x5x?x3x2x1xf32> {
   %0 = "tfl.broadcast_to"(%arg1, %arg3) : (tensor<8x7x6x5x?x3x2x1xf32>, tensor<8xi64>) -> tensor<8x7x6x5x?x3x2x1xf32>
   %1 = "tfl.broadcast_to"(%arg2, %arg3) : (tensor<?x3x2x1xf32>, tensor<8xi64>) -> tensor<8x7x6x5x?x3x2x1xf32>
-  // expected-error @+1 {{'tfl.select_v2' op failed to verify that operands do not have the same shape or broadcastable shapes within the rank 5}}
+  // expected-error @+1 {{'tfl.select_v2' op failed to verify that operands do not have the same shape or broadcastable shapes within the rank 4}}
   %2 = "tfl.select_v2"(%arg0, %0, %1) : (tensor<8x7x6x5x?x3x2x1xi1>, tensor<8x7x6x5x?x3x2x1xf32>, tensor<8x7x6x5x?x3x2x1xf32>) -> tensor<8x7x6x5x?x3x2x1xf32>
   func.return %2 : tensor<8x7x6x5x?x3x2x1xf32>
 }
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index 01b2a9d..c5dbe0f 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -791,34 +791,55 @@
   }
 }
 
-template <typename D, typename T, int N = 5>
-void BroadcastSelectSlow(const RuntimeShape& input_condition_shape,
-                         const D* input_condition_data,
-                         const RuntimeShape& input_x_shape,
-                         const T* input_x_data,
-                         const RuntimeShape& input_y_shape,
-                         const T* input_y_data,
-                         const RuntimeShape& output_shape, T* output_data) {
-  TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), N);
-  TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), N);
-  TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), N);
-  TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N);
+template <typename D, typename T>
+void BroadcastSelect4DSlow(const RuntimeShape& input_condition_shape,
+                           const D* input_condition_data,
+                           const RuntimeShape& input_x_shape,
+                           const T* input_x_data,
+                           const RuntimeShape& input_y_shape,
+                           const T* input_y_data,
+                           const RuntimeShape& output_shape, T* output_data) {
+  TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
 
-  NdArrayDesc<N> desc_condition;
-  NdArrayDesc<N> desc_x;
-  NdArrayDesc<N> desc_y;
-  NdArrayDesc<N> output_desc;
-  CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc);
+  const RuntimeShape extended_output_shape =
+      RuntimeShape::ExtendedShape(4, output_shape);
+
+  NdArrayDesc<4> desc_condition;
+  NdArrayDesc<4> desc_x;
+  NdArrayDesc<4> desc_y;
   NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
                                       input_y_shape, &desc_condition, &desc_x,
                                       &desc_y);
-  auto select_func = [&](int indexes[N]) {
-    output_data[SubscriptToIndex(output_desc, indexes)] =
-        input_condition_data[SubscriptToIndex(desc_condition, indexes)]
-            ? input_x_data[SubscriptToIndex(desc_x, indexes)]
-            : input_y_data[SubscriptToIndex(desc_y, indexes)];
-  };
-  NDOpsHelper<N>(output_desc, select_func);
+
+  // In Tensorflow, the dimensions are canonically named (batch_number, row,
+  // col, channel), with extents (batches, height, width, depth), with the
+  // trailing dimension changing most rapidly (channels has the smallest
+  // stride, typically 1 element).
+  //
+  // In generated C code, we store arrays with the dimensions reversed. The
+  // first dimension has smallest stride.
+  //
+  // We name our variables by their Tensorflow convention, but generate C code
+  // nesting loops such that the innermost loop has the smallest stride for
+  // the best cache behavior.
+  for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+    for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+      for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+        for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+          const int condition_index =
+              SubscriptToIndex(desc_condition, b, y, x, c);
+          const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
+          const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
+          output_data[Offset(extended_output_shape, b, y, x, c)] =
+              input_condition_data[condition_index] ? input_x_data[x_index]
+                                                    : input_y_data[y_index];
+        }
+      }
+    }
+  }
 }
 
 template <typename D, typename T>
diff --git a/tensorflow/lite/kernels/select.cc b/tensorflow/lite/kernels/select.cc
index 9a8f020..00b92aa 100644
--- a/tensorflow/lite/kernels/select.cc
+++ b/tensorflow/lite/kernels/select.cc
@@ -182,7 +182,7 @@
   if (data->has_low_rank_input_condition) {
     TF_LITE_SWITCH(input_x->type, RankOneSelect);
   } else if (data->requires_broadcast) {
-    TF_LITE_SWITCH(input_x->type, BroadcastSelectSlow);
+    TF_LITE_SWITCH(input_x->type, BroadcastSelect4DSlow);
   } else {
     TF_LITE_SWITCH(input_x->type, Select);
   }
diff --git a/tensorflow/lite/kernels/select_test.cc b/tensorflow/lite/kernels/select_test.cc
index e81d128..514e8a3 100644
--- a/tensorflow/lite/kernels/select_test.cc
+++ b/tensorflow/lite/kernels/select_test.cc
@@ -308,20 +308,6 @@
   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
 }
 
-TEST(SelectV2OpTest,
-     BroadcastSelectInt32OneDimensionConditionWithSingleValue5D) {
-  SelectV2OpModel model({1}, {1, 2, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);
-
-  model.PopulateTensor<bool>(model.input1(), {false});
-  model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4, 5, 6, 7, 8});
-  model.PopulateTensor<int32_t>(model.input3(), {9, 10, 11, 12});
-  model.Invoke();
-
-  EXPECT_THAT(model.GetOutput<int32_t>(),
-              ElementsAreArray({9, 10, 11, 12, 9, 10, 11, 12}));
-  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 2, 1}));
-}
-
 TEST(SelectV2OpTest, BroadcastSelectInt32LesserThan4D) {
   SelectV2OpModel model({1, 2}, {1, 2, 2}, {1, 2, 2}, TensorType_INT32);
 
diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc
index 996d7aa..d046668 100644
--- a/tensorflow/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/lite/testing/generated_examples_zip_test.cc
@@ -66,6 +66,9 @@
 // TODO(ahentz): make sure we clean this list up frequently.
 const BrokenTestMap& GetKnownBrokenTests() {
   static const BrokenTestMap* const kBrokenTests = new BrokenTestMap({
+      // Select kernel doesn't support broadcasting yet.
+      {R"(^\/where.*1,2,3,1)", {"134692786", false}},
+
       // TODO(b/194364155): TF and TFLite have different behaviors when output
       // nan values in LocalResponseNorm ops.
       {R"(^\/local_response_norm.*alpha=-3.*beta=2)", {"194364155", true}},
diff --git a/tensorflow/lite/testing/op_tests/where_v2.py b/tensorflow/lite/testing/op_tests/where_v2.py
index 8b8e431..22f679b 100644
--- a/tensorflow/lite/testing/op_tests/where_v2.py
+++ b/tensorflow/lite/testing/op_tests/where_v2.py
@@ -82,13 +82,6 @@
           "input_dtype": [tf.float32, tf.int32],
           "input_shape_set": [([1, 2, 2], [1, 2]),],
       },
-      # Requires kernel supporting broadcasting for 5D case.
-      {
-          "condition_dtype": [tf.float32],
-          "input_condition_shape": [[1, 1, 1, 1, 1]],
-          "input_dtype": [tf.float32],
-          "input_shape_set": [([], [1, None, 1, 2, 512])],
-      },
   ]
 
   def build_graph(parameters):
@@ -116,21 +109,16 @@
     out = tf.where_v2(input_condition, input_value1, input_value2)
     return [input_condition, input_value1, input_value2], [out]
 
-  def build_input_shape(input_shape):
-    return [1 if v is None else v for v in input_shape]
-
   def build_inputs(parameters, sess, inputs, outputs):
     input_condition = create_tensor_data(parameters["condition_dtype"],
                                          parameters["input_condition_shape"])
     input_value1 = None
     input_value2 = None
     if parameters["input_dtype"] is not None:
-      input_value1 = create_tensor_data(
-          parameters["input_dtype"],
-          build_input_shape(parameters["input_shape_set"][0]))
-      input_value2 = create_tensor_data(
-          parameters["input_dtype"],
-          build_input_shape(parameters["input_shape_set"][1]))
+      input_value1 = create_tensor_data(parameters["input_dtype"],
+                                        parameters["input_shape_set"][0])
+      input_value2 = create_tensor_data(parameters["input_dtype"],
+                                        parameters["input_shape_set"][1])
       return [input_condition, input_value1, input_value2], sess.run(
           outputs,
           feed_dict=dict(
diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc
index 41bf297..5eec312 100644
--- a/tensorflow/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/lite/toco/tflite/operator_test.cc
@@ -1085,33 +1085,6 @@
   EXPECT_EQ(base_op->GetVersion(signature), version);
 }
 
-template <typename OpType>
-void SimpleThreeInputsVersioningTest(ArrayDataType data_type, Shape shape1,
-                                     Shape shape2, Shape shape3, int version) {
-  OpType op;
-  op.inputs = {"input1", "input2", "input3"};
-  op.outputs = {"output"};
-  auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
-  const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
-
-  Model model;
-  Array& input0 = model.GetOrCreateArray(op.inputs[0]);
-  Array& input1 = model.GetOrCreateArray(op.inputs[1]);
-  Array& input2 = model.GetOrCreateArray(op.inputs[2]);
-  Array& output = model.GetOrCreateArray(op.outputs[0]);
-
-  input0.data_type = data_type;
-  input0.copy_shape(shape1);
-  input1.data_type = data_type;
-  input1.copy_shape(shape2);
-  input2.data_type = data_type;
-  input2.copy_shape(shape3);
-  output.data_type = data_type;
-
-  OperatorSignature signature = {.op = &op, .model = &model};
-  EXPECT_EQ(base_op->GetVersion(signature), version);
-}
-
 TEST_F(OperatorTest, VersioningSubTest) {
   SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
                                              {1, 2, 2, 2}, {1, 2, 2, 2}, 1);
@@ -1153,13 +1126,7 @@
 }
 
 TEST_F(OperatorTest, VersioningSelectTest) {
-  SimpleThreeInputsVersioningTest<SelectOperator>(
-      ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 1);
-  SimpleThreeInputsVersioningTest<SelectOperator>(
-      ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 2);
-  SimpleThreeInputsVersioningTest<SelectOperator>(
-      ArrayDataType::kInt8, {1, 2, 2, 2, 1}, {1, 2, 2, 1, 1}, {1, 2, 2, 1, 1},
-      3);
+  SimpleVersioningTest<SelectOperator>();
 }
 
 TEST_F(OperatorTest, VersioningRelu6Test) {
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index b11a133..2a25ce8 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -776,16 +776,6 @@
       }
       return 1;
 
-    case BuiltinOperator_SELECT: {
-      if (op_sig.inputs.at(0).dims.size() == 5 ||
-          op_sig.inputs.at(1).dims.size() == 5 ||
-          op_sig.inputs.at(2).dims.size() == 5)
-        return 3;
-      if (op_sig.inputs.at(0).type == kTfLiteInt8) {
-        return 2;
-      }
-      return 1;
-    }
     case BuiltinOperator_SPACE_TO_DEPTH:
     case BuiltinOperator_SPLIT_V:
     case BuiltinOperator_SUM:
@@ -795,6 +785,7 @@
     case BuiltinOperator_GREATER_EQUAL:
     case BuiltinOperator_LESS:
     case BuiltinOperator_LESS_EQUAL:
+    case BuiltinOperator_SELECT:
     case BuiltinOperator_RSQRT:
     case BuiltinOperator_SQUARED_DIFFERENCE:
     case BuiltinOperator_DEPTH_TO_SPACE:
diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc
index 588fbda..119d439 100644
--- a/tensorflow/lite/tools/versioning/op_version_test.cc
+++ b/tensorflow/lite/tools/versioning/op_version_test.cc
@@ -36,22 +36,6 @@
   return tensor_specs;
 }
 
-// Creates vector of OpSignatureTensorSpec with the given TfLiteType vector,
-// each with rank 'rank'
-std::vector<OpSignatureTensorSpec> CreateOpSignatureTensorSpecs(
-    const std::vector<TfLiteType>& types, int rank) {
-  std::vector<OpSignatureTensorSpec> tensor_specs;
-  for (auto type : types) {
-    OpSignatureTensorSpec tensor_spec = {};
-    tensor_spec.type = type;
-    for (int i = 0; i < rank; i++) {
-      tensor_spec.dims.push_back(4);
-    }
-    tensor_specs.push_back(tensor_spec);
-  }
-  return tensor_specs;
-}
-
 // Creates vector of OpSignatureTensorSpec of single tensor spec of TfLiteType.
 std::vector<OpSignatureTensorSpec> CreateOpSignatureTensorSpecs(
     const TfLiteType type) {
@@ -507,29 +491,7 @@
 }
 
 TEST(OpVersionTest, VersioningSelectTest) {
-  OpSignature fake_op_sig = {
-      .op = BuiltinOperator_SELECT,
-      .inputs = CreateOpSignatureTensorSpecs(
-          std::vector<TfLiteType>{kTfLiteUInt8, kTfLiteUInt8, kTfLiteUInt8}, 5),
-      .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8),
-  };
-  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
-  fake_op_sig = {
-      .op = BuiltinOperator_SELECT,
-      .inputs = CreateOpSignatureTensorSpecs(
-          std::vector<TfLiteType>{kTfLiteInt8, kTfLiteInt8, kTfLiteInt8}, 4),
-      .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8),
-  };
-  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
-  fake_op_sig = {
-      .op = BuiltinOperator_SELECT,
-      .inputs = CreateOpSignatureTensorSpecs(
-          std::vector<TfLiteType>{kTfLiteFloat32, kTfLiteFloat32,
-                                  kTfLiteFloat32},
-          4),
-      .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32),
-  };
-  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+  SimpleVersioningTest(BuiltinOperator_SELECT);
 }
 
 TEST(OpVersionTest, VersioningRelu6Test) {