| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ |
| #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ |
| |
| #include <functional> |
| #include <vector> |
| |
| #include "utils/flatbuffers.h" |
| #include "utils/strings/stringpiece.h" |
| #include "utils/variant.h" |
| #include "flatbuffers/reflection_generated.h" |
| |
| #ifdef __cplusplus |
| extern "C" { |
| #endif |
| #include "lauxlib.h" |
| #include "lua.h" |
| #include "lualib.h" |
| #ifdef __cplusplus |
| } |
| #endif |
| |
| namespace libtextclassifier3 { |
| |
| static constexpr const char *kLengthKey = "__len"; |
| static constexpr const char *kPairsKey = "__pairs"; |
| static constexpr const char *kIndexKey = "__index"; |
| |
| // Casts to the lua user data type. |
| template <typename T> |
| void *AsUserData(const T *value) { |
| return static_cast<void *>(const_cast<T *>(value)); |
| } |
| template <typename T> |
| void *AsUserData(const T value) { |
| return reinterpret_cast<void *>(value); |
| } |
| |
| // Retrieves up-values. |
| template <typename T> |
| T FromUpValue(const int index, lua_State *state) { |
| return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index))); |
| } |
| |
| class LuaEnvironment { |
| public: |
| // Wrapper for handling an iterator. |
| class Iterator { |
| public: |
| virtual ~Iterator() {} |
| static int NextCallback(lua_State *state); |
| static int LengthCallback(lua_State *state); |
| static int ItemCallback(lua_State *state); |
| static int IteritemsCallback(lua_State *state); |
| |
| // Called when the next element of an iterator is fetched. |
| virtual int Next(lua_State *state) const = 0; |
| |
| // Called when the length of the iterator is queried. |
| virtual int Length(lua_State *state) const = 0; |
| |
| // Called when an item is queried. |
| virtual int Item(lua_State *state) const = 0; |
| |
| // Called when a new iterator is started. |
| virtual int Iteritems(lua_State *state) const = 0; |
| |
| protected: |
| static constexpr int kIteratorArgId = 1; |
| }; |
| |
| template <typename T> |
| class ItemIterator : public Iterator { |
| public: |
| void NewIterator(StringPiece name, const T *items, lua_State *state) const { |
| lua_newtable(state); |
| luaL_newmetatable(state, name.data()); |
| lua_pushlightuserdata(state, AsUserData(this)); |
| lua_pushlightuserdata(state, AsUserData(items)); |
| lua_pushcclosure(state, &Iterator::ItemCallback, 2); |
| lua_setfield(state, -2, kIndexKey); |
| lua_pushlightuserdata(state, AsUserData(this)); |
| lua_pushlightuserdata(state, AsUserData(items)); |
| lua_pushcclosure(state, &Iterator::LengthCallback, 2); |
| lua_setfield(state, -2, kLengthKey); |
| lua_pushlightuserdata(state, AsUserData(this)); |
| lua_pushlightuserdata(state, AsUserData(items)); |
| lua_pushcclosure(state, &Iterator::IteritemsCallback, 2); |
| lua_setfield(state, -2, kPairsKey); |
| lua_setmetatable(state, -2); |
| } |
| |
| int Iteritems(lua_State *state) const override { |
| lua_pushlightuserdata(state, AsUserData(this)); |
| lua_pushlightuserdata( |
| state, lua_touserdata(state, lua_upvalueindex(kItemsArgId))); |
| lua_pushnumber(state, 0); |
| lua_pushcclosure(state, &Iterator::NextCallback, 3); |
| return /*num results=*/1; |
| } |
| |
| int Length(lua_State *state) const override { |
| lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size()); |
| return /*num results=*/1; |
| } |
| |
| int Next(lua_State *state) const override { |
| return Next(FromUpValue<T *>(kItemsArgId, state), |
| lua_tointeger(state, lua_upvalueindex(kIterValueArgId)), |
| state); |
| } |
| |
| int Next(const T *items, const int64 pos, lua_State *state) const { |
| if (pos >= items->size()) { |
| return 0; |
| } |
| |
| // Update iterator value. |
| lua_pushnumber(state, pos + 1); |
| lua_replace(state, lua_upvalueindex(3)); |
| |
| // Push key. |
| lua_pushinteger(state, pos + 1); |
| |
| // Push item. |
| return 1 + Item(items, pos, state); |
| } |
| |
| int Item(lua_State *state) const override { |
| const T *items = FromUpValue<T *>(kItemsArgId, state); |
| switch (lua_type(state, -1)) { |
| case LUA_TNUMBER: { |
| // Lua is one based, so adjust the index here. |
| const int64 index = |
| static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1; |
| if (index < 0 || index >= items->size()) { |
| TC3_LOG(ERROR) << "Invalid index: " << index; |
| lua_error(state); |
| return 0; |
| } |
| return Item(items, index, state); |
| } |
| case LUA_TSTRING: { |
| size_t key_length = 0; |
| const char *key = lua_tolstring(state, /*idx=*/-1, &key_length); |
| return Item(items, StringPiece(key, key_length), state); |
| } |
| default: |
| TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1); |
| lua_error(state); |
| return 0; |
| } |
| } |
| |
| virtual int Item(const T *items, const int64 pos, |
| lua_State *state) const = 0; |
| |
| virtual int Item(const T *items, StringPiece key, lua_State *state) const { |
| TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString(); |
| lua_error(state); |
| return 0; |
| } |
| |
| protected: |
| static constexpr int kItemsArgId = 2; |
| static constexpr int kIterValueArgId = 3; |
| }; |
| |
| virtual ~LuaEnvironment(); |
| LuaEnvironment(); |
| |
| // Compile a lua snippet into binary bytecode. |
| // NOTE: The compiled bytecode might not be compatible across Lua versions |
| // and platforms. |
| bool Compile(StringPiece snippet, std::string *bytecode); |
| |
| typedef int (*CallbackHandler)(lua_State *); |
| |
| // Loads default libraries. |
| void LoadDefaultLibraries(); |
| |
| // Provides a callback to Lua. |
| template <typename T, int (T::*handler)()> |
| void Bind() { |
| lua_pushlightuserdata(state_, static_cast<void *>(this)); |
| lua_pushcclosure(state_, &Dispatch<T, handler>, 1); |
| } |
| |
| // Setup a named table that callsback whenever a member is accessed. |
| // This allows to lazily provide required information to the script. |
| template <typename T, int (T::*handler)()> |
| void BindTable(const char *name) { |
| lua_newtable(state_); |
| luaL_newmetatable(state_, name); |
| lua_pushlightuserdata(state_, static_cast<void *>(this)); |
| lua_pushcclosure(state_, &Dispatch<T, handler>, 1); |
| lua_setfield(state_, -2, kIndexKey); |
| lua_setmetatable(state_, -2); |
| } |
| |
| void PushValue(const Variant &value); |
| |
| // Reads a string from the stack. |
| StringPiece ReadString(const int index) const; |
| |
| // Pushes a string to the stack. |
| void PushString(const StringPiece str); |
| |
| // Pushes a flatbuffer to the stack. |
| void PushFlatbuffer(const reflection::Schema *schema, |
| const flatbuffers::Table *table); |
| |
| // Reads a flatbuffer from the stack. |
| int ReadFlatbuffer(ReflectiveFlatbuffer *buffer); |
| |
| // Runs a closure in protected mode. |
| // `func`: closure to run in protected mode. |
| // `num_lua_args`: number of arguments from the lua stack to process. |
| // `num_results`: number of result values pushed on the stack. |
| int RunProtected(const std::function<int()> &func, const int num_args = 0, |
| const int num_results = 0); |
| |
| lua_State *state() const { return state_; } |
| |
| protected: |
| lua_State *state_; |
| |
| private: |
| // Auxiliary methods to expose (reflective) flatbuffer based data to Lua. |
| static void PushFlatbuffer(const char *name, const reflection::Schema *schema, |
| const reflection::Object *type, |
| const flatbuffers::Table *table, lua_State *state); |
| static int GetFieldCallback(lua_State *state); |
| static int GetField(const reflection::Schema *schema, |
| const reflection::Object *type, |
| const flatbuffers::Table *table, lua_State *state); |
| |
| template <typename T, int (T::*handler)()> |
| static int Dispatch(lua_State *state) { |
| T *env = FromUpValue<T *>(1, state); |
| return ((*env).*handler)(); |
| } |
| }; |
| |
| bool Compile(StringPiece snippet, std::string *bytecode); |
| |
| } // namespace libtextclassifier3 |
| |
| #endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ |