blob: d825cb90a541b937b2bdb85b3a9107123868d1e6 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
#include <functional>
#include <vector>
#include "utils/flatbuffers.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/reflection_generated.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif
namespace libtextclassifier3 {
static constexpr const char *kLengthKey = "__len";
static constexpr const char *kPairsKey = "__pairs";
static constexpr const char *kIndexKey = "__index";
// Casts to the lua user data type.
template <typename T>
void *AsUserData(const T *value) {
return static_cast<void *>(const_cast<T *>(value));
}
template <typename T>
void *AsUserData(const T value) {
return reinterpret_cast<void *>(value);
}
// Retrieves up-values.
template <typename T>
T FromUpValue(const int index, lua_State *state) {
return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
}
class LuaEnvironment {
public:
// Wrapper for handling an iterator.
class Iterator {
public:
virtual ~Iterator() {}
static int NextCallback(lua_State *state);
static int LengthCallback(lua_State *state);
static int ItemCallback(lua_State *state);
static int IteritemsCallback(lua_State *state);
// Called when the next element of an iterator is fetched.
virtual int Next(lua_State *state) const = 0;
// Called when the length of the iterator is queried.
virtual int Length(lua_State *state) const = 0;
// Called when an item is queried.
virtual int Item(lua_State *state) const = 0;
// Called when a new iterator is started.
virtual int Iteritems(lua_State *state) const = 0;
protected:
static constexpr int kIteratorArgId = 1;
};
template <typename T>
class ItemIterator : public Iterator {
public:
void NewIterator(StringPiece name, const T *items, lua_State *state) const {
lua_newtable(state);
luaL_newmetatable(state, name.data());
lua_pushlightuserdata(state, AsUserData(this));
lua_pushlightuserdata(state, AsUserData(items));
lua_pushcclosure(state, &Iterator::ItemCallback, 2);
lua_setfield(state, -2, kIndexKey);
lua_pushlightuserdata(state, AsUserData(this));
lua_pushlightuserdata(state, AsUserData(items));
lua_pushcclosure(state, &Iterator::LengthCallback, 2);
lua_setfield(state, -2, kLengthKey);
lua_pushlightuserdata(state, AsUserData(this));
lua_pushlightuserdata(state, AsUserData(items));
lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
lua_setfield(state, -2, kPairsKey);
lua_setmetatable(state, -2);
}
int Iteritems(lua_State *state) const override {
lua_pushlightuserdata(state, AsUserData(this));
lua_pushlightuserdata(
state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
lua_pushnumber(state, 0);
lua_pushcclosure(state, &Iterator::NextCallback, 3);
return /*num results=*/1;
}
int Length(lua_State *state) const override {
lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
return /*num results=*/1;
}
int Next(lua_State *state) const override {
return Next(FromUpValue<T *>(kItemsArgId, state),
lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
state);
}
int Next(const T *items, const int64 pos, lua_State *state) const {
if (pos >= items->size()) {
return 0;
}
// Update iterator value.
lua_pushnumber(state, pos + 1);
lua_replace(state, lua_upvalueindex(3));
// Push key.
lua_pushinteger(state, pos + 1);
// Push item.
return 1 + Item(items, pos, state);
}
int Item(lua_State *state) const override {
const T *items = FromUpValue<T *>(kItemsArgId, state);
switch (lua_type(state, -1)) {
case LUA_TNUMBER: {
// Lua is one based, so adjust the index here.
const int64 index =
static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
if (index < 0 || index >= items->size()) {
TC3_LOG(ERROR) << "Invalid index: " << index;
lua_error(state);
return 0;
}
return Item(items, index, state);
}
case LUA_TSTRING: {
size_t key_length = 0;
const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
return Item(items, StringPiece(key, key_length), state);
}
default:
TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
lua_error(state);
return 0;
}
}
virtual int Item(const T *items, const int64 pos,
lua_State *state) const = 0;
virtual int Item(const T *items, StringPiece key, lua_State *state) const {
TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
lua_error(state);
return 0;
}
protected:
static constexpr int kItemsArgId = 2;
static constexpr int kIterValueArgId = 3;
};
virtual ~LuaEnvironment();
LuaEnvironment();
// Compile a lua snippet into binary bytecode.
// NOTE: The compiled bytecode might not be compatible across Lua versions
// and platforms.
bool Compile(StringPiece snippet, std::string *bytecode);
typedef int (*CallbackHandler)(lua_State *);
// Loads default libraries.
void LoadDefaultLibraries();
// Provides a callback to Lua.
template <typename T, int (T::*handler)()>
void Bind() {
lua_pushlightuserdata(state_, static_cast<void *>(this));
lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
}
// Setup a named table that callsback whenever a member is accessed.
// This allows to lazily provide required information to the script.
template <typename T, int (T::*handler)()>
void BindTable(const char *name) {
lua_newtable(state_);
luaL_newmetatable(state_, name);
lua_pushlightuserdata(state_, static_cast<void *>(this));
lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
lua_setfield(state_, -2, kIndexKey);
lua_setmetatable(state_, -2);
}
void PushValue(const Variant &value);
// Reads a string from the stack.
StringPiece ReadString(const int index) const;
// Pushes a string to the stack.
void PushString(const StringPiece str);
// Pushes a flatbuffer to the stack.
void PushFlatbuffer(const reflection::Schema *schema,
const flatbuffers::Table *table);
// Reads a flatbuffer from the stack.
int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);
// Runs a closure in protected mode.
// `func`: closure to run in protected mode.
// `num_lua_args`: number of arguments from the lua stack to process.
// `num_results`: number of result values pushed on the stack.
int RunProtected(const std::function<int()> &func, const int num_args = 0,
const int num_results = 0);
lua_State *state() const { return state_; }
protected:
lua_State *state_;
private:
// Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
const reflection::Object *type,
const flatbuffers::Table *table, lua_State *state);
static int GetFieldCallback(lua_State *state);
static int GetField(const reflection::Schema *schema,
const reflection::Object *type,
const flatbuffers::Table *table, lua_State *state);
template <typename T, int (T::*handler)()>
static int Dispatch(lua_State *state) {
T *env = FromUpValue<T *>(1, state);
return ((*env).*handler)();
}
};
bool Compile(StringPiece snippet, std::string *bytecode);
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_