blob: 64071ca7f66b685d040acd8134863d5b11d5b2e1 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/lua-utils.h"
// lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
#ifndef TC3_AOSP
#define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
#endif
namespace libtextclassifier3 {
namespace {
// Upvalue indices for the flatbuffer callback.
static constexpr int kSchemaArgId = 1;
static constexpr int kTypeArgId = 2;
static constexpr int kTableArgId = 3;
static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
{LUA_TABLIBNAME, luaopen_table},
{LUA_STRLIBNAME, luaopen_string},
{LUA_BITLIBNAME, luaopen_bit32},
{LUA_MATHLIBNAME, luaopen_math},
{nullptr, nullptr}};
// Implementation of a lua_Writer that appends the data to a string.
int LuaStringWriter(lua_State *state, const void *data, size_t size,
void *result) {
std::string *const result_string = static_cast<std::string *>(result);
result_string->insert(result_string->size(), static_cast<const char *>(data),
size);
return LUA_OK;
}
} // namespace
LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
LuaEnvironment::~LuaEnvironment() {
if (state_ != nullptr) {
lua_close(state_);
}
}
int LuaEnvironment::Iterator::NextCallback(lua_State *state) {
return FromUpValue<Iterator *>(kIteratorArgId, state)->Next(state);
}
int LuaEnvironment::Iterator::LengthCallback(lua_State *state) {
return FromUpValue<Iterator *>(kIteratorArgId, state)->Length(state);
}
int LuaEnvironment::Iterator::ItemCallback(lua_State *state) {
return FromUpValue<Iterator *>(kIteratorArgId, state)->Item(state);
}
int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) {
return FromUpValue<Iterator *>(kIteratorArgId, state)->Iteritems(state);
}
void LuaEnvironment::PushFlatbuffer(const char *name,
const reflection::Schema *schema,
const reflection::Object *type,
const flatbuffers::Table *table,
lua_State *state) {
lua_newtable(state);
luaL_newmetatable(state, name);
lua_pushlightuserdata(state, AsUserData(schema));
lua_pushlightuserdata(state, AsUserData(type));
lua_pushlightuserdata(state, AsUserData(table));
lua_pushcclosure(state, &GetFieldCallback, 3);
lua_setfield(state, -2, kIndexKey);
lua_setmetatable(state, -2);
}
int LuaEnvironment::GetFieldCallback(lua_State *state) {
// Fetch the arguments.
const reflection::Schema *schema =
FromUpValue<reflection::Schema *>(kSchemaArgId, state);
const reflection::Object *type =
FromUpValue<reflection::Object *>(kTypeArgId, state);
const flatbuffers::Table *table =
FromUpValue<flatbuffers::Table *>(kTableArgId, state);
return GetField(schema, type, table, state);
}
int LuaEnvironment::GetField(const reflection::Schema *schema,
const reflection::Object *type,
const flatbuffers::Table *table,
lua_State *state) {
const char *field_name = lua_tostring(state, -1);
const reflection::Field *field = type->fields()->LookupByKey(field_name);
if (field == nullptr) {
lua_error(state);
return 0;
}
// Provide primitive fields directly.
const reflection::BaseType field_type = field->type()->base_type();
switch (field_type) {
case reflection::Bool:
lua_pushboolean(state, table->GetField<uint8_t>(
field->offset(), field->default_integer()));
break;
case reflection::Int:
lua_pushinteger(state, table->GetField<int32>(field->offset(),
field->default_integer()));
break;
case reflection::Long:
lua_pushinteger(state, table->GetField<int64>(field->offset(),
field->default_integer()));
break;
case reflection::Float:
lua_pushnumber(state, table->GetField<float>(field->offset(),
field->default_real()));
break;
case reflection::Double:
lua_pushnumber(state, table->GetField<double>(field->offset(),
field->default_real()));
break;
case reflection::String: {
const flatbuffers::String *string_value =
table->GetPointer<const flatbuffers::String *>(field->offset());
if (string_value != nullptr) {
lua_pushlstring(state, string_value->data(), string_value->Length());
} else {
lua_pushlstring(state, "", 0);
}
break;
}
case reflection::Obj: {
const flatbuffers::Table *field_table =
table->GetPointer<const flatbuffers::Table *>(field->offset());
if (field_table == nullptr) {
TC3_LOG(ERROR) << "Field was not set in entity data.";
lua_error(state);
return 0;
}
const reflection::Object *field_type =
schema->objects()->Get(field->type()->index());
PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table,
state);
break;
}
default:
TC3_LOG(ERROR) << "Unsupported type: " << field_type;
lua_error(state);
return 0;
}
return 1;
}
int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) {
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected actions table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_error(state_);
return LUA_ERRRUN;
}
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
const StringPiece key = ReadString(/*index=*/-2);
const reflection::Field *field = buffer->GetFieldOrNull(key);
if (field == nullptr) {
TC3_LOG(ERROR) << "Unknown field: " << key.ToString();
lua_error(state_);
return LUA_ERRRUN;
}
switch (field->type()->base_type()) {
case reflection::Obj:
return ReadFlatbuffer(buffer->Mutable(field));
case reflection::Bool:
buffer->Set(field,
static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
break;
case reflection::Int:
buffer->Set(field, static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
break;
case reflection::Long:
buffer->Set(field,
static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
break;
case reflection::Float:
buffer->Set(field,
static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
break;
case reflection::Double:
buffer->Set(field,
static_cast<double>(lua_tonumber(state_, /*idx=*/-1)));
break;
case reflection::String: {
buffer->Set(field, ReadString(/*index=*/-1));
break;
}
default:
TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
lua_error(state_);
return LUA_ERRRUN;
}
lua_pop(state_, 1);
}
// lua_pop(state_, /*n=*/1);
return LUA_OK;
}
void LuaEnvironment::LoadDefaultLibraries() {
for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
luaL_requiref(state_, lib->name, lib->func, 1);
lua_pop(state_, 1); /* remove lib */
}
}
void LuaEnvironment::PushValue(const Variant &value) {
if (value.HasInt()) {
lua_pushnumber(state_, value.IntValue());
} else if (value.HasInt64()) {
lua_pushnumber(state_, value.Int64Value());
} else if (value.HasBool()) {
lua_pushboolean(state_, value.BoolValue());
} else if (value.HasFloat()) {
lua_pushnumber(state_, value.FloatValue());
} else if (value.HasDouble()) {
lua_pushnumber(state_, value.DoubleValue());
} else if (value.HasString()) {
lua_pushlstring(state_, value.StringValue().data(),
value.StringValue().size());
} else {
TC3_LOG(FATAL) << "Unknown value type.";
}
}
StringPiece LuaEnvironment::ReadString(const int index) const {
size_t length = 0;
const char *data = lua_tolstring(state_, index, &length);
return StringPiece(data, length);
}
void LuaEnvironment::PushString(const StringPiece str) {
lua_pushlstring(state_, str.data(), str.size());
}
void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema,
const flatbuffers::Table *table) {
PushFlatbuffer(schema->root_table()->name()->c_str(), schema,
schema->root_table(), table, state_);
}
int LuaEnvironment::RunProtected(const std::function<int()> &func,
const int num_args, const int num_results) {
struct ProtectedCall {
std::function<int()> func;
static int run(lua_State *state) {
// Read the pointer to the ProtectedCall struct.
ProtectedCall *p = static_cast<ProtectedCall *>(
lua_touserdata(state, lua_upvalueindex(1)));
return p->func();
}
};
ProtectedCall protected_call = {func};
lua_pushlightuserdata(state_, &protected_call);
lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1);
// Put the closure before the arguments on the stack.
if (num_args > 0) {
lua_insert(state_, -(1 + num_args));
}
return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
}
bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) {
if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
/*name=*/nullptr) != LUA_OK) {
TC3_LOG(ERROR) << "Could not compile lua snippet: "
<< ReadString(/*index=*/-1).ToString();
lua_pop(state_, 1);
return false;
}
if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
lua_pop(state_, 1);
return false;
}
lua_pop(state_, 1);
return true;
}
bool Compile(StringPiece snippet, std::string *bytecode) {
return LuaEnvironment().Compile(snippet, bytecode);
}
} // namespace libtextclassifier3