mDNS record refresh and expiration tracker

Change-Id: Ia0cbd58d97051f329d6732de6ade8a4de5699028
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1779582
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
Reviewed-by: mark a. foltz <mfoltz@chromium.org>
diff --git a/cast/common/mdns/mdns_random.h b/cast/common/mdns/mdns_random.h
index 7061230..1d82f9c 100644
--- a/cast/common/mdns/mdns_random.h
+++ b/cast/common/mdns/mdns_random.h
@@ -43,8 +43,8 @@
   static constexpr int64_t kMinimumInitialQueryDelayMs = 20;
   static constexpr int64_t kMaximumInitialQueryDelayMs = 120;
 
-  static constexpr double kMinimumTtlVariationPercent = 0.0;
-  static constexpr double kMaximumTtlVariationPercent = 2.0;
+  static constexpr double kMinimumTtlVariation = 0.0;
+  static constexpr double kMaximumTtlVariation = 0.02;
 
   static constexpr int64_t kMinimumSharedRecordResponseDelayMs = 20;
   static constexpr int64_t kMaximumSharedRecordResponseDelayMs = 120;
@@ -56,7 +56,7 @@
   std::uniform_int_distribution<int64_t> initial_query_delay_{
       kMinimumInitialQueryDelayMs, kMaximumInitialQueryDelayMs};
   std::uniform_real_distribution<double> record_ttl_variation_{
-      kMinimumTtlVariationPercent, kMaximumTtlVariationPercent};
+      kMinimumTtlVariation, kMaximumTtlVariation};
   std::uniform_int_distribution<int64_t> shared_record_response_delay_{
       kMinimumSharedRecordResponseDelayMs, kMaximumSharedRecordResponseDelayMs};
   std::uniform_int_distribution<int64_t> truncated_query_response_delay_{
diff --git a/cast/common/mdns/mdns_random_unittest.cc b/cast/common/mdns/mdns_random_unittest.cc
index edefd5a..e0a7fe5 100644
--- a/cast/common/mdns/mdns_random_unittest.cc
+++ b/cast/common/mdns/mdns_random_unittest.cc
@@ -15,44 +15,44 @@
 }
 
 TEST(MdnsRandomTest, InitialQueryDelay) {
-  std::chrono::milliseconds lower_bound{20};
-  std::chrono::milliseconds upper_bound{120};
+  constexpr std::chrono::milliseconds lower_bound{20};
+  constexpr std::chrono::milliseconds upper_bound{120};
   MdnsRandom mdns_random;
   for (int i = 0; i < kIterationCount; ++i) {
-    Clock::duration delay = mdns_random.GetInitialQueryDelay();
+    const Clock::duration delay = mdns_random.GetInitialQueryDelay();
     EXPECT_GE(delay, lower_bound);
     EXPECT_LE(delay, upper_bound);
   }
 }
 
 TEST(MdnsRandomTest, RecordTtlVariation) {
-  double lower_bound = 0.0;
-  double upper_bound = 2.0;
+  constexpr double lower_bound = 0.0;
+  constexpr double upper_bound = 0.02;
   MdnsRandom mdns_random;
   for (int i = 0; i < kIterationCount; ++i) {
-    double variation = mdns_random.GetRecordTtlVariation();
+    const double variation = mdns_random.GetRecordTtlVariation();
     EXPECT_GE(variation, lower_bound);
     EXPECT_LE(variation, upper_bound);
   }
 }
 
 TEST(MdnsRandomTest, SharedRecordResponseDelay) {
-  std::chrono::milliseconds lower_bound{20};
-  std::chrono::milliseconds upper_bound{120};
+  constexpr std::chrono::milliseconds lower_bound{20};
+  constexpr std::chrono::milliseconds upper_bound{120};
   MdnsRandom mdns_random;
   for (int i = 0; i < kIterationCount; ++i) {
-    Clock::duration delay = mdns_random.GetSharedRecordResponseDelay();
+    const Clock::duration delay = mdns_random.GetSharedRecordResponseDelay();
     EXPECT_GE(delay, lower_bound);
     EXPECT_LE(delay, upper_bound);
   }
 }
 
 TEST(MdnsRandomTest, TruncatedQueryResponseDelay) {
-  std::chrono::milliseconds lower_bound{400};
-  std::chrono::milliseconds upper_bound{500};
+  constexpr std::chrono::milliseconds lower_bound{400};
+  constexpr std::chrono::milliseconds upper_bound{500};
   MdnsRandom mdns_random;
   for (int i = 0; i < kIterationCount; ++i) {
-    Clock::duration delay = mdns_random.GetTruncatedQueryResponseDelay();
+    const Clock::duration delay = mdns_random.GetTruncatedQueryResponseDelay();
     EXPECT_GE(delay, lower_bound);
     EXPECT_LE(delay, upper_bound);
   }
diff --git a/cast/common/mdns/mdns_trackers.cc b/cast/common/mdns/mdns_trackers.cc
index 478bad3..b781103 100644
--- a/cast/common/mdns/mdns_trackers.cc
+++ b/cast/common/mdns/mdns_trackers.cc
@@ -4,9 +4,98 @@
 
 #include "cast/common/mdns/mdns_trackers.h"
 
+#include <array>
+
+#include "util/std_util.h"
+
 namespace cast {
 namespace mdns {
 
+MdnsTracker::MdnsTracker(MdnsSender* sender,
+                         TaskRunner* task_runner,
+                         ClockNowFunctionPtr now_function,
+                         MdnsRandom* random_delay)
+    : sender_(sender),
+      now_function_(now_function),
+      send_alarm_(now_function, task_runner),
+      random_delay_(random_delay) {
+  OSP_DCHECK(task_runner);
+  OSP_DCHECK(now_function);
+  OSP_DCHECK(random_delay);
+  OSP_DCHECK(sender);
+}
+
+namespace {
+// RFC 6762 Section 5.2
+// https://tools.ietf.org/html/rfc6762#section-5.2
+constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
+}  // namespace
+
+MdnsRecordTracker::MdnsRecordTracker(MdnsSender* sender,
+                                     TaskRunner* task_runner,
+                                     ClockNowFunctionPtr now_function,
+                                     MdnsRandom* random_delay)
+    : MdnsTracker(sender, task_runner, now_function, random_delay) {}
+
+Error MdnsRecordTracker::Start(MdnsRecord record) {
+  if (record_.has_value()) {
+    return Error::Code::kOperationInvalid;
+  }
+
+  record_ = std::move(record);
+  start_time_ = now_function_();
+  send_count_ = 0;
+  send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
+                       GetNextSendTime());
+  return Error::None();
+}
+
+Error MdnsRecordTracker::Stop() {
+  if (!record_.has_value()) {
+    return Error::Code::kOperationInvalid;
+  }
+
+  send_alarm_.Cancel();
+  record_.reset();
+  return Error::None();
+}
+
+bool MdnsRecordTracker::IsStarted() {
+  return record_.has_value();
+};
+
+void MdnsRecordTracker::SendQuery() {
+  const MdnsRecord& record = record_.value();
+  const Clock::time_point expiration_time = start_time_ + record.ttl();
+  const bool is_expired = (now_function_() >= expiration_time);
+  if (!is_expired) {
+    MdnsQuestion question(record.name(), record.dns_type(),
+                          record.record_class(), ResponseType::kMulticast);
+    MdnsMessage message(CreateMessageId(), MessageType::Query);
+    message.AddQuestion(std::move(question));
+    sender_->SendMulticast(message);
+    send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
+                         GetNextSendTime());
+  } else {
+    // TODO(yakimakha): Notify owner that the record has expired
+  }
+}
+
+Clock::time_point MdnsRecordTracker::GetNextSendTime() {
+  OSP_DCHECK(send_count_ < openscreen::countof(kTtlFractions));
+
+  double ttl_fraction = kTtlFractions[send_count_++];
+
+  // Do not add random variation to the expiration time (last fraction of TTL)
+  if (send_count_ != openscreen::countof(kTtlFractions)) {
+    ttl_fraction += random_delay_->GetRecordTtlVariation();
+  }
+
+  const Clock::duration delay = std::chrono::duration_cast<Clock::duration>(
+      record_.value().ttl() * ttl_fraction);
+  return start_time_ + delay;
+}
+
 namespace {
 // RFC 6762 Section 5.2
 // https://tools.ietf.org/html/rfc6762#section-5.2
@@ -19,16 +108,7 @@
                                          TaskRunner* task_runner,
                                          ClockNowFunctionPtr now_function,
                                          MdnsRandom* random_delay)
-    : sender_(sender),
-      now_function_(now_function),
-      resend_alarm_(now_function, task_runner),
-      random_delay_(random_delay),
-      resend_delay_(kMinimumQueryInterval) {
-  OSP_DCHECK(task_runner);
-  OSP_DCHECK(now_function);
-  OSP_DCHECK(random_delay);
-  OSP_DCHECK(sender);
-}
+    : MdnsTracker(sender, task_runner, now_function, random_delay) {}
 
 Error MdnsQuestionTracker::Start(MdnsQuestion question) {
   if (question_.has_value()) {
@@ -36,12 +116,12 @@
   }
 
   question_ = std::move(question);
-  resend_delay_ = kMinimumQueryInterval;
+  send_delay_ = kMinimumQueryInterval;
   // The initial query has to be sent after a random delay of 20-120
   // milliseconds.
-  Clock::duration delay = random_delay_->GetInitialQueryDelay();
-  resend_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuestion, this),
-                         now_function_() + delay);
+  const Clock::duration delay = random_delay_->GetInitialQueryDelay();
+  send_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuery, this),
+                       now_function_() + delay);
   return Error::None();
 }
 
