blob: 6eed677bfd4fd844f578ef619726b488817c7c7b [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/mdns/mdns_records.h"
#include <algorithm>
#include <cctype>
#include <limits>
#include <sstream>
#include <vector>
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "discovery/mdns/mdns_writer.h"
namespace openscreen {
namespace discovery {
namespace {
constexpr size_t kMaxRawRecordSize = std::numeric_limits<uint16_t>::max();
constexpr size_t kMaxMessageFieldEntryCount =
std::numeric_limits<uint16_t>::max();
inline int CompareIgnoreCase(const std::string& x, const std::string& y) {
size_t i = 0;
for (; i < x.size(); i++) {
if (i == y.size()) {
return 1;
}
const char& x_char = std::tolower(x[i]);
const char& y_char = std::tolower(y[i]);
if (x_char < y_char) {
return -1;
} else if (y_char < x_char) {
return 1;
}
}
return i == y.size() ? 0 : -1;
}
template <typename RDataType>
bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) {
const RDataType& lhs_cast = absl::get<RDataType>(lhs);
const RDataType& rhs_cast = absl::get<RDataType>(rhs);
// The Extra 2 in length is from the record size that Write() prepends to the
// result.
const size_t lhs_size = lhs_cast.MaxWireSize() + 2;
const size_t rhs_size = rhs_cast.MaxWireSize() + 2;
uint8_t lhs_bytes[lhs_size];
uint8_t rhs_bytes[rhs_size];
MdnsWriter lhs_writer(lhs_bytes, lhs_size);
MdnsWriter rhs_writer(rhs_bytes, rhs_size);
const bool lhs_write = lhs_writer.Write(lhs_cast);
const bool rhs_write = rhs_writer.Write(rhs_cast);
OSP_DCHECK(lhs_write);
OSP_DCHECK(rhs_write);
// Skip the size bits.
const size_t min_size = std::min(lhs_writer.offset(), rhs_writer.offset());
for (size_t i = 2; i < min_size; i++) {
if (lhs_bytes[i] != rhs_bytes[i]) {
return lhs_bytes[i] > rhs_bytes[i];
}
}
return lhs_size > rhs_size;
}
bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) {
switch (type) {
case DnsType::kA:
return IsGreaterThan<ARecordRdata>(lhs, rhs);
case DnsType::kPTR:
return IsGreaterThan<PtrRecordRdata>(lhs, rhs);
case DnsType::kTXT:
return IsGreaterThan<TxtRecordRdata>(lhs, rhs);
case DnsType::kAAAA:
return IsGreaterThan<AAAARecordRdata>(lhs, rhs);
case DnsType::kSRV:
return IsGreaterThan<SrvRecordRdata>(lhs, rhs);
case DnsType::kNSEC:
return IsGreaterThan<NsecRecordRdata>(lhs, rhs);
default:
return IsGreaterThan<RawRecordRdata>(lhs, rhs);
}
}
} // namespace
bool IsValidDomainLabel(absl::string_view label) {
const size_t label_size = label.size();
return label_size > 0 && label_size <= kMaxLabelLength;
}
DomainName::DomainName() = default;
DomainName::DomainName(std::vector<std::string> labels)
: DomainName(labels.begin(), labels.end()) {}
DomainName::DomainName(const std::vector<absl::string_view>& labels)
: DomainName(labels.begin(), labels.end()) {}
DomainName::DomainName(std::initializer_list<absl::string_view> labels)
: DomainName(labels.begin(), labels.end()) {}
DomainName::DomainName(std::vector<std::string> labels, size_t max_wire_size)
: max_wire_size_(max_wire_size), labels_(std::move(labels)) {}
DomainName::DomainName(const DomainName& other) = default;
DomainName::DomainName(DomainName&& other) = default;
DomainName& DomainName::operator=(const DomainName& rhs) = default;
DomainName& DomainName::operator=(DomainName&& rhs) = default;
std::string DomainName::ToString() const {
return absl::StrJoin(labels_, ".");
}
bool DomainName::operator<(const DomainName& rhs) const {
size_t i = 0;
for (; i < labels_.size(); i++) {
if (i == rhs.labels_.size()) {
return false;
} else {
int result = CompareIgnoreCase(labels_[i], rhs.labels_[i]);
if (result < 0) {
return true;
} else if (result > 0) {
return false;
}
}
}
return i < rhs.labels_.size();
}
bool DomainName::operator<=(const DomainName& rhs) const {
return (*this < rhs) || (*this == rhs);
}
bool DomainName::operator>(const DomainName& rhs) const {
return !(*this < rhs) && !(*this == rhs);
}
bool DomainName::operator>=(const DomainName& rhs) const {
return !(*this < rhs);
}
bool DomainName::operator==(const DomainName& rhs) const {
if (labels_.size() != rhs.labels_.size()) {
return false;
}
for (size_t i = 0; i < labels_.size(); i++) {
if (CompareIgnoreCase(labels_[i], rhs.labels_[i]) != 0) {
return false;
}
}
return true;
}
bool DomainName::operator!=(const DomainName& rhs) const {
return !(*this == rhs);
}
size_t DomainName::MaxWireSize() const {
return max_wire_size_;
}
// static
ErrorOr<RawRecordRdata> RawRecordRdata::TryCreate(std::vector<uint8_t> rdata) {
if (rdata.size() > kMaxRawRecordSize) {
return Error::Code::kIndexOutOfBounds;
} else {
return RawRecordRdata(std::move(rdata));
}
}
RawRecordRdata::RawRecordRdata() = default;
RawRecordRdata::RawRecordRdata(std::vector<uint8_t> rdata)
: rdata_(std::move(rdata)) {
// Ensure RDATA length does not exceed the maximum allowed.
OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize);
}
RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size)
: RawRecordRdata(std::vector<uint8_t>(begin, begin + size)) {}
RawRecordRdata::RawRecordRdata(const RawRecordRdata& other) = default;
RawRecordRdata::RawRecordRdata(RawRecordRdata&& other) = default;
RawRecordRdata& RawRecordRdata::operator=(const RawRecordRdata& rhs) = default;
RawRecordRdata& RawRecordRdata::operator=(RawRecordRdata&& rhs) = default;
bool RawRecordRdata::operator==(const RawRecordRdata& rhs) const {
return rdata_ == rhs.rdata_;
}
bool RawRecordRdata::operator!=(const RawRecordRdata& rhs) const {
return !(*this == rhs);
}
size_t RawRecordRdata::MaxWireSize() const {
// max_wire_size includes uint16_t record length field.
return sizeof(uint16_t) + rdata_.size();
}
SrvRecordRdata::SrvRecordRdata() = default;
SrvRecordRdata::SrvRecordRdata(uint16_t priority,
uint16_t weight,
uint16_t port,
DomainName target)
: priority_(priority),
weight_(weight),
port_(port),
target_(std::move(target)) {}
SrvRecordRdata::SrvRecordRdata(const SrvRecordRdata& other) = default;
SrvRecordRdata::SrvRecordRdata(SrvRecordRdata&& other) = default;
SrvRecordRdata& SrvRecordRdata::operator=(const SrvRecordRdata& rhs) = default;
SrvRecordRdata& SrvRecordRdata::operator=(SrvRecordRdata&& rhs) = default;
bool SrvRecordRdata::operator==(const SrvRecordRdata& rhs) const {
return priority_ == rhs.priority_ && weight_ == rhs.weight_ &&
port_ == rhs.port_ && target_ == rhs.target_;
}
bool SrvRecordRdata::operator!=(const SrvRecordRdata& rhs) const {
return !(*this == rhs);
}
size_t SrvRecordRdata::MaxWireSize() const {
// max_wire_size includes uint16_t record length field.
return sizeof(uint16_t) + sizeof(priority_) + sizeof(weight_) +
sizeof(port_) + target_.MaxWireSize();
}
ARecordRdata::ARecordRdata() = default;
ARecordRdata::ARecordRdata(IPAddress ipv4_address,
NetworkInterfaceIndex interface_index)
: ipv4_address_(std::move(ipv4_address)),
interface_index_(interface_index) {
OSP_CHECK(ipv4_address_.IsV4());
}
ARecordRdata::ARecordRdata(const ARecordRdata& other) = default;
ARecordRdata::ARecordRdata(ARecordRdata&& other) = default;
ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default;
ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default;
bool ARecordRdata::operator==(const ARecordRdata& rhs) const {
return ipv4_address_ == rhs.ipv4_address_ &&
interface_index_ == rhs.interface_index_;
}
bool ARecordRdata::operator!=(const ARecordRdata& rhs) const {
return !(*this == rhs);
}
size_t ARecordRdata::MaxWireSize() const {
// max_wire_size includes uint16_t record length field.
return sizeof(uint16_t) + IPAddress::kV4Size;
}
AAAARecordRdata::AAAARecordRdata() = default;
AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address,
NetworkInterfaceIndex interface_index)
: ipv6_address_(std::move(ipv6_address)),
interface_index_(interface_index) {
OSP_CHECK(ipv6_address_.IsV6());
}
AAAARecordRdata::AAAARecordRdata(const AAAARecordRdata& other) = default;
AAAARecordRdata::AAAARecordRdata(AAAARecordRdata&& other) = default;
AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) =
default;
AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default;
bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const {
return ipv6_address_ == rhs.ipv6_address_ &&
interface_index_ == rhs.interface_index_;
}
bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const {
return !(*this == rhs);
}
size_t AAAARecordRdata::MaxWireSize() const {
// max_wire_size includes uint16_t record length field.
return sizeof(uint16_t) + IPAddress::kV6Size;
}
PtrRecordRdata::PtrRecordRdata() = default;
PtrRecordRdata::PtrRecordRdata(DomainName ptr_domain)
: ptr_domain_(ptr_domain) {}
PtrRecordRdata::PtrRecordRdata(const PtrRecordRdata& other) = default;
PtrRecordRdata::PtrRecordRdata(PtrRecordRdata&& other) = default;
PtrRecordRdata& PtrRecordRdata::operator=(const PtrRecordRdata& rhs) = default;
PtrRecordRdata& PtrRecordRdata::operator=(PtrRecordRdata&& rhs) = default;
bool PtrRecordRdata::operator==(const PtrRecordRdata& rhs) const {
return ptr_domain_ == rhs.ptr_domain_;
}
bool PtrRecordRdata::operator!=(const PtrRecordRdata& rhs) const {
return !(*this == rhs);
}
size_t PtrRecordRdata::MaxWireSize() const {
// max_wire_size includes uint16_t record length field.
return sizeof(uint16_t) + ptr_domain_.MaxWireSize();
}
// static
ErrorOr<TxtRecordRdata> TxtRecordRdata::TryCreate(std::vector<Entry> texts) {
std::vector<std::string> str_texts;
size_t max_wire_size = 3;
if (texts.size() > 0) {
str_texts.reserve(texts.size());
// max_wire_size includes uint16_t record length field.
max_wire_size = sizeof(uint16_t);
for (const auto& text : texts) {
if (text.empty()) {
return Error::Code::kParameterInvalid;
}
str_texts.push_back(
std::string(reinterpret_cast<const char*>(text.data()), text.size()));
// Include the length byte in the size calculation.
max_wire_size += text.size() + 1;
}
}
return TxtRecordRdata(std::move(str_texts), max_wire_size);
}
TxtRecordRdata::TxtRecordRdata() = default;
TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) {
ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
OSP_DCHECK(rdata.is_value());
*this = std::move(rdata.value());
}
TxtRecordRdata::TxtRecordRdata(std::vector<std::string> texts,
size_t max_wire_size)
: max_wire_size_(max_wire_size), texts_(std::move(texts)) {}
TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default;
TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) = default;
TxtRecordRdata& TxtRecordRdata::operator=(const TxtRecordRdata& rhs) = default;
TxtRecordRdata& TxtRecordRdata::operator=(TxtRecordRdata&& rhs) = default;
bool TxtRecordRdata::operator==(const TxtRecordRdata& rhs) const {
return texts_ == rhs.texts_;
}
bool TxtRecordRdata::operator!=(const TxtRecordRdata& rhs) const {
return !(*this == rhs);
}
size_t TxtRecordRdata::MaxWireSize() const {
return max_wire_size_;
}
NsecRecordRdata::NsecRecordRdata() = default;
NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name,
std::vector<DnsType> types)
: types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) {
// Sort the types_ array for easier comparison later.
std::sort(types_.begin(), types_.end());
// Calculate the bitmaps as described in RFC 4034 Section 4.1.2.
std::vector<uint8_t> block_contents;
uint8_t current_block = 0;
for (auto type : types_) {
const uint16_t type_int = static_cast<uint16_t>(type);
const uint8_t block = static_cast<uint8_t>(type_int >> 8);
const uint8_t block_position = static_cast<uint8_t>(type_int & 0xFF);
const uint8_t byte_bit_is_at = block_position >> 3; // First 5 bits.
const uint8_t byte_mask = 0x80 >> (block_position & 0x07); // Last 3 bits.
// If the block has changed, write the previous block's info and all of its
// contents to the |encoded_types_| vector.
if (block > current_block) {
if (!block_contents.empty()) {
encoded_types_.push_back(current_block);
encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
block_contents.end());
}
block_contents = std::vector<uint8_t>();
current_block = block;
}
// Make sure |block_contents| is large enough to hold the bit representing
// the new type , then set it.
if (block_contents.size() <= byte_bit_is_at) {
block_contents.insert(block_contents.end(),
byte_bit_is_at - block_contents.size() + 1, 0x00);
}
block_contents[byte_bit_is_at] |= byte_mask;
}
if (!block_contents.empty()) {
encoded_types_.push_back(current_block);
encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
block_contents.end());
}
}
NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default;
NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) = default;
NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) =
default;
NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default;
bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const {
return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_;
}
bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const {
return !(*this == rhs);
}
size_t NsecRecordRdata::MaxWireSize() const {
return next_domain_name_.MaxWireSize() + encoded_types_.size();
}
size_t OptRecordRdata::Option::MaxWireSize() const {
// One uint16_t for each of OPTION-LENGTH and OPTION-CODE as defined in RFC
// 6891 section 6.1.2.
constexpr size_t kOptionLengthAndCodeSize = 2 * sizeof(uint16_t);
return data.size() + kOptionLengthAndCodeSize;
}
bool OptRecordRdata::Option::operator>(
const OptRecordRdata::Option& rhs) const {
if (code != rhs.code) {
return code > rhs.code;
} else if (length != rhs.length) {
return length > rhs.length;
} else if (data.size() != rhs.data.size()) {
return data.size() > rhs.data.size();
}
for (int i = 0; i < static_cast<int>(data.size()); i++) {
if (data[i] != rhs.data[i]) {
return data[i] > rhs.data[i];
}
}
return false;
}
bool OptRecordRdata::Option::operator<(
const OptRecordRdata::Option& rhs) const {
return rhs > *this;
}
bool OptRecordRdata::Option::operator>=(
const OptRecordRdata::Option& rhs) const {
return !(*this < rhs);
}
bool OptRecordRdata::Option::operator<=(
const OptRecordRdata::Option& rhs) const {
return !(*this > rhs);
}
bool OptRecordRdata::Option::operator==(
const OptRecordRdata::Option& rhs) const {
return *this >= rhs && *this <= rhs;
}
bool OptRecordRdata::Option::operator!=(
const OptRecordRdata::Option& rhs) const {
return !(*this == rhs);
}
OptRecordRdata::OptRecordRdata() = default;
OptRecordRdata::OptRecordRdata(std::vector<Option> options)
: options_(std::move(options)) {
for (const auto& option : options_) {
max_wire_size_ += option.MaxWireSize();
}
std::sort(options_.begin(), options_.end());
}
OptRecordRdata::OptRecordRdata(const OptRecordRdata& other) = default;
OptRecordRdata::OptRecordRdata(OptRecordRdata&& other) = default;
OptRecordRdata& OptRecordRdata::operator=(const OptRecordRdata& rhs) = default;
OptRecordRdata& OptRecordRdata::operator=(OptRecordRdata&& rhs) = default;
bool OptRecordRdata::operator==(const OptRecordRdata& rhs) const {
return options_ == rhs.options_;
}
bool OptRecordRdata::operator!=(const OptRecordRdata& rhs) const {
return !(*this == rhs);
}
// static
ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name,
DnsType dns_type,
DnsClass dns_class,
RecordType record_type,
std::chrono::seconds ttl,
Rdata rdata) {
if (!IsValidConfig(name, dns_type, ttl, rdata)) {
return Error::Code::kParameterInvalid;
} else {
return MdnsRecord(std::move(name), dns_type, dns_class, record_type, ttl,
std::move(rdata));
}
}
MdnsRecord::MdnsRecord() = default;
MdnsRecord::MdnsRecord(DomainName name,
DnsType dns_type,
DnsClass dns_class,
RecordType record_type,
std::chrono::seconds ttl,
Rdata rdata)
: name_(std::move(name)),
dns_type_(dns_type),
dns_class_(dns_class),
record_type_(record_type),
ttl_(ttl),
rdata_(std::move(rdata)) {
OSP_DCHECK(IsValidConfig(name_, dns_type, ttl_, rdata_));
}
MdnsRecord::MdnsRecord(const MdnsRecord& other) = default;
MdnsRecord::MdnsRecord(MdnsRecord&& other) = default;
MdnsRecord& MdnsRecord::operator=(const MdnsRecord& rhs) = default;
MdnsRecord& MdnsRecord::operator=(MdnsRecord&& rhs) = default;
// static
bool MdnsRecord::IsValidConfig(const DomainName& name,
DnsType dns_type,
std::chrono::seconds ttl,
const Rdata& rdata) {
// NOTE: Although the name_ field was initially expected to be non-empty, this
// validation is no longer accurate for some record types (such as OPT
// records). To ensure that future record types correctly parse into
// RawRecordData types and do not invalidate the received message, this check
// has been removed.
return ttl.count() <= std::numeric_limits<uint32_t>::max() &&
((dns_type == DnsType::kSRV &&
absl::holds_alternative<SrvRecordRdata>(rdata)) ||
(dns_type == DnsType::kA &&
absl::holds_alternative<ARecordRdata>(rdata)) ||
(dns_type == DnsType::kAAAA &&
absl::holds_alternative<AAAARecordRdata>(rdata)) ||
(dns_type == DnsType::kPTR &&
absl::holds_alternative<PtrRecordRdata>(rdata)) ||
(dns_type == DnsType::kTXT &&
absl::holds_alternative<TxtRecordRdata>(rdata)) ||
(dns_type == DnsType::kNSEC &&
absl::holds_alternative<NsecRecordRdata>(rdata)) ||
(dns_type == DnsType::kOPT &&
absl::holds_alternative<OptRecordRdata>(rdata)) ||
absl::holds_alternative<RawRecordRdata>(rdata));
}
bool MdnsRecord::operator==(const MdnsRecord& rhs) const {
return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
record_type_ == rhs.record_type_ && ttl_ == rhs.ttl_ &&
name_ == rhs.name_ && rdata_ == rhs.rdata_;
}
bool MdnsRecord::operator!=(const MdnsRecord& rhs) const {
return !(*this == rhs);
}
bool MdnsRecord::operator>(const MdnsRecord& rhs) const {
// Returns the record which is lexicographically later. The determination of
// "lexicographically later" is performed by first comparing the record class,
// then the record type, then raw comparison of the binary content of the
// rdata without regard for meaning or structure.
// NOTE: Per RFC, the TTL is not included in this comparison.
if (name() != rhs.name()) {
return name() > rhs.name();
}
if (record_type() != rhs.record_type()) {
return record_type() == RecordType::kUnique;
}
if (dns_class() != rhs.dns_class()) {
return dns_class() > rhs.dns_class();
}
uint16_t this_type = static_cast<uint16_t>(dns_type()) & kClassMask;
uint16_t other_type = static_cast<uint16_t>(rhs.dns_type()) & kClassMask;
if (this_type != other_type) {
return this_type > other_type;
}
return IsGreaterThan(dns_type(), rdata(), rhs.rdata());
}
bool MdnsRecord::operator<(const MdnsRecord& rhs) const {
return rhs > *this;
}
bool MdnsRecord::operator<=(const MdnsRecord& rhs) const {
return !(*this > rhs);
}
bool MdnsRecord::operator>=(const MdnsRecord& rhs) const {
return !(*this < rhs);
}
size_t MdnsRecord::MaxWireSize() const {
auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); };
// NAME size, 2-byte TYPE, 2-byte CLASS, 4-byte TTL, RDATA size
return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8;
}
std::string MdnsRecord::ToString() const {
std::stringstream ss;
ss << "name: '" << name_.ToString() << "'";
ss << ", type: " << dns_type_;
if (dns_type_ == DnsType::kPTR) {
const DomainName& target = absl::get<PtrRecordRdata>(rdata_).ptr_domain();
ss << ", target: '" << target.ToString() << "'";
} else if (dns_type_ == DnsType::kSRV) {
const DomainName& target = absl::get<SrvRecordRdata>(rdata_).target();
ss << ", target: '" << target.ToString() << "'";
} else if (dns_type_ == DnsType::kNSEC) {
const auto& nsec_rdata = absl::get<NsecRecordRdata>(rdata_);
std::vector<DnsType> types = nsec_rdata.types();
ss << ", representing [";
if (!types.empty()) {
auto it = types.begin();
ss << *it++;
while (it != types.end()) {
ss << ", " << *it++;
}
ss << "]";
}
}
return ss.str();
}
MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) {
Rdata rdata;
DnsType type;
std::chrono::seconds ttl;
if (address.IsV4()) {
type = DnsType::kA;
rdata = ARecordRdata(address);
ttl = kARecordTtl;
} else {
type = DnsType::kAAAA;
rdata = AAAARecordRdata(address);
ttl = kAAAARecordTtl;
}
return MdnsRecord(std::move(name), type, DnsClass::kIN, RecordType::kUnique,
ttl, std::move(rdata));
}
// static
ErrorOr<MdnsQuestion> MdnsQuestion::TryCreate(DomainName name,
DnsType dns_type,
DnsClass dns_class,
ResponseType response_type) {
if (name.empty()) {
return Error::Code::kParameterInvalid;
}
return MdnsQuestion(std::move(name), dns_type, dns_class, response_type);
}
MdnsQuestion::MdnsQuestion(DomainName name,
DnsType dns_type,
DnsClass dns_class,
ResponseType response_type)
: name_(std::move(name)),
dns_type_(dns_type),
dns_class_(dns_class),
response_type_(response_type) {
OSP_CHECK(!name_.empty());
}
bool MdnsQuestion::operator==(const MdnsQuestion& rhs) const {
return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
response_type_ == rhs.response_type_ && name_ == rhs.name_;
}
bool MdnsQuestion::operator!=(const MdnsQuestion& rhs) const {
return !(*this == rhs);
}
size_t MdnsQuestion::MaxWireSize() const {
// NAME size, 2-byte TYPE, 2-byte CLASS
return name_.MaxWireSize() + 4;
}
// static
ErrorOr<MdnsMessage> MdnsMessage::TryCreate(
uint16_t id,
MessageType type,
std::vector<MdnsQuestion> questions,
std::vector<MdnsRecord> answers,
std::vector<MdnsRecord> authority_records,
std::vector<MdnsRecord> additional_records) {
if (questions.size() >= kMaxMessageFieldEntryCount ||
answers.size() >= kMaxMessageFieldEntryCount ||
authority_records.size() >= kMaxMessageFieldEntryCount ||
additional_records.size() >= kMaxMessageFieldEntryCount) {
return Error::Code::kParameterInvalid;
}
return MdnsMessage(id, type, std::move(questions), std::move(answers),
std::move(authority_records),
std::move(additional_records));
}
MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
: id_(id), type_(type) {}
MdnsMessage::MdnsMessage(uint16_t id,
MessageType type,
std::vector<MdnsQuestion> questions,
std::vector<MdnsRecord> answers,
std::vector<MdnsRecord> authority_records,
std::vector<MdnsRecord> additional_records)
: id_(id),
type_(type),
questions_(std::move(questions)),
answers_(std::move(answers)),
authority_records_(std::move(authority_records)),
additional_records_(std::move(additional_records)) {
OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
for (const MdnsQuestion& question : questions_) {
max_wire_size_ += question.MaxWireSize();
}
for (const MdnsRecord& record : answers_) {
max_wire_size_ += record.MaxWireSize();
}
for (const MdnsRecord& record : authority_records_) {
max_wire_size_ += record.MaxWireSize();
}
for (const MdnsRecord& record : additional_records_) {
max_wire_size_ += record.MaxWireSize();
}
}
bool MdnsMessage::operator==(const MdnsMessage& rhs) const {
return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ &&
answers_ == rhs.answers_ &&
authority_records_ == rhs.authority_records_ &&
additional_records_ == rhs.additional_records_;
}
bool MdnsMessage::operator!=(const MdnsMessage& rhs) const {
return !(*this == rhs);
}
bool MdnsMessage::IsProbeQuery() const {
// A message is a probe query if it contains records in the authority section
// which answer the question being asked.
if (questions().empty() || authority_records().empty()) {
return false;
}
for (const MdnsQuestion& question : questions_) {
for (const MdnsRecord& record : authority_records_) {
if (question.name() == record.name() &&
((question.dns_type() == record.dns_type()) ||
(question.dns_type() == DnsType::kANY)) &&
((question.dns_class() == record.dns_class()) ||
(question.dns_class() == DnsClass::kANY))) {
return true;
}
}
}
return false;
}
size_t MdnsMessage::MaxWireSize() const {
return max_wire_size_;
}
void MdnsMessage::AddQuestion(MdnsQuestion question) {
OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += question.MaxWireSize();
questions_.emplace_back(std::move(question));
}
void MdnsMessage::AddAnswer(MdnsRecord record) {
OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += record.MaxWireSize();
answers_.emplace_back(std::move(record));
}
void MdnsMessage::AddAuthorityRecord(MdnsRecord record) {
OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += record.MaxWireSize();
authority_records_.emplace_back(std::move(record));
}
void MdnsMessage::AddAdditionalRecord(MdnsRecord record) {
OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
max_wire_size_ += record.MaxWireSize();
additional_records_.emplace_back(std::move(record));
}
bool MdnsMessage::CanAddRecord(const MdnsRecord& record) {
return (max_wire_size_ + record.MaxWireSize()) < kMaxMulticastMessageSize;
}
uint16_t CreateMessageId() {
static uint16_t id(0);
return id++;
}
bool CanBePublished(DnsType type) {
// NOTE: A 'default' switch statement has intentionally been avoided below to
// enforce that new DnsTypes added must be added below through a compile-time
// check.
switch (type) {
case DnsType::kA:
case DnsType::kAAAA:
case DnsType::kPTR:
case DnsType::kTXT:
case DnsType::kSRV:
return true;
case DnsType::kOPT:
case DnsType::kNSEC:
case DnsType::kANY:
break;
}
return false;
}
bool CanBePublished(const MdnsRecord& record) {
return CanBePublished(record.dns_type());
}
bool CanBeQueried(DnsType type) {
// NOTE: A 'default' switch statement has intentionally been avoided below to
// enforce that new DnsTypes added must be added below through a compile-time
// check.
switch (type) {
case DnsType::kA:
case DnsType::kAAAA:
case DnsType::kPTR:
case DnsType::kTXT:
case DnsType::kSRV:
case DnsType::kANY:
return true;
case DnsType::kOPT:
case DnsType::kNSEC:
break;
}
return false;
}
bool CanBeProcessed(DnsType type) {
// NOTE: A 'default' switch statement has intentionally been avoided below to
// enforce that new DnsTypes added must be added below through a compile-time
// check.
switch (type) {
case DnsType::kA:
case DnsType::kAAAA:
case DnsType::kPTR:
case DnsType::kTXT:
case DnsType::kSRV:
case DnsType::kNSEC:
return true;
case DnsType::kOPT:
case DnsType::kANY:
break;
}
return false;
}
} // namespace discovery
} // namespace openscreen