[tf.lite] Adds setSerializationParams() Java API for enabling GPU delegate serialization.
PiperOrigin-RevId: 401822099
Change-Id: I32deb8a95eadb72e4169da3bc328e95ded2900c6
diff --git a/RELEASE.md b/RELEASE.md
index f874104..ede68b9 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -26,6 +26,10 @@
*<SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
*<IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
*<NOTES SHOULD BE GROUPED PER AREA>
+* `tf.lite`:
+ * GPU
+ * Adds GPU Delegation support for serialization to Java API. This boosts
+ initialization time upto 90% when OpenCL is available.
# Thanks to our Contributors
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
index e198dad..13524a4 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
+++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
@@ -64,7 +64,7 @@
}
/**
- * Enables running quantized models with the delegate. Defaults to false.
+ * Enables running quantized models with the delegate.
*
* <p>WARNING: This is an experimental API and subject to change.
*
@@ -86,9 +86,30 @@
return this;
}
+ /**
+ * Enables serialization on the delegate. Note non-null {@code serializationDir} and {@code
+ * modelToken} are required for serialization.
+ *
+ * <p>WARNING: This is an experimental API and subject to change.
+ *
+ * @param serializationDir The directory to use for storing data. Caller is responsible to
+ * ensure the model is not stored in a public directory. It's recommended to use {@link
+ * android.content.Context#getCodeCacheDir()} to provide a private location for the
+ * application on Android.
+ * @param modelToken The token to be used to identify the model. Caller is responsible to ensure
+ * the token is unique to the model graph and data.
+ */
+ public Options setSerializationParams(String serializationDir, String modelToken) {
+ this.serializationDir = serializationDir;
+ this.modelToken = modelToken;
+ return this;
+ }
+
boolean precisionLossAllowed = true;
boolean quantizedModelsAllowed = true;
int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
+ String serializationDir = null;
+ String modelToken = null;
}
public GpuDelegate(Options options) {
@@ -96,7 +117,9 @@
createDelegate(
options.precisionLossAllowed,
options.quantizedModelsAllowed,
- options.inferencePreference);
+ options.inferencePreference,
+ options.serializationDir,
+ options.modelToken);
}
@UsedByReflection("TFLiteSupport/model/GpuDelegateProxy")
@@ -127,7 +150,11 @@
}
private static native long createDelegate(
- boolean precisionLossAllowed, boolean quantizedModelsAllowed, int preference);
+ boolean precisionLossAllowed,
+ boolean quantizedModelsAllowed,
+ int preference,
+ String serializationDir,
+ String modelToken);
private static native void deleteDelegate(long delegateHandle);
}
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
index 62a52f3..8d85c53 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
+++ b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
@@ -29,7 +29,8 @@
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
- jboolean quantized_models_allowed, jint inference_preference) {
+ jboolean quantized_models_allowed, jint inference_preference,
+ jstring serialization_dir, jstring model_token) {
TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default();
if (precision_loss_allowed == JNI_TRUE) {
options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
@@ -42,6 +43,18 @@
options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
}
options.inference_preference = static_cast<int32_t>(inference_preference);
+ if (serialization_dir) {
+ options.serialization_dir =
+ env->GetStringUTFChars(serialization_dir, /*isCopy=*/nullptr);
+ }
+ if (model_token) {
+ options.model_token =
+ env->GetStringUTFChars(model_token, /*isCopy=*/nullptr);
+ }
+ if (options.serialization_dir && options.model_token) {
+ options.experimental_flags |=
+ TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_SERIALIZATION;
+ }
return reinterpret_cast<jlong>(TfLiteGpuDelegateV2Create(&options));
}
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
index 66bc35b..2acc341 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
@@ -16,7 +16,10 @@
package org.tensorflow.lite.gpu;
import static com.google.common.truth.Truth.assertThat;
+import static java.util.concurrent.TimeUnit.MICROSECONDS;
+import com.google.common.base.Stopwatch;
+import java.io.File;
import java.nio.ByteBuffer;
import java.util.AbstractMap;
import java.util.ArrayList;
@@ -24,7 +27,9 @@
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.lite.Interpreter;
@@ -41,6 +46,8 @@
TestUtils.getTestFileAsBuffer(
"tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
+ @Rule public final TemporaryFolder tempDir = new TemporaryFolder();
+
@Test
public void testBasic() throws Exception {
try (GpuDelegate delegate = new GpuDelegate()) {
@@ -113,6 +120,59 @@
}
}
+ @Test
+ public void testDelegateSerialization() throws Exception {
+ ByteBuffer img =
+ TestUtils.getTestImageAsByteBuffer(
+ "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
+
+ File serializationFolder = tempDir.newFolder();
+ String serializationDir = serializationFolder.getPath();
+
+ // Create the interpreter with serialization enabled delegate.
+ createInterpreterWithDelegate(/*enableSerialization=*/ true, serializationFolder.getPath());
+
+ // In the second interpreter initialization, delegate reuses the serialization data.
+ Stopwatch stopWatch = Stopwatch.createStarted();
+ Interpreter interpreter =
+ createInterpreterWithDelegate(/*enableSerialization=*/ true, serializationFolder.getPath());
+ stopWatch.stop();
+ long serializedInitTime = stopWatch.elapsed(MICROSECONDS);
+ // Check on the model.
+ byte[][] output = new byte[1][1001];
+ interpreter.run(img, output);
+ // 653 == "military uniform"
+ assertThat(getTopKLabels(output, 3)).contains(653);
+
+ // If OpenCL is available, serialized data will be written to serializationDir and
+ // initialization time improvement shall be observed.
+ // Otherwise, this testcase performs a check that enabling the option won't crash.
+ if (serializationFolder.list().length > 0) {
+ stopWatch.reset();
+ stopWatch.start();
+ // Initialze interpreter with GpuDelegate serialization not enabled.
+ createInterpreterWithDelegate(/*enableSerialization=*/ false, /*serializationDir=*/ null);
+ long notserializedInitTime = stopWatch.elapsed(MICROSECONDS);
+
+ assertThat(serializedInitTime).isLessThan(notserializedInitTime);
+ }
+ }
+
+ private Interpreter createInterpreterWithDelegate(
+ boolean enableSerialization, String serializationDir) {
+ Interpreter.Options options = new Interpreter.Options();
+ if (enableSerialization) {
+ options.addDelegate(
+ new GpuDelegate(
+ new GpuDelegate.Options()
+ .setSerializationParams(serializationDir, "GpuDelegateTest.testModelToken")));
+ } else {
+ options.addDelegate(new GpuDelegate());
+ }
+ Interpreter interpreter = new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options);
+ return interpreter;
+ }
+
private static ArrayList<Integer> getTopKLabels(byte[][] byteLabels, int k) {
float[][] labels = new float[1][1001];
for (int i = 0; i < byteLabels[0].length; ++i) {