@@ -50,7 +130,7 @@
     return Error::Code::kOperationInvalid;
   }
 
-  resend_alarm_.Cancel();
+  send_alarm_.Cancel();
   question_.reset();
   return Error::None();
 }
@@ -59,17 +139,17 @@
   return question_.has_value();
 };
 
-void MdnsQuestionTracker::SendQuestion() {
+void MdnsQuestionTracker::SendQuery() {
   MdnsMessage message(CreateMessageId(), MessageType::Query);
   message.AddQuestion(question_.value());
   // TODO(yakimakha): Implement known-answer suppression by adding known
   // answers to the question
   sender_->SendMulticast(message);
-  resend_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuestion, this),
-                         now_function_() + resend_delay_);
-  resend_delay_ = resend_delay_ * kIntervalIncreaseFactor;
-  if (resend_delay_ > kMaximumQueryInterval) {
-    resend_delay_ = kMaximumQueryInterval;
+  send_alarm_.Schedule(std::bind(&MdnsQuestionTracker::SendQuery, this),
+                       now_function_() + send_delay_);
+  send_delay_ = send_delay_ * kIntervalIncreaseFactor;
+  if (send_delay_ > kMaximumQueryInterval) {
+    send_delay_ = kMaximumQueryInterval;
   }
 }
 
