Update and expiration callbacks for MdnsRecordTracker

Change-Id: I21fe7ded3f52cd1749d8a154e4dfc472efdf25c3
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1799365
Reviewed-by: Ryan Keane <rwkeane@google.com>
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
diff --git a/cast/common/mdns/mdns_records.cc b/cast/common/mdns/mdns_records.cc
index 128a60e..45d9d25 100644
--- a/cast/common/mdns/mdns_records.cc
+++ b/cast/common/mdns/mdns_records.cc
@@ -172,6 +172,7 @@
       ttl_(ttl),
       rdata_(std::move(rdata)) {
   OSP_DCHECK(!name_.empty());
+  OSP_DCHECK_LE(ttl_.count(), std::numeric_limits<uint32_t>::max());
   OSP_DCHECK((dns_type == DnsType::kSRV &&
               absl::holds_alternative<SrvRecordRdata>(rdata_)) ||
              (dns_type == DnsType::kA &&
diff --git a/cast/common/mdns/mdns_trackers.cc b/cast/common/mdns/mdns_trackers.cc
index 200b5b3..18e0658 100644
--- a/cast/common/mdns/mdns_trackers.cc
+++ b/cast/common/mdns/mdns_trackers.cc
@@ -11,6 +11,37 @@
 namespace cast {
 namespace mdns {
 
+namespace {
+
+// RFC 6762 Section 5.2
+// https://tools.ietf.org/html/rfc6762#section-5.2
+
+// Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
+constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
+
+// Intervals between successive queries must increase by at least a factor of 2.
+constexpr int kIntervalIncreaseFactor = 2;
+
+// The interval between the first two queries must be at least one second.
+constexpr std::chrono::seconds kMinimumQueryInterval{1};
+
+// The querier may cap the question refresh interval to a maximum of 60 minutes.
+constexpr std::chrono::minutes kMaximumQueryInterval{60};
+
+// RFC 6762 Section 10.1
+// https://tools.ietf.org/html/rfc6762#section-10.1
+
+// A goodbye record is a record with TTL of 0.
+bool IsGoodbyeRecord(const MdnsRecord& record) {
+  constexpr std::chrono::seconds zero_ttl{0};
+  return record.ttl() == zero_ttl;
+}
+
+// The interval between the first two queries must be at least one second.
+constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
+
+}  // namespace
+
 MdnsTracker::MdnsTracker(MdnsSender* sender,
                          TaskRunner* task_runner,
                          ClockNowFunctionPtr now_function,
@@ -25,17 +56,19 @@
   OSP_DCHECK(sender);
 }
 
-namespace {
-// RFC 6762 Section 5.2
-// https://tools.ietf.org/html/rfc6762#section-5.2
-constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
-}  // namespace
-
-MdnsRecordTracker::MdnsRecordTracker(MdnsSender* sender,
-                                     TaskRunner* task_runner,
-                                     ClockNowFunctionPtr now_function,
-                                     MdnsRandom* random_delay)
-    : MdnsTracker(sender, task_runner, now_function, random_delay) {}
+MdnsRecordTracker::MdnsRecordTracker(
+    MdnsSender* sender,
+    TaskRunner* task_runner,
+    ClockNowFunctionPtr now_function,
+    MdnsRandom* random_delay,
+    std::function<void(const MdnsRecord&)> record_updated_callback,
+    std::function<void(const MdnsRecord&)> record_expired_callback)
+    : MdnsTracker(sender, task_runner, now_function, random_delay),
+      record_updated_callback_(record_updated_callback),
+      record_expired_callback_(record_expired_callback) {
+  OSP_DCHECK(record_updated_callback);
+  OSP_DCHECK(record_expired_callback);
+}
 
 Error MdnsRecordTracker::Start(MdnsRecord record) {
   if (record_.has_value()) {
@@ -60,6 +93,49 @@
   return Error::None();
 }
 
+Error MdnsRecordTracker::Update(const MdnsRecord& new_record) {
+  if (!record_.has_value()) {
+    return Error::Code::kOperationInvalid;
+  }
+
+  MdnsRecord& old_record = record_.value();
+  if ((old_record.dns_type() != new_record.dns_type()) ||
+      (old_record.dns_class() != new_record.dns_class()) ||
+      (old_record.name() != new_record.name())) {
+    // The new record has been passed to a wrong tracker
+    return Error::Code::kParameterInvalid;
+  }
+
+  // Check if RDATA has changed before a call to Stop clears the old record
+  bool is_updated = (new_record.rdata() != old_record.rdata());
+
+  Error error = Stop();
+  if (!error.ok()) {
+    return error;
+  }
+
+  if (IsGoodbyeRecord(new_record)) {
+    // RFC 6762 Section 10.1
+    // https://tools.ietf.org/html/rfc6762#section-10.1
+    // In case of a goodbye record, the querier should set TTL to 1 second
+    error = Start(MdnsRecord(new_record.name(), new_record.dns_type(),
+                             new_record.dns_class(), new_record.record_type(),
+                             kGoodbyeRecordTtl, new_record.rdata()));
+  } else {
+    error = Start(new_record);
+  }
+
+  if (!error.ok()) {
+    return error;
+  }
+
+  if (is_updated) {
+    record_updated_callback_(record_.value());
+  }
+
+  return Error::None();
+}
+
 bool MdnsRecordTracker::IsStarted() {
   return record_.has_value();
 };
@@ -77,7 +153,7 @@
     send_alarm_.Schedule(std::bind(&MdnsRecordTracker::SendQuery, this),
                          GetNextSendTime());
   } else {
-    // TODO(yakimakha): Notify owner that the record has expired
+    record_expired_callback_(record);
   }
 }
 
@@ -96,14 +172,6 @@
   return start_time_ + delay;
 }
 
-namespace {
-// RFC 6762 Section 5.2
-// https://tools.ietf.org/html/rfc6762#section-5.2
-constexpr int kIntervalIncreaseFactor = 2;
-constexpr std::chrono::seconds kMinimumQueryInterval{1};
-constexpr std::chrono::minutes kMaximumQueryInterval{60};
-}  // namespace
-
 MdnsQuestionTracker::MdnsQuestionTracker(MdnsSender* sender,
                                          TaskRunner* task_runner,
                                          ClockNowFunctionPtr now_function,
diff --git a/cast/common/mdns/mdns_trackers.h b/cast/common/mdns/mdns_trackers.h
index 7a69e8b..153facd 100644
--- a/cast/common/mdns/mdns_trackers.h
+++ b/cast/common/mdns/mdns_trackers.h
@@ -52,10 +52,13 @@
 // refreshing records as they reach their expiration time.
 class MdnsRecordTracker : public MdnsTracker {
  public:
-  MdnsRecordTracker(MdnsSender* sender,
-                    TaskRunner* task_runner,
-                    ClockNowFunctionPtr now_function,
-                    MdnsRandom* random_delay);
+  MdnsRecordTracker(
+      MdnsSender* sender,
+      TaskRunner* task_runner,
+      ClockNowFunctionPtr now_function,
+      MdnsRandom* random_delay,
+      std::function<void(const MdnsRecord&)> record_updated_callback,
+      std::function<void(const MdnsRecord&)> record_expired_callback);
 
   // Starts sending query messages for the provided record using record's TTL
   // and the time of the call to determine when to send the queries. Returns
@@ -70,6 +73,12 @@
   // stopped.
   Error Stop();
 
+  // Updates record tracker with the new record:
+  // 1. Calls update callback if RDATA has changed.
+  // 2. Resets TTL to the value specified in new_record.
+  // 3. Schedules expiration in case of a goodbye record.
+  Error Update(const MdnsRecord& new_record);
+
   // Returns true if MdnsRecordTracker instance has been started and is
   // automatically refreshing the record, false otherwise.
   bool IsStarted();
@@ -84,6 +93,8 @@
   Clock::time_point start_time_;
   // Number of times a question to refresh the record has been sent.
   size_t send_count_ = 0;
+  std::function<void(const MdnsRecord&)> record_updated_callback_;
+  std::function<void(const MdnsRecord&)> record_expired_callback_;
 };
 
 // MdnsQuestionTracker manages automatic resending of mDNS queries for
diff --git a/cast/common/mdns/mdns_trackers_unittest.cc b/cast/common/mdns/mdns_trackers_unittest.cc
index 334ff2b..95609ad 100644
--- a/cast/common/mdns/mdns_trackers_unittest.cc
+++ b/cast/common/mdns/mdns_trackers_unittest.cc
@@ -73,25 +73,25 @@
                   ARecordRdata(IPAddress{172, 0, 0, 1})) {}
 
   template <class TrackerType, class TrackedType>
-  void TrackerStartStop(TrackedType tracked_data) {
-    TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
-    EXPECT_EQ(tracker.IsStarted(), false);
-    EXPECT_EQ(tracker.Stop(), Error(Error::Code::kOperationInvalid));
-    EXPECT_EQ(tracker.IsStarted(), false);
-    EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
-    EXPECT_EQ(tracker.IsStarted(), true);
-    EXPECT_EQ(tracker.Start(tracked_data),
+  void TrackerStartStop(std::unique_ptr<TrackerType> tracker,
+                        TrackedType tracked_data) {
+    EXPECT_EQ(tracker->IsStarted(), false);
+    EXPECT_EQ(tracker->Stop(), Error(Error::Code::kOperationInvalid));
+    EXPECT_EQ(tracker->IsStarted(), false);
+    EXPECT_EQ(tracker->Start(tracked_data), Error(Error::Code::kNone));
+    EXPECT_EQ(tracker->IsStarted(), true);
+    EXPECT_EQ(tracker->Start(tracked_data),
               Error(Error::Code::kOperationInvalid));
-    EXPECT_EQ(tracker.IsStarted(), true);
-    EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
-    EXPECT_EQ(tracker.IsStarted(), false);
+    EXPECT_EQ(tracker->IsStarted(), true);
+    EXPECT_EQ(tracker->Stop(), Error(Error::Code::kNone));
+    EXPECT_EQ(tracker->IsStarted(), false);
   }
 
   template <class TrackerType, class TrackedType>
-  void TrackerNoQueryAfterStop(TrackedType tracked_data) {
-    TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
-    EXPECT_EQ(tracker.Start(tracked_data), Error(Error::Code::kNone));
-    EXPECT_EQ(tracker.Stop(), Error(Error::Code::kNone));
+  void TrackerNoQueryAfterStop(std::unique_ptr<TrackerType> tracker,
+                               TrackedType tracked_data) {
+    EXPECT_EQ(tracker->Start(tracked_data), Error(Error::Code::kNone));
+    EXPECT_EQ(tracker->Stop(), Error(Error::Code::kNone));
     EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
     // Advance fake clock by a long time interval to make sure if there's a
     // scheduled task, it will run.
@@ -99,17 +99,33 @@
   }
 
   template <class TrackerType, class TrackedType>
-  void TrackerNoQueryAfterDestruction(TrackedType tracked_data) {
-    {
-      TrackerType tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
-      tracker.Start(tracked_data);
-    }
+  void TrackerNoQueryAfterDestruction(std::unique_ptr<TrackerType> tracker,
+                                      TrackedType tracked_data) {
+    tracker->Start(tracked_data);
+    tracker.reset();
     EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(0);
     // Advance fake clock by a long time interval to make sure if there's a
     // scheduled task, it will run.
     clock_.Advance(std::chrono::hours(1));
   }
 
+  void UpdateCallback(const MdnsRecord&) { update_called_ = true; }
+  void ExpirationCallback(const MdnsRecord&) { expiration_called_ = true; }
+
+  std::unique_ptr<MdnsRecordTracker> CreateRecordTracker() {
+    return std::make_unique<MdnsRecordTracker>(
+        &sender_, &task_runner_, &FakeClock::now, &random_,
+        std::bind(&MdnsTrackerTest::UpdateCallback, this,
+                  std::placeholders::_1),
+        std::bind(&MdnsTrackerTest::ExpirationCallback, this,
+                  std::placeholders::_1));
+  }
+
+  std::unique_ptr<MdnsQuestionTracker> CreateQuestionTracker() {
+    return std::make_unique<MdnsQuestionTracker>(&sender_, &task_runner_,
+                                                 &FakeClock::now, &random_);
+  }
+
  protected:
   // clang-format off
   const std::vector<uint8_t> kQuestionQueryBytes = {
@@ -151,6 +167,9 @@
 
   MdnsQuestion a_question_;
   MdnsRecord a_record_;
+
+  bool update_called_ = false;
+  bool expiration_called_ = false;
 };
 
 // Records are re-queried at 80%, 85%, 90% and 95% TTL as per RFC 6762
@@ -162,12 +181,13 @@
 // https://tools.ietf.org/html/rfc6762#section-5.2
 
 TEST_F(MdnsTrackerTest, RecordTrackerStartStop) {
-  TrackerStartStop<MdnsRecordTracker>(a_record_);
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  TrackerStartStop(std::move(tracker), a_record_);
 }
 
 TEST_F(MdnsTrackerTest, RecordTrackerQueryAfterDelay) {
-  MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
-  tracker.Start(a_record_);
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
   // Only expect 4 queries being sent, when record reaches it's TTL it's
   // considered expired and another query is not sent
   constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00};
@@ -184,8 +204,8 @@
 }
 
 TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) {
-  MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
-  tracker.Start(a_record_);
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
 
   EXPECT_CALL(socket_, SendMessage(_, _, _))
       .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
@@ -196,16 +216,18 @@
 }
 
 TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterStop) {
-  TrackerNoQueryAfterStop<MdnsRecordTracker>(a_record_);
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  TrackerNoQueryAfterStop(std::move(tracker), a_record_);
 }
 
 TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterDestruction) {
-  TrackerNoQueryAfterDestruction<MdnsRecordTracker>(a_record_);
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  TrackerNoQueryAfterDestruction(std::move(tracker), a_record_);
 }
 
 TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterLateTask) {
-  MdnsRecordTracker tracker(&sender_, &task_runner_, &FakeClock::now, &random_);
-  tracker.Start(a_record_);
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
   // If task runner was too busy and callback happened too late, there should be
   // no query and instead the record will expire.
   // Check lower bound for task being late (TTL) and an arbitrarily long time
@@ -215,6 +237,113 @@
   clock_.Advance(std::chrono::hours(1));
 }
 
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateFailsWhenNotStarted) {
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  EXPECT_EQ(tracker->Update(a_record_), Error(Error::Code::kOperationInvalid));
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateFailsForMismatchedRecord) {
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  MdnsRecord updated_record = MdnsRecord(
+      DomainName{"alpha"}, a_record_.dns_type(), a_record_.dns_class(),
+      a_record_.record_type(), a_record_.ttl(), a_record_.rdata());
+  EXPECT_EQ(tracker->Update(updated_record),
+            Error(Error::Code::kParameterInvalid));
+
+  updated_record =
+      MdnsRecord(a_record_.name(), DnsType::kPTR, a_record_.dns_class(),
+                 a_record_.record_type(), a_record_.ttl(),
+                 PtrRecordRdata(DomainName{"bravo"}));
+  EXPECT_EQ(tracker->Update(updated_record),
+            Error(Error::Code::kParameterInvalid));
+
+  updated_record = MdnsRecord(a_record_.name(), a_record_.dns_type(),
+                              static_cast<DnsClass>(2), a_record_.record_type(),
+                              a_record_.ttl(), a_record_.rdata());
+  EXPECT_EQ(tracker->Update(updated_record),
+            Error(Error::Code::kParameterInvalid));
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerCallbackOnRdataUpdate) {
+  MdnsRecord updated_record(a_record_.name(), a_record_.dns_type(),
+                            a_record_.dns_class(), a_record_.record_type(),
+                            a_record_.ttl(),
+                            ARecordRdata(IPAddress{192, 168, 0, 1}));
+
+  update_called_ = false;
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  EXPECT_EQ(tracker->Update(updated_record), Error::None());
+  EXPECT_TRUE(update_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoCallbackOnTtlUpdate) {
+  update_called_ = false;
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  EXPECT_EQ(tracker->Update(a_record_), Error::None());
+  EXPECT_FALSE(update_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerUpdateResetsTtl) {
+  expiration_called_ = false;
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  // Advance time by 60% of record's TTL
+  Clock::duration advance_time =
+      std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.6);
+  clock_.Advance(advance_time);
+  // Now update the record, this must reset expiration time
+  EXPECT_EQ(tracker->Update(a_record_), Error::None());
+  // Advance time by 60% of record's TTL again
+  clock_.Advance(advance_time);
+  // Check that expiration callback was not called
+  EXPECT_FALSE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallback) {
+  expiration_called_ = false;
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  clock_.Advance(a_record_.ttl());
+  EXPECT_TRUE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) {
+  update_called_ = false;
+  expiration_called_ = false;
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  MdnsRecord goodbye_record(a_record_.name(), a_record_.dns_type(),
+                            a_record_.dns_class(), a_record_.record_type(),
+                            std::chrono::seconds{0}, a_record_.rdata());
+
+  EXPECT_EQ(tracker->Update(goodbye_record), Error::None());
+  // After a goodbye record is received, expiration is schedule in a second.
+  clock_.Advance(std::chrono::seconds{1});
+  EXPECT_FALSE(update_called_);
+  EXPECT_TRUE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterStop) {
+  expiration_called_ = false;
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  tracker->Stop();
+  clock_.Advance(a_record_.ttl());
+  EXPECT_FALSE(expiration_called_);
+}
+
+TEST_F(MdnsTrackerTest, RecordTrackerNoExpirationCallbackAfterDestruction) {
+  expiration_called_ = false;
+  std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker();
+  tracker->Start(a_record_);
+  tracker.reset();
+  clock_.Advance(a_record_.ttl());
+  EXPECT_FALSE(expiration_called_);
+}
+
 // Initial query is delayed for up to 120 ms as per RFC 6762 Section 5.2
 // Subsequent queries happen no sooner than a second after the initial query and
 // the interval between the queries increases at least by a factor of 2 for each
@@ -222,13 +351,13 @@
 // https://tools.ietf.org/html/rfc6762#section-5.2
 
 TEST_F(MdnsTrackerTest, QuestionTrackerStartStop) {
-  TrackerStartStop<MdnsQuestionTracker>(a_question_);
+  std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+  TrackerStartStop(std::move(tracker), a_question_);
 }
 
 TEST_F(MdnsTrackerTest, QuestionTrackerQueryAfterDelay) {
-  MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
-                              &random_);
-  tracker.Start(a_question_);
+  std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+  tracker->Start(a_question_);
 
   EXPECT_CALL(socket_, SendMessage(_, _, _)).Times(1);
   clock_.Advance(std::chrono::milliseconds(120));
@@ -242,9 +371,8 @@
 }
 
 TEST_F(MdnsTrackerTest, QuestionTrackerSendsMessage) {
-  MdnsQuestionTracker tracker(&sender_, &task_runner_, &FakeClock::now,
-                              &random_);
-  tracker.Start(a_question_);
+  std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+  tracker->Start(a_question_);
 
   EXPECT_CALL(socket_, SendMessage(_, _, _))
       .WillOnce(WithArgs<0, 1>(VerifyMessageBytesWithoutId(
@@ -254,11 +382,13 @@
 }
 
 TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterStop) {
-  TrackerNoQueryAfterStop<MdnsQuestionTracker>(a_question_);
+  std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+  TrackerNoQueryAfterStop(std::move(tracker), a_question_);
 }
 
 TEST_F(MdnsTrackerTest, QuestionTrackerNoQueryAfterDestruction) {
-  TrackerNoQueryAfterDestruction<MdnsQuestionTracker>(a_question_);
+  std::unique_ptr<MdnsQuestionTracker> tracker = CreateQuestionTracker();
+  TrackerNoQueryAfterDestruction(std::move(tracker), a_question_);
 }
 
 }  // namespace mdns