rusty-gd: Fill in new and parse methods for *Data structs
Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost SimpleHalTest
Change-Id: Ifff55709c1a3d65ca443351e2e08effdd2fa34d8
diff --git a/system/gd/packet/parser/gen_rust.cc b/system/gd/packet/parser/gen_rust.cc
index 4b035d8..8e7ce02 100644
--- a/system/gd/packet/parser/gen_rust.cc
+++ b/system/gd/packet/parser/gen_rust.cc
@@ -147,6 +147,9 @@
}
for (const auto& packet_def : decls.packet_defs_queue_) {
+ if (packet_def.second->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
+ continue;
+ }
packet_def.second->GenRustDef(out_file);
out_file << "\n\n";
}
diff --git a/system/gd/packet/parser/packet_def.cc b/system/gd/packet/parser/packet_def.cc
index ef7660b..ba76f17 100644
--- a/system/gd/packet/parser/packet_def.cc
+++ b/system/gd/packet/parser/packet_def.cc
@@ -745,6 +745,9 @@
s << "#[derive(Debug)] ";
s << "enum " << name_ << "DataChild {";
for (const auto& child : children_) {
+ if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
+ continue;
+ }
s << child->name_ << "(Arc<" << child->name_ << "Data>),";
}
s << "None,";
@@ -752,6 +755,9 @@
s << "#[derive(Debug)] ";
s << "pub enum " << name_ << "Child {";
for (const auto& child : children_) {
+ if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
+ continue;
+ }
s << child->name_ << "(" << child->name_ << "Packet),";
}
s << "None,";
@@ -851,21 +857,36 @@
s << "impl " << name_ << "Data {";
s << "fn new(";
bool fields_exist = GenRustStructFieldNameAndType(s);
- s << ") -> Self { unimplemented!();"; /* Self {";
- GenRustStructFieldNames(s);
- if (fields_exist) {
- GenRustStructSizeField(s);
+ if (!children_.empty()) {
+ s << "child: " << name_ << "DataChild,";
}
- s << "}*/
+ s << ") -> Self { ";
+
+ s << "Self { ";
+ GenRustStructFieldNames(s);
+ if (!children_.empty()) {
+ s << "child";
+ }
+
+ s << "}";
s << "}";
// parse function
- s << "fn parse(bytes: &[u8]) -> Result<Self> { unimplemented!();";
+ if (parent_constraints_.empty() && !children_.empty() && parent_ != nullptr) {
+ auto constraint = FindConstraintField();
+ auto constraint_field = GetParamList().GetField(constraint);
+ auto constraint_type = constraint_field->GetRustDataType();
+ s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type
+ << ") -> Result<Self> {";
+ } else {
+ s << "fn parse(bytes: &[u8]) -> Result<Self> {";
+ }
auto fields = fields_.GetFieldsWithoutTypes({
BodyField::kFieldType,
+ FixedScalarField::kFieldType,
});
- /*for (auto const& field : fields) {
+ for (auto const& field : fields) {
auto start_field_offset = GetOffsetForField(field->GetName(), false);
auto end_field_offset = GetOffsetForField(field->GetName(), true);
@@ -877,6 +898,63 @@
field->GenRustGetter(s, start_field_offset, end_field_offset);
}
+ auto payload_field = fields_.GetFieldsWithTypes({
+ PayloadField::kFieldType,
+ });
+
+ Size payload_offset;
+
+ if (payload_field.HasPayload()) {
+ payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
+ }
+
+ auto constraint_name = FindConstraintField();
+ auto constrained_descendants = FindDescendantsWithConstraint(constraint_name);
+
+ if (!children_.empty()) {
+ s << "let child = match " << constraint_name << " {";
+ }
+
+ for (const auto& desc : constrained_descendants) {
+ if (desc.first->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
+ continue;
+ }
+ auto desc_path = FindPathToDescendant(desc.first->name_);
+ std::reverse(desc_path.begin(), desc_path.end());
+ auto constraint_field = GetParamList().GetField(constraint_name);
+ auto constraint_type = constraint_field->GetFieldType();
+
+ if (constraint_type == EnumField::kFieldType) {
+ auto type = std::get<std::string>(desc.second);
+ auto variant_name = type.substr(type.find("::") + 2, type.length());
+ auto enum_type = type.substr(0, type.find("::"));
+ auto enum_variant = enum_type + "::"
+ + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
+ s << enum_variant;
+ s << " => {";
+ s << name_ << "DataChild::";
+ s << desc_path[0]->name_ << "(Arc::new(";
+ if (desc_path[0]->parent_constraints_.empty()) {
+ s << desc_path[0]->name_ << "Data::parse(&bytes[" << payload_offset.bytes() << "..]";
+ s << ", " << enum_variant << ")?))";
+ } else {
+ s << desc_path[0]->name_ << "Data::parse(&bytes[" << payload_offset.bytes() << "..])?))";
+ }
+ } else if (constraint_type == ScalarField::kFieldType) {
+ s << std::get<int64_t>(desc.second) << " => {";
+ s << "unimplemented!();";
+ }
+ s << "}\n";
+ }
+
+ if (!constrained_descendants.empty()) {
+ s << "_ => panic!(\"unexpected value " << "\"),";
+ }
+
+ if (!children_.empty()) {
+ s << "};\n";
+ }
+
s << "Ok(Self {";
fields = fields_.GetFieldsWithoutTypes({
BodyField::kFieldType,
@@ -885,6 +963,7 @@
ReservedField::kFieldType,
SizeField::kFieldType,
PayloadField::kFieldType,
+ FixedScalarField::kFieldType,
});
if (fields_exist) {
@@ -893,9 +972,12 @@
s << fields[i]->GetName();
s << ", ";
}
- GenRustStructSizeField(s);
}
- s << "})}\n";*/
+
+ if (!children_.empty()) {
+ s << "child,";
+ }
+ s << "})\n";
s << "}\n";
// write_to function
@@ -928,6 +1010,9 @@
if (!children_.empty()) {
s << "match &self.child {";
for (const auto& child : children_) {
+ if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
+ continue;
+ }
s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
}
s << name_ << "DataChild::None => {}";
@@ -945,7 +1030,7 @@
}
void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
- if (complement_ != nullptr) {
+ if (complement_ != nullptr && complement_->name_.rfind("LeGetVendorCapabilitiesComplete", 0) != 0) {
auto complement_root = complement_->GetRootDef();
auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
s << "impl CommandExpectations for " << name_ << "Packet {";
@@ -977,6 +1062,9 @@
s << " pub fn specialize(&self) -> " << name_ << "Child {";
s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
for (const auto& child : children_) {
+ if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
+ continue;
+ }
s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
<< child->name_ << "Packet::new(self." << root_accessor << ".clone())),";
}
@@ -1047,7 +1135,7 @@
}
void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
- if (complement_ != nullptr) {
+ if (complement_ != nullptr && complement_->name_.rfind("LeGetVendorCapabilitiesComplete", 0) != 0) {
auto complement_root = complement_->GetRootDef();
auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
s << "impl CommandExpectations for " << name_ << "Builder {";
diff --git a/system/gd/packet/parser/parent_def.cc b/system/gd/packet/parser/parent_def.cc
index b54fb1f..ecfd46a 100644
--- a/system/gd/packet/parser/parent_def.cc
+++ b/system/gd/packet/parser/parent_def.cc
@@ -521,3 +521,47 @@
}
return false;
}
+
+std::string ParentDef::FindConstraintField() const {
+ std::string res;
+ for (const auto& child : children_) {
+ if (!child->parent_constraints_.empty()) {
+ return child->parent_constraints_.begin()->first;
+ }
+ res = child->FindConstraintField();
+ }
+ return res;
+}
+
+std::map<const ParentDef*, const std::variant<int64_t, std::string>>
+ ParentDef::FindDescendantsWithConstraint(
+ std::string constraint_name) const {
+ std::map<const ParentDef*, const std::variant<int64_t, std::string>> res;
+
+ for (auto const& child : children_) {
+ auto constraint = child->parent_constraints_.find(constraint_name);
+ if (constraint != child->parent_constraints_.end()) {
+ res.insert(std::pair(child, constraint->second));
+ }
+ auto m = child->FindDescendantsWithConstraint(constraint_name);
+ res.insert(m.begin(), m.end());
+ }
+ return res;
+}
+
+std::vector<const ParentDef*> ParentDef::FindPathToDescendant(std::string descendant) const {
+ std::vector<const ParentDef*> res;
+
+ for (auto const& child : children_) {
+ auto v = child->FindPathToDescendant(descendant);
+ if (v.size() > 0) {
+ res.insert(res.begin(), v.begin(), v.end());
+ res.push_back(child);
+ }
+ if (child->name_ == descendant) {
+ res.push_back(child);
+ return res;
+ }
+ }
+ return res;
+}
diff --git a/system/gd/packet/parser/parent_def.h b/system/gd/packet/parser/parent_def.h
index 8667d3c..7a70506 100644
--- a/system/gd/packet/parser/parent_def.h
+++ b/system/gd/packet/parser/parent_def.h
@@ -69,6 +69,12 @@
std::vector<const ParentDef*> GetAncestors() const;
+ std::string FindConstraintField() const;
+
+ std::map<const ParentDef*, const std::variant<int64_t, std::string>>
+ FindDescendantsWithConstraint(std::string constraint_name) const;
+ std::vector<const ParentDef*> FindPathToDescendant(std::string descendant) const;
+
FieldList fields_;
ParentDef* parent_{nullptr};
diff --git a/system/gd/packet/parser/util.h b/system/gd/packet/parser/util.h
index b249f8c..342e1c5 100644
--- a/system/gd/packet/parser/util.h
+++ b/system/gd/packet/parser/util.h
@@ -173,4 +173,19 @@
return "u64";
}
+inline std::string ToLowerCase(std::string value) {
+ if (value[0] < 'A' || value[0] > 'Z') {
+ ERROR() << value << " doesn't look like CONSTANT_CASE";
+ }
+
+ std::ostringstream lower_case;
+
+ for (unsigned char c : value) {
+ c = std::tolower(c);
+ lower_case << c;
+ }
+
+ return lower_case.str();
+}
+
} // namespace util