blob: 36bc3ed0d509368a3c2daa926b09c39d08c01067 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "cast/common/mdns/mdns_reader.h"
#include "cast/common/mdns/mdns_constants.h"
#include "platform/api/logging.h"
namespace cast {
namespace mdns {
bool MdnsReader::Read(absl::string_view* out) {
Cursor cursor(this);
uint8_t string_length;
if (!Read(&string_length)) {
return false;
}
const char* string_begin = reinterpret_cast<const char*>(current());
if (!Skip(string_length)) {
return false;
}
*out = absl::string_view(string_begin, string_length);
cursor.Commit();
return true;
}
// RFC 1035: https://www.ietf.org/rfc/rfc1035.txt
// See section 4.1.4. Message compression
bool MdnsReader::Read(DomainName* out) {
OSP_DCHECK(out);
const uint8_t* position = current();
// The number of bytes consumed reading from the starting position to either
// the first label pointer or the final termination byte, including the
// pointer or the termination byte. This is equal to the actual wire size of
// the DomainName accounting for compression.
size_t bytes_consumed = 0;
// The number of bytes that was processed when reading the DomainName,
// including all label pointers and direct labels. It is used to detect
// circular compression. The number of processed bytes cannot be possibly
// greater than the length of the buffer.
size_t bytes_processed = 0;
size_t domain_name_length = 0;
std::vector<absl::string_view> labels;
// If we are pointing before the beginning or past the end of the buffer, we
// hit a malformed pointer. If we have processed more bytes than there are in
// the buffer, we are in a circular compression loop.
while (position >= begin() && position < end() &&
bytes_processed <= length()) {
const uint8_t label_type = openscreen::ReadBigEndian<uint8_t>(position);
if (IsTerminationLabel(label_type)) {
*out = DomainName(labels);
if (!bytes_consumed) {
bytes_consumed = position + sizeof(uint8_t) - current();
}
return Skip(bytes_consumed);
} else if (IsPointerLabel(label_type)) {
if (position + sizeof(uint16_t) > end()) {
return false;
}
const uint16_t label_offset =
GetPointerLabelOffset(openscreen::ReadBigEndian<uint16_t>(position));
if (!bytes_consumed) {
bytes_consumed = position + sizeof(uint16_t) - current();
}
bytes_processed += sizeof(uint16_t);
position = begin() + label_offset;
} else if (IsDirectLabel(label_type)) {
const uint8_t label_length = GetDirectLabelLength(label_type);
OSP_DCHECK_GT(label_length, 0);
bytes_processed += sizeof(uint8_t);
position += sizeof(uint8_t);
if (position + label_length >= end()) {
return false;
}
const absl::string_view label(reinterpret_cast<const char*>(position),
label_length);
domain_name_length += label_length + 1; // including the length byte
if (!IsValidDomainLabel(label) ||
domain_name_length > kMaxDomainNameLength) {
return false;
}
labels.push_back(label);
bytes_processed += label_length;
position += label_length;
} else {
return false;
}
}
return false;
}
bool MdnsReader::Read(RawRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
if (Read(&record_length)) {
std::vector<uint8_t> buffer(record_length);
if (Read(buffer.size(), buffer.data())) {
*out = RawRecordRdata(std::move(buffer));
cursor.Commit();
return true;
}
}
return false;
}
bool MdnsReader::Read(SrvRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
uint16_t priority;
uint16_t weight;
uint16_t port;
DomainName target;
if (Read(&record_length) && Read(&priority) && Read(&weight) && Read(&port) &&
Read(&target) &&
(cursor.delta() == sizeof(record_length) + record_length)) {
*out = SrvRecordRdata(priority, weight, port, std::move(target));
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(ARecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
IPAddress address;
if (Read(&record_length) && (record_length == IPAddress::kV4Size) &&
Read(IPAddress::Version::kV4, &address)) {
*out = ARecordRdata(address);
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(AAAARecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
IPAddress address;
if (Read(&record_length) && (record_length == IPAddress::kV6Size) &&
Read(IPAddress::Version::kV6, &address)) {
*out = AAAARecordRdata(address);
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(PtrRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
DomainName ptr_domain;
if (Read(&record_length) && Read(&ptr_domain) &&
(cursor.delta() == sizeof(record_length) + record_length)) {
*out = PtrRecordRdata(std::move(ptr_domain));
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(TxtRecordRdata* out) {
OSP_DCHECK(out);
Cursor cursor(this);
uint16_t record_length;
if (!Read(&record_length)) {
return false;
}
std::vector<absl::string_view> texts;
while (cursor.delta() < sizeof(record_length) + record_length) {
absl::string_view entry;
if (!Read(&entry)) {
return false;
}
OSP_DCHECK(entry.length() <= kTXTMaxEntrySize);
if (!entry.empty()) {
texts.push_back(entry);
}
}
if (cursor.delta() != sizeof(record_length) + record_length) {
return false;
}
*out = TxtRecordRdata(texts);
cursor.Commit();
return true;
}
bool MdnsReader::Read(MdnsRecord* out) {
OSP_DCHECK(out);
Cursor cursor(this);
DomainName name;
uint16_t type;
uint16_t rrclass;
uint32_t ttl;
Rdata rdata;
if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) &&
Read(static_cast<DnsType>(type), &rdata)) {
*out = MdnsRecord(std::move(name), static_cast<DnsType>(type),
GetDnsClass(rrclass), GetCacheFlush(rrclass), ttl,
std::move(rdata));
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(MdnsQuestion* out) {
OSP_DCHECK(out);
Cursor cursor(this);
DomainName name;
uint16_t type;
uint16_t rrclass;
if (Read(&name) && Read(&type) && Read(&rrclass)) {
*out = MdnsQuestion(std::move(name), static_cast<DnsType>(type),
GetDnsClass(rrclass), GetUnicastResponse(rrclass));
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(MdnsMessage* out) {
OSP_DCHECK(out);
Cursor cursor(this);
Header header;
std::vector<MdnsQuestion> questions;
std::vector<MdnsRecord> answers;
std::vector<MdnsRecord> authority_records;
std::vector<MdnsRecord> additional_records;
if (Read(&header) && Read(header.question_count, &questions) &&
Read(header.answer_count, &answers) &&
Read(header.authority_record_count, &authority_records) &&
Read(header.additional_record_count, &additional_records)) {
// TODO(yakimakha): Skip messages with non-zero opcode and rcode.
// One way to do this is to change the method signature to return
// ErrorOr<MdnsMessage> and return different error codes for failure to read
// and for messages that were read successfully but are non-conforming.
*out = MdnsMessage(header.id, GetMessageType(header.flags), questions,
answers, authority_records, additional_records);
cursor.Commit();
return true;
}
return false;
}
bool MdnsReader::Read(IPAddress::Version version, IPAddress* out) {
OSP_DCHECK(out);
size_t ipaddress_size = (version == IPAddress::Version::kV6)
? IPAddress::kV6Size
: IPAddress::kV4Size;
const uint8_t* const address_bytes = current();
if (Skip(ipaddress_size)) {
*out = IPAddress(version, address_bytes);
return true;
}
return false;
}
bool MdnsReader::Read(DnsType type, Rdata* out) {
OSP_DCHECK(out);
switch (type) {
case DnsType::kSRV:
return Read<SrvRecordRdata>(out);
case DnsType::kA:
return Read<ARecordRdata>(out);
case DnsType::kAAAA:
return Read<AAAARecordRdata>(out);
case DnsType::kPTR:
return Read<PtrRecordRdata>(out);
case DnsType::kTXT:
return Read<TxtRecordRdata>(out);
default:
return Read<RawRecordRdata>(out);
}
}
bool MdnsReader::Read(Header* out) {
OSP_DCHECK(out);
Cursor cursor(this);
if (Read(&out->id) && Read(&out->flags) && Read(&out->question_count) &&
Read(&out->answer_count) && Read(&out->authority_record_count) &&
Read(&out->additional_record_count)) {
cursor.Commit();
return true;
}
return false;
}
} // namespace mdns
} // namespace cast