blob: 9c510fbbf534f72887dc66bea0d2dc5d252f98e0 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/mdns/mdns_trackers.h"
#include <array>
#include <limits>
#include <utility>
#include "discovery/common/config.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
#include "util/std_util.h"
namespace openscreen {
namespace discovery {
namespace {
// RFC 6762 Section 5.2
// https://tools.ietf.org/html/rfc6762#section-5.2
// Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
// Intervals between successive queries must increase by at least a factor of 2.
constexpr int kIntervalIncreaseFactor = 2;
// The interval between the first two queries must be at least one second.
constexpr std::chrono::seconds kMinimumQueryInterval{1};
// The querier may cap the question refresh interval to a maximum of 60 minutes.
constexpr std::chrono::minutes kMaximumQueryInterval{60};
// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// A goodbye record is a record with TTL of 0.
bool IsGoodbyeRecord(const MdnsRecord& record) {
return record.ttl() == std::chrono::seconds(0);
}
bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) {
if (record.dns_type() != DnsType::kNSEC) {
return false;
}
const auto& nsec_types = absl::get<NsecRecordRdata>(record.rdata()).types();
return std::find_if(nsec_types.begin(), nsec_types.end(),
[dns_type](DnsType type) {
return type == dns_type || type == DnsType::kANY;
}) != nsec_types.end();
}
// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// In case of a goodbye record, the querier should set TTL to 1 second
constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
} // namespace
MdnsTracker::MdnsTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay,
TrackerType tracker_type)
: sender_(sender),
task_runner_(task_runner),
now_function_(now_function),
send_alarm_(now_function, task_runner),
random_delay_(random_delay),
tracker_type_(tracker_type) {
OSP_DCHECK(task_runner_);
OSP_DCHECK(now_function_);
OSP_DCHECK(random_delay_);
OSP_DCHECK(sender_);
}
MdnsTracker::~MdnsTracker() {
send_alarm_.Cancel();
for (const MdnsTracker* node : adjacent_nodes_) {
node->RemovedReverseAdjacency(this);
}
}
bool MdnsTracker::AddAdjacentNode(const MdnsTracker* node) const {
OSP_DCHECK(node);
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
if (it != adjacent_nodes_.end()) {
return false;
}
adjacent_nodes_.push_back(node);
node->AddReverseAdjacency(this);
return true;
}
bool MdnsTracker::RemoveAdjacentNode(const MdnsTracker* node) const {
OSP_DCHECK(node);
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
if (it == adjacent_nodes_.end()) {
return false;
}
adjacent_nodes_.erase(it);
node->RemovedReverseAdjacency(this);
return true;
}
void MdnsTracker::AddReverseAdjacency(const MdnsTracker* node) const {
OSP_DCHECK(std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node) ==
adjacent_nodes_.end());
adjacent_nodes_.push_back(node);
}
void MdnsTracker::RemovedReverseAdjacency(const MdnsTracker* node) const {
auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
OSP_DCHECK(it != adjacent_nodes_.end());
adjacent_nodes_.erase(it);
}
MdnsRecordTracker::MdnsRecordTracker(
MdnsRecord record,
DnsType dns_type,
MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay,
RecordExpiredCallback record_expired_callback)
: MdnsTracker(sender,
task_runner,
now_function,
random_delay,
TrackerType::kRecordTracker),
record_(std::move(record)),
dns_type_(dns_type),
start_time_(now_function_()),
record_expired_callback_(std::move(record_expired_callback)) {
OSP_DCHECK(record_expired_callback_);
// RecordTrackers cannot be created for tracking NSEC types or ANY types.
OSP_DCHECK(dns_type_ != DnsType::kNSEC);
OSP_DCHECK(dns_type_ != DnsType::kANY);
// Validate that, if the provided |record| is an NSEC record, then it provides
// a negative response for |dns_type|.
OSP_DCHECK(record_.dns_type() != DnsType::kNSEC ||
IsNegativeResponseForType(record_, dns_type_));
ScheduleFollowUpQuery();
}
MdnsRecordTracker::~MdnsRecordTracker() = default;
ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
const MdnsRecord& new_record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
const bool has_same_rdata = record_.dns_type() == new_record.dns_type() &&
record_.rdata() == new_record.rdata();
const bool new_is_negative_response = new_record.dns_type() == DnsType::kNSEC;
const bool current_is_negative_response =
record_.dns_type() == DnsType::kNSEC;
if ((record_.dns_class() != new_record.dns_class()) ||
(record_.name() != new_record.name())) {
// The new record has been passed to a wrong tracker.
return Error::Code::kParameterInvalid;
}
// New response record must correspond to the correct type.
if ((!new_is_negative_response && new_record.dns_type() != dns_type_) ||
(new_is_negative_response &&
!IsNegativeResponseForType(new_record, dns_type_))) {
// The new record has been passed to a wrong tracker.
return Error::Code::kParameterInvalid;
}
// Goodbye records must have the same RDATA but TTL of 0.
// RFC 6762 Section 10.1.
// https://tools.ietf.org/html/rfc6762#section-10.1
if (!new_is_negative_response && !current_is_negative_response &&
IsGoodbyeRecord(new_record) && !has_same_rdata) {
// The new record has been passed to a wrong tracker.
return Error::Code::kParameterInvalid;
}
UpdateType result = UpdateType::kGoodbye;
if (IsGoodbyeRecord(new_record)) {
record_ = MdnsRecord(new_record.name(), new_record.dns_type(),
new_record.dns_class(), new_record.record_type(),
kGoodbyeRecordTtl, new_record.rdata());
// Goodbye records do not need to be re-queried, set the attempt count to
// the last item, which is 100% of TTL, i.e. record expiration.
attempt_count_ = countof(kTtlFractions) - 1;
} else {
record_ = new_record;
attempt_count_ = 0;
result = has_same_rdata ? UpdateType::kTTLOnly : UpdateType::kRdata;
}
start_time_ = now_function_();
ScheduleFollowUpQuery();
return result;
}
bool MdnsRecordTracker::AddAssociatedQuery(
const MdnsQuestionTracker* question_tracker) const {
return AddAdjacentNode(question_tracker);
}
bool MdnsRecordTracker::RemoveAssociatedQuery(
const MdnsQuestionTracker* question_tracker) const {
return RemoveAdjacentNode(question_tracker);
}
void MdnsRecordTracker::ExpireSoon() {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
record_ =
MdnsRecord(record_.name(), record_.dns_type(), record_.dns_class(),
record_.record_type(), kGoodbyeRecordTtl, record_.rdata());
// Set the attempt count to the last item, which is 100% of TTL, i.e. record
// expiration, to prevent any re-queries
attempt_count_ = countof(kTtlFractions) - 1;
start_time_ = now_function_();
ScheduleFollowUpQuery();
}
void MdnsRecordTracker::ExpireNow() {
record_expired_callback_(this, record_);
}
bool MdnsRecordTracker::IsNearingExpiry() const {
return (now_function_() - start_time_) > record_.ttl() / 2;
}
bool MdnsRecordTracker::SendQuery() const {
const Clock::time_point expiration_time = start_time_ + record_.ttl();
bool is_expired = (now_function_() >= expiration_time);
if (!is_expired) {
for (const MdnsTracker* tracker : adjacent_nodes()) {
tracker->SendQuery();
}
} else {
record_expired_callback_(this, record_);
}
return !is_expired;
}
void MdnsRecordTracker::ScheduleFollowUpQuery() {
send_alarm_.Schedule(
[this] {
if (SendQuery()) {
ScheduleFollowUpQuery();
}
},
GetNextSendTime());
}
std::vector<MdnsRecord> MdnsRecordTracker::GetRecords() const {
return {record_};
}
Clock::time_point MdnsRecordTracker::GetNextSendTime() {
OSP_DCHECK(attempt_count_ < countof(kTtlFractions));
double ttl_fraction = kTtlFractions[attempt_count_++];
// Do not add random variation to the expiration time (last fraction of TTL)
if (attempt_count_ != countof(kTtlFractions)) {
ttl_fraction += random_delay_->GetRecordTtlVariation();
}
const Clock::duration delay =
Clock::to_duration(record_.ttl() * ttl_fraction);
return start_time_ + delay;
}
MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question,
MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay,
const Config& config,
QueryType query_type)
: MdnsTracker(sender,
task_runner,
now_function,
random_delay,
TrackerType::kQuestionTracker),
question_(std::move(question)),
send_delay_(kMinimumQueryInterval),
query_type_(query_type),
maximum_announcement_count_(config.new_query_announcement_count < 0
? INT_MAX
: config.new_query_announcement_count) {
// Initialize the last send time to time_point::min() so that the next call to
// SendQuery() is guaranteed to query the network.
last_send_time_ = TrivialClockTraits::time_point::min();
// The initial query has to be sent after a random delay of 20-120
// milliseconds.
if (announcements_so_far_ < maximum_announcement_count_) {
announcements_so_far_++;
if (query_type_ == QueryType::kOneShot) {
task_runner_->PostTask([this] { MdnsQuestionTracker::SendQuery(); });
} else {
OSP_DCHECK(query_type_ == QueryType::kContinuous);
send_alarm_.ScheduleFromNow(
[this]() {
MdnsQuestionTracker::SendQuery();
ScheduleFollowUpQuery();
},
random_delay_->GetInitialQueryDelay());
}
}
}
MdnsQuestionTracker::~MdnsQuestionTracker() = default;
bool MdnsQuestionTracker::AddAssociatedRecord(
const MdnsRecordTracker* record_tracker) const {
return AddAdjacentNode(record_tracker);
}
bool MdnsQuestionTracker::RemoveAssociatedRecord(
const MdnsRecordTracker* record_tracker) const {
return RemoveAdjacentNode(record_tracker);
}
std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const {
std::vector<MdnsRecord> records;
for (const MdnsTracker* tracker : adjacent_nodes()) {
OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker);
// This call cannot result in an infinite loop because MdnsRecordTracker
// instances only return a single record from this call.
std::vector<MdnsRecord> node_records = tracker->GetRecords();
OSP_DCHECK(node_records.size() == 1);
records.push_back(std::move(node_records[0]));
}
return records;
}
bool MdnsQuestionTracker::SendQuery() const {
// NOTE: The RFC does not specify the minimum interval between queries for
// multiple records of the same query when initiated for different reasons
// (such as for different record refreshes or for one record refresh and the
// periodic re-querying for a continuous query). For this reason, a constant
// outside of scope of the RFC has been chosen.
TrivialClockTraits::time_point now = now_function_();
if (now < last_send_time_ + kMinimumQueryInterval) {
return true;
}
last_send_time_ = now;
MdnsMessage message(CreateMessageId(), MessageType::Query);
message.AddQuestion(question_);
// Send the message and additional known answer packets as needed.
for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) {
OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker);
const MdnsRecordTracker* record_tracker =
static_cast<const MdnsRecordTracker*>(*it);
if (record_tracker->IsNearingExpiry()) {
it++;
continue;
}
// A record tracker should only contain one record.
std::vector<MdnsRecord> node_records = (*it)->GetRecords();
OSP_DCHECK(node_records.size() == 1);
MdnsRecord node_record = std::move(node_records[0]);
if (message.CanAddRecord(node_record)) {
message.AddAnswer(std::move(node_record));
it++;
} else if (message.questions().empty() && message.answers().empty()) {
// This case should never happen, because it means a record is too large
// to fit into its own message.
OSP_LOG_INFO
<< "Encountered unreasonably large message in cache. Skipping "
<< "known answer in suppressions...";
it++;
} else {
message.set_truncated();
sender_->SendMulticast(message);
message = MdnsMessage(CreateMessageId(), MessageType::Query);
}
}
sender_->SendMulticast(message);
return true;
}
void MdnsQuestionTracker::ScheduleFollowUpQuery() {
if (announcements_so_far_ >= maximum_announcement_count_) {
return;
}
announcements_so_far_++;
send_alarm_.ScheduleFromNow(
[this] {
if (SendQuery()) {
ScheduleFollowUpQuery();
}
},
send_delay_);
send_delay_ = send_delay_ * kIntervalIncreaseFactor;
if (send_delay_ > kMaximumQueryInterval) {
send_delay_ = kMaximumQueryInterval;
}
}
} // namespace discovery
} // namespace openscreen