[tflite] Fix import of 0 splat tensors
Splat tensors of 0 have their values omitted entirely by Grappler. The import needs to account for that.
PiperOrigin-RevId: 442555519
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index e849c97..23cc266 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -357,17 +357,18 @@
} else if (input_tensor.tensor_content().size() ==
input_flat_size * sizeof(T)) {
TensorTraits<T>::CopyFromContent(input_tensor, output_data);
- } else if (num_elements_in_tensor > 0 &&
+ } else if (num_elements_in_tensor >= 0 &&
num_elements_in_tensor < input_flat_size) {
// TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
// official way to import tensor data. This particular else-if handles a
// grappler optimization where the last few elements in a tensor are
- // omitted if they are repeated.
+ // omitted if they are repeated, and where all elements are omitted if they
+ // are zero.
int i = 0;
for (; i < num_elements_in_tensor; ++i) {
(*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
}
- auto last = (*output_data)[i - 1];
+ auto last = i == 0 ? T(0) : (*output_data)[i - 1];
for (; i < input_flat_size; ++i) {
(*output_data)[i] = last;
}
diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc
index 8f34f87..31df791 100644
--- a/tensorflow/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/lite/toco/import_tensorflow_test.cc
@@ -291,17 +291,6 @@
"Tensor shape is too large\n\t (while processing node 'Node1')");
}
-TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
- NodeDef node;
- BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node);
- auto status = ImportNode(node);
- EXPECT_THAT(status.error_message(),
- ::testing::MatchesRegex(
- "Neither input_content .0. nor .*_val .0. have the right "
- "dimensions .8. for this .* tensor\n\t .while processing "
- "node 'Node1'."));
-}
-
std::vector<tensorflow::DataType> TestTypes() {
return {DT_FLOAT, DT_INT32, DT_INT64, DT_BOOL, DT_QUINT8, DT_COMPLEX64};
}
@@ -344,6 +333,8 @@
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 2, 3, 4, 5, 5));
RemoveTrailingElements(&node, 4);
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+ RemoveTrailingElements(&node, 1);
+ EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
}
TEST_F(ContentImportTest, Int64) {
@@ -357,6 +348,8 @@
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 2, 3, 4, 5, 5));
RemoveTrailingElements(&node, 4);
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+ RemoveTrailingElements(&node, 1);
+ EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
}
TEST_F(ContentImportTest, Quint8) {
@@ -370,6 +363,8 @@
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 2, 3, 4, 5, 5));
RemoveTrailingElements(&node, 4);
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+ RemoveTrailingElements(&node, 1);
+ EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
}
TEST_F(ContentImportTest, Bool) {
@@ -383,6 +378,8 @@
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 0, 1, 0, 1, 1));
RemoveTrailingElements(&node, 4);
EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+ RemoveTrailingElements(&node, 1);
+ EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
}
TEST_F(ContentImportTest, Float) {
@@ -399,6 +396,9 @@
RemoveTrailingElements(&node, 4);
EXPECT_THAT(ImportAndGetData<kType>(node),
ElementsAre(1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000));
+ RemoveTrailingElements(&node, 1);
+ EXPECT_THAT(ImportAndGetData<kType>(node),
+ ElementsAre(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000));
}
TEST_F(ContentImportTest, Complex64) {
@@ -426,6 +426,13 @@
ElementsAre(std::complex<float>(1.0000, -1.0000), cplx(1.0000, -1.0000),
cplx(1.0000, -1.0000), cplx(1.0000, -1.0000),
cplx(1.0000, -1.0000), cplx(1.0000, -1.0000)));
+
+ RemoveTrailingElements(&node, 1);
+ EXPECT_THAT(
+ ImportAndGetData<kType>(node),
+ ElementsAre(std::complex<float>(0.0000, 0.0000), cplx(0.0000, 0.0000),
+ cplx(0.0000, 0.0000), cplx(0.0000, 0.0000),
+ cplx(0.0000, 0.0000), cplx(0.0000, 0.0000)));
}
std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {