Allow XNNPack application to fail in Java API if another static delegate has been applied
PiperOrigin-RevId: 328838712
Change-Id: Iefd56d264bb75cccc0a6fedb5e7d680c387d20f5
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 59afc0c..ba0f569 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -152,7 +152,7 @@
* <ul>
* <li>Startup time and resize time may increase.
* <li>Baseline memory consumption may increase.
- * <li>Compatibility with other delegates (e.g., GPU) has not been fully validated.
+ * <li>May be ignored if another delegate (eg NNAPI) have been applied.
* <li>Quantized models will not see any benefit.
* </ul>
*
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 2d1844f..fc0857f 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -367,8 +367,14 @@
}
tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
xnnpack_create(&options), xnnpack_delete);
- if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) !=
- kTfLiteOk) {
+ auto delegation_status =
+ interpreter->ModifyGraphWithDelegate(std::move(delegate));
+ // kTfLiteApplicationError occurs in cases where delegation fails but
+ // the runtime is invokable (eg. another delegate has already been applied).
+ // We don't throw an Exception in that case.
+ // TODO(b/166483905): Add support for multiple delegates when model allows.
+ if (delegation_status != kTfLiteOk &&
+ delegation_status != kTfLiteApplicationError) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Failed to apply XNNPACK delegate: %s",
error_reporter->CachedErrorMessage());
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java
index 45d66e2..fc9038c 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java
@@ -57,6 +57,25 @@
}
@Test
+ public void testInterpreterWithNnApiAndXNNPack() throws Exception {
+ Interpreter.Options options = new Interpreter.Options();
+ options.setUseXNNPACK(true);
+
+ try (NnApiDelegate delegate = new NnApiDelegate();
+ Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ }
+ }
+
+ @Test
public void testInterpreterWithNnApiAllowFp16() throws Exception {
Interpreter.Options options = new Interpreter.Options();
NnApiDelegate.Options nnApiOptions = new NnApiDelegate.Options();