blob: 9c8ca0715fe26386b234315e438090f2da02bb35 [file] [log] [blame]
/*
* Copyright 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fields/vector_field.h"
#include "fields/count_field.h"
#include "fields/custom_field.h"
#include "util.h"
const std::string VectorField::kFieldType = "VectorField";
VectorField::VectorField(std::string name, int element_size, std::string size_modifier, ParseLocation loc)
: PacketField(name, loc), element_field_(new ScalarField("val", element_size, loc)), element_size_(element_size),
size_modifier_(size_modifier) {
if (element_size > 64 || element_size < 0)
ERROR(this) << __func__ << ": Not implemented for element size = " << element_size;
if (element_size % 8 != 0) {
ERROR(this) << "Can only have arrays with elements that are byte aligned (" << element_size << ")";
}
}
VectorField::VectorField(std::string name, TypeDef* type_def, std::string size_modifier, ParseLocation loc)
: PacketField(name, loc), element_field_(type_def->GetNewField("val", loc)),
element_size_(element_field_->GetSize()), size_modifier_(size_modifier) {
if (!element_size_.empty() && element_size_.bits() % 8 != 0) {
ERROR(this) << "Can only have arrays with elements that are byte aligned (" << element_size_ << ")";
}
}
const std::string& VectorField::GetFieldType() const {
return VectorField::kFieldType;
}
Size VectorField::GetSize() const {
// If there is no size field, then it is of unknown size.
if (size_field_ == nullptr) {
return Size();
}
// size_field_ is of type SIZE
if (size_field_->GetFieldType() == SizeField::kFieldType) {
std::string ret = "(static_cast<size_t>(Get" + util::UnderscoreToCamelCase(size_field_->GetName()) + "()) * 8)";
if (!size_modifier_.empty()) ret += size_modifier_;
return ret;
}
// size_field_ is of type COUNT and elements have a fixed size
if (!element_size_.empty() && !element_size_.has_dynamic()) {
return "(static_cast<size_t>(Get" + util::UnderscoreToCamelCase(size_field_->GetName()) + "()) * " +
std::to_string(element_size_.bits()) + ")";
}
return Size();
}
Size VectorField::GetBuilderSize() const {
if (!element_size_.empty() && !element_size_.has_dynamic()) {
std::string ret = "(static_cast<size_t>(" + GetName() + "_.size()) * " + std::to_string(element_size_.bits()) + ")";
return ret;
} else if (element_field_->BuilderParameterMustBeMoved()) {
std::string ret = "[this](){ size_t length = 0; for (const auto& elem : " + GetName() +
"_) { length += elem->size() * 8; } return length; }()";
return ret;
} else {
std::string ret = "[this](){ size_t length = 0; for (const auto& elem : " + GetName() +
"_) { length += elem.size() * 8; } return length; }()";
return ret;
}
}
Size VectorField::GetStructSize() const {
// If there is no size field, then it is of unknown size.
if (size_field_ == nullptr) {
return Size();
}
// size_field_ is of type SIZE
if (size_field_->GetFieldType() == SizeField::kFieldType) {
std::string ret = "(static_cast<size_t>(to_fill->" + size_field_->GetName() + "_extracted_) * 8)";
if (!size_modifier_.empty()) ret += "-" + size_modifier_;
return ret;
}
// size_field_ is of type COUNT and elements have a fixed size
if (!element_size_.empty() && !element_size_.has_dynamic()) {
return "(static_cast<size_t>(to_fill->" + size_field_->GetName() + "_extracted_) * " +
std::to_string(element_size_.bits()) + ")";
}
return Size();
}
std::string VectorField::GetDataType() const {
return "std::vector<" + element_field_->GetDataType() + ">";
}
void VectorField::GenExtractor(std::ostream& s, int num_leading_bits, bool for_struct) const {
s << "auto " << element_field_->GetName() << "_it = " << GetName() << "_it;";
if (size_field_ != nullptr && size_field_->GetFieldType() == CountField::kFieldType) {
s << "size_t " << element_field_->GetName() << "_count = ";
if (for_struct) {
s << "to_fill->" << size_field_->GetName() << "_extracted_;";
} else {
s << "Get" << util::UnderscoreToCamelCase(size_field_->GetName()) << "();";
}
}
s << "while (";
if (size_field_ != nullptr && size_field_->GetFieldType() == CountField::kFieldType) {
s << "(" << element_field_->GetName() << "_count-- > 0) && ";
}
if (!element_size_.empty()) {
s << element_field_->GetName() << "_it.NumBytesRemaining() >= " << element_size_.bytes() << ") {";
} else {
s << element_field_->GetName() << "_it.NumBytesRemaining() > 0) {";
}
if (element_field_->BuilderParameterMustBeMoved()) {
s << element_field_->GetDataType() << " " << element_field_->GetName() << "_ptr;";
} else {
s << element_field_->GetDataType() << " " << element_field_->GetName() << "_value;";
s << element_field_->GetDataType() << "* " << element_field_->GetName() << "_ptr = &" << element_field_->GetName()
<< "_value;";
}
element_field_->GenExtractor(s, num_leading_bits, for_struct);
s << "if (" << element_field_->GetName() << "_ptr != nullptr) { ";
if (element_field_->BuilderParameterMustBeMoved()) {
s << GetName() << "_ptr->push_back(std::move(" << element_field_->GetName() << "_ptr));";
} else {
s << GetName() << "_ptr->push_back(" << element_field_->GetName() << "_value);";
}
s << "}";
s << "}";
}
std::string VectorField::GetGetterFunctionName() const {
std::stringstream ss;
ss << "Get" << util::UnderscoreToCamelCase(GetName());
return ss.str();
}
void VectorField::GenGetter(std::ostream& s, Size start_offset, Size end_offset) const {
s << GetDataType() << " " << GetGetterFunctionName() << "() {";
s << "ASSERT(was_validated_);";
s << "size_t end_index = size();";
s << "auto to_bound = begin();";
int num_leading_bits = GenBounds(s, start_offset, end_offset, GetSize());
s << GetDataType() << " " << GetName() << "_value{};";
s << GetDataType() << "* " << GetName() << "_ptr = &" << GetName() << "_value;";
GenExtractor(s, num_leading_bits, false);
s << "return " << GetName() << "_value;";
s << "}\n";
}
std::string VectorField::GetBuilderParameterType() const {
std::stringstream ss;
if (element_field_->BuilderParameterMustBeMoved()) {
ss << "std::vector<" << element_field_->GetDataType() << ">";
} else {
ss << "const std::vector<" << element_field_->GetDataType() << ">&";
}
return ss.str();
}
bool VectorField::BuilderParameterMustBeMoved() const {
return element_field_->BuilderParameterMustBeMoved();
}
bool VectorField::GenBuilderMember(std::ostream& s) const {
s << "std::vector<" << element_field_->GetDataType() << "> " << GetName();
return true;
}
bool VectorField::HasParameterValidator() const {
// Does not have parameter validator yet.
// TODO: See comment in GenParameterValidator
return false;
}
void VectorField::GenParameterValidator(std::ostream&) const {
// No Parameter validator if its dynamically size.
// TODO: Maybe add a validator to ensure that the size isn't larger than what the size field can hold.
return;
}
void VectorField::GenInserter(std::ostream& s) const {
s << "for (const auto& val_ : " << GetName() << "_) {";
element_field_->GenInserter(s);
s << "}\n";
}
void VectorField::GenValidator(std::ostream&) const {
// NOTE: We could check if the element size divides cleanly into the array size, but we decided to forgo that
// in favor of just returning as many elements as possible in a best effort style.
//
// Other than that there is nothing that arrays need to be validated on other than length so nothing needs to
// be done here.
}
void VectorField::SetSizeField(const SizeField* size_field) {
if (size_field->GetFieldType() == CountField::kFieldType && !size_modifier_.empty()) {
ERROR(this, size_field) << "Can not use count field to describe array with a size modifier."
<< " Use size instead";
}
size_field_ = size_field;
}
const std::string& VectorField::GetSizeModifier() const {
return size_modifier_;
}
bool VectorField::IsContainerField() const {
return true;
}
const PacketField* VectorField::GetElementField() const {
return element_field_;
}
void VectorField::GenStringRepresentation(std::ostream& s, std::string accessor) const {
s << "\"VECTOR[\";";
std::string arr_idx = "arridx_" + accessor;
std::string vec_size = accessor + ".size()";
s << "for (size_t index = 0; index < " << vec_size << "; index++) {";
std::string element_accessor = "(" + accessor + "[index])";
s << "ss << ((index == 0) ? \"\" : \", \") << ";
if (element_field_->GetFieldType() == CustomField::kFieldType) {
s << element_accessor << ".ToString()";
} else {
element_field_->GenStringRepresentation(s, element_accessor);
}
s << ";}";
s << "ss << \"]\"";
}
std::string VectorField::GetRustDataType() const {
return "Vec::<" + element_field_->GetRustDataType() + ">";
}
void VectorField::GenBoundsCheck(std::ostream& s, Size start_offset, Size, std::string context) const {
auto element_field_type = GetElementField()->GetFieldType();
auto element_field = GetElementField();
auto element_size = element_field->GetSize().bytes();
if (element_field_type == ScalarField::kFieldType) {
if (size_field_ == nullptr) {
s << "let rem_ = (bytes.len() - " << start_offset.bytes() << ") % " << element_size << ";";
s << "if rem_ != 0 {";
s << " return Err(Error::InvalidLengthError{";
s << " obj: \"" << context << "\".to_string(),";
s << " field: \"" << GetName() << "\".to_string(),";
s << " wanted: bytes.len() + rem_,";
s << " got: bytes.len()});";
s << "}";
} else if (size_field_->GetFieldType() == CountField::kFieldType) {
s << "let want_ = " << start_offset.bytes() << " + ((" << size_field_->GetName() << " as usize) * "
<< element_size << ");";
s << "if bytes.len() < want_ {";
s << " return Err(Error::InvalidLengthError{";
s << " obj: \"" << context << "\".to_string(),";
s << " field: \"" << GetName() << "\".to_string(),";
s << " wanted: want_,";
s << " got: bytes.len()});";
s << "}";
} else {
s << "let want_ = " << start_offset.bytes() << " + (" << size_field_->GetName() << " as usize)";
if (GetSizeModifier() != "") {
s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
}
s << ";";
s << "if bytes.len() < want_ {";
s << " return Err(Error::InvalidLengthError{";
s << " obj: \"" << context << "\".to_string(),";
s << " field: \"" << GetName() << "\".to_string(),";
s << " wanted: want_,";
s << " got: bytes.len()});";
s << "}";
if (GetSizeModifier() != "") {
s << "if ((" << size_field_->GetName() << " as usize) < ((" << GetSizeModifier().substr(1) << ") / 8)) {";
s << " return Err(Error::ImpossibleStructError);";
s << "}";
}
}
}
}
void VectorField::GenRustGetter(std::ostream& s, Size start_offset, Size) const {
auto element_field_type = GetElementField()->GetFieldType();
auto element_field = GetElementField();
auto element_size = element_field->GetSize().bytes();
if (element_field_type == ScalarField::kFieldType) {
s << "let " << GetName() << ": " << GetRustDataType() << " = ";
if (size_field_ == nullptr) {
s << "bytes[" << start_offset.bytes() << "..]";
} else if (size_field_->GetFieldType() == CountField::kFieldType) {
s << "bytes[" << start_offset.bytes() << ".." << start_offset.bytes() << " + ((";
s << size_field_->GetName() << " as usize) * " << element_size << ")]";
} else {
s << "bytes[" << start_offset.bytes() << "..(";
s << start_offset.bytes() << " + " << size_field_->GetName();
s << " as usize)";
if (GetSizeModifier() != "") {
s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
}
s << "]";
}
s << ".to_vec().chunks_exact(" << element_size << ").into_iter().map(|i| ";
s << element_field->GetRustDataType() << "::from_le_bytes([";
for (int j=0; j < element_size; j++) {
s << "i[" << j << "]";
if (j != element_size - 1) {
s << ", ";
}
}
s << "])).collect();";
} else {
s << "let mut " << GetName() << ": " << GetRustDataType() << " = Vec::new();";
if (size_field_ == nullptr) {
s << "let mut parsable_ = &bytes[" << start_offset.bytes() << "..];";
s << "while parsable_.len() > 0 {";
} else if (size_field_->GetFieldType() == CountField::kFieldType) {
s << "let mut parsable_ = &bytes[" << start_offset.bytes() << "..];";
s << "let count_ = " << size_field_->GetName() << " as usize;";
s << "for _ in 0..count_ {";
} else {
s << "let mut parsable_ = &bytes[" << start_offset.bytes() << ".." << start_offset.bytes() << " + ("
<< size_field_->GetName() << " as usize)";
if (GetSizeModifier() != "") {
s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
}
s << "];";
s << "while parsable_.len() > 0 {";
}
s << " match " << element_field->GetRustDataType() << "::parse(&parsable_) {";
s << " Ok(parsed) => {";
s << " parsable_ = &parsable_[parsed.get_total_size()..];";
s << GetName() << ".push(parsed);";
s << " },";
s << " Err(Error::ImpossibleStructError) => break,";
s << " Err(e) => return Err(e),";
s << " }";
s << "}";
}
}
void VectorField::GenRustWriter(std::ostream& s, Size start_offset, Size) const {
if (GetElementField()->GetFieldType() == ScalarField::kFieldType) {
s << "for (i, e) in self." << GetName() << ".iter().enumerate() {";
s << "buffer[" << start_offset.bytes() << "+i..";
s << start_offset.bytes() << "+i+" << GetElementField()->GetSize().bytes() << "]";
s << ".copy_from_slice(&e.to_le_bytes())";
s << "}";
} else {
s << "let mut vec_buffer_ = &mut buffer[" << start_offset.bytes() << "..];";
s << "for e_ in &self." << GetName() << " {";
s << " e_.write_to(&mut vec_buffer_[0..e_.get_total_size()]);";
s << " vec_buffer_ = &mut vec_buffer_[e_.get_total_size()..];";
s << "}";
}
}