| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "discovery/mdns/mdns_reader.h" |
| |
| #include <algorithm> |
| #include <utility> |
| |
| #include "discovery/common/config.h" |
| #include "discovery/mdns/public/mdns_constants.h" |
| #include "util/osp_logging.h" |
| |
| namespace openscreen { |
| namespace discovery { |
| namespace { |
| |
| bool TryParseDnsType(uint16_t to_parse, DnsType* type) { |
| auto it = std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(), |
| static_cast<DnsType>(to_parse)); |
| if (it == kSupportedDnsTypes.end()) { |
| return false; |
| } |
| |
| *type = *it; |
| return true; |
| } |
| |
| } // namespace |
| |
| MdnsReader::MdnsReader(const Config& config, |
| const uint8_t* buffer, |
| size_t length) |
| : BigEndianReader(buffer, length), |
| kMaximumAllowedRdataSize( |
| static_cast<size_t>(config.maximum_valid_rdata_size)) { |
| // TODO(rwkeane): Validate |maximum_valid_rdata_size| > MaxWireSize() for |
| // rdata types A, AAAA, SRV, PTR. |
| OSP_DCHECK_GT(config.maximum_valid_rdata_size, 0); |
| } |
| |
| bool MdnsReader::Read(TxtRecordRdata::Entry* out) { |
| Cursor cursor(this); |
| uint8_t entry_length; |
| if (!Read(&entry_length)) { |
| return false; |
| } |
| const uint8_t* entry_begin = current(); |
| if (!Skip(entry_length)) { |
| return false; |
| } |
| out->reserve(entry_length); |
| out->insert(out->end(), entry_begin, entry_begin + entry_length); |
| cursor.Commit(); |
| return true; |
| } |
| |
| // RFC 1035: https://www.ietf.org/rfc/rfc1035.txt |
| // See section 4.1.4. Message compression. |
| bool MdnsReader::Read(DomainName* out) { |
| OSP_DCHECK(out); |
| const uint8_t* position = current(); |
| // The number of bytes consumed reading from the starting position to either |
| // the first label pointer or the final termination byte, including the |
| // pointer or the termination byte. This is equal to the actual wire size of |
| // the DomainName accounting for compression. |
| size_t bytes_consumed = 0; |
| // The number of bytes that was processed when reading the DomainName, |
| // including all label pointers and direct labels. It is used to detect |
| // circular compression. The number of processed bytes cannot be possibly |
| // greater than the length of the buffer. |
| size_t bytes_processed = 0; |
| size_t domain_name_length = 0; |
| std::vector<absl::string_view> labels; |
| // If we are pointing before the beginning or past the end of the buffer, we |
| // hit a malformed pointer. If we have processed more bytes than there are in |
| // the buffer, we are in a circular compression loop. |
| while (position >= begin() && position < end() && |
| bytes_processed <= length()) { |
| const uint8_t label_type = ReadBigEndian<uint8_t>(position); |
| if (IsTerminationLabel(label_type)) { |
| ErrorOr<DomainName> domain = |
| DomainName::TryCreate(labels.begin(), labels.end()); |
| if (domain.is_error()) { |
| return false; |
| } |
| *out = std::move(domain.value()); |
| if (!bytes_consumed) { |
| bytes_consumed = position + sizeof(uint8_t) - current(); |
| } |
| return Skip(bytes_consumed); |
| } else if (IsPointerLabel(label_type)) { |
| if (position + sizeof(uint16_t) > end()) { |
| return false; |
| } |
| const uint16_t label_offset = |
| GetPointerLabelOffset(ReadBigEndian<uint16_t>(position)); |
| if (!bytes_consumed) { |
| bytes_consumed = position + sizeof(uint16_t) - current(); |
| } |
| bytes_processed += sizeof(uint16_t); |
| position = begin() + label_offset; |
| } else if (IsDirectLabel(label_type)) { |
| const uint8_t label_length = GetDirectLabelLength(label_type); |
| OSP_DCHECK_GT(label_length, 0); |
| bytes_processed += sizeof(uint8_t); |
| position += sizeof(uint8_t); |
| if (position + label_length >= end()) { |
| return false; |
| } |
| const absl::string_view label(reinterpret_cast<const char*>(position), |
| label_length); |
| domain_name_length += label_length + 1; // including the length byte |
| if (!IsValidDomainLabel(label) || |
| domain_name_length > kMaxDomainNameLength) { |
| return false; |
| } |
| labels.push_back(label); |
| bytes_processed += label_length; |
| position += label_length; |
| } else { |
| return false; |
| } |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(RawRecordRdata* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| uint16_t record_length; |
| if (Read(&record_length)) { |
| if (record_length > kMaximumAllowedRdataSize) { |
| return false; |
| } |
| |
| std::vector<uint8_t> buffer(record_length); |
| if (Read(buffer.size(), buffer.data())) { |
| ErrorOr<RawRecordRdata> rdata = |
| RawRecordRdata::TryCreate(std::move(buffer)); |
| if (rdata.is_error()) { |
| return false; |
| } |
| *out = std::move(rdata.value()); |
| cursor.Commit(); |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(SrvRecordRdata* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| uint16_t record_length; |
| uint16_t priority; |
| uint16_t weight; |
| uint16_t port; |
| DomainName target; |
| if (Read(&record_length) && Read(&priority) && Read(&weight) && Read(&port) && |
| Read(&target) && |
| (cursor.delta() == sizeof(record_length) + record_length)) { |
| *out = SrvRecordRdata(priority, weight, port, std::move(target)); |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(ARecordRdata* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| uint16_t record_length; |
| IPAddress address; |
| if (Read(&record_length) && (record_length == IPAddress::kV4Size) && |
| Read(IPAddress::Version::kV4, &address)) { |
| *out = ARecordRdata(address); |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(AAAARecordRdata* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| uint16_t record_length; |
| IPAddress address; |
| if (Read(&record_length) && (record_length == IPAddress::kV6Size) && |
| Read(IPAddress::Version::kV6, &address)) { |
| *out = AAAARecordRdata(address); |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(PtrRecordRdata* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| uint16_t record_length; |
| DomainName ptr_domain; |
| if (Read(&record_length) && Read(&ptr_domain) && |
| (cursor.delta() == sizeof(record_length) + record_length)) { |
| *out = PtrRecordRdata(std::move(ptr_domain)); |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(TxtRecordRdata* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| uint16_t record_length; |
| if (!Read(&record_length)) { |
| return false; |
| } |
| if (record_length > kMaximumAllowedRdataSize) { |
| return false; |
| } |
| std::vector<TxtRecordRdata::Entry> texts; |
| while (cursor.delta() < sizeof(record_length) + record_length) { |
| TxtRecordRdata::Entry entry; |
| if (!Read(&entry)) { |
| return false; |
| } |
| OSP_DCHECK(entry.size() <= kTXTMaxEntrySize); |
| if (!entry.empty()) { |
| texts.emplace_back(entry); |
| } |
| } |
| if (cursor.delta() != sizeof(record_length) + record_length) { |
| return false; |
| } |
| ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts)); |
| if (rdata.is_error()) { |
| return false; |
| } |
| *out = std::move(rdata.value()); |
| cursor.Commit(); |
| return true; |
| } |
| |
| bool MdnsReader::Read(NsecRecordRdata* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| |
| const uint8_t* start_position = current(); |
| uint16_t record_length; |
| DomainName next_record_name; |
| if (!Read(&record_length) || !Read(&next_record_name)) { |
| return false; |
| } |
| if (record_length > kMaximumAllowedRdataSize) { |
| return false; |
| } |
| |
| // Calculate the next record name length. This may not be equal to the length |
| // of |next_record_name| due to domain name compression. |
| const int encoded_next_name_length = |
| current() - start_position - sizeof(record_length); |
| const int remaining_length = record_length - encoded_next_name_length; |
| if (remaining_length <= 0) { |
| // This means either the length is invalid or the NSEC record has no |
| // associated types. |
| return false; |
| } |
| |
| std::vector<DnsType> types; |
| if (Read(&types, remaining_length)) { |
| *out = NsecRecordRdata(std::move(next_record_name), std::move(types)); |
| cursor.Commit(); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| bool MdnsReader::Read(MdnsRecord* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| DomainName name; |
| uint16_t type; |
| uint16_t rrclass; |
| uint32_t ttl; |
| Rdata rdata; |
| if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) && |
| Read(static_cast<DnsType>(type), &rdata)) { |
| ErrorOr<MdnsRecord> record = MdnsRecord::TryCreate( |
| std::move(name), static_cast<DnsType>(type), GetDnsClass(rrclass), |
| GetRecordType(rrclass), std::chrono::seconds(ttl), std::move(rdata)); |
| if (record.is_error()) { |
| return false; |
| } |
| *out = std::move(record.value()); |
| |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(MdnsQuestion* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| DomainName name; |
| uint16_t type; |
| uint16_t rrclass; |
| if (Read(&name) && Read(&type) && Read(&rrclass)) { |
| ErrorOr<MdnsQuestion> question = |
| MdnsQuestion::TryCreate(std::move(name), static_cast<DnsType>(type), |
| GetDnsClass(rrclass), GetResponseType(rrclass)); |
| if (question.is_error()) { |
| return false; |
| } |
| *out = std::move(question.value()); |
| |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(MdnsMessage* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| Header header; |
| std::vector<MdnsQuestion> questions; |
| std::vector<MdnsRecord> answers; |
| std::vector<MdnsRecord> authority_records; |
| std::vector<MdnsRecord> additional_records; |
| if (Read(&header) && Read(header.question_count, &questions) && |
| Read(header.answer_count, &answers) && |
| Read(header.authority_record_count, &authority_records) && |
| Read(header.additional_record_count, &additional_records)) { |
| // TODO(yakimakha): Skip messages with non-zero opcode and rcode. |
| // One way to do this is to change the method signature to return |
| // ErrorOr<MdnsMessage> and return different error codes for failure to read |
| // and for messages that were read successfully but are non-conforming. |
| ErrorOr<MdnsMessage> message = MdnsMessage::TryCreate( |
| header.id, GetMessageType(header.flags), questions, answers, |
| authority_records, additional_records); |
| if (message.is_error()) { |
| return false; |
| } |
| *out = std::move(message.value()); |
| |
| if (IsMessageTruncated(header.flags)) { |
| out->set_truncated(); |
| } |
| |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(IPAddress::Version version, IPAddress* out) { |
| OSP_DCHECK(out); |
| size_t ipaddress_size = (version == IPAddress::Version::kV6) |
| ? IPAddress::kV6Size |
| : IPAddress::kV4Size; |
| const uint8_t* const address_bytes = current(); |
| if (Skip(ipaddress_size)) { |
| *out = IPAddress(version, address_bytes); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(DnsType type, Rdata* out) { |
| OSP_DCHECK(out); |
| switch (type) { |
| case DnsType::kSRV: |
| return Read<SrvRecordRdata>(out); |
| case DnsType::kA: |
| return Read<ARecordRdata>(out); |
| case DnsType::kAAAA: |
| return Read<AAAARecordRdata>(out); |
| case DnsType::kPTR: |
| return Read<PtrRecordRdata>(out); |
| case DnsType::kTXT: |
| return Read<TxtRecordRdata>(out); |
| case DnsType::kNSEC: |
| return Read<NsecRecordRdata>(out); |
| default: |
| OSP_DCHECK(std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(), |
| type) == kSupportedDnsTypes.end()); |
| return Read<RawRecordRdata>(out); |
| } |
| } |
| |
| bool MdnsReader::Read(Header* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| if (Read(&out->id) && Read(&out->flags) && Read(&out->question_count) && |
| Read(&out->answer_count) && Read(&out->authority_record_count) && |
| Read(&out->additional_record_count)) { |
| cursor.Commit(); |
| return true; |
| } |
| return false; |
| } |
| |
| bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| |
| // Continue reading bitmaps until the entire input is read. If we have gone |
| // past the end of the record, it's malformed input so fail. |
| *out = std::vector<DnsType>(); |
| int processed_bytes = 0; |
| while (processed_bytes < remaining_size) { |
| NsecBitMapField bitmap; |
| if (!Read(&bitmap)) { |
| return false; |
| } |
| |
| processed_bytes += bitmap.bitmap_length + 2; |
| if (processed_bytes > remaining_size) { |
| return false; |
| } |
| |
| // The ith bit of the bitmap represents DnsType with value i, shifted |
| // a multiple of 0x100 according to the window. |
| for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) { |
| int current_byte = i / 8; |
| uint8_t bitmask = 0x80 >> i % 8; |
| |
| // If this bit flag represents a type we support, add it to the vector. |
| // Else, we won't be able to use it later on in the code anyway, so drop |
| // it. |
| DnsType type; |
| uint16_t type_index = i | (bitmap.window_block << 8); |
| if ((bitmap.bitmap[current_byte] & bitmask) && |
| TryParseDnsType(type_index, &type)) { |
| out->push_back(type); |
| } |
| } |
| } |
| |
| cursor.Commit(); |
| return true; |
| } |
| |
| bool MdnsReader::Read(NsecBitMapField* out) { |
| OSP_DCHECK(out); |
| Cursor cursor(this); |
| |
| // Read the window and bitmap length, then one byte for each byte called out |
| // by the length. |
| if (Read(&out->window_block) && Read(&out->bitmap_length)) { |
| if (out->bitmap_length == 0 || out->bitmap_length > 32) { |
| return false; |
| } |
| |
| out->bitmap = current(); |
| if (!Skip(out->bitmap_length)) { |
| return false; |
| } |
| cursor.Commit(); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| } // namespace discovery |
| } // namespace openscreen |