blob: 73032d01ee52b99b50d5fcc141426e7d8b26dc56 [file] [log] [blame] [edit]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/lua-ranker.h"
#include "utils/base/logging.h"
#include "utils/lua-utils.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif
namespace libtextclassifier3 {
std::unique_ptr<ActionsSuggestionsLuaRanker>
ActionsSuggestionsLuaRanker::Create(
const Conversation& conversation, const std::string& ranker_code,
const reflection::Schema* entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
ActionsSuggestionsResponse* response) {
auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
new ActionsSuggestionsLuaRanker(
conversation, ranker_code, entity_data_schema,
annotations_entity_data_schema, response));
if (!ranker->Initialize()) {
TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
return nullptr;
}
return ranker;
}
bool ActionsSuggestionsLuaRanker::Initialize() {
return RunProtected([this] {
LoadDefaultLibraries();
// Expose generated actions.
PushActions(&response_->actions, actions_entity_data_schema_,
annotations_entity_data_schema_);
lua_setglobal(state_, "actions");
// Expose conversation message stream.
PushConversation(&conversation_.messages,
annotations_entity_data_schema_);
lua_setglobal(state_, "messages");
return LUA_OK;
}) == LUA_OK;
}
int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected actions table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
lua_error(state_);
return LUA_ERRRUN;
}
std::vector<ActionSuggestion> ranked_actions;
lua_pushnil(state_);
while (Next(/*index=*/-2)) {
const int action_id = Read<int>(/*index=*/-1) - 1;
lua_pop(state_, 1);
if (action_id < 0 || action_id >= response_->actions.size()) {
TC3_LOG(ERROR) << "Invalid action index: " << action_id;
lua_error(state_);
return LUA_ERRRUN;
}
ranked_actions.push_back(response_->actions[action_id]);
}
lua_pop(state_, 1);
response_->actions = ranked_actions;
return LUA_OK;
}
bool ActionsSuggestionsLuaRanker::RankActions() {
if (response_->actions.empty()) {
// Nothing to do.
return true;
}
if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
/*name=*/nullptr) != LUA_OK) {
TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
return false;
}
if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
TC3_LOG(ERROR) << "Could not run ranking snippet.";
return false;
}
if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
LUA_OK) {
TC3_LOG(ERROR) << "Could not read lua result.";
return false;
}
return true;
}
} // namespace libtextclassifier3