| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| package org.tensorflow.lite; |
| |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.List; |
| import java.util.Map; |
| import org.checkerframework.checker.nullness.qual.NonNull; |
| |
| /** |
| * Interface to TensorFlow Lite model interpreter, excluding experimental methods. |
| * |
| * <p>An {@code InterpreterApi} instance encapsulates a pre-trained TensorFlow Lite model, in which |
| * operations are executed for model inference. |
| * |
| * <p>For example, if a model takes only one input and returns only one output: |
| * |
| * <pre>{@code |
| * try (InterpreterApi interpreter = |
| * new InterpreterFactory().create(file_of_a_tensorflowlite_model)) { |
| * interpreter.run(input, output); |
| * } |
| * }</pre> |
| * |
| * <p>If a model takes multiple inputs or outputs: |
| * |
| * <pre>{@code |
| * Object[] inputs = {input0, input1, ...}; |
| * Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>(); |
| * FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4. |
| * ith_output.order(ByteOrder.nativeOrder()); |
| * map_of_indices_to_outputs.put(i, ith_output); |
| * try (InterpreterApi interpreter = |
| * new InterpreterFactory().create(file_of_a_tensorflowlite_model)) { |
| * interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); |
| * } |
| * }</pre> |
| * |
| * <p>If a model takes or produces string tensors: |
| * |
| * <pre>{@code |
| * String[] input = {"foo", "bar"}; // Input tensor shape is [2]. |
| * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. |
| * try (InterpreterApi interpreter = |
| * new InterpreterFactory().create(file_of_a_tensorflowlite_model)) { |
| * interpreter.runForMultipleInputsOutputs(input, output); |
| * } |
| * }</pre> |
| * |
| * <p>Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite |
| * model with Toco, as are the default shapes of the inputs. |
| * |
| * <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will |
| * be implicitly resized according to that array's shape. When inputs are provided as {@link |
| * java.nio.Buffer} types, no implicit resizing is done; the caller must ensure that the {@link |
| * java.nio.Buffer} byte size either matches that of the corresponding tensor, or that they first |
| * resize the tensor via {@link #resizeInput(int, int[])}. Tensor shape and type information can be |
| * obtained via the {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link |
| * #getOutputTensor(int)}. |
| * |
| * <p><b>WARNING:</b>{@code InterpreterApi} instances are <b>not</b> thread-safe. |
| * |
| * <p><b>WARNING:</b>An {@code InterpreterApi} instance owns resources that <b>must</b> be |
| * explicitly freed by invoking {@link #close()} |
| * |
| * <p>The TFLite library is built against NDK API 19. It may work for Android API levels below 19, |
| * but is not guaranteed. |
| * |
| * @see InterpreterFactory |
| */ |
| public interface InterpreterApi extends AutoCloseable { |
| |
| /** An options class for controlling runtime interpreter behavior. */ |
| public static class Options { |
| public Options() { |
| this.delegates = new ArrayList<>(); |
| } |
| |
| public Options(Options other) { |
| this.numThreads = other.numThreads; |
| this.useNNAPI = other.useNNAPI; |
| this.allowCancellation = other.allowCancellation; |
| this.delegates = new ArrayList<>(other.delegates); |
| this.runtime = other.runtime; |
| } |
| |
| /** |
| * Sets the number of threads to be used for ops that support multi-threading. |
| * |
| * <p>{@code numThreads} should be {@code >= -1}. Setting {@code numThreads} to 0 has the effect |
| * of disabling multithreading, which is equivalent to setting {@code numThreads} to 1. If |
| * unspecified, or set to the value -1, the number of threads used will be |
| * implementation-defined and platform-dependent. |
| */ |
| public Options setNumThreads(int numThreads) { |
| this.numThreads = numThreads; |
| return this; |
| } |
| |
| /** |
| * Returns the number of threads to be used for ops that support multi-threading. |
| * |
| * <p>{@code numThreads} should be {@code >= -1}. Values of 0 (or 1) disable multithreading. |
| * Default value is -1: the number of threads used will be implementation-defined and |
| * platform-dependent. |
| */ |
| public int getNumThreads() { |
| return numThreads; |
| } |
| |
| /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */ |
| public Options setUseNNAPI(boolean useNNAPI) { |
| this.useNNAPI = useNNAPI; |
| return this; |
| } |
| |
| /** |
| * Returns whether to use NN API (if available) for op execution. Default value is false |
| * (disabled). |
| */ |
| public boolean getUseNNAPI() { |
| return useNNAPI != null && useNNAPI; |
| } |
| |
| /** |
| * Advanced: Set if the interpreter is able to be cancelled. |
| * |
| * <p>Interpreters may have an experimental API <a |
| * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter#setCancelled(boolean)">setCancelled(boolean)</a>. |
| * If this interpreter is cancellable and such a method is invoked, a cancellation flag will be |
| * set to true. The interpreter will check the flag between Op invocations, and if it's {@code |
| * true}, the interpreter will stop execution. The interpreter will remain a cancelled state |
| * until explicitly "uncancelled" by {@code setCancelled(false)}. |
| */ |
| public Options setCancellable(boolean allow) { |
| this.allowCancellation = allow; |
| return this; |
| } |
| |
| /** |
| * Advanced: Returns whether the interpreter is able to be cancelled. |
| * |
| * <p>Interpreters may have an experimental API <a |
| * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter#setCancelled(boolean)">setCancelled(boolean)</a>. |
| * If this interpreter is cancellable and such a method is invoked, a cancellation flag will be |
| * set to true. The interpreter will check the flag between Op invocations, and if it's {@code |
| * true}, the interpreter will stop execution. The interpreter will remain a cancelled state |
| * until explicitly "uncancelled" by {@code setCancelled(false)}. |
| */ |
| public boolean isCancellable() { |
| return allowCancellation != null && allowCancellation; |
| } |
| |
| /** Adds a {@link Delegate} to be applied during interpreter creation. */ |
| public Options addDelegate(Delegate delegate) { |
| delegates.add(delegate); |
| return this; |
| } |
| |
| /** |
| * Returns the list of delegates intended to be applied during interpreter creation (that have |
| * been registered via {@code addDelegate}). |
| */ |
| public List<Delegate> getDelegates() { |
| return Collections.unmodifiableList(delegates); |
| } |
| |
| /** Enum to represent where to get the TensorFlow Lite runtime implementation from. */ |
| public static enum TfLiteRuntime { |
| /** |
| * Use a TF Lite runtime implementation that is linked into the application. If there is no |
| * suitable TF Lite runtime implementation linked into the application, then attempting to |
| * create an InterpreterApi instance with this TfLiteRuntime setting will throw an |
| * IllegalStateException exception (even if the OS or system services could provide a TF Lite |
| * runtime implementation). |
| * |
| * <p>This is the default setting. This setting is also appropriate for apps that must run on |
| * systems that don't provide a TF Lite runtime implementation. |
| */ |
| FROM_APPLICATION_ONLY, |
| |
| /** |
| * Use a TF Lite runtime implementation provided by the OS or system services. This will be |
| * obtained from a system library / shared object / service, such as Google Play Services. It |
| * may be newer than the version linked into the application (if any). If there is no suitable |
| * TF Lite runtime implementation provided by the system, then attempting to create an |
| * InterpreterApi instance with this TfLiteRuntime setting will throw an IllegalStateException |
| * exception (even if there is a TF Lite runtime implementation linked into the application). |
| * |
| * <p>This setting is appropriate for code that will use a system-provided TF Lite runtime, |
| * which can reduce app binary size and can be updated more frequently. |
| */ |
| FROM_SYSTEM_ONLY, |
| |
| /** |
| * Use a system-provided TF Lite runtime implementation, if any, otherwise use the TF Lite |
| * runtime implementation linked into the application, if any. If no suitable TF Lite runtime |
| * can be found in any location, then attempting to create an InterpreterApi instance with |
| * this TFLiteRuntime setting will throw an IllegalStateException. If there is both a suitable |
| * TF Lite runtime linked into the application and also a suitable TF Lite runtime provided by |
| * the system, the one provided by the system will be used. |
| * |
| * <p>This setting is suitable for use in code that doesn't care where the TF Lite runtime is |
| * coming from (e.g. middleware layers). |
| */ |
| PREFER_SYSTEM_OVER_APPLICATION, |
| }; |
| |
| /** Method for specifying where to get the TF Lite runtime implementation from. */ |
| public Options setRuntime(TfLiteRuntime runtime) { |
| this.runtime = runtime; |
| return this; |
| } |
| |
| TfLiteRuntime runtime = TfLiteRuntime.FROM_APPLICATION_ONLY; |
| int numThreads = -1; |
| Boolean useNNAPI; |
| Boolean allowCancellation; |
| |
| // See InterpreterApi.Options#addDelegate(boolean). |
| final List<Delegate> delegates; |
| } |
| |
| /** |
| * Runs model inference if the model takes only one input, and provides only one output. |
| * |
| * <p>Warning: The API is more efficient if a {@code Buffer} (preferably direct, but not required) |
| * is used as the input/output data type. Please consider using {@code Buffer} to feed and fetch |
| * primitive data for better performance. The following concrete {@code Buffer} types are |
| * supported: |
| * |
| * <ul> |
| * <li>{@code ByteBuffer} - compatible with any underlying primitive Tensor type. |
| * <li>{@code FloatBuffer} - compatible with float Tensors. |
| * <li>{@code IntBuffer} - compatible with int32 Tensors. |
| * <li>{@code LongBuffer} - compatible with int64 Tensors. |
| * </ul> |
| * |
| * Note that boolean types are only supported as arrays, not {@code Buffer}s, or as scalar inputs. |
| * |
| * @param input an array or multidimensional array, or a {@code Buffer} of primitive types |
| * including int, float, long, and byte. {@code Buffer} is the preferred way to pass large |
| * input data for primitive types, whereas string types require using the (multi-dimensional) |
| * array input path. When a {@code Buffer} is used, its content should remain unchanged until |
| * model inference is done, and the caller must ensure that the {@code Buffer} is at the |
| * appropriate read position. A {@code null} value is allowed only if the caller is using a |
| * {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to the |
| * input {@link Tensor}. |
| * @param output a multidimensional array of output data, or a {@code Buffer} of primitive types |
| * including int, float, long, and byte. When a {@code Buffer} is used, the caller must ensure |
| * that it is set the appropriate write position. A null value is allowed, and is useful for |
| * certain cases, e.g., if the caller is using a {@link Delegate} that allows buffer handle |
| * interop, and such a buffer has been bound to the output {@link Tensor} (see also <a |
| * href="https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/Interpreter.Options#setAllowBufferHandleOutput(boolean)">Interpreter.Options#setAllowBufferHandleOutput(boolean)</a>), |
| * or if the graph has dynamically shaped outputs and the caller must query the output {@link |
| * Tensor} shape after inference has been invoked, fetching the data directly from the output |
| * tensor (via {@link Tensor#asReadOnlyBuffer()}). |
| * @throws IllegalArgumentException if {@code input} is null or empty, or if an error occurs when |
| * running inference. |
| * @throws IllegalArgumentException (EXPERIMENTAL, subject to change) if the inference is |
| * interrupted by {@code setCancelled(true)}. |
| */ |
| public void run(Object input, Object output); |
| |
| /** |
| * Runs model inference if the model takes multiple inputs, or returns multiple outputs. |
| * |
| * <p>Warning: The API is more efficient if {@code Buffer}s (preferably direct, but not required) |
| * are used as the input/output data types. Please consider using {@code Buffer} to feed and fetch |
| * primitive data for better performance. The following concrete {@code Buffer} types are |
| * supported: |
| * |
| * <ul> |
| * <li>{@code ByteBuffer} - compatible with any underlying primitive Tensor type. |
| * <li>{@code FloatBuffer} - compatible with float Tensors. |
| * <li>{@code IntBuffer} - compatible with int32 Tensors. |
| * <li>{@code LongBuffer} - compatible with int64 Tensors. |
| * </ul> |
| * |
| * Note that boolean types are only supported as arrays, not {@code Buffer}s, or as scalar inputs. |
| * |
| * <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is |
| * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and |
| * such a buffer has been bound to the corresponding input or output {@link Tensor}(s). |
| * |
| * @param inputs an array of input data. The inputs should be in the same order as inputs of the |
| * model. Each input can be an array or multidimensional array, or a {@code Buffer} of |
| * primitive types including int, float, long, and byte. {@code Buffer} is the preferred way |
| * to pass large input data, whereas string types require using the (multi-dimensional) array |
| * input path. When {@code Buffer} is used, its content should remain unchanged until model |
| * inference is done, and the caller must ensure that the {@code Buffer} is at the appropriate |
| * read position. |
| * @param outputs a map mapping output indices to multidimensional arrays of output data or {@code |
| * Buffer}s of primitive types including int, float, long, and byte. It only needs to keep |
| * entries for the outputs to be used. When a {@code Buffer} is used, the caller must ensure |
| * that it is set the appropriate write position. The map may be empty for cases where either |
| * buffer handles are used for output tensor data, or cases where the outputs are dynamically |
| * shaped and the caller must query the output {@link Tensor} shape after inference has been |
| * invoked, fetching the data directly from the output tensor (via {@link |
| * Tensor#asReadOnlyBuffer()}). |
| * @throws IllegalArgumentException if {@code inputs} is null or empty, if {@code outputs} is |
| * null, or if an error occurs when running inference. |
| */ |
| public void runForMultipleInputsOutputs( |
| Object @NonNull [] inputs, @NonNull Map<Integer, Object> outputs); |
| |
| /** |
| * Explicitly updates allocations for all tensors, if necessary. |
| * |
| * <p>This will propagate shapes and memory allocations for dependent tensors using the input |
| * tensor shape(s) as given. |
| * |
| * <p>Note: This call is *purely optional*. Tensor allocation will occur automatically during |
| * execution if any input tensors have been resized. This call is most useful in determining the |
| * shapes for any output tensors before executing the graph, e.g., |
| * |
| * <pre>{@code |
| * interpreter.resizeInput(0, new int[]{1, 4, 4, 3})); |
| * interpreter.allocateTensors(); |
| * FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0).numElements()); |
| * // Populate inputs... |
| * FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements()); |
| * interpreter.run(input, output) |
| * // Process outputs... |
| * }</pre> |
| * |
| * <p>Note: Some graphs have dynamically shaped outputs, in which case the output shape may not |
| * fully propagate until inference is executed. |
| * |
| * @throws IllegalStateException if the graph's tensors could not be successfully allocated. |
| */ |
| public void allocateTensors(); |
| |
| /** |
| * Resizes idx-th input of the native model to the given dims. |
| * |
| * @throws IllegalArgumentException if {@code idx} is negative or is not smaller than the number |
| * of model inputs; or if error occurs when resizing the idx-th input. |
| */ |
| public void resizeInput(int idx, @NonNull int[] dims); |
| |
| /** |
| * Resizes idx-th input of the native model to the given dims. |
| * |
| * <p>When `strict` is True, only unknown dimensions can be resized. Unknown dimensions are |
| * indicated as `-1` in the array returned by `Tensor.shapeSignature()`. |
| * |
| * @throws IllegalArgumentException if {@code idx} is negative or is not smaller than the number |
| * of model inputs; or if error occurs when resizing the idx-th input. Additionally, the error |
| * occurs when attempting to resize a tensor with fixed dimensions when `strict` is True. |
| */ |
| public void resizeInput(int idx, @NonNull int[] dims, boolean strict); |
| |
| /** Gets the number of input tensors. */ |
| public int getInputTensorCount(); |
| |
| /** |
| * Gets index of an input given the op name of the input. |
| * |
| * @throws IllegalArgumentException if {@code opName} does not match any input in the model used |
| * to initialize the interpreter. |
| */ |
| public int getInputIndex(String opName); |
| |
| /** |
| * Gets the Tensor associated with the provdied input index. |
| * |
| * @throws IllegalArgumentException if {@code inputIndex} is negative or is not smaller than the |
| * number of model inputs. |
| */ |
| public Tensor getInputTensor(int inputIndex); |
| |
| /** Gets the number of output Tensors. */ |
| public int getOutputTensorCount(); |
| |
| /** |
| * Gets index of an output given the op name of the output. |
| * |
| * @throws IllegalArgumentException if {@code opName} does not match any output in the model used |
| * to initialize the interpreter. |
| */ |
| public int getOutputIndex(String opName); |
| |
| /** |
| * Gets the Tensor associated with the provdied output index. |
| * |
| * <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference |
| * is executed. If you need updated details *before* running inference (e.g., after resizing an |
| * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to |
| * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes |
| * that are dependent on input *values*, the output shape may not be fully determined until |
| * running inference. |
| * |
| * @throws IllegalArgumentException if {@code outputIndex} is negative or is not smaller than the |
| * number of model outputs. |
| */ |
| public Tensor getOutputTensor(int outputIndex); |
| |
| /** |
| * Returns native inference timing. |
| * |
| * @throws IllegalArgumentException if the model is not initialized by the interpreter. |
| */ |
| public Long getLastNativeInferenceDurationNanoseconds(); |
| |
| /** Release resources associated with the {@code InterpreterApi} instance. */ |
| @Override |
| public void close(); |
| } |