| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <jni.h> |
| |
| #include <memory> |
| #include <string> |
| |
| #include "absl/strings/str_cat.h" |
| #include "tensorflow_lite_support/cc/port/statusor.h" |
| #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" |
| #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" |
| #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" |
| #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" |
| #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" |
| #include "tensorflow_lite_support/cc/utils/jni_utils.h" |
| #include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" |
| |
| namespace { |
| |
| using ::tflite::support::StatusOr; |
| using ::tflite::support::utils::CreateByteArray; |
| using ::tflite::support::utils::GetMappedFileBuffer; |
| using ::tflite::support::utils::kAssertionError; |
| using ::tflite::support::utils::kIllegalArgumentException; |
| using ::tflite::support::utils::kInvalidPointer; |
| using ::tflite::support::utils::ThrowException; |
| using ::tflite::task::vision::ConvertToFrameBufferOrientation; |
| using ::tflite::task::vision::FrameBuffer; |
| using ::tflite::task::vision::ImageSegmenter; |
| using ::tflite::task::vision::ImageSegmenterOptions; |
| using ::tflite::task::vision::Segmentation; |
| using ::tflite::task::vision::SegmentationResult; |
| |
| constexpr char kArrayListClassNameNoSig[] = "java/util/ArrayList"; |
| constexpr char kObjectClassName[] = "Ljava/lang/Object;"; |
| constexpr char kColorClassName[] = "Landroid/graphics/Color;"; |
| constexpr char kColorClassNameNoSig[] = "android/graphics/Color"; |
| constexpr char kColoredLabelClassName[] = |
| "Lorg/tensorflow/lite/task/vision/segmenter/ColoredLabel;"; |
| constexpr char kColoredLabelClassNameNoSig[] = |
| "org/tensorflow/lite/task/vision/segmenter/ColoredLabel"; |
| constexpr char kStringClassName[] = "Ljava/lang/String;"; |
| constexpr int kOutputTypeCategoryMask = 0; |
| constexpr int kOutputTypeConfidenceMask = 1; |
| |
| // Creates an ImageSegmenterOptions proto based on the Java class. |
| ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env, |
| jstring display_names_locale, |
| jint output_type, |
| jint num_threads) { |
| ImageSegmenterOptions proto_options; |
| |
| const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr); |
| proto_options.set_display_names_locale(pchars); |
| env->ReleaseStringUTFChars(display_names_locale, pchars); |
| |
| switch (output_type) { |
| case kOutputTypeCategoryMask: |
| proto_options.set_output_type(ImageSegmenterOptions::CATEGORY_MASK); |
| break; |
| case kOutputTypeConfidenceMask: |
| proto_options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK); |
| break; |
| default: |
| // Should never happen. |
| ThrowException(env, kIllegalArgumentException, |
| "Unsupported output type: %d", output_type); |
| } |
| |
| proto_options.set_num_threads(num_threads); |
| |
| return proto_options; |
| } |
| |
| void ConvertToSegmentationResults(JNIEnv* env, |
| const SegmentationResult& results, |
| jobject jmask_buffers, jintArray jmask_shape, |
| jobject jcolored_labels) { |
| if (results.segmentation_size() != 1) { |
| // Should never happen. |
| ThrowException( |
| env, kAssertionError, |
| "ImageSegmenter only supports one segmentation result, getting %d", |
| results.segmentation_size()); |
| } |
| |
| const Segmentation& segmentation = results.segmentation(0); |
| |
| // Get the shape from the C++ Segmentation results. |
| int shape_array[2] = {segmentation.height(), segmentation.width()}; |
| env->SetIntArrayRegion(jmask_shape, 0, 2, shape_array); |
| |
| // jclass, init, and add of ArrayList. |
| jclass array_list_class = env->FindClass(kArrayListClassNameNoSig); |
| jmethodID array_list_add_method = |
| env->GetMethodID(array_list_class, "add", |
| absl::StrCat("(", kObjectClassName, ")Z").c_str()); |
| |
| // Convert the masks into ByteBuffer list. |
| int num_pixels = segmentation.height() * segmentation.width(); |
| if (segmentation.has_category_mask()) { |
| jbyteArray byte_array = CreateByteArray( |
| env, |
| reinterpret_cast<const jbyte*>(segmentation.category_mask().data()), |
| num_pixels * sizeof(uint8)); |
| env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array); |
| env->DeleteLocalRef(byte_array); |
| } else { |
| for (const auto& confidence_mask : |
| segmentation.confidence_masks().confidence_mask()) { |
| jbyteArray byte_array = CreateByteArray( |
| env, reinterpret_cast<const jbyte*>(confidence_mask.value().data()), |
| num_pixels * sizeof(float)); |
| env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array); |
| env->DeleteLocalRef(byte_array); |
| } |
| } |
| |
| // Convert colored labels from the C++ object to the Java object. |
| jclass color_class = env->FindClass(kColorClassNameNoSig); |
| jmethodID color_rgb_method = |
| env->GetStaticMethodID(color_class, "rgb", "(III)I"); |
| jclass colored_label_class = env->FindClass(kColoredLabelClassNameNoSig); |
| jmethodID colored_label_create_method = env->GetStaticMethodID( |
| colored_label_class, "create", |
| absl::StrCat("(", kStringClassName, kStringClassName, "I)", |
| kColoredLabelClassName) |
| .c_str()); |
| |
| for (const auto& colored_label : segmentation.colored_labels()) { |
| jstring label = env->NewStringUTF(colored_label.class_name().c_str()); |
| jstring display_name = |
| env->NewStringUTF(colored_label.display_name().c_str()); |
| jint rgb = env->CallStaticIntMethod(color_class, color_rgb_method, |
| colored_label.r(), colored_label.g(), |
| colored_label.b()); |
| jobject jcolored_label = env->CallStaticObjectMethod( |
| colored_label_class, colored_label_create_method, label, display_name, |
| rgb); |
| env->CallBooleanMethod(jcolored_labels, array_list_add_method, |
| jcolored_label); |
| |
| env->DeleteLocalRef(label); |
| env->DeleteLocalRef(display_name); |
| env->DeleteLocalRef(jcolored_label); |
| } |
| } |
| |
| jlong CreateImageClassifierFromOptions(JNIEnv* env, |
| const ImageSegmenterOptions& options) { |
| StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = |
| ImageSegmenter::CreateFromOptions(options); |
| if (image_segmenter_or.ok()) { |
| return reinterpret_cast<jlong>(image_segmenter_or->release()); |
| } else { |
| ThrowException(env, kAssertionError, |
| "Error occurred when initializing ImageSegmenter: %s", |
| image_segmenter_or.status().message().data()); |
| return kInvalidPointer; |
| } |
| } |
| |
| extern "C" JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni( |
| JNIEnv* env, jobject thiz, jlong native_handle) { |
| delete reinterpret_cast<ImageSegmenter*>(native_handle); |
| } |
| |
| // Creates an ImageSegmenter instance from the model file descriptor. |
| // file_descriptor_length and file_descriptor_offset are optional. Non-possitive |
| // values will be ignored. |
| extern "C" JNIEXPORT jlong JNICALL |
| Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions( |
| JNIEnv* env, jclass thiz, jint file_descriptor, |
| jlong file_descriptor_length, jlong file_descriptor_offset, |
| jstring display_names_locale, jint output_type, jint num_threads) { |
| ImageSegmenterOptions proto_options = ConvertToProtoOptions( |
| env, display_names_locale, output_type, num_threads); |
| auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata() |
| ->mutable_file_descriptor_meta(); |
| file_descriptor_meta->set_fd(file_descriptor); |
| if (file_descriptor_length > 0) { |
| file_descriptor_meta->set_length(file_descriptor_length); |
| } |
| if (file_descriptor_offset > 0) { |
| file_descriptor_meta->set_offset(file_descriptor_offset); |
| } |
| return CreateImageClassifierFromOptions(env, proto_options); |
| } |
| |
| extern "C" JNIEXPORT jlong JNICALL |
| Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer( |
| JNIEnv* env, jclass thiz, jobject model_buffer, |
| jstring display_names_locale, jint output_type, jint num_threads) { |
| ImageSegmenterOptions proto_options = ConvertToProtoOptions( |
| env, display_names_locale, output_type, num_threads); |
| proto_options.mutable_model_file_with_metadata()->set_file_content( |
| static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), |
| static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); |
| return CreateImageClassifierFromOptions(env, proto_options); |
| } |
| |
| extern "C" JNIEXPORT void JNICALL |
| Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative( |
| JNIEnv* env, jclass thiz, jlong native_handle, jobject jimage_byte_buffer, |
| jint width, jint height, jobject jmask_buffers, jintArray jmask_shape, |
| jobject jcolored_labels, jint jorientation) { |
| auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle); |
| absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer); |
| std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( |
| reinterpret_cast<const uint8*>(image.data()), |
| FrameBuffer::Dimension{width, height}, |
| ConvertToFrameBufferOrientation(env, jorientation)); |
| auto results_or = segmenter->Segment(*frame_buffer); |
| if (results_or.ok()) { |
| ConvertToSegmentationResults(env, results_or.value(), jmask_buffers, |
| jmask_shape, jcolored_labels); |
| } else { |
| ThrowException(env, kAssertionError, |
| "Error occurred when segmenting the image: %s", |
| results_or.status().message().data()); |
| } |
| } |
| |
| } // namespace |