Add a new overload of tflite::Verify that doesn't take an OpResolver,
and deprecate the existing tflite::Verify function.
PiperOrigin-RevId: 365904539
Change-Id: I41f5ee929befd64a8b804761d044f8dc9217fc35
diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h
index 0329766..bbf1358 100644
--- a/tensorflow/lite/interpreter.h
+++ b/tensorflow/lite/interpreter.h
@@ -75,7 +75,9 @@
/// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
/// // Return failure.
/// }
-/// interpreter->AllocateTensors();
+/// if (interpreter->AllocateTensors() != kTfLiteOk) {
+/// // Return failure.
+/// }
///
/// auto input = interpreter->typed_tensor<float>(0);
/// for (int i = 0; i < input_size; i++) {
@@ -428,7 +430,9 @@
// expensive. This *must be* called after the interpreter has been created
// and before running inference (and accessing tensor buffers), and *must be*
// called again if (and only if) an input tensor is resized. Returns status of
- // success or failure.
+ // success or failure. Will fail if any of the ops in the model (other than
+ // those which were rewritten by delegates, if any) are not supported by the
+ // Interpreter's OpResolver.
TfLiteStatus AllocateTensors();
/// Invoke the interpreter (run the whole graph in dependency order).
diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc
index 9110621..8c3bd27 100644
--- a/tensorflow/lite/tools/verifier.cc
+++ b/tensorflow/lite/tools/verifier.cc
@@ -718,11 +718,7 @@
return true;
}
-} // namespace
-
-bool Verify(const void* buf, size_t len, const OpResolver& resolver,
- ErrorReporter* error_reporter) {
- const Model* model = VerifyFlatbufferAndGetModel(buf, len);
+bool VerifyModel(const Model* model, ErrorReporter* error_reporter) {
if (model == nullptr) {
ReportError(error_reporter, "Invalid flatbuffer format");
return false;
@@ -737,6 +733,23 @@
if (!VerifyTensors(*model, error_reporter)) {
return false;
}
+ return true;
+}
+
+} // namespace
+
+bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) {
+ const Model* model = VerifyFlatbufferAndGetModel(buf, len);
+ return VerifyModel(model, error_reporter);
+}
+
+// Deprecated: see comments in header.
+bool Verify(const void* buf, size_t len, const OpResolver& resolver,
+ ErrorReporter* error_reporter) {
+ const Model* model = VerifyFlatbufferAndGetModel(buf, len);
+ if (!VerifyModel(model, error_reporter)) {
+ return false;
+ }
if (!VerifyOps(*model, resolver, error_reporter)) {
return false;
}
diff --git a/tensorflow/lite/tools/verifier.h b/tensorflow/lite/tools/verifier.h
index 35d7e30..6ff5129 100644
--- a/tensorflow/lite/tools/verifier.h
+++ b/tensorflow/lite/tools/verifier.h
@@ -46,9 +46,28 @@
// * The file is following a legit flatbuffer schema.
// * The model is in supported version.
// * All ops used in the model are supported by OpResolver.
+// DEPRECATED:
+// This function is deprecated, because it doesn't take delegates into
+// account, and as a result may report errors if the model contains
+// operators that are not supported by the OpResolver but that would be
+// rewritten by any TfLiteDelegate that you are using.
+// Suggested replacement:
+// Use the version below that doesn't takes an OpResolver (and
+// doesn't check the validity of the ops) instead of this function,
+// and delay verification of the ops until after you have constructed
+// the Interpreter. To verify that the operators in the model are supported
+// by the delegate(s) and/or by the OpResolver, construct the Interpreter,
+// applying the TfLiteDelegate(s) using InterpreterBuilder::AddDelegate,
+// and then just check the return value from Interpreter::AllocateTensors().
bool Verify(const void* buf, size_t len, const OpResolver& resolver,
ErrorReporter* error_reporter);
+// Verifies the integrity of a Tensorflow Lite flatbuffer model file.
+// Currently, it verifies:
+// * The file is following a legit flatbuffer schema.
+// * The model is in supported version.
+bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter);
+
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_VERIFIER_H_
diff --git a/tensorflow/lite/tools/verifier_test.cc b/tensorflow/lite/tools/verifier_test.cc
index 9a892f5..ca16e7a 100644
--- a/tensorflow/lite/tools/verifier_test.cc
+++ b/tensorflow/lite/tools/verifier_test.cc
@@ -132,6 +132,11 @@
bool Verify() {
return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(),
+ &mock_reporter_);
+ }
+
+ bool VerifyWithOpResolver() {
+ return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(),
resolver_, &mock_reporter_);
}
@@ -165,6 +170,8 @@
::tflite::FinishModelBuffer(builder, model);
MockErrorReporter mock_reporter;
+ ASSERT_FALSE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
MutableOpResolver{}, &mock_reporter));
EXPECT_THAT(mock_reporter.GetAsString(),
@@ -183,6 +190,7 @@
builder.AddTensor({2, 3}, TensorType_INT32, {}, "output");
builder.FinishModel({0, 1}, {3});
ASSERT_TRUE(builder.Verify());
+ ASSERT_TRUE(builder.VerifyWithOpResolver());
}
TEST(VerifyModel, TestSimpleModel) {
@@ -196,6 +204,7 @@
builder.AddTensor({2, 3}, TensorType_INT32, {}, "output");
builder.FinishModel({0, 1}, {2});
ASSERT_TRUE(builder.Verify());
+ ASSERT_TRUE(builder.VerifyWithOpResolver());
EXPECT_EQ("", builder.GetErrorString());
}
@@ -205,6 +214,7 @@
builder.FinishModel(
{}, {2}, TfLiteFlatbufferModelBuilder::kBuilderModeEmptyVectorIsNull);
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_EQ(builder.GetErrorString(),
"Input tensor 0 to op 0 (CUSTOM) is not produced");
}
@@ -214,6 +224,7 @@
builder.FinishModel(
{0, 1}, {2}, TfLiteFlatbufferModelBuilder::kBuilderModeEmptyVectorIsNull);
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(
builder.GetErrorString(),
::testing::ContainsRegex("Missing 'operators' section in subgraph"));
@@ -231,6 +242,7 @@
builder.FinishModel(
{}, {2}, TfLiteFlatbufferModelBuilder::kBuilderModeEmptyVectorIsNull);
ASSERT_TRUE(builder.Verify());
+ ASSERT_TRUE(builder.VerifyWithOpResolver());
EXPECT_EQ("", builder.GetErrorString());
}
@@ -277,6 +289,7 @@
builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex("Tensor input requires 6 bytes, but is "
"allocated with 4 bytes buffer"));
@@ -287,6 +300,7 @@
builder.AddTensor({2, 1}, TensorType_UINT8, {1, 2, 3, 4}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex("Tensor input requires 2 bytes, but is "
"allocated with 4 bytes buffer"));
@@ -298,6 +312,7 @@
"input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex("Tensor input dimension overflow"));
}
@@ -323,6 +338,8 @@
::tflite::FinishModelBuffer(builder, model);
MockErrorReporter mock_reporter;
+ ASSERT_FALSE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
MutableOpResolver{}, &mock_reporter));
EXPECT_THAT(
@@ -335,6 +352,7 @@
builder.AddTensor({2}, TensorType_STRING, {0x00}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_EQ(builder.GetErrorString(), "String tensor input is invalid (empty)");
}
@@ -346,6 +364,7 @@
"input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(
builder.GetErrorString(),
::testing::ContainsRegex(
@@ -360,6 +379,7 @@
{2, 0, 0, 0, 12, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B'}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex(
"String tensor input buffer initial offset must be: 16"));
@@ -372,6 +392,7 @@
{2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 22, 0, 0, 0, 'A', 'B'}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex(
"String tensor input buffer is invalid: index 2"));
@@ -385,6 +406,7 @@
"input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex(
"String tensor input buffer last offset must be 19"));
@@ -400,6 +422,7 @@
builder.AddOperator({0, 1}, {3}, BuiltinOperator_CUSTOM, "CustomOp");
builder.FinishModel({}, {});
ASSERT_TRUE(builder.Verify());
+ ASSERT_TRUE(builder.VerifyWithOpResolver());
EXPECT_EQ("", builder.GetErrorString());
}
@@ -410,7 +433,9 @@
builder.AddTensor({2, 2}, TensorType_UINT8, {}, "output");
builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr);
builder.FinishModel({}, {});
- ASSERT_FALSE(builder.Verify());
+ ASSERT_TRUE(builder.Verify());
+ EXPECT_EQ("", builder.GetErrorString());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(
builder.GetErrorString(),
::testing::ContainsRegex("Unsupported builtin op: ADD, version: 1"));
@@ -423,7 +448,9 @@
builder.AddTensor({2, 2}, TensorType_UINT8, {}, "output");
builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "Not supported");
builder.FinishModel({}, {});
- ASSERT_FALSE(builder.Verify());
+ ASSERT_TRUE(builder.Verify());
+ EXPECT_EQ("", builder.GetErrorString());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex(
"Unsupported custom op: Not supported, version: 1"));
@@ -436,7 +463,9 @@
builder.AddTensor({2, 2}, TensorType_UINT8, {}, "output");
builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "");
builder.FinishModel({}, {});
- ASSERT_FALSE(builder.Verify());
+ ASSERT_TRUE(builder.Verify());
+ EXPECT_EQ("", builder.GetErrorString());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(builder.GetErrorString(),
::testing::ContainsRegex(
"Invalid custom op name, cannot be null/empty."));
@@ -455,6 +484,7 @@
builder.AddTensor({2, 3}, TensorType_INT32, {}, "output");
builder.FinishModel({0, 2}, {3});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_EQ("Input tensor 1 to op 0 (CUSTOM) is not produced",
builder.GetErrorString());
}
@@ -469,6 +499,7 @@
builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "CustomOp");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_EQ(
"Output tensor 2 to op 1 (CUSTOM) is an output from another op. "
"There is a cycle in the graph",
@@ -487,6 +518,7 @@
builder.AddTensor({2, 3}, TensorType_INT32, {1, 2, 3, 4, 5, 6}, "output");
builder.FinishModel({0, 1}, {2});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_EQ("Output tensor 2 to op 0 (CUSTOM) is a constant",
builder.GetErrorString());
}
@@ -503,6 +535,7 @@
// Output shouldn't be a subgraph input.
builder.FinishModel({0, 1, 2}, {2});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_EQ("Output tensor 2 to op 0 (CUSTOM) is a subgraph input",
builder.GetErrorString());
}
@@ -519,6 +552,7 @@
builder.AddTensor({2, 3}, TensorType_INT32, {}, "output", /*variable*/ true);
builder.FinishModel({0, 1}, {2});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_EQ("Output tensor 2 to op 0 (CUSTOM) is a variable",
builder.GetErrorString());
}
@@ -535,6 +569,7 @@
builder.AddTensor({2, 3}, TensorType_INT32, {}, "output");
builder.FinishModel({0, 1}, {2});
ASSERT_TRUE(builder.Verify());
+ ASSERT_TRUE(builder.VerifyWithOpResolver());
EXPECT_EQ("", builder.GetErrorString());
}
@@ -547,6 +582,7 @@
{1, 2, 3, 4}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
+ ASSERT_FALSE(builder.VerifyWithOpResolver());
EXPECT_THAT(
builder.GetErrorString(),
::testing::ContainsRegex("Tensor input requires .* bytes, but is "
@@ -573,6 +609,7 @@
"input");
builder.FinishModel({}, {});
ASSERT_TRUE(builder.Verify());
+ ASSERT_TRUE(builder.VerifyWithOpResolver());
}
}
@@ -591,8 +628,8 @@
MutableOpResolver resolver;
TfLiteRegistration fake_op;
resolver.AddCustom("FakeOp", &fake_op);
- Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
- &mock_reporter);
+ ASSERT_TRUE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
&mock_reporter));
}
@@ -615,6 +652,8 @@
MutableOpResolver resolver;
TfLiteRegistration fake_op;
resolver.AddCustom("FakeOp", &fake_op);
+ ASSERT_FALSE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
&mock_reporter));
EXPECT_THAT(mock_reporter.GetAsString(),
@@ -640,6 +679,8 @@
MutableOpResolver resolver;
TfLiteRegistration fake_op;
resolver.AddCustom("FakeOp", &fake_op);
+ ASSERT_FALSE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
&mock_reporter));
EXPECT_THAT(mock_reporter.GetAsString(),
@@ -664,6 +705,8 @@
MutableOpResolver resolver;
TfLiteRegistration fake_op;
resolver.AddCustom("FakeOp", &fake_op);
+ ASSERT_FALSE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
&mock_reporter));
EXPECT_THAT(mock_reporter.GetAsString(),
@@ -690,6 +733,8 @@
MutableOpResolver resolver;
TfLiteRegistration fake_op;
resolver.AddCustom("FakeOp", &fake_op);
+ ASSERT_FALSE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
&mock_reporter));
EXPECT_THAT(mock_reporter.GetAsString(),
@@ -728,6 +773,8 @@
MutableOpResolver resolver;
TfLiteRegistration fake_op;
resolver.AddCustom("FakeOp", &fake_op);
+ ASSERT_TRUE(
+ Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter));
ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(), resolver,
&mock_reporter));
}