blob: a23007b2853214a3bf4a7e1ea3c9bec8e0d32375 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.tensorflow.lite.nnapi.NnApiDelegate;
/**
* An internal wrapper that wraps native interpreter and controls model execution.
*
* <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
* explicitly freed by invoking the {@link #close()} method when the {@code
* NativeInterpreterWrapper} object is no longer needed.
*
* <p>Note: This class is not thread safe.
*/
final class NativeInterpreterWrapper implements AutoCloseable {
NativeInterpreterWrapper(String modelPath) {
this(modelPath, /* options= */ null);
}
NativeInterpreterWrapper(ByteBuffer byteBuffer) {
this(byteBuffer, /* options= */ null);
}
NativeInterpreterWrapper(String modelPath, Interpreter.Options options) {
TensorFlowLite.init();
long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
long modelHandle = createModel(modelPath, errorHandle);
init(errorHandle, modelHandle, options);
}
NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
TensorFlowLite.init();
if (buffer == null
|| (!(buffer instanceof MappedByteBuffer)
&& (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) {
throw new IllegalArgumentException(
"Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct "
+ "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content.");
}
this.modelByteBuffer = buffer;
long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
init(errorHandle, modelHandle, options);
}
private void init(long errorHandle, long modelHandle, Interpreter.Options options) {
if (options == null) {
options = new Interpreter.Options();
}
this.errorHandle = errorHandle;
this.modelHandle = modelHandle;
this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
if (options.allowCancellation != null && options.allowCancellation) {
this.cancellationFlagHandle = createCancellationFlag(interpreterHandle);
}
this.inputTensors = new Tensor[getInputCount(interpreterHandle)];
this.outputTensors = new Tensor[getOutputCount(interpreterHandle)];
if (options.allowFp16PrecisionForFp32 != null) {
allowFp16PrecisionForFp32(
interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue());
}
if (options.allowBufferHandleOutput != null) {
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
}
applyDelegates(options);
// Simply use "-1" to represent the default mode.
int applyXNNPACKMode = -1;
if (options.useXNNPACK != null) {
applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
}
// TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
// enabled for C++ TfLite library on Android platform.
if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
}
allocateTensors(interpreterHandle, errorHandle);
this.isMemoryAllocated = true;
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@Override
public void close() {
// Close the tensors first as they may reference the native interpreter.
for (int i = 0; i < inputTensors.length; ++i) {
if (inputTensors[i] != null) {
inputTensors[i].close();
inputTensors[i] = null;
}
}
for (int i = 0; i < outputTensors.length; ++i) {
if (outputTensors[i] != null) {
outputTensors[i].close();
outputTensors[i] = null;
}
}
delete(errorHandle, modelHandle, interpreterHandle);
deleteCancellationFlag(cancellationFlagHandle);
errorHandle = 0;
modelHandle = 0;
interpreterHandle = 0;
cancellationFlagHandle = 0;
modelByteBuffer = null;
inputsIndexes = null;
outputsIndexes = null;
isMemoryAllocated = false;
delegates.clear();
for (AutoCloseable ownedDelegate : ownedDelegates) {
try {
ownedDelegate.close();
} catch (Exception e) {
System.err.println("Failed to close flex delegate: " + e);
}
}
ownedDelegates.clear();
}
public void runSignature(
Map<String, Object> inputs, Map<String, Object> outputs, String methodName) {
if (inputs == null || inputs.isEmpty()) {
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
}
if (outputs == null || outputs.isEmpty()) {
throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
}
initTensorIndexesMaps();
// Map inputs/output to input indexes.
Map<Integer, Object> inputsWithInputIndex = new TreeMap<>();
Map<Integer, Object> outputsWithOutputIndex = new TreeMap<>();
for (Map.Entry<String, Object> input : inputs.entrySet()) {
int tensorIndex =
getInputTensorIndexFromSignature(interpreterHandle, input.getKey(), methodName);
inputsWithInputIndex.put(tensorToInputsIndexes.get(tensorIndex), input.getValue());
}
for (Map.Entry<String, Object> output : outputs.entrySet()) {
int tensorIndex =
getOutputTensorIndexFromSignature(interpreterHandle, output.getKey(), methodName);
outputsWithOutputIndex.put(tensorToOutputsIndexes.get(tensorIndex), output.getValue());
}
Object[] inputsList = new Object[inputs.size()];
int index = 0;
for (Map.Entry<Integer, Object> input : inputsWithInputIndex.entrySet()) {
inputsList[index++] = input.getValue();
}
run(inputsList, outputsWithOutputIndex);
}
/** Sets inputs, runs model inference and returns outputs. */
void run(Object[] inputs, Map<Integer, Object> outputs) {
inferenceDurationNanoseconds = -1;
if (inputs == null || inputs.length == 0) {
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
}
if (outputs == null || outputs.isEmpty()) {
throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
}
// TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs.
// Rather than forcing an immediate resize + allocation if an input's shape differs, we first
// flush all resizes, avoiding redundant allocations.
for (int i = 0; i < inputs.length; ++i) {
Tensor tensor = getInputTensor(i);
int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
if (newShape != null) {
resizeInput(i, newShape);
}
}
boolean needsAllocation = !isMemoryAllocated;
if (needsAllocation) {
allocateTensors(interpreterHandle, errorHandle);
isMemoryAllocated = true;
}
for (int i = 0; i < inputs.length; ++i) {
getInputTensor(i).setTo(inputs[i]);
}
long inferenceStartNanos = System.nanoTime();
run(interpreterHandle, errorHandle);
long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
// Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes.
if (needsAllocation) {
for (int i = 0; i < outputTensors.length; ++i) {
if (outputTensors[i] != null) {
outputTensors[i].refreshShape();
}
}
}
for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
getOutputTensor(output.getKey()).copyTo(output.getValue());
}
// Only set if the entire operation succeeds.
this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
}
private static native void run(long interpreterHandle, long errorHandle);
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims) {
resizeInput(idx, dims, false);
}
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims, boolean strict) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims, strict)) {
// Tensor allocation is deferred until either an explicit `allocateTensors()` call or
// `invoke()` avoiding redundant allocations if multiple tensors are simultaneosly resized.
isMemoryAllocated = false;
if (inputTensors[idx] != null) {
inputTensors[idx].refreshShape();
}
}
}
private static native boolean resizeInput(
long interpreterHandle, long errorHandle, int inputIdx, int[] dims, boolean strict);
/** Triggers explicit allocation of tensors. */
void allocateTensors() {
if (isMemoryAllocated) {
return;
}
isMemoryAllocated = true;
allocateTensors(interpreterHandle, errorHandle);
for (int i = 0; i < outputTensors.length; ++i) {
if (outputTensors[i] != null) {
outputTensors[i].refreshShape();
}
}
}
private static native long allocateTensors(long interpreterHandle, long errorHandle);
void setNumThreads(int numThreads) {
numThreads(interpreterHandle, numThreads);
}
void modifyGraphWithDelegate(Delegate delegate) {
applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
delegates.add(delegate);
}
void resetVariableTensors() {
resetVariableTensors(interpreterHandle, errorHandle);
}
/** Gets index of an input given its name. */
int getInputIndex(String name) {
if (inputsIndexes == null) {
String[] names = getInputNames(interpreterHandle);
inputsIndexes = new HashMap<>();
if (names != null) {
for (int i = 0; i < names.length; ++i) {
inputsIndexes.put(names[i], i);
}
}
}
if (inputsIndexes.containsKey(name)) {
return inputsIndexes.get(name);
} else {
throw new IllegalArgumentException(
String.format(
"Input error: '%s' is not a valid name for any input. Names of inputs and their "
+ "indexes are %s",
name, inputsIndexes));
}
}
/** Initializes mapping from tensor index to input/output index. * */
private void initTensorIndexesMaps() {
if (tensorToInputsIndexes != null) {
return;
}
tensorToInputsIndexes = new HashMap<>();
tensorToOutputsIndexes = new HashMap<>();
int inputCount = getInputTensorCount();
for (int i = 0; i < inputCount; ++i) {
int tensorIndex = getInputTensorIndex(interpreterHandle, i);
tensorToInputsIndexes.put(tensorIndex, i);
}
int outputCount = getOutputTensorCount();
for (int i = 0; i < outputCount; ++i) {
int tensorIndex = getOutputTensorIndex(interpreterHandle, i);
tensorToOutputsIndexes.put(tensorIndex, i);
}
}
/** Gets index of an output given its name. */
int getOutputIndex(String name) {
if (outputsIndexes == null) {
String[] names = getOutputNames(interpreterHandle);
outputsIndexes = new HashMap<>();
if (names != null) {
for (int i = 0; i < names.length; ++i) {
outputsIndexes.put(names[i], i);
}
}
}
if (outputsIndexes.containsKey(name)) {
return outputsIndexes.get(name);
} else {
throw new IllegalArgumentException(
String.format(
"Input error: '%s' is not a valid name for any output. Names of outputs and their "
+ "indexes are %s",
name, outputsIndexes));
}
}
/**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
*/
Long getLastNativeInferenceDurationNanoseconds() {
return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds;
}
/** Gets the number of input tensors. */
int getInputTensorCount() {
return inputTensors.length;
}
/**
* Gets the input {@link Tensor} for the provided input index.
*
* @throws IllegalArgumentException if the input index is invalid.
*/
Tensor getInputTensor(int index) {
if (index < 0 || index >= inputTensors.length) {
throw new IllegalArgumentException("Invalid input Tensor index: " + index);
}
Tensor inputTensor = inputTensors[index];
if (inputTensor == null) {
inputTensor =
inputTensors[index] =
Tensor.fromIndex(interpreterHandle, getInputTensorIndex(interpreterHandle, index));
}
return inputTensor;
}
/**
* Gets the input {@link Tensor} given the tensor name and method in the signature.
*
* @throws IllegalArgumentException if the input name is invalid.
*/
Tensor getInputTensor(String inputName, String methodName) {
if (inputName == null) {
throw new IllegalArgumentException("Invalid input tensor name provided (null)");
}
initTensorIndexesMaps();
int tensorIndex = getInputTensorIndexFromSignature(interpreterHandle, inputName, methodName);
if (!tensorToInputsIndexes.containsKey(tensorIndex)) {
throw new IllegalArgumentException(
String.format(
"Invalid input tensor name (%s) for signature (%s).", inputName, methodName));
}
return getInputTensor(tensorToInputsIndexes.get(tensorIndex));
}
/** Gets the list of SignatureDefs available in the model, if any. */
public String[] getSignatureDefNames() {
return getSignatureDefNames(interpreterHandle);
}
private static native String[] getSignatureDefNames(long interpreterHandle);
/** Gets the list of SignatureDefs inputs for method {@code methodName} */
String[] getSignatureInputs(String methodName) {
return getSignatureInputs(interpreterHandle, methodName);
}
private static native String[] getSignatureInputs(long interpreterHandle, String methodName);
/** Gets the list of SignatureDefs outputs for method {@code methodName} */
String[] getSignatureOutputs(String methodName) {
return getSignatureOutputs(interpreterHandle, methodName);
}
private static native String[] getSignatureOutputs(long interpreterHandle, String methodName);
/** Gets the number of output tensors. */
int getOutputTensorCount() {
return outputTensors.length;
}
/**
* Gets the output {@link Tensor} for the provided output index.
*
* @throws IllegalArgumentException if the output index is invalid.
*/
Tensor getOutputTensor(int index) {
if (index < 0 || index >= outputTensors.length) {
throw new IllegalArgumentException("Invalid output Tensor index: " + index);
}
Tensor outputTensor = outputTensors[index];
if (outputTensor == null) {
outputTensor =
outputTensors[index] =
Tensor.fromIndex(interpreterHandle, getOutputTensorIndex(interpreterHandle, index));
}
return outputTensor;
}
/**
* Gets the output {@link Tensor} given the tensor name and method in the signature.
*
* @throws IllegalArgumentException if the output name is invalid.
*/
Tensor getOutputTensor(String outputName, String methodName) {
if (outputName == null) {
throw new IllegalArgumentException("Invalid output tensor name provided (null)");
}
initTensorIndexesMaps();
int tensorIndex = getOutputTensorIndexFromSignature(interpreterHandle, outputName, methodName);
if (!tensorToOutputsIndexes.containsKey(tensorIndex)) {
throw new IllegalArgumentException(
String.format(
"Invalid output tensor name (%s) for signature (%s).", outputName, methodName));
}
return getOutputTensor(tensorToOutputsIndexes.get(tensorIndex));
}
/** Gets the number of ops in the execution plan. */
int getExecutionPlanLength() {
return getExecutionPlanLength(interpreterHandle);
}
/**
* Sets internal cancellation flag. If it's true, the interpreter will try to interrupt any
* invocation between ops.
*/
void setCancelled(boolean value) {
if (cancellationFlagHandle == 0) {
throw new IllegalStateException(
"Cannot cancel the inference. Have you called Interpreter.Options.setCancellable?");
}
setCancelled(interpreterHandle, cancellationFlagHandle, value);
}
private static native void setCancelled(
long interpreterHandle, long cancellationFlagHandle, boolean value);
private void applyDelegates(Interpreter.Options options) {
// First apply the flex delegate if necessary. This ensures the graph is fully resolved before
// applying other delegates.
boolean originalGraphHasUnresolvedFlexOp = hasUnresolvedFlexOp(interpreterHandle);
if (originalGraphHasUnresolvedFlexOp) {
Delegate optionalFlexDelegate = maybeCreateFlexDelegate(options.delegates);
if (optionalFlexDelegate != null) {
ownedDelegates.add((AutoCloseable) optionalFlexDelegate);
applyDelegate(interpreterHandle, errorHandle, optionalFlexDelegate.getNativeHandle());
}
}
// Now apply the user-supplied delegates.
try {
for (Delegate delegate : options.delegates) {
applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
delegates.add(delegate);
}
if (options.useNNAPI != null && options.useNNAPI.booleanValue()) {
NnApiDelegate optionalNnApiDelegate = new NnApiDelegate();
ownedDelegates.add(optionalNnApiDelegate);
applyDelegate(interpreterHandle, errorHandle, optionalNnApiDelegate.getNativeHandle());
}
} catch (IllegalArgumentException e) {
// Suppress exceptions where a delegate fails to apply after the flex delegate is successfuly
// applied. This can be a common occurrence, as the flex delegate makes the graph dynamic,
// which is typically unsupported by most delegates (e.g., NNAPI, GPU delegates). We should
// still log an error to indicate that the delegate application was a no-op.
// TODO(b/142678372): Fix the flex delegate to not unconditionally mark graphs as dynamic.
boolean shouldSuppressException =
originalGraphHasUnresolvedFlexOp && !hasUnresolvedFlexOp(interpreterHandle);
if (!shouldSuppressException) {
throw e;
}
System.err.println("Ignoring failed delegate application: " + e);
}
}
private static Delegate maybeCreateFlexDelegate(List<Delegate> delegates) {
try {
Class<?> clazz = Class.forName("org.tensorflow.lite.flex.FlexDelegate");
// No need to create the Flex delegate if one has already been provided.
for (Delegate delegate : delegates) {
if (clazz.isInstance(delegate)) {
return null;
}
}
return (Delegate) clazz.getConstructor().newInstance();
} catch (Exception e) {
// The error will propagate when tensors are allocated.
return null;
}
}
private static native int getOutputDataType(long interpreterHandle, int outputIdx);
private static final int ERROR_BUFFER_SIZE = 512;
private long errorHandle;
private long interpreterHandle;
private long modelHandle;
private long cancellationFlagHandle = 0;
private long inferenceDurationNanoseconds = -1;
private ByteBuffer modelByteBuffer;
// Lazily constructed maps of input and output names to input and output Tensor indexes.
private Map<String, Integer> inputsIndexes;
private Map<String, Integer> outputsIndexes;
// Lazily constructed maps of tensor index to index in input and output indexes.
private Map<Integer, Integer> tensorToInputsIndexes;
private Map<Integer, Integer> tensorToOutputsIndexes;
// Lazily constructed and populated arrays of input and output Tensor wrappers.
private Tensor[] inputTensors;
private Tensor[] outputTensors;
private boolean isMemoryAllocated = false;
// As the Java Delegate owns the native delegate instance, we keep a strong ref to any injected
// delegates for safety.
private final List<Delegate> delegates = new ArrayList<>();
// List of owned delegates that must be closed when the interpreter is closed.
private final List<AutoCloseable> ownedDelegates = new ArrayList<>();
private static native boolean hasUnresolvedFlexOp(long interpreterHandle);
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
private static native int getInputTensorIndexFromSignature(
long interpreterHandle, String signatureInputName, String methodName);
private static native int getOutputTensorIndexFromSignature(
long interpreterHandle, String signatureInputName, String methodName);
private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
private static native int getInputCount(long interpreterHandle);
private static native int getOutputCount(long interpreterHandle);
private static native int getExecutionPlanLength(long interpreterHandle);
private static native String[] getInputNames(long interpreterHandle);
private static native String[] getOutputNames(long interpreterHandle);
private static native void numThreads(long interpreterHandle, int numThreads);
private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow);
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
private static native void useXNNPACK(
long interpreterHandle, long errorHandle, int state, int numThreads);
private static native long createErrorReporter(int size);
private static native long createModel(String modelPathOrBuffer, long errorHandle);
private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle);
private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads);
private static native void applyDelegate(
long interpreterHandle, long errorHandle, long delegateHandle);
private static native void resetVariableTensors(long interpreterHandle, long errorHandle);
private static native long createCancellationFlag(long interpreterHandle);
private static native long deleteCancellationFlag(long cancellationFlagHandle);
private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
}