| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // Utility functions for working with FlatBuffers. |
| |
| #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_ |
| #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_ |
| |
| #include <map> |
| #include <memory> |
| #include <string> |
| |
| #include "annotator/model_generated.h" |
| #include "utils/strings/stringpiece.h" |
| #include "utils/variant.h" |
| #include "flatbuffers/flatbuffers.h" |
| #include "flatbuffers/reflection.h" |
| |
| namespace libtextclassifier3 { |
| |
| // Loads and interprets the buffer as 'FlatbufferMessage' and verifies its |
| // integrity. |
| template <typename FlatbufferMessage> |
| const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) { |
| const FlatbufferMessage* message = |
| flatbuffers::GetRoot<FlatbufferMessage>(buffer); |
| if (message == nullptr) { |
| return nullptr; |
| } |
| flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer), |
| size); |
| if (message->Verify(verifier)) { |
| return message; |
| } else { |
| return nullptr; |
| } |
| } |
| |
| // Same as above but takes string. |
| template <typename FlatbufferMessage> |
| const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) { |
| return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(), |
| buffer.size()); |
| } |
| |
| // Loads and interprets the buffer as 'FlatbufferMessage', verifies its |
| // integrity and returns its mutable version. |
| template <typename FlatbufferMessage> |
| std::unique_ptr<typename FlatbufferMessage::NativeTableType> |
| LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) { |
| const FlatbufferMessage* message = |
| LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size); |
| if (message == nullptr) { |
| return nullptr; |
| } |
| return std::unique_ptr<typename FlatbufferMessage::NativeTableType>( |
| message->UnPack()); |
| } |
| |
| // Same as above but takes string. |
| template <typename FlatbufferMessage> |
| std::unique_ptr<typename FlatbufferMessage::NativeTableType> |
| LoadAndVerifyMutableFlatbuffer(const std::string& buffer) { |
| return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(), |
| buffer.size()); |
| } |
| |
| template <typename FlatbufferMessage> |
| const char* FlatbufferFileIdentifier() { |
| return nullptr; |
| } |
| |
| template <> |
| const char* FlatbufferFileIdentifier<Model>(); |
| |
| // Packs the mutable flatbuffer message to string. |
| template <typename FlatbufferMessage> |
| std::string PackFlatbuffer( |
| const typename FlatbufferMessage::NativeTableType* mutable_message) { |
| flatbuffers::FlatBufferBuilder builder; |
| builder.Finish(FlatbufferMessage::Pack(builder, mutable_message), |
| FlatbufferFileIdentifier<FlatbufferMessage>()); |
| return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), |
| builder.GetSize()); |
| } |
| |
| // A flatbuffer that can be built using flatbuffer reflection data of the |
| // schema. |
| // Normally, field information is hard-coded in code generated from a flatbuffer |
| // schema. Here we lookup the necessary information for building a flatbuffer |
| // from the provided reflection meta data. |
| // When serializing a flatbuffer, the library requires that the sub messages |
| // are already serialized, therefore we explicitly keep the field values and |
| // serialize the message in (reverse) topological dependency order. |
| class ReflectiveFlatbuffer { |
| public: |
| ReflectiveFlatbuffer(const reflection::Schema* schema, |
| const reflection::Object* type) |
| : schema_(schema), type_(type) {} |
| |
| // Encapsulates a repeated field. |
| // Serves as a common base class for repeated fields. |
| class RepeatedField { |
| public: |
| virtual ~RepeatedField() {} |
| |
| virtual flatbuffers::uoffset_t Serialize( |
| flatbuffers::FlatBufferBuilder* builder) const = 0; |
| }; |
| |
| // Represents a repeated field of particular type. |
| template <typename T> |
| class TypedRepeatedField : public RepeatedField { |
| public: |
| void Add(const T value) { items_.push_back(value); } |
| |
| flatbuffers::uoffset_t Serialize( |
| flatbuffers::FlatBufferBuilder* builder) const override { |
| return builder->CreateVector(items_).o; |
| } |
| |
| private: |
| std::vector<T> items_; |
| }; |
| |
| // Specialization for strings. |
| template <> |
| class TypedRepeatedField<std::string> : public RepeatedField { |
| public: |
| void Add(const std::string& value) { items_.push_back(value); } |
| |
| flatbuffers::uoffset_t Serialize( |
| flatbuffers::FlatBufferBuilder* builder) const override { |
| std::vector<flatbuffers::Offset<flatbuffers::String>> offsets( |
| items_.size()); |
| for (int i = 0; i < items_.size(); i++) { |
| offsets[i] = builder->CreateString(items_[i]); |
| } |
| return builder->CreateVector(offsets).o; |
| } |
| |
| private: |
| std::vector<std::string> items_; |
| }; |
| |
| // Specialization for repeated sub-messages. |
| template <> |
| class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField { |
| public: |
| TypedRepeatedField<ReflectiveFlatbuffer>( |
| const reflection::Schema* const schema, |
| const reflection::Type* const type) |
| : schema_(schema), type_(type) {} |
| |
| ReflectiveFlatbuffer* Add() { |
| items_.emplace_back(new ReflectiveFlatbuffer( |
| schema_, schema_->objects()->Get(type_->index()))); |
| return items_.back().get(); |
| } |
| |
| flatbuffers::uoffset_t Serialize( |
| flatbuffers::FlatBufferBuilder* builder) const override { |
| std::vector<flatbuffers::Offset<void>> offsets(items_.size()); |
| for (int i = 0; i < items_.size(); i++) { |
| offsets[i] = items_[i]->Serialize(builder); |
| } |
| return builder->CreateVector(offsets).o; |
| } |
| |
| private: |
| const reflection::Schema* const schema_; |
| const reflection::Type* const type_; |
| std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_; |
| }; |
| |
| // Gets the field information for a field name, returns nullptr if the |
| // field was not defined. |
| const reflection::Field* GetFieldOrNull(const StringPiece field_name) const; |
| const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const; |
| const reflection::Field* GetFieldByOffsetOrNull(const int field_offset) const; |
| |
| // Gets a nested field and the message it is defined on. |
| bool GetFieldWithParent(const FlatbufferFieldPath* field_path, |
| ReflectiveFlatbuffer** parent, |
| reflection::Field const** field); |
| |
| // Checks whether a variant value type agrees with a field type. |
| bool IsMatchingType(const reflection::Field* field, |
| const Variant& value) const; |
| |
| // Sets a (primitive) field to a specific value. |
| // Returns true if successful, and false if the field was not found or the |
| // expected type doesn't match. |
| template <typename T> |
| bool Set(StringPiece field_name, T value) { |
| if (const reflection::Field* field = GetFieldOrNull(field_name)) { |
| return Set<T>(field, value); |
| } |
| return false; |
| } |
| |
| // Sets a (primitive) field to a specific value. |
| // Returns true if successful, and false if the expected type doesn't match. |
| // Expects `field` to be non-null. |
| template <typename T> |
| bool Set(const reflection::Field* field, T value) { |
| if (field == nullptr) { |
| TC3_LOG(ERROR) << "Expected non-null field."; |
| return false; |
| } |
| Variant variant_value(value); |
| if (!IsMatchingType(field, variant_value)) { |
| TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str() |
| << "`, expected: " << field->type()->base_type() |
| << ", got: " << variant_value.GetType(); |
| return false; |
| } |
| fields_[field] = variant_value; |
| return true; |
| } |
| |
| template <typename T> |
| bool Set(const FlatbufferFieldPath* path, T value) { |
| ReflectiveFlatbuffer* parent; |
| const reflection::Field* field; |
| if (!GetFieldWithParent(path, &parent, &field)) { |
| return false; |
| } |
| return parent->Set<T>(field, value); |
| } |
| |
| // Sets a (primitive) field to a specific value. |
| // Parses the string value according to the field type. |
| bool ParseAndSet(const reflection::Field* field, const std::string& value); |
| bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value); |
| |
| // Gets the reflective flatbuffer for a table field. |
| // Returns nullptr if the field was not found, or the field type was not a |
| // table. |
| ReflectiveFlatbuffer* Mutable(StringPiece field_name); |
| ReflectiveFlatbuffer* Mutable(const reflection::Field* field); |
| |
| // Gets the reflective flatbuffer for a repeated field. |
| // Returns nullptr if the field was not found, or the field type was not a |
| // vector. |
| RepeatedField* Repeated(StringPiece field_name); |
| RepeatedField* Repeated(const reflection::Field* field); |
| |
| template <typename T> |
| TypedRepeatedField<T>* Repeated(const reflection::Field* field) { |
| return static_cast<TypedRepeatedField<T>*>(Repeated(field)); |
| } |
| |
| template <typename T> |
| TypedRepeatedField<T>* Repeated(StringPiece field_name) { |
| return static_cast<TypedRepeatedField<T>*>(Repeated(field_name)); |
| } |
| |
| // Serializes the flatbuffer. |
| flatbuffers::uoffset_t Serialize( |
| flatbuffers::FlatBufferBuilder* builder) const; |
| std::string Serialize() const; |
| |
| // Merges the fields from the given flatbuffer table into this flatbuffer. |
| // Scalar fields will be overwritten, if present in `from`. |
| // Embedded messages will be merged. |
| bool MergeFrom(const flatbuffers::Table* from); |
| bool MergeFromSerializedFlatbuffer(StringPiece from); |
| |
| // Flattens the flatbuffer as a flat map. |
| // (Nested) fields names are joined by `key_separator`. |
| std::map<std::string, Variant> AsFlatMap( |
| const std::string& key_separator = ".") const { |
| std::map<std::string, Variant> result; |
| AsFlatMap(key_separator, /*key_prefix=*/"", &result); |
| return result; |
| } |
| |
| private: |
| const reflection::Schema* const schema_; |
| const reflection::Object* const type_; |
| |
| // Cached primitive fields (scalars and strings). |
| std::map<const reflection::Field*, Variant> fields_; |
| |
| // Cached sub-messages. |
| std::map<const reflection::Field*, std::unique_ptr<ReflectiveFlatbuffer>> |
| children_; |
| |
| // Cached repeated fields. |
| std::map<const reflection::Field*, std::unique_ptr<RepeatedField>> |
| repeated_fields_; |
| |
| // Flattens the flatbuffer as a flat map. |
| // (Nested) fields names are joined by `key_separator` and prefixed by |
| // `key_prefix`. |
| void AsFlatMap(const std::string& key_separator, |
| const std::string& key_prefix, |
| std::map<std::string, Variant>* result) const; |
| }; |
| |
| // A helper class to build flatbuffers based on schema reflection data. |
| // Can be used to a `ReflectiveFlatbuffer` for the root message of the |
| // schema, or any defined table via name. |
| class ReflectiveFlatbufferBuilder { |
| public: |
| explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema) |
| : schema_(schema) {} |
| |
| // Starts a new root table message. |
| std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const; |
| |
| // Starts a new table message. Returns nullptr if no table with given name is |
| // found in the schema. |
| std::unique_ptr<ReflectiveFlatbuffer> NewTable( |
| const StringPiece table_name) const; |
| |
| private: |
| const reflection::Schema* const schema_; |
| }; |
| |
| } // namespace libtextclassifier3 |
| |
| #endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_ |