Rollback:
[lite] Add support for 5D select
PiperOrigin-RevId: 450095846
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index d17c882..e640ec6 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3052,7 +3052,7 @@
NoSideEffect,
QuantizableResult,
SameOperandsAndResultsScale,
- TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 5>,
+ TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
PredOpTrait<"operands and result have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index df50164..c1048e3 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -2907,7 +2907,7 @@
func.func @select_v2_with_dynamic_shape_not_from_broadcast_args(%arg0: tensor<8x7x6x5x?x3x2x1xi1>, %arg1: tensor<8x7x6x5x?x3x2x1xf32>, %arg2: tensor<?x3x2x1xf32>, %arg3: tensor<8xi64>) -> tensor<8x7x6x5x?x3x2x1xf32> {
%0 = "tfl.broadcast_to"(%arg1, %arg3) : (tensor<8x7x6x5x?x3x2x1xf32>, tensor<8xi64>) -> tensor<8x7x6x5x?x3x2x1xf32>
%1 = "tfl.broadcast_to"(%arg2, %arg3) : (tensor<?x3x2x1xf32>, tensor<8xi64>) -> tensor<8x7x6x5x?x3x2x1xf32>
- // expected-error @+1 {{'tfl.select_v2' op failed to verify that operands do not have the same shape or broadcastable shapes within the rank 5}}
+ // expected-error @+1 {{'tfl.select_v2' op failed to verify that operands do not have the same shape or broadcastable shapes within the rank 4}}
%2 = "tfl.select_v2"(%arg0, %0, %1) : (tensor<8x7x6x5x?x3x2x1xi1>, tensor<8x7x6x5x?x3x2x1xf32>, tensor<8x7x6x5x?x3x2x1xf32>) -> tensor<8x7x6x5x?x3x2x1xf32>
func.return %2 : tensor<8x7x6x5x?x3x2x1xf32>
}
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index 01b2a9d..c5dbe0f 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -791,34 +791,55 @@
}
}
-template <typename D, typename T, int N = 5>
-void BroadcastSelectSlow(const RuntimeShape& input_condition_shape,
- const D* input_condition_data,
- const RuntimeShape& input_x_shape,
- const T* input_x_data,
- const RuntimeShape& input_y_shape,
- const T* input_y_data,
- const RuntimeShape& output_shape, T* output_data) {
- TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), N);
- TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), N);
- TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), N);
- TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N);
+template <typename D, typename T>
+void BroadcastSelect4DSlow(const RuntimeShape& input_condition_shape,
+ const D* input_condition_data,
+ const RuntimeShape& input_x_shape,
+ const T* input_x_data,
+ const RuntimeShape& input_y_shape,
+ const T* input_y_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
- NdArrayDesc<N> desc_condition;
- NdArrayDesc<N> desc_x;
- NdArrayDesc<N> desc_y;
- NdArrayDesc<N> output_desc;
- CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc);
+ const RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ NdArrayDesc<4> desc_condition;
+ NdArrayDesc<4> desc_x;
+ NdArrayDesc<4> desc_y;
NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
input_y_shape, &desc_condition, &desc_x,
&desc_y);
- auto select_func = [&](int indexes[N]) {
- output_data[SubscriptToIndex(output_desc, indexes)] =
- input_condition_data[SubscriptToIndex(desc_condition, indexes)]
- ? input_x_data[SubscriptToIndex(desc_x, indexes)]
- : input_y_data[SubscriptToIndex(desc_y, indexes)];
- };
- NDOpsHelper<N>(output_desc, select_func);
+
+ // In Tensorflow, the dimensions are canonically named (batch_number, row,
+ // col, channel), with extents (batches, height, width, depth), with the
+ // trailing dimension changing most rapidly (channels has the smallest
+ // stride, typically 1 element).
+ //
+ // In generated C code, we store arrays with the dimensions reversed. The
+ // first dimension has smallest stride.
+ //
+ // We name our variables by their Tensorflow convention, but generate C code
+ // nesting loops such that the innermost loop has the smallest stride for
+ // the best cache behavior.
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ const int condition_index =
+ SubscriptToIndex(desc_condition, b, y, x, c);
+ const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
+ const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ input_condition_data[condition_index] ? input_x_data[x_index]
+ : input_y_data[y_index];
+ }
+ }
+ }
+ }
}
template <typename D, typename T>
diff --git a/tensorflow/lite/kernels/select.cc b/tensorflow/lite/kernels/select.cc
index 9a8f020..00b92aa 100644
--- a/tensorflow/lite/kernels/select.cc
+++ b/tensorflow/lite/kernels/select.cc
@@ -182,7 +182,7 @@
if (data->has_low_rank_input_condition) {
TF_LITE_SWITCH(input_x->type, RankOneSelect);
} else if (data->requires_broadcast) {
- TF_LITE_SWITCH(input_x->type, BroadcastSelectSlow);
+ TF_LITE_SWITCH(input_x->type, BroadcastSelect4DSlow);
} else {
TF_LITE_SWITCH(input_x->type, Select);
}
diff --git a/tensorflow/lite/kernels/select_test.cc b/tensorflow/lite/kernels/select_test.cc
index e81d128..514e8a3 100644
--- a/tensorflow/lite/kernels/select_test.cc
+++ b/tensorflow/lite/kernels/select_test.cc
@@ -308,20 +308,6 @@
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
}
-TEST(SelectV2OpTest,
- BroadcastSelectInt32OneDimensionConditionWithSingleValue5D) {
- SelectV2OpModel model({1}, {1, 2, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);
-
- model.PopulateTensor<bool>(model.input1(), {false});
- model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4, 5, 6, 7, 8});
- model.PopulateTensor<int32_t>(model.input3(), {9, 10, 11, 12});
- model.Invoke();
-
- EXPECT_THAT(model.GetOutput<int32_t>(),
- ElementsAreArray({9, 10, 11, 12, 9, 10, 11, 12}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 2, 1}));
-}
-
TEST(SelectV2OpTest, BroadcastSelectInt32LesserThan4D) {
SelectV2OpModel model({1, 2}, {1, 2, 2}, {1, 2, 2}, TensorType_INT32);
diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc
index 996d7aa..d046668 100644
--- a/tensorflow/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/lite/testing/generated_examples_zip_test.cc
@@ -66,6 +66,9 @@
// TODO(ahentz): make sure we clean this list up frequently.
const BrokenTestMap& GetKnownBrokenTests() {
static const BrokenTestMap* const kBrokenTests = new BrokenTestMap({
+ // Select kernel doesn't support broadcasting yet.
+ {R"(^\/where.*1,2,3,1)", {"134692786", false}},
+
// TODO(b/194364155): TF and TFLite have different behaviors when output
// nan values in LocalResponseNorm ops.
{R"(^\/local_response_norm.*alpha=-3.*beta=2)", {"194364155", true}},
diff --git a/tensorflow/lite/testing/op_tests/where_v2.py b/tensorflow/lite/testing/op_tests/where_v2.py
index 8b8e431..22f679b 100644
--- a/tensorflow/lite/testing/op_tests/where_v2.py
+++ b/tensorflow/lite/testing/op_tests/where_v2.py
@@ -82,13 +82,6 @@
"input_dtype": [tf.float32, tf.int32],
"input_shape_set": [([1, 2, 2], [1, 2]),],
},
- # Requires kernel supporting broadcasting for 5D case.
- {
- "condition_dtype": [tf.float32],
- "input_condition_shape": [[1, 1, 1, 1, 1]],
- "input_dtype": [tf.float32],
- "input_shape_set": [([], [1, None, 1, 2, 512])],
- },
]
def build_graph(parameters):
@@ -116,21 +109,16 @@
out = tf.where_v2(input_condition, input_value1, input_value2)
return [input_condition, input_value1, input_value2], [out]
- def build_input_shape(input_shape):
- return [1 if v is None else v for v in input_shape]
-
def build_inputs(parameters, sess, inputs, outputs):
input_condition = create_tensor_data(parameters["condition_dtype"],
parameters["input_condition_shape"])
input_value1 = None
input_value2 = None
if parameters["input_dtype"] is not None:
- input_value1 = create_tensor_data(
- parameters["input_dtype"],
- build_input_shape(parameters["input_shape_set"][0]))
- input_value2 = create_tensor_data(
- parameters["input_dtype"],
- build_input_shape(parameters["input_shape_set"][1]))
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_set"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_set"][1])
return [input_condition, input_value1, input_value2], sess.run(
outputs,
feed_dict=dict(
diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc
index 41bf297..5eec312 100644
--- a/tensorflow/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/lite/toco/tflite/operator_test.cc
@@ -1085,33 +1085,6 @@
EXPECT_EQ(base_op->GetVersion(signature), version);
}
-template <typename OpType>
-void SimpleThreeInputsVersioningTest(ArrayDataType data_type, Shape shape1,
- Shape shape2, Shape shape3, int version) {
- OpType op;
- op.inputs = {"input1", "input2", "input3"};
- op.outputs = {"output"};
- auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
- const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
-
- Model model;
- Array& input0 = model.GetOrCreateArray(op.inputs[0]);
- Array& input1 = model.GetOrCreateArray(op.inputs[1]);
- Array& input2 = model.GetOrCreateArray(op.inputs[2]);
- Array& output = model.GetOrCreateArray(op.outputs[0]);
-
- input0.data_type = data_type;
- input0.copy_shape(shape1);
- input1.data_type = data_type;
- input1.copy_shape(shape2);
- input2.data_type = data_type;
- input2.copy_shape(shape3);
- output.data_type = data_type;
-
- OperatorSignature signature = {.op = &op, .model = &model};
- EXPECT_EQ(base_op->GetVersion(signature), version);
-}
-
TEST_F(OperatorTest, VersioningSubTest) {
SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
{1, 2, 2, 2}, {1, 2, 2, 2}, 1);
@@ -1153,13 +1126,7 @@
}
TEST_F(OperatorTest, VersioningSelectTest) {
- SimpleThreeInputsVersioningTest<SelectOperator>(
- ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 1);
- SimpleThreeInputsVersioningTest<SelectOperator>(
- ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 2);
- SimpleThreeInputsVersioningTest<SelectOperator>(
- ArrayDataType::kInt8, {1, 2, 2, 2, 1}, {1, 2, 2, 1, 1}, {1, 2, 2, 1, 1},
- 3);
+ SimpleVersioningTest<SelectOperator>();
}
TEST_F(OperatorTest, VersioningRelu6Test) {
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index b11a133..2a25ce8 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -776,16 +776,6 @@
}
return 1;
- case BuiltinOperator_SELECT: {
- if (op_sig.inputs.at(0).dims.size() == 5 ||
- op_sig.inputs.at(1).dims.size() == 5 ||
- op_sig.inputs.at(2).dims.size() == 5)
- return 3;
- if (op_sig.inputs.at(0).type == kTfLiteInt8) {
- return 2;
- }
- return 1;
- }
case BuiltinOperator_SPACE_TO_DEPTH:
case BuiltinOperator_SPLIT_V:
case BuiltinOperator_SUM:
@@ -795,6 +785,7 @@
case BuiltinOperator_GREATER_EQUAL:
case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_SELECT:
case BuiltinOperator_RSQRT:
case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_DEPTH_TO_SPACE:
diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc
index 588fbda..119d439 100644
--- a/tensorflow/lite/tools/versioning/op_version_test.cc
+++ b/tensorflow/lite/tools/versioning/op_version_test.cc
@@ -36,22 +36,6 @@
return tensor_specs;
}
-// Creates vector of OpSignatureTensorSpec with the given TfLiteType vector,
-// each with rank 'rank'
-std::vector<OpSignatureTensorSpec> CreateOpSignatureTensorSpecs(
- const std::vector<TfLiteType>& types, int rank) {
- std::vector<OpSignatureTensorSpec> tensor_specs;
- for (auto type : types) {
- OpSignatureTensorSpec tensor_spec = {};
- tensor_spec.type = type;
- for (int i = 0; i < rank; i++) {
- tensor_spec.dims.push_back(4);
- }
- tensor_specs.push_back(tensor_spec);
- }
- return tensor_specs;
-}
-
// Creates vector of OpSignatureTensorSpec of single tensor spec of TfLiteType.
std::vector<OpSignatureTensorSpec> CreateOpSignatureTensorSpecs(
const TfLiteType type) {
@@ -507,29 +491,7 @@
}
TEST(OpVersionTest, VersioningSelectTest) {
- OpSignature fake_op_sig = {
- .op = BuiltinOperator_SELECT,
- .inputs = CreateOpSignatureTensorSpecs(
- std::vector<TfLiteType>{kTfLiteUInt8, kTfLiteUInt8, kTfLiteUInt8}, 5),
- .outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt8),
- };
- EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
- fake_op_sig = {
- .op = BuiltinOperator_SELECT,
- .inputs = CreateOpSignatureTensorSpecs(
- std::vector<TfLiteType>{kTfLiteInt8, kTfLiteInt8, kTfLiteInt8}, 4),
- .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8),
- };
- EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
- fake_op_sig = {
- .op = BuiltinOperator_SELECT,
- .inputs = CreateOpSignatureTensorSpecs(
- std::vector<TfLiteType>{kTfLiteFloat32, kTfLiteFloat32,
- kTfLiteFloat32},
- 4),
- .outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32),
- };
- EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+ SimpleVersioningTest(BuiltinOperator_SELECT);
}
TEST(OpVersionTest, VersioningRelu6Test) {