mDNS record refresh and expiration tracker
Change-Id: Ia0cbd58d97051f329d6732de6ade8a4de5699028
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1779582
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
Reviewed-by: mark a. foltz <mfoltz@chromium.org>
diff --git a/cast/common/mdns/mdns_random.h b/cast/common/mdns/mdns_random.h
index 7061230..1d82f9c 100644
--- a/cast/common/mdns/mdns_random.h
+++ b/cast/common/mdns/mdns_random.h
@@ -43,8 +43,8 @@
static constexpr int64_t kMinimumInitialQueryDelayMs = 20;
static constexpr int64_t kMaximumInitialQueryDelayMs = 120;
- static constexpr double kMinimumTtlVariationPercent = 0.0;
- static constexpr double kMaximumTtlVariationPercent = 2.0;
+ static constexpr double kMinimumTtlVariation = 0.0;
+ static constexpr double kMaximumTtlVariation = 0.02;
static constexpr int64_t kMinimumSharedRecordResponseDelayMs = 20;
static constexpr int64_t kMaximumSharedRecordResponseDelayMs = 120;
@@ -56,7 +56,7 @@
std::uniform_int_distribution<int64_t> initial_query_delay_{
kMinimumInitialQueryDelayMs, kMaximumInitialQueryDelayMs};
std::uniform_real_distribution<double> record_ttl_variation_{
- kMinimumTtlVariationPercent, kMaximumTtlVariationPercent};
+ kMinimumTtlVariation, kMaximumTtlVariation};
std::uniform_int_distribution<int64_t> shared_record_response_delay_{
kMinimumSharedRecordResponseDelayMs, kMaximumSharedRecordResponseDelayMs};
std::uniform_int_distribution<int64_t> truncated_query_response_delay_{
diff --git a/cast/common/mdns/mdns_random_unittest.cc b/cast/common/mdns/mdns_random_unittest.cc
index edefd5a..e0a7fe5 100644
--- a/cast/common/mdns/mdns_random_unittest.cc
+++ b/cast/common/mdns/mdns_random_unittest.cc
@@ -15,44 +15,44 @@
}
TEST(MdnsRandomTest, InitialQueryDelay) {
- std::chrono::milliseconds lower_bound{20};
- std::chrono::milliseconds upper_bound{120};
+ constexpr std::chrono::milliseconds lower_bound{20};
+ constexpr std::chrono::milliseconds upper_bound{120};
MdnsRandom mdns_random;
for (int i = 0; i < kIterationCount; ++i) {
- Clock::duration delay = mdns_random.GetInitialQueryDelay();
+ const Clock::duration delay = mdns_random.GetInitialQueryDelay();
EXPECT_GE(delay, lower_bound);
EXPECT_LE(delay, upper_bound);
}
}
TEST(MdnsRandomTest, RecordTtlVariation) {
- double lower_bound = 0.0;
- double upper_bound = 2.0;
+ constexpr double lower_bound = 0.0;
+ constexpr double upper_bound = 0.02;
MdnsRandom mdns_random;
for (int i = 0; i < kIterationCount; ++i) {
- double variation = mdns_random.GetRecordTtlVariation();
+ const double variation = mdns_random.GetRecordTtlVariation();
EXPECT_GE(variation, lower_bound);
EXPECT_LE(variation, upper_bound);
}
}
TEST(MdnsRandomTest, SharedRecordResponseDelay) {
- std::chrono::milliseconds lower_bound{20};
- std::chrono::milliseconds upper_bound{120};
+ constexpr std::chrono::milliseconds lower_bound{20};
+ constexpr std::chrono::milliseconds upper_bound{120};
MdnsRandom mdns_random;
for (int i = 0; i < kIterationCount; ++i) {
- Clock::duration delay = mdns_random.GetSharedRecordResponseDelay();
+ const Clock::duration delay = mdns_random.GetSharedRecordResponseDelay();
EXPECT_GE(delay, lower_bound);
EXPECT_LE(delay, upper_bound);
}
}
TEST(MdnsRandomTest, TruncatedQueryResponseDelay) {
- std::chrono::milliseconds lower_bound{400};
- std::chrono::milliseconds upper_bound{500};
+ constexpr std::chrono::milliseconds lower_bound{400};
+ constexpr std::chrono::milliseconds upper_bound{500};
MdnsRandom mdns_random;
for (int i = 0; i < kIterationCount; ++i) {
- Clock::duration delay = mdns_random.GetTruncatedQueryResponseDelay();
+ const Clock::duration delay = mdns_random.GetTruncatedQueryResponseDelay();
EXPECT_GE(delay, lower_bound);
EXPECT_LE(delay, upper_bound);
}
diff --git a/cast/common/mdns/mdns_trackers.cc b/cast/common/mdns/mdns_trackers.cc
index 478bad3..b781103 100644
--- a/cast/common/mdns/mdns_trackers.cc
+++ b/cast/common/mdns/mdns_trackers.cc
@@ -4,9 +4,98 @@
#include "cast/common/mdns/mdns_trackers.h"
+#include <array>
+
+#include "util/std_util.h"
+
namespace cast {
namespace mdns {
+MdnsTracker::MdnsTracker(MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay)
+ : sender_(sender),
+ now_function_(now_function),
+ send_alarm_(now_function, task_runner),
+ random_delay_(random_delay) {
+ OSP_DCHECK(task_runner);
+ OSP_DCHECK(now_function);
+ OSP_DCHECK(random_delay);
+ OSP_DCHECK(sender);
+}
+
+namespace {
+// RFC 6762 Section 5.2
+// https://tools.ietf.org/html/rfc6762#section-5.2
+constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
+} // namespace
+
+MdnsRecordTracker::MdnsRecordTracker(MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay)
+ : MdnsTracker(sender, task_runner, now_function, random_delay) {}
+
+Error MdnsRecordTracker::Start(MdnsRecord record) {
+ if (record_.has_value()) {
+ return Error::Code::kOperationInvalid;
+ }
+
+ record_ = std::move(record);
+ start_time_ = now_function_();
+ send_count_ = 0;
+ send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
+ GetNextSendTime());
+ return Error::None();
+}
+
+Error MdnsRecordTracker::Stop() {
+ if (!record_.has_value()) {
+ return Error::Code::kOperationInvalid;
+ }
+
+ send_alarm_.Cancel();
+ record_.reset();
+ return Error::None();
+}
+
+bool MdnsRecordTracker::IsStarted() {
+ return record_.has_value();
+};
+
+void MdnsRecordTracker::SendQuery() {
+ const MdnsRecord& record = record_.value();
+ const Clock::time_point expiration_time = start_time_ + record.ttl();
+ const bool is_expired = (now_function_() >= expiration_time);
+ if (!is_expired) {
+ MdnsQuestion question(record.name(), record.dns_type(),
+ record.record_class(), ResponseType::kMulticast);
+ MdnsMessage message(CreateMessageId(), MessageType::Query);
+ message.AddQuestion(std::move(question));
+ sender_->SendMulticast(message);
+ send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
+ GetNextSendTime());
+ } else {
+ // TODO(yakimakha): Notify owner that the record has expired
+ }
+}
+
+Clock::time_point MdnsRecordTracker::GetNextSendTime() {
+ OSP_DCHECK(send_count_ < openscreen::countof(kTtlFractions));
+
+ double ttl_fraction = kTtlFractions[send_count_++];
+
+ // Do not add random variation to the expiration time (last fraction of TTL)
+ if (send_count_ != openscreen::countof(kTtlFractions)) {
+ ttl_fraction += random_delay_->GetRecordTtlVariation();
+ }
+
+ const Clock::duration delay = std::chrono::duration_cast<Clock::duration>(
+ record_.value().ttl() * ttl_fraction);
+ return start_time_ + delay;
+}
+
namespace {
// RFC 6762 Section 5.2
// https://tools.ietf.org/html/rfc6762#section-5.2
@@ -19,16 +108,7 @@
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay)
- : sender_(sender),
- now_function_(now_function),
- resend_alarm_(now_function, task_runner),
- random_delay_(random_delay),
- resend_delay_(kMinimumQueryInterval) {
- OSP_DCHECK(task_runner);
- OSP_DCHECK(now_function);
- OSP_DCHECK(random_delay);
- OSP_DCHECK(sender);
-}
+ : MdnsTracker(sender, task_runner, now_function, random_delay) {}
Error MdnsQuestionTracker::Start(MdnsQuestion question) {
if (question_.has_value()) {
@@ -36,12 +116,12 @@
}
question_ = std::move(question);
- resend_delay_ = kMinimumQueryInterval;
+ send_delay_ = kMinimumQueryInterval;
// The initial query has to be sent after a random delay of 20-120
// milliseconds.
- Clock::duration delay = random_delay_->GetInitialQueryDelay();
- resend_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuestion, this),
- now_function_() + delay);
+ const Clock::duration delay = random_delay_->GetInitialQueryDelay();
+ send_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuery, this),
+ now_function_() + delay);
return Error::None();
}
@@ -50,7 +130,7 @@
return Error::Code::kOperationInvalid;
}
- resend_alarm_.Cancel();
+ send_alarm_.Cancel();
question_.reset();
return Error::None();
}
@@ -59,17 +139,17 @@
return question_.has_value();
};
-void MdnsQuestionTracker::SendQuestion() {
+void MdnsQuestionTracker::SendQuery() {
MdnsMessage message(CreateMessageId(), MessageType::Query);
message.AddQuestion(question_.value());
// TODO(yakimakha): Implement known-answer suppression by adding known
// answers to the question
sender_->SendMulticast(message);
- resend_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuestion, this),
- now_function_() + resend_delay_);
- resend_delay_ = resend_delay_ * kIntervalIncreaseFactor;
- if (resend_delay_ > kMaximumQueryInterval) {
- resend_delay_ = kMaximumQueryInterval;
+ send_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuery, this),
+ now_function_() + send_delay_);
+ send_delay_ = send_delay_ * kIntervalIncreaseFactor;
+ if (send_delay_ > kMaximumQueryInterval) {
+ send_delay_ = kMaximumQueryInterval;
}
}
diff --git a/cast/common/mdns/mdns_trackers.h b/cast/common/mdns/mdns_trackers.h
index 4c87921..7a69e8b 100644
--- a/cast/common/mdns/mdns_trackers.h
+++ b/cast/common/mdns/mdns_trackers.h
@@ -22,52 +22,101 @@
using openscreen::platform::ClockNowFunctionPtr;
using openscreen::platform::TaskRunner;
+// MdnsTracker is a base class for MdnsRecordTracker and MdnsQuestionTracker for
+// the purposes of common code sharing only
+class MdnsTracker {
+ public:
+ // MdnsTracker does not own |sender|, |task_runner| and |random_delay|
+ // and expects that the lifetime of these objects exceeds the lifetime of
+ // MdnsTracker.
+ MdnsTracker(MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay);
+
+ MdnsTracker(const MdnsTracker& other) = delete;
+ MdnsTracker(MdnsTracker&& other) noexcept = delete;
+ ~MdnsTracker() = default;
+
+ MdnsTracker& operator=(const MdnsTracker& other) = delete;
+ MdnsTracker& operator=(MdnsTracker&& other) noexcept = delete;
+
+ protected:
+ MdnsSender* const sender_;
+ const ClockNowFunctionPtr now_function_;
+ Alarm send_alarm_; // TODO(yakimakha): Use cancelable task when available
+ MdnsRandom* const random_delay_;
+};
+
+// MdnsRecordTracker manages automatic resending of mDNS queries for
+// refreshing records as they reach their expiration time.
+class MdnsRecordTracker : public MdnsTracker {
+ public:
+ MdnsRecordTracker(MdnsSender* sender,
+ TaskRunner* task_runner,
+ ClockNowFunctionPtr now_function,
+ MdnsRandom* random_delay);
+
+ // Starts sending query messages for the provided record using record's TTL
+ // and the time of the call to determine when to send the queries. Returns
+ // error with code Error::Code::kOperationInvalid if called on an instance of
+ // MdnsRecordTracker that has already been started.
+ Error Start(MdnsRecord record);
+
+ // Stops sending query for automatic record refresh. This cancels record
+ // expiration notification as well. Returns error
+ // with code Error::Code::kOperationInvalid if called on an instance of
+ // MdnsRecordTracker that has not yet been started or has already been
+ // stopped.
+ Error Stop();
+
+ // Returns true if MdnsRecordTracker instance has been started and is
+ // automatically refreshing the record, false otherwise.
+ bool IsStarted();
+
+ private:
+ void SendQuery();
+ Clock::time_point GetNextSendTime();
+
+ // Stores MdnsRecord provided to Start method call.
+ absl::optional<MdnsRecord> record_;
+ // A point in time when the record was received and the tracking has started.
+ Clock::time_point start_time_;
+ // Number of times a question to refresh the record has been sent.
+ size_t send_count_ = 0;
+};
+
// MdnsQuestionTracker manages automatic resending of mDNS queries for
// continuous monitoring with exponential back-off as described in RFC 6762
-class MdnsQuestionTracker {
+class MdnsQuestionTracker : public MdnsTracker {
public:
- // MdnsQuestionTracker does not own |sender|, |task_runner| and |random_delay|
- // and expects that the lifetime of these objects exceeds the lifetime of
- // MdnsQuestionTracker.
MdnsQuestionTracker(MdnsSender* sender,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay);
- MdnsQuestionTracker(const MdnsQuestionTracker& other) = delete;
- MdnsQuestionTracker(MdnsQuestionTracker&& other) noexcept = delete;
- ~MdnsQuestionTracker() = default;
-
- MdnsQuestionTracker& operator=(const MdnsQuestionTracker& other) = delete;
- MdnsQuestionTracker& operator=(MdnsQuestionTracker&& other) noexcept = delete;
-
// Starts sending query messages for the provided question. Returns error with
// code Error::Code::kOperationInvalid if called on an instance of
// MdnsQuestionTracker that has already been started.
Error Start(MdnsQuestion question);
- // Stop sending query messages and resets the querying interval. Returns error
- // with code Error::Code::kOperationInvalid if called on an instance of
+ // Stops sending query messages and resets the querying interval. Returns
+ // error with code Error::Code::kOperationInvalid if called on an instance of
// MdnsQuestionTracker that has not yet been started or has already been
// stopped.
Error Stop();
- // Return true if MdnsQuestionTracker instance has been started and is
+ // Returns true if MdnsQuestionTracker instance has been started and is
// automatically sending queries, false otherwise.
bool IsStarted();
private:
// Sends a query message via MdnsSender and schedules the next resend.
- void SendQuestion();
-
- MdnsSender* const sender_;
- const ClockNowFunctionPtr now_function_;
- Alarm resend_alarm_; // TODO(yakimakha): Use cancelable task when available
- MdnsRandom* const random_delay_;
+ void SendQuery();
// Stores MdnsQuestion provided to Start method call.
absl::optional<MdnsQuestion> question_;
- Clock::duration resend_delay_;
+ Clock::duration send_delay_;
};
} // namespace mdns
diff --git a/cast/common/mdns/mdns_trackers_unittest.cc b/cast/common/mdns/mdns_trackers_unittest.cc
index e16c5c0..334ff2b 100644
--- a/cast/common/mdns/mdns_trackers_unittest.cc
+++ b/cast/common/mdns/mdns_trackers_unittest.cc
@@ -17,7 +17,19 @@
using openscreen::platform::NetworkInterfaceIndex;
using ::testing::_;
using ::testing::Args;
+using ::testing::Invoke;
using ::testing::Return;
+using ::testing::WithArgs;
+
+ACTION_P2(VerifyMessageBytesWithoutId, expected_data, expected_size) {
+ const uint8_t* actual_data = reinterpret_cast<const uint8_t*>(arg0);
+ const size_t actual_size = arg1;
+ EXPECT_EQ(actual_size, expected_size);
+ // Start at bytes[2] to skip a generated message ID.
+ for (size_t i = 2; i < actual_size; ++i) {
+ EXPECT_EQ(actual_data[i], expected_data[i]);
+ }
+}
class MockUdpSocket : public UdpSocket {
public:
@@ -50,13 +62,72 @@
socket_(&task_runner_),
sender_(&socket_),
a_question_(DomainName{"testing", "local"},
- DnsType::kA,
+ DnsType::kANY,
DnsClass::kIN,
- ResponseType::kMulticast) {}
+ ResponseType::kMulticast),
+ a_record_(DomainName{"testing", "local"},
+ DnsType::kA,
+ DnsClass::kIN,
+ RecordType::kShared,
+ std::chrono::seconds(120),
+ ARecordRdata(IPAddress{172, 0, 0, 1})) {}
+
+ template <class TrackerType, class TrackedType>
+ void TrackerStartStop(TrackedType tracked_data) {
+ TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+ EXPECT_EQ(tracker.IsStarted(), false);
+ EXPECT_EQ(tracker.Stop(), Error(Error::Code::kOperationInvalid));
+ EXPECT_EQ(tracker.IsStarted(), false);
+ EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
+ EXPECT_EQ(tracker.IsStarted(), true);
+ EXPECT_EQ(tracker.Start(tracked_data),
+ Error(Error::Code::kOperationInvalid));
+ EXPECT_EQ(tracker.IsStarted(), true);
+ EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
+ EXPECT_EQ(tracker.IsStarted(), false);
+ }
+
+ template <class TrackerType, class TrackedType>
+ void TrackerNoQueryAfterStop(TrackedType tracked_data) {
+ TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+ EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
+ EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
+ EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
+ // Advance fake clock by a long time interval to make sure if there's a
+ // scheduled task, it will run.
+ clock_.Advance(std::chrono::hours(1));
+ }
+
+ template <class TrackerType, class TrackedType>
+ void TrackerNoQueryAfterDestruction(TrackedType tracked_data) {
+ {
+ TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+ tracker.Start(tracked_data);
+ }
+ EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
+ // Advance fake clock by a long time interval to make sure if there's a
+ // scheduled task, it will run.
+ clock_.Advance(std::chrono::hours(1));
+ }
protected:
// clang-format off
- const std::vector<uint8_t> kQueryBytes = {
+ const std::vector<uint8_t> kQuestionQueryBytes = {
+ 0x00, 0x00, // ID = 0
+ 0x00, 0x00, // FLAGS = None
+ 0x00, 0x01, // Question count
+ 0x00, 0x00, // Answer count
+ 0x00, 0x00, // Authority count
+ 0x00, 0x00, // Additional count
+ // Question
+ 0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0xFF, // TYPE = ANY (255)
+ 0x00, 0x01, // CLASS = IN (1)
+ };
+
+ const std::vector<uint8_t> kRecordQueryBytes = {
0x00, 0x00, // ID = 0
0x00, 0x00, // FLAGS = None
0x00, 0x01, // Question count
@@ -70,6 +141,7 @@
0x00, 0x01, // TYPE = A (1)
0x00, 0x01, // CLASS = IN (1)
};
+
// clang-format on
FakeClock clock_;
FakeTaskRunner task_runner_;
@@ -78,20 +150,69 @@
MdnsRandom random_;
MdnsQuestion a_question_;
+ MdnsRecord a_record_;
};
-TEST_F(MdnsTrackerTest, QueryTrackerStartStop) {
- MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
- &random_);
- EXPECT_EQ(tracker.IsStarted(), false);
- EXPECT_EQ(tracker.Stop(), Error(Error::Code::kOperationInvalid));
- EXPECT_EQ(tracker.IsStarted(), false);
- EXPECT_EQ(tracker.Start(a_question_), Error(Error::Code::kNone));
- EXPECT_EQ(tracker.IsStarted(), true);
- EXPECT_EQ(tracker.Start(a_question_), Error(Error::Code::kOperationInvalid));
- EXPECT_EQ(tracker.IsStarted(), true);
- EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
- EXPECT_EQ(tracker.IsStarted(), false);
+// Records are re-queried at 80%, 85%, 90% and 95% TTL as per RFC 6762
+// Section 5.2 There are no subsequent queries to refresh the record after that,
+// the record is expired after TTL has passed since the start of tracking.
+// Random variance required is from 0% to 2%, making these times at most 82%,
+// 87%, 92% and 97% TTL. Fake clock is advanced to 83%, 88%, 93% and 98% to make
+// sure that task gets executed.
+// https://tools.ietf.org/html/rfc6762#section-5.2
+
+TEST_F(MdnsTrackerTest, RecordTrackerStartStop) {
+ TrackerStartStop<MdnsRecordTracker>(a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelay) {
+ MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+ tracker.Start(a_record_);
+ // Only expect 4 queries being sent, when record reaches it's TTL it's
+ // considered expired and another query is not sent
+ constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00};
+ EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(4);
+
+ Clock::duration time_passed{0};
+ for (double fraction : kTtlFractions) {
+ Clock::duration time_till_refresh =
+ std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * fraction);
+ Clock::duration delta = time_till_refresh - time_passed;
+ time_passed = time_till_refresh;
+ clock_.Advance(delta);
+ }
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) {
+ MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+ tracker.Start(a_record_);
+
+ EXPECT_CALL(socket_, SendMessage(_, _, _))
+ .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
+ kRecordQueryBytes.data(), kRecordQueryBytes.size())));
+
+ clock_.Advance(
+ std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.83));
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterStop) {
+ TrackerNoQueryAfterStop<MdnsRecordTracker>(a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterDestruction) {
+ TrackerNoQueryAfterDestruction<MdnsRecordTracker>(a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) {
+ MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+ tracker.Start(a_record_);
+ // If task runner was too busy and callback happened too late, there should be
+ // no query and instead the record will expire.
+ // Check lower bound for task being late (TTL) and an arbitrarily long time
+ // interval to ensure the query is not sent a later time.
+ EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
+ clock_.Advance(a_record_.ttl());
+ clock_.Advance(std::chrono::hours(1));
}
// Initial query is delayed for up to 120 ms as per RFC 6762 Section 5.2
@@ -100,7 +221,11 @@
// next query up until it's capped at 1 hour.
// https://tools.ietf.org/html/rfc6762#section-5.2
-TEST_F(MdnsTrackerTest, QueryTrackerQueryAfterDelay) {
+TEST_F(MdnsTrackerTest, QuestionTrackerStartStop) {
+ TrackerStartStop<MdnsQuestionTracker>(a_question_);
+}
+
+TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) {
MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
&random_);
tracker.Start(a_question_);
@@ -116,46 +241,24 @@
}
}
-TEST_F(MdnsTrackerTest, QueryTrackerSendsMessage) {
+TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) {
MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
&random_);
tracker.Start(a_question_);
- EXPECT_CALL(socket_, SendMessage(_, kQueryBytes.size(), _))
- .WillOnce(testing::WithArgs<0, 1>(
- testing::Invoke([this](const void* data, size_t size) {
- EXPECT_EQ(size, kQueryBytes.size());
- const uint8_t* bytes = static_cast<const uint8_t*>(data);
- // Start at bytes[2] to skip a generated message ID.
- for (size_t i = 2; i < size; ++i) {
- EXPECT_EQ(bytes[i], kQueryBytes[i]);
- }
- })));
+ EXPECT_CALL(socket_, SendMessage(_, _, _))
+ .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
+ kQuestionQueryBytes.data(), kQuestionQueryBytes.size())));
+
clock_.Advance(std::chrono::milliseconds(120));
}
-TEST_F(MdnsTrackerTest, QueryTrackerNoQueryAfterStop) {
- MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
- &random_);
-
- EXPECT_EQ(tracker.Start(a_question_), Error(Error::Code::kNone));
- EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
- // Advance fake clock by a long time interval to make sure if there's a
- // scheduled task, it will run.
- clock_.Advance(std::chrono::hours(1));
+TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterStop) {
+ TrackerNoQueryAfterStop<MdnsQuestionTracker>(a_question_);
}
-TEST_F(MdnsTrackerTest, QueryTrackerNoQueryAfterDestruction) {
- {
- MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
- &random_);
- tracker.Start(a_question_);
- }
- EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
- // Advance fake clock by a long time interval to make sure if there's a
- // scheduled task, it will run.
- clock_.Advance(std::chrono::hours(1));
+TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) {
+ TrackerNoQueryAfterDestruction<MdnsQuestionTracker>(a_question_);
}
} // namespace mdns