| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "utils/flatbuffers.h" |
| |
| #include <vector> |
| #include "utils/strings/numbers.h" |
| #include "utils/variant.h" |
| |
| namespace libtextclassifier3 { |
| namespace { |
| bool CreateRepeatedField( |
| const reflection::Schema* schema, const reflection::Type* type, |
| std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) { |
| switch (type->element()) { |
| case reflection::Bool: |
| repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>); |
| return true; |
| case reflection::Int: |
| repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>); |
| return true; |
| case reflection::Long: |
| repeated_field->reset( |
| new ReflectiveFlatbuffer::TypedRepeatedField<int64>); |
| return true; |
| case reflection::Float: |
| repeated_field->reset( |
| new ReflectiveFlatbuffer::TypedRepeatedField<float>); |
| return true; |
| case reflection::Double: |
| repeated_field->reset( |
| new ReflectiveFlatbuffer::TypedRepeatedField<double>); |
| return true; |
| case reflection::String: |
| repeated_field->reset( |
| new ReflectiveFlatbuffer::TypedRepeatedField<std::string>); |
| return true; |
| case reflection::Obj: |
| repeated_field->reset( |
| new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>( |
| schema, type)); |
| return true; |
| default: |
| TC3_LOG(ERROR) << "Unsupported type: " << type->element(); |
| return false; |
| } |
| } |
| } // namespace |
| |
| template <> |
| const char* FlatbufferFileIdentifier<Model>() { |
| return ModelIdentifier(); |
| } |
| |
| std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot() |
| const { |
| if (!schema_->root_table()) { |
| TC3_LOG(ERROR) << "No root table specified."; |
| return nullptr; |
| } |
| return std::unique_ptr<ReflectiveFlatbuffer>( |
| new ReflectiveFlatbuffer(schema_, schema_->root_table())); |
| } |
| |
| std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable( |
| StringPiece table_name) const { |
| for (const reflection::Object* object : *schema_->objects()) { |
| if (table_name.Equals(object->name()->str())) { |
| return std::unique_ptr<ReflectiveFlatbuffer>( |
| new ReflectiveFlatbuffer(schema_, object)); |
| } |
| } |
| return nullptr; |
| } |
| |
| const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull( |
| const StringPiece field_name) const { |
| return type_->fields()->LookupByKey(field_name.data()); |
| } |
| |
| const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull( |
| const FlatbufferField* field) const { |
| // Lookup by name might be faster as the fields are sorted by name in the |
| // schema data, so try that first. |
| if (field->field_name() != nullptr) { |
| return GetFieldOrNull(field->field_name()->str()); |
| } |
| return GetFieldByOffsetOrNull(field->field_offset()); |
| } |
| |
| bool ReflectiveFlatbuffer::GetFieldWithParent( |
| const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent, |
| reflection::Field const** field) { |
| const auto* path = field_path->field(); |
| if (path == nullptr || path->size() == 0) { |
| return false; |
| } |
| |
| for (int i = 0; i < path->size(); i++) { |
| *parent = (i == 0 ? this : (*parent)->Mutable(*field)); |
| if (*parent == nullptr) { |
| return false; |
| } |
| *field = (*parent)->GetFieldOrNull(path->Get(i)); |
| if (*field == nullptr) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull( |
| const int field_offset) const { |
| if (type_->fields() == nullptr) { |
| return nullptr; |
| } |
| for (const reflection::Field* field : *type_->fields()) { |
| if (field->offset() == field_offset) { |
| return field; |
| } |
| } |
| return nullptr; |
| } |
| |
| bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field, |
| const Variant& value) const { |
| switch (field->type()->base_type()) { |
| case reflection::Bool: |
| return value.HasBool(); |
| case reflection::Int: |
| return value.HasInt(); |
| case reflection::Long: |
| return value.HasInt64(); |
| case reflection::Float: |
| return value.HasFloat(); |
| case reflection::Double: |
| return value.HasDouble(); |
| case reflection::String: |
| return value.HasString(); |
| default: |
| return false; |
| } |
| } |
| |
| bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field, |
| const std::string& value) { |
| switch (field->type()->base_type()) { |
| case reflection::String: |
| return Set(field, value); |
| case reflection::Int: { |
| int32 int_value; |
| if (!ParseInt32(value.data(), &int_value)) { |
| TC3_LOG(ERROR) << "Could not parse '" << value << "' as int32."; |
| return false; |
| } |
| return Set(field, int_value); |
| } |
| case reflection::Long: { |
| int64 int_value; |
| if (!ParseInt64(value.data(), &int_value)) { |
| TC3_LOG(ERROR) << "Could not parse '" << value << "' as int64."; |
| return false; |
| } |
| return Set(field, int_value); |
| } |
| case reflection::Float: { |
| double double_value; |
| if (!ParseDouble(value.data(), &double_value)) { |
| TC3_LOG(ERROR) << "Could not parse '" << value << "' as float."; |
| return false; |
| } |
| return Set(field, static_cast<float>(double_value)); |
| } |
| case reflection::Double: { |
| double double_value; |
| if (!ParseDouble(value.data(), &double_value)) { |
| TC3_LOG(ERROR) << "Could not parse '" << value << "' as double."; |
| return false; |
| } |
| return Set(field, double_value); |
| } |
| default: |
| TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type(); |
| return false; |
| } |
| } |
| |
| bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path, |
| const std::string& value) { |
| ReflectiveFlatbuffer* parent; |
| const reflection::Field* field; |
| if (!GetFieldWithParent(path, &parent, &field)) { |
| return false; |
| } |
| return parent->ParseAndSet(field, value); |
| } |
| |
| ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable( |
| const StringPiece field_name) { |
| if (const reflection::Field* field = GetFieldOrNull(field_name)) { |
| return Mutable(field); |
| } |
| TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString(); |
| return nullptr; |
| } |
| |
| ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable( |
| const reflection::Field* field) { |
| if (field->type()->base_type() != reflection::Obj) { |
| TC3_LOG(ERROR) << "Field is not of type Object."; |
| return nullptr; |
| } |
| const auto entry = children_.find(field); |
| if (entry != children_.end()) { |
| return entry->second.get(); |
| } |
| const auto it = children_.insert( |
| /*hint=*/entry, |
| std::make_pair( |
| field, |
| std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer( |
| schema_, schema_->objects()->Get(field->type()->index()))))); |
| return it->second.get(); |
| } |
| |
| ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated( |
| StringPiece field_name) { |
| if (const reflection::Field* field = GetFieldOrNull(field_name)) { |
| return Repeated(field); |
| } |
| TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString(); |
| return nullptr; |
| } |
| |
| ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated( |
| const reflection::Field* field) { |
| if (field->type()->base_type() != reflection::Vector) { |
| TC3_LOG(ERROR) << "Field is not of type Vector."; |
| return nullptr; |
| } |
| |
| // If the repeated field was already set, return its instance. |
| const auto entry = repeated_fields_.find(field); |
| if (entry != repeated_fields_.end()) { |
| return entry->second.get(); |
| } |
| |
| // Otherwise, create a new instance and store it. |
| std::unique_ptr<RepeatedField> repeated_field; |
| if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) { |
| TC3_LOG(ERROR) << "Could not create repeated field."; |
| return nullptr; |
| } |
| const auto it = repeated_fields_.insert( |
| /*hint=*/entry, std::make_pair(field, std::move(repeated_field))); |
| return it->second.get(); |
| } |
| |
| flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize( |
| flatbuffers::FlatBufferBuilder* builder) const { |
| // Build all children before we can start with this table. |
| std::vector< |
| std::pair</* field vtable offset */ int, |
| /* field data offset in buffer */ flatbuffers::uoffset_t>> |
| offsets; |
| offsets.reserve(children_.size() + repeated_fields_.size()); |
| for (const auto& it : children_) { |
| offsets.push_back({it.first->offset(), it.second->Serialize(builder)}); |
| } |
| |
| // Create strings. |
| for (const auto& it : fields_) { |
| if (it.second.HasString()) { |
| offsets.push_back({it.first->offset(), |
| builder->CreateString(it.second.StringValue()).o}); |
| } |
| } |
| |
| // Build the repeated fields. |
| for (const auto& it : repeated_fields_) { |
| offsets.push_back({it.first->offset(), it.second->Serialize(builder)}); |
| } |
| |
| // Build the table now. |
| const flatbuffers::uoffset_t table_start = builder->StartTable(); |
| |
| // Add scalar fields. |
| for (const auto& it : fields_) { |
| switch (it.second.GetType()) { |
| case Variant::TYPE_BOOL_VALUE: |
| builder->AddElement<uint8_t>( |
| it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()), |
| static_cast<uint8_t>(it.first->default_integer())); |
| continue; |
| case Variant::TYPE_INT_VALUE: |
| builder->AddElement<int32>( |
| it.first->offset(), it.second.IntValue(), |
| static_cast<int32>(it.first->default_integer())); |
| continue; |
| case Variant::TYPE_INT64_VALUE: |
| builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(), |
| it.first->default_integer()); |
| continue; |
| case Variant::TYPE_FLOAT_VALUE: |
| builder->AddElement<float>( |
| it.first->offset(), it.second.FloatValue(), |
| static_cast<float>(it.first->default_real())); |
| continue; |
| case Variant::TYPE_DOUBLE_VALUE: |
| builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(), |
| it.first->default_real()); |
| continue; |
| default: |
| continue; |
| } |
| } |
| |
| // Add strings, subtables and repeated fields. |
| for (const auto& it : offsets) { |
| builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second)); |
| } |
| |
| return builder->EndTable(table_start); |
| } |
| |
| std::string ReflectiveFlatbuffer::Serialize() const { |
| flatbuffers::FlatBufferBuilder builder; |
| builder.Finish(flatbuffers::Offset<void>(Serialize(&builder))); |
| return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), |
| builder.GetSize()); |
| } |
| |
| bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) { |
| // No fields to set. |
| if (type_->fields() == nullptr) { |
| return true; |
| } |
| |
| for (const reflection::Field* field : *type_->fields()) { |
| // Skip fields that are not explicitly set. |
| if (!from->CheckField(field->offset())) { |
| continue; |
| } |
| const reflection::BaseType type = field->type()->base_type(); |
| switch (type) { |
| case reflection::Bool: |
| Set<bool>(field, from->GetField<uint8_t>(field->offset(), |
| field->default_integer())); |
| break; |
| case reflection::Int: |
| Set<int32>(field, from->GetField<int32>(field->offset(), |
| field->default_integer())); |
| break; |
| case reflection::Long: |
| Set<int64>(field, from->GetField<int64>(field->offset(), |
| field->default_integer())); |
| break; |
| case reflection::Float: |
| Set<float>(field, from->GetField<float>(field->offset(), |
| field->default_real())); |
| break; |
| case reflection::Double: |
| Set<double>(field, from->GetField<double>(field->offset(), |
| field->default_real())); |
| break; |
| case reflection::String: |
| Set<std::string>( |
| field, from->GetPointer<const flatbuffers::String*>(field->offset()) |
| ->str()); |
| break; |
| case reflection::Obj: |
| if (!Mutable(field)->MergeFrom( |
| from->GetPointer<const flatbuffers::Table* const>( |
| field->offset()))) { |
| return false; |
| } |
| break; |
| default: |
| TC3_LOG(ERROR) << "Unsupported type: " << type; |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) { |
| return MergeFrom(flatbuffers::GetAnyRoot( |
| reinterpret_cast<const unsigned char*>(from.data()))); |
| } |
| |
| void ReflectiveFlatbuffer::AsFlatMap( |
| const std::string& key_separator, const std::string& key_prefix, |
| std::map<std::string, Variant>* result) const { |
| // Add direct fields. |
| for (auto it : fields_) { |
| (*result)[key_prefix + it.first->name()->str()] = it.second; |
| } |
| |
| // Add nested messages. |
| for (auto& it : children_) { |
| it.second->AsFlatMap(key_separator, |
| key_prefix + it.first->name()->str() + key_separator, |
| result); |
| } |
| } |
| |
| } // namespace libtextclassifier3 |