blob: edeadf9c624182737c80a9c157eb91a3e4146f81 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/lua-utils.h"
namespace libtextclassifier3 {
namespace {
static constexpr const char* kTextKey = "text";
static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
static constexpr const char* kGranularityKey = "granularity";
static constexpr const char* kCollectionKey = "collection";
static constexpr const char* kNameKey = "name";
static constexpr const char* kScoreKey = "score";
static constexpr const char* kPriorityScoreKey = "priority_score";
static constexpr const char* kTypeKey = "type";
static constexpr const char* kResponseTextKey = "response_text";
static constexpr const char* kAnnotationKey = "annotation";
static constexpr const char* kSpanKey = "span";
static constexpr const char* kMessageKey = "message";
static constexpr const char* kBeginKey = "begin";
static constexpr const char* kEndKey = "end";
static constexpr const char* kClassificationKey = "classification";
static constexpr const char* kSerializedEntity = "serialized_entity";
static constexpr const char* kEntityKey = "entity";
} // namespace
template <>
int AnnotationIterator<ClassificationResult>::Item(
const std::vector<ClassificationResult>* annotations, StringPiece key,
lua_State* state) const {
// Lookup annotation by collection.
for (const ClassificationResult& annotation : *annotations) {
if (key.Equals(annotation.collection)) {
PushAnnotation(annotation, entity_data_schema_, env_);
return 1;
}
}
TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString()
<< " found.";
lua_error(state);
return 0;
}
template <>
int AnnotationIterator<ActionSuggestionAnnotation>::Item(
const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
lua_State* state) const {
// Lookup annotation by name.
for (const ActionSuggestionAnnotation& annotation : *annotations) {
if (key.Equals(annotation.name)) {
PushAnnotation(annotation, entity_data_schema_, env_);
return 1;
}
}
TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found.";
lua_error(state);
return 0;
}
void PushAnnotation(const ClassificationResult& classification,
const reflection::Schema* entity_data_schema,
LuaEnvironment* env) {
if (entity_data_schema == nullptr ||
classification.serialized_entity_data.empty()) {
// Empty table.
lua_newtable(env->state());
} else {
env->PushFlatbuffer(entity_data_schema,
flatbuffers::GetRoot<flatbuffers::Table>(
classification.serialized_entity_data.data()));
}
lua_pushinteger(env->state(),
classification.datetime_parse_result.time_ms_utc);
lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey);
lua_pushinteger(env->state(),
classification.datetime_parse_result.granularity);
lua_setfield(env->state(), /*idx=*/-2, kGranularityKey);
env->PushString(classification.collection);
lua_setfield(env->state(), /*idx=*/-2, kCollectionKey);
lua_pushnumber(env->state(), classification.score);
lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
env->PushString(classification.serialized_entity_data);
lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity);
}
void PushAnnotation(const ClassificationResult& classification,
StringPiece text,
const reflection::Schema* entity_data_schema,
LuaEnvironment* env) {
PushAnnotation(classification, entity_data_schema, env);
env->PushString(text);
lua_setfield(env->state(), /*idx=*/-2, kTextKey);
}
void PushAnnotatedSpan(
const AnnotatedSpan& annotated_span,
const AnnotationIterator<ClassificationResult>& annotation_iterator,
LuaEnvironment* env) {
lua_newtable(env->state());
{
lua_newtable(env->state());
lua_pushinteger(env->state(), annotated_span.span.first);
lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
lua_pushinteger(env->state(), annotated_span.span.second);
lua_setfield(env->state(), /*idx=*/-2, kEndKey);
}
lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
annotation_iterator.NewIterator(kClassificationKey,
&annotated_span.classification, env->state());
lua_setfield(env->state(), /*idx=*/-2, kClassificationKey);
}
MessageTextSpan ReadSpan(LuaEnvironment* env) {
MessageTextSpan span;
lua_pushnil(env->state());
while (lua_next(env->state(), /*idx=*/-2)) {
const StringPiece key = env->ReadString(/*index=*/-2);
if (key.Equals(kMessageKey)) {
span.message_index =
static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kBeginKey)) {
span.span.first =
static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kEndKey)) {
span.span.second =
static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kTextKey)) {
span.text = env->ReadString(/*index=*/-1).ToString();
} else {
TC3_LOG(INFO) << "Unknown span field: " << key.ToString();
}
lua_pop(env->state(), 1);
}
return span;
}
int ReadAnnotations(const reflection::Schema* entity_data_schema,
LuaEnvironment* env,
std::vector<ActionSuggestionAnnotation>* annotations) {
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected annotations table, got: "
<< lua_type(env->state(), /*idx=*/-1);
lua_pop(env->state(), 1);
lua_error(env->state());
return LUA_ERRRUN;
}
// Read actions.
lua_pushnil(env->state());
while (lua_next(env->state(), /*idx=*/-2)) {
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected annotation table, got: "
<< lua_type(env->state(), /*idx=*/-1);
lua_pop(env->state(), 1);
continue;
}
annotations->push_back(ReadAnnotation(entity_data_schema, env));
lua_pop(env->state(), 1);
}
return LUA_OK;
}
ActionSuggestionAnnotation ReadAnnotation(
const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
ActionSuggestionAnnotation annotation;
lua_pushnil(env->state());
while (lua_next(env->state(), /*idx=*/-2)) {
const StringPiece key = env->ReadString(/*index=*/-2);
if (key.Equals(kNameKey)) {
annotation.name = env->ReadString(/*index=*/-1).ToString();
} else if (key.Equals(kSpanKey)) {
annotation.span = ReadSpan(env);
} else if (key.Equals(kEntityKey)) {
annotation.entity = ReadClassificationResult(entity_data_schema, env);
} else {
TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString();
}
lua_pop(env->state(), 1);
}
return annotation;
}
ClassificationResult ReadClassificationResult(
const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
ClassificationResult classification;
lua_pushnil(env->state());
while (lua_next(env->state(), /*idx=*/-2)) {
const StringPiece key = env->ReadString(/*index=*/-2);
if (key.Equals(kCollectionKey)) {
classification.collection = env->ReadString(/*index=*/-1).ToString();
} else if (key.Equals(kScoreKey)) {
classification.score =
static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kTimeUsecKey)) {
classification.datetime_parse_result.time_ms_utc =
static_cast<int64>(lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kGranularityKey)) {
classification.datetime_parse_result.granularity =
static_cast<DatetimeGranularity>(
lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kSerializedEntity)) {
classification.serialized_entity_data =
env->ReadString(/*index=*/-1).ToString();
} else if (key.Equals(kEntityKey)) {
auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
env->ReadFlatbuffer(buffer.get());
classification.serialized_entity_data = buffer->Serialize();
} else {
TC3_LOG(INFO) << "Unknown classification result field: "
<< key.ToString();
}
lua_pop(env->state(), 1);
}
return classification;
}
void PushAnnotation(const ActionSuggestionAnnotation& annotation,
const reflection::Schema* entity_data_schema,
LuaEnvironment* env) {
PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema,
env);
env->PushString(annotation.name);
lua_setfield(env->state(), /*idx=*/-2, kNameKey);
{
lua_newtable(env->state());
lua_pushinteger(env->state(), annotation.span.message_index);
lua_setfield(env->state(), /*idx=*/-2, kMessageKey);
lua_pushinteger(env->state(), annotation.span.span.first);
lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
lua_pushinteger(env->state(), annotation.span.span.second);
lua_setfield(env->state(), /*idx=*/-2, kEndKey);
}
lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
}
void PushAction(
const ActionSuggestion& action,
const reflection::Schema* entity_data_schema,
const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
LuaEnvironment* env) {
if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) {
// Empty table.
lua_newtable(env->state());
} else {
env->PushFlatbuffer(entity_data_schema,
flatbuffers::GetRoot<flatbuffers::Table>(
action.serialized_entity_data.data()));
}
env->PushString(action.type);
lua_setfield(env->state(), /*idx=*/-2, kTypeKey);
env->PushString(action.response_text);
lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey);
lua_pushnumber(env->state(), action.score);
lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
lua_pushnumber(env->state(), action.priority_score);
lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey);
annotation_iterator.NewIterator(kAnnotationKey, &action.annotations,
env->state());
lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey);
}
ActionSuggestion ReadAction(
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
LuaEnvironment* env) {
ActionSuggestion action;
lua_pushnil(env->state());
while (lua_next(env->state(), /*idx=*/-2)) {
const StringPiece key = env->ReadString(/*index=*/-2);
if (key.Equals(kResponseTextKey)) {
action.response_text = env->ReadString(/*index=*/-1).ToString();
} else if (key.Equals(kTypeKey)) {
action.type = env->ReadString(/*index=*/-1).ToString();
} else if (key.Equals(kScoreKey)) {
action.score = static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kPriorityScoreKey)) {
action.priority_score =
static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
} else if (key.Equals(kAnnotationKey)) {
ReadAnnotations(actions_entity_data_schema, env, &action.annotations);
} else if (key.Equals(kEntityKey)) {
auto buffer =
ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
env->ReadFlatbuffer(buffer.get());
action.serialized_entity_data = buffer->Serialize();
} else {
TC3_LOG(INFO) << "Unknown action field: " << key.ToString();
}
lua_pop(env->state(), 1);
}
return action;
}
int ReadActions(const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
LuaEnvironment* env, std::vector<ActionSuggestion>* actions) {
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected actions table, got: "
<< lua_type(env->state(), /*idx=*/-1);
lua_pop(env->state(), 1);
lua_error(env->state());
return LUA_ERRRUN;
}
// Read actions.
lua_pushnil(env->state());
while (lua_next(env->state(), /*idx=*/-2)) {
if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected action table, got: "
<< lua_type(env->state(), /*idx=*/-1);
lua_pop(env->state(), 1);
continue;
}
actions->push_back(ReadAction(actions_entity_data_schema,
annotations_entity_data_schema, env));
lua_pop(env->state(), /*n=1*/ 1);
}
lua_pop(env->state(), /*n=*/1);
return LUA_OK;
}
int ConversationIterator::Item(const std::vector<ConversationMessage>* messages,
const int64 pos, lua_State* state) const {
const ConversationMessage& message = (*messages)[pos];
lua_newtable(state);
lua_pushinteger(state, message.user_id);
lua_setfield(state, /*idx=*/-2, "user_id");
env_->PushString(message.text);
lua_setfield(state, /*idx=*/-2, "text");
lua_pushinteger(state, message.reference_time_ms_utc);
lua_setfield(state, /*idx=*/-2, "time_ms_utc");
env_->PushString(message.reference_timezone);
lua_setfield(state, /*idx=*/-2, "timezone");
annotated_span_iterator_.NewIterator("annotation", &message.annotations,
state);
lua_setfield(state, /*idx=*/-2, "annotation");
return 1;
}
} // namespace libtextclassifier3