TFL: Selective registration for C++ target.
Usage:
- Create a tflite_custom_cc_library rule in the BUILD file with the targeted model.
- Call tflite::CreateOpResolver to get the slimmed op resolver.
PiperOrigin-RevId: 316849510
Change-Id: I3e7d75da6a9f2876b3fbefe1962e5ae09ebadb33
diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl
index 285824a..5e48739 100644
--- a/tensorflow/lite/build_def.bzl
+++ b/tensorflow/lite/build_def.bzl
@@ -732,3 +732,49 @@
] + if_non_eager,
if_none = [] + if_none,
)
+
+def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility = ["//visibility:private"]):
+ """Generates a tflite cc library, stripping off unused operators.
+
+ This library includes the TfLite runtime as well as all operators needed for the given models.
+ Op resolver can be retrieved using tflite::CreateOpResolver method.
+
+ Args:
+ name: Str, name of the target.
+ models: List of models. This TFLite build will only include
+ operators used in these models. If the list is empty, all builtin
+ operators are included.
+ srcs: List of files implementing custom operators if any.
+ deps: Additional dependencies to build all the custom operators.
+ visibility: Visibility setting for the generated target. Default to private.
+ """
+ real_srcs = []
+ real_srcs.extend(srcs)
+ real_deps = []
+ real_deps.extend(deps)
+
+ if models:
+ gen_selected_ops(
+ name = "%s_registration" % name,
+ model = models[0],
+ )
+ real_srcs.append(":%s_registration" % name)
+ real_deps.append("//tensorflow/lite/java/src/main/native:selected_ops_jni")
+ else:
+ # Support all operators if `models` not specified.
+ real_deps.append("//tensorflow/lite/java/src/main/native")
+
+ native.cc_library(
+ name = name,
+ srcs = real_srcs,
+ copts = tflite_copts(),
+ linkopts = [
+ "-lm",
+ "-ldl",
+ ],
+ deps = depset([
+ "//tensorflow/lite:framework",
+ "//tensorflow/lite/kernels:builtin_ops",
+ ] + real_deps),
+ visibility = visibility,
+ )
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index 0d3535b..fdbbc9d 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -45,14 +45,27 @@
srcs = [
"builtin_ops_jni.cc",
],
+ hdrs = ["op_resolver.h"],
copts = tflite_copts(),
deps = [
":native_framework_only",
+ "//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
)
+# TODO(b/153652701): Generate this target to give CreateOpResolver a custom namespace.
+cc_library(
+ name = "selected_ops_jni",
+ srcs = ["selected_ops_jni.cc"],
+ hdrs = ["op_resolver.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/lite:framework",
+ ],
+)
+
exports_files(
[
"exported_symbols.lds",
diff --git a/tensorflow/lite/java/src/main/native/op_resolver.h b/tensorflow/lite/java/src/main/native/op_resolver.h
new file mode 100644
index 0000000..ba9c1bf
--- /dev/null
+++ b/tensorflow/lite/java/src/main/native/op_resolver.h
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
+#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
+
+#include "tensorflow/lite/op_resolver.h"
+
+namespace tflite {
+
+std::unique_ptr<OpResolver> CreateOpResolver();
+
+}
+
+#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_OP_RESOLVER_H_
diff --git a/tensorflow/lite/java/src/main/native/selected_ops_jni.cc b/tensorflow/lite/java/src/main/native/selected_ops_jni.cc
new file mode 100644
index 0000000..d8eb233
--- /dev/null
+++ b/tensorflow/lite/java/src/main/native/selected_ops_jni.cc
@@ -0,0 +1,36 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/java/src/main/native/op_resolver.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+// This method is generated by `gen_selected_ops`.
+// TODO(b/153652701): Instead of relying on a global method, make
+// `gen_selected_ops` generating a header file with custom namespace.
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+
+namespace tflite {
+// This interface is the unified entry point for creating op resolver
+// regardless if selective registration is being used. C++ client will call
+// this method directly and Java client will call this method indirectly via
+// JNI code in interpreter_jni.cc.
+std::unique_ptr<OpResolver> CreateOpResolver() {
+ std::unique_ptr<MutableOpResolver> resolver =
+ std::unique_ptr<MutableOpResolver>(new MutableOpResolver());
+ RegisterSelectedOps(resolver.get());
+ return std::move(resolver);
+}
+
+} // namespace tflite