Refactor test to enable support of other variants of LSTM.
PiperOrigin-RevId: 282605030
Change-Id: I407fce1f3a419d6180cf494c7387dc6cbf8389dd
diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc
index 704a711..d3836e3 100644
--- a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc
+++ b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc
@@ -30,23 +30,30 @@
// Create a model with 1 lstm layer.
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
- auto tensor = absl::make_unique<TensorT>();
auto buffer = absl::make_unique<tflite::BufferT>();
auto lstm_op_code = absl::make_unique<OperatorCodeT>();
auto lstm_op = absl::make_unique<OperatorT>();
- tensor->name = "lstm_tensor0";
- tensor->shape = {2, 3, 4};
- tensor->type = TensorType_FLOAT32;
lstm_op_code->builtin_code = BuiltinOperator_LSTM;
lstm_op_code->version = 2;
lstm_op->opcode_index = 0;
- lstm_op->inputs = {0};
- lstm_op->outputs = {0};
+ lstm_op->inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
+ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
+ lstm_op->outputs = {24};
model->subgraphs.push_back(std::move(subgraph));
+ for (int i = 0; i < lstm_op->inputs.size(); ++i) {
+ const int index = lstm_op->inputs[i];
+ if (index == -1) {
+ continue;
+ }
+ auto tensor = absl::make_unique<TensorT>();
+ tensor->name = "lstm_tensor" + std::to_string(index);
+ tensor->shape = {2, 3, 4};
+ tensor->type = TensorType_FLOAT32;
+ model->subgraphs[0]->tensors.push_back(std::move(tensor));
+ }
model->subgraphs[0]->operators.push_back(std::move(lstm_op));
- model->subgraphs[0]->tensors.push_back(std::move(tensor));
model->operator_codes.push_back(std::move(lstm_op_code));
model->buffers.push_back(std::move(buffer));
@@ -58,21 +65,24 @@
EXPECT_EQ(model->operator_codes.size(), 1);
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
- EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
+ EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26);
EXPECT_EQ(model->buffers.size(), 1);
EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM);
EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
- EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "intermediate_0_0");
- EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "intermediate_0_1");
- EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "intermediate_0_2");
- EXPECT_EQ(model->subgraphs[0]->tensors[4]->name, "intermediate_0_3");
- EXPECT_EQ(model->subgraphs[0]->tensors[5]->name, "intermediate_0_4");
- EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({0}));
+ EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
+ EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
+ EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
+ EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
+ EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
+ EXPECT_THAT(
+ model->subgraphs[0]->operators[0]->inputs,
+ ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
+ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
- ElementsAreArray({0}));
+ ElementsAreArray({24}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates,
- ElementsAreArray({1, 2, 3, 4, 5}));
+ ElementsAreArray({21, 22, 23, 24, 25}));
// Call AddIntemediateTensorsToFusedOp again and expect no change in model.
tflite::optimize::AddIntemediateTensorsToFusedOp(&builder, model.get());
@@ -81,21 +91,24 @@
EXPECT_EQ(model->operator_codes.size(), 1);
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
- EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
+ EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26);
EXPECT_EQ(model->buffers.size(), 1);
EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM);
EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
- EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "intermediate_0_0");
- EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "intermediate_0_1");
- EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "intermediate_0_2");
- EXPECT_EQ(model->subgraphs[0]->tensors[4]->name, "intermediate_0_3");
- EXPECT_EQ(model->subgraphs[0]->tensors[5]->name, "intermediate_0_4");
- EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({0}));
+ EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
+ EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
+ EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
+ EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
+ EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
+ EXPECT_THAT(
+ model->subgraphs[0]->operators[0]->inputs,
+ ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
+ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
- ElementsAreArray({0}));
+ ElementsAreArray({24}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates,
- ElementsAreArray({1, 2, 3, 4, 5}));
+ ElementsAreArray({21, 22, 23, 24, 25}));
}
} // namespace