blob: 079dbf7d79f8f2a305b5f03b2cae7c6919a7efe6 [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/dnssd/impl/dns_data_graph.h"
#include <utility>
#include "discovery/dnssd/impl/conversion_layer.h"
#include "discovery/dnssd/impl/instance_key.h"
namespace openscreen {
namespace discovery {
namespace {
ErrorOr<DnsSdInstanceEndpoint> CreateEndpoint(
const DomainName& domain,
const absl::optional<ARecordRdata>& a,
const absl::optional<AAAARecordRdata>& aaaa,
const SrvRecordRdata& srv,
const TxtRecordRdata& txt,
NetworkInterfaceIndex network_interface) {
// Create the user-visible TXT record representation.
ErrorOr<DnsSdTxtRecord> txt_or_error = CreateFromDnsTxt(txt);
if (txt_or_error.is_error()) {
return txt_or_error.error();
}
InstanceKey instance_id(domain);
std::vector<IPEndpoint> endpoints;
if (a.has_value()) {
endpoints.push_back({a.value().ipv4_address(), srv.port()});
}
if (aaaa.has_value()) {
endpoints.push_back({aaaa.value().ipv6_address(), srv.port()});
}
return DnsSdInstanceEndpoint(
instance_id.instance_id(), instance_id.service_id(),
instance_id.domain_id(), std::move(txt_or_error.value()),
network_interface, std::move(endpoints));
}
class DnsDataGraphImpl : public DnsDataGraph {
public:
using DnsDataGraph::DomainChangeCallback;
explicit DnsDataGraphImpl(NetworkInterfaceIndex network_interface)
: network_interface_(network_interface) {}
DnsDataGraphImpl(const DnsDataGraphImpl& other) = delete;
DnsDataGraphImpl(DnsDataGraphImpl&& other) = delete;
~DnsDataGraphImpl() override { is_dtor_running_ = true; }
DnsDataGraphImpl& operator=(const DnsDataGraphImpl& rhs) = delete;
DnsDataGraphImpl& operator=(DnsDataGraphImpl&& rhs) = delete;
// DnsDataGraph overrides.
void StartTracking(const DomainName& domain,
DomainChangeCallback on_start_tracking) override;
void StopTracking(const DomainName& domain,
DomainChangeCallback on_stop_tracking) override;
std::vector<ErrorOr<DnsSdInstanceEndpoint>> CreateEndpoints(
DomainGroup domain_group,
const DomainName& name) const override;
Error ApplyDataRecordChange(MdnsRecord record,
RecordChangedEvent event,
DomainChangeCallback on_start_tracking,
DomainChangeCallback on_stop_tracking) override;
size_t GetTrackedDomainCount() const override { return nodes_.size(); }
bool IsTracked(const DomainName& name) const override {
return nodes_.find(name) != nodes_.end();
}
private:
class NodeLifetimeHandler;
using ScopedCallbackHandler = std::unique_ptr<NodeLifetimeHandler>;
// A single node of the graph represented by this type.
class Node {
public:
// NOE: This class is non-copyable, non-movable because either operation
// would invalidate the pointer references or bidirectional edge states
// maintained by instances of this class.
Node(DomainName name, DnsDataGraphImpl* graph);
Node(const Node& other) = delete;
Node(Node&& other) = delete;
~Node();
Node& operator=(const Node& rhs) = delete;
Node& operator=(Node&& rhs) = delete;
// Applies a record change for this node.
Error ApplyDataRecordChange(MdnsRecord record, RecordChangedEvent event);
// Returns the first rdata of a record with type matching |type| in this
// node's |records_|, or absl::nullopt if no such record exists.
template <typename T>
absl::optional<T> GetRdata(DnsType type) {
auto it = FindRecord(type);
if (it == records_.end()) {
return absl::nullopt;
} else {
return std::cref(absl::get<T>(it->rdata()));
}
}
const DomainName& name() const { return name_; }
const std::vector<Node*>& parents() const { return parents_; }
const std::vector<Node*>& children() const { return children_; }
const std::vector<MdnsRecord>& records() const { return records_; }
private:
// Adds or removes an edge in |graph_|.
// NOTE: The same edge may be added multiple times, and one call to remove
// is needed for every such call.
void AddChild(Node* child);
void RemoveChild(Node* child);
// Applies the specified change to domain |child| for this node.
void ApplyChildChange(DomainName child_name, RecordChangedEvent event);
// Finds an iterator to the record of the provided type, or to
// records_.end() if no such record exists.
std::vector<MdnsRecord>::iterator FindRecord(DnsType type);
// The domain with which the data records stored in this node are
// associated.
const DomainName name_;
// Currently extant mDNS Records at |name_|.
std::vector<MdnsRecord> records_;
// Nodes which contain records pointing to this node's |name|.
std::vector<Node*> parents_;
// Nodes containing records pointed to by the records in this node.
std::vector<Node*> children_;
// Graph containing this node.
DnsDataGraphImpl* graph_;
};
// Wrapper to handle the creation and deletion callbacks. When the object is
// created, it sets the callback to use, and erases the callback when it goes
// out of scope. This class allows all node creations to complete before
// calling the user-provided callback to ensure there are no race-conditions.
class NodeLifetimeHandler {
public:
NodeLifetimeHandler(DomainChangeCallback* callback_ptr,
DomainChangeCallback callback);
// NOTE: The copy and delete ctors and operators must be deleted because
// they would invalidate the pointer logic used here.
NodeLifetimeHandler(const NodeLifetimeHandler& other) = delete;
NodeLifetimeHandler(NodeLifetimeHandler&& other) = delete;
~NodeLifetimeHandler();
NodeLifetimeHandler operator=(const NodeLifetimeHandler& other) = delete;
NodeLifetimeHandler operator=(NodeLifetimeHandler&& other) = delete;
private:
std::vector<DomainName> domains_changed;
DomainChangeCallback* callback_ptr_;
DomainChangeCallback callback_;
};
// Helpers to create the ScopedCallbackHandlers for creation and deletion
// callbacks.
ScopedCallbackHandler GetScopedCreationHandler(
DomainChangeCallback creation_callback);
ScopedCallbackHandler GetScopedDeletionHandler(
DomainChangeCallback deletion_callback);
// Determines whether the provided node has the necessary records to be a
// valid node at the specified domain level.
static bool IsValidAddressNode(Node* node);
static bool IsValidSrvAndTxtNode(Node* node);
// Calculates the set of DnsSdInstanceEndpoints associated with the PTR
// records present at the given |node|.
std::vector<ErrorOr<DnsSdInstanceEndpoint>> CalculatePtrRecordEndpoints(
Node* node) const;
// Denotes whether the dtor for this instance has been called. This is
// required for validation of Node instance functionality. See the
// implementation of DnsDataGraph::Node::~Node() for more details.
bool is_dtor_running_ = false;
// Map from domain name to the node containing all records associated with the
// name.
std::map<DomainName, std::unique_ptr<Node>> nodes_;
const NetworkInterfaceIndex network_interface_;
// The methods to be called when a domain name either starts or stops being
// referenced. These will only be set when a record change is ongoing, and act
// as a single source of truth for the creation and deletion callbacks that
// should be used during that operation.
DomainChangeCallback on_node_creation_;
DomainChangeCallback on_node_deletion_;
};
DnsDataGraphImpl::Node::Node(DomainName name, DnsDataGraphImpl* graph)
: name_(std::move(name)), graph_(graph) {
OSP_DCHECK(graph_);
graph_->on_node_creation_(name_);
}
DnsDataGraphImpl::Node::~Node() {
// A node should only be deleted when it has no parents. The only case where
// a deletion can occur when parents are still extant is during destruction of
// the holding graph. In that case, the state of the graph no longer matters
// and all nodes will be deleted, so no need to consider the child pointers.
if (!graph_->is_dtor_running_) {
auto it = std::find_if(parents_.begin(), parents_.end(),
[this](Node* parent) { return parent != this; });
OSP_DCHECK(it == parents_.end());
// Erase all childrens' parent pointers to this node.
for (Node* child : children_) {
RemoveChild(child);
}
OSP_DCHECK(graph_->on_node_deletion_);
graph_->on_node_deletion_(name_);
}
}
Error DnsDataGraphImpl::Node::ApplyDataRecordChange(MdnsRecord record,
RecordChangedEvent event) {
OSP_DCHECK(record.name() == name_);
// The child domain to which the changed record points, or none. This is only
// applicable for PTR and SRV records, and is empty in all other cases.
DomainName child_name;
// The location of the current record. In the case of PTR records, multiple
// records are allowed for the same domain. In all other cases, this is not
// valid.
std::vector<MdnsRecord>::iterator it;
if (record.dns_type() == DnsType::kPTR) {
child_name = absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
it = std::find_if(records_.begin(), records_.end(),
[record](const MdnsRecord& rhs) {
return record.IsReannouncementOf(rhs);
});
} else {
if (record.dns_type() == DnsType::kSRV) {
child_name = absl::get<SrvRecordRdata>(record.rdata()).target();
}
it = FindRecord(record.dns_type());
}
// Validate that the requested change is allowed and apply it.
switch (event) {
case RecordChangedEvent::kCreated:
if (it != records_.end()) {
return Error::Code::kItemAlreadyExists;
}
records_.push_back(std::move(record));
break;
case RecordChangedEvent::kUpdated:
if (it == records_.end()) {
return Error::Code::kItemNotFound;
}
*it = std::move(record);
break;
case RecordChangedEvent::kExpired:
if (it == records_.end()) {
return Error::Code::kItemNotFound;
}
records_.erase(it);
break;
}
// Apply any required edge changes to the graph. This is only applicable if
// a |child| was found earlier. Note that the same child can be added multiple
// times to the |children_| vector, which simplifies the code dramatically.
if (!child_name.empty()) {
ApplyChildChange(std::move(child_name), event);
}
return Error::None();
}
void DnsDataGraphImpl::Node::ApplyChildChange(DomainName child_name,
RecordChangedEvent event) {
if (event == RecordChangedEvent::kCreated) {
const auto pair =
graph_->nodes_.emplace(child_name, std::unique_ptr<Node>());
if (pair.second) {
auto new_node = std::make_unique<Node>(std::move(child_name), graph_);
pair.first->second.swap(new_node);
}
AddChild(pair.first->second.get());
} else if (event == RecordChangedEvent::kExpired) {
const auto it = graph_->nodes_.find(child_name);
OSP_DCHECK(it != graph_->nodes_.end());
RemoveChild(it->second.get());
}
}
void DnsDataGraphImpl::Node::AddChild(Node* child) {
OSP_DCHECK(child);
children_.push_back(child);
child->parents_.push_back(this);
}
void DnsDataGraphImpl::Node::RemoveChild(Node* child) {
OSP_DCHECK(child);
auto it = std::find(children_.begin(), children_.end(), child);
OSP_DCHECK(it != children_.end());
children_.erase(it);
it = std::find(child->parents_.begin(), child->parents_.end(), this);
OSP_DCHECK(it != child->parents_.end());
child->parents_.erase(it);
// If the node has been orphaned, remove it.
it = std::find_if(child->parents_.begin(), child->parents_.end(),
[child](Node* parent) { return parent != child; });
if (it == child->parents_.end()) {
DomainName child_name = child->name();
const size_t count = graph_->nodes_.erase(child_name);
OSP_DCHECK(child == this || count);
}
}
std::vector<MdnsRecord>::iterator DnsDataGraphImpl::Node::FindRecord(
DnsType type) {
return std::find_if(
records_.begin(), records_.end(),
[type](const MdnsRecord& record) { return record.dns_type() == type; });
}
DnsDataGraphImpl::NodeLifetimeHandler::NodeLifetimeHandler(
DomainChangeCallback* callback_ptr,
DomainChangeCallback callback)
: callback_ptr_(callback_ptr), callback_(callback) {
OSP_DCHECK(callback_ptr_);
OSP_DCHECK(callback);
OSP_DCHECK(*callback_ptr_ == nullptr);
*callback_ptr = [this](DomainName domain) {
domains_changed.push_back(std::move(domain));
};
}
DnsDataGraphImpl::NodeLifetimeHandler::~NodeLifetimeHandler() {
*callback_ptr_ = nullptr;
for (DomainName& domain : domains_changed) {
callback_(domain);
}
}
DnsDataGraphImpl::ScopedCallbackHandler
DnsDataGraphImpl::GetScopedCreationHandler(
DomainChangeCallback creation_callback) {
return std::make_unique<NodeLifetimeHandler>(&on_node_creation_,
std::move(creation_callback));
}
DnsDataGraphImpl::ScopedCallbackHandler
DnsDataGraphImpl::GetScopedDeletionHandler(
DomainChangeCallback deletion_callback) {
return std::make_unique<NodeLifetimeHandler>(&on_node_deletion_,
std::move(deletion_callback));
}
void DnsDataGraphImpl::StartTracking(const DomainName& domain,
DomainChangeCallback on_start_tracking) {
ScopedCallbackHandler creation_handler =
GetScopedCreationHandler(std::move(on_start_tracking));
auto pair = nodes_.emplace(domain, std::make_unique<Node>(domain, this));
OSP_DCHECK(pair.second);
OSP_DCHECK(nodes_.find(domain) != nodes_.end());
}
void DnsDataGraphImpl::StopTracking(const DomainName& domain,
DomainChangeCallback on_stop_tracking) {
ScopedCallbackHandler deletion_handler =
GetScopedDeletionHandler(std::move(on_stop_tracking));
auto it = nodes_.find(domain);
OSP_CHECK(it != nodes_.end());
OSP_DCHECK(it->second->parents().empty());
it->second.reset();
const size_t erased_count = nodes_.erase(domain);
OSP_DCHECK(erased_count);
}
Error DnsDataGraphImpl::ApplyDataRecordChange(
MdnsRecord record,
RecordChangedEvent event,
DomainChangeCallback on_start_tracking,
DomainChangeCallback on_stop_tracking) {
ScopedCallbackHandler creation_handler =
GetScopedCreationHandler(std::move(on_start_tracking));
ScopedCallbackHandler deletion_handler =
GetScopedDeletionHandler(std::move(on_stop_tracking));
auto it = nodes_.find(record.name());
if (it == nodes_.end()) {
return Error::Code::kOperationCancelled;
}
const auto result =
it->second->ApplyDataRecordChange(std::move(record), event);
return result;
}
std::vector<ErrorOr<DnsSdInstanceEndpoint>> DnsDataGraphImpl::CreateEndpoints(
DomainGroup domain_group,
const DomainName& name) const {
const auto it = nodes_.find(name);
if (it == nodes_.end()) {
return {};
}
Node* target_node = it->second.get();
// NOTE: One of these will contain no more than one element, so iterating over
// them both will be fast.
std::vector<Node*> srv_and_txt_record_nodes;
std::vector<Node*> address_record_nodes;
switch (domain_group) {
case DomainGroup::kAddress:
if (!IsValidAddressNode(target_node)) {
return {};
}
address_record_nodes.push_back(target_node);
srv_and_txt_record_nodes = target_node->parents();
break;
case DomainGroup::kSrvAndTxt:
if (!IsValidSrvAndTxtNode(target_node)) {
return {};
}
srv_and_txt_record_nodes.push_back(target_node);
address_record_nodes = target_node->children();
break;
case DomainGroup::kPtr:
return CalculatePtrRecordEndpoints(target_node);
default:
return {};
}
// Iterate across all node pairs and create all possible DnsSdInstanceEndpoint
// objects.
std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
for (Node* srv_and_txt : srv_and_txt_record_nodes) {
for (Node* address : address_record_nodes) {
// First, there has to be a SRV record present (to provide the port
// number), and the target of that SRV record has to be the node where the
// address records are sourced from.
const absl::optional<SrvRecordRdata> srv =
srv_and_txt->GetRdata<SrvRecordRdata>(DnsType::kSRV);
if (!srv.has_value() || srv.value().target() != address->name()) {
continue;
}
// Next, a TXT record must be present to provide additional connection
// information about the service per RFC 6763.
const absl::optional<TxtRecordRdata> txt =
srv_and_txt->GetRdata<TxtRecordRdata>(DnsType::kTXT);
if (!txt.has_value()) {
continue;
}
// Last, at least one address record must be present to provide an
// endpoint for this instance.
const absl::optional<ARecordRdata> a =
address->GetRdata<ARecordRdata>(DnsType::kA);
const absl::optional<AAAARecordRdata> aaaa =
address->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
if (!a.has_value() && !aaaa.has_value()) {
continue;
}
// Then use the above info to create an endpoint object. If an error
// occurs, this is only related to the one endpoint and its possible that
// other endpoints may still be valid, so only the one endpoint is treated
// as failing. For instance, a bad TXT record for service A will not
// affect the endpoints for service B.
ErrorOr<DnsSdInstanceEndpoint> endpoint =
CreateEndpoint(srv_and_txt->name(), a, aaaa, srv.value(), txt.value(),
network_interface_);
endpoints.push_back(std::move(endpoint));
}
}
return endpoints;
}
// static
bool DnsDataGraphImpl::IsValidAddressNode(Node* node) {
const absl::optional<ARecordRdata> a =
node->GetRdata<ARecordRdata>(DnsType::kA);
const absl::optional<AAAARecordRdata> aaaa =
node->GetRdata<AAAARecordRdata>(DnsType::kAAAA);
return a.has_value() || aaaa.has_value();
}
// static
bool DnsDataGraphImpl::IsValidSrvAndTxtNode(Node* node) {
const absl::optional<SrvRecordRdata> srv =
node->GetRdata<SrvRecordRdata>(DnsType::kSRV);
const absl::optional<TxtRecordRdata> txt =
node->GetRdata<TxtRecordRdata>(DnsType::kTXT);
return srv.has_value() && txt.has_value();
}
std::vector<ErrorOr<DnsSdInstanceEndpoint>>
DnsDataGraphImpl::CalculatePtrRecordEndpoints(Node* node) const {
// PTR records aren't actually part of the generated endpoint objects, so
// call this method recursively on all children and
std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints;
for (const MdnsRecord& record : node->records()) {
if (record.dns_type() != DnsType::kPTR) {
continue;
}
const DomainName domain =
absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
const Node* child = nodes_.find(domain)->second.get();
std::vector<ErrorOr<DnsSdInstanceEndpoint>> child_endpoints =
CreateEndpoints(DomainGroup::kSrvAndTxt, child->name());
for (auto& endpoint_or_error : child_endpoints) {
endpoints.push_back(std::move(endpoint_or_error));
}
}
return endpoints;
}
} // namespace
DnsDataGraph::~DnsDataGraph() = default;
// static
std::unique_ptr<DnsDataGraph> DnsDataGraph::Create(
NetworkInterfaceIndex network_interface) {
return std::make_unique<DnsDataGraphImpl>(network_interface);
}
// static
DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(DnsType type) {
switch (type) {
case DnsType::kA:
case DnsType::kAAAA:
return DnsDataGraphImpl::DomainGroup::kAddress;
case DnsType::kSRV:
case DnsType::kTXT:
return DnsDataGraphImpl::DomainGroup::kSrvAndTxt;
case DnsType::kPTR:
return DnsDataGraphImpl::DomainGroup::kPtr;
default:
OSP_NOTREACHED();
}
}
// static
DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(
const MdnsRecord record) {
return GetDomainGroup(record.dns_type());
}
} // namespace discovery
} // namespace openscreen