diff --git a/cast/common/mdns/mdns_trackers.h b/cast/common/mdns/mdns_trackers.h
index 4c87921..7a69e8b 100644
--- a/cast/common/mdns/mdns_trackers.h
+++ b/cast/common/mdns/mdns_trackers.h
@@ -22,52 +22,101 @@
 using openscreen::platform::ClockNowFunctionPtr;
 using openscreen::platform::TaskRunner;
 
+// MdnsTracker is a base class for MdnsRecordTracker and MdnsQuestionTracker for
+// the purposes of common code sharing only
+class MdnsTracker {
+ public:
+  // MdnsTracker does not own |sender|, |task_runner| and |random_delay|
+  // and expects that the lifetime of these objects exceeds the lifetime of
+  // MdnsTracker.
+  MdnsTracker(MdnsSender* sender,
+              TaskRunner* task_runner,
+              ClockNowFunctionPtr now_function,
+              MdnsRandom* random_delay);
+
+  MdnsTracker(const MdnsTracker& other) = delete;
+  MdnsTracker(MdnsTracker&& other) noexcept = delete;
+  ~MdnsTracker() = default;
+
+  MdnsTracker& operator=(const MdnsTracker& other) = delete;
+  MdnsTracker& operator=(MdnsTracker&& other) noexcept = delete;
+
+ protected:
+  MdnsSender* const sender_;
+  const ClockNowFunctionPtr now_function_;
+  Alarm send_alarm_;  // TODO(yakimakha): Use cancelable task when available
+  MdnsRandom* const random_delay_;
+};
+
+// MdnsRecordTracker manages automatic resending of mDNS queries for
+// refreshing records as they reach their expiration time.
+class MdnsRecordTracker : public MdnsTracker {
+ public:
+  MdnsRecordTracker(MdnsSender* sender,
+                    TaskRunner* task_runner,
+                    ClockNowFunctionPtr now_function,
+                    MdnsRandom* random_delay);
+
+  // Starts sending query messages for the provided record using record's TTL
+  // and the time of the call to determine when to send the queries. Returns
+  // error with code Error::Code::kOperationInvalid if called on an instance of
+  // MdnsRecordTracker that has already been started.
+  Error Start(MdnsRecord record);
+
+  // Stops sending query for automatic record refresh. This cancels record
+  // expiration notification as well. Returns error
+  // with code Error::Code::kOperationInvalid if called on an instance of
+  // MdnsRecordTracker that has not yet been started or has already been
+  // stopped.
+  Error Stop();
+
+  // Returns true if MdnsRecordTracker instance has been started and is
+  // automatically refreshing the record, false otherwise.
+  bool IsStarted();
+
+ private:
+  void SendQuery();
+  Clock::time_point GetNextSendTime();
+
+  // Stores MdnsRecord provided to Start method call.
+  absl::optional<MdnsRecord> record_;
+  // A point in time when the record was received and the tracking has started.
+  Clock::time_point start_time_;
+  // Number of times a question to refresh the record has been sent.
+  size_t send_count_ = 0;
+};
+
 // MdnsQuestionTracker manages automatic resending of mDNS queries for
 // continuous monitoring with exponential back-off as described in RFC 6762
