blob: a76c7908fd477eb23bbd6328f531e92158861e60 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
#include <vector>
#include "actions/types.h"
#include "annotator/types.h"
#include "utils/flatbuffers/mutable.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif
namespace libtextclassifier3 {
static constexpr const char kLengthKey[] = "__len";
static constexpr const char kPairsKey[] = "__pairs";
static constexpr const char kIndexKey[] = "__index";
static constexpr const char kGcKey[] = "__gc";
static constexpr const char kNextKey[] = "__next";
static constexpr const int kIndexStackTop = -1;
// Casts to the lua user data type.
template <typename T>
void* AsUserData(const T* value) {
return static_cast<void*>(const_cast<T*>(value));
}
template <typename T>
void* AsUserData(const T value) {
return reinterpret_cast<void*>(value);
}
// Retrieves up-values.
template <typename T>
T FromUpValue(const int index, lua_State* state) {
return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
}
class LuaEnvironment {
public:
virtual ~LuaEnvironment();
explicit LuaEnvironment();
// Compile a lua snippet into binary bytecode.
// NOTE: The compiled bytecode might not be compatible across Lua versions
// and platforms.
bool Compile(StringPiece snippet, std::string* bytecode) const;
// Loads default libraries.
void LoadDefaultLibraries();
// Provides a callback to Lua.
template <typename T>
void PushFunction(int (T::*handler)()) {
PushFunction(std::bind(handler, static_cast<T*>(this)));
}
template <typename F>
void PushFunction(const F& func) const {
// Copy closure to the lua stack.
new (lua_newuserdata(state_, sizeof(func))) F(func);
// Register garbage collection callback.
lua_newtable(state_);
lua_pushcfunction(state_, &ReleaseFunction<F>);
lua_setfield(state_, -2, kGcKey);
lua_setmetatable(state_, -2);
// Push dispatch.
lua_pushcclosure(state_, &CallFunction<F>, 1);
}
// Sets up a named table that calls back whenever a member is accessed.
// This allows to lazily provide required information to the script.
template <typename T>
void PushLazyObject(int (T::*handler)()) {
PushLazyObject(std::bind(handler, static_cast<T*>(this)));
}
template <typename F>
void PushLazyObject(const F& func) const {
lua_newtable(state_);
lua_newtable(state_);
PushFunction(func);
lua_setfield(state_, -2, kIndexKey);
lua_setmetatable(state_, -2);
}
void Push(const int64 value) const { lua_pushinteger(state_, value); }
void Push(const uint64 value) const { lua_pushinteger(state_, value); }
void Push(const int32 value) const { lua_pushinteger(state_, value); }
void Push(const uint32 value) const { lua_pushinteger(state_, value); }
void Push(const int16 value) const { lua_pushinteger(state_, value); }
void Push(const uint16 value) const { lua_pushinteger(state_, value); }
void Push(const int8 value) const { lua_pushinteger(state_, value); }
void Push(const uint8 value) const { lua_pushinteger(state_, value); }
void Push(const float value) const { lua_pushnumber(state_, value); }
void Push(const double value) const { lua_pushnumber(state_, value); }
void Push(const bool value) const { lua_pushboolean(state_, value); }
void Push(const StringPiece value) const { PushString(value); }
void Push(const flatbuffers::String* value) const {
if (value == nullptr) {
PushString("");
} else {
PushString(StringPiece(value->c_str(), value->size()));
}
}
template <typename T>
T Read(const int index = -1) const;
template <>
int64 Read<int64>(const int index) const {
return static_cast<int64>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint64 Read<uint64>(const int index) const {
return static_cast<uint64>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int32 Read<int32>(const int index) const {
return static_cast<int32>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint32 Read<uint32>(const int index) const {
return static_cast<uint32>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int16 Read<int16>(const int index) const {
return static_cast<int16>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint16 Read<uint16>(const int index) const {
return static_cast<uint16>(lua_tointeger(state_, /*idx=*/index));
}
template <>
int8 Read<int8>(const int index) const {
return static_cast<int8>(lua_tointeger(state_, /*idx=*/index));
}
template <>
uint8 Read<uint8>(const int index) const {
return static_cast<uint8>(lua_tointeger(state_, /*idx=*/index));
}
template <>
float Read<float>(const int index) const {
return static_cast<float>(lua_tonumber(state_, /*idx=*/index));
}
template <>
double Read<double>(const int index) const {
return static_cast<double>(lua_tonumber(state_, /*idx=*/index));
}
template <>
bool Read<bool>(const int index) const {
return lua_toboolean(state_, /*idx=*/index);
}
template <>
StringPiece Read<StringPiece>(const int index) const {
return ReadString(index);
}
template <>
std::string Read<std::string>(const int index) const {
return ReadString(index).ToString();
}
// Reads a string from the stack.
StringPiece ReadString(int index) const;
// Pushes a string to the stack.
void PushString(const StringPiece str) const;
// Pushes a flatbuffer to the stack.
void PushFlatbuffer(const reflection::Schema* schema,
const flatbuffers::Table* table) const {
PushFlatbuffer(schema, schema->root_table(), table);
}
// Reads a flatbuffer from the stack.
int ReadFlatbuffer(int index, MutableFlatbuffer* buffer) const;
// Pushes an iterator.
template <typename ItemCallback, typename KeyCallback>
void PushIterator(const int length, const ItemCallback& item_callback,
const KeyCallback& key_callback) const {
lua_newtable(state_);
CreateIteratorMetatable(length, item_callback);
PushFunction([this, length, item_callback, key_callback]() {
return Iterator::Dispatch(this, length, item_callback, key_callback);
});
lua_setfield(state_, -2, kIndexKey);
lua_setmetatable(state_, -2);
}
template <typename ItemCallback>
void PushIterator(const int length, const ItemCallback& item_callback) const {
lua_newtable(state_);
CreateIteratorMetatable(length, item_callback);
PushFunction([this, length, item_callback]() {
return Iterator::Dispatch(this, length, item_callback);
});
lua_setfield(state_, -2, kIndexKey);
lua_setmetatable(state_, -2);
}
template <typename ItemCallback>
void CreateIteratorMetatable(const int length,
const ItemCallback& item_callback) const {
lua_newtable(state_);
PushFunction([this, length]() { return Iterator::Length(this, length); });
lua_setfield(state_, -2, kLengthKey);
PushFunction([this, length, item_callback]() {
return Iterator::IterItems(this, length, item_callback);
});
lua_setfield(state_, -2, kPairsKey);
PushFunction([this, length, item_callback]() {
return Iterator::Next(this, length, item_callback);
});
lua_setfield(state_, -2, kNextKey);
}
template <typename T>
void PushVectorIterator(const std::vector<T>* items) const {
PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
this->Push(items->at(pos));
return 1;
});
}
template <typename T>
void PushVector(const std::vector<T>& items) const {
lua_newtable(state_);
for (int i = 0; i < items.size(); i++) {
// Key: index, 1-based.
Push(i + 1);
// Value.
Push(items[i]);
lua_settable(state_, /*idx=*/-3);
}
}
void PushEmptyVector() const { lua_newtable(state_); }
template <typename T>
std::vector<T> ReadVector(const int index = -1) const {
std::vector<T> result;
if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected a table, got: "
<< lua_type(state_, /*idx=*/kIndexStackTop);
lua_pop(state_, 1);
return {};
}
lua_pushnil(state_);
while (Next(index - 1)) {
result.push_back(Read<T>(/*index=*/kIndexStackTop));
lua_pop(state_, 1);
}
return result;
}
// Runs a closure in protected mode.
// `func`: closure to run in protected mode.
// `num_lua_args`: number of arguments from the lua stack to process.
// `num_results`: number of result values pushed on the stack.
template <typename F>
int RunProtected(const F& func, const int num_args = 0,
const int num_results = 0) const {
PushFunction(func);
// Put the closure before the arguments on the stack.
if (num_args > 0) {
lua_insert(state_, -(1 + num_args));
}
return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
}
// Auxiliary methods to handle model results.
// Provides an annotation to lua.
void PushAnnotation(const ClassificationResult& classification,
const reflection::Schema* entity_data_schema) const;
void PushAnnotation(const ClassificationResult& classification,
StringPiece text,
const reflection::Schema* entity_data_schema) const;
void PushAnnotation(const ActionSuggestionAnnotation& annotation,
const reflection::Schema* entity_data_schema) const;
template <typename Annotation>
void PushAnnotations(const std::vector<Annotation>* annotations,
const reflection::Schema* entity_data_schema) const {
PushIterator(
annotations ? annotations->size() : 0,
[this, annotations, entity_data_schema](const int64 index) {
PushAnnotation(annotations->at(index), entity_data_schema);
return 1;
},
[this, annotations, entity_data_schema](StringPiece name) {
if (const Annotation* annotation =
GetAnnotationByName(*annotations, name)) {
PushAnnotation(*annotation, entity_data_schema);
return 1;
} else {
return 0;
}
});
}
// Pushes a span to the lua stack.
void PushAnnotatedSpan(const AnnotatedSpan& annotated_span,
const reflection::Schema* entity_data_schema) const;
void PushAnnotatedSpans(const std::vector<AnnotatedSpan>* annotated_spans,
const reflection::Schema* entity_data_schema) const;
// Reads a message text span from lua.
MessageTextSpan ReadSpan() const;
ActionSuggestionAnnotation ReadAnnotation(
const reflection::Schema* entity_data_schema) const;
int ReadAnnotations(
const reflection::Schema* entity_data_schema,
std::vector<ActionSuggestionAnnotation>* annotations) const;
ClassificationResult ReadClassificationResult(
const reflection::Schema* entity_data_schema) const;
// Provides an action to lua.
void PushAction(
const ActionSuggestion& action,
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) const;
void PushActions(
const std::vector<ActionSuggestion>* actions,
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) const;
ActionSuggestion ReadAction(
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) const;
int ReadActions(const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
std::vector<ActionSuggestion>* actions) const;
// Conversation message iterator.
void PushConversation(
const std::vector<ConversationMessage>* conversation,
const reflection::Schema* annotations_entity_data_schema) const;
lua_State* state() const { return state_; }
protected:
// Wrapper for handling iteration over containers.
class Iterator {
public:
// Starts a new key-value pair iterator.
template <typename ItemCallback>
static int IterItems(const LuaEnvironment* env, const int length,
const ItemCallback& callback) {
env->PushFunction([env, callback, length, pos = 0]() mutable {
if (pos >= length) {
lua_pushnil(env->state());
return 1;
}
// Push key.
lua_pushinteger(env->state(), pos + 1);
// Push item.
return 1 + callback(pos++);
});
return 1; // Num. results.
}
// Gets the next element.
template <typename ItemCallback>
static int Next(const LuaEnvironment* env, const int length,
const ItemCallback& item_callback) {
int64 pos = lua_isnil(env->state(), /*idx=*/kIndexStackTop)
? 0
: env->Read<int64>(/*index=*/kIndexStackTop);
if (pos < length) {
// Push next key.
lua_pushinteger(env->state(), pos + 1);
// Push item.
return 1 + item_callback(pos);
} else {
lua_pushnil(env->state());
return 1;
}
}
// Returns the length of the container the iterator processes.
static int Length(const LuaEnvironment* env, const int length) {
lua_pushinteger(env->state(), length);
return 1; // Num. results.
}
// Handles item queries to the iterator.
// Elements of the container can either be queried by name or index.
// Dispatch will check how an element is accessed and
// calls `key_callback` for access by name and `item_callback` for access by
// index.
template <typename ItemCallback, typename KeyCallback>
static int Dispatch(const LuaEnvironment* env, const int length,
const ItemCallback& item_callback,
const KeyCallback& key_callback) {
switch (lua_type(env->state(), kIndexStackTop)) {
case LUA_TNUMBER: {
// Lua is one based, so adjust the index here.
const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
if (index < 0 || index >= length) {
TC3_LOG(ERROR) << "Invalid index: " << index;
lua_error(env->state());
return 0;
}
return item_callback(index);
}
case LUA_TSTRING: {
return key_callback(env->ReadString(kIndexStackTop));
}
default:
TC3_LOG(ERROR) << "Unexpected access type: "
<< lua_type(env->state(), kIndexStackTop);
lua_error(env->state());
return 0;
}
}
template <typename ItemCallback>
static int Dispatch(const LuaEnvironment* env, const int length,
const ItemCallback& item_callback) {
switch (lua_type(env->state(), kIndexStackTop)) {
case LUA_TNUMBER: {
// Lua is one based, so adjust the index here.
const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
if (index < 0 || index >= length) {
TC3_LOG(ERROR) << "Invalid index: " << index;
lua_error(env->state());
return 0;
}
return item_callback(index);
}
default:
TC3_LOG(ERROR) << "Unexpected access type: "
<< lua_type(env->state(), kIndexStackTop);
lua_error(env->state());
return 0;
}
}
};
// Calls the deconstructor from a previously pushed function.
template <typename T>
static int ReleaseFunction(lua_State* state) {
static_cast<T*>(lua_touserdata(state, 1))->~T();
return 0;
}
template <typename T>
static int CallFunction(lua_State* state) {
return (*static_cast<T*>(lua_touserdata(state, lua_upvalueindex(1))))();
}
// Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
void PushFlatbuffer(const reflection::Schema* schema,
const reflection::Object* type,
const flatbuffers::Table* table) const;
int GetField(const reflection::Schema* schema, const reflection::Object* type,
const flatbuffers::Table* table) const;
// Reads a repeated field from lua.
template <typename T>
void ReadRepeatedField(const int index, RepeatedField* result) const {
for (const T& element : ReadVector<T>(index)) {
result->Add(element);
}
}
template <>
void ReadRepeatedField<MutableFlatbuffer>(const int index,
RepeatedField* result) const {
lua_pushnil(state_);
while (Next(index - 1)) {
ReadFlatbuffer(index, result->Add());
lua_pop(state_, 1);
}
}
// Pushes a repeated field to the lua stack.
template <typename T>
void PushRepeatedField(const flatbuffers::Vector<T>* items) const {
PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
Push(items->Get(pos));
return 1; // Num. results.
});
}
void PushRepeatedFlatbufferField(
const reflection::Schema* schema, const reflection::Object* type,
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>* items)
const {
PushIterator(items ? items->size() : 0,
[this, schema, type, items](const int64 pos) {
PushFlatbuffer(schema, type, items->Get(pos));
return 1; // Num. results.
});
}
// Overloads Lua next function to use __next key on the metatable.
// This allows us to treat lua objects and lazy objects provided by our
// callbacks uniformly.
int Next(int index) const {
// Check whether the (meta)table of this object has an associated "__next"
// entry. This means, we registered our own callback. So we explicitly call
// that.
if (luaL_getmetafield(state_, index, kNextKey)) {
// Callback is now on top of the stack, so adjust relative indices by 1.
if (index < 0) {
index--;
}
// Copy the reference to the table.
lua_pushvalue(state_, index);
// Move the key to top to have it as second argument for the callback.
// Copy the key to the top.
lua_pushvalue(state_, -3);
// Remove the copy of the key.
lua_remove(state_, -4);
// Call the callback with (key and table as arguments).
lua_pcall(state_, /*nargs=*/2 /* table, key */,
/*nresults=*/2 /* key, item */, 0);
// Next returned nil, it's the end.
if (lua_isnil(state_, kIndexStackTop)) {
// Remove nil value.
// Results will be padded to `nresults` specified above, so we need
// to remove two elements here.
lua_pop(state_, 2);
return 0;
}
return 2; // Num. results.
} else if (lua_istable(state_, index)) {
return lua_next(state_, index);
}
// Remove the key.
lua_pop(state_, 1);
return 0;
}
static const ClassificationResult* GetAnnotationByName(
const std::vector<ClassificationResult>& annotations, StringPiece name) {
// Lookup annotation by collection.
for (const ClassificationResult& annotation : annotations) {
if (name.Equals(annotation.collection)) {
return &annotation;
}
}
TC3_LOG(ERROR) << "No annotation with collection: " << name << " found.";
return nullptr;
}
static const ActionSuggestionAnnotation* GetAnnotationByName(
const std::vector<ActionSuggestionAnnotation>& annotations,
StringPiece name) {
// Lookup annotation by name.
for (const ActionSuggestionAnnotation& annotation : annotations) {
if (name.Equals(annotation.name)) {
return &annotation;
}
}
TC3_LOG(ERROR) << "No annotation with name: " << name << " found.";
return nullptr;
}
lua_State* state_;
}; // namespace libtextclassifier3
bool Compile(StringPiece snippet, std::string* bytecode);
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_