Enable setting an op resolver for the tflite model reader.
PiperOrigin-RevId: 308924400
Change-Id: Ifebe1e5ddd7a54990b9fceac90511f6f35bb08dd
diff --git a/tensorflow/lite/delegates/gpu/cl/testing/BUILD b/tensorflow/lite/delegates/gpu/cl/testing/BUILD
index f7c07ec..723e4cd 100644
--- a/tensorflow/lite/delegates/gpu/cl/testing/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/testing/BUILD
@@ -12,6 +12,7 @@
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader",
+ "//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/time",
],
)
diff --git a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc
index 8dc5dac..75dcbc1 100644
--- a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc
+++ b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc
@@ -23,6 +23,7 @@
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h"
+#include "tensorflow/lite/kernels/register.h"
namespace tflite {
namespace gpu {
@@ -31,7 +32,8 @@
absl::Status RunModelSample(const std::string& model_name) {
auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str());
GraphFloat32 graph_cl;
- RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, &graph_cl));
+ ops::builtin::BuiltinOpResolver op_resolver;
+ RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl));
Environment env;
RETURN_IF_ERROR(CreateEnvironment(&env));
diff --git a/tensorflow/lite/delegates/gpu/common/testing/BUILD b/tensorflow/lite/delegates/gpu/common/testing/BUILD
index 76394df..a7f97eb 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/testing/BUILD
@@ -26,21 +26,12 @@
deps = [
"//tensorflow/lite:framework_lib",
"//tensorflow/lite:kernel_api",
- "//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:common",
- "//tensorflow/lite/delegates/gpu/cl:api",
- "//tensorflow/lite/delegates/gpu/cl:opencl_wrapper",
- "//tensorflow/lite/delegates/gpu/cl:tensor_type_util",
+ "//tensorflow/lite/core/api",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:model_builder",
- "//tensorflow/lite/delegates/gpu/common:model_transformer",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
- "//tensorflow/lite/delegates/gpu/gl:api2",
"//tensorflow/lite/kernels:builtin_ops",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc
index f0872bd..0faa621 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc
@@ -17,6 +17,7 @@
#include <memory>
#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@@ -71,8 +72,8 @@
} // namespace
absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
+ const tflite::OpResolver& op_resolver,
GraphFloat32* graph) {
- ops::builtin::BuiltinOpResolver op_resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
diff --git a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h
index e372507..1a22508 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h
+++ b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h
@@ -15,6 +15,7 @@
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_TFLITE_MODEL_READER_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_TFLITE_MODEL_READER_H_
+#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/model_builder.h"
@@ -25,6 +26,7 @@
// Generates GraphFloat32 basing on the FlatBufferModel without specifying a
// delegate.
absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
+ const tflite::OpResolver& op_resolver,
GraphFloat32* graph);
} // namespace gpu
} // namespace tflite