-class MdnsQuestionTracker {
+class MdnsQuestionTracker : public MdnsTracker {
  public:
-  // MdnsQuestionTracker does not own |sender|, |task_runner| and |random_delay|
-  // and expects that the lifetime of these objects exceeds the lifetime of
-  // MdnsQuestionTracker.
   MdnsQuestionTracker(MdnsSender* sender,
                       TaskRunner* task_runner,
                       ClockNowFunctionPtr now_function,
                       MdnsRandom* random_delay);
 
-  MdnsQuestionTracker(const MdnsQuestionTracker& other) = delete;
-  MdnsQuestionTracker(MdnsQuestionTracker&& other) noexcept = delete;
-  ~MdnsQuestionTracker() = default;
-
-  MdnsQuestionTracker& operator=(const MdnsQuestionTracker& other) = delete;
-  MdnsQuestionTracker& operator=(MdnsQuestionTracker&& other) noexcept = delete;
-
   // Starts sending query messages for the provided question. Returns error with
   // code Error::Code::kOperationInvalid if called on an instance of
   // MdnsQuestionTracker that has already been started.
   Error Start(MdnsQuestion question);
 
-  // Stop sending query messages and resets the querying interval. Returns error
-  // with code Error::Code::kOperationInvalid if called on an instance of
+  // Stops sending query messages and resets the querying interval. Returns
+  // error with code Error::Code::kOperationInvalid if called on an instance of
   // MdnsQuestionTracker that has not yet been started or has already been
   // stopped.
   Error Stop();
 
-  // Return true if MdnsQuestionTracker instance has been started and is
+  // Returns true if MdnsQuestionTracker instance has been started and is
   // automatically sending queries, false otherwise.
   bool IsStarted();
 
  private:
   // Sends a query message via MdnsSender and schedules the next resend.
-  void SendQuestion();
-
-  MdnsSender* const sender_;
-  const ClockNowFunctionPtr now_function_;
-  Alarm resend_alarm_;  // TODO(yakimakha): Use cancelable task when available
-  MdnsRandom* const random_delay_;
+  void SendQuery();
 
   // Stores MdnsQuestion provided to Start method call.
   absl::optional<MdnsQuestion> question_;
-  Clock::duration resend_delay_;
+  Clock::duration send_delay_;
 };
 
 }  // namespace mdns
diff --git a/cast/common/mdns/mdns_trackers_unittest.cc b/cast/common/mdns/mdns_trackers_unittest.cc
index e16c5c0..334ff2b 100644
--- a/cast/common/mdns/mdns_trackers_unittest.cc
+++ b/cast/common/mdns/mdns_trackers_unittest.cc
@@ -17,7 +17,19 @@
 using openscreen::platform::NetworkInterfaceIndex;
 using ::testing::_;
 using ::testing::Args;
+using ::testing::Invoke;
 using ::testing::Return;
