blob: 55f14e6c6eb7c7dba374d399b03a42bd2339880d [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/annotator_jni_common.h"
#include "utils/java/jni-base.h"
#include "utils/java/scoped_local_ref.h"
namespace libtextclassifier3 {
namespace {
std::unordered_set<std::string> EntityTypesFromJObject(JNIEnv* env,
const jobject& jobject) {
std::unordered_set<std::string> entity_types;
jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
const int size = env->GetArrayLength(jentity_types);
for (int i = 0; i < size; ++i) {
jstring jentity_type =
reinterpret_cast<jstring>(env->GetObjectArrayElement(jentity_types, i));
entity_types.insert(ToStlString(env, jentity_type));
}
return entity_types;
}
template <typename T>
T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
const std::string& class_name) {
if (!joptions) {
return {};
}
const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
env);
if (!options_class) {
return {};
}
const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
"getLocale", "Ljava/lang/String;");
const std::pair<bool, jobject> status_or_reference_timezone =
CallJniMethod0<jobject>(env, joptions, options_class.get(),
&JNIEnv::CallObjectMethod, "getReferenceTimezone",
"Ljava/lang/String;");
const std::pair<bool, int64> status_or_reference_time_ms_utc =
CallJniMethod0<int64>(env, joptions, options_class.get(),
&JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
"J");
const std::pair<bool, jobject> status_or_detected_text_language_tags =
CallJniMethod0<jobject>(
env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
"getDetectedTextLanguageTags", "Ljava/lang/String;");
const std::pair<bool, int> status_or_annotation_usecase =
CallJniMethod0<int>(env, joptions, options_class.get(),
&JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
if (!status_or_locales.first || !status_or_reference_timezone.first ||
!status_or_reference_time_ms_utc.first ||
!status_or_detected_text_language_tags.first ||
!status_or_annotation_usecase.first) {
return {};
}
T options;
options.locales =
ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
options.reference_timezone = ToStlString(
env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
options.detected_text_language_tags = ToStlString(
env,
reinterpret_cast<jstring>(status_or_detected_text_language_tags.second));
options.annotation_usecase =
static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
return options;
}
} // namespace
SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
if (!joptions) {
return {};
}
const ScopedLocalRef<jclass> options_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$SelectionOptions"),
env);
const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
"getLocales", "Ljava/lang/String;");
const std::pair<bool, int> status_or_annotation_usecase =
CallJniMethod0<int>(env, joptions, options_class.get(),
&JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
if (!status_or_locales.first || !status_or_annotation_usecase.first) {
return {};
}
SelectionOptions options;
options.locales =
ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
options.annotation_usecase =
static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
return options;
}
ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
jobject joptions) {
return FromJavaOptionsInternal<ClassificationOptions>(
env, joptions,
TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions");
}
AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
if (!joptions) return {};
const ScopedLocalRef<jclass> options_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$AnnotationOptions"),
env);
if (!options_class) return {};
const std::pair<bool, jobject> status_or_entity_types =
CallJniMethod0<jobject>(env, joptions, options_class.get(),
&JNIEnv::CallObjectMethod, "getEntityTypes",
"[Ljava/lang/String;");
if (!status_or_entity_types.first) return {};
const std::pair<bool, bool> status_or_enable_serialized_entity_data =
CallJniMethod0<bool>(env, joptions, options_class.get(),
&JNIEnv::CallBooleanMethod,
"isSerializedEntityDataEnabled", "Z");
if (!status_or_enable_serialized_entity_data.first) return {};
AnnotationOptions annotation_options =
FromJavaOptionsInternal<AnnotationOptions>(
env, joptions,
TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions");
annotation_options.entity_types =
EntityTypesFromJObject(env, status_or_entity_types.second);
annotation_options.is_serialized_entity_data_enabled =
status_or_enable_serialized_entity_data.second;
return annotation_options;
}
} // namespace libtextclassifier3