blob: 016b0bfd87388b9443ebd7d7c34fc5050161e2a6 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <jni.h>
#include <memory>
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
#include "tensorflow_lite_support/cc/task/vision/object_detector.h"
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
namespace {
using ::tflite::support::StatusOr;
using ::tflite::support::utils::GetMappedFileBuffer;
using ::tflite::support::utils::kAssertionError;
using ::tflite::support::utils::kInvalidPointer;
using ::tflite::support::utils::StringListToVector;
using ::tflite::support::utils::ThrowException;
using ::tflite::task::vision::BoundingBox;
using ::tflite::task::vision::ConvertToCategory;
using ::tflite::task::vision::ConvertToFrameBufferOrientation;
using ::tflite::task::vision::DetectionResult;
using ::tflite::task::vision::FrameBuffer;
using ::tflite::task::vision::ObjectDetector;
using ::tflite::task::vision::ObjectDetectorOptions;
// Creates an ObjectDetectorOptions proto based on the Java class.
ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options) {
ObjectDetectorOptions proto_options;
jclass java_options_class = env->FindClass(
"org/tensorflow/lite/task/vision/detector/"
"ObjectDetector$ObjectDetectorOptions");
jmethodID display_names_locale_id = env->GetMethodID(
java_options_class, "getDisplayNamesLocale", "()Ljava/lang/String;");
jstring display_names_locale = static_cast<jstring>(
env->CallObjectMethod(java_options, display_names_locale_id));
const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
proto_options.set_display_names_locale(pchars);
env->ReleaseStringUTFChars(display_names_locale, pchars);
jmethodID max_results_id =
env->GetMethodID(java_options_class, "getMaxResults", "()I");
jint max_results = env->CallIntMethod(java_options, max_results_id);
proto_options.set_max_results(max_results);
jmethodID is_score_threshold_set_id =
env->GetMethodID(java_options_class, "getIsScoreThresholdSet", "()Z");
jboolean is_score_threshold_set =
env->CallBooleanMethod(java_options, is_score_threshold_set_id);
if (is_score_threshold_set) {
jmethodID score_threshold_id =
env->GetMethodID(java_options_class, "getScoreThreshold", "()F");
jfloat score_threshold =
env->CallFloatMethod(java_options, score_threshold_id);
proto_options.set_score_threshold(score_threshold);
}
jmethodID allow_list_id = env->GetMethodID(
java_options_class, "getLabelAllowList", "()Ljava/util/List;");
jobject allow_list = env->CallObjectMethod(java_options, allow_list_id);
std::vector<std::string> allow_list_vector =
StringListToVector(env, allow_list);
for (const auto& class_name : allow_list_vector) {
proto_options.add_class_name_whitelist(class_name);
}
jmethodID deny_list_id = env->GetMethodID(
java_options_class, "getLabelDenyList", "()Ljava/util/List;");
jobject deny_list = env->CallObjectMethod(java_options, deny_list_id);
auto deny_list_vector = StringListToVector(env, deny_list);
for (const auto& class_name : deny_list_vector) {
proto_options.add_class_name_blacklist(class_name);
}
jmethodID num_threads_id =
env->GetMethodID(java_options_class, "getNumThreads", "()I");
jint num_threads = env->CallIntMethod(java_options, num_threads_id);
proto_options.set_num_threads(num_threads);
return proto_options;
}
jobject ConvertToDetectionResults(JNIEnv* env, const DetectionResult& results) {
// jclass and init of Detection.
jclass detection_class =
env->FindClass("org/tensorflow/lite/task/vision/detector/Detection");
jmethodID detection_create = env->GetStaticMethodID(
detection_class, "create",
"(Landroid/graphics/RectF;Ljava/util/List;)Lorg/tensorflow/lite/"
"task/vision/detector/Detection;");
// jclass, init, and add of ArrayList.
jclass array_list_class = env->FindClass("java/util/ArrayList");
jmethodID array_list_init =
env->GetMethodID(array_list_class, "<init>", "(I)V");
jmethodID array_list_add_method =
env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z");
// jclass, init of RectF.
jclass rectf_class = env->FindClass("android/graphics/RectF");
jmethodID rectf_init = env->GetMethodID(rectf_class, "<init>", "(FFFF)V");
jobject detections_list =
env->NewObject(array_list_class, array_list_init,
static_cast<jint>(results.detections_size()));
for (const auto& detection : results.detections()) {
// Create the category list.
jobject category_list = env->NewObject(array_list_class, array_list_init,
detection.classes_size());
for (const auto& classification : detection.classes()) {
jobject jcategory = ConvertToCategory(env, classification);
env->CallBooleanMethod(category_list, array_list_add_method, jcategory);
}
// Create the bounding box object.
const BoundingBox& bounding_box = detection.bounding_box();
float left = static_cast<float>(bounding_box.origin_x());
float top = static_cast<float>(bounding_box.origin_y());
float right = static_cast<float>(left + bounding_box.width());
float bottom = static_cast<float>(top + bounding_box.height());
jobject jbounding_box =
env->NewObject(rectf_class, rectf_init, left, top, right, bottom);
// Create the java Detection object.
jobject jdetection = env->CallStaticObjectMethod(
detection_class, detection_create, jbounding_box, category_list);
env->CallBooleanMethod(detections_list, array_list_add_method, jdetection);
}
return detections_list;
}
jlong CreateObjectDetectorFromOptions(JNIEnv* env,
const ObjectDetectorOptions& options) {
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
if (object_detector_or.ok()) {
return reinterpret_cast<jlong>(object_detector_or->release());
} else {
ThrowException(env, kAssertionError,
"Error occurred when initializing ObjectDetector: %s",
object_detector_or.status().message().data());
return kInvalidPointer;
}
}
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
JNIEnv* env, jobject thiz, jlong native_handle) {
delete reinterpret_cast<ObjectDetector*>(native_handle);
}
// Creates an ObjectDetector instance from the model file descriptor.
// file_descriptor_length and file_descriptor_offset are optional. Non-possitive
// values will be ignored.
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions(
JNIEnv* env, jclass thiz, jint file_descriptor,
jlong file_descriptor_length, jlong file_descriptor_offset,
jobject java_options) {
ObjectDetectorOptions proto_options =
ConvertToProtoOptions(env, java_options);
auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
->mutable_file_descriptor_meta();
file_descriptor_meta->set_fd(file_descriptor);
if (file_descriptor_length > 0) {
file_descriptor_meta->set_length(file_descriptor_length);
}
if (file_descriptor_offset > 0) {
file_descriptor_meta->set_offset(file_descriptor_offset);
}
return CreateObjectDetectorFromOptions(env, proto_options);
}
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer(
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
ObjectDetectorOptions proto_options =
ConvertToProtoOptions(env, java_options);
proto_options.mutable_model_file_with_metadata()->set_file_content(
static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
return CreateObjectDetectorFromOptions(env, proto_options);
}
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_detectNative(
JNIEnv* env, jclass thiz, jlong native_handle, jobject image_byte_buffer,
jint width, jint height, jint jorientation) {
auto* detector = reinterpret_cast<ObjectDetector*>(native_handle);
absl::string_view image = GetMappedFileBuffer(env, image_byte_buffer);
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
reinterpret_cast<const uint8*>(image.data()),
FrameBuffer::Dimension{width, height},
ConvertToFrameBufferOrientation(env, jorientation));
auto results_or = detector->Detect(*frame_buffer);
if (results_or.ok()) {
return ConvertToDetectionResults(env, results_or.value());
} else {
ThrowException(env, kAssertionError,
"Error occurred when detecting the image: %s",
results_or.status().message().data());
return nullptr;
}
}
} // namespace