| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| package org.tensorflow.lite.task.text.nlclassifier; |
| |
| import android.content.Context; |
| import android.os.ParcelFileDescriptor; |
| import com.google.auto.value.AutoValue; |
| import java.io.File; |
| import java.io.IOException; |
| import java.nio.ByteBuffer; |
| import java.nio.MappedByteBuffer; |
| import java.util.List; |
| import org.tensorflow.lite.annotations.UsedByReflection; |
| import org.tensorflow.lite.support.label.Category; |
| import org.tensorflow.lite.task.core.BaseTaskApi; |
| import org.tensorflow.lite.task.core.TaskJniUtils; |
| import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; |
| |
| /** |
| * Classifier API for natural language classification tasks, categorizes string into different |
| * classes. |
| * |
| * <p>The API expects a TFLite model with the following input/output tensor: |
| * |
| * <ul> |
| * <li>Input tensor (kTfLiteString) |
| * <ul> |
| * <li>input of the model, accepts a string. |
| * </ul> |
| * <li>Output score tensor |
| * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool) |
| * <ul> |
| * <li>output scores for each class, if type is one of the Int types, dequantize it, if it |
| * is Bool type, convert the values to 0.0 and 1.0 respectively. |
| * <li>can have an optional associated file in metadata for labels, the file should be a |
| * plain text file with one label per line, the number of labels should match the number |
| * of categories the model outputs. Output label tensor: optional (kTfLiteString) - |
| * output classname for each class, should be of the same length with scores. If this |
| * tensor is not present, the API uses score indices as classnames. - will be ignored if |
| * output score tensor already has an associated label file. |
| * </ul> |
| * <li>Optional Output label tensor (kTfLiteString/kTfLiteInt32) |
| * <ul> |
| * <li>output classname for each class, should be of the same length with scores. If this |
| * tensor is not present, the API uses score indices as classnames. |
| * <li>will be ignored if output score tensor already has an associated labe file. |
| * </ul> |
| * </ul> |
| * |
| * <p>By default the API tries to find the input/output tensors with default configurations in |
| * {@link NLClassifierOptions}, with tensor name prioritized over tensor index. The option is |
| * configurable for different TFLite models. |
| */ |
| public class NLClassifier extends BaseTaskApi { |
| |
| /** Options to identify input and output tensors of the model. */ |
| @AutoValue |
| @UsedByReflection("nl_classifier_jni.cc") |
| public abstract static class NLClassifierOptions { |
| private static final int DEFAULT_INPUT_TENSOR_INDEX = 0; |
| private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0; |
| // By default there is no output label tensor. The label file can be attached |
| // to the output score tensor metadata. |
| private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1; |
| private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT"; |
| private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE"; |
| private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL"; |
| |
| @UsedByReflection("nl_classifier_jni.cc") |
| abstract int inputTensorIndex(); |
| |
| @UsedByReflection("nl_classifier_jni.cc") |
| abstract int outputScoreTensorIndex(); |
| |
| @UsedByReflection("nl_classifier_jni.cc") |
| abstract int outputLabelTensorIndex(); |
| |
| @UsedByReflection("nl_classifier_jni.cc") |
| abstract String inputTensorName(); |
| |
| @UsedByReflection("nl_classifier_jni.cc") |
| abstract String outputScoreTensorName(); |
| |
| @UsedByReflection("nl_classifier_jni.cc") |
| abstract String outputLabelTensorName(); |
| |
| public static Builder builder() { |
| return new AutoValue_NLClassifier_NLClassifierOptions.Builder() |
| .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX) |
| .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX) |
| .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX) |
| .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME) |
| .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME) |
| .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME); |
| } |
| |
| /** Builder for {@link NLClassifierOptions}. */ |
| @AutoValue.Builder |
| public abstract static class Builder { |
| public abstract Builder setInputTensorIndex(int value); |
| |
| public abstract Builder setOutputScoreTensorIndex(int value); |
| |
| public abstract Builder setOutputLabelTensorIndex(int value); |
| |
| public abstract Builder setInputTensorName(String value); |
| |
| public abstract Builder setOutputScoreTensorName(String value); |
| |
| public abstract Builder setOutputLabelTensorName(String value); |
| |
| public abstract NLClassifierOptions build(); |
| } |
| } |
| |
| private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; |
| |
| /** |
| * Constructor to initialize the JNI with a pointer from C++. |
| * |
| * @param nativeHandle a pointer referencing memory allocated in C++. |
| */ |
| protected NLClassifier(long nativeHandle) { |
| super(nativeHandle); |
| } |
| |
| /** |
| * Create {@link NLClassifier} from default {@link NLClassifierOptions}. |
| * |
| * @param context Android context. |
| * @param pathToModel Path to the classification model relative to asset dir. |
| * @return {@link NLClassifier} instance. |
| * @throws IOException If model file fails to load. |
| */ |
| public static NLClassifier createFromFile(Context context, String pathToModel) |
| throws IOException { |
| return createFromFileAndOptions(context, pathToModel, NLClassifierOptions.builder().build()); |
| } |
| |
| /** |
| * Create {@link NLClassifier} from default {@link NLClassifierOptions}. |
| * |
| * @param modelFile The classification model {@link File} instance. |
| * @return {@link NLClassifier} instance. |
| * @throws IOException If model file fails to load. |
| */ |
| public static NLClassifier createFromFile(File modelFile) throws IOException { |
| return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build()); |
| } |
| |
| /** |
| * Create {@link NLClassifier} from {@link NLClassifierOptions}. |
| * |
| * @param context Android context |
| * @param pathToModel Path to the classification model relative to asset dir. |
| * @param options Configurations for the model. |
| * @return {@link NLClassifier} instance. |
| * @throws IOException If model file fails to load. |
| */ |
| public static NLClassifier createFromFileAndOptions( |
| Context context, String pathToModel, NLClassifierOptions options) throws IOException { |
| return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, pathToModel), options); |
| } |
| |
| /** |
| * Create {@link NLClassifier} from {@link NLClassifierOptions}. |
| * |
| * @param modelFile The classification model {@link File} instance. |
| * @param options Configurations for the model. |
| * @return {@link NLClassifier} instance. |
| * @throws IOException If model file fails to load. |
| */ |
| public static NLClassifier createFromFileAndOptions( |
| File modelFile, final NLClassifierOptions options) throws IOException { |
| try (ParcelFileDescriptor descriptor = |
| ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { |
| return new NLClassifier( |
| TaskJniUtils.createHandleFromLibrary( |
| new EmptyHandleProvider() { |
| @Override |
| public long createHandle() { |
| return initJniWithFileDescriptor(options, descriptor.getFd()); |
| } |
| }, |
| NL_CLASSIFIER_NATIVE_LIBNAME)); |
| } |
| } |
| |
| /** |
| * Create {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}. |
| * |
| * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the |
| * classification model |
| * @param options Configurations for the model |
| * @return {@link NLClassifier} instance |
| * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a |
| * {@link MappedByteBuffer} |
| */ |
| public static NLClassifier createFromBufferAndOptions( |
| final ByteBuffer modelBuffer, final NLClassifierOptions options) { |
| if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { |
| throw new IllegalArgumentException( |
| "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); |
| } |
| return new NLClassifier( |
| TaskJniUtils.createHandleFromLibrary( |
| new EmptyHandleProvider() { |
| @Override |
| public long createHandle() { |
| return initJniWithByteBuffer(options, modelBuffer); |
| } |
| }, |
| NL_CLASSIFIER_NATIVE_LIBNAME)); |
| } |
| |
| /** |
| * Perform classification on a string input, returns classified {@link Category}s. |
| * |
| * @param text input text to the model. |
| * @return A list of Category results. |
| */ |
| public List<Category> classify(String text) { |
| return classifyNative(getNativeHandle(), text); |
| } |
| |
| private static native long initJniWithByteBuffer( |
| NLClassifierOptions options, ByteBuffer modelBuffer); |
| |
| private static native long initJniWithFileDescriptor(NLClassifierOptions options, int fd); |
| |
| private static native List<Category> classifyNative(long nativeHandle, String text); |
| |
| @Override |
| protected void deinit(long nativeHandle) { |
| deinitJni(nativeHandle); |
| } |
| |
| /** |
| * Native implementation to release memory pointed by the pointer. |
| * |
| * @param nativeHandle pointer to memory allocated |
| */ |
| private native void deinitJni(long nativeHandle); |
| } |