| /* |
| * Copyright 2019 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "packet_def.h" |
| |
| #include <iomanip> |
| #include <list> |
| #include <set> |
| |
| #include "fields/all_fields.h" |
| #include "util.h" |
| |
| PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {} |
| PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {} |
| |
| PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const { |
| return nullptr; // Packets can't be fields |
| } |
| |
| void PacketDef::GenParserDefinition(std::ostream& s) const { |
| s << "class " << name_ << "View"; |
| if (parent_ != nullptr) { |
| s << " : public " << parent_->name_ << "View {"; |
| } else { |
| s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {"; |
| } |
| s << " public:"; |
| |
| // Specialize function |
| if (parent_ != nullptr) { |
| s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)"; |
| s << "{ return " << name_ << "View(std::move(parent)); }"; |
| } else { |
| s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) "; |
| s << "{ return " << name_ << "View(std::move(packet)); }"; |
| } |
| |
| GenTestingParserFromBytes(s); |
| |
| std::set<std::string> fixed_types = { |
| FixedScalarField::kFieldType, |
| FixedEnumField::kFieldType, |
| }; |
| |
| // Print all of the public fields which are all the fields minus the fixed fields. |
| const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types); |
| bool has_fixed_fields = public_fields.size() != fields_.size(); |
| for (const auto& field : public_fields) { |
| GenParserFieldGetter(s, field); |
| s << "\n"; |
| } |
| GenValidator(s); |
| s << "\n"; |
| |
| s << " public:"; |
| GenParserToString(s); |
| s << "\n"; |
| |
| s << " protected:\n"; |
| // Constructor from a View |
| if (parent_ != nullptr) { |
| s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)"; |
| s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }"; |
| } else { |
| s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) "; |
| s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}"; |
| } |
| |
| // Print the private fields which are the fixed fields. |
| if (has_fixed_fields) { |
| const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types); |
| s << " private:\n"; |
| for (const auto& field : private_fields) { |
| GenParserFieldGetter(s, field); |
| s << "\n"; |
| } |
| } |
| s << "};\n"; |
| } |
| |
| void PacketDef::GenTestingParserFromBytes(std::ostream& s) const { |
| s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n"; |
| |
| s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {"; |
| s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);"; |
| s << "return " << name_ << "View::Create("; |
| auto ancestor_ptr = parent_; |
| size_t parent_parens = 0; |
| while (ancestor_ptr != nullptr) { |
| s << ancestor_ptr->name_ << "View::Create("; |
| parent_parens++; |
| ancestor_ptr = ancestor_ptr->parent_; |
| } |
| s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)"; |
| for (size_t i = 0; i < parent_parens; i++) { |
| s << ")"; |
| } |
| s << ");"; |
| s << "}"; |
| |
| s << "\n#endif\n"; |
| } |
| |
| void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const { |
| s << "py::class_<" << name_ << "View"; |
| if (parent_ != nullptr) { |
| s << ", " << parent_->name_ << "View"; |
| } else { |
| s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>"; |
| } |
| s << ">(m, \"" << name_ << "View\")"; |
| if (parent_ != nullptr) { |
| s << ".def(py::init([](" << parent_->name_ << "View parent) {"; |
| } else { |
| s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {"; |
| } |
| s << "auto view =" << name_ << "View::Create(std::move(parent));"; |
| s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }"; |
| s << "return view; }))"; |
| |
| s << ".def(py::init(&" << name_ << "View::Create))"; |
| std::set<std::string> protected_field_types = { |
| FixedScalarField::kFieldType, |
| FixedEnumField::kFieldType, |
| SizeField::kFieldType, |
| CountField::kFieldType, |
| }; |
| const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types); |
| for (const auto& field : public_fields) { |
| auto getter_func_name = field->GetGetterFunctionName(); |
| if (getter_func_name.empty()) { |
| continue; |
| } |
| s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")"; |
| } |
| s << ".def(\"IsValid\", &" << name_ << "View::IsValid)"; |
| s << ";\n"; |
| } |
| |
| void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const { |
| // Start field offset |
| auto start_field_offset = GetOffsetForField(field->GetName(), false); |
| auto end_field_offset = GetOffsetForField(field->GetName(), true); |
| |
| if (start_field_offset.empty() && end_field_offset.empty()) { |
| ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, " |
| << "no method exists to determine field location from begin() or end().\n"; |
| } |
| |
| field->GenGetter(s, start_field_offset, end_field_offset); |
| } |
| |
| TypeDef::Type PacketDef::GetDefinitionType() const { |
| return TypeDef::Type::PACKET; |
| } |
| |
| void PacketDef::GenValidator(std::ostream& s) const { |
| // Get the static offset for all of our fields. |
| int bits_size = 0; |
| for (const auto& field : fields_) { |
| if (field->GetFieldType() != PaddingField::kFieldType) { |
| bits_size += field->GetSize().bits(); |
| } |
| } |
| |
| // Write the function declaration. |
| s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {"; |
| s << "if (was_validated_) { return true; } "; |
| s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }"; |
| s << "}"; |
| |
| s << "protected:"; |
| s << "virtual bool IsValid_() const {"; |
| |
| if (parent_ != nullptr) { |
| s << "if (!" << parent_->name_ << "View::IsValid_()) { return false; } "; |
| } |
| |
| // Offset by the parents known size. We know that any dynamic fields can |
| // already be called since the parent must have already been validated by |
| // this point. |
| auto parent_size = Size(0); |
| if (parent_ != nullptr) { |
| parent_size = parent_->GetSize(true); |
| } |
| |
| s << "auto it = begin() + (" << parent_size << ") / 8;"; |
| |
| // Check if you can extract the static fields. |
| // At this point you know you can use the size getters without crashing |
| // as long as they follow the instruction that size fields cant come before |
| // their corrisponding variable length field. |
| s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;"; |
| s << "if (it > end()) return false;"; |
| |
| // For any variable length fields, use their size check. |
| for (const auto& field : fields_) { |
| if (field->GetFieldType() == ChecksumStartField::kFieldType) { |
| auto offset = GetOffsetForField(field->GetName(), false); |
| if (!offset.empty()) { |
| s << "size_t sum_index = (" << offset << ") / 8;"; |
| } else { |
| offset = GetOffsetForField(field->GetName(), true); |
| if (offset.empty()) { |
| ERROR(field) << "Checksum Start Field offset can not be determined."; |
| } |
| s << "size_t sum_index = size() - (" << offset << ") / 8;"; |
| } |
| |
| const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName(); |
| const auto& started_field = fields_.GetField(field_name); |
| if (started_field == nullptr) { |
| ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName() |
| << ")"; |
| } |
| auto end_offset = GetOffsetForField(started_field->GetName(), false); |
| if (!end_offset.empty()) { |
| s << "size_t end_sum_index = (" << end_offset << ") / 8;"; |
| } else { |
| end_offset = GetOffsetForField(started_field->GetName(), true); |
| if (end_offset.empty()) { |
| ERROR(started_field) << "Checksum Field end_offset can not be determined."; |
| } |
| s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;"; |
| } |
| if (is_little_endian_) { |
| s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);"; |
| } else { |
| s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);"; |
| } |
| s << started_field->GetDataType() << " checksum;"; |
| s << "checksum.Initialize();"; |
| s << "for (uint8_t byte : checksum_view) { "; |
| s << "checksum.AddByte(byte);}"; |
| s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<" |
| << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }"; |
| |
| continue; |
| } |
| |
| auto field_size = field->GetSize(); |
| // Fixed size fields have already been handled. |
| if (!field_size.has_dynamic()) { |
| continue; |
| } |
| |
| // Custom fields with dynamic size must have the offset for the field passed in as well |
| // as the end iterator so that they may ensure that they don't try to read past the end. |
| // Custom fields with fixed sizes will be handled in the static offset checking. |
| if (field->GetFieldType() == CustomField::kFieldType) { |
| // Check if we can determine offset from begin(), otherwise error because by this point, |
| // the size of the custom field is unknown and can't be subtracted from end() to get the |
| // offset. |
| auto offset = GetOffsetForField(field->GetName(), false); |
| if (offset.empty()) { |
| ERROR(field) << "Custom Field offset can not be determined from begin()."; |
| } |
| |
| if (offset.bits() % 8 != 0) { |
| ERROR(field) << "Custom fields must be byte aligned."; |
| } |
| |
| // Custom fields are special as their size field takes an argument. |
| const auto& custom_size_var = field->GetName() + "_size"; |
| s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string(); |
| s << "(begin() + (" << offset << ") / 8);"; |
| |
| s << "if (!" << custom_size_var << ".has_value()) { return false; }"; |
| s << "it += *" << custom_size_var << ";"; |
| s << "if (it > end()) return false;"; |
| continue; |
| } else { |
| s << "it += (" << field_size.dynamic_string() << ") / 8;"; |
| s << "if (it > end()) return false;"; |
| } |
| } |
| |
| // Validate constraints after validating the size |
| if (parent_constraints_.size() > 0 && parent_ == nullptr) { |
| ERROR() << "Can't have a constraint on a NULL parent"; |
| } |
| |
| for (const auto& constraint : parent_constraints_) { |
| s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != "; |
| const auto& field = parent_->GetParamList().GetField(constraint.first); |
| if (field->GetFieldType() == ScalarField::kFieldType) { |
| s << std::get<int64_t>(constraint.second); |
| } else { |
| s << std::get<std::string>(constraint.second); |
| } |
| s << ") return false;"; |
| } |
| |
| // Validate the packets fields last |
| for (const auto& field : fields_) { |
| field->GenValidator(s); |
| s << "\n"; |
| } |
| |
| s << "return true;"; |
| s << "}\n"; |
| if (parent_ == nullptr) { |
| s << "bool was_validated_{false};\n"; |
| } |
| } |
| |
| void PacketDef::GenParserToString(std::ostream& s) const { |
| s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {"; |
| s << "std::stringstream ss;"; |
| s << "ss << std::showbase << std::hex << \"" << name_ << " { \";"; |
| |
| if (fields_.size() > 0) { |
| s << "ss << \"\" "; |
| bool firstfield = true; |
| for (const auto& field : fields_) { |
| if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType || |
| field->GetFieldType() == ChecksumStartField::kFieldType) |
| continue; |
| |
| s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << "; |
| |
| field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()"); |
| |
| if (firstfield) { |
| firstfield = false; |
| } |
| } |
| s << ";"; |
| } |
| |
| s << "ss << \" }\";"; |
| s << "return ss.str();"; |
| s << "}\n"; |
| } |
| |
| void PacketDef::GenBuilderDefinition(std::ostream& s) const { |
| s << "class " << name_ << "Builder"; |
| if (parent_ != nullptr) { |
| s << " : public " << parent_->name_ << "Builder"; |
| } else { |
| if (is_little_endian_) { |
| s << " : public PacketBuilder<kLittleEndian>"; |
| } else { |
| s << " : public PacketBuilder<!kLittleEndian>"; |
| } |
| } |
| s << " {"; |
| s << " public:"; |
| s << " virtual ~" << name_ << "Builder() = default;"; |
| |
| if (!fields_.HasBody()) { |
| GenBuilderCreate(s); |
| s << "\n"; |
| |
| GenTestingFromView(s); |
| s << "\n"; |
| } |
| |
| GenSerialize(s); |
| s << "\n"; |
| |
| GenSize(s); |
| s << "\n"; |
| |
| s << " protected:\n"; |
| GenBuilderConstructor(s); |
| s << "\n"; |
| |
| GenBuilderParameterChecker(s); |
| s << "\n"; |
| |
| GenMembers(s); |
| s << "};\n"; |
| |
| GenTestDefine(s); |
| s << "\n"; |
| |
| GenFuzzTestDefine(s); |
| s << "\n"; |
| } |
| |
| void PacketDef::GenTestingFromView(std::ostream& s) const { |
| s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n"; |
| |
| s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {"; |
| s << "return " << name_ << "Builder::Create("; |
| FieldList params = GetParamList().GetFieldsWithoutTypes({ |
| BodyField::kFieldType, |
| }); |
| for (std::size_t i = 0; i < params.size(); i++) { |
| params[i]->GenBuilderParameterFromView(s); |
| if (i != params.size() - 1) { |
| s << ", "; |
| } |
| } |
| s << ");"; |
| s << "}"; |
| |
| s << "\n#endif\n"; |
| } |
| |
| void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const { |
| s << "py::class_<" << name_ << "Builder"; |
| if (parent_ != nullptr) { |
| s << ", " << parent_->name_ << "Builder"; |
| } else { |
| if (is_little_endian_) { |
| s << ", PacketBuilder<kLittleEndian>"; |
| } else { |
| s << ", PacketBuilder<!kLittleEndian>"; |
| } |
| } |
| s << ", std::shared_ptr<" << name_ << "Builder>"; |
| s << ">(m, \"" << name_ << "Builder\")"; |
| if (!fields_.HasBody()) { |
| GenBuilderCreatePybind11(s); |
| } |
| s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){"; |
| s << "std::vector<uint8_t> bytes;"; |
| s << "BitInserter bi(bytes);"; |
| s << "builder.Serialize(bi);"; |
| s << "return bytes;})"; |
| s << ";\n"; |
| } |
| |
| void PacketDef::GenTestDefine(std::ostream& s) const { |
| s << "#ifdef PACKET_TESTING\n"; |
| s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)"; |
| s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { "; |
| s << "public: "; |
| s << "void CompareBytes(std::vector<uint8_t> captured_packet) {"; |
| s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);"; |
| s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());"; |
| s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}"; |
| s << "ASSERT_TRUE(view.IsValid());"; |
| s << "auto packet = " << name_ << "Builder::FromView(view);"; |
| s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();"; |
| s << "packet_bytes->reserve(packet->size());"; |
| s << "BitInserter it(*packet_bytes);"; |
| s << "packet->Serialize(it);"; |
| s << "ASSERT_EQ(*packet_bytes, captured_packet);"; |
| s << "}"; |
| s << "};"; |
| s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {"; |
| s << "CompareBytes(GetParam());"; |
| s << "}"; |
| s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, "; |
| s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))"; |
| int i = 0; |
| for (const auto& bytes : test_cases_) { |
| s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";"; |
| s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "("; |
| s << name_ << "_test_bytes_" << i << ","; |
| s << name_ << "_test_bytes_" << i << " + sizeof("; |
| s << name_ << "_test_bytes_" << i << ") - 1);"; |
| i++; |
| } |
| if (!test_cases_.empty()) { |
| i = 0; |
| s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest("; |
| for (auto bytes : test_cases_) { |
| if (i > 0) { |
| s << ","; |
| } |
| s << name_ << "_test_vec_" << i++; |
| } |
| s << ");"; |
| } |
| s << "\n#endif"; |
| } |
| |
| void PacketDef::GenFuzzTestDefine(std::ostream& s) const { |
| s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n"; |
| s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() "; |
| s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {"; |
| s << "auto vec = std::vector<uint8_t>(data, data + size);"; |
| s << name_ << "View view = " << name_ << "View::FromBytes(vec);"; |
| s << "if (!view.IsValid()) { return; }"; |
| s << "auto packet = " << name_ << "Builder::FromView(view);"; |
| s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();"; |
| s << "packet_bytes->reserve(packet->size());"; |
| s << "BitInserter it(*packet_bytes);"; |
| s << "packet->Serialize(it);"; |
| s << "}"; |
| s << "\n#endif\n"; |
| s << "#ifdef PACKET_FUZZ_TESTING\n"; |
| s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) "; |
| s << "DEFINE_" << name_ << "ReflectionFuzzTest();"; |
| s << " class " << name_ << "ReflectionFuzzTestRegistrant {"; |
| s << "public: "; |
| s << "explicit " << name_ |
| << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {"; |
| s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);"; |
| s << "}}; "; |
| s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);"; |
| s << "\n#endif"; |
| } |
| |
| FieldList PacketDef::GetParametersToValidate() const { |
| FieldList params_to_validate; |
| for (const auto& field : GetParamList()) { |
| if (field->HasParameterValidator()) { |
| params_to_validate.AppendField(field); |
| } |
| } |
| return params_to_validate; |
| } |
| |
| void PacketDef::GenBuilderCreate(std::ostream& s) const { |
| s << "static std::unique_ptr<" << name_ << "Builder> Create("; |
| |
| auto params = GetParamList(); |
| for (std::size_t i = 0; i < params.size(); i++) { |
| params[i]->GenBuilderParameter(s); |
| if (i != params.size() - 1) { |
| s << ", "; |
| } |
| } |
| s << ") {"; |
| |
| // Call the constructor |
| s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder("; |
| |
| params = params.GetFieldsWithoutTypes({ |
| PayloadField::kFieldType, |
| BodyField::kFieldType, |
| }); |
| // Add the parameters. |
| for (std::size_t i = 0; i < params.size(); i++) { |
| if (params[i]->BuilderParameterMustBeMoved()) { |
| s << "std::move(" << params[i]->GetName() << ")"; |
| } else { |
| s << params[i]->GetName(); |
| } |
| if (i != params.size() - 1) { |
| s << ", "; |
| } |
| } |
| |
| s << "));"; |
| if (fields_.HasPayload()) { |
| s << "builder->payload_ = std::move(payload);"; |
| } |
| s << "return builder;"; |
| s << "}\n"; |
| } |
| |
| void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const { |
| s << ".def(py::init([]("; |
| auto params = GetParamList(); |
| std::vector<std::string> constructor_args; |
| int i = 1; |
| for (const auto& param : params) { |
| i++; |
| std::stringstream ss; |
| auto param_type = param->GetBuilderParameterType(); |
| if (param_type.empty()) { |
| continue; |
| } |
| // Use shared_ptr instead of unique_ptr for the Python interface |
| if (param->BuilderParameterMustBeMoved()) { |
| param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr"); |
| } |
| ss << param_type << " " << param->GetName(); |
| constructor_args.push_back(ss.str()); |
| } |
| s << util::StringJoin(",", constructor_args) << "){"; |
| |
| // Deal with move only args |
| for (const auto& param : params) { |
| std::stringstream ss; |
| auto param_type = param->GetBuilderParameterType(); |
| if (param_type.empty()) { |
| continue; |
| } |
| if (!param->BuilderParameterMustBeMoved()) { |
| continue; |
| } |
| auto move_only_param_name = param->GetName() + "_move_only"; |
| s << param_type << " " << move_only_param_name << ";"; |
| if (param->IsContainerField()) { |
| // Assume single layer container and copy it |
| auto struct_type = param->GetElementField()->GetDataType(); |
| struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", ""); |
| struct_type = util::StringFindAndReplaceAll(struct_type, ">", ""); |
| s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {"; |
| // Serialize each struct |
| s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();"; |
| s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());"; |
| s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);"; |
| s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);"; |
| // Parse it again |
| s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);"; |
| s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = "; |
| s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());"; |
| // Push it into a new container |
| if (param->GetFieldType() == VectorField::kFieldType) { |
| s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));"; |
| } else if (param->GetFieldType() == ArrayField::kFieldType) { |
| s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);"; |
| } else { |
| ERROR() << param << " is not supported by Pybind11"; |
| } |
| s << "}"; |
| } else { |
| // Serialize the parameter and pass the bytes in a RawBuilder |
| s << "std::vector<uint8_t> " << param->GetName() + "_bytes;"; |
| s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());"; |
| s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);"; |
| s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);"; |
| s << move_only_param_name << " = "; |
| s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);"; |
| } |
| } |
| s << "return " << name_ << "Builder::Create("; |
| std::vector<std::string> builder_vars; |
| for (const auto& param : params) { |
| std::stringstream ss; |
| auto param_type = param->GetBuilderParameterType(); |
| if (param_type.empty()) { |
| continue; |
| } |
| auto param_name = param->GetName(); |
| if (param->BuilderParameterMustBeMoved()) { |
| ss << "std::move(" << param_name << "_move_only)"; |
| } else { |
| ss << param_name; |
| } |
| builder_vars.push_back(ss.str()); |
| } |
| s << util::StringJoin(",", builder_vars) << ");}"; |
| s << "))"; |
| } |
| |
| void PacketDef::GenBuilderParameterChecker(std::ostream& s) const { |
| FieldList params_to_validate = GetParametersToValidate(); |
| |
| // Skip writing this function if there is nothing to validate. |
| if (params_to_validate.size() == 0) { |
| return; |
| } |
| |
| // Generate function arguments. |
| s << "void CheckParameterValues("; |
| for (std::size_t i = 0; i < params_to_validate.size(); i++) { |
| params_to_validate[i]->GenBuilderParameter(s); |
| if (i != params_to_validate.size() - 1) { |
| s << ", "; |
| } |
| } |
| s << ") {"; |
| |
| // Check the parameters. |
| for (const auto& field : params_to_validate) { |
| field->GenParameterValidator(s); |
| } |
| s << "}\n"; |
| } |
| |
| void PacketDef::GenBuilderConstructor(std::ostream& s) const { |
| s << "explicit " << name_ << "Builder("; |
| |
| // Generate the constructor parameters. |
| auto params = GetParamList().GetFieldsWithoutTypes({ |
| PayloadField::kFieldType, |
| BodyField::kFieldType, |
| }); |
| for (std::size_t i = 0; i < params.size(); i++) { |
| params[i]->GenBuilderParameter(s); |
| if (i != params.size() - 1) { |
| s << ", "; |
| } |
| } |
| if (params.size() > 0 || parent_constraints_.size() > 0) { |
| s << ") :"; |
| } else { |
| s << ")"; |
| } |
| |
| // Get the list of parent params to call the parent constructor with. |
| FieldList parent_params; |
| if (parent_ != nullptr) { |
| // Pass parameters to the parent constructor |
| s << parent_->name_ << "Builder("; |
| parent_params = parent_->GetParamList().GetFieldsWithoutTypes({ |
| PayloadField::kFieldType, |
| BodyField::kFieldType, |
| }); |
| |
| // Go through all the fields and replace constrained fields with fixed values |
| // when calling the parent constructor. |
| for (std::size_t i = 0; i < parent_params.size(); i++) { |
| const auto& field = parent_params[i]; |
| const auto& constraint = parent_constraints_.find(field->GetName()); |
| if (constraint != parent_constraints_.end()) { |
| if (field->GetFieldType() == ScalarField::kFieldType) { |
| s << std::get<int64_t>(constraint->second); |
| } else if (field->GetFieldType() == EnumField::kFieldType) { |
| s << std::get<std::string>(constraint->second); |
| } else { |
| ERROR(field) << "Constraints on non enum/scalar fields should be impossible."; |
| } |
| |
| s << "/* " << field->GetName() << "_ */"; |
| } else { |
| s << field->GetName(); |
| } |
| |
| if (i != parent_params.size() - 1) { |
| s << ", "; |
| } |
| } |
| s << ") "; |
| } |
| |
| // Build a list of parameters that excludes all parent parameters. |
| FieldList saved_params; |
| for (const auto& field : params) { |
| if (parent_params.GetField(field->GetName()) == nullptr) { |
| saved_params.AppendField(field); |
| } |
| } |
| if (parent_ != nullptr && saved_params.size() > 0) { |
| s << ","; |
| } |
| for (std::size_t i = 0; i < saved_params.size(); i++) { |
| const auto& saved_param_name = saved_params[i]->GetName(); |
| if (saved_params[i]->BuilderParameterMustBeMoved()) { |
| s << saved_param_name << "_(std::move(" << saved_param_name << "))"; |
| } else { |
| s << saved_param_name << "_(" << saved_param_name << ")"; |
| } |
| if (i != saved_params.size() - 1) { |
| s << ","; |
| } |
| } |
| s << " {"; |
| |
| FieldList params_to_validate = GetParametersToValidate(); |
| |
| if (params_to_validate.size() > 0) { |
| s << "CheckParameterValues("; |
| for (std::size_t i = 0; i < params_to_validate.size(); i++) { |
| s << params_to_validate[i]->GetName() << "_"; |
| if (i != params_to_validate.size() - 1) { |
| s << ", "; |
| } |
| } |
| s << ");"; |
| } |
| |
| s << "}\n"; |
| } |
| |
| void PacketDef::GenRustChildEnums(std::ostream& s) const { |
| if (HasChildEnums()) { |
| bool payload = fields_.HasPayload(); |
| s << "#[derive(Debug)] "; |
| s << "enum " << name_ << "DataChild {"; |
| for (const auto& child : children_) { |
| s << child->name_ << "(Arc<" << child->name_ << "Data>),"; |
| } |
| if (payload) { |
| s << "Payload(Bytes),"; |
| } |
| s << "None,"; |
| s << "}\n"; |
| |
| s << "impl " << name_ << "DataChild {"; |
| s << "fn get_total_size(&self) -> usize {"; |
| s << "match self {"; |
| for (const auto& child : children_) { |
| s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),"; |
| } |
| if (payload) { |
| s << name_ << "DataChild::Payload(p) => p.len(),"; |
| } |
| s << name_ << "DataChild::None => 0,"; |
| s << "}\n"; |
| s << "}\n"; |
| s << "}\n"; |
| |
| s << "#[derive(Debug)] "; |
| s << "pub enum " << name_ << "Child {"; |
| for (const auto& child : children_) { |
| s << child->name_ << "(" << child->name_ << "Packet),"; |
| } |
| if (payload) { |
| s << "Payload(Bytes),"; |
| } |
| s << "None,"; |
| s << "}\n"; |
| } |
| } |
| |
| void PacketDef::GenRustStructDeclarations(std::ostream& s) const { |
| s << "#[derive(Debug)] "; |
| s << "struct " << name_ << "Data {"; |
| |
| // Generate struct fields |
| GenRustStructFieldNameAndType(s); |
| if (HasChildEnums()) { |
| s << "child: " << name_ << "DataChild,"; |
| } |
| s << "}\n"; |
| |
| // Generate accessor struct |
| s << "#[derive(Debug, Clone)] "; |
| s << "pub struct " << name_ << "Packet {"; |
| auto lineage = GetAncestors(); |
| lineage.push_back(this); |
| for (auto it = lineage.begin(); it != lineage.end(); it++) { |
| auto def = *it; |
| s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,"; |
| } |
| s << "}\n"; |
| |
| // Generate builder struct |
| s << "#[derive(Debug)] "; |
| s << "pub struct " << name_ << "Builder {"; |
| auto params = GetParamList().GetFieldsWithoutTypes({ |
| PayloadField::kFieldType, |
| BodyField::kFieldType, |
| }); |
| for (auto param : params) { |
| s << "pub "; |
| param->GenRustNameAndType(s); |
| s << ", "; |
| } |
| if (fields_.HasPayload()) { |
| s << "pub payload: Option<Bytes>,"; |
| } |
| s << "}\n"; |
| } |
| |
| bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const { |
| auto fields = fields_.GetFieldsWithoutTypes({ |
| BodyField::kFieldType, |
| CountField::kFieldType, |
| PaddingField::kFieldType, |
| ReservedField::kFieldType, |
| SizeField::kFieldType, |
| PayloadField::kFieldType, |
| FixedScalarField::kFieldType, |
| }); |
| if (fields.size() == 0) { |
| return false; |
| } |
| for (const auto& field : fields) { |
| field->GenRustNameAndType(s); |
| s << ", "; |
| } |
| return true; |
| } |
| |
| void PacketDef::GenRustStructFieldNames(std::ostream& s) const { |
| auto fields = fields_.GetFieldsWithoutTypes({ |
| BodyField::kFieldType, |
| CountField::kFieldType, |
| PaddingField::kFieldType, |
| ReservedField::kFieldType, |
| SizeField::kFieldType, |
| PayloadField::kFieldType, |
| FixedScalarField::kFieldType, |
| }); |
| for (const auto field : fields) { |
| s << field->GetName(); |
| s << ", "; |
| } |
| } |
| |
| void PacketDef::GenRustStructImpls(std::ostream& s) const { |
| s << "impl " << name_ << "Data {"; |
| |
| // conforms function |
| s << "fn conforms(bytes: &[u8]) -> bool {"; |
| GenRustConformanceCheck(s); |
| |
| auto fields = fields_.GetFieldsWithTypes({ |
| StructField::kFieldType, |
| }); |
| |
| for (auto const& field : fields) { |
| auto start_offset = GetOffsetForField(field->GetName(), false); |
| auto end_offset = GetOffsetForField(field->GetName(), true); |
| |
| s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes(); |
| s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }"; |
| } |
| |
| s << " true"; |
| s << "}"; |
| |
| // parse function |
| if (parent_constraints_.empty() && children_.size() > 1 && parent_ != nullptr) { |
| auto constraint = FindConstraintField(); |
| auto constraint_field = GetParamList().GetField(constraint); |
| auto constraint_type = constraint_field->GetRustDataType(); |
| s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type << ") -> Result<Self> {"; |
| } else { |
| s << "fn parse(bytes: &[u8]) -> Result<Self> {"; |
| } |
| fields = fields_.GetFieldsWithoutTypes({ |
| BodyField::kFieldType, |
| }); |
| |
| for (auto const& field : fields) { |
| auto start_field_offset = GetOffsetForField(field->GetName(), false); |
| auto end_field_offset = GetOffsetForField(field->GetName(), true); |
| |
| if (start_field_offset.empty() && end_field_offset.empty()) { |
| ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, " |
| << "no method exists to determine field location from begin() or end().\n"; |
| } |
| |
| field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_); |
| field->GenRustGetter(s, start_field_offset, end_field_offset); |
| } |
| |
| auto payload_field = fields_.GetFieldsWithTypes({ |
| PayloadField::kFieldType, |
| }); |
| |
| Size payload_offset; |
| |
| if (payload_field.HasPayload()) { |
| payload_offset = GetOffsetForField(payload_field[0]->GetName(), false); |
| } |
| |
| auto constraint_name = FindConstraintField(); |
| auto constrained_descendants = FindDescendantsWithConstraint(constraint_name); |
| |
| if (children_.size() > 1) { |
| s << "let child = match " << constraint_name << " {"; |
| |
| for (const auto& desc : constrained_descendants) { |
| auto desc_path = FindPathToDescendant(desc.first->name_); |
| std::reverse(desc_path.begin(), desc_path.end()); |
| auto constraint_field = GetParamList().GetField(constraint_name); |
| auto constraint_type = constraint_field->GetFieldType(); |
| |
| if (constraint_type == EnumField::kFieldType) { |
| auto type = std::get<std::string>(desc.second); |
| auto variant_name = type.substr(type.find("::") + 2, type.length()); |
| auto enum_type = type.substr(0, type.find("::")); |
| auto enum_variant = enum_type + "::" |
| + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name)); |
| s << enum_variant; |
| s << " if " << desc_path[0]->name_ << "Data::conforms(&bytes[..])"; |
| s << " => {"; |
| s << name_ << "DataChild::"; |
| s << desc_path[0]->name_ << "(Arc::new("; |
| if (desc_path[0]->parent_constraints_.empty()) { |
| s << desc_path[0]->name_ << "Data::parse(&bytes[..]"; |
| s << ", " << enum_variant << ")?))"; |
| } else { |
| s << desc_path[0]->name_ << "Data::parse(&bytes[..])?))"; |
| } |
| } else if (constraint_type == ScalarField::kFieldType) { |
| s << std::get<int64_t>(desc.second) << " => {"; |
| s << "unimplemented!();"; |
| } |
| s << "}\n"; |
| } |
| |
| if (!constrained_descendants.empty()) { |
| s << "v => return Err(Error::ConstraintOutOfBounds{field: \"" << constraint_name |
| << "\".to_string(), value: v as u64}),"; |
| } |
| |
| s << "};\n"; |
| } else if (children_.size() == 1) { |
| auto child = children_.at(0); |
| s << "let child = match " << child->name_ << "Data::parse(&bytes[..]) {"; |
| s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {"; |
| s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))"; |
| s << " },"; |
| s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,"; |
| s << " _ => return Err(Error::InvalidPacketError),"; |
| s << "};"; |
| } else if (fields_.HasPayload()) { |
| s << "let child = if payload.len() > 0 {"; |
| s << name_ << "DataChild::Payload(Bytes::from(payload))"; |
| s << "} else {"; |
| s << name_ << "DataChild::None"; |
| s << "};"; |
| } |
| |
| s << "Ok(Self {"; |
| fields = fields_.GetFieldsWithoutTypes({ |
| BodyField::kFieldType, |
| CountField::kFieldType, |
| PaddingField::kFieldType, |
| ReservedField::kFieldType, |
| SizeField::kFieldType, |
| PayloadField::kFieldType, |
| FixedScalarField::kFieldType, |
| }); |
| |
| if (fields.size() > 0) { |
| for (const auto& field : fields) { |
| auto field_type = field->GetFieldType(); |
| s << field->GetName(); |
| s << ", "; |
| } |
| } |
| |
| if (HasChildEnums()) { |
| s << "child,"; |
| } |
| s << "})\n"; |
| s << "}\n"; |
| |
| // write_to function |
| s << "fn write_to(&self, buffer: &mut BytesMut) {"; |
| GenRustWriteToFields(s); |
| |
| if (HasChildEnums()) { |
| s << "match &self.child {"; |
| for (const auto& child : children_) { |
| s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),"; |
| } |
| if (fields_.HasPayload()) { |
| auto offset = GetOffsetForField("payload"); |
| s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),"; |
| } |
| s << name_ << "DataChild::None => {}"; |
| s << "}"; |
| } |
| |
| s << "}\n"; |
| |
| s << "fn get_total_size(&self) -> usize {"; |
| if (HasChildEnums()) { |
| s << "self.get_size() + self.child.get_total_size()"; |
| } else { |
| s << "self.get_size()"; |
| } |
| s << "}\n"; |
| |
| s << "fn get_size(&self) -> usize {"; |
| GenSizeRetVal(s); |
| s << "}\n"; |
| s << "}\n"; |
| } |
| |
| void PacketDef::GenRustAccessStructImpls(std::ostream& s) const { |
| if (complement_ != nullptr) { |
| auto complement_root = complement_->GetRootDef(); |
| auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_); |
| s << "impl CommandExpectations for " << name_ << "Packet {"; |
| s << " type ResponseType = " << complement_->name_ << "Packet;"; |
| s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { "; |
| s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"; |
| s << " }"; |
| s << "}"; |
| } |
| |
| s << "impl Packet for " << name_ << "Packet {"; |
| auto root = GetRootDef(); |
| auto root_accessor = util::CamelCaseToUnderScore(root->name_); |
| |
| s << "fn to_bytes(self) -> Bytes {"; |
| s << " let mut buffer = BytesMut::new();"; |
| s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);"; |
| s << " self." << root_accessor << ".write_to(&mut buffer);"; |
| s << " buffer.freeze()"; |
| s << "}\n"; |
| |
| s << "fn to_vec(self) -> Vec<u8> { self.to_bytes().to_vec() }\n"; |
| s << "}"; |
| |
| s << "impl From<" << name_ << "Packet" |
| << "> for Bytes {\n"; |
| s << "fn from(packet: " << name_ << "Packet" |
| << ") -> Self {\n"; |
| s << "packet.to_bytes()\n"; |
| s << "}\n"; |
| s << "}\n"; |
| |
| s << "impl From<" << name_ << "Packet" |
| << "> for Vec<u8> {\n"; |
| s << "fn from(packet: " << name_ << "Packet" |
| << ") -> Self {\n"; |
| s << "packet.to_vec()\n"; |
| s << "}\n"; |
| s << "}\n"; |
| |
| s << "impl " << name_ << "Packet {"; |
| if (parent_ == nullptr) { |
| s << "pub fn parse(bytes: &[u8]) -> Result<Self> { "; |
| s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)))"; |
| s << "}"; |
| } |
| |
| if (HasChildEnums()) { |
| s << " pub fn specialize(&self) -> " << name_ << "Child {"; |
| s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {"; |
| for (const auto& child : children_) { |
| s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "(" |
| << child->name_ << "Packet::new(self." << root_accessor << ".clone())),"; |
| } |
| if (fields_.HasPayload()) { |
| s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),"; |
| } |
| s << name_ << "DataChild::None => " << name_ << "Child::None,"; |
| s << "}}"; |
| } |
| auto lineage = GetAncestors(); |
| lineage.push_back(this); |
| const ParentDef* prev = nullptr; |
| |
| s << " fn new(root: Arc<" << root->name_ << "Data>) -> Self {"; |
| for (auto it = lineage.begin(); it != lineage.end(); it++) { |
| auto def = *it; |
| auto accessor_name = util::CamelCaseToUnderScore(def->name_); |
| if (prev == nullptr) { |
| s << "let " << accessor_name << " = root;"; |
| } else { |
| s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {"; |
| s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),"; |
| s << "_ => panic!(\"inconsistent state - child was not " << def->name_ << "\"),"; |
| s << "};"; |
| } |
| prev = def; |
| } |
| s << "Self {"; |
| for (auto it = lineage.begin(); it != lineage.end(); it++) { |
| auto def = *it; |
| s << util::CamelCaseToUnderScore(def->name_) << ","; |
| } |
| s << "}}"; |
| |
| for (auto it = lineage.begin(); it != lineage.end(); it++) { |
| auto def = *it; |
| auto fields = def->fields_.GetFieldsWithoutTypes({ |
| BodyField::kFieldType, |
| CountField::kFieldType, |
| PaddingField::kFieldType, |
| ReservedField::kFieldType, |
| SizeField::kFieldType, |
| PayloadField::kFieldType, |
| FixedScalarField::kFieldType, |
| }); |
| |
| for (auto const& field : fields) { |
| if (field->GetterIsByRef()) { |
| s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{"; |
| s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName(); |
| s << "}\n"; |
| } else { |
| s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{"; |
| s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName(); |
| s << "}\n"; |
| } |
| } |
| } |
| |
| s << "}\n"; |
| |
| lineage = GetAncestors(); |
| for (auto it = lineage.begin(); it != lineage.end(); it++) { |
| auto def = *it; |
| s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {"; |
| s << " fn into(self) -> " << def->name_ << "Packet {"; |
| s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")"; |
| s << " }"; |
| s << "}\n"; |
| } |
| } |
| |
| void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const { |
| if (complement_ != nullptr) { |
| auto complement_root = complement_->GetRootDef(); |
| auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_); |
| s << "impl CommandExpectations for " << name_ << "Builder {"; |
| s << " type ResponseType = " << complement_->name_ << "Packet;"; |
| s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { "; |
| s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())"; |
| s << " }"; |
| s << "}"; |
| } |
| |
| s << "impl " << name_ << "Builder {"; |
| s << "pub fn build(self) -> " << name_ << "Packet {"; |
| auto lineage = GetAncestors(); |
| lineage.push_back(this); |
| std::reverse(lineage.begin(), lineage.end()); |
| |
| auto all_constraints = GetAllConstraints(); |
| |
| const ParentDef* prev = nullptr; |
| for (auto ancestor : lineage) { |
| auto fields = ancestor->fields_.GetFieldsWithoutTypes({ |
| BodyField::kFieldType, |
| CountField::kFieldType, |
| PaddingField::kFieldType, |
| ReservedField::kFieldType, |
| SizeField::kFieldType, |
| PayloadField::kFieldType, |
| FixedScalarField::kFieldType, |
| }); |
| |
| auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_); |
| s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {"; |
| for (auto field : fields) { |
| auto constraint = all_constraints.find(field->GetName()); |
| s << field->GetName() << ": "; |
| if (constraint != all_constraints.end()) { |
| if (field->GetFieldType() == ScalarField::kFieldType) { |
| s << std::get<int64_t>(constraint->second); |
| } else if (field->GetFieldType() == EnumField::kFieldType) { |
| auto value = std::get<std::string>(constraint->second); |
| auto constant = value.substr(value.find("::") + 2, std::string::npos); |
| s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant); |
| ; |
| } else { |
| ERROR(field) << "Constraints on non enum/scalar fields should be impossible."; |
| } |
| } else { |
| s << "self." << field->GetName(); |
| } |
| s << ", "; |
| } |
| if (ancestor->HasChildEnums()) { |
| if (prev == nullptr) { |
| if (ancestor->fields_.HasPayload()) { |
| s << "child: match self.payload { "; |
| s << "None => " << name_ << "DataChild::None,"; |
| s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),"; |
| s << "},"; |
| } else { |
| s << "child: " << name_ << "DataChild::None,"; |
| } |
| } else { |
| s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "(" |
| << util::CamelCaseToUnderScore(prev->name_) << "),"; |
| } |
| } |
| s << "});"; |
| prev = ancestor; |
| } |
| |
| s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ")"; |
| s << "}\n"; |
| |
| s << "}\n"; |
| for (const auto ancestor : GetAncestors()) { |
| s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {"; |
| s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }"; |
| s << "}\n"; |
| } |
| } |
| |
| void PacketDef::GenRustBuilderTest(std::ostream& s) const { |
| auto lineage = GetAncestors(); |
| lineage.push_back(this); |
| if (!lineage.empty() && !test_cases_.empty()) { |
| s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { "; |
| s << "($($name:ident: $byte_string:expr,)*) => {"; |
| s << "$("; |
| s << "\n#[test]\n"; |
| s << "pub fn $name() { "; |
| s << "let raw_bytes = $byte_string;"; |
| for (size_t i = 0; i < lineage.size(); i++) { |
| s << "/* (" << i << ") */\n"; |
| if (i == 0) { |
| s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {"; |
| s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {"; |
| s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {"; |
| } else if (i != lineage.size() - 1) { |
| s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "("; |
| s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {"; |
| s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {"; |
| } else { |
| s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {"; |
| s << "let rebuilder = " << lineage[i]->name_ << "Builder {"; |
| FieldList params = GetParamList(); |
| if (params.HasBody()) { |
| ERROR() << "Packets with body fields can't be auto-tested. Test a child."; |
| } |
| for (const auto param : params) { |
| s << param->GetName() << " : packet."; |
| if (param->GetFieldType() == VectorField::kFieldType) { |
| s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),"; |
| } else if (param->GetFieldType() == ArrayField::kFieldType) { |
| const auto array_param = static_cast<const ArrayField*>(param); |
| const auto element_field = array_param->GetElementField(); |
| if (element_field->GetFieldType() == StructField::kFieldType) { |
| s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),"; |
| } else { |
| s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),"; |
| } |
| } else if (param->GetFieldType() == StructField::kFieldType) { |
| s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),"; |
| } else { |
| s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),"; |
| } |
| } |
| s << "};"; |
| s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();"; |
| s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();"; |
| s << "assert_eq!(rebuilder_bytes, raw_bytes);"; |
| s << "}"; |
| } |
| } |
| for (size_t i = 1; i < lineage.size(); i++) { |
| s << "_ => {"; |
| s << "println!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_); |
| s << "{:02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); "; |
| s << "}}}"; |
| } |
| |
| s << ","; |
| s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),"; |
| s << "}"; |
| s << "}"; |
| s << ")*"; |
| s << "}"; |
| s << "}"; |
| |
| s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { "; |
| int number = 0; |
| for (const auto& test_case : test_cases_) { |
| s << util::CamelCaseToUnderScore(name_) << "_builder_test_"; |
| s << std::setfill('0') << std::setw(2) << number++ << ": "; |
| s << "b\"" << test_case << "\","; |
| } |
| s << "}"; |
| s << "\n"; |
| } |
| } |
| |
| void PacketDef::GenRustDef(std::ostream& s) const { |
| GenRustChildEnums(s); |
| GenRustStructDeclarations(s); |
| GenRustStructImpls(s); |
| GenRustAccessStructImpls(s); |
| GenRustBuilderStructImpls(s); |
| GenRustBuilderTest(s); |
| } |