+using ::testing::WithArgs;
+
+ACTION_P2(VerifyMessageBytesWithoutId, expected_data, expected_size) {
+  const uint8_t* actual_data = reinterpret_cast<const uint8_t*>(arg0);
+  const size_t actual_size = arg1;
+  EXPECT_EQ(actual_size, expected_size);
+  // Start at bytes[2] to skip a generated message ID.
+  for (size_t i = 2; i < actual_size; ++i) {
+    EXPECT_EQ(actual_data[i], expected_data[i]);
+  }
+}
 
 class MockUdpSocket : public UdpSocket {
  public:
@@ -50,13 +62,72 @@
         socket_(&task_runner_),
         sender_(&socket_),
         a_question_(DomainName{"testing", "local"},
-                    DnsType::kA,
+                    DnsType::kANY,
                     DnsClass::kIN,
-                    ResponseType::kMulticast) {}
+                    ResponseType::kMulticast),
+        a_record_(DomainName{"testing", "local"},
+                  DnsType::kA,
+                  DnsClass::kIN,
+                  RecordType::kShared,
+                  std::chrono::seconds(120),
+                  ARecordRdata(IPAddress{172, 0, 0, 1})) {}
+
+  template <class TrackerType, class TrackedType>
+  void TrackerStartStop(TrackedType tracked_data) {
+    TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+    EXPECT_EQ(tracker.IsStarted(), false);
+    EXPECT_EQ(tracker.Stop(), Error(Error::Code::kOperationInvalid));
+    EXPECT_EQ(tracker.IsStarted(), false);
+    EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
+    EXPECT_EQ(tracker.IsStarted(), true);
+    EXPECT_EQ(tracker.Start(tracked_data),
+              Error(Error::Code::kOperationInvalid));
+    EXPECT_EQ(tracker.IsStarted(), true);
+    EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
+    EXPECT_EQ(tracker.IsStarted(), false);
+  }
+
+  template <class TrackerType, class TrackedType>
+  void TrackerNoQueryAfterStop(TrackedType tracked_data) {
+    TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+    EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
+    EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
+    EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
+    // Advance fake clock by a long time interval to make sure if there's a
+    // scheduled task, it will run.
+    clock_.Advance(std::chrono::hours(1));
+  }
+
+  template <class TrackerType, class TrackedType>
+  void TrackerNoQueryAfterDestruction(TrackedType tracked_data) {
+    {
+      TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+      tracker.Start(tracked_data);
+    }
+    EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
+    // Advance fake clock by a long time interval to make sure if there's a
+    // scheduled task, it will run.
+    clock_.Advance(std::chrono::hours(1));
+  }
 
  protected:
   // clang-format off
-  const std::vector<uint8_t> kQueryBytes = {
+  const std::vector<uint8_t> kQuestionQueryBytes = {
+      0x00, 0x00,  // ID = 0
+      0x00, 0x00,  // FLAGS = None
+      0x00, 0x01,  // Question count
+      0x00, 0x00,  // Answer count
+      0x00, 0x00,  // Authority count
+      0x00, 0x00,  // Additional count
+      // Question
+      0x07, 't', 'e', 's', 't', 'i', 'n', 'g',
+      0x05, 'l', 'o', 'c', 'a', 'l',
+      0x00,
+      0x00, 0xFF,  // TYPE = ANY (255)
+      0x00, 0x01,  // CLASS = IN (1)
+  };
+
+  const std::vector<uint8_t> kRecordQueryBytes = {
       0x00, 0x00,  // ID = 0
       0x00, 0x00,  // FLAGS = None
       0x00, 0x01,  // Question count
@@ -70,6 +141,7 @@
       0x00, 0x01,  // TYPE = A (1)
       0x00, 0x01,  // CLASS = IN (1)
   };
+
   // clang-format on
   FakeClock clock_;
   FakeTaskRunner task_runner_;
@@ -78,20 +150,69 @@
   MdnsRandom random_;
 
   MdnsQuestion a_question_;
+  MdnsRecord a_record_;
 };
 
