blob: 4f06674f492a1e4eb90504621e62a7256d382be9 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
#define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
#include "actions/types.h"
#include "annotator/types.h"
#include "utils/flatbuffers.h"
#include "utils/lua-utils.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif
// Action specific shared lua utilities.
namespace libtextclassifier3 {
// Provides an annotation to lua.
void PushAnnotation(const ClassificationResult& classification,
const reflection::Schema* entity_data_schema,
LuaEnvironment* env);
void PushAnnotation(const ClassificationResult& classification,
StringPiece text,
const reflection::Schema* entity_data_schema,
LuaEnvironment* env);
void PushAnnotation(const ActionSuggestionAnnotation& annotation,
const reflection::Schema* entity_data_schema,
LuaEnvironment* env);
// A lua iterator to enumerate annotation.
template <typename Annotation>
class AnnotationIterator
: public LuaEnvironment::ItemIterator<std::vector<Annotation>> {
public:
AnnotationIterator(const reflection::Schema* entity_data_schema,
LuaEnvironment* env)
: env_(env), entity_data_schema_(entity_data_schema) {}
int Item(const std::vector<Annotation>* annotations, const int64 pos,
lua_State* state) const override {
PushAnnotation((*annotations)[pos], entity_data_schema_, env_);
return 1;
}
int Item(const std::vector<Annotation>* annotations, StringPiece key,
lua_State* state) const override;
private:
LuaEnvironment* env_;
const reflection::Schema* entity_data_schema_;
};
template <>
int AnnotationIterator<ClassificationResult>::Item(
const std::vector<ClassificationResult>* annotations, StringPiece key,
lua_State* state) const;
template <>
int AnnotationIterator<ActionSuggestionAnnotation>::Item(
const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
lua_State* state) const;
void PushAnnotatedSpan(
const AnnotatedSpan& annotated_span,
const AnnotationIterator<ClassificationResult>& annotation_iterator,
LuaEnvironment* env);
MessageTextSpan ReadSpan(LuaEnvironment* env);
ActionSuggestionAnnotation ReadAnnotation(
const reflection::Schema* entity_data_schema, LuaEnvironment* env);
int ReadAnnotations(const reflection::Schema* entity_data_schema,
LuaEnvironment* env,
std::vector<ActionSuggestionAnnotation>* annotations);
ClassificationResult ReadClassificationResult(
const reflection::Schema* entity_data_schema, LuaEnvironment* env);
// A lua iterator to enumerate annotated spans.
class AnnotatedSpanIterator
: public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> {
public:
AnnotatedSpanIterator(
const AnnotationIterator<ClassificationResult>& annotation_iterator,
LuaEnvironment* env)
: env_(env), annotation_iterator_(annotation_iterator) {}
AnnotatedSpanIterator(const reflection::Schema* entity_data_schema,
LuaEnvironment* env)
: env_(env), annotation_iterator_(entity_data_schema, env) {}
int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos,
lua_State* state) const override {
PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_);
return /*num results=*/1;
}
private:
LuaEnvironment* env_;
AnnotationIterator<ClassificationResult> annotation_iterator_;
};
// Provides an action to lua.
void PushAction(
const ActionSuggestion& action,
const reflection::Schema* entity_data_schema,
const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
LuaEnvironment* env);
ActionSuggestion ReadAction(
const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
LuaEnvironment* env);
int ReadActions(const reflection::Schema* actions_entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
LuaEnvironment* env, std::vector<ActionSuggestion>* actions);
// A lua iterator to enumerate actions suggestions.
class ActionsIterator
: public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> {
public:
ActionsIterator(const reflection::Schema* entity_data_schema,
const reflection::Schema* annotations_entity_data_schema,
LuaEnvironment* env)
: env_(env),
entity_data_schema_(entity_data_schema),
annotation_iterator_(annotations_entity_data_schema, env) {}
int Item(const std::vector<ActionSuggestion>* actions, const int64 pos,
lua_State* state) const override {
PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_,
env_);
return /*num results=*/1;
}
private:
LuaEnvironment* env_;
const reflection::Schema* entity_data_schema_;
AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
};
// Conversation message lua iterator.
class ConversationIterator
: public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> {
public:
ConversationIterator(
const AnnotationIterator<ClassificationResult>& annotation_iterator,
LuaEnvironment* env)
: env_(env),
annotated_span_iterator_(
AnnotatedSpanIterator(annotation_iterator, env)) {}
ConversationIterator(const reflection::Schema* entity_data_schema,
LuaEnvironment* env)
: env_(env),
annotated_span_iterator_(
AnnotatedSpanIterator(entity_data_schema, env)) {}
int Item(const std::vector<ConversationMessage>* messages, const int64 pos,
lua_State* state) const override;
private:
LuaEnvironment* env_;
AnnotatedSpanIterator annotated_span_iterator_;
};
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_