| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "utils/lua-utils.h" |
| |
| namespace libtextclassifier3 { |
| namespace { |
| static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base}, |
| {LUA_TABLIBNAME, luaopen_table}, |
| {LUA_STRLIBNAME, luaopen_string}, |
| {LUA_MATHLIBNAME, luaopen_math}, |
| {nullptr, nullptr}}; |
| |
| static constexpr const char kTextKey[] = "text"; |
| static constexpr const char kTimeUsecKey[] = "parsed_time_ms_utc"; |
| static constexpr const char kGranularityKey[] = "granularity"; |
| static constexpr const char kCollectionKey[] = "collection"; |
| static constexpr const char kNameKey[] = "name"; |
| static constexpr const char kScoreKey[] = "score"; |
| static constexpr const char kPriorityScoreKey[] = "priority_score"; |
| static constexpr const char kTypeKey[] = "type"; |
| static constexpr const char kResponseTextKey[] = "response_text"; |
| static constexpr const char kAnnotationKey[] = "annotation"; |
| static constexpr const char kSpanKey[] = "span"; |
| static constexpr const char kMessageKey[] = "message"; |
| static constexpr const char kBeginKey[] = "begin"; |
| static constexpr const char kEndKey[] = "end"; |
| static constexpr const char kClassificationKey[] = "classification"; |
| static constexpr const char kSerializedEntity[] = "serialized_entity"; |
| static constexpr const char kEntityKey[] = "entity"; |
| |
| // Implementation of a lua_Writer that appends the data to a string. |
| int LuaStringWriter(lua_State* state, const void* data, size_t size, |
| void* result) { |
| std::string* const result_string = static_cast<std::string*>(result); |
| result_string->insert(result_string->size(), static_cast<const char*>(data), |
| size); |
| return LUA_OK; |
| } |
| |
| } // namespace |
| |
| LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); } |
| |
| LuaEnvironment::~LuaEnvironment() { |
| if (state_ != nullptr) { |
| lua_close(state_); |
| } |
| } |
| |
| void LuaEnvironment::PushFlatbuffer(const reflection::Schema* schema, |
| const reflection::Object* type, |
| const flatbuffers::Table* table) const { |
| PushLazyObject( |
| std::bind(&LuaEnvironment::GetField, this, schema, type, table)); |
| } |
| |
| int LuaEnvironment::GetField(const reflection::Schema* schema, |
| const reflection::Object* type, |
| const flatbuffers::Table* table) const { |
| const char* field_name = lua_tostring(state_, /*idx=*/kIndexStackTop); |
| const reflection::Field* field = type->fields()->LookupByKey(field_name); |
| if (field == nullptr) { |
| lua_error(state_); |
| return 0; |
| } |
| // Provide primitive fields directly. |
| const reflection::BaseType field_type = field->type()->base_type(); |
| switch (field_type) { |
| case reflection::Bool: |
| Push(table->GetField<bool>(field->offset(), field->default_integer())); |
| break; |
| case reflection::UByte: |
| Push(table->GetField<uint8>(field->offset(), field->default_integer())); |
| break; |
| case reflection::Byte: |
| Push(table->GetField<int8>(field->offset(), field->default_integer())); |
| break; |
| case reflection::Int: |
| Push(table->GetField<int32>(field->offset(), field->default_integer())); |
| break; |
| case reflection::UInt: |
| Push(table->GetField<uint32>(field->offset(), field->default_integer())); |
| break; |
| case reflection::Long: |
| Push(table->GetField<int64>(field->offset(), field->default_integer())); |
| break; |
| case reflection::ULong: |
| Push(table->GetField<uint64>(field->offset(), field->default_integer())); |
| break; |
| case reflection::Float: |
| Push(table->GetField<float>(field->offset(), field->default_real())); |
| break; |
| case reflection::Double: |
| Push(table->GetField<double>(field->offset(), field->default_real())); |
| break; |
| case reflection::String: { |
| Push(table->GetPointer<const flatbuffers::String*>(field->offset())); |
| break; |
| } |
| case reflection::Obj: { |
| const flatbuffers::Table* field_table = |
| table->GetPointer<const flatbuffers::Table*>(field->offset()); |
| if (field_table == nullptr) { |
| // Field was not set in entity data. |
| return 0; |
| } |
| const reflection::Object* field_type = |
| schema->objects()->Get(field->type()->index()); |
| PushFlatbuffer(schema, field_type, field_table); |
| break; |
| } |
| case reflection::Vector: { |
| const flatbuffers::Vector<flatbuffers::Offset<void>>* field_vector = |
| table->GetPointer< |
| const flatbuffers::Vector<flatbuffers::Offset<void>>*>( |
| field->offset()); |
| if (field_vector == nullptr) { |
| // Repeated field was not set in flatbuffer. |
| PushEmptyVector(); |
| break; |
| } |
| switch (field->type()->element()) { |
| case reflection::Bool: |
| PushRepeatedField(table->GetPointer<const flatbuffers::Vector<bool>*>( |
| field->offset())); |
| break; |
| case reflection::UByte: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector<uint8>*>( |
| field->offset())); |
| break; |
| case reflection::Byte: |
| PushRepeatedField(table->GetPointer<const flatbuffers::Vector<int8>*>( |
| field->offset())); |
| break; |
| case reflection::Int: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector<int32>*>( |
| field->offset())); |
| break; |
| case reflection::UInt: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector<uint32>*>( |
| field->offset())); |
| break; |
| case reflection::Long: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector<int64>*>( |
| field->offset())); |
| break; |
| case reflection::ULong: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector<uint64>*>( |
| field->offset())); |
| break; |
| case reflection::Float: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector<float>*>( |
| field->offset())); |
| break; |
| case reflection::Double: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector<double>*>( |
| field->offset())); |
| break; |
| case reflection::String: |
| PushRepeatedField( |
| table->GetPointer<const flatbuffers::Vector< |
| flatbuffers::Offset<flatbuffers::String>>*>(field->offset())); |
| break; |
| case reflection::Obj: |
| PushRepeatedFlatbufferField( |
| schema, schema->objects()->Get(field->type()->index()), |
| table->GetPointer<const flatbuffers::Vector< |
| flatbuffers::Offset<flatbuffers::Table>>*>(field->offset())); |
| break; |
| default: |
| TC3_LOG(ERROR) << "Unsupported repeated type: " |
| << field->type()->element(); |
| lua_error(state_); |
| return 0; |
| } |
| break; |
| } |
| default: |
| TC3_LOG(ERROR) << "Unsupported type: " << field_type; |
| lua_error(state_); |
| return 0; |
| } |
| return 1; |
| } |
| |
| int LuaEnvironment::ReadFlatbuffer(const int index, |
| MutableFlatbuffer* buffer) const { |
| if (buffer == nullptr) { |
| TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index; |
| lua_error(state_); |
| return LUA_ERRRUN; |
| } |
| if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) { |
| TC3_LOG(ERROR) << "Expected table, got: " |
| << lua_type(state_, /*idx=*/kIndexStackTop); |
| lua_error(state_); |
| return LUA_ERRRUN; |
| } |
| |
| lua_pushnil(state_); |
| while (Next(index - 1)) { |
| const StringPiece key = ReadString(/*index=*/index - 1); |
| const reflection::Field* field = buffer->GetFieldOrNull(key); |
| if (field == nullptr) { |
| TC3_LOG(ERROR) << "Unknown field: " << key; |
| lua_error(state_); |
| return LUA_ERRRUN; |
| } |
| switch (field->type()->base_type()) { |
| case reflection::Obj: |
| ReadFlatbuffer(/*index=*/kIndexStackTop, buffer->Mutable(field)); |
| break; |
| case reflection::Bool: |
| buffer->Set(field, Read<bool>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::Byte: |
| buffer->Set(field, Read<int8>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::UByte: |
| buffer->Set(field, Read<uint8>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::Int: |
| buffer->Set(field, Read<int32>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::UInt: |
| buffer->Set(field, Read<uint32>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::Long: |
| buffer->Set(field, Read<int64>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::ULong: |
| buffer->Set(field, Read<uint64>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::Float: |
| buffer->Set(field, Read<float>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::Double: |
| buffer->Set(field, Read<double>(/*index=*/kIndexStackTop)); |
| break; |
| case reflection::String: { |
| buffer->Set(field, ReadString(/*index=*/kIndexStackTop)); |
| break; |
| } |
| case reflection::Vector: { |
| // Read repeated field. |
| switch (field->type()->element()) { |
| case reflection::Bool: |
| ReadRepeatedField<bool>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::Byte: |
| ReadRepeatedField<int8>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::UByte: |
| ReadRepeatedField<uint8>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::Int: |
| ReadRepeatedField<int32>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::UInt: |
| ReadRepeatedField<uint32>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::Long: |
| ReadRepeatedField<int64>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::ULong: |
| ReadRepeatedField<uint64>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::Float: |
| ReadRepeatedField<float>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::Double: |
| ReadRepeatedField<double>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::String: |
| ReadRepeatedField<std::string>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| case reflection::Obj: |
| ReadRepeatedField<MutableFlatbuffer>(/*index=*/kIndexStackTop, |
| buffer->Repeated(field)); |
| break; |
| default: |
| TC3_LOG(ERROR) << "Unsupported repeated field type: " |
| << field->type()->element(); |
| lua_error(state_); |
| return LUA_ERRRUN; |
| } |
| break; |
| } |
| default: |
| TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type(); |
| lua_error(state_); |
| return LUA_ERRRUN; |
| } |
| lua_pop(state_, 1); |
| } |
| return LUA_OK; |
| } |
| |
| void LuaEnvironment::LoadDefaultLibraries() { |
| for (const luaL_Reg* lib = defaultlibs; lib->func; lib++) { |
| luaL_requiref(state_, lib->name, lib->func, 1); |
| lua_pop(state_, 1); // Remove lib. |
| } |
| } |
| |
| StringPiece LuaEnvironment::ReadString(const int index) const { |
| size_t length = 0; |
| const char* data = lua_tolstring(state_, index, &length); |
| return StringPiece(data, length); |
| } |
| |
| void LuaEnvironment::PushString(const StringPiece str) const { |
| lua_pushlstring(state_, str.data(), str.size()); |
| } |
| |
| bool LuaEnvironment::Compile(StringPiece snippet, std::string* bytecode) const { |
| if (luaL_loadbuffer(state_, snippet.data(), snippet.size(), |
| /*name=*/nullptr) != LUA_OK) { |
| TC3_LOG(ERROR) << "Could not compile lua snippet: " |
| << ReadString(/*index=*/kIndexStackTop); |
| lua_pop(state_, 1); |
| return false; |
| } |
| if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) { |
| TC3_LOG(ERROR) << "Could not dump compiled lua snippet."; |
| lua_pop(state_, 1); |
| return false; |
| } |
| lua_pop(state_, 1); |
| return true; |
| } |
| |
| void LuaEnvironment::PushAnnotation( |
| const ClassificationResult& classification, |
| const reflection::Schema* entity_data_schema) const { |
| if (entity_data_schema == nullptr || |
| classification.serialized_entity_data.empty()) { |
| // Empty table. |
| lua_newtable(state_); |
| } else { |
| PushFlatbuffer(entity_data_schema, |
| flatbuffers::GetRoot<flatbuffers::Table>( |
| classification.serialized_entity_data.data())); |
| } |
| Push(classification.datetime_parse_result.time_ms_utc); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTimeUsecKey); |
| Push(classification.datetime_parse_result.granularity); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kGranularityKey); |
| Push(classification.collection); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kCollectionKey); |
| Push(classification.score); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey); |
| Push(classification.serialized_entity_data); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSerializedEntity); |
| } |
| |
| void LuaEnvironment::PushAnnotation( |
| const ClassificationResult& classification, StringPiece text, |
| const reflection::Schema* entity_data_schema) const { |
| PushAnnotation(classification, entity_data_schema); |
| Push(text); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTextKey); |
| } |
| |
| void LuaEnvironment::PushAnnotation( |
| const ActionSuggestionAnnotation& annotation, |
| const reflection::Schema* entity_data_schema) const { |
| PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema); |
| PushString(annotation.name); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kNameKey); |
| { |
| lua_newtable(state_); |
| Push(annotation.span.message_index); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kMessageKey); |
| Push(annotation.span.span.first); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey); |
| Push(annotation.span.span.second); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey); |
| } |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey); |
| } |
| |
| void LuaEnvironment::PushAnnotatedSpan( |
| const AnnotatedSpan& annotated_span, |
| const reflection::Schema* entity_data_schema) const { |
| lua_newtable(state_); |
| { |
| lua_newtable(state_); |
| Push(annotated_span.span.first); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey); |
| Push(annotated_span.span.second); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey); |
| } |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey); |
| PushAnnotations(&annotated_span.classification, entity_data_schema); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kClassificationKey); |
| } |
| |
| void LuaEnvironment::PushAnnotatedSpans( |
| const std::vector<AnnotatedSpan>* annotated_spans, |
| const reflection::Schema* entity_data_schema) const { |
| PushIterator(annotated_spans ? annotated_spans->size() : 0, |
| [this, annotated_spans, entity_data_schema](const int64 index) { |
| PushAnnotatedSpan(annotated_spans->at(index), |
| entity_data_schema); |
| return 1; |
| }); |
| } |
| |
| MessageTextSpan LuaEnvironment::ReadSpan() const { |
| MessageTextSpan span; |
| lua_pushnil(state_); |
| while (Next(/*index=*/kIndexStackTop - 1)) { |
| const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1); |
| if (key.Equals(kMessageKey)) { |
| span.message_index = Read<int>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kBeginKey)) { |
| span.span.first = Read<int>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kEndKey)) { |
| span.span.second = Read<int>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kTextKey)) { |
| span.text = Read<std::string>(/*index=*/kIndexStackTop); |
| } else { |
| TC3_LOG(INFO) << "Unknown span field: " << key; |
| } |
| lua_pop(state_, 1); |
| } |
| return span; |
| } |
| |
| int LuaEnvironment::ReadAnnotations( |
| const reflection::Schema* entity_data_schema, |
| std::vector<ActionSuggestionAnnotation>* annotations) const { |
| if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) { |
| TC3_LOG(ERROR) << "Expected annotations table, got: " |
| << lua_type(state_, /*idx=*/kIndexStackTop); |
| lua_pop(state_, 1); |
| lua_error(state_); |
| return LUA_ERRRUN; |
| } |
| |
| // Read actions. |
| lua_pushnil(state_); |
| while (Next(/*index=*/kIndexStackTop - 1)) { |
| if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) { |
| TC3_LOG(ERROR) << "Expected annotation table, got: " |
| << lua_type(state_, /*idx=*/kIndexStackTop); |
| lua_pop(state_, 1); |
| continue; |
| } |
| annotations->push_back(ReadAnnotation(entity_data_schema)); |
| lua_pop(state_, 1); |
| } |
| return LUA_OK; |
| } |
| |
| ActionSuggestionAnnotation LuaEnvironment::ReadAnnotation( |
| const reflection::Schema* entity_data_schema) const { |
| ActionSuggestionAnnotation annotation; |
| lua_pushnil(state_); |
| while (Next(/*index=*/kIndexStackTop - 1)) { |
| const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1); |
| if (key.Equals(kNameKey)) { |
| annotation.name = Read<std::string>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kSpanKey)) { |
| annotation.span = ReadSpan(); |
| } else if (key.Equals(kEntityKey)) { |
| annotation.entity = ReadClassificationResult(entity_data_schema); |
| } else { |
| TC3_LOG(ERROR) << "Unknown annotation field: " << key; |
| } |
| lua_pop(state_, 1); |
| } |
| return annotation; |
| } |
| |
| ClassificationResult LuaEnvironment::ReadClassificationResult( |
| const reflection::Schema* entity_data_schema) const { |
| ClassificationResult classification; |
| lua_pushnil(state_); |
| while (Next(/*index=*/kIndexStackTop - 1)) { |
| const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1); |
| if (key.Equals(kCollectionKey)) { |
| classification.collection = Read<std::string>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kScoreKey)) { |
| classification.score = Read<float>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kTimeUsecKey)) { |
| classification.datetime_parse_result.time_ms_utc = |
| Read<int64>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kGranularityKey)) { |
| classification.datetime_parse_result.granularity = |
| static_cast<DatetimeGranularity>( |
| lua_tonumber(state_, /*idx=*/kIndexStackTop)); |
| } else if (key.Equals(kSerializedEntity)) { |
| classification.serialized_entity_data = |
| Read<std::string>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kEntityKey)) { |
| auto buffer = MutableFlatbufferBuilder(entity_data_schema).NewRoot(); |
| ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get()); |
| classification.serialized_entity_data = buffer->Serialize(); |
| } else { |
| TC3_LOG(INFO) << "Unknown classification result field: " << key; |
| } |
| lua_pop(state_, 1); |
| } |
| return classification; |
| } |
| |
| void LuaEnvironment::PushAction( |
| const ActionSuggestion& action, |
| const reflection::Schema* actions_entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema) const { |
| if (actions_entity_data_schema == nullptr || |
| action.serialized_entity_data.empty()) { |
| // Empty table. |
| lua_newtable(state_); |
| } else { |
| PushFlatbuffer(actions_entity_data_schema, |
| flatbuffers::GetRoot<flatbuffers::Table>( |
| action.serialized_entity_data.data())); |
| } |
| PushString(action.type); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTypeKey); |
| PushString(action.response_text); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kResponseTextKey); |
| Push(action.score); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey); |
| Push(action.priority_score); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kPriorityScoreKey); |
| PushAnnotations(&action.annotations, annotations_entity_data_schema); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kAnnotationKey); |
| } |
| |
| void LuaEnvironment::PushActions( |
| const std::vector<ActionSuggestion>* actions, |
| const reflection::Schema* actions_entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema) const { |
| PushIterator(actions ? actions->size() : 0, |
| [this, actions, actions_entity_data_schema, |
| annotations_entity_data_schema](const int64 index) { |
| PushAction(actions->at(index), actions_entity_data_schema, |
| annotations_entity_data_schema); |
| return 1; |
| }); |
| } |
| |
| ActionSuggestion LuaEnvironment::ReadAction( |
| const reflection::Schema* actions_entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema) const { |
| ActionSuggestion action; |
| lua_pushnil(state_); |
| while (Next(/*index=*/kIndexStackTop - 1)) { |
| const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1); |
| if (key.Equals(kResponseTextKey)) { |
| action.response_text = Read<std::string>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kTypeKey)) { |
| action.type = Read<std::string>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kScoreKey)) { |
| action.score = Read<float>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kPriorityScoreKey)) { |
| action.priority_score = Read<float>(/*index=*/kIndexStackTop); |
| } else if (key.Equals(kAnnotationKey)) { |
| ReadAnnotations(actions_entity_data_schema, &action.annotations); |
| } else if (key.Equals(kEntityKey)) { |
| auto buffer = |
| MutableFlatbufferBuilder(actions_entity_data_schema).NewRoot(); |
| ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get()); |
| action.serialized_entity_data = buffer->Serialize(); |
| } else { |
| TC3_LOG(INFO) << "Unknown action field: " << key; |
| } |
| lua_pop(state_, 1); |
| } |
| return action; |
| } |
| |
| int LuaEnvironment::ReadActions( |
| const reflection::Schema* actions_entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema, |
| std::vector<ActionSuggestion>* actions) const { |
| // Read actions. |
| lua_pushnil(state_); |
| while (Next(/*index=*/kIndexStackTop - 1)) { |
| if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) { |
| TC3_LOG(ERROR) << "Expected action table, got: " |
| << lua_type(state_, /*idx=*/kIndexStackTop); |
| lua_pop(state_, 1); |
| continue; |
| } |
| actions->push_back( |
| ReadAction(actions_entity_data_schema, annotations_entity_data_schema)); |
| lua_pop(state_, /*n=*/1); |
| } |
| lua_pop(state_, /*n=*/1); |
| |
| return LUA_OK; |
| } |
| |
| void LuaEnvironment::PushConversation( |
| const std::vector<ConversationMessage>* conversation, |
| const reflection::Schema* annotations_entity_data_schema) const { |
| PushIterator( |
| conversation ? conversation->size() : 0, |
| [this, conversation, annotations_entity_data_schema](const int64 index) { |
| const ConversationMessage& message = conversation->at(index); |
| lua_newtable(state_); |
| Push(message.user_id); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "user_id"); |
| Push(message.text); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "text"); |
| Push(message.reference_time_ms_utc); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "time_ms_utc"); |
| Push(message.reference_timezone); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "timezone"); |
| PushAnnotatedSpans(&message.annotations, |
| annotations_entity_data_schema); |
| lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "annotation"); |
| return 1; |
| }); |
| } |
| |
| bool Compile(StringPiece snippet, std::string* bytecode) { |
| return LuaEnvironment().Compile(snippet, bytecode); |
| } |
| |
| } // namespace libtextclassifier3 |