Update and expiration callbacks for MdnsRecordTracker
Change-Id: I21fe7ded3f52cd1749d8a154e4dfc472efdf25c3
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1799365
Reviewed-by: Ryan Keane <rwkeane@google.com>
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
diff --git a/cast/common/mdns/mdns_records.cc b/cast/common/mdns/mdns_records.cc
index 128a60e..45d9d25 100644
--- a/cast/common/mdns/mdns_records.cc
+++ b/cast/common/mdns/mdns_records.cc
@@ -172,6 +172,7 @@
ttl_(ttl),
rdata_(std::move(rdata)) {
OSP_DCHECK(!name_.empty());
+ OSP_DCHECK_LE(ttl_.count(), std::numeric_limits<uint32_t>::max());
OSP_DCHECK((dns_type == DnsType::kSRV &&
absl::holds_alternative<SrvRecordRdata>(rdata_)) ||
(dns_type == DnsType::kA &&
diff --git a/cast/common/mdns/mdns_trackers.cc b/cast/common/mdns/mdns_trackers.cc
index 200b5b3..18e0658 100644
--- a/cast/common/mdns/mdns_trackers.cc
+++ b/cast/common/mdns/mdns_trackers.cc
@@ -11,6 +11,37 @@
namespace cast {
namespace mdns {
+namespace {
+
+// RFC 6762 Section 5.2
+// https://tools.ietf.org/html/rfc6762#section-5.2
+
+// Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
+constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
+
+// Intervals between successive queries must increase by at least a factor of 2.
+constexpr int kIntervalIncreaseFactor = 2;
+
+// The interval between the first two queries must be at least one second.
+constexpr std::chrono::seconds kMinimumQueryInterval{1};
+
+// The querier may cap the question refresh interval to a maximum of 60 minutes.
+constexpr std::chrono::minutes kMaximumQueryInterval{60};
+
+// RFC 6762 Section 10.1
+// https://tools.ietf.org/html/rfc6762#section-10.1
+
+// A goodbye record is a record with TTL of 0.
+bool IsGoodbyeRecord(const MdnsRecord& record) {
+ constexpr std::chrono::seconds zero_ttl{0};
+ return record.ttl() == zero_ttl;
+}
+
+// The interval between the first two queries must be at least one second.
+constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
+
+} // namespace
+
MdnsTracker::MdnsTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
@@ -25,17 +56,19 @@
OSP_DCHECK(sender);
}
-namespace {
-// RFC 6762 Section 5.2
-// https://tools.ietf.org/html/rfc6762#section-5.2
-constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
-} // namespace
-
-MdnsRecordTracker::MdnsRecordTracker(MdnsSender* sender,
- TaskRunner* task_runner,
- ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay)
- : MdnsTracker(sender, task_runner, now_function, random_delay) {}
+MdnsRecordTracker::MdnsRecordTracker(
+ MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay,
+ std::function<void(const MdnsRecord&)> record_updated_callback,
+ std::function<void(const MdnsRecord&)> record_expired_callback)
+ : MdnsTracker(sender, task_runner, now_function, random_delay),
+ record_updated_callback_(record_updated_callback),
+ record_expired_callback_(record_expired_callback) {
+ OSP_DCHECK(record_updated_callback);
+ OSP_DCHECK(record_expired_callback);
+}
Error MdnsRecordTracker::Start(MdnsRecord record) {
if (record_.has_value()) {
@@ -60,6 +93,49 @@
return Error::None();
}
+Error MdnsRecordTracker::Update(const MdnsRecord& new_record) {
+ if (!record_.has_value()) {
+ return Error::Code::kOperationInvalid;
+ }
+
+ MdnsRecord& old_record = record_.value();
+ if ((old_record.dns_type() != new_record.dns_type()) ||
+ (old_record.dns_class() != new_record.dns_class()) ||
+ (old_record.name() != new_record.name())) {
+ // The new record has been passed to a wrong tracker
+ return Error::Code::kParameterInvalid;
+ }
+
+ // Check if RDATA has changed before a call to Stop clears the old record
+ bool is_updated = (new_record.rdata() != old_record.rdata());
+
+ Error error = Stop();
+ if (!error.ok()) {
+ return error;
+ }
+
+ if (IsGoodbyeRecord(new_record)) {
+ // RFC 6762 Section 10.1
+ // https://tools.ietf.org/html/rfc6762#section-10.1
+ // In case of a goodbye record, the querier should set TTL to 1 second
+ error = Start(MdnsRecord(new_record.name(), new_record.dns_type(),
+ new_record.dns_class(), new_record.record_type(),
+ kGoodbyeRecordTtl, new_record.rdata()));
+ } else {
+ error = Start(new_record);
+ }
+
+ if (!error.ok()) {
+ return error;
+ }
+
+ if (is_updated) {
+ record_updated_callback_(record_.value());
+ }
+
+ return Error::None();
+}
+
bool MdnsRecordTracker::IsStarted() {
return record_.has_value();
};
@@ -77,7 +153,7 @@
send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
GetNextSendTime());
} else {
- // TODO(yakimakha): Notify owner that the record has expired
+ record_expired_callback_(record);
}
}
@@ -96,14 +172,6 @@
return start_time_ + delay;
}
-namespace {
-// RFC 6762 Section 5.2
-// https://tools.ietf.org/html/rfc6762#section-5.2
-constexpr int kIntervalIncreaseFactor = 2;
-constexpr std::chrono::seconds kMinimumQueryInterval{1};
-constexpr std::chrono::minutes kMaximumQueryInterval{60};
-} // namespace
-
MdnsQuestionTracker::MdnsQuestionTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
diff --git a/cast/common/mdns/mdns_trackers.h b/cast/common/mdns/mdns_trackers.h
index 7a69e8b..153facd 100644
--- a/cast/common/mdns/mdns_trackers.h
+++ b/cast/common/mdns/mdns_trackers.h
@@ -52,10 +52,13 @@
// refreshing records as they reach their expiration time.
class MdnsRecordTracker : public MdnsTracker {
public:
- MdnsRecordTracker(MdnsSender* sender,
- TaskRunner* task_runner,
- ClockNowFunctionPtr now_function,
- MdnsRandom* random_delay);
+ MdnsRecordTracker(
+ MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay,
+ std::function<void(const MdnsRecord&)> record_updated_callback,
+ std::function<void(const MdnsRecord&)> record_expired_callback);
// Starts sending query messages for the provided record using record's TTL
// and the time of the call to determine when to send the queries. Returns
@@ -70,6 +73,12 @@
// stopped.
Error Stop();
+ // Updates record tracker with the new record:
+ // 1. Calls update callback if RDATA has changed.
+ // 2. Resets TTL to the value specified in new_record.
+ // 3. Schedules expiration in case of a goodbye record.
+ Error Update(const MdnsRecord& new_record);
+
// Returns true if MdnsRecordTracker instance has been started and is
// automatically refreshing the record, false otherwise.
bool IsStarted();
@@ -84,6 +93,8 @@
Clock::time_point start_time_;
// Number of times a question to refresh the record has been sent.
size_t send_count_ = 0;
+ std::function<void(const MdnsRecord&)> record_updated_callback_;
+ std::function<void(const MdnsRecord&)> record_expired_callback_;
};
// MdnsQuestionTracker manages automatic resending of mDNS queries for
diff --git a/cast/common/mdns/mdns_trackers_unittest.cc b/cast/common/mdns/mdns_trackers_unittest.cc
index 334ff2b..95609ad 100644
--- a/cast/common/mdns/mdns_trackers_unittest.cc
+++ b/cast/common/mdns/mdns_trackers_unittest.cc
@@ -73,25 +73,25 @@
ARecordRdata(IPAddress{172, 0, 0, 1})) {}
template <class TrackerType, class TrackedType>
- void TrackerStartStop(TrackedType tracked_data) {
- TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
- EXPECT_EQ(tracker.IsStarted(), false);
- EXPECT_EQ(tracker.Stop(), Error(Error::Code::kOperationInvalid));
- EXPECT_EQ(tracker.IsStarted(), false);
- EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
- EXPECT_EQ(tracker.IsStarted(), true);
- EXPECT_EQ(tracker.Start(tracked_data),
+ void TrackerStartStop(std::unique_ptr<TrackerType> tracker,
+ TrackedType tracked_data) {
+ EXPECT_EQ(tracker->IsStarted(), false);
+ EXPECT_EQ(tracker->Stop(), Error(Error::Code::kOperationInvalid));
+ EXPECT_EQ(tracker->IsStarted(), false);
+ EXPECT_EQ(tracker->Start(tracked_data), Error(Error::Code::kNone));
+ EXPECT_EQ(tracker->IsStarted(), true);
+ EXPECT_EQ(tracker->Start(tracked_data),
Error(Error::Code::kOperationInvalid));
- EXPECT_EQ(tracker.IsStarted(), true);
- EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
- EXPECT_EQ(tracker.IsStarted(), false);
+ EXPECT_EQ(tracker->IsStarted(), true);
+ EXPECT_EQ(tracker->Stop(), Error(Error::Code::kNone));
+ EXPECT_EQ(tracker->IsStarted(), false);
}
template <class TrackerType, class TrackedType>
- void TrackerNoQueryAfterStop(TrackedType tracked_data) {
- TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
- EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
- EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
+ void TrackerNoQueryAfterStop(std::unique_ptr<TrackerType> tracker,
+ TrackedType tracked_data) {
+ EXPECT_EQ(tracker->Start(tracked_data), Error(Error::Code::kNone));
+ EXPECT_EQ(tracker->Stop(), Error(Error::Code::kNone));
EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
// Advance fake clock by a long time interval to make sure if there's a
// scheduled task, it will run.
@@ -99,17 +99,33 @@
}
template <class TrackerType, class TrackedType>
- void TrackerNoQueryAfterDestruction(TrackedType tracked_data) {
- {
- TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
- tracker.Start(tracked_data);
- }
+ void TrackerNoQueryAfterDestruction(std::unique_ptr<TrackerType> tracker,
+ TrackedType tracked_data) {
+ tracker->Start(tracked_data);
+ tracker.reset();
EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
// Advance fake clock by a long time interval to make sure if there's a
// scheduled task, it will run.
clock_.Advance(std::chrono::hours(1));
}
+ void UpdateCallback(const MdnsRecord&) { update_called_ = true; }
+ void ExpirationCallback(const MdnsRecord&) { expiration_called_ = true; }
+
+ std::unique_ptr<MdnsRecordTracker> CreateRecordTracker() {
+ return std::make_unique<MdnsRecordTracker>(
+ &sender_, &task_runner_, &FakeClock::now, &random_,
+ std::bind(&MdnsTrackerTest::UpdateCallback, this,
+ std::placeholders::_1),
+ std::bind(&MdnsTrackerTest::ExpirationCallback, this,
+ std::placeholders::_1));
+ }
+
+ std::unique_ptr<MdnsQuestionTracker> CreateQuestionTracker() {
+ return std::make_unique<MdnsQuestionTracker>(&sender_, &task_runner_,
+ &FakeClock::now, &random_);
+ }
+
protected:
// clang-format off
const std::vector<uint8_t> kQuestionQueryBytes = {
@@ -151,6 +167,9 @@
MdnsQuestion a_question_;
MdnsRecord a_record_;
+
+ bool update_called_ = false;
+ bool expiration_called_ = false;
};
// Records are re-queried at 80%, 85%, 90% and 95% TTL as per RFC 6762
@@ -162,12 +181,13 @@
// https://tools.ietf.org/html/rfc6762#section-5.2
TEST_F(MdnsTrackerTest, RecordTrackerStartStop) {
- TrackerStartStop<MdnsRecordTracker>(a_record_);
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ TrackerStartStop(std::move(tracker), a_record_);
}
TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelay) {
- MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
- tracker.Start(a_record_);
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
// Only expect 4 queries being sent, when record reaches it's TTL it's
// considered expired and another query is not sent
constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00};
@@ -184,8 +204,8 @@
}
TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) {
- MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
- tracker.Start(a_record_);
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
EXPECT_CALL(socket_, SendMessage(_, _, _))
.WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
@@ -196,16 +216,18 @@
}
TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterStop) {
- TrackerNoQueryAfterStop<MdnsRecordTracker>(a_record_);
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ TrackerNoQueryAfterStop(std::move(tracker), a_record_);
}
TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterDestruction) {
- TrackerNoQueryAfterDestruction<MdnsRecordTracker>(a_record_);
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ TrackerNoQueryAfterDestruction(std::move(tracker), a_record_);
}
TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) {
- MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
- tracker.Start(a_record_);
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
// If task runner was too busy and callback happened too late, there should be
// no query and instead the record will expire.
// Check lower bound for task being late (TTL) and an arbitrarily long time
@@ -215,6 +237,113 @@
clock_.Advance(std::chrono::hours(1));
}
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateFailsWhenNotStarted) {
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ EXPECT_EQ(tracker->Update(a_record_), Error(Error::Code::kOperationInvalid));
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateFailsForMismatchedRecord) {
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ MdnsRecord updated_record = MdnsRecord(
+ DomainName{"alpha"}, a_record_.dns_type(), a_record_.dns_class(),
+ a_record_.record_type(), a_record_.ttl(), a_record_.rdata());
+ EXPECT_EQ(tracker->Update(updated_record),
+ Error(Error::Code::kParameterInvalid));
+
+ updated_record =
+ MdnsRecord(a_record_.name(), DnsType::kPTR, a_record_.dns_class(),
+ a_record_.record_type(), a_record_.ttl(),
+ PtrRecordRdata(DomainName{"bravo"}));
+ EXPECT_EQ(tracker->Update(updated_record),
+ Error(Error::Code::kParameterInvalid));
+
+ updated_record = MdnsRecord(a_record_.name(), a_record_.dns_type(),
+ static_cast<DnsClass>(2), a_record_.record_type(),
+ a_record_.ttl(), a_record_.rdata());
+ EXPECT_EQ(tracker->Update(updated_record),
+ Error(Error::Code::kParameterInvalid));
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerCallbackOnRdataUpdate) {
+ MdnsRecord updated_record(a_record_.name(), a_record_.dns_type(),
+ a_record_.dns_class(), a_record_.record_type(),
+ a_record_.ttl(),
+ ARecordRdata(IPAddress{192, 168, 0, 1}));
+
+ update_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ EXPECT_EQ(tracker->Update(updated_record), Error::None());
+ EXPECT_TRUE(update_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoCallbackOnTtlUpdate) {
+ update_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ EXPECT_EQ(tracker->Update(a_record_), Error::None());
+ EXPECT_FALSE(update_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateResetsTtl) {
+ expiration_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ // Advance time by 60% of record's TTL
+ Clock::duration advance_time =
+ std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.6);
+ clock_.Advance(advance_time);
+ // Now update the record, this must reset expiration time
+ EXPECT_EQ(tracker->Update(a_record_), Error::None());
+ // Advance time by 60% of record's TTL again
+ clock_.Advance(advance_time);
+ // Check that expiration callback was not called
+ EXPECT_FALSE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) {
+ expiration_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ clock_.Advance(a_record_.ttl());
+ EXPECT_TRUE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) {
+ update_called_ = false;
+ expiration_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ MdnsRecord goodbye_record(a_record_.name(), a_record_.dns_type(),
+ a_record_.dns_class(), a_record_.record_type(),
+ std::chrono::seconds{0}, a_record_.rdata());
+
+ EXPECT_EQ(tracker->Update(goodbye_record), Error::None());
+ // After a goodbye record is received, expiration is schedule in a second.
+ clock_.Advance(std::chrono::seconds{1});
+ EXPECT_FALSE(update_called_);
+ EXPECT_TRUE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterStop) {
+ expiration_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ tracker->Stop();
+ clock_.Advance(a_record_.ttl());
+ EXPECT_FALSE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) {
+ expiration_called_ = false;
+ std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+ tracker->Start(a_record_);
+ tracker.reset();
+ clock_.Advance(a_record_.ttl());
+ EXPECT_FALSE(expiration_called_);
+}
+
// Initial query is delayed for up to 120 ms as per RFC 6762 Section 5.2
// Subsequent queries happen no sooner than a second after the initial query and
// the interval between the queries increases at least by a factor of 2 for each
@@ -222,13 +351,13 @@
// https://tools.ietf.org/html/rfc6762#section-5.2
TEST_F(MdnsTrackerTest, QuestionTrackerStartStop) {
- TrackerStartStop<MdnsQuestionTracker>(a_question_);
+ std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+ TrackerStartStop(std::move(tracker), a_question_);
}
TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) {
- MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
- &random_);
- tracker.Start(a_question_);
+ std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+ tracker->Start(a_question_);
EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(1);
clock_.Advance(std::chrono::milliseconds(120));
@@ -242,9 +371,8 @@
}
TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) {
- MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
- &random_);
- tracker.Start(a_question_);
+ std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+ tracker->Start(a_question_);
EXPECT_CALL(socket_, SendMessage(_, _, _))
.WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
@@ -254,11 +382,13 @@
}
TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterStop) {
- TrackerNoQueryAfterStop<MdnsQuestionTracker>(a_question_);
+ std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+ TrackerNoQueryAfterStop(std::move(tracker), a_question_);
}
TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) {
- TrackerNoQueryAfterDestruction<MdnsQuestionTracker>(a_question_);
+ std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+ TrackerNoQueryAfterDestruction(std::move(tracker), a_question_);
}
} // namespace mdns