blob: 7cf871af65914bb145ae3766effad666123f1032 [file] [log] [blame] [edit]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/lua-actions.h"
#include "utils/base/logging.h"
#include "utils/lua-utils.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif
namespace libtextclassifier3 {
namespace {
TensorView<float> GetTensorViewForOutput(
const TfLiteModelExecutor* model_executor,
const tflite::Interpreter* interpreter, int output) {
if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
return TensorView<float>::Invalid();
}
return model_executor->OutputView<float>(output, interpreter);
}
std::vector<std::string> GetStringTensorForOutput(
const TfLiteModelExecutor* model_executor,
const tflite::Interpreter* interpreter, int output) {
if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
return {};
}
return model_executor->Output<std::string>(output, interpreter);
}
} // namespace
std::unique_ptr<LuaActionsSuggestions>
LuaActionsSuggestions::CreateLuaActionsSuggestions(
const std::string& snippet, const Conversation& conversation,
const TfLiteModelExecutor* model_executor,
const TensorflowLiteModelSpec* model_spec,
const tflite::Interpreter* interpreter,
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) {
auto lua_actions =
std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
snippet, conversation, model_executor, model_spec, interpreter,
actions_entity_data_schema, annotations_entity_data_schema));
if (!lua_actions->Initialize()) {
TC3_LOG(ERROR)
<< "Could not initialize lua environment for actions suggestions.";
return nullptr;
}
return lua_actions;
}
LuaActionsSuggestions::LuaActionsSuggestions(
const std::string& snippet, const Conversation& conversation,
const TfLiteModelExecutor* model_executor,
const TensorflowLiteModelSpec* model_spec,
const tflite::Interpreter* interpreter,
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema)
: snippet_(snippet),
conversation_(conversation),
actions_scores_(
model_spec == nullptr
? TensorView<float>::Invalid()
: GetTensorViewForOutput(model_executor, interpreter,
model_spec->output_actions_scores())),
smart_reply_scores_(
model_spec == nullptr
? TensorView<float>::Invalid()
: GetTensorViewForOutput(model_executor, interpreter,
model_spec->output_replies_scores())),
sensitivity_score_(model_spec == nullptr
? TensorView<float>::Invalid()
: GetTensorViewForOutput(
model_executor, interpreter,
model_spec->output_sensitive_topic_score())),
triggering_score_(
model_spec == nullptr
? TensorView<float>::Invalid()
: GetTensorViewForOutput(model_executor, interpreter,
model_spec->output_triggering_score())),
smart_replies_(model_spec == nullptr ? std::vector<std::string>{}
: GetStringTensorForOutput(
model_executor, interpreter,
model_spec->output_replies())),
actions_entity_data_schema_(actions_entity_data_schema),
annotations_entity_data_schema_(annotations_entity_data_schema) {}
bool LuaActionsSuggestions::Initialize() {
return RunProtected([this] {
LoadDefaultLibraries();
// Expose conversation message stream.
PushConversation(&conversation_.messages,
annotations_entity_data_schema_);
lua_setglobal(state_, "messages");
// Expose ML model output.
lua_newtable(state_);
PushTensor(&actions_scores_);
lua_setfield(state_, /*idx=*/-2, "actions_scores");
PushTensor(&smart_reply_scores_);
lua_setfield(state_, /*idx=*/-2, "reply_scores");
PushTensor(&sensitivity_score_);
lua_setfield(state_, /*idx=*/-2, "sensitivity");
PushTensor(&triggering_score_);
lua_setfield(state_, /*idx=*/-2, "triggering_score");
PushVectorIterator(&smart_replies_);
lua_setfield(state_, /*idx=*/-2, "reply");
lua_setglobal(state_, "model");
return LUA_OK;
}) == LUA_OK;
}
bool LuaActionsSuggestions::SuggestActions(
std::vector<ActionSuggestion>* actions) {
if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
/*name=*/nullptr) != LUA_OK) {
TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
return false;
}
if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
return false;
}
if (RunProtected(
[this, actions] {
return ReadActions(actions_entity_data_schema_,
annotations_entity_data_schema_, actions);
},
/*num_args=*/1) != LUA_OK) {
TC3_LOG(ERROR) << "Could not read lua result.";
return false;
}
return true;
}
} // namespace libtextclassifier3