[tflite] Fix import of 0 splat tensors

Splat tensors of 0 have their values omitted entirely by Grappler. The import needs to account for that.

PiperOrigin-RevId: 442555519
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index e849c97..23cc266 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -357,17 +357,18 @@
   } else if (input_tensor.tensor_content().size() ==
              input_flat_size * sizeof(T)) {
     TensorTraits<T>::CopyFromContent(input_tensor, output_data);
-  } else if (num_elements_in_tensor > 0 &&
+  } else if (num_elements_in_tensor >= 0 &&
              num_elements_in_tensor < input_flat_size) {
     // TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
     // official way to import tensor data. This particular else-if handles a
     // grappler optimization where the last few elements in a tensor are
-    // omitted if they are repeated.
+    // omitted if they are repeated, and where all elements are omitted if they
+    // are zero.
     int i = 0;
     for (; i < num_elements_in_tensor; ++i) {
       (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
     }
-    auto last = (*output_data)[i - 1];
+    auto last = i == 0 ? T(0) : (*output_data)[i - 1];
     for (; i < input_flat_size; ++i) {
       (*output_data)[i] = last;
     }
diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc
index 8f34f87..31df791 100644
--- a/tensorflow/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/lite/toco/import_tensorflow_test.cc
@@ -291,17 +291,6 @@
             "Tensor shape is too large\n\t (while processing node 'Node1')");
 }
 
-TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
-  NodeDef node;
-  BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node);
-  auto status = ImportNode(node);
-  EXPECT_THAT(status.error_message(),
-              ::testing::MatchesRegex(
-                  "Neither input_content .0. nor .*_val .0. have the right "
-                  "dimensions .8. for this .* tensor\n\t .while processing "
-                  "node 'Node1'."));
-}
-
 std::vector<tensorflow::DataType> TestTypes() {
   return {DT_FLOAT, DT_INT32, DT_INT64, DT_BOOL, DT_QUINT8, DT_COMPLEX64};
 }
@@ -344,6 +333,8 @@
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 2, 3, 4, 5, 5));
   RemoveTrailingElements(&node, 4);
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+  RemoveTrailingElements(&node, 1);
+  EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
 }
 
 TEST_F(ContentImportTest, Int64) {
@@ -357,6 +348,8 @@
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 2, 3, 4, 5, 5));
   RemoveTrailingElements(&node, 4);
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+  RemoveTrailingElements(&node, 1);
+  EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
 }
 
 TEST_F(ContentImportTest, Quint8) {
@@ -370,6 +363,8 @@
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 2, 3, 4, 5, 5));
   RemoveTrailingElements(&node, 4);
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+  RemoveTrailingElements(&node, 1);
+  EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
 }
 
 TEST_F(ContentImportTest, Bool) {
@@ -383,6 +378,8 @@
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 0, 1, 0, 1, 1));
   RemoveTrailingElements(&node, 4);
   EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(1, 1, 1, 1, 1, 1));
+  RemoveTrailingElements(&node, 1);
+  EXPECT_THAT(ImportAndGetData<kType>(node), ElementsAre(0, 0, 0, 0, 0, 0));
 }
 
 TEST_F(ContentImportTest, Float) {
@@ -399,6 +396,9 @@
   RemoveTrailingElements(&node, 4);
   EXPECT_THAT(ImportAndGetData<kType>(node),
               ElementsAre(1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000));
+  RemoveTrailingElements(&node, 1);
+  EXPECT_THAT(ImportAndGetData<kType>(node),
+              ElementsAre(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000));
 }
 
 TEST_F(ContentImportTest, Complex64) {
@@ -426,6 +426,13 @@
       ElementsAre(std::complex<float>(1.0000, -1.0000), cplx(1.0000, -1.0000),
                   cplx(1.0000, -1.0000), cplx(1.0000, -1.0000),
                   cplx(1.0000, -1.0000), cplx(1.0000, -1.0000)));
+
+  RemoveTrailingElements(&node, 1);
+  EXPECT_THAT(
+      ImportAndGetData<kType>(node),
+      ElementsAre(std::complex<float>(0.0000, 0.0000), cplx(0.0000, 0.0000),
+                  cplx(0.0000, 0.0000), cplx(0.0000, 0.0000),
+                  cplx(0.0000, 0.0000), cplx(0.0000, 0.0000)));
 }
 
 std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {