blob: 9ed24781cc04997cdc89110deb5f4c3015de08d9 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;
/** Unit tests for {@link org.tensorflow.lite.InterpreterApi}. */
@RunWith(JUnit4.class)
public final class InterpreterApiTest {
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin";
private static final String MULTIPLE_INPUTS_MODEL_PATH =
"tensorflow/lite/testdata/multi_add.bin";
private static final String FLEX_MODEL_PATH =
"tensorflow/lite/testdata/multi_add_flex.bin";
private static final String UNKNOWN_DIMS_MODEL_PATH =
"tensorflow/lite/java/src/testdata/add_unknown_dimensions.bin";
private static final String DYNAMIC_SHAPES_MODEL_PATH =
"tensorflow/lite/testdata/dynamic_shapes.bin";
private static final String BOOL_MODEL =
"tensorflow/lite/java/src/testdata/tile_with_bool_input.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(MULTIPLE_INPUTS_MODEL_PATH);
private static final ByteBuffer FLEX_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(FLEX_MODEL_PATH);
private static final ByteBuffer UNKNOWN_DIMS_MODEL_PATH_BUFFER =
TestUtils.getTestFileAsBuffer(UNKNOWN_DIMS_MODEL_PATH);
private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH);
private static final ByteBuffer BOOL_MODEL_BUFFER = TestUtils.getTestFileAsBuffer(BOOL_MODEL);
// We want to run these tests both with the TF Lite runtime library linked in,
// and also using the system TF Lite runtime library if the client for that is linked in.
// So we need to use a runtime setting that will work with both scenarios.
private static final InterpreterApi.Options TEST_OPTIONS =
new InterpreterApi.Options().setRuntime(TfLiteRuntime.PREFER_SYSTEM_OVER_APPLICATION);
@Before
public void setUp() {
TestInit.init();
}
@Test
public void testInterpreter() throws Exception {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
assertThat(interpreter).isNotNull();
assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
}
}
@Test
public void testInterpreterWithOptions() throws Exception {
InterpreterApi.Options options = new InterpreterApi.Options(TEST_OPTIONS);
try (InterpreterApi interpreter =
new InterpreterFactory().create(MODEL_BUFFER, options.setNumThreads(2).setUseNNAPI(true))) {
assertThat(interpreter).isNotNull();
assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
}
}
@Test
public void testInterpreterWithNullOptions() throws Exception {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, null)) {
assertThat(interpreter).isNotNull();
assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
} catch (IllegalStateException e) {
// This can occur when this code is not linked against the TF Lite runtime.
// Verify that the error message has some hints about how to link
// against the runtime ("org.tensorflow:tensorflow-lite:<version>").
assertThat(e).hasMessageThat().contains("org.tensorflow");
assertThat(e).hasMessageThat().contains("tensorflow-lite");
}
}
@Test
public void testRuntimeFromApplicationOnly() throws Exception {
InterpreterApi.Options options =
new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_APPLICATION_ONLY);
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, options)) {
assertThat(interpreter).isNotNull();
assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
} catch (IllegalStateException e) {
// This can occur when this code is not linked against the TF Lite runtime.
// Verify that the error message has some hints about how to link
// against the runtime ("org.tensorflow:tensorflow-lite:<version>").
assertThat(e).hasMessageThat().contains("org.tensorflow");
assertThat(e).hasMessageThat().contains("tensorflow-lite");
}
}
@Test
public void testRuntimeFromSystemOnly() throws Exception {
InterpreterApi.Options options =
new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY);
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, options)) {
assertThat(interpreter).isNotNull();
assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
} catch (IllegalStateException e) {
// This can occur when this code is not linked against the TF Lite runtime.
// Verify that the error message has some hints about how to link in the
// client library for TF Lite in Google Play Services
// ("com.google.android.gms:play-services-tflite-java:<version>").
assertThat(e).hasMessageThat().contains("com.google.android.gms");
assertThat(e).hasMessageThat().contains("play-services-tflite-java");
}
}
@Test
public void testRunWithFileModel() throws Exception {
if (!TestUtils.supportsFilePaths()) {
System.err.println("Not testing with file model, since file paths aren't supported.");
return;
}
try (InterpreterApi interpreter =
new InterpreterFactory().create(new File(MODEL_PATH), TEST_OPTIONS)) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
}
@Test
public void testRunWithDirectByteBufferModel() throws Exception {
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(MODEL_BUFFER.capacity());
byteBuffer.order(ByteOrder.nativeOrder());
byteBuffer.put(MODEL_BUFFER.duplicate()); // Use duplicate to avoid updating MODEL_BUFFER.
try (InterpreterApi interpreter = new InterpreterFactory().create(byteBuffer, TEST_OPTIONS)) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
}
@Test
public void testRunWithInvalidByteBufferModel() throws Exception {
ByteBuffer byteBuffer = ByteBuffer.allocate(MODEL_BUFFER.capacity());
byteBuffer.order(ByteOrder.nativeOrder());
byteBuffer.put(MODEL_BUFFER.duplicate()); // Use duplicate to avoid updating MODEL_BUFFER.
try {
new InterpreterFactory().create(byteBuffer, TEST_OPTIONS);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
"Model ByteBuffer should be either a MappedByteBuffer"
+ " of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder()");
}
}
@Test
public void testRun() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
Float[] oneD = {1.23f, 6.54f, 7.81f};
Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
Float[][][][] fourD = {threeD, threeD};
Float[][][][] parsedOutputs = new Float[2][8][8][3];
try {
interpreter.run(fourD, parsedOutputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;");
}
}
}
@Test
public void testRunWithBoxedInputs() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
}
@Test
public void testRunForMultipleInputsOutputs() {
try (InterpreterApi interpreter =
new InterpreterFactory().create(MULTIPLE_INPUTS_MODEL_BUFFER, TEST_OPTIONS)) {
assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
assertThat(interpreter.getInputTensor(0).index()).isGreaterThan(-1);
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getInputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getInputTensor(2).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getInputTensor(3).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getOutputTensorCount()).isEqualTo(2);
assertThat(interpreter.getOutputTensor(0).index()).isGreaterThan(-1);
assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
assertThat(interpreter.getOutputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
float[] input0 = {1.23f};
float[] input1 = {2.43f};
Object[] inputs = {input0, input1, input0, input1};
float[] parsedOutput0 = new float[1];
float[] parsedOutput1 = new float[1];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, parsedOutput0);
outputs.put(1, parsedOutput1);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
float[] expected0 = {4.89f};
float[] expected1 = {6.09f};
assertThat(parsedOutput0).usingTolerance(0.1f).containsExactly(expected0).inOrder();
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
}
}
@Test
public void testRunWithByteBufferOutput() {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
interpreter.run(fourD, parsedOutput);
}
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testRunWithScalarInput() {
FloatBuffer parsedOutput = FloatBuffer.allocate(1);
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
interpreter.run(2.37f, parsedOutput);
}
assertThat(parsedOutput.get(0)).isWithin(0.1f).of(7.11f);
}
@Test
public void testResizeInput() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
int[] inputDims = {1};
interpreter.resizeInput(0, inputDims);
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
interpreter.run(input, output);
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
}
}
@Test
public void testAllocateTensors() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
// Redundant allocateTensors() should have no effect.
interpreter.allocateTensors();
// allocateTensors() should propagate resizes.
int[] inputDims = {1};
assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims);
interpreter.resizeInput(0, inputDims);
assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims);
interpreter.allocateTensors();
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
// Additional redundant calls should have no effect.
interpreter.allocateTensors();
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
// Execution should succeed as expected.
ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
interpreter.run(input, output);
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
}
}
@Test
public void testUnknownDims() {
try (InterpreterApi interpreter =
new InterpreterFactory().create(UNKNOWN_DIMS_MODEL_PATH_BUFFER, TEST_OPTIONS)) {
int[] inputDims = {1, 1, 3, 3};
int[] inputDimsSignature = {1, -1, 3, 3};
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
assertThat(interpreter.getInputTensor(0).shapeSignature()).isEqualTo(inputDimsSignature);
// Resize tensor with strict checking. Try invalid resize.
inputDims[2] = 5;
try {
interpreter.resizeInput(0, inputDims, true);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
"ResizeInputTensorStrict only allows mutating unknown dimensions identified by -1");
}
inputDims[2] = 3;
// Set the dimension of the unknown dimension to the expected dimension and ensure shape
// signature doesn't change.
inputDims[1] = 3;
interpreter.resizeInput(0, inputDims, true);
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
assertThat(interpreter.getInputTensor(0).shapeSignature()).isEqualTo(inputDimsSignature);
ByteBuffer input =
ByteBuffer.allocateDirect(1 * 3 * 3 * 3 * 4).order(ByteOrder.nativeOrder());
ByteBuffer output =
ByteBuffer.allocateDirect(1 * 3 * 3 * 3 * 4).order(ByteOrder.nativeOrder());
interpreter.run(input, output);
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
}
}
@Test
public void testRunWithWrongInputType() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
int[] oneD = {4, 3, 9};
int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
try {
interpreter.run(fourD, parsedOutputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot convert between a TensorFlowLite tensor with type "
+ "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ " TensorFlowLite type INT32)");
}
}
}
@Test
public void testRunWithUnsupportedInputType() {
DoubleBuffer doubleBuffer = DoubleBuffer.allocate(10);
float[][][][] parsedOutputs = new float[2][8][8][3];
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
interpreter.run(doubleBuffer, parsedOutputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("DataType error: cannot resolve DataType of");
}
}
@Test
public void testRunWithWrongOutputType() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
int[][][][] parsedOutputs = new int[2][8][8][3];
try {
interpreter.run(fourD, parsedOutputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot convert between a TensorFlowLite tensor with type "
+ "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ " TensorFlowLite type INT32)");
}
}
}
@Test
public void testGetInputIndex() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
try {
interpreter.getInputIndex("WrongInputName");
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
"'WrongInputName' is not a valid name for any input. Names of inputs and their "
+ "indexes are {input=0}");
}
int index = interpreter.getInputIndex("input");
assertThat(index).isEqualTo(0);
}
}
@Test
public void testGetOutputIndex() {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
try {
interpreter.getOutputIndex("WrongOutputName");
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
"'WrongOutputName' is not a valid name for any output. Names of outputs and their"
+ " indexes are {output=0}");
}
int index = interpreter.getOutputIndex("output");
assertThat(index).isEqualTo(0);
}
}
@Test
public void testTurnOnNNAPI() throws Exception {
InterpreterApi.Options options = new InterpreterApi.Options(TEST_OPTIONS).setUseNNAPI(true);
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, options)) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
}
@Test
public void testRedundantClose() throws Exception {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
interpreter.close();
interpreter.close();
} // Implicitly calls interpreter.close() for a third time.
}
@Test
public void testNullInputs() throws Exception {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
try {
interpreter.run(null, new float[2][8][8][3]);
fail();
} catch (IllegalArgumentException e) {
// Expected failure.
}
}
}
@Test
public void testNullOutputs() throws Exception {
try (InterpreterApi interpreter = new InterpreterFactory().create(MODEL_BUFFER, TEST_OPTIONS)) {
float[] input = {1.f};
interpreter.run(input, null);
float output = interpreter.getOutputTensor(0).asReadOnlyBuffer().getFloat(0);
assertThat(output).isEqualTo(3.f);
}
}
// Smoke test validating that flex model loading fails when the flex delegate is not linked.
@Test
public void testFlexModel() throws Exception {
try {
new InterpreterFactory().create(FLEX_MODEL_BUFFER, TEST_OPTIONS);
fail();
} catch (IllegalStateException e) {
// Expected failure.
} catch (IllegalArgumentException e) {
// As we could apply some TfLite delegate by default, the flex ops preparation could fail if
// the flex delegate isn't applied first, in which case this type of exception is thrown.
// Expected failure.
}
}
@Test
public void testBoolModel() throws Exception {
boolean[][][] inputs = {{{true, false}, {false, true}}, {{true, true}, {false, true}}};
int[] multipliers = {1, 1, 2};
boolean[][][] parsedOutputs = new boolean[2][2][4];
try (InterpreterApi interpreter =
new InterpreterFactory().create(BOOL_MODEL_BUFFER, TEST_OPTIONS)) {
assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.BOOL);
Object[] inputsArray = {inputs, multipliers};
Map<Integer, Object> outputsMap = new HashMap<>();
outputsMap.put(0, parsedOutputs);
interpreter.runForMultipleInputsOutputs(inputsArray, outputsMap);
boolean[][][] expectedOutputs = {
{{true, false, true, false}, {false, true, false, true}},
{{true, true, true, true}, {false, true, false, true}}
};
assertThat(parsedOutputs).isEqualTo(expectedOutputs);
}
}
private static FloatBuffer fill(FloatBuffer buffer, float value) {
while (buffer.hasRemaining()) {
buffer.put(value);
}
buffer.rewind();
return buffer;
}
// Regression test case to ensure that graphs with dynamically computed shapes work properly.
// Historically, direct ByteBuffer addresses would overwrite the arena-allocated tensor input
// pointers. Normally this works fine, but for dynamic graphs, the original input tensor pointers
// may be "restored" at invocation time by the arena allocator, resetting the direct ByteBuffer
// address and leading to stale input data being used.
@Test
public void testDynamicShapesWithDirectBufferInputs() {
try (InterpreterApi interpreter =
new InterpreterFactory().create(DYNAMIC_SHAPES_MODEL_BUFFER, TEST_OPTIONS)) {
ByteBuffer input0 =
ByteBuffer.allocateDirect(8 * 42 * 1024 * 4).order(ByteOrder.nativeOrder());
ByteBuffer input1 =
ByteBuffer.allocateDirect(1 * 90 * 1024 * 4).order(ByteOrder.nativeOrder());
ByteBuffer input2 = ByteBuffer.allocateDirect(1 * 4).order(ByteOrder.nativeOrder());
Object[] inputs = {input0, input1, input2};
fill(input0.asFloatBuffer(), 2.0f);
fill(input1.asFloatBuffer(), 0.5f);
// Note that the value of this input dictates the shape of the output.
fill(input2.asFloatBuffer(), 1.0f);
FloatBuffer output = FloatBuffer.allocate(8 * 1 * 1024);
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, output);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
FloatBuffer expected = fill(FloatBuffer.allocate(8 * 1 * 1024), 2.0f);
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
}
}
@Test
public void testDynamicShapesWithEmptyOutputs() {
try (InterpreterApi interpreter =
new InterpreterFactory().create(DYNAMIC_SHAPES_MODEL_BUFFER, TEST_OPTIONS)) {
ByteBuffer input0 =
ByteBuffer.allocateDirect(8 * 42 * 1024 * 4).order(ByteOrder.nativeOrder());
ByteBuffer input1 =
ByteBuffer.allocateDirect(1 * 90 * 1024 * 4).order(ByteOrder.nativeOrder());
ByteBuffer input2 = ByteBuffer.allocateDirect(1 * 4).order(ByteOrder.nativeOrder());
Object[] inputs = {input0, input1, input2};
fill(input0.asFloatBuffer(), 2.0f);
fill(input1.asFloatBuffer(), 0.5f);
fill(input2.asFloatBuffer(), 1.0f);
// Use an empty output map; the output data will be retrieved directly from the tensor.
Map<Integer, Object> outputs = new HashMap<>();
interpreter.runForMultipleInputsOutputs(inputs, outputs);
FloatBuffer output = FloatBuffer.allocate(8 * 1 * 1024);
output.put(interpreter.getOutputTensor(0).asReadOnlyBuffer().asFloatBuffer());
FloatBuffer expected = fill(FloatBuffer.allocate(8 * 1 * 1024), 2.0f);
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
}
}
}