| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ |
| #define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ |
| |
| #include "actions/types.h" |
| #include "annotator/types.h" |
| #include "utils/flatbuffers.h" |
| #include "utils/lua-utils.h" |
| |
| #ifdef __cplusplus |
| extern "C" { |
| #endif |
| #include "lauxlib.h" |
| #include "lua.h" |
| #include "lualib.h" |
| #ifdef __cplusplus |
| } |
| #endif |
| |
| // Action specific shared lua utilities. |
| namespace libtextclassifier3 { |
| |
| // Provides an annotation to lua. |
| void PushAnnotation(const ClassificationResult& classification, |
| const reflection::Schema* entity_data_schema, |
| LuaEnvironment* env); |
| void PushAnnotation(const ClassificationResult& classification, |
| StringPiece text, |
| const reflection::Schema* entity_data_schema, |
| LuaEnvironment* env); |
| void PushAnnotation(const ActionSuggestionAnnotation& annotation, |
| const reflection::Schema* entity_data_schema, |
| LuaEnvironment* env); |
| |
| // A lua iterator to enumerate annotation. |
| template <typename Annotation> |
| class AnnotationIterator |
| : public LuaEnvironment::ItemIterator<std::vector<Annotation>> { |
| public: |
| AnnotationIterator(const reflection::Schema* entity_data_schema, |
| LuaEnvironment* env) |
| : env_(env), entity_data_schema_(entity_data_schema) {} |
| int Item(const std::vector<Annotation>* annotations, const int64 pos, |
| lua_State* state) const override { |
| PushAnnotation((*annotations)[pos], entity_data_schema_, env_); |
| return 1; |
| } |
| int Item(const std::vector<Annotation>* annotations, StringPiece key, |
| lua_State* state) const override; |
| |
| private: |
| LuaEnvironment* env_; |
| const reflection::Schema* entity_data_schema_; |
| }; |
| |
| template <> |
| int AnnotationIterator<ClassificationResult>::Item( |
| const std::vector<ClassificationResult>* annotations, StringPiece key, |
| lua_State* state) const; |
| |
| template <> |
| int AnnotationIterator<ActionSuggestionAnnotation>::Item( |
| const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key, |
| lua_State* state) const; |
| |
| void PushAnnotatedSpan( |
| const AnnotatedSpan& annotated_span, |
| const AnnotationIterator<ClassificationResult>& annotation_iterator, |
| LuaEnvironment* env); |
| |
| MessageTextSpan ReadSpan(LuaEnvironment* env); |
| ActionSuggestionAnnotation ReadAnnotation( |
| const reflection::Schema* entity_data_schema, LuaEnvironment* env); |
| int ReadAnnotations(const reflection::Schema* entity_data_schema, |
| LuaEnvironment* env, |
| std::vector<ActionSuggestionAnnotation>* annotations); |
| ClassificationResult ReadClassificationResult( |
| const reflection::Schema* entity_data_schema, LuaEnvironment* env); |
| |
| // A lua iterator to enumerate annotated spans. |
| class AnnotatedSpanIterator |
| : public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> { |
| public: |
| AnnotatedSpanIterator( |
| const AnnotationIterator<ClassificationResult>& annotation_iterator, |
| LuaEnvironment* env) |
| : env_(env), annotation_iterator_(annotation_iterator) {} |
| AnnotatedSpanIterator(const reflection::Schema* entity_data_schema, |
| LuaEnvironment* env) |
| : env_(env), annotation_iterator_(entity_data_schema, env) {} |
| |
| int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos, |
| lua_State* state) const override { |
| PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_); |
| return /*num results=*/1; |
| } |
| |
| private: |
| LuaEnvironment* env_; |
| AnnotationIterator<ClassificationResult> annotation_iterator_; |
| }; |
| |
| // Provides an action to lua. |
| void PushAction( |
| const ActionSuggestion& action, |
| const reflection::Schema* entity_data_schema, |
| const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator, |
| LuaEnvironment* env); |
| |
| ActionSuggestion ReadAction( |
| const reflection::Schema* actions_entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema, |
| LuaEnvironment* env); |
| int ReadActions(const reflection::Schema* actions_entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema, |
| LuaEnvironment* env, std::vector<ActionSuggestion>* actions); |
| |
| // A lua iterator to enumerate actions suggestions. |
| class ActionsIterator |
| : public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> { |
| public: |
| ActionsIterator(const reflection::Schema* entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema, |
| LuaEnvironment* env) |
| : env_(env), |
| entity_data_schema_(entity_data_schema), |
| annotation_iterator_(annotations_entity_data_schema, env) {} |
| int Item(const std::vector<ActionSuggestion>* actions, const int64 pos, |
| lua_State* state) const override { |
| PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_, |
| env_); |
| return /*num results=*/1; |
| } |
| |
| private: |
| LuaEnvironment* env_; |
| const reflection::Schema* entity_data_schema_; |
| AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_; |
| }; |
| |
| // Conversation message lua iterator. |
| class ConversationIterator |
| : public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> { |
| public: |
| ConversationIterator( |
| const AnnotationIterator<ClassificationResult>& annotation_iterator, |
| LuaEnvironment* env) |
| : env_(env), |
| annotated_span_iterator_( |
| AnnotatedSpanIterator(annotation_iterator, env)) {} |
| ConversationIterator(const reflection::Schema* entity_data_schema, |
| LuaEnvironment* env) |
| : env_(env), |
| annotated_span_iterator_( |
| AnnotatedSpanIterator(entity_data_schema, env)) {} |
| |
| int Item(const std::vector<ConversationMessage>* messages, const int64 pos, |
| lua_State* state) const override; |
| |
| private: |
| LuaEnvironment* env_; |
| AnnotatedSpanIterator annotated_span_iterator_; |
| }; |
| |
| } // namespace libtextclassifier3 |
| |
| #endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_ |