blob: 2d52f93714e721b8b468c5e791e8454bd9933fb8 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <jni.h>
#include <memory>
#include <string>
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
namespace {
using ::tflite::support::StatusOr;
using ::tflite::support::utils::GetMappedFileBuffer;
using ::tflite::support::utils::kAssertionError;
using ::tflite::support::utils::kInvalidPointer;
using ::tflite::support::utils::StringListToVector;
using ::tflite::support::utils::ThrowException;
using ::tflite::task::vision::BoundingBox;
using ::tflite::task::vision::ClassificationResult;
using ::tflite::task::vision::Classifications;
using ::tflite::task::vision::ConvertToCategory;
using ::tflite::task::vision::ConvertToFrameBufferOrientation;
using ::tflite::task::vision::FrameBuffer;
using ::tflite::task::vision::ImageClassifier;
using ::tflite::task::vision::ImageClassifierOptions;
// Creates an ImageClassifierOptions proto based on the Java class.
ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env,
jobject java_options) {
ImageClassifierOptions proto_options;
jclass java_options_class = env->FindClass(
"org/tensorflow/lite/task/vision/classifier/"
"ImageClassifier$ImageClassifierOptions");
jmethodID display_names_locale_id = env->GetMethodID(
java_options_class, "getDisplayNamesLocale", "()Ljava/lang/String;");
jstring display_names_locale = static_cast<jstring>(
env->CallObjectMethod(java_options, display_names_locale_id));
const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
proto_options.set_display_names_locale(pchars);
env->ReleaseStringUTFChars(display_names_locale, pchars);
jmethodID max_results_id =
env->GetMethodID(java_options_class, "getMaxResults", "()I");
jint max_results = env->CallIntMethod(java_options, max_results_id);
proto_options.set_max_results(max_results);
jmethodID is_score_threshold_set_id =
env->GetMethodID(java_options_class, "getIsScoreThresholdSet", "()Z");
jboolean is_score_threshold_set =
env->CallBooleanMethod(java_options, is_score_threshold_set_id);
if (is_score_threshold_set) {
jmethodID score_threshold_id =
env->GetMethodID(java_options_class, "getScoreThreshold", "()F");
jfloat score_threshold =
env->CallFloatMethod(java_options, score_threshold_id);
proto_options.set_score_threshold(score_threshold);
}
jmethodID allow_list_id = env->GetMethodID(
java_options_class, "getLabelAllowList", "()Ljava/util/List;");
jobject allow_list = env->CallObjectMethod(java_options, allow_list_id);
auto allow_list_vector = StringListToVector(env, allow_list);
for (const auto& class_name : allow_list_vector) {
proto_options.add_class_name_whitelist(class_name);
}
jmethodID deny_list_id = env->GetMethodID(
java_options_class, "getLabelDenyList", "()Ljava/util/List;");
jobject deny_list = env->CallObjectMethod(java_options, deny_list_id);
auto deny_list_vector = StringListToVector(env, deny_list);
for (const auto& class_name : deny_list_vector) {
proto_options.add_class_name_blacklist(class_name);
}
jmethodID num_threads_id =
env->GetMethodID(java_options_class, "getNumThreads", "()I");
jint num_threads = env->CallIntMethod(java_options, num_threads_id);
proto_options.set_num_threads(num_threads);
return proto_options;
}
jobject ConvertToClassificationResults(JNIEnv* env,
const ClassificationResult& results) {
// jclass and init of Classifications.
jclass classifications_class = env->FindClass(
"org/tensorflow/lite/task/vision/classifier/Classifications");
jmethodID classifications_create =
env->GetStaticMethodID(classifications_class, "create",
"(Ljava/util/List;I)Lorg/tensorflow/lite/"
"task/vision/classifier/Classifications;");
// jclass, init, and add of ArrayList.
jclass array_list_class = env->FindClass("java/util/ArrayList");
jmethodID array_list_init =
env->GetMethodID(array_list_class, "<init>", "(I)V");
jmethodID array_list_add_method =
env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z");
jobject classifications_list =
env->NewObject(array_list_class, array_list_init,
static_cast<jint>(results.classifications_size()));
for (int i = 0; i < results.classifications_size(); i++) {
auto classifications = results.classifications(i);
jobject jcategory_list = env->NewObject(array_list_class, array_list_init,
classifications.classes_size());
for (const auto& classification : classifications.classes()) {
jobject jcategory = ConvertToCategory(env, classification);
env->CallBooleanMethod(jcategory_list, array_list_add_method, jcategory);
env->DeleteLocalRef(jcategory);
}
jobject jclassifications = env->CallStaticObjectMethod(
classifications_class, classifications_create, jcategory_list,
classifications.head_index());
env->CallBooleanMethod(classifications_list, array_list_add_method,
jclassifications);
env->DeleteLocalRef(jcategory_list);
env->DeleteLocalRef(jclassifications);
}
return classifications_list;
}
jlong CreateImageClassifierFromOptions(JNIEnv* env,
const ImageClassifierOptions& options) {
StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
ImageClassifier::CreateFromOptions(options);
if (image_classifier_or.ok()) {
// Deletion is handled at deinitJni time.
return reinterpret_cast<jlong>(image_classifier_or->release());
} else {
ThrowException(env, kAssertionError,
"Error occurred when initializing ImageClassifier: %s",
image_classifier_or.status().message().data());
return kInvalidPointer;
}
}
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
JNIEnv* env, jobject thiz, jlong native_handle) {
delete reinterpret_cast<ImageClassifier*>(native_handle);
}
// Creates an ImageClassifier instance from the model file descriptor.
// file_descriptor_length and file_descriptor_offset are optional. Non-possitive
// values will be ignored.
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions(
JNIEnv* env, jclass thiz, jint file_descriptor,
jlong file_descriptor_length, jlong file_descriptor_offset,
jobject java_options) {
ImageClassifierOptions proto_options =
ConvertToProtoOptions(env, java_options);
auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
->mutable_file_descriptor_meta();
file_descriptor_meta->set_fd(file_descriptor);
if (file_descriptor_length > 0) {
file_descriptor_meta->set_length(file_descriptor_length);
}
if (file_descriptor_offset > 0) {
file_descriptor_meta->set_offset(file_descriptor_offset);
}
return CreateImageClassifierFromOptions(env, proto_options);
}
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer(
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
ImageClassifierOptions proto_options =
ConvertToProtoOptions(env, java_options);
// External proto generated header does not overload `set_file_content` with
// string_view, therefore GetMappedFileBuffer does not apply here.
// Creating a std::string will cause one extra copying of data. Thus, the
// most efficient way here is to set file_content using char* and its size.
proto_options.mutable_model_file_with_metadata()->set_file_content(
static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
return CreateImageClassifierFromOptions(env, proto_options);
}
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_classifyNative(
JNIEnv* env, jclass thiz, jlong native_handle, jobject image_byte_buffer,
jint width, jint height, jintArray jroi, jint jorientation) {
auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle);
auto image = GetMappedFileBuffer(env, image_byte_buffer);
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
reinterpret_cast<const uint8*>(image.data()),
FrameBuffer::Dimension{width, height},
ConvertToFrameBufferOrientation(env, jorientation));
int* roi_array = env->GetIntArrayElements(jroi, 0);
BoundingBox roi;
roi.set_origin_x(roi_array[0]);
roi.set_origin_y(roi_array[1]);
roi.set_width(roi_array[2]);
roi.set_height(roi_array[3]);
env->ReleaseIntArrayElements(jroi, roi_array, 0);
auto results_or = classifier->Classify(*frame_buffer, roi);
if (results_or.ok()) {
return ConvertToClassificationResults(env, results_or.value());
} else {
ThrowException(env, kAssertionError,
"Error occurred when classifying the image: %s",
results_or.status().message().data());
return nullptr;
}
}
} // namespace