Fix incorrect index while setting unique output
PiperOrigin-RevId: 294708167
Change-Id: I6a99bb5fadff6a1ce5cd9584cff2e9a0e6be379b
diff --git a/tensorflow/lite/kernels/unique.cc b/tensorflow/lite/kernels/unique.cc
index ea0639c..d0d277e 100644
--- a/tensorflow/lite/kernels/unique.cc
+++ b/tensorflow/lite/kernels/unique.cc
@@ -65,6 +65,7 @@
// increase in the binary size.
std::map<T, int> unique_values;
TfLiteTensor* output_indexes = GetOutput(context, node, 1);
+ std::vector<T> output_values;
I* indexes = GetTensorData<I>(output_indexes);
const T* data = GetTensorData<T>(input);
const int num_elements = NumElements(input);
@@ -77,6 +78,7 @@
const int unique_index = unique_values.size();
unique_values[data[i]] = unique_index;
indexes[i] = unique_index;
+ output_values.push_back(data[i]);
}
}
// Allocate output tensor.
@@ -88,8 +90,8 @@
context->ResizeTensor(context, unique_output, shape.release()));
// Set the values in the output tensor.
T* output_unique_values = GetTensorData<T>(unique_output);
- for (int i = 0; i < unique_values.size(); ++i) {
- output_unique_values[i] = data[indexes[i]];
+ for (int i = 0; i < output_values.size(); ++i) {
+ output_unique_values[i] = output_values[i];
}
return kTfLiteOk;
}
diff --git a/tensorflow/lite/kernels/unique_test.cc b/tensorflow/lite/kernels/unique_test.cc
index 1df5e6b..b18fcbe 100644
--- a/tensorflow/lite/kernels/unique_test.cc
+++ b/tensorflow/lite/kernels/unique_test.cc
@@ -89,6 +89,16 @@
ElementsAreArray({0, 1, 2, 3, 0, 3, 1}));
}
+TEST(UniqueOpModelTest, MultipleElements_RepeatedDuplicates) {
+ UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {6}},
+ TensorType_FLOAT32, TensorType_INT32);
+ model.PopulateTensor<float>(model.input_tensor_id(),
+ {-1, -1, -2, -2, -3, -3});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({-1, -2, -3}));
+ EXPECT_THAT(model.GetIndexesOutput(), ElementsAreArray({0, 0, 1, 1, 2, 2}));
+}
+
TEST(UniqueOpModelTest, MultipleElements_SomeDuplicates_IndexInt64) {
UniqueOpModel<float, int64_t> model({TensorType_FLOAT32, {7}},
TensorType_FLOAT32, TensorType_INT64);