blob: f882515c0191c5747fafa167433fa2a2dcb3ebd1 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/intents/intent-generator.h"
#include <vector>
#include "actions/lua-utils.h"
#include "actions/types.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/hash/farmhash.h"
#include "utils/java/jni-base.h"
#include "utils/java/string_utils.h"
#include "utils/lua-utils.h"
#include "utils/strings/stringpiece.h"
#include "utils/strings/substitute.h"
#include "utils/utf8/unicodetext.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
#include "flatbuffers/reflection_generated.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#ifdef __cplusplus
}
#endif
namespace libtextclassifier3 {
namespace {
static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
static constexpr const char* kHashKey = "hash";
static constexpr const char* kUrlSchemaKey = "url_schema";
static constexpr const char* kUrlHostKey = "url_host";
static constexpr const char* kUrlEncodeKey = "urlencode";
static constexpr const char* kPackageNameKey = "package_name";
static constexpr const char* kDeviceLocaleKey = "device_locales";
static constexpr const char* kFormatKey = "format";
// An Android specific Lua environment with JNI backed callbacks.
class JniLuaEnvironment : public LuaEnvironment {
public:
JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
const jobject context,
const std::vector<Locale>& device_locales);
// Environment setup.
bool Initialize();
// Runs an intent generator snippet.
bool RunIntentGenerator(const std::string& generator_snippet,
std::vector<RemoteActionTemplate>* remote_actions);
protected:
virtual void SetupExternalHook();
int HandleExternalCallback();
int HandleAndroidCallback();
int HandleUserRestrictionsCallback();
int HandleUrlEncode();
int HandleUrlSchema();
int HandleHash();
int HandleFormat();
int HandleAndroidStringResources();
int HandleUrlHost();
// Checks and retrieves string resources from the model.
bool LookupModelStringResource();
// Reads and create a RemoteAction result from Lua.
RemoteActionTemplate ReadRemoteActionTemplateResult();
// Reads the extras from the Lua result.
void ReadExtras(std::map<std::string, Variant>* extra);
// Reads the intent categories array from a Lua result.
void ReadCategories(std::vector<std::string>* category);
// Retrieves user manager if not previously done.
bool RetrieveUserManager();
// Retrieves system resources if not previously done.
bool RetrieveSystemResources();
// Parse the url string by using Uri.parse from Java.
ScopedLocalRef<jobject> ParseUri(StringPiece url) const;
// Read remote action templates from lua generator.
int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
const Resources& resources_;
JNIEnv* jenv_;
const JniCache* jni_cache_;
const jobject context_;
std::vector<Locale> device_locales_;
ScopedGlobalRef<jobject> usermanager_;
// Whether we previously attempted to retrieve the UserManager before.
bool usermanager_retrieved_;
ScopedGlobalRef<jobject> system_resources_;
// Whether we previously attempted to retrieve the system resources.
bool system_resources_resources_retrieved_;
// Cached JNI references for Java strings `string` and `android`.
ScopedGlobalRef<jstring> string_;
ScopedGlobalRef<jstring> android_;
};
JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
const JniCache* jni_cache,
const jobject context,
const std::vector<Locale>& device_locales)
: resources_(resources),
jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
jni_cache_(jni_cache),
context_(context),
device_locales_(device_locales),
usermanager_(/*object=*/nullptr,
/*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
usermanager_retrieved_(false),
system_resources_(/*object=*/nullptr,
/*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
system_resources_resources_retrieved_(false),
string_(/*object=*/nullptr,
/*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
android_(/*object=*/nullptr,
/*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
bool JniLuaEnvironment::Initialize() {
string_ =
MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
android_ =
MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
if (string_ == nullptr || android_ == nullptr) {
TC3_LOG(ERROR) << "Could not allocate constant strings references.";
return false;
}
return (RunProtected([this] {
LoadDefaultLibraries();
SetupExternalHook();
lua_setglobal(state_, "external");
return LUA_OK;
}) == LUA_OK);
}
void JniLuaEnvironment::SetupExternalHook() {
// This exposes an `external` object with the following fields:
// * entity: the bundle with all information about a classification.
// * android: callbacks into specific android provided methods.
// * android.user_restrictions: callbacks to check user permissions.
// * android.R: callbacks to retrieve string resources.
BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>(
"external");
// android
BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>(
"android");
{
// android.user_restrictions
BindTable<JniLuaEnvironment,
&JniLuaEnvironment::HandleUserRestrictionsCallback>(
"user_restrictions");
lua_setfield(state_, /*idx=*/-2, "user_restrictions");
// android.R
// Callback to access android string resources.
BindTable<JniLuaEnvironment,
&JniLuaEnvironment::HandleAndroidStringResources>("R");
lua_setfield(state_, /*idx=*/-2, "R");
}
lua_setfield(state_, /*idx=*/-2, "android");
}
int JniLuaEnvironment::HandleExternalCallback() {
const StringPiece key = ReadString(/*index=*/-1);
if (key.Equals(kHashKey)) {
Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
return 1;
} else if (key.Equals(kFormatKey)) {
Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>();
return 1;
} else {
TC3_LOG(ERROR) << "Undefined external access " << key.ToString();
lua_error(state_);
return 0;
}
}
int JniLuaEnvironment::HandleAndroidCallback() {
const StringPiece key = ReadString(/*index=*/-1);
if (key.Equals(kDeviceLocaleKey)) {
// Provide the locale as table with the individual fields set.
lua_newtable(state_);
for (int i = 0; i < device_locales_.size(); i++) {
// Adjust index to 1-based indexing for Lua.
lua_pushinteger(state_, i + 1);
lua_newtable(state_);
PushString(device_locales_[i].Language());
lua_setfield(state_, -2, "language");
PushString(device_locales_[i].Region());
lua_setfield(state_, -2, "region");
PushString(device_locales_[i].Script());
lua_setfield(state_, -2, "script");
lua_settable(state_, /*idx=*/-3);
}
return 1;
} else if (key.Equals(kPackageNameKey)) {
if (context_ == nullptr) {
TC3_LOG(ERROR) << "Context invalid.";
lua_error(state_);
return 0;
}
ScopedLocalRef<jstring> package_name_str(
static_cast<jstring>(jenv_->CallObjectMethod(
context_, jni_cache_->context_get_package_name)));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling Context.getPackageName";
lua_error(state_);
return 0;
}
PushString(ToStlString(jenv_, package_name_str.get()));
return 1;
} else if (key.Equals(kUrlEncodeKey)) {
Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
return 1;
} else if (key.Equals(kUrlHostKey)) {
Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>();
return 1;
} else if (key.Equals(kUrlSchemaKey)) {
Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>();
return 1;
} else {
TC3_LOG(ERROR) << "Undefined android reference " << key.ToString();
lua_error(state_);
return 0;
}
}
int JniLuaEnvironment::HandleUserRestrictionsCallback() {
if (jni_cache_->usermanager_class == nullptr ||
jni_cache_->usermanager_get_user_restrictions == nullptr) {
// UserManager is only available for API level >= 17 and
// getUserRestrictions only for API level >= 18, so we just return false
// normally here.
lua_pushboolean(state_, false);
return 1;
}
// Get user manager if not previously retrieved.
if (!RetrieveUserManager()) {
TC3_LOG(ERROR) << "Error retrieving user manager.";
lua_error(state_);
return 0;
}
ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
TC3_LOG(ERROR) << "Error calling getUserRestrictions";
lua_error(state_);
return 0;
}
const StringPiece key_str = ReadString(/*index=*/-1);
if (key_str.empty()) {
TC3_LOG(ERROR) << "Expected string, got null.";
lua_error(state_);
return 0;
}
ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
TC3_LOG(ERROR) << "Expected string, got null.";
lua_error(state_);
return 0;
}
const bool permission = jenv_->CallBooleanMethod(
bundle.get(), jni_cache_->bundle_get_boolean, key.get());
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error getting bundle value";
lua_pushboolean(state_, false);
} else {
lua_pushboolean(state_, permission);
}
return 1;
}
int JniLuaEnvironment::HandleUrlEncode() {
const StringPiece input = ReadString(/*index=*/1);
if (input.empty()) {
TC3_LOG(ERROR) << "Expected string, got null.";
lua_error(state_);
return 0;
}
// Call Java URL encoder.
ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
TC3_LOG(ERROR) << "Expected string, got null.";
lua_error(state_);
return 0;
}
ScopedLocalRef<jstring> encoded_str(
static_cast<jstring>(jenv_->CallStaticObjectMethod(
jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
input_str.get(), jni_cache_->string_utf8.get())));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
lua_error(state_);
return 0;
}
PushString(ToStlString(jenv_, encoded_str.get()));
return 1;
}
ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
if (url.empty()) {
return nullptr;
}
// Call to Java URI parser.
ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
TC3_LOG(ERROR) << "Expected string, got null";
return nullptr;
}
// Try to parse uri and get scheme.
ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
TC3_LOG(ERROR) << "Error calling Uri.parse";
}
return uri;
}
int JniLuaEnvironment::HandleUrlSchema() {
StringPiece url = ReadString(/*index=*/1);
ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
if (parsed_uri == nullptr) {
lua_error(state_);
return 0;
}
ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling Uri.getScheme";
lua_error(state_);
return 0;
}
if (scheme_str == nullptr) {
lua_pushnil(state_);
} else {
PushString(ToStlString(jenv_, scheme_str.get()));
}
return 1;
}
int JniLuaEnvironment::HandleUrlHost() {
StringPiece url = ReadString(/*index=*/-1);
ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
if (parsed_uri == nullptr) {
lua_error(state_);
return 0;
}
ScopedLocalRef<jstring> host_str(static_cast<jstring>(
jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling Uri.getHost";
lua_error(state_);
return 0;
}
if (host_str == nullptr) {
lua_pushnil(state_);
} else {
PushString(ToStlString(jenv_, host_str.get()));
}
return 1;
}
int JniLuaEnvironment::HandleHash() {
const StringPiece input = ReadString(/*index=*/-1);
lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
return 1;
}
int JniLuaEnvironment::HandleFormat() {
const int num_args = lua_gettop(state_);
std::vector<StringPiece> args(num_args - 1);
for (int i = 0; i < num_args - 1; i++) {
args[i] = ReadString(/*index=*/i + 2);
}
PushString(strings::Substitute(ReadString(/*index=*/1), args));
return 1;
}
bool JniLuaEnvironment::LookupModelStringResource() {
// Handle only lookup by name.
if (lua_type(state_, 2) != LUA_TSTRING) {
return false;
}
const StringPiece resource_name = ReadString(/*index=*/-1);
std::string resource_content;
if (!resources_.GetResourceContent(device_locales_, resource_name,
&resource_content)) {
// Resource cannot be provided by the model.
return false;
}
PushString(resource_content);
return true;
}
int JniLuaEnvironment::HandleAndroidStringResources() {
// Check whether the requested resource can be served from the model data.
if (LookupModelStringResource()) {
return 1;
}
// Get system resources if not previously retrieved.
if (!RetrieveSystemResources()) {
TC3_LOG(ERROR) << "Error retrieving system resources.";
lua_error(state_);
return 0;
}
int resource_id;
switch (lua_type(state_, -1)) {
case LUA_TNUMBER:
resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
break;
case LUA_TSTRING: {
const StringPiece resource_name_str = ReadString(/*index=*/-1);
if (resource_name_str.empty()) {
TC3_LOG(ERROR) << "No resource name provided.";
lua_error(state_);
return 0;
}
ScopedLocalRef<jstring> resource_name =
jni_cache_->ConvertToJavaString(resource_name_str);
if (resource_name == nullptr) {
TC3_LOG(ERROR) << "Invalid resource name.";
lua_error(state_);
return 0;
}
resource_id = jenv_->CallIntMethod(
system_resources_.get(), jni_cache_->resources_get_identifier,
resource_name.get(), string_.get(), android_.get());
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling getIdentifier.";
lua_error(state_);
return 0;
}
break;
}
default:
TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
lua_error(state_);
return 0;
}
if (resource_id == 0) {
TC3_LOG(ERROR) << "Resource not found.";
lua_pushnil(state_);
return 1;
}
ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
jenv_->CallObjectMethod(system_resources_.get(),
jni_cache_->resources_get_string, resource_id)));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling getString.";
lua_error(state_);
return 0;
}
if (resource_str == nullptr) {
lua_pushnil(state_);
} else {
PushString(ToStlString(jenv_, resource_str.get()));
}
return 1;
}
bool JniLuaEnvironment::RetrieveSystemResources() {
if (system_resources_resources_retrieved_) {
return (system_resources_ != nullptr);
}
system_resources_resources_retrieved_ = true;
jobject system_resources_ref = jenv_->CallStaticObjectMethod(
jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling getSystem.";
return false;
}
system_resources_ =
MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
return (system_resources_ != nullptr);
}
bool JniLuaEnvironment::RetrieveUserManager() {
if (context_ == nullptr) {
return false;
}
if (usermanager_retrieved_) {
return (usermanager_ != nullptr);
}
usermanager_retrieved_ = true;
ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
jobject usermanager_ref = jenv_->CallObjectMethod(
context_, jni_cache_->context_get_system_service, service.get());
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling getSystemService.";
return false;
}
usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
return (usermanager_ != nullptr);
}
RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
RemoteActionTemplate result;
// Read intent template.
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
const StringPiece key = ReadString(/*index=*/-2);
if (key.Equals("title_without_entity")) {
result.title_without_entity = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("title_with_entity")) {
result.title_with_entity = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("description")) {
result.description = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("description_with_app_name")) {
result.description_with_app_name = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("action")) {
result.action = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("data")) {
result.data = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("type")) {
result.type = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("flags")) {
result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
} else if (key.Equals("package_name")) {
result.package_name = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("request_code")) {
result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
} else if (key.Equals("category")) {
ReadCategories(&result.category);
} else if (key.Equals("extra")) {
ReadExtras(&result.extra);
} else {
TC3_LOG(INFO) << "Unknown entry: " << key.ToString();
}
lua_pop(state_, 1);
}
lua_pop(state_, 1);
return result;
}
void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
// Read category array.
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected categories table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
return;
}
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
category->push_back(ReadString(/*index=*/-1).ToString());
lua_pop(state_, 1);
}
}
void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected extras table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
return;
}
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
// Each entry is a table specifying name and value.
// The value is specified via a type specific field as Lua doesn't allow
// to easily distinguish between different number types.
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected a table for an extra, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
return;
}
std::string name;
Variant value;
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
const StringPiece key = ReadString(/*index=*/-2);
if (key.Equals("name")) {
name = ReadString(/*index=*/-1).ToString();
} else if (key.Equals("int_value")) {
value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
} else if (key.Equals("long_value")) {
value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
} else if (key.Equals("float_value")) {
value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
} else if (key.Equals("bool_value")) {
value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
} else if (key.Equals("string_value")) {
value = Variant(ReadString(/*index=*/-1).ToString());
} else {
TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
}
lua_pop(state_, 1);
}
if (!name.empty()) {
(*extra)[name] = value;
} else {
TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
}
lua_pop(state_, 1);
}
}
int JniLuaEnvironment::ReadRemoteActionTemplates(
std::vector<RemoteActionTemplate>* result) {
// Read result.
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
lua_error(state_);
return LUA_ERRRUN;
}
// Read remote action templates array.
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected intent table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
continue;
}
result->push_back(ReadRemoteActionTemplateResult());
}
lua_pop(state_, /*n=*/1);
return LUA_OK;
}
bool JniLuaEnvironment::RunIntentGenerator(
const std::string& generator_snippet,
std::vector<RemoteActionTemplate>* remote_actions) {
int status;
status = luaL_loadbuffer(state_, generator_snippet.data(),
generator_snippet.size(),
/*name=*/nullptr);
if (status != LUA_OK) {
TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
return false;
}
status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
if (status != LUA_OK) {
TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
return false;
}
if (RunProtected(
[this, remote_actions] {
return ReadRemoteActionTemplates(remote_actions);
},
/*num_args=*/1) != LUA_OK) {
TC3_LOG(ERROR) << "Could not read results.";
return false;
}
// Check that we correctly cleaned-up the state.
const int stack_size = lua_gettop(state_);
if (stack_size > 0) {
TC3_LOG(ERROR) << "Unexpected stack size.";
lua_settop(state_, 0);
return false;
}
return true;
}
// Lua environment for classfication result intent generation.
class AnnotatorJniEnvironment : public JniLuaEnvironment {
public:
AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
const jobject context,
const std::vector<Locale>& device_locales,
const std::string& entity_text,
const ClassificationResult& classification,
const int64 reference_time_ms_utc,
const reflection::Schema* entity_data_schema)
: JniLuaEnvironment(resources, jni_cache, context, device_locales),
entity_text_(entity_text),
classification_(classification),
reference_time_ms_utc_(reference_time_ms_utc),
entity_data_schema_(entity_data_schema) {}
protected:
void SetupExternalHook() override {
JniLuaEnvironment::SetupExternalHook();
lua_pushinteger(state_, reference_time_ms_utc_);
lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
PushAnnotation(classification_, entity_text_, entity_data_schema_, this);
lua_setfield(state_, /*idx=*/-2, "entity");
}
const std::string& entity_text_;
const ClassificationResult& classification_;
const int64 reference_time_ms_utc_;
// Reflection schema data.
const reflection::Schema* const entity_data_schema_;
};
// Lua environment for actions intent generation.
class ActionsJniLuaEnvironment : public JniLuaEnvironment {
public:
ActionsJniLuaEnvironment(
const Resources& resources, const JniCache* jni_cache,
const jobject context, const std::vector<Locale>& device_locales,
const Conversation& conversation, const ActionSuggestion& action,
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema)
: JniLuaEnvironment(resources, jni_cache, context, device_locales),
conversation_(conversation),
action_(action),
annotation_iterator_(annotations_entity_data_schema, this),
conversation_iterator_(annotations_entity_data_schema, this),
entity_data_schema_(actions_entity_data_schema) {}
protected:
void SetupExternalHook() override {
JniLuaEnvironment::SetupExternalHook();
conversation_iterator_.NewIterator("conversation", &conversation_.messages,
state_);
lua_setfield(state_, /*idx=*/-2, "conversation");
PushAction(action_, entity_data_schema_, annotation_iterator_, this);
lua_setfield(state_, /*idx=*/-2, "entity");
}
const Conversation& conversation_;
const ActionSuggestion& action_;
const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
const ConversationIterator conversation_iterator_;
const reflection::Schema* entity_data_schema_;
};
} // namespace
std::unique_ptr<IntentGenerator> IntentGenerator::Create(
const IntentFactoryModel* options, const ResourcePool* resources,
const std::shared_ptr<JniCache>& jni_cache) {
std::unique_ptr<IntentGenerator> intent_generator(
new IntentGenerator(options, resources, jni_cache));
if (options == nullptr || options->generator() == nullptr) {
TC3_LOG(ERROR) << "No intent generator options.";
return nullptr;
}
std::unique_ptr<ZlibDecompressor> zlib_decompressor =
ZlibDecompressor::Instance();
if (!zlib_decompressor) {
TC3_LOG(ERROR) << "Cannot initialize decompressor.";
return nullptr;
}
for (const IntentFactoryModel_::IntentGenerator* generator :
*options->generator()) {
std::string lua_template_generator;
if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
generator->lua_template_generator(),
generator->compressed_lua_template_generator(),
&lua_template_generator)) {
TC3_LOG(ERROR) << "Could not decompress generator template.";
return nullptr;
}
std::string lua_code = lua_template_generator;
if (options->precompile_generators()) {
if (!Compile(lua_template_generator, &lua_code)) {
TC3_LOG(ERROR) << "Could not precompile generator template.";
return nullptr;
}
}
intent_generator->generators_[generator->type()->str()] = lua_code;
}
return intent_generator;
}
std::vector<Locale> IntentGenerator::ParseDeviceLocales(
const jstring device_locales) const {
if (device_locales == nullptr) {
TC3_LOG(ERROR) << "No locales provided.";
return {};
}
ScopedStringChars locales_str =
GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
if (locales_str == nullptr) {
TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
return {};
}
std::vector<Locale> locales;
if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
&locales)) {
TC3_LOG(ERROR) << "Cannot parse locales.";
return {};
}
return locales;
}
bool IntentGenerator::GenerateIntents(
const jstring device_locales, const ClassificationResult& classification,
const int64 reference_time_ms_utc, const std::string& text,
const CodepointSpan selection_indices, const jobject context,
const reflection::Schema* annotations_entity_data_schema,
std::vector<RemoteActionTemplate>* remote_actions) const {
if (options_ == nullptr) {
return false;
}
// Retrieve generator for specified entity.
auto it = generators_.find(classification.collection);
if (it == generators_.end()) {
return true;
}
const std::string entity_text =
UTF8ToUnicodeText(text, /*do_copy=*/false)
.UTF8Substring(selection_indices.first, selection_indices.second);
std::unique_ptr<AnnotatorJniEnvironment> interpreter(
new AnnotatorJniEnvironment(
resources_, jni_cache_.get(), context,
ParseDeviceLocales(device_locales), entity_text, classification,
reference_time_ms_utc, annotations_entity_data_schema));
if (!interpreter->Initialize()) {
TC3_LOG(ERROR) << "Could not create Lua interpreter.";
return false;
}
return interpreter->RunIntentGenerator(it->second, remote_actions);
}
bool IntentGenerator::GenerateIntents(
const jstring device_locales, const ActionSuggestion& action,
const Conversation& conversation, const jobject context,
const reflection::Schema* annotations_entity_data_schema,
const reflection::Schema* actions_entity_data_schema,
std::vector<RemoteActionTemplate>* remote_actions) const {
if (options_ == nullptr) {
return false;
}
// Retrieve generator for specified action.
auto it = generators_.find(action.type);
if (it == generators_.end()) {
return true;
}
std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
new ActionsJniLuaEnvironment(
resources_, jni_cache_.get(), context,
ParseDeviceLocales(device_locales), conversation, action,
actions_entity_data_schema, annotations_entity_data_schema));
if (!interpreter->Initialize()) {
TC3_LOG(ERROR) << "Could not create Lua interpreter.";
return false;
}
return interpreter->RunIntentGenerator(it->second, remote_actions);
}
} // namespace libtextclassifier3