Automated g4 rollback of changelist 321508374.
PiperOrigin-RevId: 321601571
Change-Id: If3b349d65d9030dca3a6ddaddc78d05ea9845a05
diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc
index e484c5b..a565422 100644
--- a/tensorflow/lite/experimental/writer/option_writer_generator.cc
+++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc
@@ -265,32 +265,6 @@
" }\n break;\n");
}
-// Reshape Op infers output shape either from Parameter or from shape tensor
-// that's is an additional input. When we have this additional shape tensor as
-// input we don't have the parameter present in this layer. In case of more than
-// one input we import an empty vector for the parameters.
-void GenerateImportForReshapeOp(FILE* fp) {
- fprintf(fp,
- " case BuiltinOperator_RESHAPE: {\n"
- " const auto* params = reinterpret_cast<const "
- "TfLiteReshapeParams*>(builtin_op_data);\n"
- " flatbuffers::Offset<void> union_type;\n"
- " if ((node.inputs->size > 1) &&\n"
- " (params->num_dimensions < 0 ||\n"
- " params->num_dimensions >= "
- "TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT)) {\n"
- " union_type = CreateReshapeOptions(*fbb).Union();\n"
- " } else {\n"
- " auto val0 = fbb->CreateVector(std::vector<int>(params->shape, "
- "params->shape + params->num_dimensions));\n"
- " union_type = CreateReshapeOptions(*fbb, "
- "val0).Union();\n"
- " }\n"
- " return std::make_pair(BuiltinOptions_ReshapeOptions, "
- "union_type);\n"
- " }\n break;\n");
-}
-
void GenerateImportForOp(FILE* fp, const std::string& op_name,
const std::string& option_name,
const std::string& option_type,
@@ -302,13 +276,6 @@
return;
}
- // Special case Reshape that may have 'new_shape' field missing from the
- // parameters.
- if (struct_name == "TfLiteReshapeParams") {
- GenerateImportForReshapeOp(fp);
- return;
- }
-
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
if (options->num_elems != 0) {
fprintf(fp,
diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc
index 2c71919..85f5752 100644
--- a/tensorflow/lite/experimental/writer/writer_lib.cc
+++ b/tensorflow/lite/experimental/writer/writer_lib.cc
@@ -31,7 +31,7 @@
std::pair<BuiltinOptions, flatbuffers::Offset<void>> CreateBuiltinUnion(
flatbuffers::FlatBufferBuilder* fbb, enum BuiltinOperator op,
- void* builtin_op_data, const TfLiteNode& node) {
+ void* builtin_op_data) {
switch (op) {
#include "tensorflow/lite/experimental/writer/option_writer_generated.h"
}
@@ -82,7 +82,7 @@
// builtin
auto builtin_options_and_type = CreateBuiltinUnion(
fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
- node.builtin_data, node);
+ node.builtin_data);
builtin_options = builtin_options_and_type.second;
builtin_options_type = builtin_options_and_type.first;
} else {
diff --git a/tensorflow/lite/experimental/writer/writer_lib_test.cc b/tensorflow/lite/experimental/writer/writer_lib_test.cc
index 4cab27e..41cca88 100644
--- a/tensorflow/lite/experimental/writer/writer_lib_test.cc
+++ b/tensorflow/lite/experimental/writer/writer_lib_test.cc
@@ -15,8 +15,6 @@
#include "tensorflow/lite/experimental/writer/writer_lib.h"
-#include <numeric>
-
#include <gtest/gtest.h>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
@@ -186,79 +184,6 @@
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
}
-struct ReshapeTestPattern {
- int num_inputs;
- bool is_param_valid;
-};
-
-class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
-
-TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
- const auto param = GetParam();
- Interpreter interpreter;
- const int total_tensors = param.num_inputs + 1;
- interpreter.AddTensors(total_tensors);
- int output_shape[] = {1, 2, 3};
- interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32,
- /*name=*/"a", /*dims=*/{6},
- TfLiteQuantization());
- ASSERT_LE(param.num_inputs, 2);
- if (param.num_inputs == 2) {
- interpreter.SetTensorParametersReadOnly(
- /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
- TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
- sizeof(output_shape));
- }
- interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
- kTfLiteFloat32, /*name=*/"c",
- /*dims=*/{3}, TfLiteQuantization());
-
- std::vector<int> input_tensors(param.num_inputs);
- std::iota(input_tensors.begin(), input_tensors.end(), 0);
-
- interpreter.SetInputs(input_tensors);
- interpreter.SetOutputs({total_tensors - 1});
- const char* initial_data = "";
- tflite::ops::builtin::BuiltinOpResolver resolver;
- TfLiteReshapeParams* builtin_data = reinterpret_cast<TfLiteReshapeParams*>(
- malloc(sizeof(TfLiteReshapeParams)));
- if (param.is_param_valid) {
- builtin_data->num_dimensions = 3;
- for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) {
- builtin_data->shape[dim] = output_shape[dim];
- }
- }
- const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1);
- interpreter.AddNodeWithParameters(input_tensors,
- /*outputs=*/{total_tensors - 1},
- initial_data, /*init_data_size=*/0,
- reinterpret_cast<void*>(builtin_data), reg);
-
- SubgraphWriter writer(&interpreter.primary_subgraph());
- std::string filename = absl::StrCat("/tmp/test_reshape_", param.num_inputs,
- "_", param.is_param_valid, ".tflite");
- writer.Write(filename);
- std::unique_ptr<FlatBufferModel> model =
- FlatBufferModel::BuildFromFile(filename.c_str());
- InterpreterBuilder builder(*model, resolver);
- std::unique_ptr<Interpreter> new_interpreter;
- builder(&new_interpreter);
- ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
-}
-
-INSTANTIATE_TEST_SUITE_P(
- Writer, ReshapeLayerTest,
- ::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
- /*is_param_valid=*/true},
- ReshapeTestPattern{/*num_inputs=*/2,
- /*is_param_valid=*/false},
- ReshapeTestPattern{/*num_inputs=*/1,
- /*is_param_valid=*/true}),
- [](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
- std::string name = absl::StrCat("num_inputs_", info.param.num_inputs,
- "_isvalid_", info.param.is_param_valid);
- return name;
- });
} // namespace tflite
int main(int argc, char** argv) {