[tf.lite] Adds setSerializationParams() Java API for enabling GPU delegate serialization.

PiperOrigin-RevId: 401822099
Change-Id: I32deb8a95eadb72e4169da3bc328e95ded2900c6
diff --git a/RELEASE.md b/RELEASE.md
index f874104..ede68b9 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -26,6 +26,10 @@
 *<SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
 *<IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
 *<NOTES SHOULD BE GROUPED PER AREA>
+* `tf.lite`:
+  * GPU
+    * Adds GPU Delegation support for serialization to Java API. This boosts
+      initialization time upto 90% when OpenCL is available.
 
 # Thanks to our Contributors
 
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
index e198dad..13524a4 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
+++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java
@@ -64,7 +64,7 @@
     }
 
     /**
-     * Enables running quantized models with the delegate. Defaults to false.
+     * Enables running quantized models with the delegate.
      *
      * <p>WARNING: This is an experimental API and subject to change.
      *
@@ -86,9 +86,30 @@
       return this;
     }
 
+    /**
+     * Enables serialization on the delegate. Note non-null {@code serializationDir} and {@code
+     * modelToken} are required for serialization.
+     *
+     * <p>WARNING: This is an experimental API and subject to change.
+     *
+     * @param serializationDir The directory to use for storing data. Caller is responsible to
+     *     ensure the model is not stored in a public directory. It's recommended to use {@link
+     *     android.content.Context#getCodeCacheDir()} to provide a private location for the
+     *     application on Android.
+     * @param modelToken The token to be used to identify the model. Caller is responsible to ensure
+     *     the token is unique to the model graph and data.
+     */
+    public Options setSerializationParams(String serializationDir, String modelToken) {
+      this.serializationDir = serializationDir;
+      this.modelToken = modelToken;
+      return this;
+    }
+
     boolean precisionLossAllowed = true;
     boolean quantizedModelsAllowed = true;
     int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
+    String serializationDir = null;
+    String modelToken = null;
   }
 
   public GpuDelegate(Options options) {
@@ -96,7 +117,9 @@
         createDelegate(
             options.precisionLossAllowed,
             options.quantizedModelsAllowed,
-            options.inferencePreference);
+            options.inferencePreference,
+            options.serializationDir,
+            options.modelToken);
   }
 
   @UsedByReflection("TFLiteSupport/model/GpuDelegateProxy")
@@ -127,7 +150,11 @@
   }
 
   private static native long createDelegate(
-      boolean precisionLossAllowed, boolean quantizedModelsAllowed, int preference);
+      boolean precisionLossAllowed,
+      boolean quantizedModelsAllowed,
+      int preference,
+      String serializationDir,
+      String modelToken);
 
   private static native void deleteDelegate(long delegateHandle);
 }
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
index 62a52f3..8d85c53 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
+++ b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
@@ -29,7 +29,8 @@
 
 JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
     JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
-    jboolean quantized_models_allowed, jint inference_preference) {
+    jboolean quantized_models_allowed, jint inference_preference,
+    jstring serialization_dir, jstring model_token) {
   TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default();
   if (precision_loss_allowed == JNI_TRUE) {
     options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
@@ -42,6 +43,18 @@
     options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
   }
   options.inference_preference = static_cast<int32_t>(inference_preference);
+  if (serialization_dir) {
+    options.serialization_dir =
+        env->GetStringUTFChars(serialization_dir, /*isCopy=*/nullptr);
+  }
+  if (model_token) {
+    options.model_token =
+        env->GetStringUTFChars(model_token, /*isCopy=*/nullptr);
+  }
+  if (options.serialization_dir && options.model_token) {
+    options.experimental_flags |=
+        TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_SERIALIZATION;
+  }
   return reinterpret_cast<jlong>(TfLiteGpuDelegateV2Create(&options));
 }
 
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
index 66bc35b..2acc341 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
@@ -16,7 +16,10 @@
 package org.tensorflow.lite.gpu;
 
 import static com.google.common.truth.Truth.assertThat;
+import static java.util.concurrent.TimeUnit.MICROSECONDS;
 
+import com.google.common.base.Stopwatch;
+import java.io.File;
 import java.nio.ByteBuffer;
 import java.util.AbstractMap;
 import java.util.ArrayList;
@@ -24,7 +27,9 @@
 import java.util.HashMap;
 import java.util.Map;
 import java.util.PriorityQueue;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 import org.tensorflow.lite.Interpreter;
@@ -41,6 +46,8 @@
       TestUtils.getTestFileAsBuffer(
           "tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
 
+  @Rule public final TemporaryFolder tempDir = new TemporaryFolder();
+
   @Test
   public void testBasic() throws Exception {
     try (GpuDelegate delegate = new GpuDelegate()) {
@@ -113,6 +120,59 @@
     }
   }
 
+  @Test
+  public void testDelegateSerialization() throws Exception {
+    ByteBuffer img =
+        TestUtils.getTestImageAsByteBuffer(
+            "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
+
+    File serializationFolder = tempDir.newFolder();
+    String serializationDir = serializationFolder.getPath();
+
+    // Create the interpreter with serialization enabled delegate.
+    createInterpreterWithDelegate(/*enableSerialization=*/ true, serializationFolder.getPath());
+
+    // In the second interpreter initialization, delegate reuses the serialization data.
+    Stopwatch stopWatch = Stopwatch.createStarted();
+    Interpreter interpreter =
+        createInterpreterWithDelegate(/*enableSerialization=*/ true, serializationFolder.getPath());
+    stopWatch.stop();
+    long serializedInitTime = stopWatch.elapsed(MICROSECONDS);
+    // Check on the model.
+    byte[][] output = new byte[1][1001];
+    interpreter.run(img, output);
+    // 653 == "military uniform"
+    assertThat(getTopKLabels(output, 3)).contains(653);
+
+    // If OpenCL is available, serialized data will be written to serializationDir and
+    // initialization time improvement shall be observed.
+    // Otherwise, this testcase performs a check that enabling the option won't crash.
+    if (serializationFolder.list().length > 0) {
+      stopWatch.reset();
+      stopWatch.start();
+      // Initialze interpreter with GpuDelegate serialization not enabled.
+      createInterpreterWithDelegate(/*enableSerialization=*/ false, /*serializationDir=*/ null);
+      long notserializedInitTime = stopWatch.elapsed(MICROSECONDS);
+
+      assertThat(serializedInitTime).isLessThan(notserializedInitTime);
+    }
+  }
+
+  private Interpreter createInterpreterWithDelegate(
+      boolean enableSerialization, String serializationDir) {
+    Interpreter.Options options = new Interpreter.Options();
+    if (enableSerialization) {
+      options.addDelegate(
+          new GpuDelegate(
+              new GpuDelegate.Options()
+                  .setSerializationParams(serializationDir, "GpuDelegateTest.testModelToken")));
+    } else {
+      options.addDelegate(new GpuDelegate());
+    }
+    Interpreter interpreter = new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options);
+    return interpreter;
+  }
+
   private static ArrayList<Integer> getTopKLabels(byte[][] byteLabels, int k) {
     float[][] labels = new float[1][1001];
     for (int i = 0; i < byteLabels[0].length; ++i) {