blob: 25887ef2b6caeb8e3d257532bd318fe0fe073f16 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/mdns/mdns_querier.h"
#include <algorithm>
#include <bitset>
#include <memory>
#include <unordered_set>
#include <utility>
#include <vector>
#include "discovery/common/config.h"
#include "discovery/common/reporting_client.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_sender.h"
#include "discovery/mdns/public/mdns_constants.h"
namespace openscreen {
namespace discovery {
namespace {
const std::vector<DnsType> kTranslatedNsecAnyQueryTypes = {
DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
if (record.dns_type() != DnsType::kNSEC) {
return false;
}
const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
// RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
// record to indicate this is an mDNS NSEC record rather than a traditional
// DNS NSEC record.
if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
nsec.types().end()) {
return false;
}
return std::find_if(nsec.types().begin(), nsec.types().end(),
[type](DnsType stored_type) {
return stored_type == type ||
stored_type == DnsType::kANY;
}) != nsec.types().end();
}
struct HashDnsType {
inline size_t operator()(DnsType type) const {
return static_cast<size_t>(type);
}
};
// Helper used for sorting MDNS records. This function guarantees the following:
// - All MdnsRecords with the same name appear adjacent to each-other.
// - An NSEC record with a given name appears before all other records with the
// same name.
bool CompareRecordByNameAndType(const MdnsRecord& first,
const MdnsRecord& second) {
if (first.name() != second.name()) {
return first.name() < second.name();
}
if ((first.dns_type() == DnsType::kNSEC) !=
(second.dns_type() == DnsType::kNSEC)) {
return first.dns_type() == DnsType::kNSEC;
}
return first < second;
}
class DnsTypeBitSet {
public:
// Returns whether any types are currently stored in this data structure.
bool IsEmpty() { return !elements_.any(); }
// Attempts to insert the given type into this data structure. Returns
// true iff the type was not already present.
bool Insert(DnsType type) {
uint16_t bit = (type == DnsType::kANY) ? 0 : static_cast<uint16_t>(type);
bool was_set = elements_.test(bit);
elements_.set(bit);
return !was_set;
}
// Iterates over all members of the provided container, inserting each
// DnsType contained within to this instance. Returns true iff any element
// inserted was not already present in this instance.
template <typename Container>
bool Insert(const Container& container) {
bool has_element_been_inserted = false;
for (DnsType type : container) {
has_element_been_inserted |= Insert(type);
}
return has_element_been_inserted;
}
// Attempts to remove the given type from this data structure. Returns true
// iff the type was present prior to this call.
bool Remove(DnsType type) {
if (IsEmpty()) {
return false;
} else if (type == DnsType::kANY) {
elements_.reset();
return true;
}
uint16_t bit = static_cast<uint16_t>(type);
bool was_set = elements_.test(bit);
elements_.reset(bit);
return was_set;
}
// Returns the DnsTypes currently stored in this data structure.
std::vector<DnsType> GetTypes() const {
if (elements_.test(0)) {
return {DnsType::kANY};
}
std::vector<DnsType> types;
for (DnsType type : kSupportedDnsTypes) {
if (type == DnsType::kANY) {
continue;
}
uint16_t cast_int = static_cast<uint16_t>(type);
if (elements_.test(cast_int)) {
types.push_back(type);
}
}
return types;
}
private:
std::bitset<64> elements_;
};
// Modifies |records| such that no NSEC record signifies the nonexistance of a
// record which is also present in the same message. Order of the input vector
// is NOT preserved.
// NOTE: |records| is not of type MdnsRecord::ConstRef because the members must
// be modified.
// TODO(b/170353378): Break this logic into a separate processing module between
// the MdnsReader and the MdnsQuerier.
void RemoveInvalidNsecFlags(std::vector<MdnsRecord>* records) {
// Sort the records so NSEC records are first so that only one iteration
// through all records is needed.
std::sort(records->begin(), records->end(), CompareRecordByNameAndType);
// The set of NSEC records that need to be removed from |records|. This can't
// be done as part of the below loop because it would invalidate the iterator
// that's still being used.
std::vector<std::vector<MdnsRecord>::iterator> nsecs_to_delete;
// Process all elements.
for (auto it = records->begin(); it != records->end();) {
if (it->dns_type() != DnsType::kNSEC) {
it++;
continue;
}
// Track whether the current NSEC record in the input vector has been
// modified by some step of this algorithm, be that merging with another
// record, removing a DnsType, or any other modification.
bool has_changed = false;
// The types for the new record to create, if |has_changed|.
const NsecRecordRdata& nsec_rdata = absl::get<NsecRecordRdata>(it->rdata());
DnsTypeBitSet types;
for (DnsType type : nsec_rdata.types()) {
types.Insert(type);
}
auto nsec = it;
it++;
// Combine multiple NSECs to simplify the following code. This probably
// won't happen, but the RFC doesn't exclude the possibility, so account for
// it. Define the TTL of this new NSEC record created by this merge process
// to be the minimum of all merged NSEC records.
std::chrono::seconds new_ttl = nsec->ttl();
while (it != records->end() && it->name() == nsec->name() &&
it->dns_type() == DnsType::kNSEC) {
has_changed |=
types.Insert(absl::get<NsecRecordRdata>(it->rdata()).types());
new_ttl = std::min(new_ttl, it->ttl());
it = records->erase(it);
}
// Remove any types associated with a known record type.
for (; it != records->end() && it->name() == nsec->name(); it++) {
OSP_DCHECK(it->dns_type() != DnsType::kNSEC);
has_changed |= types.Remove(it->dns_type());
}
// Modify the stored NSEC record, if needed.
if (has_changed && types.IsEmpty()) {
nsecs_to_delete.push_back(nsec);
} else if (has_changed) {
NsecRecordRdata new_rdata(nsec_rdata.next_domain_name(),
types.GetTypes());
*nsec = MdnsRecord(nsec->name(), nsec->dns_type(), nsec->dns_class(),
nsec->record_type(), new_ttl, std::move(new_rdata));
}
}
// Erase invalid NSEC records. Go backwards to avoid invalidating the
// remaining iterators.
for (auto erase_it = nsecs_to_delete.rbegin();
erase_it != nsecs_to_delete.rend(); erase_it++) {
records->erase(*erase_it);
}
}
} // namespace
MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
MdnsQuerier* querier,
MdnsSender* sender,
MdnsRandom* random_delay,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
ReportingClient* reporting_client,
const Config& config)
: querier_(querier),
sender_(sender),
random_delay_(random_delay),
task_runner_(task_runner),
now_function_(now_function),
reporting_client_(reporting_client),
config_(config) {
OSP_DCHECK(sender_);
OSP_DCHECK(random_delay_);
OSP_DCHECK(task_runner_);
OSP_DCHECK(reporting_client_);
OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
}
std::vector<std::reference_wrapper<const MdnsRecordTracker>>
MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
return Find(name, DnsType::kANY, DnsClass::kANY);
}
std::vector<std::reference_wrapper<const MdnsRecordTracker>>
MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
DnsType dns_type,
DnsClass dns_class) {
std::vector<RecordTrackerConstRef> results;
auto pair = records_.equal_range(name);
for (auto it = pair.first; it != pair.second; it++) {
const MdnsRecordTracker& tracker = *it->second;
if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
(dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
results.push_back(std::cref(tracker));
}
}
return results;
}
int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
TrackerApplicableCheck check) {
auto pair = records_.equal_range(domain);
int count = 0;
for (RecordMap::iterator it = pair.first; it != pair.second;) {
if (check(*it->second)) {
lru_order_.erase(it->second);
it = records_.erase(it);
count++;
} else {
it++;
}
}
return count;
}
int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
const DomainName& domain,
TrackerApplicableCheck check) {
auto pair = records_.equal_range(domain);
int count = 0;
for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
if (check(*it->second)) {
MoveToEnd(it);
it->second->ExpireSoon();
count++;
}
}
return count;
}
int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
TrackerApplicableCheck check) {
return Update(record, check, [](const MdnsRecordTracker& t) {});
}
int MdnsQuerier::RecordTrackerLruCache::Update(
const MdnsRecord& record,
TrackerApplicableCheck check,
TrackerChangeCallback on_rdata_update) {
auto pair = records_.equal_range(record.name());
int count = 0;
for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
if (check(*it->second)) {
auto result = it->second->Update(record);
if (result.is_error()) {
reporting_client_->OnRecoverableError(
Error(Error::Code::kUpdateReceivedRecordFailure,
result.error().ToString()));
continue;
}
count++;
if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
it->second->ExpireSoon();
MoveToEnd(it);
} else {
MoveToBeginning(it);
if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
on_rdata_update(*it->second);
}
}
}
}
return count;
}
const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
MdnsRecord record,
DnsType dns_type) {
auto expiration_callback = [this](const MdnsRecordTracker* tracker,
const MdnsRecord& record) {
querier_->OnRecordExpired(tracker, record);
};
while (lru_order_.size() >=
static_cast<size_t>(config_.querier_max_records_cached)) {
// This call erases one of the tracked records.
OSP_DVLOG << "Maximum cacheable record count exceeded ("
<< config_.querier_max_records_cached << ")";
lru_order_.back().ExpireNow();
}
auto name = record.name();
lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
now_function_, random_delay_,
std::move(expiration_callback));
records_.emplace(std::move(name), lru_order_.begin());
return lru_order_.front();
}
void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
it->second = lru_order_.begin();
}
void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
lru_order_.splice(lru_order_.end(), lru_order_, it->second);
it->second = --lru_order_.end();
}
MdnsQuerier::MdnsQuerier(MdnsSender* sender,
MdnsReceiver* receiver,
TaskRunner* task_runner,
ClockNowFunctionPtr now_function,
MdnsRandom* random_delay,
ReportingClient* reporting_client,
Config config)
: sender_(sender),
receiver_(receiver),
task_runner_(task_runner),
now_function_(now_function),
random_delay_(random_delay),
reporting_client_(reporting_client),
config_(std::move(config)),
records_(this,
sender_,
random_delay_,
task_runner_,
now_function_,
reporting_client_,
config_) {
OSP_DCHECK(sender_);
OSP_DCHECK(receiver_);
OSP_DCHECK(task_runner_);
OSP_DCHECK(now_function_);
OSP_DCHECK(random_delay_);
OSP_DCHECK(reporting_client_);
receiver_->AddResponseCallback(this);
}
MdnsQuerier::~MdnsQuerier() {
receiver_->RemoveResponseCallback(this);
}
// NOTE: The code below is range loops instead of std:find_if, for better
// readability, brevity and homogeneity. Using std::find_if results in a few
// more lines of code, readability suffers from extra lambdas.
void MdnsQuerier::StartQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(callback);
OSP_DCHECK(CanBeQueried(dns_type));
// Add a new callback if haven't seen it before
auto callbacks_it = callbacks_.equal_range(name);
for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
const CallbackInfo& callback_info = entry->second;
if (dns_type == callback_info.dns_type &&
dns_class == callback_info.dns_class &&
callback == callback_info.callback) {
// Already have this callback
return;
}
}
callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
// Notify the new callback with previously cached records.
// NOTE: In the future, could allow callers to fetch cached records after
// adding a callback, for example to prime the UI.
std::vector<PendingQueryChange> pending_changes;
const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
records_.Find(name, dns_type, dns_class);
for (const MdnsRecordTracker& tracker : trackers) {
if (!tracker.is_negative_response()) {
MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
tracker.record_type(), tracker.ttl(),
tracker.rdata());
std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
std::move(stored_record), RecordChangedEvent::kCreated);
pending_changes.insert(pending_changes.end(), new_changes.begin(),
new_changes.end());
}
}
// Add a new question if haven't seen it before
auto questions_it = questions_.equal_range(name);
const bool is_question_already_tracked =
std::find_if(questions_it.first, questions_it.second,
[dns_type, dns_class](const auto& entry) {
const MdnsQuestion& tracked_question =
entry.second->question();
return dns_type == tracked_question.dns_type() &&
dns_class == tracked_question.dns_class();
}) != questions_it.second;
if (!is_question_already_tracked) {
AddQuestion(
MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
}
// Apply any pending changes from the OnRecordChanged() callbacks.
ApplyPendingChanges(std::move(pending_changes));
}
void MdnsQuerier::StopQuery(const DomainName& name,
DnsType dns_type,
DnsClass dns_class,
MdnsRecordChangedCallback* callback) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(callback);
if (!CanBeQueried(dns_type)) {
return;
}
// Find and remove the callback.
int callbacks_for_key = 0;
auto callbacks_it = callbacks_.equal_range(name);
for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
const CallbackInfo& callback_info = entry->second;
if (dns_type == callback_info.dns_type &&
dns_class == callback_info.dns_class) {
if (callback == callback_info.callback) {
entry = callbacks_.erase(entry);
} else {
++callbacks_for_key;
++entry;
}
}
}
// Exit if there are still callbacks registered for DomainName + DnsType +
// DnsClass
if (callbacks_for_key > 0) {
return;
}
// Find and delete a question that does not have any associated callbacks
auto questions_it = questions_.equal_range(name);
for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
const MdnsQuestion& tracked_question = entry->second->question();
if (dns_type == tracked_question.dns_type() &&
dns_class == tracked_question.dns_class()) {
questions_.erase(entry);
return;
}
}
}
void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
// Get the ongoing queries and their callbacks.
std::vector<CallbackInfo> callbacks;
auto its = callbacks_.equal_range(name);
for (auto it = its.first; it != its.second; it++) {
callbacks.push_back(std::move(it->second));
}
callbacks_.erase(name);
// Remove all known questions and answers.
questions_.erase(name);
records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
// Restart the queries.
for (const auto& cb : callbacks) {
StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
}
}
void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(message.type() == MessageType::Response);
OSP_DVLOG << "Received mDNS Response message with "
<< message.answers().size() << " answers and "
<< message.additional_records().size()
<< " additional records. Processing...";
std::vector<MdnsRecord> records_to_process;
// Add any records that are relevant for this querier.
bool found_relevant_records = false;
for (const MdnsRecord& record : message.answers()) {
if (ShouldAnswerRecordBeProcessed(record)) {
records_to_process.push_back(record);
found_relevant_records = true;
}
}
// If any of the message's answers are relevant, add all additional records.
// Else, since the message has already been received and parsed, use any
// individual records relevant to this querier to update the cache.
for (const MdnsRecord& record : message.additional_records()) {
if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
records_to_process.push_back(record);
}
}
// Drop NSEC records associated with a non-NSEC record of the same type.
RemoveInvalidNsecFlags(&records_to_process);
// Process all remaining records.
for (const MdnsRecord& record_to_process : records_to_process) {
ProcessRecord(record_to_process);
}
OSP_DVLOG << "\tmDNS Response processed (" << records_to_process.size()
<< " records accepted)!";
// TODO(crbug.com/openscreen/83): Check authority records.
}
bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
// First, accept the record if it's associated with an ongoing question.
const auto questions_range = questions_.equal_range(answer.name());
const auto it = std::find_if(
questions_range.first, questions_range.second,
[&answer](const auto& pair) {
return (pair.second->question().dns_type() == DnsType::kANY ||
IsNegativeResponseFor(answer,
pair.second->question().dns_type()) ||
pair.second->question().dns_type() == answer.dns_type()) &&
(pair.second->question().dns_class() == DnsClass::kANY ||
pair.second->question().dns_class() == answer.dns_class());
});
if (it != questions_range.second) {
return true;
}
// If not, check if it corresponds to an already existing record. This is
// required because records which are already stored may either have been
// received in an additional records section, or are associated with a query
// which is no longer active.
std::vector<DnsType> types{answer.dns_type()};
if (answer.dns_type() == DnsType::kNSEC) {
const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
types = nsec_rdata.types();
}
for (DnsType type : types) {
std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
records_.Find(answer.name(), type, answer.dns_class());
if (!trackers.empty()) {
return true;
}
}
// In all other cases, the record isn't relevant. Drop it.
return false;
}
void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
const MdnsRecord& record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
if (!tracker->is_negative_response()) {
ProcessCallbacks(record, RecordChangedEvent::kExpired);
}
records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
return tracker == &it_tracker;
});
}
void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
// Skip all records that can't be processed.
if (!CanBeProcessed(record.dns_type())) {
return;
}
// Ignore NSEC records if the embedder has configured us to do so.
if (config_.ignore_nsec_responses && record.dns_type() == DnsType::kNSEC) {
return;
}
// Get the types which the received record is associated with. In most cases
// this will only be the type of the provided record, but in the case of
// NSEC records this will be all records which the record dictates the
// nonexistence of.
std::vector<DnsType> types;
const std::vector<DnsType>* types_ptr = &types;
if (record.dns_type() == DnsType::kNSEC) {
const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
DnsType::kANY) != nsec_rdata.types().end()) {
types_ptr = &kTranslatedNsecAnyQueryTypes;
} else {
types_ptr = &nsec_rdata.types();
}
} else {
types.push_back(record.dns_type());
}
// Apply the update for each type that the record is associated with.
for (DnsType dns_type : *types_ptr) {
switch (record.record_type()) {
case RecordType::kShared: {
ProcessSharedRecord(record, dns_type);
break;
}
case RecordType::kUnique: {
ProcessUniqueRecord(record, dns_type);
break;
}
}
}
}
void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
DnsType dns_type) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(record.record_type() == RecordType::kShared);
// By design, NSEC records are never shared records.
if (record.dns_type() == DnsType::kNSEC) {
return;
}
// For any records updated, this host already has this shared record. Since
// the RDATA matches, this is only a TTL update.
auto check = [&record](const MdnsRecordTracker& tracker) {
return record.dns_type() == tracker.dns_type() &&
record.dns_class() == tracker.dns_class() &&
record.rdata() == tracker.rdata();
};
auto updated_count = records_.Update(record, std::move(check));
if (!updated_count) {
// Have never before seen this shared record, insert a new one.
AddRecord(record, dns_type);
ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
}
void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
DnsType dns_type) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(record.record_type() == RecordType::kUnique);
std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
records_.Find(record.name(), dns_type, record.dns_class());
size_t num_records_for_key = trackers.size();
// Have not seen any records with this key before. This case is expected the
// first time a record is received.
if (num_records_for_key == size_t{0}) {
const bool will_exist = record.dns_type() != DnsType::kNSEC;
AddRecord(record, dns_type);
if (will_exist) {
ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
} else if (num_records_for_key == size_t{1}) {
// There is exactly one tracker associated with this key. This is the
// expected case when a record matching this one has already been seen.
ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
} else {
// Multiple records with the same key.
ProcessMultiTrackedUniqueRecord(record, dns_type);
}
}
void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
const MdnsRecord& record,
const MdnsRecordTracker& tracker) {
const bool existed_previously = !tracker.is_negative_response();
const bool will_exist = record.dns_type() != DnsType::kNSEC;
// Calculate the callback to call on record update success while the old
// record still exists.
MdnsRecord record_for_callback = record;
if (existed_previously && !will_exist) {
record_for_callback =
MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
tracker.record_type(), tracker.ttl(), tracker.rdata());
}
auto on_rdata_change = [this, r = std::move(record_for_callback),
existed_previously,
will_exist](const MdnsRecordTracker& tracker) {
// If RDATA on the record is different, notify that the record has
// been updated.
if (existed_previously && will_exist) {
ProcessCallbacks(r, RecordChangedEvent::kUpdated);
} else if (existed_previously) {
// Do not expire the tracker, because it still holds an NSEC record.
ProcessCallbacks(r, RecordChangedEvent::kExpired);
} else if (will_exist) {
ProcessCallbacks(r, RecordChangedEvent::kCreated);
}
};
int updated_count = records_.Update(
record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
std::move(on_rdata_change));
OSP_DCHECK_EQ(updated_count, 1);
}
void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
DnsType dns_type) {
auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
return tracker.dns_type() == dns_type &&
tracker.dns_class() == record.dns_class() &&
tracker.rdata() == record.rdata();
};
int update_count = records_.Update(
record, std::move(update_check),
[](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
OSP_DCHECK_LE(update_count, 1);
auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
return tracker.dns_type() == dns_type &&
tracker.dns_class() == record.dns_class() &&
tracker.rdata() != record.rdata();
};
int expire_count =
records_.ExpireSoon(record.name(), std::move(expire_check));
OSP_DCHECK_GE(expire_count, 1);
// Did not find an existing record to update.
if (!update_count && !expire_count) {
AddRecord(record, dns_type);
if (record.dns_type() != DnsType::kNSEC) {
ProcessCallbacks(record, RecordChangedEvent::kCreated);
}
}
}
void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
RecordChangedEvent event) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
std::vector<PendingQueryChange> pending_changes;
auto callbacks_it = callbacks_.equal_range(record.name());
for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
const CallbackInfo& callback_info = entry->second;
if ((callback_info.dns_type == DnsType::kANY ||
record.dns_type() == callback_info.dns_type) &&
(callback_info.dns_class == DnsClass::kANY ||
record.dns_class() == callback_info.dns_class)) {
std::vector<PendingQueryChange> new_changes =
callback_info.callback->OnRecordChanged(record, event);
pending_changes.insert(pending_changes.end(), new_changes.begin(),
new_changes.end());
}
}
ApplyPendingChanges(std::move(pending_changes));
}
void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
auto tracker = std::make_unique<MdnsQuestionTracker>(
question, sender_, task_runner_, now_function_, random_delay_, config_);
MdnsQuestionTracker* ptr = tracker.get();
questions_.emplace(question.name(), std::move(tracker));
// Let all records associated with this question know that there is a new
// query that can be used for their refresh.
std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
records_.Find(question.name(), question.dns_type(), question.dns_class());
for (const MdnsRecordTracker& tracker : trackers) {
// NOTE: When the pointed to object is deleted, its dtor removes itself
// from all associated records.
ptr->AddAssociatedRecord(&tracker);
}
}
void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
// Add the new record.
const auto& tracker = records_.StartTracking(record, type);
// Let all questions associated with this record know that there is a new
// record that answers them (for known answer suppression).
auto query_it = questions_.equal_range(record.name());
for (auto entry = query_it.first; entry != query_it.second; ++entry) {
const MdnsQuestion& query = entry->second->question();
const bool is_relevant_type =
type == DnsType::kANY || type == query.dns_type();
const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
record.dns_class() == query.dns_class();
if (is_relevant_type && is_relevant_class) {
// NOTE: When the pointed to object is deleted, its dtor removes itself
// from all associated queries.
entry->second->AddAssociatedRecord(&tracker);
}
}
}
void MdnsQuerier::ApplyPendingChanges(
std::vector<PendingQueryChange> pending_changes) {
for (auto& pending_change : pending_changes) {
switch (pending_change.change_type) {
case PendingQueryChange::kStartQuery:
StartQuery(std::move(pending_change.name), pending_change.dns_type,
pending_change.dns_class, pending_change.callback);
break;
case PendingQueryChange::kStopQuery:
StopQuery(std::move(pending_change.name), pending_change.dns_type,
pending_change.dns_class, pending_change.callback);
break;
}
}
}
} // namespace discovery
} // namespace openscreen