| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/lite/tools/list_flex_ops.h" |
| |
| #include <cstdint> |
| |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| #include "flatbuffers/flexbuffers.h" // from @flatbuffers |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/resource_loader.h" |
| #include "tensorflow/lite/kernels/test_util.h" |
| |
| namespace tflite { |
| namespace flex { |
| |
| class FlexOpsListTest : public ::testing::Test { |
| protected: |
| FlexOpsListTest() {} |
| |
| void ReadOps(const string& path) { |
| std::string full_path = tensorflow::GetDataDependencyFilepath(path); |
| auto model = FlatBufferModel::BuildFromFile(full_path.data()); |
| AddFlexOpsFromModel(model->GetModel(), &flex_ops_); |
| output_text_ = OpListToJSONString(flex_ops_); |
| } |
| |
| void ReadOps(const tflite::Model* model) { |
| AddFlexOpsFromModel(model, &flex_ops_); |
| output_text_ = OpListToJSONString(flex_ops_); |
| } |
| |
| std::string output_text_; |
| OpKernelSet flex_ops_; |
| }; |
| |
| TfLiteRegistration* Register_TEST() { |
| static TfLiteRegistration r = {nullptr, nullptr, nullptr, nullptr}; |
| return &r; |
| } |
| |
| std::vector<uint8_t> CreateFlexCustomOptions(std::string nodedef_raw_string) { |
| tensorflow::NodeDef node_def; |
| tensorflow::protobuf::TextFormat::ParseFromString(nodedef_raw_string, |
| &node_def); |
| std::string node_def_str = node_def.SerializeAsString(); |
| auto flex_builder = std::make_unique<flexbuffers::Builder>(); |
| flex_builder->Vector([&]() { |
| flex_builder->String(node_def.op()); |
| flex_builder->String(node_def_str); |
| }); |
| flex_builder->Finish(); |
| return flex_builder->GetBuffer(); |
| } |
| |
| class FlexOpModel : public SingleOpModel { |
| public: |
| FlexOpModel(const std::string& op_name, const TensorData& input1, |
| const TensorData& input2, const TensorType& output, |
| const std::vector<uint8_t>& custom_options) { |
| input1_ = AddInput(input1); |
| input2_ = AddInput(input2); |
| output_ = AddOutput(output); |
| SetCustomOp(op_name, custom_options, Register_TEST); |
| BuildInterpreter({GetShape(input1_), GetShape(input2_)}); |
| } |
| |
| protected: |
| int input1_; |
| int input2_; |
| int output_; |
| }; |
| |
| TEST_F(FlexOpsListTest, TestModelsNoFlex) { |
| ReadOps("tensorflow/lite/testdata/test_model.bin"); |
| EXPECT_EQ(output_text_, "[]"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestBrokenModel) { |
| EXPECT_DEATH_IF_SUPPORTED( |
| ReadOps("tensorflow/lite/testdata/test_model_broken.bin"), ""); |
| } |
| |
| TEST_F(FlexOpsListTest, TestZeroSubgraphs) { |
| ReadOps("tensorflow/lite/testdata/0_subgraphs.bin"); |
| EXPECT_EQ(output_text_, "[]"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestFlexAdd) { |
| ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); |
| EXPECT_EQ(output_text_, |
| "[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestTwoModel) { |
| ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); |
| ReadOps("tensorflow/lite/testdata/softplus_flex.bin"); |
| EXPECT_EQ(output_text_, |
| "[[\"Add\", \"BinaryOp<CPUDevice, " |
| "functor::add<float>>\"],\n[\"Softplus\", \"SoftplusOp<CPUDevice, " |
| "float>\"]]"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestDuplicatedOp) { |
| ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); |
| ReadOps("tensorflow/lite/testdata/multi_add_flex.bin"); |
| EXPECT_EQ(output_text_, |
| "[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestInvalidCustomOptions) { |
| // Using a invalid custom options, expected to fail. |
| std::vector<uint8_t> random_custom_options(20); |
| FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, |
| {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, |
| random_custom_options); |
| EXPECT_DEATH_IF_SUPPORTED( |
| ReadOps(tflite::GetModel(max_model.GetModelBuffer())), |
| "Failed to parse data into a valid NodeDef"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestOpNameEmpty) { |
| // NodeDef with empty opname. |
| std::string nodedef_raw_str = |
| "name: \"node_1\"" |
| "op: \"\"" |
| "input: [ \"b\", \"c\" ]" |
| "attr: { key: \"T\" value: { type: DT_FLOAT } }"; |
| std::string random_fieldname = "random string"; |
| FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, |
| {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, |
| CreateFlexCustomOptions(nodedef_raw_str)); |
| EXPECT_DEATH_IF_SUPPORTED( |
| ReadOps(tflite::GetModel(max_model.GetModelBuffer())), "Invalid NodeDef"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestOpNotFound) { |
| // NodeDef with invalid opname. |
| std::string nodedef_raw_str = |
| "name: \"node_1\"" |
| "op: \"FlexInvalidOp\"" |
| "input: [ \"b\", \"c\" ]" |
| "attr: { key: \"T\" value: { type: DT_FLOAT } }"; |
| |
| FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, |
| {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, |
| CreateFlexCustomOptions(nodedef_raw_str)); |
| EXPECT_DEATH_IF_SUPPORTED( |
| ReadOps(tflite::GetModel(max_model.GetModelBuffer())), |
| "Op FlexInvalidOp not found"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestKernelNotFound) { |
| // NodeDef with non-supported type. |
| std::string nodedef_raw_str = |
| "name: \"node_1\"" |
| "op: \"Add\"" |
| "input: [ \"b\", \"c\" ]" |
| "attr: { key: \"T\" value: { type: DT_BOOL } }"; |
| |
| FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, |
| {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, |
| CreateFlexCustomOptions(nodedef_raw_str)); |
| EXPECT_DEATH_IF_SUPPORTED( |
| ReadOps(tflite::GetModel(max_model.GetModelBuffer())), |
| "Failed to find kernel class for op: Add"); |
| } |
| |
| TEST_F(FlexOpsListTest, TestFlexAddWithSingleOpModel) { |
| std::string nodedef_raw_str = |
| "name: \"node_1\"" |
| "op: \"Add\"" |
| "input: [ \"b\", \"c\" ]" |
| "attr: { key: \"T\" value: { type: DT_FLOAT } }"; |
| |
| FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, |
| {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, |
| CreateFlexCustomOptions(nodedef_raw_str)); |
| ReadOps(tflite::GetModel(max_model.GetModelBuffer())); |
| EXPECT_EQ(output_text_, |
| "[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]"); |
| } |
| } // namespace flex |
| } // namespace tflite |
| |
| int main(int argc, char** argv) { |
| // On Linux, add: FLAGS_logtostderr = true; |
| ::testing::InitGoogleTest(&argc, argv); |
| return RUN_ALL_TESTS(); |
| } |