Refactor test to enable support of other variants of LSTM.

PiperOrigin-RevId: 282605030
Change-Id: I407fce1f3a419d6180cf494c7387dc6cbf8389dd
diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc
index 704a711..d3836e3 100644
--- a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc
+++ b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc
@@ -30,23 +30,30 @@
   // Create a model with 1 lstm layer.
   auto model = absl::make_unique<ModelT>();
   auto subgraph = absl::make_unique<tflite::SubGraphT>();
-  auto tensor = absl::make_unique<TensorT>();
   auto buffer = absl::make_unique<tflite::BufferT>();
   auto lstm_op_code = absl::make_unique<OperatorCodeT>();
   auto lstm_op = absl::make_unique<OperatorT>();
 
-  tensor->name = "lstm_tensor0";
-  tensor->shape = {2, 3, 4};
-  tensor->type = TensorType_FLOAT32;
   lstm_op_code->builtin_code = BuiltinOperator_LSTM;
   lstm_op_code->version = 2;
   lstm_op->opcode_index = 0;
-  lstm_op->inputs = {0};
-  lstm_op->outputs = {0};
+  lstm_op->inputs = {0, 1,  2,  3,  4,  5,  6,  7,  8,  -1, -1, -1,
+                     9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
+  lstm_op->outputs = {24};
 
   model->subgraphs.push_back(std::move(subgraph));
+  for (int i = 0; i < lstm_op->inputs.size(); ++i) {
+    const int index = lstm_op->inputs[i];
+    if (index == -1) {
+      continue;
+    }
+    auto tensor = absl::make_unique<TensorT>();
+    tensor->name = "lstm_tensor" + std::to_string(index);
+    tensor->shape = {2, 3, 4};
+    tensor->type = TensorType_FLOAT32;
+    model->subgraphs[0]->tensors.push_back(std::move(tensor));
+  }
   model->subgraphs[0]->operators.push_back(std::move(lstm_op));
-  model->subgraphs[0]->tensors.push_back(std::move(tensor));
   model->operator_codes.push_back(std::move(lstm_op_code));
   model->buffers.push_back(std::move(buffer));
 
@@ -58,21 +65,24 @@
   EXPECT_EQ(model->operator_codes.size(), 1);
   EXPECT_EQ(model->subgraphs.size(), 1);
   EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
-  EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
+  EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26);
   EXPECT_EQ(model->buffers.size(), 1);
 
   EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM);
   EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
-  EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "intermediate_0_0");
-  EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "intermediate_0_1");
-  EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "intermediate_0_2");
-  EXPECT_EQ(model->subgraphs[0]->tensors[4]->name, "intermediate_0_3");
-  EXPECT_EQ(model->subgraphs[0]->tensors[5]->name, "intermediate_0_4");
-  EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({0}));
+  EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
+  EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
+  EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
+  EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
+  EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
+  EXPECT_THAT(
+      model->subgraphs[0]->operators[0]->inputs,
+      ElementsAreArray({0, 1,  2,  3,  4,  5,  6,  7,  8,  -1, -1, -1,
+                        9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
   EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
-              ElementsAreArray({0}));
+              ElementsAreArray({24}));
   EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates,
-              ElementsAreArray({1, 2, 3, 4, 5}));
+              ElementsAreArray({21, 22, 23, 24, 25}));
 
   // Call AddIntemediateTensorsToFusedOp again and expect no change in model.
   tflite::optimize::AddIntemediateTensorsToFusedOp(&builder, model.get());
@@ -81,21 +91,24 @@
   EXPECT_EQ(model->operator_codes.size(), 1);
   EXPECT_EQ(model->subgraphs.size(), 1);
   EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
-  EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
+  EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26);
   EXPECT_EQ(model->buffers.size(), 1);
 
   EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM);
   EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
-  EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "intermediate_0_0");
-  EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "intermediate_0_1");
-  EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "intermediate_0_2");
-  EXPECT_EQ(model->subgraphs[0]->tensors[4]->name, "intermediate_0_3");
-  EXPECT_EQ(model->subgraphs[0]->tensors[5]->name, "intermediate_0_4");
-  EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({0}));
+  EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
+  EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
+  EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
+  EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
+  EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
+  EXPECT_THAT(
+      model->subgraphs[0]->operators[0]->inputs,
+      ElementsAreArray({0, 1,  2,  3,  4,  5,  6,  7,  8,  -1, -1, -1,
+                        9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
   EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
-              ElementsAreArray({0}));
+              ElementsAreArray({24}));
   EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates,
-              ElementsAreArray({1, 2, 3, 4, 5}));
+              ElementsAreArray({21, 22, 23, 24, 25}));
 }
 
 }  // namespace