blob: bd38e81be23cbeda4ac78d5c32cbc937e94fb69d [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/mdns/mdns_reader.h"
#include <algorithm>
#include <utility>
#include "discovery/common/config.h"
#include "discovery/mdns/public/mdns_constants.h"
#include "util/osp_logging.h"
namespace openscreen {
namespace discovery {
namespace {
bool TryParseDnsType(uint16_t to_parse, DnsType* type) {
auto it = std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
static_cast<DnsType>(to_parse));
if (it == kSupportedDnsTypes.end()) {
return false;
}
*type = *it;
return true;
}
} // namespace
MdnsReader::MdnsReader(const Config& config,
const uint8_t* buffer,
size_t length)
: BigEndianReader(buffer, length),
kMaximumAllowedRdataSize(
static_cast<size_t>(config.maximum_valid_rdata_size)) {
// TODO(rwkeane): Validate |maximum_valid_rdata_size| > MaxWireSize() for
// rdata types A, AAAA, SRV, PTR.
OSP_DCHECK_GT(config.maximum_valid_rdata_size, 0);
}
bool MdnsReader::Read(TxtRecordRdata::Entry* out) {
Cursor cursor(this);
uint8_t entry_length;
if (!Read(&entry_length)) {
return false;
}
const uint8_t* entry_begin = current();
if (!Skip(entry_length)) {
return false;
}
out->reserve(entry_length);
out->insert(out->end(), entry_begin, entry_begin + entry_length);
cursor.Commit();
return true;
}
// RFC 1035: https://www.ietf.org/rfc/rfc1035.txt
// See section 4.1.4. Message compression.
bool MdnsReader::Read(DomainName* out) {
OSP_DCHECK(out);
const uint8_t* position = current();
// The number of bytes consumed reading from the starting position to either
// the first label pointer or the final termination byte, including the
// pointer or the termination byte. This is equal to the actual wire size of
// the DomainName accounting for compression.
size_t bytes_consumed = 0;
// The number of bytes that was processed when reading the DomainName,
// including all label pointers and direct labels. It is used to detect
// circular compression. The number of processed bytes cannot be possibly
// greater than the length of the buffer.
size_t bytes_processed = 0;
size_t domain_name_length = 0;
std::vector<absl::string_view> labels;
// If we are pointing before the beginning or past the end of the buffer, we
// hit a malformed pointer. If we have processed more bytes than there are in
// the buffer, we are in a circular compression loop.
while (position >= begin() && position < end() &&
bytes_processed <= length()) {
const uint8_t label_type = ReadBigEndian<uint8_t>(position);
if (IsTerminationLabel(label_type)) {
ErrorOr<DomainName> domain =
DomainName::TryCreate(labels.begin(), labels.end());
if (domain.is_error()) {
return false;
}
*out = std::move(domain.value());
if (!bytes_consumed) {
bytes_consumed = position + sizeof(uint8_t) - current();
}
return Skip(bytes_consumed);
} else if (IsPointerLabel(label_type)) {
if (position + sizeof(uint16_t) > end()) {
return false;
}
const uint16_t label_offset =
GetPointerLabelOffset(ReadBigEndian<uint16_t>(position));
if (!bytes_consumed) {
bytes_consumed = position + sizeof(uint16_t) - current();
}
bytes_processed += sizeof(uint16_t);
position = begin() + label_offset;
} else if (IsDirectLabel(label_type)) {
const uint8_t label_length = GetDirectLabelLength(label_type);
OSP_DCHECK_GT(label_length, 0);
bytes_processed += sizeof(uint8_t);
position += sizeof(uint8_t);
if (position + label_length >= end()) {
return false;
}
const absl::string_view label(reinterpret_cast<const char*>(position),
label_length);
domain_name_length += label_length + 1; // including the length byte
if (!IsValidDomainLabel(label) ||
domain_name_length > kMaxDomainNameLength) {
return false;
}
labels.push_back(label);
bytes_processed += label_length;
position += label_length;
} else {
return false;
}
}
return false;
}
bool MdnsReader::Read(RawRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
if (Read(&record_length)) {
if (record_length > kMaximumAllowedRdataSize) {
return false;
}
std::vector<uint8_t> buffer(record_length);
if (Read(buffer.size(), buffer.data())) {
ErrorOr<RawRecordRdata> rdata =
RawRecordRdata::TryCreate(std::move(buffer));
if (rdata.is_error()) {
return false;
}
*out = std::move(rdata.value());
cursor.Commit();
return true;
}
}
return false;
}
bool MdnsReader::Read(SrvRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
uint16_t priority;
uint16_t weight;
uint16_t port;
DomainName target;
if (Read(&record_length) && Read(&priority) && Read(&weight) && Read(&port) &&
Read(&target) &&
(cursor.delta() == sizeof(record_length) + record_length)) {
*out = SrvRecordRdata(priority, weight, port, std::move(target));
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(ARecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
IPAddress address;
if (Read(&record_length) && (record_length == IPAddress::kV4Size) &&
Read(IPAddress::Version::kV4, &address)) {
*out = ARecordRdata(address);
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(AAAARecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
IPAddress address;
if (Read(&record_length) && (record_length == IPAddress::kV6Size) &&
Read(IPAddress::Version::kV6, &address)) {
*out = AAAARecordRdata(address);
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(PtrRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
DomainName ptr_domain;
if (Read(&record_length) && Read(&ptr_domain) &&
(cursor.delta() == sizeof(record_length) + record_length)) {
*out = PtrRecordRdata(std::move(ptr_domain));
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(TxtRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
if (!Read(&record_length)) {
return false;
}
if (record_length > kMaximumAllowedRdataSize) {
return false;
}
std::vector<TxtRecordRdata::Entry> texts;
while (cursor.delta() < sizeof(record_length) + record_length) {
TxtRecordRdata::Entry entry;
if (!Read(&entry)) {
return false;
}
OSP_DCHECK(entry.size() <= kTXTMaxEntrySize);
if (!entry.empty()) {
texts.emplace_back(entry);
}
}
if (cursor.delta() != sizeof(record_length) + record_length) {
return false;
}
ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
if (rdata.is_error()) {
return false;
}
*out = std::move(rdata.value());
cursor.Commit();
return true;
}
bool MdnsReader::Read(NsecRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
const uint8_t* start_position = current();
uint16_t record_length;
DomainName next_record_name;
if (!Read(&record_length) || !Read(&next_record_name)) {
return false;
}
if (record_length > kMaximumAllowedRdataSize) {
return false;
}
// Calculate the next record name length. This may not be equal to the length
// of |next_record_name| due to domain name compression.
const int encoded_next_name_length =
current() - start_position - sizeof(record_length);
const int remaining_length = record_length - encoded_next_name_length;
if (remaining_length <= 0) {
// This means either the length is invalid or the NSEC record has no
// associated types.
return false;
}
std::vector<DnsType> types;
if (Read(&types, remaining_length)) {
*out = NsecRecordRdata(std::move(next_record_name), std::move(types));
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(MdnsRecord* out) {
OSP_DCHECK(out);
Cursor cursor(this);
DomainName name;
uint16_t type;
uint16_t rrclass;
uint32_t ttl;
Rdata rdata;
if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) &&
Read(static_cast<DnsType>(type), &rdata)) {
ErrorOr<MdnsRecord> record = MdnsRecord::TryCreate(
std::move(name), static_cast<DnsType>(type), GetDnsClass(rrclass),
GetRecordType(rrclass), std::chrono::seconds(ttl), std::move(rdata));
if (record.is_error()) {
return false;
}
*out = std::move(record.value());
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(MdnsQuestion* out) {
OSP_DCHECK(out);
Cursor cursor(this);
DomainName name;
uint16_t type;
uint16_t rrclass;
if (Read(&name) && Read(&type) && Read(&rrclass)) {
ErrorOr<MdnsQuestion> question =
MdnsQuestion::TryCreate(std::move(name), static_cast<DnsType>(type),
GetDnsClass(rrclass), GetResponseType(rrclass));
if (question.is_error()) {
return false;
}
*out = std::move(question.value());
cursor.Commit();
return true;
}
return false;
}
ErrorOr<MdnsMessage> MdnsReader::Read() {
MdnsMessage out;
Cursor cursor(this);
Header header;
std::vector<MdnsQuestion> questions;
std::vector<MdnsRecord> answers;
std::vector<MdnsRecord> authority_records;
std::vector<MdnsRecord> additional_records;
if (Read(&header) && Read(header.question_count, &questions) &&
Read(header.answer_count, &answers) &&
Read(header.authority_record_count, &authority_records) &&
Read(header.additional_record_count, &additional_records)) {
if (!IsValidFlagsSection(header.flags)) {
return Error::Code::kMdnsNonConformingFailure;
}
ErrorOr<MdnsMessage> message = MdnsMessage::TryCreate(
header.id, GetMessageType(header.flags), questions, answers,
authority_records, additional_records);
if (message.is_error()) {
return std::move(message.error());
}
out = std::move(message.value());
if (IsMessageTruncated(header.flags)) {
out.set_truncated();
}
cursor.Commit();
return out;
}
return Error::Code::kMdnsReadFailure;
}
bool MdnsReader::Read(IPAddress::Version version, IPAddress* out) {
OSP_DCHECK(out);
size_t ipaddress_size = (version == IPAddress::Version::kV6)
? IPAddress::kV6Size
: IPAddress::kV4Size;
const uint8_t* const address_bytes = current();
if (Skip(ipaddress_size)) {
*out = IPAddress(version, address_bytes);
return true;
}
return false;
}
bool MdnsReader::Read(DnsType type, Rdata* out) {
OSP_DCHECK(out);
switch (type) {
case DnsType::kSRV:
return Read<SrvRecordRdata>(out);
case DnsType::kA:
return Read<ARecordRdata>(out);
case DnsType::kAAAA:
return Read<AAAARecordRdata>(out);
case DnsType::kPTR:
return Read<PtrRecordRdata>(out);
case DnsType::kTXT:
return Read<TxtRecordRdata>(out);
case DnsType::kNSEC:
return Read<NsecRecordRdata>(out);
default:
OSP_DCHECK(std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
type) == kSupportedDnsTypes.end());
return Read<RawRecordRdata>(out);
}
}
bool MdnsReader::Read(Header* out) {
OSP_DCHECK(out);
Cursor cursor(this);
if (Read(&out->id) && Read(&out->flags) && Read(&out->question_count) &&
Read(&out->answer_count) && Read(&out->authority_record_count) &&
Read(&out->additional_record_count)) {
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) {
OSP_DCHECK(out);
Cursor cursor(this);
// Continue reading bitmaps until the entire input is read. If we have gone
// past the end of the record, it's malformed input so fail.
*out = std::vector<DnsType>();
int processed_bytes = 0;
while (processed_bytes < remaining_size) {
NsecBitMapField bitmap;
if (!Read(&bitmap)) {
return false;
}
processed_bytes += bitmap.bitmap_length + 2;
if (processed_bytes > remaining_size) {
return false;
}
// The ith bit of the bitmap represents DnsType with value i, shifted
// a multiple of 0x100 according to the window.
for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) {
int current_byte = i / 8;
uint8_t bitmask = 0x80 >> i % 8;
// If this bit flag represents a type we support, add it to the vector.
// Else, we won't be able to use it later on in the code anyway, so drop
// it.
DnsType type;
uint16_t type_index = i | (bitmap.window_block << 8);
if ((bitmap.bitmap[current_byte] & bitmask) &&
TryParseDnsType(type_index, &type)) {
out->push_back(type);
}
}
}
cursor.Commit();
return true;
}
bool MdnsReader::Read(NsecBitMapField* out) {
OSP_DCHECK(out);
Cursor cursor(this);
// Read the window and bitmap length, then one byte for each byte called out
// by the length.
if (Read(&out->window_block) && Read(&out->bitmap_length)) {
if (out->bitmap_length == 0 || out->bitmap_length > 32) {
return false;
}
out->bitmap = current();
if (!Skip(out->bitmap_length)) {
return false;
}
cursor.Commit();
return true;
}
return false;
}
} // namespace discovery
} // namespace openscreen