| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| package org.tensorflow.lite.task.vision.segmenter; |
| |
| import android.content.Context; |
| import android.content.res.AssetFileDescriptor; |
| import android.os.ParcelFileDescriptor; |
| import com.google.auto.value.AutoValue; |
| import java.io.File; |
| import java.io.IOException; |
| import java.nio.ByteBuffer; |
| import java.nio.ByteOrder; |
| import java.nio.MappedByteBuffer; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.List; |
| import org.tensorflow.lite.DataType; |
| import org.tensorflow.lite.support.image.TensorImage; |
| import org.tensorflow.lite.task.core.BaseTaskApi; |
| import org.tensorflow.lite.task.core.TaskJniUtils; |
| import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; |
| import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; |
| |
| /** |
| * Performs segmentation on images. |
| * |
| * <p>The API expects a TFLite model with <a |
| * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. |
| * |
| * <p>The API supports models with one image input tensor and one output tensor. To be more |
| * specific, here are the requirements. |
| * |
| * <ul> |
| * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) |
| * <ul> |
| * <li>image input of size {@code [batch x height x width x channels]}. |
| * <li>batch inference is not supported ({@code batch} is required to be 1). |
| * <li>only RGB inputs are supported ({@code channels} is required to be 3). |
| * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached |
| * to the metadata for input normalization. |
| * </ul> |
| * <li>Output image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) |
| * <ul> |
| * <li>tensor of size {@code [batch x mask_height x mask_width x num_classes]}, where {@code |
| * batch} is required to be 1, {@code mask_width} and {@code mask_height} are the |
| * dimensions of the segmentation masks produced by the model, and {@code num_classes} |
| * is the number of classes supported by the model. |
| * <li>optional (but recommended) label map(s) can be attached as AssociatedFile-s with type |
| * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if |
| * any) is used to fill the class name, i.e. {@link ColoredLabel#getClassName} of the |
| * results. The display name, i.e. {@link ColoredLabel#getDisplayName}, is filled from |
| * the AssociatedFile (if any) whose locale matches the `display_names_locale` field of |
| * the `ImageSegmenterOptions` used at creation time ("en" by default, i.e. English). If |
| * none of these are available, only the `index` field of the results will be filled. |
| * </ul> |
| * </ul> |
| * |
| * <p>An example of such model can be found on <a |
| * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>. |
| */ |
| public final class ImageSegmenter extends BaseTaskApi { |
| |
| private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni"; |
| private static final int OPTIONAL_FD_LENGTH = -1; |
| private static final int OPTIONAL_FD_OFFSET = -1; |
| |
| private final OutputType outputType; |
| |
| /** |
| * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. |
| * |
| * @param modelPath path of the segmentation model with metadata in the assets |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native |
| * code |
| */ |
| public static ImageSegmenter createFromFile(Context context, String modelPath) |
| throws IOException { |
| return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build()); |
| } |
| |
| /** |
| * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. |
| * |
| * @param modelFile the segmentation model {@link File} instance |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native |
| * code |
| */ |
| public static ImageSegmenter createFromFile(File modelFile) throws IOException { |
| return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build()); |
| } |
| |
| /** |
| * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link |
| * ImageSegmenterOptions}. |
| * |
| * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the |
| * classification model |
| * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native |
| * code |
| * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a |
| * {@link MappedByteBuffer} |
| */ |
| public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) { |
| return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build()); |
| } |
| |
| /** |
| * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. |
| * |
| * @param modelPath path of the segmentation model with metadata in the assets |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native |
| * code |
| */ |
| public static ImageSegmenter createFromFileAndOptions( |
| Context context, String modelPath, final ImageSegmenterOptions options) throws IOException { |
| try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { |
| return createFromModelFdAndOptions( |
| /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), |
| /*fileDescriptorLength=*/ assetFileDescriptor.getLength(), |
| /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), |
| options); |
| } |
| } |
| |
| /** |
| * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. |
| * |
| * @param modelFile the segmentation model {@link File} instance |
| * @throws IOException if an I/O error occurs when loading the tflite model |
| * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native |
| * code |
| */ |
| public static ImageSegmenter createFromFileAndOptions( |
| File modelFile, final ImageSegmenterOptions options) throws IOException { |
| try (ParcelFileDescriptor descriptor = |
| ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { |
| return createFromModelFdAndOptions( |
| /*fileDescriptor=*/ descriptor.getFd(), |
| /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, |
| /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, |
| options); |
| } |
| } |
| |
| /** |
| * Creates an {@link ImageSegmenter} instance with a model buffer and {@link |
| * ImageSegmenterOptions}. |
| * |
| * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the |
| * classification model |
| * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native |
| * code |
| * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a |
| * {@link MappedByteBuffer} |
| */ |
| public static ImageSegmenter createFromBufferAndOptions( |
| final ByteBuffer modelBuffer, final ImageSegmenterOptions options) { |
| if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { |
| throw new IllegalArgumentException( |
| "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); |
| } |
| return new ImageSegmenter( |
| TaskJniUtils.createHandleFromLibrary( |
| new EmptyHandleProvider() { |
| @Override |
| public long createHandle() { |
| return initJniWithByteBuffer( |
| modelBuffer, |
| options.getDisplayNamesLocale(), |
| options.getOutputType().getValue(), |
| options.getNumThreads()); |
| } |
| }, |
| IMAGE_SEGMENTER_NATIVE_LIB), |
| options.getOutputType()); |
| } |
| |
| /** |
| * Constructor to initialize the JNI with a pointer from C++. |
| * |
| * @param nativeHandle a pointer referencing memory allocated in C++ |
| */ |
| private ImageSegmenter(long nativeHandle, OutputType outputType) { |
| super(nativeHandle); |
| this.outputType = outputType; |
| } |
| |
| /** Options for setting up an {@link ImageSegmenter}. */ |
| @AutoValue |
| public abstract static class ImageSegmenterOptions { |
| private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en"; |
| private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK; |
| private static final int NUM_THREADS = -1; |
| |
| public abstract String getDisplayNamesLocale(); |
| |
| public abstract OutputType getOutputType(); |
| |
| public abstract int getNumThreads(); |
| |
| public static Builder builder() { |
| return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() |
| .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE) |
| .setOutputType(DEFAULT_OUTPUT_TYPE) |
| .setNumThreads(NUM_THREADS); |
| } |
| |
| /** Builder for {@link ImageSegmenterOptions}. */ |
| @AutoValue.Builder |
| public abstract static class Builder { |
| |
| /** |
| * Sets the locale to use for display names specified through the TFLite Model Metadata, if |
| * any. |
| * |
| * <p>Defaults to English({@code "en"}). See the <a |
| * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite |
| * Metadata schema file.</a> for the accepted pattern of locale. |
| */ |
| public abstract Builder setDisplayNamesLocale(String displayNamesLocale); |
| |
| public abstract Builder setOutputType(OutputType outputType); |
| |
| /** |
| * Sets the number of threads to be used for TFLite ops that support multi-threading when |
| * running inference with CPU. Defaults to -1. |
| * |
| * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the |
| * effect to let TFLite runtime set the value. |
| */ |
| public abstract Builder setNumThreads(int numThreads); |
| |
| public abstract ImageSegmenterOptions build(); |
| } |
| } |
| |
| /** |
| * Performs actual segmentation on the provided image. |
| * |
| * @param image a {@link TensorImage} object that represents an RGB image |
| * @return results of performing image segmentation. Note that at the time, a single {@link |
| * Segmentation} element is expected to be returned. The result is stored in a {@link List} |
| * for later extension to e.g. instance segmentation models, which may return one segmentation |
| * per object. |
| * @throws AssertionError if error occurs when segmenting the image from the native code |
| */ |
| public List<Segmentation> segment(TensorImage image) { |
| return segment(image, ImageProcessingOptions.builder().build()); |
| } |
| |
| /** |
| * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}. |
| * |
| * @param image a {@link TensorImage} object that represents an RGB image |
| * @param options {@link ImageSegmenter} only supports image rotation (through {@link |
| * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image |
| * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. |
| * @return results of performing image segmentation. Note that at the time, a single {@link |
| * Segmentation} element is expected to be returned. The result is stored in a {@link List} |
| * for later extension to e.g. instance segmentation models, which may return one segmentation |
| * per object. |
| * @throws AssertionError if error occurs when segmenting the image from the native code |
| */ |
| public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) { |
| checkNotClosed(); |
| |
| // image_segmenter_jni.cc expects an uint8 image. Convert image of other types into uint8. |
| TensorImage imageUint8 = |
| image.getDataType() == DataType.UINT8 |
| ? image |
| : TensorImage.createFrom(image, DataType.UINT8); |
| List<byte[]> maskByteArrays = new ArrayList<>(); |
| List<ColoredLabel> coloredLabels = new ArrayList<>(); |
| int[] maskShape = new int[2]; |
| segmentNative( |
| getNativeHandle(), |
| imageUint8.getBuffer(), |
| imageUint8.getWidth(), |
| imageUint8.getHeight(), |
| maskByteArrays, |
| maskShape, |
| coloredLabels, |
| options.getOrientation().getValue()); |
| |
| List<ByteBuffer> maskByteBuffers = new ArrayList<>(); |
| for (byte[] bytes : maskByteArrays) { |
| ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); |
| // Change the byte order to little_endian, since the buffers were generated in jni. |
| byteBuffer.order(ByteOrder.LITTLE_ENDIAN); |
| maskByteBuffers.add(byteBuffer); |
| } |
| |
| return Arrays.asList( |
| Segmentation.create( |
| outputType, |
| outputType.createMasksFromBuffer(maskByteBuffers, maskShape), |
| coloredLabels)); |
| } |
| |
| private static ImageSegmenter createFromModelFdAndOptions( |
| final int fileDescriptor, |
| final long fileDescriptorLength, |
| final long fileDescriptorOffset, |
| final ImageSegmenterOptions options) { |
| long nativeHandle = |
| TaskJniUtils.createHandleFromLibrary( |
| new EmptyHandleProvider() { |
| @Override |
| public long createHandle() { |
| return initJniWithModelFdAndOptions( |
| fileDescriptor, |
| fileDescriptorLength, |
| fileDescriptorOffset, |
| options.getDisplayNamesLocale(), |
| options.getOutputType().getValue(), |
| options.getNumThreads()); |
| } |
| }, |
| IMAGE_SEGMENTER_NATIVE_LIB); |
| return new ImageSegmenter(nativeHandle, options.getOutputType()); |
| } |
| |
| private static native long initJniWithModelFdAndOptions( |
| int fileDescriptor, |
| long fileDescriptorLength, |
| long fileDescriptorOffset, |
| String displayNamesLocale, |
| int outputType, |
| int numThreads); |
| |
| private static native long initJniWithByteBuffer( |
| ByteBuffer modelBuffer, String displayNamesLocale, int outputType, int numThreads); |
| |
| /** |
| * The native method to segment the image. |
| * |
| * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native |
| * layer. |
| */ |
| private static native void segmentNative( |
| long nativeHandle, |
| ByteBuffer image, |
| int width, |
| int height, |
| List<byte[]> maskByteArrays, |
| int[] maskShape, |
| List<ColoredLabel> coloredLabels, |
| int orientation); |
| |
| @Override |
| protected void deinit(long nativeHandle) { |
| deinitJni(nativeHandle); |
| } |
| |
| /** |
| * Native implementation to release memory pointed by the pointer. |
| * |
| * @param nativeHandle pointer to memory allocated |
| */ |
| private native void deinitJni(long nativeHandle); |
| } |