Split tensorflow-lite-gpu SDK into an API and an implementation SDK
Add a DelegateFactory interface and a GpuDelegateFactory implementation in the GPU API SDK
Move GpuDelegate.Options into GpuDelegateFactory, but add a subclass in GpuDelegate for backwards compatibility.
PiperOrigin-RevId: 460917028
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD
index 924bcb8..3ff6ce2 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD
@@ -5,6 +5,14 @@
filegroup(
name = "gpu_delegate",
srcs = [
+ "GpuDelegateFactory.java",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "gpu_delegate_impl",
+ srcs = [
"CompatibilityList.java",
"GpuDelegate.java",
],
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
index 13524a4..0eec396 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
+++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
@@ -15,7 +15,6 @@
package org.tensorflow.lite.gpu;
-import java.io.Closeable;
import org.tensorflow.lite.Delegate;
import org.tensorflow.lite.annotations.UsedByReflection;
@@ -30,103 +29,36 @@
* Interpreter.Options.addDelegate()} was called.
*/
@UsedByReflection("TFLiteSupport/model/GpuDelegateProxy")
-public class GpuDelegate implements Delegate, Closeable {
+public class GpuDelegate implements Delegate {
private static final long INVALID_DELEGATE_HANDLE = 0;
private static final String TFLITE_GPU_LIB = "tensorflowlite_gpu_jni";
private long delegateHandle;
- /** Delegate options. */
- public static final class Options {
- public Options() {}
-
- /**
- * Delegate will be used only once, therefore, bootstrap/init time should be taken into account.
- */
- public static final int INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER = 0;
-
- /**
- * Prefer maximizing the throughput. Same delegate will be used repeatedly on multiple inputs.
- */
- public static final int INFERENCE_PREFERENCE_SUSTAINED_SPEED = 1;
-
- /**
- * Sets whether precision loss is allowed.
- *
- * @param precisionLossAllowed When `true` (default), the GPU may quantify tensors, downcast
- * values, process in FP16. When `false`, computations are carried out in 32-bit floating
- * point.
- */
- public Options setPrecisionLossAllowed(boolean precisionLossAllowed) {
- this.precisionLossAllowed = precisionLossAllowed;
- return this;
- }
-
- /**
- * Enables running quantized models with the delegate.
- *
- * <p>WARNING: This is an experimental API and subject to change.
- *
- * @param quantizedModelsAllowed When {@code true} (default), the GPU may run quantized models.
- */
- public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) {
- this.quantizedModelsAllowed = quantizedModelsAllowed;
- return this;
- }
-
- /**
- * Sets the inference preference for precision/compilation/runtime tradeoffs.
- *
- * @param preference One of `INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER` (default),
- * `INFERENCE_PREFERENCE_SUSTAINED_SPEED`.
- */
- public Options setInferencePreference(int preference) {
- this.inferencePreference = preference;
- return this;
- }
-
- /**
- * Enables serialization on the delegate. Note non-null {@code serializationDir} and {@code
- * modelToken} are required for serialization.
- *
- * <p>WARNING: This is an experimental API and subject to change.
- *
- * @param serializationDir The directory to use for storing data. Caller is responsible to
- * ensure the model is not stored in a public directory. It's recommended to use {@link
- * android.content.Context#getCodeCacheDir()} to provide a private location for the
- * application on Android.
- * @param modelToken The token to be used to identify the model. Caller is responsible to ensure
- * the token is unique to the model graph and data.
- */
- public Options setSerializationParams(String serializationDir, String modelToken) {
- this.serializationDir = serializationDir;
- this.modelToken = modelToken;
- return this;
- }
-
- boolean precisionLossAllowed = true;
- boolean quantizedModelsAllowed = true;
- int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
- String serializationDir = null;
- String modelToken = null;
- }
-
- public GpuDelegate(Options options) {
+ public GpuDelegate(GpuDelegateFactory.Options options) {
delegateHandle =
createDelegate(
- options.precisionLossAllowed,
- options.quantizedModelsAllowed,
- options.inferencePreference,
- options.serializationDir,
- options.modelToken);
+ options.isPrecisionLossAllowed(),
+ options.areQuantizedModelsAllowed(),
+ options.getInferencePreference(),
+ options.getSerializationDir(),
+ options.getModelToken());
}
@UsedByReflection("TFLiteSupport/model/GpuDelegateProxy")
public GpuDelegate() {
- this(new Options());
+ this(new GpuDelegateFactory.Options());
}
+ /**
+ * Inherits from {@link GpuDelegateFactory.Options} for compatibility with existing code.
+ *
+ * @deprecated Use {@link GpuDelegateFactory.Options} instead.
+ */
+ @Deprecated
+ public static class Options extends GpuDelegateFactory.Options {}
+
@Override
public long getNativeHandle() {
return delegateHandle;
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegateFactory.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegateFactory.java
new file mode 100644
index 0000000..c4e3d10
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegateFactory.java
@@ -0,0 +1,159 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.gpu;
+
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import org.tensorflow.lite.Delegate;
+import org.tensorflow.lite.DelegateFactory;
+import org.tensorflow.lite.RuntimeFlavor;
+
+/** {@link DelegateFactory} for creating a {@link GpuDelegate}. */
+public class GpuDelegateFactory implements DelegateFactory {
+
+ private static final String GPU_DELEGATE_CLASS_NAME = "GpuDelegate";
+
+ private final Options options;
+
+ /** Delegate options. */
+ public static class Options {
+ public Options() {}
+
+ /**
+ * Delegate will be used only once, therefore, bootstrap/init time should be taken into account.
+ */
+ public static final int INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER = 0;
+
+ /**
+ * Prefer maximizing the throughput. Same delegate will be used repeatedly on multiple inputs.
+ */
+ public static final int INFERENCE_PREFERENCE_SUSTAINED_SPEED = 1;
+
+ /**
+ * Sets whether precision loss is allowed.
+ *
+ * @param precisionLossAllowed When `true` (default), the GPU may quantify tensors, downcast
+ * values, process in FP16. When `false`, computations are carried out in 32-bit floating
+ * point.
+ */
+ public Options setPrecisionLossAllowed(boolean precisionLossAllowed) {
+ this.precisionLossAllowed = precisionLossAllowed;
+ return this;
+ }
+
+ /**
+ * Enables running quantized models with the delegate.
+ *
+ * <p>WARNING: This is an experimental API and subject to change.
+ *
+ * @param quantizedModelsAllowed When {@code true} (default), the GPU may run quantized models.
+ */
+ public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) {
+ this.quantizedModelsAllowed = quantizedModelsAllowed;
+ return this;
+ }
+
+ /**
+ * Sets the inference preference for precision/compilation/runtime tradeoffs.
+ *
+ * @param preference One of `INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER` (default),
+ * `INFERENCE_PREFERENCE_SUSTAINED_SPEED`.
+ */
+ public Options setInferencePreference(int preference) {
+ this.inferencePreference = preference;
+ return this;
+ }
+
+ /**
+ * Enables serialization on the delegate. Note non-null {@code serializationDir} and {@code
+ * modelToken} are required for serialization.
+ *
+ * <p>WARNING: This is an experimental API and subject to change.
+ *
+ * @param serializationDir The directory to use for storing data. Caller is responsible to
+ * ensure the model is not stored in a public directory. It's recommended to use {@link
+ * android.content.Context#getCodeCacheDir()} to provide a private location for the
+ * application on Android.
+ * @param modelToken The token to be used to identify the model. Caller is responsible to ensure
+ * the token is unique to the model graph and data.
+ */
+ public Options setSerializationParams(String serializationDir, String modelToken) {
+ this.serializationDir = serializationDir;
+ this.modelToken = modelToken;
+ return this;
+ }
+
+ public boolean isPrecisionLossAllowed() {
+ return precisionLossAllowed;
+ }
+
+ public boolean areQuantizedModelsAllowed() {
+ return quantizedModelsAllowed;
+ }
+
+ public int getInferencePreference() {
+ return inferencePreference;
+ }
+
+ public String getSerializationDir() {
+ return serializationDir;
+ }
+
+ public String getModelToken() {
+ return modelToken;
+ }
+
+ private boolean precisionLossAllowed = true;
+ boolean quantizedModelsAllowed = true;
+ int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
+ String serializationDir = null;
+ String modelToken = null;
+ }
+
+ public GpuDelegateFactory() {
+ this(new Options());
+ }
+
+ public GpuDelegateFactory(Options options) {
+ this.options = options;
+ }
+
+ @Override
+ public Delegate create(RuntimeFlavor runtimeFlavor) {
+ String packageName;
+ switch (runtimeFlavor) {
+ case APPLICATION:
+ packageName = "org.tensorflow.lite.gpu";
+ break;
+ case SYSTEM:
+ packageName = "com.google.android.gms.tflite.gpu";
+ break;
+ default:
+ throw new IllegalArgumentException("Unsupported runtime flavor " + runtimeFlavor);
+ }
+ try {
+ Class<?> delegateClass = Class.forName(packageName + "." + GPU_DELEGATE_CLASS_NAME);
+ Constructor<?> constructor = delegateClass.getDeclaredConstructor(Options.class);
+ return (Delegate) constructor.newInstance(options);
+ } catch (ClassNotFoundException
+ | IllegalAccessException
+ | InstantiationException
+ | NoSuchMethodException
+ | InvocationTargetException e) {
+ throw new IllegalStateException("Error creating GPU delegate", e);
+ }
+ }
+}
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD b/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD
index 919a5f5..4c5d4e5 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD
+++ b/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD
@@ -10,6 +10,8 @@
licenses = ["notice"],
)
+exports_files(srcs = ["gpu_delegate_jni.cc"])
+
cc_library_with_tflite(
name = "native",
srcs = ["gpu_delegate_jni.cc"],
diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD
index b635d1b..594dfc1 100644
--- a/tensorflow/lite/java/BUILD
+++ b/tensorflow/lite/java/BUILD
@@ -55,9 +55,11 @@
JAVA_API_SRCS = [
"src/main/java/org/tensorflow/lite/DataType.java",
"src/main/java/org/tensorflow/lite/Delegate.java",
+ "src/main/java/org/tensorflow/lite/DelegateFactory.java",
"src/main/java/org/tensorflow/lite/InterpreterApi.java",
"src/main/java/org/tensorflow/lite/InterpreterFactory.java",
"src/main/java/org/tensorflow/lite/InterpreterFactoryApi.java",
+ "src/main/java/org/tensorflow/lite/RuntimeFlavor.java",
"src/main/java/org/tensorflow/lite/Tensor.java",
"src/main/java/org/tensorflow/lite/TensorFlowLite.java",
"src/main/java/org/tensorflow/lite/annotations/UsedByReflection.java",
@@ -83,6 +85,7 @@
"src/main/java/org/tensorflow/lite/DataType.java",
"src/main/java/org/tensorflow/lite/DataTypeUtils.java",
"src/main/java/org/tensorflow/lite/Delegate.java",
+ "src/main/java/org/tensorflow/lite/DelegateFactory.java",
"src/main/java/org/tensorflow/lite/InterpreterApi.java",
"src/main/java/org/tensorflow/lite/InterpreterFactory.java",
"src/main/java/org/tensorflow/lite/InterpreterFactoryApi.java",
@@ -90,6 +93,7 @@
"src/main/java/org/tensorflow/lite/InterpreterImpl.java",
"src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java",
"src/main/java/org/tensorflow/lite/NativeSignatureRunnerWrapper.java",
+ "src/main/java/org/tensorflow/lite/RuntimeFlavor.java",
"src/main/java/org/tensorflow/lite/Tensor.java",
"src/main/java/org/tensorflow/lite/TensorFlowLite.java",
"src/main/java/org/tensorflow/lite/TensorImpl.java",
@@ -166,12 +170,19 @@
android_library = ":tensorflowlite_flex",
)
-# EXPERIMENTAL: AAR target for GPU acceleration. Note that this .aar contains
-# *only* the GPU delegate; clients must also include the core `tensorflow-lite`
-# runtime.
+# EXPERIMENTAL: AAR target for the GPU acceleration API, EXCLUDING implementation.
+# Note that this .aar contains *only* the GPU delegate API; clients must also include a GPU delegate
+# implementation, as well as the core `tensorflow-lite` runtime.
+aar_without_jni(
+ name = "tensorflow-lite-gpu-api",
+ android_library = ":tensorflowlite_gpu_api",
+)
+
+# EXPERIMENTAL: AAR target for GPU acceleration API and implementation. Note that this .aar contains
+# *only* the GPU delegate; clients must also include the core `tensorflow-lite` runtime.
aar_with_jni(
name = "tensorflow-lite-gpu",
- android_library = ":tensorflowlite_gpu",
+ android_library = ":tensorflowlite_gpu_impl",
headers = [
"//tensorflow/lite/delegates/gpu:delegate.h",
],
@@ -195,12 +206,12 @@
proguard_specs = ["proguard.flags"],
tflite_deps = [
":tensorflowlite_native",
- ":tensorflowlite_api",
],
- tflite_exports = [
+ exports = [
":tensorflowlite_api",
],
deps = [
+ ":tensorflowlite_api",
"@org_checkerframework_qual",
],
)
@@ -210,7 +221,7 @@
# Maven package "org.tensorflow:tensorflow-lite-api".
# This target does not include the TF Lite Runtime, which nevertheless is
# required and must be provided via a separate dependency.
-android_library_with_tflite(
+android_library(
name = "tensorflowlite_api",
srcs = [":java_api_srcs"],
manifest = "AndroidManifest.xml",
@@ -263,20 +274,41 @@
# contains *only* the GPU delegate and its Java wrapper; clients must also
# include the core `tensorflowlite` runtime.
# Note that AndroidManifestGpu.xml usage requires AGP 4.2.0+.
-android_library(
+alias(
name = "tensorflowlite_gpu",
- # Note that we need to directly includes all the
- # required Java source files directly in "srcs" rather than
- # depending on them via "deps"/"exports"; this is needed when
- # building the AAR file since the current AAR building process
- # doesn't include the transitive Java dependencies.
- srcs = ["//tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu:gpu_delegate"],
- exports_manifest = 1,
- manifest = "AndroidManifest.xml",
- proguard_specs = ["proguard.flags"],
+ actual = "tensorflowlite_gpu_impl",
+)
+
+# EXPERIMENTAL: Android target for the implementation of the GPU acceleration API, including the
+# native library. Note that this library contains *only* the GPU delegate and its Java wrapper;
+# clients must also include the core `tensorflowlite` runtime.
+# Note that AndroidManifestGpu.xml usage requires AGP 4.2.0+.
+android_library(
+ name = "tensorflowlite_gpu_impl",
+ # Note that we need to directly includes all the Java source files we intend to ship directly in
+ # "srcs" rather than depending on them via "deps"/"exports"; this is needed when building the
+ # AAR file since the current AAR building process doesn't include the transitive Java
+ # dependencies. The API target however can be an export, because it is shipped in a different
+ # AAR file.
+ srcs = ["//tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu:gpu_delegate_impl"],
+ exports = [
+ ":tensorflowlite_gpu_api",
+ ":tensorflowlite_gpu_native",
+ ],
deps = [
- ":tensorflowlite_java",
- ":tensorflowlite_native_gpu",
+ ":tensorflowlite_gpu_api",
+ "//tensorflow/lite/java:tensorflowlite_api",
+ ],
+)
+
+# EXPERIMENTAL: Android target for the implementation of the GPU acceleration API,
+# EXCLUDING the native library.
+android_library(
+ name = "tensorflowlite_gpu_impl_java",
+ srcs = ["//tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu:gpu_delegate_impl"],
+ deps = [
+ ":tensorflowlite_api",
+ ":tensorflowlite_gpu_api",
],
)
@@ -302,17 +334,17 @@
],
)
-# EXPERIMENTAL: Android target for GPU acceleration, EXCLUDING NATIVE CODE dependencies.
-# Note that this library contains *only* the GPU delegate and its Java wrapper; clients must also
-# include the core `tensorflowlite` runtime.
+# EXPERIMENTAL: Android target for GPU acceleration API, EXCLUDING implementation.
+# Note that this library contains *only* the GPU delegate API; clients must also include
+# an implementation, as well as the core `tensorflowlite` runtime.
# Note that AndroidManifestGpu.xml usage requires AGP 4.2.0+.
android_library(
- name = "tensorflowlite_gpu_java",
+ name = "tensorflowlite_gpu_api",
srcs = ["//tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu:gpu_delegate"],
exports_manifest = 1,
manifest = "AndroidManifestGpu.xml",
proguard_specs = ["proguard.flags"],
- deps = [":tensorflowlite_java"],
+ deps = [":tensorflowlite_api"],
)
#-----------------------------------------------------------------------------
@@ -854,7 +886,7 @@
)
cc_library_with_tflite(
- name = "tensorflowlite_native_gpu",
+ name = "tensorflowlite_gpu_native",
tflite_jni_binaries = ["libtensorflowlite_gpu_jni.so"],
visibility = ["//visibility:private"],
)
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java
index eaf9ae5..55a5fff 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java
@@ -24,7 +24,7 @@
* technically allows sharing of a single delegate instance across multiple interpreter instances,
* the delegate implementation must explicitly support this.
*/
-public interface Delegate {
+public interface Delegate extends AutoCloseable {
/**
* Returns a native handle to the TensorFlow Lite delegate implementation.
*
@@ -40,5 +40,15 @@
* @return The native delegate handle. In C/C++, this should be a pointer to
* 'TfLiteOpaqueDelegate'.
*/
- public long getNativeHandle();
+ long getNativeHandle();
+
+ /**
+ * Closes the delegate and releases any resources associated with it.
+ *
+ * <p>In contrast to the method declared in the base {@link AutoCloseable} interface, this method
+ * does not throw checked exceptions.
+ */
+ @SuppressWarnings("StaticOrDefaultInterfaceMethod")
+ @Override
+ default void close() {}
}
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DelegateFactory.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DelegateFactory.java
new file mode 100644
index 0000000..8d1c606
--- /dev/null
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DelegateFactory.java
@@ -0,0 +1,28 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+/** Allows creating delegates for different runtime flavors. */
+public interface DelegateFactory {
+ /**
+ * Create a {@link Delegate} for the given {@link RuntimeFlavor}.
+ *
+ * <p>Note for developers implementing this interface: Currently TF Lite in Google Play Services
+ * does not support external (developer-provided) delegates. Correspondingly, implementations of
+ * this method can expect to be called with {@link RuntimeFlavor#APPLICATION}.
+ */
+ Delegate create(RuntimeFlavor runtimeFlavor);
+}
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 777d8b9..8a84853 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -81,8 +81,7 @@
/** An options class for controlling runtime interpreter behavior. */
public static class Options extends InterpreterImpl.Options {
- public Options() {
- }
+ public Options() {}
public Options(InterpreterApi.Options options) {
super(options);
@@ -124,6 +123,12 @@
return this;
}
+ @Override
+ public Options addDelegateFactory(DelegateFactory delegateFactory) {
+ super.addDelegateFactory(delegateFactory);
+ return this;
+ }
+
/**
* Advanced: Set if buffer handle output is allowed.
*
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterApi.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterApi.java
index ce904f9..ed6fc36 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterApi.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterApi.java
@@ -23,6 +23,7 @@
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;
+import org.tensorflow.lite.nnapi.NnApiDelegate;
/**
* Interface to TensorFlow Lite model interpreter, excluding experimental methods.
@@ -87,8 +88,10 @@
/** An options class for controlling runtime interpreter behavior. */
class Options {
+
public Options() {
this.delegates = new ArrayList<>();
+ this.delegateFactories = new ArrayList<>();
}
public Options(Options other) {
@@ -96,6 +99,7 @@
this.useNNAPI = other.useNNAPI;
this.allowCancellation = other.allowCancellation;
this.delegates = new ArrayList<>(other.delegates);
+ this.delegateFactories = new ArrayList<>(other.delegateFactories);
this.runtime = other.runtime;
}
@@ -166,21 +170,56 @@
return allowCancellation != null && allowCancellation;
}
- /** Adds a {@link Delegate} to be applied during interpreter creation. */
+ /**
+ * Adds a {@link Delegate} to be applied during interpreter creation.
+ *
+ * <p>Delegates added here are applied before any delegates created from a {@link
+ * DelegateFactory} that was added with {@link #addDelegateFactory}.
+ *
+ * <p>Note that TF Lite in Google Play Services (see {@link #setRuntime}) does not support
+ * external (developer-provided) delegates, and adding a {@link Delegate} other than {@link
+ * NnApiDelegate} here is not allowed when using TF Lite in Google Play Services.
+ */
public Options addDelegate(Delegate delegate) {
delegates.add(delegate);
return this;
}
/**
- * Returns the list of delegates intended to be applied during interpreter creation (that have
- * been registered via {@code addDelegate}).
+ * Returns the list of delegates intended to be applied during interpreter creation that have
+ * been registered via {@code addDelegate}.
*/
public List<Delegate> getDelegates() {
return Collections.unmodifiableList(delegates);
}
- /** Enum to represent where to get the TensorFlow Lite runtime implementation from. */
+ /**
+ * Adds a {@link DelegateFactory} which will be invoked to apply its created {@link Delegate}
+ * during interpreter creation.
+ *
+ * <p>Delegates from a delegated factory that was added here are applied after any delegates
+ * added with {@link #addDelegate}.
+ */
+ public Options addDelegateFactory(DelegateFactory delegateFactory) {
+ delegateFactories.add(delegateFactory);
+ return this;
+ }
+
+ /**
+ * Returns the list of delegate factories that have been registered via {@code
+ * addDelegateFactory}).
+ */
+ public List<DelegateFactory> getDelegateFactories() {
+ return Collections.unmodifiableList(delegateFactories);
+ }
+
+ /**
+ * Enum to represent where to get the TensorFlow Lite runtime implementation from.
+ *
+ * <p>The difference between this class and the RuntimeFlavor class: This class specifies a
+ * <em>preference</em> which runtime to use, whereas {@link RuntimeFlavor} specifies which exact
+ * runtime <em>is</em> being used.
+ */
public enum TfLiteRuntime {
/**
* Use a TF Lite runtime implementation that is linked into the application. If there is no
@@ -237,8 +276,10 @@
Boolean useNNAPI;
Boolean allowCancellation;
- // See InterpreterApi.Options#addDelegate(boolean).
+ // See InterpreterApi.Options#addDelegate.
final List<Delegate> delegates;
+ // See InterpreterApi.Options#addDelegateFactory.
+ private final List<DelegateFactory> delegateFactories;
}
/**
@@ -270,8 +311,7 @@
* direct {@code ByteBuffer} of nativeOrder.
*/
@SuppressWarnings("StaticOrDefaultInterfaceMethod")
- static InterpreterApi create(
- @NonNull ByteBuffer byteBuffer, InterpreterApi.Options options) {
+ static InterpreterApi create(@NonNull ByteBuffer byteBuffer, InterpreterApi.Options options) {
TfLiteRuntime runtime = (options == null ? null : options.getRuntime());
InterpreterFactoryApi factory = TensorFlowLite.getFactory(runtime);
return factory.create(byteBuffer, options);
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterFactoryImpl.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterFactoryImpl.java
index c893c37..77fdfd6 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterFactoryImpl.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/InterpreterFactoryImpl.java
@@ -25,6 +25,7 @@
/** Package-private factory class for constructing InterpreterApi instances. */
@UsedByReflection("InterpreterFactory.java")
class InterpreterFactoryImpl implements InterpreterFactoryApi {
+
public InterpreterFactoryImpl() {}
@Override
@@ -51,12 +52,12 @@
return nativeSchemaVersion();
}
- private static native String nativeRuntimeVersion();
-
- private static native String nativeSchemaVersion();
-
@Override
public NnApiDelegate.PrivateInterface createNnApiDelegateImpl(NnApiDelegate.Options options) {
return new NnApiDelegateImpl(options);
}
+
+ private static native String nativeRuntimeVersion();
+
+ private static native String nativeSchemaVersion();
}
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 9882cbc..503a255 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -23,6 +23,8 @@
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
+import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;
+import org.tensorflow.lite.InterpreterImpl.Options;
import org.tensorflow.lite.annotations.UsedByReflection;
import org.tensorflow.lite.nnapi.NnApiDelegate;
@@ -37,6 +39,9 @@
*/
class NativeInterpreterWrapper implements AutoCloseable {
+ // This is changed to RuntimeFlavor.SYSTEM for TF Lite in Google Play Services.
+ private static final RuntimeFlavor RUNTIME_FLAVOR = RuntimeFlavor.APPLICATION;
+
NativeInterpreterWrapper(String modelPath) {
this(modelPath, /* options= */ null);
}
@@ -150,12 +155,8 @@
outputsIndexes = null;
isMemoryAllocated = false;
delegates.clear();
- for (AutoCloseable ownedDelegate : ownedDelegates) {
- try {
- ownedDelegate.close();
- } catch (Exception e) {
- System.err.println("Failed to close flex delegate: " + e);
- }
+ for (Delegate ownedDelegate : ownedDelegates) {
+ ownedDelegate.close();
}
ownedDelegates.clear();
}
@@ -504,12 +505,17 @@
if (originalGraphHasUnresolvedFlexOp) {
Delegate optionalFlexDelegate = maybeCreateFlexDelegate(options.getDelegates());
if (optionalFlexDelegate != null) {
- ownedDelegates.add((AutoCloseable) optionalFlexDelegate);
+ ownedDelegates.add(optionalFlexDelegate);
delegates.add(optionalFlexDelegate);
}
}
// Now add the user-supplied delegates.
- delegates.addAll(options.getDelegates());
+ addUserProvidedDelegates(options);
+ for (DelegateFactory delegateFactory : options.getDelegateFactories()) {
+ Delegate delegate = delegateFactory.create(RUNTIME_FLAVOR);
+ ownedDelegates.add(delegate);
+ delegates.add(delegate);
+ }
if (options.getUseNNAPI()) {
NnApiDelegate optionalNnApiDelegate = new NnApiDelegate();
ownedDelegates.add(optionalNnApiDelegate);
@@ -517,6 +523,22 @@
}
}
+ private void addUserProvidedDelegates(Options options) {
+ for (Delegate delegate : options.getDelegates()) {
+ // NnApiDelegate is compatible with both the system and built-in runtimes and therefore can be
+ // added directly even when using TF Lite from the system.
+ if (options.getRuntime() != TfLiteRuntime.FROM_APPLICATION_ONLY
+ && !(delegate instanceof NnApiDelegate)) {
+ throw new IllegalArgumentException(
+ "Instantiated delegates (other than NnApiDelegate) are not allowed when using TF Lite"
+ + " from Google Play Services. Please use"
+ + " InterpreterApi.Options.setDelegateFactory() with an appropriate DelegateFactory"
+ + " instead.");
+ }
+ delegates.add(delegate);
+ }
+ }
+
// Complete the initialization of any delegates that require an InterpreterFactoryApi instance.
void initDelegatesWithInterpreterFactory() {
InterpreterFactoryApi interpreterFactoryApi = new InterpreterFactoryImpl();
@@ -595,7 +617,7 @@
private final List<Delegate> delegates = new ArrayList<>();
// List of owned delegates that must be closed when the interpreter is closed.
- private final List<AutoCloseable> ownedDelegates = new ArrayList<>();
+ private final List<Delegate> ownedDelegates = new ArrayList<>();
private static native void run(long interpreterHandle, long errorHandle);
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/RuntimeFlavor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/RuntimeFlavor.java
new file mode 100644
index 0000000..209b4ac
--- /dev/null
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/RuntimeFlavor.java
@@ -0,0 +1,30 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;
+
+/**
+ * Represents a TFLite runtime. In contrast to {@link TfLiteRuntime}, this enum represents the
+ * actual runtime that is being used, whereas the latter represents a preference for which runtime
+ * should be used.
+ */
+public enum RuntimeFlavor {
+ /** A TFLite runtime built directly into the application. */
+ APPLICATION,
+ /** A TFLite runtime provided by the system (TFLite in Google Play Services). */
+ SYSTEM,
+}
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
index 2acc341..7994b67 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
@@ -106,7 +106,7 @@
Interpreter.Options options = new Interpreter.Options();
try (GpuDelegate delegate =
- new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(false));
+ new GpuDelegate(new GpuDelegateFactory.Options().setQuantizedModelsAllowed(false));
Interpreter interpreter =
new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
byte[][] output = new byte[1][1001];
@@ -164,7 +164,7 @@
if (enableSerialization) {
options.addDelegate(
new GpuDelegate(
- new GpuDelegate.Options()
+ new GpuDelegateFactory.Options()
.setSerializationParams(serializationDir, "GpuDelegateTest.testModelToken")));
} else {
options.addDelegate(new GpuDelegate());