-TEST_F(MdnsTrackerTest, QueryTrackerStartStop) {
-  MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
-                              &random_);
-  EXPECT_EQ(tracker.IsStarted(), false);
-  EXPECT_EQ(tracker.Stop(), Error(Error::Code::kOperationInvalid));
-  EXPECT_EQ(tracker.IsStarted(), false);
-  EXPECT_EQ(tracker.Start(a_question_), Error(Error::Code::kNone));
-  EXPECT_EQ(tracker.IsStarted(), true);
-  EXPECT_EQ(tracker.Start(a_question_), Error(Error::Code::kOperationInvalid));
-  EXPECT_EQ(tracker.IsStarted(), true);
-  EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
-  EXPECT_EQ(tracker.IsStarted(), false);
+// Records are re-queried at 80%, 85%, 90% and 95% TTL as per RFC 6762
+// Section 5.2 There are no subsequent queries to refresh the record after that,
+// the record is expired after TTL has passed since the start of tracking.
+// Random variance required is from 0% to 2%, making these times at most 82%,
+// 87%, 92% and 97% TTL. Fake clock is advanced to 83%, 88%, 93% and 98% to make
+// sure that task gets executed.
+// https://tools.ietf.org/html/rfc6762#section-5.2
+
+TEST_F(MdnsTrackerTest, RecordTrackerStartStop) {
+  TrackerStartStop<MdnsRecordTracker>(a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelay) {
+  MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+  tracker.Start(a_record_);
+  // Only expect 4 queries being sent, when record reaches it's TTL it's
+  // considered expired and another query is not sent
+  constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00};
+  EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(4);
+
+  Clock::duration time_passed{0};
+  for (double fraction : kTtlFractions) {
+    Clock::duration time_till_refresh =
+        std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * fraction);
+    Clock::duration delta = time_till_refresh - time_passed;
+    time_passed = time_till_refresh;
+    clock_.Advance(delta);
+  }
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) {
+  MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+  tracker.Start(a_record_);
+
+  EXPECT_CALL(socket_, SendMessage(_, _, _))
+      .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
+          kRecordQueryBytes.data(), kRecordQueryBytes.size())));
+
+  clock_.Advance(
+      std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.83));
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterStop) {
+  TrackerNoQueryAfterStop<MdnsRecordTracker>(a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterDestruction) {
+  TrackerNoQueryAfterDestruction<MdnsRecordTracker>(a_record_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) {
+  MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
+  tracker.Start(a_record_);
+  // If task runner was too busy and callback happened too late, there should be
+  // no query and instead the record will expire.
+  // Check lower bound for task being late (TTL) and an arbitrarily long time
+  // interval to ensure the query is not sent a later time.
+  EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
+  clock_.Advance(a_record_.ttl());
+  clock_.Advance(std::chrono::hours(1));
 }
 
 // Initial query is delayed for up to 120 ms as per RFC 6762 Section 5.2
@@ -100,7 +221,11 @@
 // next query up until it's capped at 1 hour.
 // https://tools.ietf.org/html/rfc6762#section-5.2
 
-TEST_F(MdnsTrackerTest, QueryTrackerQueryAfterDelay) {
+TEST_F(MdnsTrackerTest, QuestionTrackerStartStop) {
+  TrackerStartStop<MdnsQuestionTracker>(a_question_);
+}
+
+TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) {
   MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
                               &random_);
   tracker.Start(a_question_);
@@ -116,46 +241,24 @@
   }
 }
 
-TEST_F(MdnsTrackerTest, QueryTrackerSendsMessage) {
+TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) {
   MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
                               &random_);
   tracker.Start(a_question_);
 
-  EXPECT_CALL(socket_, SendMessage(_, kQueryBytes.size(), _))
-      .WillOnce(testing::WithArgs<0, 1>(
-          testing::Invoke([this](const void* data, size_t size) {
-            EXPECT_EQ(size, kQueryBytes.size());
-            const uint8_t* bytes = static_cast<const uint8_t*>(data);
-            // Start at bytes[2] to skip a generated message ID.
-            for (size_t i = 2; i < size; ++i) {
-              EXPECT_EQ(bytes[i], kQueryBytes[i]);
-            }
-          })));
+  EXPECT_CALL(socket_, SendMessage(_, _, _))
+      .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
+          kQuestionQueryBytes.data(), kQuestionQueryBytes.size())));
+
   clock_.Advance(std::chrono::milliseconds(120));
 }
 
-TEST_F(MdnsTrackerTest, QueryTrackerNoQueryAfterStop) {
-  MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
-                              &random_);
-
-  EXPECT_EQ(tracker.Start(a_question_), Error(Error::Code::kNone));
-  EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
-  EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
-  // Advance fake clock by a long time interval to make sure if there's a
-  // scheduled task, it will run.
-  clock_.Advance(std::chrono::hours(1));
+TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterStop) {
+  TrackerNoQueryAfterStop<MdnsQuestionTracker>(a_question_);
 }
 
-TEST_F(MdnsTrackerTest, QueryTrackerNoQueryAfterDestruction) {
-  {
-    MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
-                                &random_);
-    tracker.Start(a_question_);
-  }
-  EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
-  // Advance fake clock by a long time interval to make sure if there's a
-  // scheduled task, it will run.
-  clock_.Advance(std::chrono::hours(1));
+TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) {
+  TrackerNoQueryAfterDestruction<MdnsQuestionTracker>(a_question_);
 }
 
 }  // namespace mdns