blob: 01267994e62fb116f9745f3b7bc1a9682b29aa91 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/dnssd/impl/querier_impl.h"
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include "discovery/common/reporting_client.h"
#include "discovery/dnssd/impl/conversion_layer.h"
#include "discovery/dnssd/impl/network_interface_config.h"
#include "platform/api/task_runner.h"
#include "util/osp_logging.h"
namespace openscreen {
namespace discovery {
namespace {
static constexpr char kLocalDomain[] = "local";
// Removes all error instances from the below records, and calls the log
// function on all errors present in |new_endpoints|. Input vectors are expected
// to be sorted in ascending order.
void ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>>* old_endpoints,
std::vector<ErrorOr<DnsSdInstanceEndpoint>>* new_endpoints,
std::function<void(Error)> log) {
OSP_DCHECK(old_endpoints);
OSP_DCHECK(new_endpoints);
auto old_it = old_endpoints->begin();
auto new_it = new_endpoints->begin();
// Iterate across both vectors and log new errors in the process.
// NOTE: In sorted order, all errors will appear before all non-errors.
while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) {
ErrorOr<DnsSdInstanceEndpoint>& old_ep = *old_it;
ErrorOr<DnsSdInstanceEndpoint>& new_ep = *new_it;
if (new_ep.is_value()) {
break;
}
// If they are equal, the element is in both |old_endpoints| and
// |new_endpoints|, so skip it in both vectors.
if (old_ep == new_ep) {
old_it++;
new_it++;
continue;
}
// There's an error in |old_endpoints| not in |new_endpoints|, so skip it.
if (old_ep < new_ep) {
old_it++;
continue;
}
// There's an error in |new_endpoints| not in |old_endpoints|, so it's a new
// error from the applied changes. Log it.
log(std::move(new_ep.error()));
new_it++;
}
// Skip all remaining errors in the old vector.
for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) {
}
// Log all errors remaining in the new vector.
for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) {
log(std::move(new_it->error()));
}
// Erase errors.
old_endpoints->erase(old_endpoints->begin(), old_it);
new_endpoints->erase(new_endpoints->begin(), new_it);
}
// Returns a vector containing the value of each ErrorOr<> instance provided.
// All ErrorOr<> values are expected to be non-errors.
std::vector<DnsSdInstanceEndpoint> GetValues(
std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints) {
std::vector<DnsSdInstanceEndpoint> results;
results.reserve(endpoints.size());
for (ErrorOr<DnsSdInstanceEndpoint>& endpoint : endpoints) {
results.push_back(std::move(endpoint.value()));
}
return results;
}
bool IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
const absl::optional<DnsSdInstanceEndpoint>& second) {
if (!first.has_value() || !second.has_value()) {
return !first.has_value() && !second.has_value();
}
// In the remaining case, both |first| and |second| must be values.
const DnsSdInstanceEndpoint& a = first.value();
const DnsSdInstanceEndpoint& b = second.value();
// All endpoints from this querier should have the same network interface
// because the querier is only associated with a single network interface.
OSP_DCHECK_EQ(a.network_interface(), b.network_interface());
// Function returns true if first < second.
return a.instance_id() == b.instance_id() &&
a.service_id() == b.service_id() && a.domain_id() == b.domain_id();
}
bool IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
const absl::optional<DnsSdInstanceEndpoint>& second) {
return !IsEqualOrUpdate(first, second);
}
// Calculates the created, updated, and deleted elements using the provided
// sets, appending these values to the provided vectors. Each of the input
// vectors is expected to contain only elements such that
// |element|.is_error() == false. Additionally, input vectors are expected to
// be sorted in ascending order.
//
// NOTE: A lot of operations are used to do this, but each is only O(n) so the
// resulting algorithm is still fast.
void CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,
std::vector<DnsSdInstanceEndpoint> new_endpoints,
std::vector<DnsSdInstanceEndpoint>* created_out,
std::vector<DnsSdInstanceEndpoint>* updated_out,
std::vector<DnsSdInstanceEndpoint>* deleted_out) {
OSP_DCHECK(created_out);
OSP_DCHECK(updated_out);
OSP_DCHECK(deleted_out);
// Use set difference with default operators to find the elements present in
// one list but not the others.
//
// NOTE: Because absl::optional<...> types are used here and below, calls to
// the ctor and dtor for empty elements are no-ops.
const int total_count = old_endpoints.size() + new_endpoints.size();
// This is the set of elements that aren't in the old endpoints, meaning the
// old endpoint either didn't exist or had different TXT / Address / etc..
std::vector<absl::optional<DnsSdInstanceEndpoint>> created_or_updated(
total_count);
auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(),
old_endpoints.begin(), old_endpoints.end(),
created_or_updated.begin());
created_or_updated.erase(new_end, created_or_updated.end());
// This is the set of elements that are only in the old endpoints, similar to
// the above.
std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted_or_updated(
total_count);
new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(),
new_endpoints.begin(), new_endpoints.end(),
deleted_or_updated.begin());
deleted_or_updated.erase(new_end, deleted_or_updated.end());
// Next, find the elements which were updated.
const size_t max_count =
std::max(created_or_updated.size(), deleted_or_updated.size());
std::vector<absl::optional<DnsSdInstanceEndpoint>> updated(max_count);
new_end = std::set_intersection(
created_or_updated.begin(), created_or_updated.end(),
deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
IsNotEqualOrUpdate);
updated.erase(new_end, updated.end());
// Use the updated elements to find all created and deleted elements.
std::vector<absl::optional<DnsSdInstanceEndpoint>> created(
created_or_updated.size());
new_end = std::set_difference(
created_or_updated.begin(), created_or_updated.end(), updated.begin(),
updated.end(), created.begin(), IsNotEqualOrUpdate);
created.erase(new_end, created.end());
std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted(
deleted_or_updated.size());
new_end = std::set_difference(
deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
updated.end(), deleted.begin(), IsNotEqualOrUpdate);
deleted.erase(new_end, deleted.end());
// Return the calculated elements back to the caller in the output variables.
created_out->reserve(created.size());
for (absl::optional<DnsSdInstanceEndpoint>& endpoint : created) {
OSP_DCHECK(endpoint.has_value());
created_out->push_back(std::move(endpoint.value()));
}
updated_out->reserve(updated.size());
for (absl::optional<DnsSdInstanceEndpoint>& endpoint : updated) {
OSP_DCHECK(endpoint.has_value());
updated_out->push_back(std::move(endpoint.value()));
}
deleted_out->reserve(deleted.size());
for (absl::optional<DnsSdInstanceEndpoint>& endpoint : deleted) {
OSP_DCHECK(endpoint.has_value());
deleted_out->push_back(std::move(endpoint.value()));
}
}
} // namespace
QuerierImpl::QuerierImpl(MdnsService* mdns_querier,
TaskRunner* task_runner,
ReportingClient* reporting_client,
const NetworkInterfaceConfig* network_config)
: mdns_querier_(mdns_querier),
task_runner_(task_runner),
reporting_client_(reporting_client) {
OSP_DCHECK(mdns_querier_);
OSP_DCHECK(task_runner_);
OSP_DCHECK(network_config);
graph_ = DnsDataGraph::Create(network_config->network_interface());
}
QuerierImpl::~QuerierImpl() = default;
void QuerierImpl::StartQuery(const std::string& service, Callback* callback) {
OSP_DCHECK(callback);
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'";
// Start tracking the new callback
const ServiceKey key(service, kLocalDomain);
auto it = callback_map_.emplace(key, std::vector<Callback*>{}).first;
it->second.push_back(callback);
const DomainName domain = key.GetName();
// If the associated service isn't tracked yet, start tracking it and start
// queries for the relevant PTR records.
if (!graph_->IsTracked(domain)) {
std::function<void(const DomainName&)> mdns_query(
[this, &domain](const DomainName& changed_domain) {
OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'";
mdns_querier_->StartQuery(changed_domain, DnsType::kANY,
DnsClass::kANY, this);
});
graph_->StartTracking(domain, std::move(mdns_query));
return;
}
// Else, it's already being tracked so fire creation callbacks for any already
// found service instances.
const std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints =
graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain);
for (const auto& endpoint : endpoints) {
if (endpoint.is_value()) {
callback->OnEndpointCreated(endpoint.value());
}
}
}
void QuerierImpl::StopQuery(const std::string& service, Callback* callback) {
OSP_DCHECK(callback);
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'";
ServiceKey key(service, kLocalDomain);
const auto callbacks_it = callback_map_.find(key);
if (callbacks_it == callback_map_.end()) {
return;
}
std::vector<Callback*>& callbacks = callbacks_it->second;
const auto it = std::find(callbacks.begin(), callbacks.end(), callback);
if (it == callbacks.end()) {
return;
}
callbacks.erase(it);
if (callbacks.empty()) {
callback_map_.erase(callbacks_it);
ServiceKey key(service, kLocalDomain);
DomainName domain = key.GetName();
std::function<void(const DomainName&)> stop_mdns_query(
[this](const DomainName& changed_domain) {
OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString()
<< "'";
mdns_querier_->StopQuery(changed_domain, DnsType::kANY,
DnsClass::kANY, this);
});
graph_->StopTracking(domain, std::move(stop_mdns_query));
}
}
bool QuerierImpl::IsQueryRunning(const std::string& service) const {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
const ServiceKey key(service, kLocalDomain);
return graph_->IsTracked(key.GetName());
}
void QuerierImpl::ReinitializeQueries(const std::string& service) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DVLOG << "Re-initializing query for service '" << service << "'";
const ServiceKey key(service, kLocalDomain);
const DomainName domain = key.GetName();
std::function<void(const DomainName&)> start_callback(
[this](const DomainName& domain) {
mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this);
});
std::function<void(const DomainName&)> stop_callback(
[this](const DomainName& domain) {
mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this);
});
graph_->StopTracking(domain, std::move(stop_callback));
// Restart top-level queries.
mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name);
graph_->StartTracking(domain, std::move(start_callback));
}
std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged(
const MdnsRecord& record,
RecordChangedEvent event) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DVLOG << "Record " << record.ToString()
<< " has received change of type '" << event << "'";
std::function<void(Error)> log = [this](Error error) mutable {
reporting_client_->OnRecoverableError(
Error(Error::Code::kProcessReceivedRecordFailure));
};
// Get the details to use for calling CreateEndpoints(). Special case PTR
// records to optimize performance.
const DomainName& create_endpoints_domain =
record.dns_type() != DnsType::kPTR
? record.name()
: absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
const DnsDataGraph::DomainGroup create_endpoints_group =
record.dns_type() != DnsType::kPTR
? DnsDataGraph::GetDomainGroup(record)
: DnsDataGraph::DomainGroup::kSrvAndTxt;
// Get the current set of DnsSdInstanceEndpoints prior to this change. Special
// case PTR records to avoid iterating over unrelated child domains.
std::vector<ErrorOr<DnsSdInstanceEndpoint>> old_endpoints_or_errors =
graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
// Apply the changes, creating a list of all pending changes that should be
// applied afterwards.
ErrorOr<std::vector<PendingQueryChange>> pending_changes_or_error =
ApplyRecordChanges(record, event);
if (pending_changes_or_error.is_error()) {
OSP_DVLOG << "Failed to apply changes for " << record.dns_type()
<< " record change of type " << event << " with error "
<< pending_changes_or_error.error();
log(std::move(pending_changes_or_error.error()));
return {};
}
std::vector<PendingQueryChange>& pending_changes =
pending_changes_or_error.value();
// Get the new set of DnsSdInstanceEndpoints following this change.
std::vector<ErrorOr<DnsSdInstanceEndpoint>> new_endpoints_or_errors =
graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
// Return early if the resulting sets are equal. This will frequently be the
// case, especially when both sets are empty.
std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end());
std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end());
if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() &&
std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(),
new_endpoints_or_errors.begin())) {
return pending_changes;
}
// Log all errors and erase them.
ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors,
std::move(log));
const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size();
const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size();
std::vector<DnsSdInstanceEndpoint> old_endpoints =
GetValues(std::move(old_endpoints_or_errors));
std::vector<DnsSdInstanceEndpoint> new_endpoints =
GetValues(std::move(new_endpoints_or_errors));
OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count);
OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count);
// Calculate the changes and call callbacks.
//
// NOTE: As the input sets are expected to be small, the generated sets will
// also be small.
std::vector<DnsSdInstanceEndpoint> created;
std::vector<DnsSdInstanceEndpoint> updated;
std::vector<DnsSdInstanceEndpoint> deleted;
CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints),
&created, &updated, &deleted);
InvokeChangeCallbacks(std::move(created), std::move(updated),
std::move(deleted));
return pending_changes;
}
void QuerierImpl::InvokeChangeCallbacks(
std::vector<DnsSdInstanceEndpoint> created,
std::vector<DnsSdInstanceEndpoint> updated,
std::vector<DnsSdInstanceEndpoint> deleted) {
// Find an endpoint and use it to create the key, or return if there is none.
DnsSdInstanceEndpoint* some_endpoint;
if (!created.empty()) {
some_endpoint = &created.front();
} else if (!updated.empty()) {
some_endpoint = &updated.front();
} else if (!deleted.empty()) {
some_endpoint = &deleted.front();
} else {
return;
}
ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id());
// Find all callbacks.
auto it = callback_map_.find(key);
if (it == callback_map_.end()) {
return;
}
// Call relevant callbacks.
std::vector<Callback*>& callbacks = it->second;
for (Callback* callback : callbacks) {
for (const DnsSdInstanceEndpoint& endpoint : created) {
callback->OnEndpointCreated(endpoint);
}
for (const DnsSdInstanceEndpoint& endpoint : updated) {
callback->OnEndpointUpdated(endpoint);
}
for (const DnsSdInstanceEndpoint& endpoint : deleted) {
callback->OnEndpointDeleted(endpoint);
}
}
}
ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::ApplyRecordChanges(
const MdnsRecord& record,
RecordChangedEvent event) {
std::vector<PendingQueryChange> pending_changes;
std::function<void(DomainName)> creation_callback(
[this, &pending_changes](DomainName domain) mutable {
pending_changes.push_back({std::move(domain), DnsType::kANY,
DnsClass::kANY, this,
PendingQueryChange::kStartQuery});
});
std::function<void(DomainName)> deletion_callback(
[this, &pending_changes](DomainName domain) mutable {
pending_changes.push_back({std::move(domain), DnsType::kANY,
DnsClass::kANY, this,
PendingQueryChange::kStopQuery});
});
Error result =
graph_->ApplyDataRecordChange(record, event, std::move(creation_callback),
std::move(deletion_callback));
if (!result.ok()) {
return result;
}
return pending_changes;
}
} // namespace discovery
} // namespace openscreen