| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| package org.tensorflow.lite.task.vision.classifier; |
| |
| import android.content.Context; |
| import android.graphics.Rect; |
| import android.os.ParcelFileDescriptor; |
| import java.io.File; |
| import java.io.IOException; |
| import java.nio.ByteBuffer; |
| import java.nio.MappedByteBuffer; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.List; |
| import org.tensorflow.lite.DataType; |
| import org.tensorflow.lite.annotations.UsedByReflection; |
| import org.tensorflow.lite.support.image.TensorImage; |
| import org.tensorflow.lite.task.core.BaseTaskApi; |
| import org.tensorflow.lite.task.core.TaskJniUtils; |
| import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; |
| import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; |
| import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; |
| |
| /** |
| * Performs classification on images. |
| * |
| * <p>The API expects a TFLite model with optional, but strongly recommended, <a |
| * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. |
| * |
| * <p>The API supports models with one image input tensor and one classification output tensor. To |
| * be more specific, here are the requirements. |
| * |
| * <ul> |
| * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) |
| * <ul> |
| * <li>image input of size {@code [batch x height x width x channels]}. |
| * <li>batch inference is not supported ({@code batch} is required to be 1). |
| * <li>only RGB inputs are supported ({@code channels} is required to be 3). |
| * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached |
| * to the metadata for input normalization. |
| * </ul> |
| * <li>Output score tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) |
| * <ul> |
| * <li>with {@code N} classes of either 2 or 4 dimensions, such as {@code [1 x N]} or {@code |
| * [1 x 1 x 1 x N]} |
| * <li>the label file is required to be packed to the metadata. See the <a |
| * href="https://www.tensorflow.org/lite/convert/metadata#label_output">example of |
| * creating metadata for an image classifier</a>. If no label files are packed, it will |
| * use index as label in the result. |
| * </ul> |
| * </ul> |
| * |
| * <p>An example of such model can be found on <a |
| * href="https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1">TensorFlow |
| * Hub.</a>. |
| */ |
| public final class ImageClassifier extends BaseTaskApi { |
| |
| private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni"; |
| private static final int OPTIONAL_FD_LENGTH = -1; |
| private static final int OPTIONAL_FD_OFFSET = -1; |
| |
| /** |
| * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. |
| * |
| * @param modelPath path of the classification model with metadata in the assets |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native |
| * code |
| */ |
| public static ImageClassifier createFromFile(Context context, String modelPath) |
| throws IOException { |
| return createFromFileAndOptions(context, modelPath, ImageClassifierOptions.builder().build()); |
| } |
| |
| /** |
| * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. |
| * |
| * @param modelFile the classification model {@link File} instance |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native |
| * code |
| */ |
| public static ImageClassifier createFromFile(File modelFile) throws IOException { |
| return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build()); |
| } |
| |
| /** |
| * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link |
| * ImageClassifierOptions}. |
| * |
| * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the |
| * classification model |
| * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native |
| * code |
| * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a |
| * {@link MappedByteBuffer} |
| */ |
| public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) { |
| return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build()); |
| } |
| |
| /** |
| * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}. |
| * |
| * @param modelPath path of the classification model with metadata in the assets |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native |
| * code |
| */ |
| public static ImageClassifier createFromFileAndOptions( |
| Context context, String modelPath, ImageClassifierOptions options) throws IOException { |
| return new ImageClassifier( |
| TaskJniUtils.createHandleFromFdAndOptions( |
| context, |
| new FdAndOptionsHandleProvider<ImageClassifierOptions>() { |
| @Override |
| public long createHandle( |
| int fileDescriptor, |
| long fileDescriptorLength, |
| long fileDescriptorOffset, |
| ImageClassifierOptions options) { |
| return initJniWithModelFdAndOptions( |
| fileDescriptor, fileDescriptorLength, fileDescriptorOffset, options); |
| } |
| }, |
| IMAGE_CLASSIFIER_NATIVE_LIB, |
| modelPath, |
| options)); |
| } |
| |
| /** |
| * Creates an {@link ImageClassifier} instance. |
| * |
| * @param modelFile the classification model {@link File} instance |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native |
| * code |
| */ |
| public static ImageClassifier createFromFileAndOptions( |
| File modelFile, final ImageClassifierOptions options) throws IOException { |
| try (ParcelFileDescriptor descriptor = |
| ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { |
| return new ImageClassifier( |
| TaskJniUtils.createHandleFromLibrary( |
| new TaskJniUtils.EmptyHandleProvider() { |
| @Override |
| public long createHandle() { |
| return initJniWithModelFdAndOptions( |
| descriptor.getFd(), |
| /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, |
| /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, |
| options); |
| } |
| }, |
| IMAGE_CLASSIFIER_NATIVE_LIB)); |
| } |
| } |
| |
| /** |
| * Creates an {@link ImageClassifier} instance with a model buffer and {@link |
| * ImageClassifierOptions}. |
| * |
| * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the |
| * classification model |
| * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native |
| * code |
| * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a |
| * {@link MappedByteBuffer} |
| */ |
| public static ImageClassifier createFromBufferAndOptions( |
| final ByteBuffer modelBuffer, final ImageClassifierOptions options) { |
| if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { |
| throw new IllegalArgumentException( |
| "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); |
| } |
| return new ImageClassifier( |
| TaskJniUtils.createHandleFromLibrary( |
| new EmptyHandleProvider() { |
| @Override |
| public long createHandle() { |
| return initJniWithByteBuffer(modelBuffer, options); |
| } |
| }, |
| IMAGE_CLASSIFIER_NATIVE_LIB)); |
| } |
| |
| /** |
| * Constructor to initialize the JNI with a pointer from C++. |
| * |
| * @param nativeHandle a pointer referencing memory allocated in C++ |
| */ |
| private ImageClassifier(long nativeHandle) { |
| super(nativeHandle); |
| } |
| |
| /** Options for setting up an ImageClassifier. */ |
| @UsedByReflection("image_classifier_jni.cc") |
| public static class ImageClassifierOptions { |
| // Not using AutoValue for this class because scoreThreshold cannot have default value |
| // (otherwise, the default value would override the one in the model metadata) and `Optional` is |
| // not an option here, because |
| // 1. java.util.Optional require Java 8 while we need to support Java 7. |
| // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the |
| // comments for labelAllowList. |
| private final String displayNamesLocale; |
| private final int maxResults; |
| private final float scoreThreshold; |
| private final boolean isScoreThresholdSet; |
| // As an open source project, we've been trying avoiding depending on common java libraries, |
| // such as Guava, because it may introduce conflicts with clients who also happen to use those |
| // libraries. Therefore, instead of using ImmutableList here, we convert the List into |
| // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less |
| // vulnerable. |
| private final List<String> labelAllowList; |
| private final List<String> labelDenyList; |
| private final int numThreads; |
| |
| public static Builder builder() { |
| return new Builder(); |
| } |
| |
| /** A builder that helps to configure an instance of ImageClassifierOptions. */ |
| public static class Builder { |
| private String displayNamesLocale = "en"; |
| private int maxResults = -1; |
| private float scoreThreshold; |
| private boolean isScoreThresholdSet = false; |
| private List<String> labelAllowList = new ArrayList<>(); |
| private List<String> labelDenyList = new ArrayList<>(); |
| private int numThreads = -1; |
| |
| private Builder() {} |
| |
| /** |
| * Sets the locale to use for display names specified through the TFLite Model Metadata, if |
| * any. |
| * |
| * <p>Defaults to English({@code "en"}). See the <a |
| * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite |
| * Metadata schema file.</a> for the accepted pattern of locale. |
| */ |
| public Builder setDisplayNamesLocale(String displayNamesLocale) { |
| this.displayNamesLocale = displayNamesLocale; |
| return this; |
| } |
| |
| /** |
| * Sets the maximum number of top scored results to return. |
| * |
| * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned. |
| * Defaults to -1. |
| * |
| * @throws IllegalArgumentException if maxResults is 0. |
| */ |
| public Builder setMaxResults(int maxResults) { |
| if (maxResults == 0) { |
| throw new IllegalArgumentException("maxResults cannot be 0."); |
| } |
| this.maxResults = maxResults; |
| return this; |
| } |
| |
| /** |
| * Sets the score threshold in [0,1). |
| * |
| * <p>It overrides the one provided in the model metadata (if any). Results below this value |
| * are rejected. |
| */ |
| public Builder setScoreThreshold(float scoreThreshold) { |
| this.scoreThreshold = scoreThreshold; |
| isScoreThresholdSet = true; |
| return this; |
| } |
| |
| /** |
| * Sets the optional allowlist of labels. |
| * |
| * <p>If non-empty, classifications whose label is not in this set will be filtered out. |
| * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList. |
| */ |
| public Builder setLabelAllowList(List<String> labelAllowList) { |
| this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); |
| return this; |
| } |
| |
| /** |
| * Sets the optional denylist of labels. |
| * |
| * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate |
| * or unknown labels are ignored. Mutually exclusive with labelAllowList. |
| */ |
| public Builder setLabelDenyList(List<String> labelDenyList) { |
| this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); |
| return this; |
| } |
| |
| /** |
| * Sets the number of threads to be used for TFLite ops that support multi-threading when |
| * running inference with CPU. Defaults to -1. |
| * |
| * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the |
| * effect to let TFLite runtime set the value. |
| */ |
| public Builder setNumThreads(int numThreads) { |
| this.numThreads = numThreads; |
| return this; |
| } |
| |
| public ImageClassifierOptions build() { |
| return new ImageClassifierOptions(this); |
| } |
| } |
| |
| @UsedByReflection("image_classifier_jni.cc") |
| public String getDisplayNamesLocale() { |
| return displayNamesLocale; |
| } |
| |
| @UsedByReflection("image_classifier_jni.cc") |
| public int getMaxResults() { |
| return maxResults; |
| } |
| |
| @UsedByReflection("image_classifier_jni.cc") |
| public float getScoreThreshold() { |
| return scoreThreshold; |
| } |
| |
| @UsedByReflection("image_classifier_jni.cc") |
| public boolean getIsScoreThresholdSet() { |
| return isScoreThresholdSet; |
| } |
| |
| @UsedByReflection("image_classifier_jni.cc") |
| public List<String> getLabelAllowList() { |
| return new ArrayList<>(labelAllowList); |
| } |
| |
| @UsedByReflection("image_classifier_jni.cc") |
| public List<String> getLabelDenyList() { |
| return new ArrayList<>(labelDenyList); |
| } |
| |
| @UsedByReflection("image_classifier_jni.cc") |
| public int getNumThreads() { |
| return numThreads; |
| } |
| |
| private ImageClassifierOptions(Builder builder) { |
| displayNamesLocale = builder.displayNamesLocale; |
| maxResults = builder.maxResults; |
| scoreThreshold = builder.scoreThreshold; |
| isScoreThresholdSet = builder.isScoreThresholdSet; |
| labelAllowList = builder.labelAllowList; |
| labelDenyList = builder.labelDenyList; |
| numThreads = builder.numThreads; |
| } |
| } |
| |
| /** |
| * Performs actual classification on the provided image. |
| * |
| * @param image a {@link TensorImage} object that represents an RGB image |
| * @throws AssertionError if error occurs when classifying the image from the native code |
| */ |
| public List<Classifications> classify(TensorImage image) { |
| return classify(image, ImageProcessingOptions.builder().build()); |
| } |
| |
| /** |
| * Performs actual classification on the provided image with {@link ImageProcessingOptions}. |
| * |
| * <p>{@link ImageClassifier} supports the following options: |
| * |
| * <ul> |
| * <li>Region of interest (ROI) (through {@link ImageProcessingOptions#Builder#setRoi}). It |
| * defaults to the entire image. |
| * <li>image rotation (through {@link ImageProcessingOptions#Builder#setOrientation}). It |
| * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. |
| * </ul> |
| * |
| * @param image a {@link TensorImage} object that represents an RGB image |
| * @throws AssertionError if error occurs when classifying the image from the native code |
| */ |
| public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) { |
| checkNotClosed(); |
| |
| // image_classifier_jni.cc expects an uint8 image. Convert image of other types into uint8. |
| TensorImage imageUint8 = |
| image.getDataType() == DataType.UINT8 |
| ? image |
| : TensorImage.createFrom(image, DataType.UINT8); |
| |
| Rect roi = |
| options.getRoi().isEmpty() |
| ? new Rect(0, 0, imageUint8.getWidth(), imageUint8.getHeight()) |
| : options.getRoi(); |
| |
| return classifyNative( |
| getNativeHandle(), |
| imageUint8.getBuffer(), |
| imageUint8.getWidth(), |
| imageUint8.getHeight(), |
| new int[] {roi.left, roi.top, roi.width(), roi.height()}, |
| options.getOrientation().getValue()); |
| } |
| |
| private static native long initJniWithModelFdAndOptions( |
| int fileDescriptor, |
| long fileDescriptorLength, |
| long fileDescriptorOffset, |
| ImageClassifierOptions options); |
| |
| private static native long initJniWithByteBuffer( |
| ByteBuffer modelBuffer, ImageClassifierOptions options); |
| |
| /** |
| * The native method to classify an image with the ROI and orientation. |
| * |
| * @param roi the ROI of the input image, an array representing the bounding box as {left, top, |
| * width, height} |
| * @param orientation the integer value corresponding to {@link |
| * ImageProcessingOptions#Orientation} |
| */ |
| private static native List<Classifications> classifyNative( |
| long nativeHandle, ByteBuffer image, int width, int height, int[] roi, int orientation); |
| |
| @Override |
| protected void deinit(long nativeHandle) { |
| deinitJni(nativeHandle); |
| } |
| |
| /** |
| * Native implementation to release memory pointed by the pointer. |
| * |
| * @param nativeHandle pointer to memory allocated |
| */ |
| private native void deinitJni(long nativeHandle); |
| } |