blob: 8a84853b6acded5f9c45dd6e300cd2ad7e04e2a1 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite;
import java.io.File;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
/**
* Driver class to drive model inference with TensorFlow Lite.
*
* <p>Note: If you don't need access to any of the "experimental" API features below, prefer to use
* InterpreterApi and InterpreterFactory rather than using Interpreter directly.
*
* <p>A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
* are executed for model inference.
*
* <p>For example, if a model takes only one input and returns only one output:
*
* <pre>{@code
* try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
* interpreter.run(input, output);
* }
* }</pre>
*
* <p>If a model takes multiple inputs or outputs:
*
* <pre>{@code
* Object[] inputs = {input0, input1, ...};
* Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
* FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4.
* ith_output.order(ByteOrder.nativeOrder());
* map_of_indices_to_outputs.put(i, ith_output);
* try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
* interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
* }
* }</pre>
*
* <p>If a model takes or produces string tensors:
*
* <pre>{@code
* String[] input = {"foo", "bar"}; // Input tensor shape is [2].
* String[] output = new String[3][2]; // Output tensor shape is [3, 2].
* try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
* interpreter.runForMultipleInputsOutputs(input, output);
* }
* }</pre>
*
* <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite
* model with Toco, as are the default shapes of the inputs.
*
* <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
* be implicitly resized according to that array's shape. When inputs are provided as {@code Buffer}
* types, no implicit resizing is done; the caller must ensure that the {@code Buffer} byte size
* either matches that of the corresponding tensor, or that they first resize the tensor via {@link
* #resizeInput(int, int[])}. Tensor shape and type information can be obtained via the {@link
* Tensor} class, available via {@link #getInputTensor(int)} and {@link #getOutputTensor(int)}.
*
* <p><b>WARNING:</b>{@code Interpreter} instances are <b>not</b> thread-safe. A {@code Interpreter}
* owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
*
* <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19,
* but is not guaranteed.
*/
public final class Interpreter extends InterpreterImpl implements InterpreterApi {
/** An options class for controlling runtime interpreter behavior. */
public static class Options extends InterpreterImpl.Options {
public Options() {}
public Options(InterpreterApi.Options options) {
super(options);
}
Options(InterpreterImpl.Options options) {
super(options);
}
@Override
public Options setNumThreads(int numThreads) {
super.setNumThreads(numThreads);
return this;
}
@Override
public Options setUseNNAPI(boolean useNNAPI) {
super.setUseNNAPI(useNNAPI);
return this;
}
/**
* Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
* (disallow).
*
* @deprecated Prefer using <a
* href="https://github.com/tensorflow/tensorflow/blob/5dc7f6981fdaf74c8c5be41f393df705841fb7c5/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java#L127">NnApiDelegate.Options#setAllowFp16(boolean
* enable)</a>.
*/
@Deprecated
public Options setAllowFp16PrecisionForFp32(boolean allow) {
this.allowFp16PrecisionForFp32 = allow;
return this;
}
@Override
public Options addDelegate(Delegate delegate) {
super.addDelegate(delegate);
return this;
}
@Override
public Options addDelegateFactory(DelegateFactory delegateFactory) {
super.addDelegateFactory(delegateFactory);
return this;
}
/**
* Advanced: Set if buffer handle output is allowed.
*
* <p>When a {@link Delegate} supports hardware acceleration, the interpreter will make the data
* of output tensors available in the CPU-allocated tensor buffers by default. If the client can
* consume the buffer handle directly (e.g. reading output from OpenGL texture), it can set this
* flag to false, avoiding the copy of data to the CPU buffer. The delegate documentation should
* indicate whether this is supported and how it can be used.
*
* <p>WARNING: This is an experimental interface that is subject to change.
*/
public Options setAllowBufferHandleOutput(boolean allow) {
this.allowBufferHandleOutput = allow;
return this;
}
@Override
public Options setCancellable(boolean allow) {
super.setCancellable(allow);
return this;
}
/**
* Experimental: Disable an optimized set of CPU kernels (provided by XNNPACK).
*
* <p>Disabling this flag will disable use of a highly optimized set of CPU kernels provided via
* the XNNPACK delegate. Currently, this is restricted to a subset of floating point operations.
* See
* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
* for more details.
*
* <p>WARNING: This is an experimental interface that is subject to change.
*/
public Options setUseXNNPACK(boolean useXNNPACK) {
this.useXNNPACK = useXNNPACK;
return this;
}
@Override
public Options setRuntime(InterpreterApi.Options.TfLiteRuntime runtime) {
super.setRuntime(runtime);
return this;
}
}
/**
* Initializes an {@code Interpreter}.
*
* @param modelFile a File of a pre-trained TF Lite model.
* @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
* model.
*/
public Interpreter(@NonNull File modelFile) {
this(modelFile, /*options = */ null);
}
/**
* Initializes an {@code Interpreter} and specifies options for customizing interpreter behavior.
*
* @param modelFile a file of a pre-trained TF Lite model
* @param options a set of options for customizing interpreter behavior
* @throws IllegalArgumentException if {@code modelFile} does not encode a valid TensorFlow Lite
* model.
*/
public Interpreter(@NonNull File modelFile, Options options) {
this(new NativeInterpreterWrapperExperimental(modelFile.getAbsolutePath(), options));
}
/**
* Initializes an {@code Interpreter} with a {@code ByteBuffer} of a model file.
*
* <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
* {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*
* @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a
* direct {@code ByteBuffer} of nativeOrder.
*/
public Interpreter(@NonNull ByteBuffer byteBuffer) {
this(byteBuffer, /* options= */ null);
}
/**
* Initializes an {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of
* custom {@link Interpreter.Options}.
*
* <p>The {@code ByteBuffer} should not be modified after the construction of an {@code
* Interpreter}. The {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps
* a model file, or a direct {@code ByteBuffer} of nativeOrder() that contains the bytes content
* of a model.
*
* @throws IllegalArgumentException if {@code byteBuffer} is not a {@code MappedByteBuffer} nor a
* direct {@code ByteBuffer} of nativeOrder.
*/
public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
this(new NativeInterpreterWrapperExperimental(byteBuffer, options));
}
private Interpreter(NativeInterpreterWrapperExperimental wrapper) {
super(wrapper);
wrapperExperimental = wrapper;
signatureKeyList = getSignatureKeys();
}
/**
* Runs model inference based on SignatureDef provided through {@code signatureKey}.
*
* <p>See {@link Interpreter#run(Object, Object)} for more details on the allowed input and output
* data types.
*
* <p>WARNING: This is an experimental API and subject to change.
*
* @param inputs A map from input name in the SignatureDef to an input object.
* @param outputs A map from output name in SignatureDef to output data. This may be empty if the
* caller wishes to query the {@link Tensor} data directly after inference (e.g., if the
* output shape is dynamic, or output buffer handles are used).
* @param signatureKey Signature key identifying the SignatureDef.
* @throws IllegalArgumentException if {@code inputs} is null or empty, if {@code outputs} or
* {@code signatureKey} is null, or if an error occurs when running inference.
*/
public void runSignature(
@NonNull Map<String, Object> inputs,
@NonNull Map<String, Object> outputs,
String signatureKey) {
checkNotClosed();
if (signatureKey == null && signatureKeyList.length == 1) {
signatureKey = signatureKeyList[0];
}
if (signatureKey == null) {
throw new IllegalArgumentException(
"Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
+ " model has a single Signature. Available Signatures: "
+ Arrays.toString(signatureKeyList));
}
wrapper.runSignature(inputs, outputs, signatureKey);
}
/**
* Same as {@link #runSignature(Map, Map, String)} but doesn't require passing a signatureKey,
* assuming the model has one SignatureDef. If the model has more than one SignatureDef it will
* throw an exception.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public void runSignature(
@NonNull Map<String, Object> inputs, @NonNull Map<String, Object> outputs) {
checkNotClosed();
runSignature(inputs, outputs, null);
}
/**
* Gets the Tensor associated with the provdied input name and signature method name.
*
* <p>WARNING: This is an experimental API and subject to change.
*
* @param inputName Input name in the signature.
* @param signatureKey Signature key identifying the SignatureDef, can be null if the model has
* one signature.
* @throws IllegalArgumentException if {@code inputName} or {@code signatureKey} is null or empty,
* or invalid name provided.
*/
public Tensor getInputTensorFromSignature(String inputName, String signatureKey) {
checkNotClosed();
if (signatureKey == null && signatureKeyList.length == 1) {
signatureKey = signatureKeyList[0];
}
if (signatureKey == null) {
throw new IllegalArgumentException(
"Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
+ " model has a single Signature. Available Signatures: "
+ Arrays.toString(signatureKeyList));
}
return wrapper.getInputTensor(inputName, signatureKey);
}
/**
* Gets the list of SignatureDef exported method names available in the model.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public String[] getSignatureKeys() {
checkNotClosed();
return wrapper.getSignatureKeys();
}
/**
* Gets the list of SignatureDefs inputs for method {@code signatureKey}.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public String[] getSignatureInputs(String signatureKey) {
checkNotClosed();
return wrapper.getSignatureInputs(signatureKey);
}
/**
* Gets the list of SignatureDefs outputs for method {@code signatureKey}.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public String[] getSignatureOutputs(String signatureKey) {
checkNotClosed();
return wrapper.getSignatureOutputs(signatureKey);
}
/**
* Gets the Tensor associated with the provdied output name in specifc signature method.
*
* <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
* is executed. If you need updated details *before* running inference (e.g., after resizing an
* input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
* explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
* that are dependent on input *values*, the output shape may not be fully determined until
* running inference.
*
* <p>WARNING: This is an experimental API and subject to change.
*
* @param outputName Output name in the signature.
* @param signatureKey Signature key identifying the SignatureDef, can be null if the model has
* one signature.
* @throws IllegalArgumentException if {@code outputName} or {@code signatureKey} is null or
* empty, or invalid name provided.
*/
public Tensor getOutputTensorFromSignature(String outputName, String signatureKey) {
checkNotClosed();
if (signatureKey == null && signatureKeyList.length == 1) {
signatureKey = signatureKeyList[0];
}
if (signatureKey == null) {
throw new IllegalArgumentException(
"Input error: SignatureDef signatureKey should not be null. null is only allowed if the"
+ " model has a single Signature. Available Signatures: "
+ Arrays.toString(signatureKeyList));
}
return wrapper.getOutputTensor(outputName, signatureKey);
}
/**
* Advanced: Resets all variable tensors to the default value.
*
* <p>If a variable tensor doesn't have an associated buffer, it will be reset to zero.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public void resetVariableTensors() {
checkNotClosed();
wrapperExperimental.resetVariableTensors();
}
/**
* Advanced: Interrupts inference in the middle of a call to {@link Interpreter#run}.
*
* <p>A cancellation flag will be set to true when this function gets called. The interpreter will
* check the flag between Op invocations, and if it's {@code true}, the interpreter will stop
* execution. The interpreter will remain a cancelled state until explicitly "uncancelled" by
* {@code setCancelled(false)}.
*
* <p>WARNING: This is an experimental API and subject to change.
*
* @param cancelled {@code true} to cancel inference in a best-effort way; {@code false} to
* resume.
* @throws IllegalStateException if the interpreter is not initialized with the cancellable
* option, which is by default off.
* @see Interpreter.Options#setCancellable(boolean).
*/
public void setCancelled(boolean cancelled) {
wrapper.setCancelled(cancelled);
}
NativeInterpreterWrapperExperimental wrapperExperimental;
String[] signatureKeyList;
}