| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "discovery/dnssd/impl/querier_impl.h" |
| |
| #include <algorithm> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "discovery/common/reporting_client.h" |
| #include "discovery/dnssd/impl/conversion_layer.h" |
| #include "discovery/dnssd/impl/network_interface_config.h" |
| #include "platform/api/task_runner.h" |
| #include "util/osp_logging.h" |
| |
| namespace openscreen { |
| namespace discovery { |
| namespace { |
| |
| static constexpr char kLocalDomain[] = "local"; |
| |
| // Removes all error instances from the below records, and calls the log |
| // function on all errors present in |new_endpoints|. Input vectors are expected |
| // to be sorted in ascending order. |
| void ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>>* old_endpoints, |
| std::vector<ErrorOr<DnsSdInstanceEndpoint>>* new_endpoints, |
| std::function<void(Error)> log) { |
| OSP_DCHECK(old_endpoints); |
| OSP_DCHECK(new_endpoints); |
| |
| auto old_it = old_endpoints->begin(); |
| auto new_it = new_endpoints->begin(); |
| |
| // Iterate across both vectors and log new errors in the process. |
| // NOTE: In sorted order, all errors will appear before all non-errors. |
| while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) { |
| ErrorOr<DnsSdInstanceEndpoint>& old_ep = *old_it; |
| ErrorOr<DnsSdInstanceEndpoint>& new_ep = *new_it; |
| |
| if (new_ep.is_value()) { |
| break; |
| } |
| |
| // If they are equal, the element is in both |old_endpoints| and |
| // |new_endpoints|, so skip it in both vectors. |
| if (old_ep == new_ep) { |
| old_it++; |
| new_it++; |
| continue; |
| } |
| |
| // There's an error in |old_endpoints| not in |new_endpoints|, so skip it. |
| if (old_ep < new_ep) { |
| old_it++; |
| continue; |
| } |
| |
| // There's an error in |new_endpoints| not in |old_endpoints|, so it's a new |
| // error from the applied changes. Log it. |
| log(std::move(new_ep.error())); |
| new_it++; |
| } |
| |
| // Skip all remaining errors in the old vector. |
| for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) { |
| } |
| |
| // Log all errors remaining in the new vector. |
| for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) { |
| log(std::move(new_it->error())); |
| } |
| |
| // Erase errors. |
| old_endpoints->erase(old_endpoints->begin(), old_it); |
| new_endpoints->erase(new_endpoints->begin(), new_it); |
| } |
| |
| // Returns a vector containing the value of each ErrorOr<> instance provided. |
| // All ErrorOr<> values are expected to be non-errors. |
| std::vector<DnsSdInstanceEndpoint> GetValues( |
| std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints) { |
| std::vector<DnsSdInstanceEndpoint> results; |
| results.reserve(endpoints.size()); |
| for (ErrorOr<DnsSdInstanceEndpoint>& endpoint : endpoints) { |
| results.push_back(std::move(endpoint.value())); |
| } |
| return results; |
| } |
| |
| bool IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first, |
| const absl::optional<DnsSdInstanceEndpoint>& second) { |
| if (!first.has_value() || !second.has_value()) { |
| return !first.has_value() && !second.has_value(); |
| } |
| |
| // In the remaining case, both |first| and |second| must be values. |
| const DnsSdInstanceEndpoint& a = first.value(); |
| const DnsSdInstanceEndpoint& b = second.value(); |
| |
| // All endpoints from this querier should have the same network interface |
| // because the querier is only associated with a single network interface. |
| OSP_DCHECK_EQ(a.network_interface(), b.network_interface()); |
| |
| // Function returns true if first < second. |
| return a.instance_id() == b.instance_id() && |
| a.service_id() == b.service_id() && a.domain_id() == b.domain_id(); |
| } |
| |
| bool IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first, |
| const absl::optional<DnsSdInstanceEndpoint>& second) { |
| return !IsEqualOrUpdate(first, second); |
| } |
| |
| // Calculates the created, updated, and deleted elements using the provided |
| // sets, appending these values to the provided vectors. Each of the input |
| // vectors is expected to contain only elements such that |
| // |element|.is_error() == false. Additionally, input vectors are expected to |
| // be sorted in ascending order. |
| // |
| // NOTE: A lot of operations are used to do this, but each is only O(n) so the |
| // resulting algorithm is still fast. |
| void CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints, |
| std::vector<DnsSdInstanceEndpoint> new_endpoints, |
| std::vector<DnsSdInstanceEndpoint>* created_out, |
| std::vector<DnsSdInstanceEndpoint>* updated_out, |
| std::vector<DnsSdInstanceEndpoint>* deleted_out) { |
| OSP_DCHECK(created_out); |
| OSP_DCHECK(updated_out); |
| OSP_DCHECK(deleted_out); |
| |
| // Use set difference with default operators to find the elements present in |
| // one list but not the others. |
| // |
| // NOTE: Because absl::optional<...> types are used here and below, calls to |
| // the ctor and dtor for empty elements are no-ops. |
| const int total_count = old_endpoints.size() + new_endpoints.size(); |
| |
| // This is the set of elements that aren't in the old endpoints, meaning the |
| // old endpoint either didn't exist or had different TXT / Address / etc.. |
| std::vector<absl::optional<DnsSdInstanceEndpoint>> created_or_updated( |
| total_count); |
| auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(), |
| old_endpoints.begin(), old_endpoints.end(), |
| created_or_updated.begin()); |
| created_or_updated.erase(new_end, created_or_updated.end()); |
| |
| // This is the set of elements that are only in the old endpoints, similar to |
| // the above. |
| std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted_or_updated( |
| total_count); |
| new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(), |
| new_endpoints.begin(), new_endpoints.end(), |
| deleted_or_updated.begin()); |
| deleted_or_updated.erase(new_end, deleted_or_updated.end()); |
| |
| // Next, find the elements which were updated. |
| const size_t max_count = |
| std::max(created_or_updated.size(), deleted_or_updated.size()); |
| std::vector<absl::optional<DnsSdInstanceEndpoint>> updated(max_count); |
| new_end = std::set_intersection( |
| created_or_updated.begin(), created_or_updated.end(), |
| deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(), |
| IsNotEqualOrUpdate); |
| updated.erase(new_end, updated.end()); |
| |
| // Use the updated elements to find all created and deleted elements. |
| std::vector<absl::optional<DnsSdInstanceEndpoint>> created( |
| created_or_updated.size()); |
| new_end = std::set_difference( |
| created_or_updated.begin(), created_or_updated.end(), updated.begin(), |
| updated.end(), created.begin(), IsNotEqualOrUpdate); |
| created.erase(new_end, created.end()); |
| |
| std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted( |
| deleted_or_updated.size()); |
| new_end = std::set_difference( |
| deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(), |
| updated.end(), deleted.begin(), IsNotEqualOrUpdate); |
| deleted.erase(new_end, deleted.end()); |
| |
| // Return the calculated elements back to the caller in the output variables. |
| created_out->reserve(created.size()); |
| for (absl::optional<DnsSdInstanceEndpoint>& endpoint : created) { |
| OSP_DCHECK(endpoint.has_value()); |
| created_out->push_back(std::move(endpoint.value())); |
| } |
| |
| updated_out->reserve(updated.size()); |
| for (absl::optional<DnsSdInstanceEndpoint>& endpoint : updated) { |
| OSP_DCHECK(endpoint.has_value()); |
| updated_out->push_back(std::move(endpoint.value())); |
| } |
| |
| deleted_out->reserve(deleted.size()); |
| for (absl::optional<DnsSdInstanceEndpoint>& endpoint : deleted) { |
| OSP_DCHECK(endpoint.has_value()); |
| deleted_out->push_back(std::move(endpoint.value())); |
| } |
| } |
| |
| } // namespace |
| |
| QuerierImpl::QuerierImpl(MdnsService* mdns_querier, |
| TaskRunner* task_runner, |
| ReportingClient* reporting_client, |
| const NetworkInterfaceConfig* network_config) |
| : mdns_querier_(mdns_querier), |
| task_runner_(task_runner), |
| reporting_client_(reporting_client) { |
| OSP_DCHECK(mdns_querier_); |
| OSP_DCHECK(task_runner_); |
| |
| OSP_DCHECK(network_config); |
| graph_ = DnsDataGraph::Create(network_config->network_interface()); |
| } |
| |
| QuerierImpl::~QuerierImpl() = default; |
| |
| void QuerierImpl::StartQuery(const std::string& service, Callback* callback) { |
| OSP_DCHECK(callback); |
| OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); |
| |
| OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'"; |
| |
| // Start tracking the new callback |
| const ServiceKey key(service, kLocalDomain); |
| auto it = callback_map_.emplace(key, std::vector<Callback*>{}).first; |
| it->second.push_back(callback); |
| |
| const DomainName domain = key.GetName(); |
| |
| // If the associated service isn't tracked yet, start tracking it and start |
| // queries for the relevant PTR records. |
| if (!graph_->IsTracked(domain)) { |
| std::function<void(const DomainName&)> mdns_query( |
| [this, &domain](const DomainName& changed_domain) { |
| OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'"; |
| mdns_querier_->StartQuery(changed_domain, DnsType::kANY, |
| DnsClass::kANY, this); |
| }); |
| graph_->StartTracking(domain, std::move(mdns_query)); |
| return; |
| } |
| |
| // Else, it's already being tracked so fire creation callbacks for any already |
| // found service instances. |
| const std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints = |
| graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain); |
| for (const auto& endpoint : endpoints) { |
| if (endpoint.is_value()) { |
| callback->OnEndpointCreated(endpoint.value()); |
| } |
| } |
| } |
| |
| void QuerierImpl::StopQuery(const std::string& service, Callback* callback) { |
| OSP_DCHECK(callback); |
| OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); |
| |
| OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'"; |
| |
| ServiceKey key(service, kLocalDomain); |
| const auto callbacks_it = callback_map_.find(key); |
| if (callbacks_it == callback_map_.end()) { |
| return; |
| } |
| |
| std::vector<Callback*>& callbacks = callbacks_it->second; |
| const auto it = std::find(callbacks.begin(), callbacks.end(), callback); |
| if (it == callbacks.end()) { |
| return; |
| } |
| |
| callbacks.erase(it); |
| if (callbacks.empty()) { |
| callback_map_.erase(callbacks_it); |
| |
| ServiceKey key(service, kLocalDomain); |
| DomainName domain = key.GetName(); |
| |
| std::function<void(const DomainName&)> stop_mdns_query( |
| [this](const DomainName& changed_domain) { |
| OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString() |
| << "'"; |
| mdns_querier_->StopQuery(changed_domain, DnsType::kANY, |
| DnsClass::kANY, this); |
| }); |
| graph_->StopTracking(domain, std::move(stop_mdns_query)); |
| } |
| } |
| |
| bool QuerierImpl::IsQueryRunning(const std::string& service) const { |
| OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); |
| const ServiceKey key(service, kLocalDomain); |
| return graph_->IsTracked(key.GetName()); |
| } |
| |
| void QuerierImpl::ReinitializeQueries(const std::string& service) { |
| OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); |
| |
| OSP_DVLOG << "Re-initializing query for service '" << service << "'"; |
| |
| const ServiceKey key(service, kLocalDomain); |
| const DomainName domain = key.GetName(); |
| |
| std::function<void(const DomainName&)> start_callback( |
| [this](const DomainName& domain) { |
| mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this); |
| }); |
| std::function<void(const DomainName&)> stop_callback( |
| [this](const DomainName& domain) { |
| mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this); |
| }); |
| graph_->StopTracking(domain, std::move(stop_callback)); |
| |
| // Restart top-level queries. |
| mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name); |
| |
| graph_->StartTracking(domain, std::move(start_callback)); |
| } |
| |
| std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged( |
| const MdnsRecord& record, |
| RecordChangedEvent event) { |
| OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); |
| |
| OSP_DVLOG << "Record " << record.ToString() |
| << " has received change of type '" << event << "'"; |
| |
| std::function<void(Error)> log = [this](Error error) mutable { |
| reporting_client_->OnRecoverableError( |
| Error(Error::Code::kProcessReceivedRecordFailure)); |
| }; |
| |
| // Get the details to use for calling CreateEndpoints(). Special case PTR |
| // records to optimize performance. |
| const DomainName& create_endpoints_domain = |
| record.dns_type() != DnsType::kPTR |
| ? record.name() |
| : absl::get<PtrRecordRdata>(record.rdata()).ptr_domain(); |
| const DnsDataGraph::DomainGroup create_endpoints_group = |
| record.dns_type() != DnsType::kPTR |
| ? DnsDataGraph::GetDomainGroup(record) |
| : DnsDataGraph::DomainGroup::kSrvAndTxt; |
| |
| // Get the current set of DnsSdInstanceEndpoints prior to this change. Special |
| // case PTR records to avoid iterating over unrelated child domains. |
| std::vector<ErrorOr<DnsSdInstanceEndpoint>> old_endpoints_or_errors = |
| graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain); |
| |
| // Apply the changes, creating a list of all pending changes that should be |
| // applied afterwards. |
| ErrorOr<std::vector<PendingQueryChange>> pending_changes_or_error = |
| ApplyRecordChanges(record, event); |
| if (pending_changes_or_error.is_error()) { |
| OSP_DVLOG << "Failed to apply changes for " << record.dns_type() |
| << " record change of type " << event << " with error " |
| << pending_changes_or_error.error(); |
| log(std::move(pending_changes_or_error.error())); |
| return {}; |
| } |
| std::vector<PendingQueryChange>& pending_changes = |
| pending_changes_or_error.value(); |
| |
| // Get the new set of DnsSdInstanceEndpoints following this change. |
| std::vector<ErrorOr<DnsSdInstanceEndpoint>> new_endpoints_or_errors = |
| graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain); |
| |
| // Return early if the resulting sets are equal. This will frequently be the |
| // case, especially when both sets are empty. |
| std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end()); |
| std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end()); |
| if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() && |
| std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(), |
| new_endpoints_or_errors.begin())) { |
| return pending_changes; |
| } |
| |
| // Log all errors and erase them. |
| ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors, |
| std::move(log)); |
| const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size(); |
| const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size(); |
| std::vector<DnsSdInstanceEndpoint> old_endpoints = |
| GetValues(std::move(old_endpoints_or_errors)); |
| std::vector<DnsSdInstanceEndpoint> new_endpoints = |
| GetValues(std::move(new_endpoints_or_errors)); |
| OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count); |
| OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count); |
| |
| // Calculate the changes and call callbacks. |
| // |
| // NOTE: As the input sets are expected to be small, the generated sets will |
| // also be small. |
| std::vector<DnsSdInstanceEndpoint> created; |
| std::vector<DnsSdInstanceEndpoint> updated; |
| std::vector<DnsSdInstanceEndpoint> deleted; |
| CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints), |
| &created, &updated, &deleted); |
| |
| InvokeChangeCallbacks(std::move(created), std::move(updated), |
| std::move(deleted)); |
| return pending_changes; |
| } |
| |
| void QuerierImpl::InvokeChangeCallbacks( |
| std::vector<DnsSdInstanceEndpoint> created, |
| std::vector<DnsSdInstanceEndpoint> updated, |
| std::vector<DnsSdInstanceEndpoint> deleted) { |
| // Find an endpoint and use it to create the key, or return if there is none. |
| DnsSdInstanceEndpoint* some_endpoint; |
| if (!created.empty()) { |
| some_endpoint = &created.front(); |
| } else if (!updated.empty()) { |
| some_endpoint = &updated.front(); |
| } else if (!deleted.empty()) { |
| some_endpoint = &deleted.front(); |
| } else { |
| return; |
| } |
| ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id()); |
| |
| // Find all callbacks. |
| auto it = callback_map_.find(key); |
| if (it == callback_map_.end()) { |
| return; |
| } |
| |
| // Call relevant callbacks. |
| std::vector<Callback*>& callbacks = it->second; |
| for (Callback* callback : callbacks) { |
| for (const DnsSdInstanceEndpoint& endpoint : created) { |
| callback->OnEndpointCreated(endpoint); |
| } |
| for (const DnsSdInstanceEndpoint& endpoint : updated) { |
| callback->OnEndpointUpdated(endpoint); |
| } |
| for (const DnsSdInstanceEndpoint& endpoint : deleted) { |
| callback->OnEndpointDeleted(endpoint); |
| } |
| } |
| } |
| |
| ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::ApplyRecordChanges( |
| const MdnsRecord& record, |
| RecordChangedEvent event) { |
| std::vector<PendingQueryChange> pending_changes; |
| std::function<void(DomainName)> creation_callback( |
| [this, &pending_changes](DomainName domain) mutable { |
| pending_changes.push_back({std::move(domain), DnsType::kANY, |
| DnsClass::kANY, this, |
| PendingQueryChange::kStartQuery}); |
| }); |
| std::function<void(DomainName)> deletion_callback( |
| [this, &pending_changes](DomainName domain) mutable { |
| pending_changes.push_back({std::move(domain), DnsType::kANY, |
| DnsClass::kANY, this, |
| PendingQueryChange::kStopQuery}); |
| }); |
| Error result = |
| graph_->ApplyDataRecordChange(record, event, std::move(creation_callback), |
| std::move(deletion_callback)); |
| if (!result.ok()) { |
| return result; |
| } |
| |
| return pending_changes; |
| } |
| |
| } // namespace discovery |
| } // namespace openscreen |