blob: 17571c3a9311080deaa45df74d433abf7ee67dbc [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// JNI wrapper for actions.
#include "actions/actions_jni.h"
#include <jni.h>
#include <type_traits>
#include <vector>
#include "actions/actions-suggestions.h"
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
#include "utils/base/integral_types.h"
#include "utils/java/scoped_local_ref.h"
#include "utils/memory/mmap.h"
using libtextclassifier3::ActionsSuggestions;
using libtextclassifier3::ActionsSuggestionsResponse;
using libtextclassifier3::ActionSuggestion;
using libtextclassifier3::ActionSuggestionOptions;
using libtextclassifier3::Annotator;
using libtextclassifier3::Conversation;
using libtextclassifier3::ScopedLocalRef;
using libtextclassifier3::ToStlString;
namespace libtextclassifier3 {
namespace {
ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
jobject joptions) {
ActionSuggestionOptions options = ActionSuggestionOptions::Default();
if (!joptions) {
return options;
}
const ScopedLocalRef<jclass> options_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
"$ActionSuggestionOptions"),
env);
if (!options_class) {
return options;
}
const std::pair<bool, jobject> status_or_annotation_options =
CallJniMethod0<jobject>(env, joptions, options_class.get(),
&JNIEnv::CallObjectMethod, "getAnnotationOptions",
"L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$AnnotationOptions;");
if (!status_or_annotation_options.first) {
return options;
}
// Create annotation options.
options.annotation_options =
FromJavaAnnotationOptions(env, status_or_annotation_options.second);
return options;
}
jobjectArray ActionSuggestionsToJObjectArray(
JNIEnv* env, const std::vector<ActionSuggestion>& action_result) {
const ScopedLocalRef<jclass> result_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
"$ActionSuggestion"),
env);
if (!result_class) {
TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
return nullptr;
}
const jmethodID result_class_constructor = env->GetMethodID(
result_class.get(), "<init>", "(Ljava/lang/String;Ljava/lang/String;F)V");
const jobjectArray results =
env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
for (int i = 0; i < action_result.size(); i++) {
ScopedLocalRef<jobject> result(env->NewObject(
result_class.get(), result_class_constructor,
env->NewStringUTF(action_result[i].response_text.c_str()),
env->NewStringUTF(action_result[i].type.c_str()),
static_cast<jfloat>(action_result[i].score)));
env->SetObjectArrayElement(results, i, result.get());
}
return results;
}
ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
if (!jmessage) {
return {};
}
const ScopedLocalRef<jclass> message_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
"$ConversationMessage"),
env);
const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
"Ljava/lang/String;");
const std::pair<bool, int32> status_or_user_id =
CallJniMethod0<int32>(env, jmessage, message_class.get(),
&JNIEnv::CallIntMethod, "getUserId", "I");
const std::pair<bool, int32> status_or_time_diff = CallJniMethod0<int32>(
env, jmessage, message_class.get(), &JNIEnv::CallIntMethod,
"getTimeDiffInSeconds", "I");
const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
"getLocales", "Ljava/lang/String;");
if (!status_or_text.first || !status_or_user_id.first ||
!status_or_locales.first || !status_or_time_diff.first) {
return {};
}
ConversationMessage message;
message.text =
ToStlString(env, reinterpret_cast<jstring>(status_or_text.second));
message.user_id = status_or_user_id.second;
message.time_diff_secs = status_or_time_diff.second;
message.locales =
ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
return message;
}
Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
if (!jconversation) {
return {};
}
const ScopedLocalRef<jclass> conversation_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
"$Conversation"),
env);
const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
"getConversationMessages",
"[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
if (!status_or_messages.first) {
return {};
}
const jobjectArray jmessages =
reinterpret_cast<jobjectArray>(status_or_messages.second);
const int size = env->GetArrayLength(jmessages);
std::vector<ConversationMessage> messages;
for (int i = 0; i < size; i++) {
jobject jmessage = env->GetObjectArrayElement(jmessages, i);
ConversationMessage message = FromJavaConversationMessage(env, jmessage);
messages.push_back(message);
}
Conversation conversation;
conversation.messages = messages;
return conversation;
}
jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
return env->NewStringUTF("");
}
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->locales()) {
return env->NewStringUTF("");
}
return env->NewStringUTF(model->locales()->c_str());
}
jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
return 0;
}
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model) {
return 0;
}
return model->version();
}
jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
return env->NewStringUTF("");
}
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->name()) {
return env->NewStringUTF("");
}
return env->NewStringUTF(model->name()->c_str());
}
} // namespace
} // namespace libtextclassifier3
using libtextclassifier3::ActionSuggestionsToJObjectArray;
using libtextclassifier3::FromJavaActionSuggestionOptions;
using libtextclassifier3::FromJavaConversation;
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
(JNIEnv* env, jobject thiz, jint fd) {
return reinterpret_cast<jlong>(
ActionsSuggestions::FromFileDescriptor(fd).release());
}
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
const std::string path_str = ToStlString(env, path);
return reinterpret_cast<jlong>(
ActionsSuggestions::FromPath(path_str).release());
}
TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME,
nativeNewActionModelsFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
return reinterpret_cast<jlong>(
ActionsSuggestions::FromFileDescriptor(fd, offset, size).release());
}
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation,
jobject joptions) {
if (!ptr) {
return nullptr;
}
const Conversation conversation = FromJavaConversation(env, jconversation);
const ActionSuggestionOptions actionSuggestionOptions =
FromJavaActionSuggestionOptions(env, joptions);
ActionsSuggestions* action_model = reinterpret_cast<ActionsSuggestions*>(ptr);
const ActionsSuggestionsResponse response =
action_model->SuggestActions(conversation, actionSuggestionOptions);
return ActionSuggestionsToJObjectArray(env, response.actions);
}
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
(JNIEnv* env, jobject clazz, jlong ptr) {
ActionsSuggestions* model = reinterpret_cast<ActionsSuggestions*>(ptr);
delete model;
}
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
}
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
return libtextclassifier3::GetNameFromMmap(env, mmap.get());
}
TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
}
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeSetAnnotator)
(JNIEnv* env, jobject clazz, jlong ptr, jlong annotatorPtr) {
if (!ptr) {
return;
}
ActionsSuggestions* action_model = reinterpret_cast<ActionsSuggestions*>(ptr);
Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
action_model->SetAnnotator(annotator);
}