blob: 18e06587820e9384c94154ee5cb6cf8ffa634577 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "cast/common/mdns/mdns_trackers.h"
#include <array>
#include "util/std_util.h"
namespace cast {
namespace mdns {
namespace {
// RFC 6762 Section 5.2
// https://tools.ietf.org/html/rfc6762#section-5.2
// Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
// Intervals between successive queries must increase by at least a factor of 2.
constexpr int kIntervalIncreaseFactor = 2;
// The interval between the first two queries must be at least one second.
constexpr std::chrono::seconds kMinimumQueryInterval{1};
// The querier may cap the question refresh interval to a maximum of 60 minutes.
constexpr std::chrono::minutes kMaximumQueryInterval{60};
// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// A goodbye record is a record with TTL of 0.
bool IsGoodbyeRecord(const MdnsRecord& record) {
constexpr std::chrono::seconds zero_ttl{0};
return record.ttl() == zero_ttl;
}
// The interval between the first two queries must be at least one second.
constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
} // namespace
MdnsTracker::MdnsTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay)
: sender_(sender),
now_function_(now_function),
send_alarm_(now_function, task_runner),
random_delay_(random_delay) {
OSP_DCHECK(task_runner);
OSP_DCHECK(now_function);
OSP_DCHECK(random_delay);
OSP_DCHECK(sender);
}
MdnsRecordTracker::MdnsRecordTracker(
MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay,
std::function<void(const MdnsRecord&)> record_updated_callback,
std::function<void(const MdnsRecord&)> record_expired_callback)
: MdnsTracker(sender, task_runner, now_function, random_delay),
record_updated_callback_(record_updated_callback),
record_expired_callback_(record_expired_callback) {
OSP_DCHECK(record_updated_callback);
OSP_DCHECK(record_expired_callback);
}
Error MdnsRecordTracker::Start(MdnsRecord record) {
if (record_.has_value()) {
return Error::Code::kOperationInvalid;
}
record_ = std::move(record);
start_time_ = now_function_();
send_count_ = 0;
send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
GetNextSendTime());
return Error::None();
}
Error MdnsRecordTracker::Stop() {
if (!record_.has_value()) {
return Error::Code::kOperationInvalid;
}
send_alarm_.Cancel();
record_.reset();
return Error::None();
}
Error MdnsRecordTracker::Update(const MdnsRecord& new_record) {
if (!record_.has_value()) {
return Error::Code::kOperationInvalid;
}
MdnsRecord& old_record = record_.value();
if ((old_record.dns_type() != new_record.dns_type()) ||
(old_record.dns_class() != new_record.dns_class()) ||
(old_record.name() != new_record.name())) {
// The new record has been passed to a wrong tracker
return Error::Code::kParameterInvalid;
}
// Check if RDATA has changed before a call to Stop clears the old record
bool is_updated = (new_record.rdata() != old_record.rdata());
Error error = Stop();
if (!error.ok()) {
return error;
}
if (IsGoodbyeRecord(new_record)) {
// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// In case of a goodbye record, the querier should set TTL to 1 second
error = Start(MdnsRecord(new_record.name(), new_record.dns_type(),
new_record.dns_class(), new_record.record_type(),
kGoodbyeRecordTtl, new_record.rdata()));
} else {
error = Start(new_record);
}
if (!error.ok()) {
return error;
}
if (is_updated) {
record_updated_callback_(record_.value());
}
return Error::None();
}
bool MdnsRecordTracker::IsStarted() {
return record_.has_value();
};
void MdnsRecordTracker::SendQuery() {
const MdnsRecord& record = record_.value();
const Clock::time_point expiration_time = start_time_ + record.ttl();
const bool is_expired = (now_function_() >= expiration_time);
if (!is_expired) {
MdnsQuestion question(record.name(), record.dns_type(), record.dns_class(),
ResponseType::kMulticast);
MdnsMessage message(CreateMessageId(), MessageType::Query);
message.AddQuestion(std::move(question));
sender_->SendMulticast(message);
send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
GetNextSendTime());
} else {
record_expired_callback_(record);
}
}
Clock::time_point MdnsRecordTracker::GetNextSendTime() {
OSP_DCHECK(send_count_ < openscreen::countof(kTtlFractions));
double ttl_fraction = kTtlFractions[send_count_++];
// Do not add random variation to the expiration time (last fraction of TTL)
if (send_count_ != openscreen::countof(kTtlFractions)) {
ttl_fraction += random_delay_->GetRecordTtlVariation();
}
const Clock::duration delay = std::chrono::duration_cast<Clock::duration>(
record_.value().ttl() * ttl_fraction);
return start_time_ + delay;
}
MdnsQuestionTracker::MdnsQuestionTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay)
: MdnsTracker(sender, task_runner, now_function, random_delay) {}
Error MdnsQuestionTracker::Start(MdnsQuestion question) {
if (question_.has_value()) {
return Error::Code::kOperationInvalid;
}
question_ = std::move(question);
send_delay_ = kMinimumQueryInterval;
// The initial query has to be sent after a random delay of 20-120
// milliseconds.
const Clock::duration delay = random_delay_->GetInitialQueryDelay();
send_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuery, this),
now_function_() + delay);
return Error::None();
}
Error MdnsQuestionTracker::Stop() {
if (!question_.has_value()) {
return Error::Code::kOperationInvalid;
}
send_alarm_.Cancel();
question_.reset();
return Error::None();
}
bool MdnsQuestionTracker::IsStarted() {
return question_.has_value();
};
void MdnsQuestionTracker::SendQuery() {
MdnsMessage message(CreateMessageId(), MessageType::Query);
message.AddQuestion(question_.value());
// TODO(yakimakha): Implement known-answer suppression by adding known
// answers to the question
sender_->SendMulticast(message);
send_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuery, this),
now_function_() + send_delay_);
send_delay_ = send_delay_ * kIntervalIncreaseFactor;
if (send_delay_ > kMaximumQueryInterval) {
send_delay_ = kMaximumQueryInterval;
}
}
} // namespace mdns
} // namespace cast