blob: 9117c54068c76d7bac2107594a61cd5e066fd03c [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/lua-utils.h"
namespace libtextclassifier3 {
namespace {
static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
{LUA_TABLIBNAME, luaopen_table},
{LUA_STRLIBNAME, luaopen_string},
{LUA_MATHLIBNAME, luaopen_math},
{nullptr, nullptr}};
static constexpr const char kTextKey[] = "text";
static constexpr const char kTimeUsecKey[] = "parsed_time_ms_utc";
static constexpr const char kGranularityKey[] = "granularity";
static constexpr const char kCollectionKey[] = "collection";
static constexpr const char kNameKey[] = "name";
static constexpr const char kScoreKey[] = "score";
static constexpr const char kPriorityScoreKey[] = "priority_score";
static constexpr const char kTypeKey[] = "type";
static constexpr const char kResponseTextKey[] = "response_text";
static constexpr const char kAnnotationKey[] = "annotation";
static constexpr const char kSpanKey[] = "span";
static constexpr const char kMessageKey[] = "message";
static constexpr const char kBeginKey[] = "begin";
static constexpr const char kEndKey[] = "end";
static constexpr const char kClassificationKey[] = "classification";
static constexpr const char kSerializedEntity[] = "serialized_entity";
static constexpr const char kEntityKey[] = "entity";
// Implementation of a lua_Writer that appends the data to a string.
int LuaStringWriter(lua_State* state, const void* data, size_t size,
void* result) {
std::string* const result_string = static_cast<std::string*>(result);
result_string->insert(result_string->size(), static_cast<const char*>(data),
size);
return LUA_OK;
}
} // namespace
LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
LuaEnvironment::~LuaEnvironment() {
if (state_ != nullptr) {
lua_close(state_);
}
}
void LuaEnvironment::PushFlatbuffer(const reflection::Schema* schema,
const reflection::Object* type,
const flatbuffers::Table* table) const {
PushLazyObject(
std::bind(&LuaEnvironment::GetField, this, schema, type, table));
}
int LuaEnvironment::GetField(const reflection::Schema* schema,
const reflection::Object* type,
const flatbuffers::Table* table) const {
const char* field_name = lua_tostring(state_, /*idx=*/kIndexStackTop);
const reflection::Field* field = type->fields()->LookupByKey(field_name);
if (field == nullptr) {
lua_error(state_);
return 0;
}
// Provide primitive fields directly.
const reflection::BaseType field_type = field->type()->base_type();
switch (field_type) {
case reflection::Bool:
Push(table->GetField<bool>(field->offset(), field->default_integer()));
break;
case reflection::UByte:
Push(table->GetField<uint8>(field->offset(), field->default_integer()));
break;
case reflection::Byte:
Push(table->GetField<int8>(field->offset(), field->default_integer()));
break;
case reflection::Int:
Push(table->GetField<int32>(field->offset(), field->default_integer()));
break;
case reflection::UInt:
Push(table->GetField<uint32>(field->offset(), field->default_integer()));
break;
case reflection::Long:
Push(table->GetField<int64>(field->offset(), field->default_integer()));
break;
case reflection::ULong:
Push(table->GetField<uint64>(field->offset(), field->default_integer()));
break;
case reflection::Float:
Push(table->GetField<float>(field->offset(), field->default_real()));
break;
case reflection::Double:
Push(table->GetField<double>(field->offset(), field->default_real()));
break;
case reflection::String: {
Push(table->GetPointer<const flatbuffers::String*>(field->offset()));
break;
}
case reflection::Obj: {
const flatbuffers::Table* field_table =
table->GetPointer<const flatbuffers::Table*>(field->offset());
if (field_table == nullptr) {
// Field was not set in entity data.
return 0;
}
const reflection::Object* field_type =
schema->objects()->Get(field->type()->index());
PushFlatbuffer(schema, field_type, field_table);
break;
}
case reflection::Vector: {
const flatbuffers::Vector<flatbuffers::Offset<void>>* field_vector =
table->GetPointer<
const flatbuffers::Vector<flatbuffers::Offset<void>>*>(
field->offset());
if (field_vector == nullptr) {
// Repeated field was not set in flatbuffer.
PushEmptyVector();
break;
}
switch (field->type()->element()) {
case reflection::Bool:
PushRepeatedField(table->GetPointer<const flatbuffers::Vector<bool>*>(
field->offset()));
break;
case reflection::UByte:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<uint8>*>(
field->offset()));
break;
case reflection::Byte:
PushRepeatedField(table->GetPointer<const flatbuffers::Vector<int8>*>(
field->offset()));
break;
case reflection::Int:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<int32>*>(
field->offset()));
break;
case reflection::UInt:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<uint32>*>(
field->offset()));
break;
case reflection::Long:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<int64>*>(
field->offset()));
break;
case reflection::ULong:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<uint64>*>(
field->offset()));
break;
case reflection::Float:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<float>*>(
field->offset()));
break;
case reflection::Double:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<double>*>(
field->offset()));
break;
case reflection::String:
PushRepeatedField(
table->GetPointer<const flatbuffers::Vector<
flatbuffers::Offset<flatbuffers::String>>*>(field->offset()));
break;
case reflection::Obj:
PushRepeatedFlatbufferField(
schema, schema->objects()->Get(field->type()->index()),
table->GetPointer<const flatbuffers::Vector<
flatbuffers::Offset<flatbuffers::Table>>*>(field->offset()));
break;
default:
TC3_LOG(ERROR) << "Unsupported repeated type: "
<< field->type()->element();
lua_error(state_);
return 0;
}
break;
}
default:
TC3_LOG(ERROR) << "Unsupported type: " << field_type;
lua_error(state_);
return 0;
}
return 1;
}
int LuaEnvironment::ReadFlatbuffer(const int index,
MutableFlatbuffer* buffer) const {
if (buffer == nullptr) {
TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
lua_error(state_);
return LUA_ERRRUN;
}
if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected table, got: "
<< lua_type(state_, /*idx=*/kIndexStackTop);
lua_error(state_);
return LUA_ERRRUN;
}
lua_pushnil(state_);
while (Next(index - 1)) {
const StringPiece key = ReadString(/*index=*/index - 1);
const reflection::Field* field = buffer->GetFieldOrNull(key);
if (field == nullptr) {
TC3_LOG(ERROR) << "Unknown field: " << key;
lua_error(state_);
return LUA_ERRRUN;
}
switch (field->type()->base_type()) {
case reflection::Obj:
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer->Mutable(field));
break;
case reflection::Bool:
buffer->Set(field, Read<bool>(/*index=*/kIndexStackTop));
break;
case reflection::Byte:
buffer->Set(field, Read<int8>(/*index=*/kIndexStackTop));
break;
case reflection::UByte:
buffer->Set(field, Read<uint8>(/*index=*/kIndexStackTop));
break;
case reflection::Int:
buffer->Set(field, Read<int32>(/*index=*/kIndexStackTop));
break;
case reflection::UInt:
buffer->Set(field, Read<uint32>(/*index=*/kIndexStackTop));
break;
case reflection::Long:
buffer->Set(field, Read<int64>(/*index=*/kIndexStackTop));
break;
case reflection::ULong:
buffer->Set(field, Read<uint64>(/*index=*/kIndexStackTop));
break;
case reflection::Float:
buffer->Set(field, Read<float>(/*index=*/kIndexStackTop));
break;
case reflection::Double:
buffer->Set(field, Read<double>(/*index=*/kIndexStackTop));
break;
case reflection::String: {
buffer->Set(field, ReadString(/*index=*/kIndexStackTop));
break;
}
case reflection::Vector: {
// Read repeated field.
switch (field->type()->element()) {
case reflection::Bool:
ReadRepeatedField<bool>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::Byte:
ReadRepeatedField<int8>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::UByte:
ReadRepeatedField<uint8>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::Int:
ReadRepeatedField<int32>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::UInt:
ReadRepeatedField<uint32>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::Long:
ReadRepeatedField<int64>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::ULong:
ReadRepeatedField<uint64>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::Float:
ReadRepeatedField<float>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::Double:
ReadRepeatedField<double>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::String:
ReadRepeatedField<std::string>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
case reflection::Obj:
ReadRepeatedField<MutableFlatbuffer>(/*index=*/kIndexStackTop,
buffer->Repeated(field));
break;
default:
TC3_LOG(ERROR) << "Unsupported repeated field type: "
<< field->type()->element();
lua_error(state_);
return LUA_ERRRUN;
}
break;
}
default:
TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
lua_error(state_);
return LUA_ERRRUN;
}
lua_pop(state_, 1);
}
return LUA_OK;
}
void LuaEnvironment::LoadDefaultLibraries() {
for (const luaL_Reg* lib = defaultlibs; lib->func; lib++) {
luaL_requiref(state_, lib->name, lib->func, 1);
lua_pop(state_, 1); // Remove lib.
}
}
StringPiece LuaEnvironment::ReadString(const int index) const {
size_t length = 0;
const char* data = lua_tolstring(state_, index, &length);
return StringPiece(data, length);
}
void LuaEnvironment::PushString(const StringPiece str) const {
lua_pushlstring(state_, str.data(), str.size());
}
bool LuaEnvironment::Compile(StringPiece snippet, std::string* bytecode) const {
if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
/*name=*/nullptr) != LUA_OK) {
TC3_LOG(ERROR) << "Could not compile lua snippet: "
<< ReadString(/*index=*/kIndexStackTop);
lua_pop(state_, 1);
return false;
}
if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
lua_pop(state_, 1);
return false;
}
lua_pop(state_, 1);
return true;
}
void LuaEnvironment::PushAnnotation(
const ClassificationResult& classification,
const reflection::Schema* entity_data_schema) const {
if (entity_data_schema == nullptr ||
classification.serialized_entity_data.empty()) {
// Empty table.
lua_newtable(state_);
} else {
PushFlatbuffer(entity_data_schema,
flatbuffers::GetRoot<flatbuffers::Table>(
classification.serialized_entity_data.data()));
}
Push(classification.datetime_parse_result.time_ms_utc);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTimeUsecKey);
Push(classification.datetime_parse_result.granularity);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kGranularityKey);
Push(classification.collection);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kCollectionKey);
Push(classification.score);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
Push(classification.serialized_entity_data);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSerializedEntity);
}
void LuaEnvironment::PushAnnotation(
const ClassificationResult& classification, StringPiece text,
const reflection::Schema* entity_data_schema) const {
PushAnnotation(classification, entity_data_schema);
Push(text);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTextKey);
}
void LuaEnvironment::PushAnnotation(
const ActionSuggestionAnnotation& annotation,
const reflection::Schema* entity_data_schema) const {
PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema);
PushString(annotation.name);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kNameKey);
{
lua_newtable(state_);
Push(annotation.span.message_index);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kMessageKey);
Push(annotation.span.span.first);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
Push(annotation.span.span.second);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
}
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
}
void LuaEnvironment::PushAnnotatedSpan(
const AnnotatedSpan& annotated_span,
const reflection::Schema* entity_data_schema) const {
lua_newtable(state_);
{
lua_newtable(state_);
Push(annotated_span.span.first);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
Push(annotated_span.span.second);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
}
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
PushAnnotations(&annotated_span.classification, entity_data_schema);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kClassificationKey);
}
void LuaEnvironment::PushAnnotatedSpans(
const std::vector<AnnotatedSpan>* annotated_spans,
const reflection::Schema* entity_data_schema) const {
PushIterator(annotated_spans ? annotated_spans->size() : 0,
[this, annotated_spans, entity_data_schema](const int64 index) {
PushAnnotatedSpan(annotated_spans->at(index),
entity_data_schema);
return 1;
});
}
MessageTextSpan LuaEnvironment::ReadSpan() const {
MessageTextSpan span;
lua_pushnil(state_);
while (Next(/*index=*/kIndexStackTop - 1)) {
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
if (key.Equals(kMessageKey)) {
span.message_index = Read<int>(/*index=*/kIndexStackTop);
} else if (key.Equals(kBeginKey)) {
span.span.first = Read<int>(/*index=*/kIndexStackTop);
} else if (key.Equals(kEndKey)) {
span.span.second = Read<int>(/*index=*/kIndexStackTop);
} else if (key.Equals(kTextKey)) {
span.text = Read<std::string>(/*index=*/kIndexStackTop);
} else {
TC3_LOG(INFO) << "Unknown span field: " << key;
}
lua_pop(state_, 1);
}
return span;
}
int LuaEnvironment::ReadAnnotations(
const reflection::Schema* entity_data_schema,
std::vector<ActionSuggestionAnnotation>* annotations) const {
if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected annotations table, got: "
<< lua_type(state_, /*idx=*/kIndexStackTop);
lua_pop(state_, 1);
lua_error(state_);
return LUA_ERRRUN;
}
// Read actions.
lua_pushnil(state_);
while (Next(/*index=*/kIndexStackTop - 1)) {
if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected annotation table, got: "
<< lua_type(state_, /*idx=*/kIndexStackTop);
lua_pop(state_, 1);
continue;
}
annotations->push_back(ReadAnnotation(entity_data_schema));
lua_pop(state_, 1);
}
return LUA_OK;
}
ActionSuggestionAnnotation LuaEnvironment::ReadAnnotation(
const reflection::Schema* entity_data_schema) const {
ActionSuggestionAnnotation annotation;
lua_pushnil(state_);
while (Next(/*index=*/kIndexStackTop - 1)) {
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
if (key.Equals(kNameKey)) {
annotation.name = Read<std::string>(/*index=*/kIndexStackTop);
} else if (key.Equals(kSpanKey)) {
annotation.span = ReadSpan();
} else if (key.Equals(kEntityKey)) {
annotation.entity = ReadClassificationResult(entity_data_schema);
} else {
TC3_LOG(ERROR) << "Unknown annotation field: " << key;
}
lua_pop(state_, 1);
}
return annotation;
}
ClassificationResult LuaEnvironment::ReadClassificationResult(
const reflection::Schema* entity_data_schema) const {
ClassificationResult classification;
lua_pushnil(state_);
while (Next(/*index=*/kIndexStackTop - 1)) {
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
if (key.Equals(kCollectionKey)) {
classification.collection = Read<std::string>(/*index=*/kIndexStackTop);
} else if (key.Equals(kScoreKey)) {
classification.score = Read<float>(/*index=*/kIndexStackTop);
} else if (key.Equals(kTimeUsecKey)) {
classification.datetime_parse_result.time_ms_utc =
Read<int64>(/*index=*/kIndexStackTop);
} else if (key.Equals(kGranularityKey)) {
classification.datetime_parse_result.granularity =
static_cast<DatetimeGranularity>(
lua_tonumber(state_, /*idx=*/kIndexStackTop));
} else if (key.Equals(kSerializedEntity)) {
classification.serialized_entity_data =
Read<std::string>(/*index=*/kIndexStackTop);
} else if (key.Equals(kEntityKey)) {
auto buffer = MutableFlatbufferBuilder(entity_data_schema).NewRoot();
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
classification.serialized_entity_data = buffer->Serialize();
} else {
TC3_LOG(INFO) << "Unknown classification result field: " << key;
}
lua_pop(state_, 1);
}
return classification;
}
void LuaEnvironment::PushAction(
const ActionSuggestion& action,
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) const {
if (actions_entity_data_schema == nullptr ||
action.serialized_entity_data.empty()) {
// Empty table.
lua_newtable(state_);
} else {
PushFlatbuffer(actions_entity_data_schema,
flatbuffers::GetRoot<flatbuffers::Table>(
action.serialized_entity_data.data()));
}
PushString(action.type);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTypeKey);
PushString(action.response_text);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kResponseTextKey);
Push(action.score);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
Push(action.priority_score);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kPriorityScoreKey);
PushAnnotations(&action.annotations, annotations_entity_data_schema);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kAnnotationKey);
}
void LuaEnvironment::PushActions(
const std::vector<ActionSuggestion>* actions,
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) const {
PushIterator(actions ? actions->size() : 0,
[this, actions, actions_entity_data_schema,
annotations_entity_data_schema](const int64 index) {
PushAction(actions->at(index), actions_entity_data_schema,
annotations_entity_data_schema);
return 1;
});
}
ActionSuggestion LuaEnvironment::ReadAction(
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) const {
ActionSuggestion action;
lua_pushnil(state_);
while (Next(/*index=*/kIndexStackTop - 1)) {
const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
if (key.Equals(kResponseTextKey)) {
action.response_text = Read<std::string>(/*index=*/kIndexStackTop);
} else if (key.Equals(kTypeKey)) {
action.type = Read<std::string>(/*index=*/kIndexStackTop);
} else if (key.Equals(kScoreKey)) {
action.score = Read<float>(/*index=*/kIndexStackTop);
} else if (key.Equals(kPriorityScoreKey)) {
action.priority_score = Read<float>(/*index=*/kIndexStackTop);
} else if (key.Equals(kAnnotationKey)) {
ReadAnnotations(actions_entity_data_schema, &action.annotations);
} else if (key.Equals(kEntityKey)) {
auto buffer =
MutableFlatbufferBuilder(actions_entity_data_schema).NewRoot();
ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
action.serialized_entity_data = buffer->Serialize();
} else {
TC3_LOG(INFO) << "Unknown action field: " << key;
}
lua_pop(state_, 1);
}
return action;
}
int LuaEnvironment::ReadActions(
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
std::vector<ActionSuggestion>* actions) const {
// Read actions.
lua_pushnil(state_);
while (Next(/*index=*/kIndexStackTop - 1)) {
if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected action table, got: "
<< lua_type(state_, /*idx=*/kIndexStackTop);
lua_pop(state_, 1);
continue;
}
actions->push_back(
ReadAction(actions_entity_data_schema, annotations_entity_data_schema));
lua_pop(state_, /*n=*/1);
}
lua_pop(state_, /*n=*/1);
return LUA_OK;
}
void LuaEnvironment::PushConversation(
const std::vector<ConversationMessage>* conversation,
const reflection::Schema* annotations_entity_data_schema) const {
PushIterator(
conversation ? conversation->size() : 0,
[this, conversation, annotations_entity_data_schema](const int64 index) {
const ConversationMessage& message = conversation->at(index);
lua_newtable(state_);
Push(message.user_id);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "user_id");
Push(message.text);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "text");
Push(message.reference_time_ms_utc);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "time_ms_utc");
Push(message.reference_timezone);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "timezone");
PushAnnotatedSpans(&message.annotations,
annotations_entity_data_schema);
lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "annotation");
return 1;
});
}
bool Compile(StringPiece snippet, std::string* bytecode) {
return LuaEnvironment().Compile(snippet, bytecode);
}
} // namespace libtextclassifier3