Update usage of ErrorOr in openscreen
This patch includes changes that make the usage of the ErrorOr
type more widespread throughout our codebase. Currently, we
only have a few example of this class, however there are multiple
opportunities for its use elsewhere.
NOTE: this patch mostly leaves alone the api/public services,
specifically the following classes are left using bool:
codegen (Write*)
mdns_responder_service (Handle*Event)
quic_client (Start, Stop)
quic_server (Start, Stop, Suspend, Resume)
service_listener, _impl (Start, Stop, ...)
service_publisher, _impl (Start, Stop, ....)
Bug: openscreen:26
Change-Id: I558b61a29046a263014e295958059d979cdca67e
Reviewed-on: https://chromium-review.googlesource.com/c/1443817
Commit-Queue: Jordan Bayles <jophba@chromium.org>
Reviewed-by: mark a. foltz <mfoltz@chromium.org>
diff --git a/api/impl/internal_services.cc b/api/impl/internal_services.cc
index d1ac976..734255d 100644
--- a/api/impl/internal_services.cc
+++ b/api/impl/internal_services.cc
@@ -7,6 +7,7 @@
#include <algorithm>
#include "api/impl/mdns_responder_service.h"
+#include "base/error.h"
#include "discovery/mdns/mdns_responder_adapter_impl.h"
#include "platform/api/error.h"
#include "platform/api/logging.h"
@@ -36,31 +37,35 @@
}
};
-bool SetUpMulticastSocket(platform::UdpSocketPtr socket,
- platform::NetworkInterfaceIndex ifindex) {
- do {
- const IPAddress broadcast_address =
- IsIPv6Socket(socket) ? kMulticastIPv6Address : kMulticastAddress;
- if (!JoinUdpMulticastGroup(socket, broadcast_address, ifindex)) {
- OSP_LOG_ERROR << "join multicast group failed for interface " << ifindex
- << ": " << platform::GetLastErrorString();
- break;
- }
-
- if (!BindUdpSocket(socket, {{}, kMulticastListeningPort}, ifindex)) {
- OSP_LOG_ERROR << "bind failed for interface " << ifindex << ": "
- << platform::GetLastErrorString();
- break;
- }
-
- return true;
- } while (false);
-
+Error GetLastError() {
+ // TODO(jophba): Add platform error handling, we shouldn't know about
+ // EADDRINUSE here.
if (platform::GetLastError() == EADDRINUSE) {
OSP_LOG_ERROR
<< "Is there a mDNS service, such as Bonjour, already running?";
+ return Error::Code::kAddressInUse;
}
- return false;
+
+ return Error::Code::kGenericPlatformError;
+}
+
+Error SetUpMulticastSocket(platform::UdpSocketPtr socket,
+ platform::NetworkInterfaceIndex ifindex) {
+ const IPAddress broadcast_address =
+ IsIPv6Socket(socket) ? kMulticastIPv6Address : kMulticastAddress;
+ if (!JoinUdpMulticastGroup(socket, broadcast_address, ifindex).ok()) {
+ OSP_LOG_ERROR << "join multicast group failed for interface " << ifindex
+ << ": " << platform::GetLastErrorString();
+ return GetLastError();
+ }
+
+ if (!BindUdpSocket(socket, {{}, kMulticastListeningPort}, ifindex).ok()) {
+ OSP_LOG_ERROR << "bind failed for interface " << ifindex << ": "
+ << platform::GetLastErrorString();
+ return GetLastError();
+ }
+
+ return Error::None();
}
// Ref-counted singleton instance of InternalServices. This lives only as long
@@ -139,7 +144,7 @@
auto* socket = addr.addresses.front().address.IsV4()
? platform::CreateUdpSocketIPv4()
: platform::CreateUdpSocketIPv6();
- if (!SetUpMulticastSocket(socket, index)) {
+ if (!SetUpMulticastSocket(socket, index).ok()) {
DestroyUdpSocket(socket);
continue;
}
diff --git a/api/impl/mdns_responder_service.cc b/api/impl/mdns_responder_service.cc
index dda8517..03d4cfe 100644
--- a/api/impl/mdns_responder_service.cc
+++ b/api/impl/mdns_responder_service.cc
@@ -7,6 +7,7 @@
#include <algorithm>
#include <utility>
+#include "base/error.h"
#include "base/make_unique.h"
#include "platform/api/logging.h"
@@ -241,17 +242,17 @@
interface.subnet, interface.socket);
}
}
- mdns::DomainName service_type;
- OSP_CHECK(mdns::DomainName::FromLabels(service_type_.begin(),
- service_type_.end(), &service_type));
+ ErrorOr<mdns::DomainName> service_type =
+ mdns::DomainName::FromLabels(service_type_.begin(), service_type_.end());
+ OSP_CHECK(service_type);
for (const auto& interface : bound_interfaces_)
- mdns_responder_->StartPtrQuery(interface.socket, service_type);
+ mdns_responder_->StartPtrQuery(interface.socket, service_type.value());
}
void MdnsResponderService::StopListening() {
- mdns::DomainName service_type;
- OSP_CHECK(mdns::DomainName::FromLabels(service_type_.begin(),
- service_type_.end(), &service_type));
+ ErrorOr<mdns::DomainName> service_type =
+ mdns::DomainName::FromLabels(service_type_.begin(), service_type_.end());
+ OSP_CHECK(service_type);
for (const auto& kv : network_scoped_domain_to_host_) {
const NetworkScopedDomainName& scoped_domain = kv.first;
@@ -268,7 +269,7 @@
}
service_by_name_.clear();
for (const auto& interface : bound_interfaces_)
- mdns_responder_->StopPtrQuery(interface.socket, service_type);
+ mdns_responder_->StopPtrQuery(interface.socket, service_type.value());
RemoveAllReceivers();
}
@@ -298,13 +299,16 @@
}
}
mdns_responder_->SetHostLabel(service_hostname_);
- mdns::DomainName domain_name;
- OSP_CHECK(mdns::DomainName::FromLabels(&service_hostname_,
- &service_hostname_ + 1, &domain_name))
- << "bad hostname configured: " << service_hostname_;
- OSP_CHECK(domain_name.Append(mdns::DomainName::GetLocalDomain()));
+ ErrorOr<mdns::DomainName> domain_name =
+ mdns::DomainName::FromLabels(&service_hostname_, &service_hostname_ + 1);
+ OSP_CHECK(domain_name) << "bad hostname configured: " << service_hostname_;
+ mdns::DomainName name = domain_name.MoveValue();
+
+ Error error = name.Append(mdns::DomainName::GetLocalDomain());
+ OSP_CHECK(error.ok());
+
mdns_responder_->RegisterService(service_instance_name_, service_type_[0],
- service_type_[1], domain_name, service_port_,
+ service_type_[1], name, service_port_,
service_txt_data_);
}
diff --git a/api/impl/presentation/url_availability_requester.cc b/api/impl/presentation/url_availability_requester.cc
index abbe551..c7f47a8 100644
--- a/api/impl/presentation/url_availability_requester.cc
+++ b/api/impl/presentation/url_availability_requester.cc
@@ -190,8 +190,7 @@
return;
uint64_t request_id = next_request_id++;
ErrorOr<uint64_t> watch_id_or_error(0);
- if (!connection ||
- (watch_id_or_error = SendRequest(request_id, urls)).is_value()) {
+ if (!connection || (watch_id_or_error = SendRequest(request_id, urls))) {
request_by_id.emplace(request_id,
Request{watch_id_or_error.value(), std::move(urls)});
} else {
@@ -324,7 +323,7 @@
for (auto& url : still_observed_urls)
urls.emplace_back(std::move(url));
if (!connection ||
- (watch_id_or_error = SendRequest(new_request_id, urls)).is_value()) {
+ (watch_id_or_error = SendRequest(new_request_id, urls))) {
new_requests.emplace(new_request_id,
Request{watch_id_or_error.value(), std::move(urls)});
} else {
@@ -386,8 +385,7 @@
this->connection = std::move(connection);
ErrorOr<uint64_t> watch_id_or_error(0);
for (auto entry = request_by_id.begin(); entry != request_by_id.end();) {
- if ((watch_id_or_error = SendRequest(entry->first, entry->second.urls))
- .is_value()) {
+ if ((watch_id_or_error = SendRequest(entry->first, entry->second.urls))) {
entry->second.watch_id = watch_id_or_error.value();
++entry;
} else {
diff --git a/api/impl/quic/quic_client_unittest.cc b/api/impl/quic/quic_client_unittest.cc
index ff93176..74ee334 100644
--- a/api/impl/quic/quic_client_unittest.cc
+++ b/api/impl/quic/quic_client_unittest.cc
@@ -56,8 +56,6 @@
: connection_(connection) {}
~ConnectionCallback() override = default;
- bool failed() const { return failed_; }
-
void OnConnectionOpened(
uint64_t request_id,
std::unique_ptr<ProtocolConnection>&& connection) override {
diff --git a/api/impl/receiver_list.cc b/api/impl/receiver_list.cc
index 8e6d10a..f21333a 100644
--- a/api/impl/receiver_list.cc
+++ b/api/impl/receiver_list.cc
@@ -15,31 +15,31 @@
receivers_.emplace_back(info);
}
-bool ReceiverList::OnReceiverChanged(const ServiceInfo& info) {
+Error ReceiverList::OnReceiverChanged(const ServiceInfo& info) {
auto existing_info = std::find_if(receivers_.begin(), receivers_.end(),
[&info](const ServiceInfo& x) {
return x.service_id == info.service_id;
});
if (existing_info == receivers_.end())
- return false;
+ return Error::Code::kNoItemFound;
*existing_info = info;
- return true;
+ return Error::None();
}
-bool ReceiverList::OnReceiverRemoved(const ServiceInfo& info) {
+Error ReceiverList::OnReceiverRemoved(const ServiceInfo& info) {
const auto it = std::remove(receivers_.begin(), receivers_.end(), info);
if (it == receivers_.end())
- return false;
+ return Error::Code::kNoItemFound;
receivers_.erase(it, receivers_.end());
- return true;
+ return Error::None();
}
-bool ReceiverList::OnAllReceiversRemoved() {
+Error ReceiverList::OnAllReceiversRemoved() {
const auto empty = receivers_.empty();
receivers_.clear();
- return !empty;
+ return empty ? Error::Code::kNoItemFound : Error::None();
}
} // namespace openscreen
diff --git a/api/impl/receiver_list.h b/api/impl/receiver_list.h
index 6a258cb..686b6c7 100644
--- a/api/impl/receiver_list.h
+++ b/api/impl/receiver_list.h
@@ -8,6 +8,7 @@
#include <vector>
#include "api/public/service_info.h"
+#include "base/error.h"
namespace openscreen {
@@ -20,17 +21,9 @@
void OnReceiverAdded(const ServiceInfo& info);
- // Returns true if |info.service_id| matched an item in |receivers_| and was
- // therefore changed, otherwise false.
- bool OnReceiverChanged(const ServiceInfo& info);
-
- // Returns true if an item matching |info| was removed from |receivers_|,
- // otherwise false.
- bool OnReceiverRemoved(const ServiceInfo& info);
-
- // Returns true if |receivers_| was not empty before this call, otherwise
- // false.
- bool OnAllReceiversRemoved();
+ Error OnReceiverChanged(const ServiceInfo& info);
+ Error OnReceiverRemoved(const ServiceInfo& info);
+ Error OnAllReceiversRemoved();
const std::vector<ServiceInfo>& receivers() const { return receivers_; }
diff --git a/api/impl/receiver_list_unittest.cc b/api/impl/receiver_list_unittest.cc
index 6215bf5..7ca93b6 100644
--- a/api/impl/receiver_list_unittest.cc
+++ b/api/impl/receiver_list_unittest.cc
@@ -3,7 +3,7 @@
// found in the LICENSE file.
#include "api/impl/receiver_list.h"
-
+#include "base/error.h"
#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
namespace openscreen {
@@ -50,8 +50,8 @@
list.OnReceiverAdded(receiver1);
list.OnReceiverAdded(receiver2);
- EXPECT_TRUE(list.OnReceiverChanged(receiver1_alt_name));
- EXPECT_FALSE(list.OnReceiverChanged(receiver3));
+ EXPECT_TRUE(list.OnReceiverChanged(receiver1_alt_name).ok());
+ EXPECT_FALSE(list.OnReceiverChanged(receiver3).ok());
ASSERT_EQ(2u, list.receivers().size());
EXPECT_EQ(receiver1_alt_name, list.receivers()[0]);
@@ -64,13 +64,13 @@
"id1", "name1", 1, {{192, 168, 1, 10}, 12345}, {}};
const ServiceInfo receiver2{
"id2", "name2", 1, {{192, 168, 1, 11}, 12345}, {}};
- EXPECT_FALSE(list.OnReceiverRemoved(receiver1));
+ EXPECT_FALSE(list.OnReceiverRemoved(receiver1).ok());
list.OnReceiverAdded(receiver1);
- EXPECT_FALSE(list.OnReceiverRemoved(receiver2));
+ EXPECT_FALSE(list.OnReceiverRemoved(receiver2).ok());
list.OnReceiverAdded(receiver2);
list.OnReceiverAdded(receiver1);
- EXPECT_TRUE(list.OnReceiverRemoved(receiver1));
+ EXPECT_TRUE(list.OnReceiverRemoved(receiver1).ok());
ASSERT_EQ(1u, list.receivers().size());
EXPECT_EQ(receiver2, list.receivers()[0]);
@@ -82,11 +82,11 @@
"id1", "name1", 1, {{192, 168, 1, 10}, 12345}, {}};
const ServiceInfo receiver2{
"id2", "name2", 1, {{192, 168, 1, 11}, 12345}, {}};
- EXPECT_FALSE(list.OnAllReceiversRemoved());
+ EXPECT_FALSE(list.OnAllReceiversRemoved().ok());
list.OnReceiverAdded(receiver1);
list.OnReceiverAdded(receiver2);
- EXPECT_TRUE(list.OnAllReceiversRemoved());
+ EXPECT_TRUE(list.OnAllReceiversRemoved().ok());
ASSERT_TRUE(list.receivers().empty());
}
diff --git a/api/impl/service_listener_impl.cc b/api/impl/service_listener_impl.cc
index 97e183d..1469d32 100644
--- a/api/impl/service_listener_impl.cc
+++ b/api/impl/service_listener_impl.cc
@@ -3,7 +3,7 @@
// found in the LICENSE file.
#include "api/impl/service_listener_impl.h"
-
+#include "base/error.h"
#include "platform/api/logging.h"
namespace openscreen {
@@ -67,20 +67,20 @@
}
void ServiceListenerImpl::OnReceiverChanged(const ServiceInfo& info) {
- const auto any_changed = receiver_list_.OnReceiverChanged(info);
- if (any_changed && observer_)
+ const Error changed_error = receiver_list_.OnReceiverChanged(info);
+ if (changed_error.ok() && observer_)
observer_->OnReceiverChanged(info);
}
void ServiceListenerImpl::OnReceiverRemoved(const ServiceInfo& info) {
- const auto any_removed = receiver_list_.OnReceiverRemoved(info);
- if (any_removed && observer_)
+ const Error removed_error = receiver_list_.OnReceiverRemoved(info);
+ if (removed_error.ok() && observer_)
observer_->OnReceiverRemoved(info);
}
void ServiceListenerImpl::OnAllReceiversRemoved() {
- const auto any_removed = receiver_list_.OnAllReceiversRemoved();
- if (any_removed && observer_)
+ const Error removed_all_error = receiver_list_.OnAllReceiversRemoved();
+ if (removed_all_error.ok() && observer_)
observer_->OnAllReceiversRemoved();
}
diff --git a/api/impl/testing/fake_mdns_responder_adapter.cc b/api/impl/testing/fake_mdns_responder_adapter.cc
index de42405..f3d27d0 100644
--- a/api/impl/testing/fake_mdns_responder_adapter.cc
+++ b/api/impl/testing/fake_mdns_responder_adapter.cc
@@ -6,6 +6,7 @@
#include <algorithm>
+#include "base/error.h"
#include "platform/api/logging.h"
namespace openscreen {
@@ -18,12 +19,12 @@
platform::UdpSocketPtr socket) {
const auto labels = std::vector<std::string>{service_instance, service_type,
service_protocol, kLocalDomain};
- mdns::DomainName full_instance_name;
- OSP_CHECK(mdns::DomainName::FromLabels(labels.begin(), labels.end(),
- &full_instance_name));
+ ErrorOr<mdns::DomainName> full_instance_name =
+ mdns::DomainName::FromLabels(labels.begin(), labels.end());
+ OSP_CHECK(full_instance_name);
mdns::PtrEvent result{
mdns::QueryEventHeader{mdns::QueryEventHeader::Type::kAdded, socket},
- full_instance_name};
+ full_instance_name.value()};
return result;
}
@@ -35,16 +36,18 @@
platform::UdpSocketPtr socket) {
const auto instance_labels = std::vector<std::string>{
service_instance, service_type, service_protocol, kLocalDomain};
- mdns::DomainName full_instance_name;
- OSP_CHECK(mdns::DomainName::FromLabels(
- instance_labels.begin(), instance_labels.end(), &full_instance_name));
+ ErrorOr<mdns::DomainName> full_instance_name = mdns::DomainName::FromLabels(
+ instance_labels.begin(), instance_labels.end());
+ OSP_CHECK(full_instance_name);
+
const auto host_labels = std::vector<std::string>{hostname, kLocalDomain};
- mdns::DomainName domain_name;
- OSP_CHECK(mdns::DomainName::FromLabels(host_labels.begin(), host_labels.end(),
- &domain_name));
+ ErrorOr<mdns::DomainName> domain_name =
+ mdns::DomainName::FromLabels(host_labels.begin(), host_labels.end());
+ OSP_CHECK(domain_name);
+
mdns::SrvEvent result{
mdns::QueryEventHeader{mdns::QueryEventHeader::Type::kAdded, socket},
- full_instance_name, domain_name, port};
+ full_instance_name.value(), domain_name.value(), port};
return result;
}
@@ -55,12 +58,12 @@
platform::UdpSocketPtr socket) {
const auto labels = std::vector<std::string>{service_instance, service_type,
service_protocol, kLocalDomain};
- mdns::DomainName full_instance_name;
- OSP_CHECK(mdns::DomainName::FromLabels(labels.begin(), labels.end(),
- &full_instance_name));
+ ErrorOr<mdns::DomainName> domain_name =
+ mdns::DomainName::FromLabels(labels.begin(), labels.end());
+ OSP_CHECK(domain_name);
mdns::TxtEvent result{
mdns::QueryEventHeader{mdns::QueryEventHeader::Type::kAdded, socket},
- full_instance_name, txt_lines};
+ domain_name.value(), txt_lines};
return result;
}
@@ -68,12 +71,12 @@
IPAddress address,
platform::UdpSocketPtr socket) {
const auto labels = std::vector<std::string>{hostname, kLocalDomain};
- mdns::DomainName domain_name;
- OSP_CHECK(
- mdns::DomainName::FromLabels(labels.begin(), labels.end(), &domain_name));
+ ErrorOr<mdns::DomainName> domain_name =
+ mdns::DomainName::FromLabels(labels.begin(), labels.end());
+ OSP_CHECK(domain_name);
mdns::AEvent result{
mdns::QueryEventHeader{mdns::QueryEventHeader::Type::kAdded, socket},
- domain_name, address};
+ domain_name.value(), address};
return result;
}
@@ -81,12 +84,12 @@
IPAddress address,
platform::UdpSocketPtr socket) {
const auto labels = std::vector<std::string>{hostname, kLocalDomain};
- mdns::DomainName domain_name;
- OSP_CHECK(
- mdns::DomainName::FromLabels(labels.begin(), labels.end(), &domain_name));
+ ErrorOr<mdns::DomainName> domain_name =
+ mdns::DomainName::FromLabels(labels.begin(), labels.end());
+ OSP_CHECK(domain_name);
mdns::AaaaEvent result{
mdns::QueryEventHeader{mdns::QueryEventHeader::Type::kAdded, socket},
- domain_name, address};
+ domain_name.value(), address};
return result;
}
@@ -180,10 +183,10 @@
return true;
}
-bool FakeMdnsResponderAdapter::Init() {
+Error FakeMdnsResponderAdapter::Init() {
OSP_CHECK(!running_);
running_ = true;
- return true;
+ return Error::None();
}
void FakeMdnsResponderAdapter::Close() {
@@ -198,28 +201,28 @@
running_ = false;
}
-bool FakeMdnsResponderAdapter::SetHostLabel(const std::string& host_label) {
- return false;
+Error FakeMdnsResponderAdapter::SetHostLabel(const std::string& host_label) {
+ return Error::Code::kNotImplemented;
}
-bool FakeMdnsResponderAdapter::RegisterInterface(
+Error FakeMdnsResponderAdapter::RegisterInterface(
const platform::InterfaceInfo& interface_info,
const platform::IPSubnet& interface_address,
platform::UdpSocketPtr socket) {
if (!running_)
- return false;
+ return Error::Code::kNotRunning;
if (std::find_if(registered_interfaces_.begin(), registered_interfaces_.end(),
[&socket](const RegisteredInterface& interface) {
return interface.socket == socket;
}) != registered_interfaces_.end()) {
- return false;
+ return Error::Code::kNoItemFound;
}
registered_interfaces_.push_back({interface_info, interface_address, socket});
- return true;
+ return Error::None();
}
-bool FakeMdnsResponderAdapter::DeregisterInterface(
+Error FakeMdnsResponderAdapter::DeregisterInterface(
platform::UdpSocketPtr socket) {
auto it =
std::find_if(registered_interfaces_.begin(), registered_interfaces_.end(),
@@ -227,10 +230,10 @@
return interface.socket == socket;
});
if (it == registered_interfaces_.end())
- return false;
+ return Error::Code::kNoItemFound;
registered_interfaces_.erase(it);
- return true;
+ return Error::None();
}
void FakeMdnsResponderAdapter::OnDataReceived(
@@ -373,7 +376,7 @@
auto canonical_service_type = service_type;
if (!canonical_service_type.EndsWithLocalDomain())
OSP_CHECK(
- canonical_service_type.Append(mdns::DomainName::GetLocalDomain()));
+ canonical_service_type.Append(mdns::DomainName::GetLocalDomain()).ok());
auto maybe_inserted =
queries_[socket].ptr_queries.insert(canonical_service_type);
@@ -450,7 +453,7 @@
auto canonical_service_type = service_type;
if (!canonical_service_type.EndsWithLocalDomain())
OSP_CHECK(
- canonical_service_type.Append(mdns::DomainName::GetLocalDomain()));
+ canonical_service_type.Append(mdns::DomainName::GetLocalDomain()).ok());
auto it = ptr_queries.find(canonical_service_type);
if (it == ptr_queries.end())
diff --git a/api/impl/testing/fake_mdns_responder_adapter.h b/api/impl/testing/fake_mdns_responder_adapter.h
index c61de14..57f8714 100644
--- a/api/impl/testing/fake_mdns_responder_adapter.h
+++ b/api/impl/testing/fake_mdns_responder_adapter.h
@@ -98,17 +98,17 @@
bool running() const { return running_; }
// mdns::MdnsResponderAdapter overrides.
- bool Init() override;
+ Error Init() override;
void Close() override;
- bool SetHostLabel(const std::string& host_label) override;
+ Error SetHostLabel(const std::string& host_label) override;
// TODO(btolsch): Reject/OSP_CHECK events that don't match any registered
// interface?
- bool RegisterInterface(const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocketPtr socket) override;
- bool DeregisterInterface(platform::UdpSocketPtr socket) override;
+ Error RegisterInterface(const platform::InterfaceInfo& interface_info,
+ const platform::IPSubnet& interface_address,
+ platform::UdpSocketPtr socket) override;
+ Error DeregisterInterface(platform::UdpSocketPtr socket) override;
void OnDataReceived(const IPEndpoint& source,
const IPEndpoint& original_destination,
diff --git a/api/impl/testing/fake_mdns_responder_adapter_unittest.cc b/api/impl/testing/fake_mdns_responder_adapter_unittest.cc
index 4f62dd7..7fe6a3e 100644
--- a/api/impl/testing/fake_mdns_responder_adapter_unittest.cc
+++ b/api/impl/testing/fake_mdns_responder_adapter_unittest.cc
@@ -110,10 +110,10 @@
EXPECT_EQ(kTestServiceInstance, labels[0]);
// TODO(btolsch): qname if PtrEvent gets it.
- mdns::DomainName st;
- ASSERT_TRUE(
- mdns::DomainName::FromLabels(labels.begin() + 1, labels.end(), &st));
- EXPECT_EQ(kTestServiceTypeCanon, st);
+ ErrorOr<mdns::DomainName> st =
+ mdns::DomainName::FromLabels(labels.begin() + 1, labels.end());
+ ASSERT_TRUE(st);
+ EXPECT_EQ(kTestServiceTypeCanon, st.value());
result = mdns_responder.StopPtrQuery(default_socket, kTestServiceType);
EXPECT_EQ(mdns::MdnsResponderErrorCode::kNoError, result);
@@ -255,26 +255,26 @@
auto socket1 = reinterpret_cast<platform::UdpSocketPtr>(16);
auto socket2 = reinterpret_cast<platform::UdpSocketPtr>(24);
- auto result = mdns_responder.RegisterInterface(platform::InterfaceInfo{},
- platform::IPSubnet{}, socket1);
- EXPECT_TRUE(result);
+ Error result = mdns_responder.RegisterInterface(
+ platform::InterfaceInfo{}, platform::IPSubnet{}, socket1);
+ EXPECT_TRUE(result.ok());
EXPECT_EQ(1u, mdns_responder.registered_interfaces().size());
result = mdns_responder.RegisterInterface(platform::InterfaceInfo{},
platform::IPSubnet{}, socket1);
- EXPECT_FALSE(result);
+ EXPECT_FALSE(result.ok());
EXPECT_EQ(1u, mdns_responder.registered_interfaces().size());
result = mdns_responder.RegisterInterface(platform::InterfaceInfo{},
platform::IPSubnet{}, socket2);
- EXPECT_TRUE(result);
+ EXPECT_TRUE(result.ok());
EXPECT_EQ(2u, mdns_responder.registered_interfaces().size());
result = mdns_responder.DeregisterInterface(socket2);
- EXPECT_TRUE(result);
+ EXPECT_TRUE(result.ok());
EXPECT_EQ(1u, mdns_responder.registered_interfaces().size());
result = mdns_responder.DeregisterInterface(socket2);
- EXPECT_FALSE(result);
+ EXPECT_FALSE(result.ok());
EXPECT_EQ(1u, mdns_responder.registered_interfaces().size());
mdns_responder.Close();
@@ -283,7 +283,7 @@
result = mdns_responder.RegisterInterface(platform::InterfaceInfo{},
platform::IPSubnet{}, socket1);
- EXPECT_FALSE(result);
+ EXPECT_FALSE(result.ok());
EXPECT_EQ(0u, mdns_responder.registered_interfaces().size());
}
diff --git a/api/public/message_demuxer.cc b/api/public/message_demuxer.cc
index 0f88cac..e1155b1 100644
--- a/api/public/message_demuxer.cc
+++ b/api/public/message_demuxer.cc
@@ -189,7 +189,7 @@
auto consumed_or_error = callback_entry->second->OnStreamMessage(
endpoint_id, connection_id, message_type, buffer->data() + 1,
buffer->size() - 1);
- if (consumed_or_error.is_error()) {
+ if (!consumed_or_error) {
if (consumed_or_error.error().code() !=
Error::Code::kCborIncompleteMessage) {
buffer->clear();
diff --git a/api/public/service_info_unittest.cc b/api/public/service_info_unittest.cc
index 717c41e..3436116 100644
--- a/api/public/service_info_unittest.cc
+++ b/api/public/service_info_unittest.cc
@@ -3,6 +3,7 @@
// found in the LICENSE file.
#include "api/public/service_info.h"
+#include "base/error.h"
#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
@@ -24,7 +25,10 @@
const ServiceInfo receiver1_alt_port{
"id3", "name1", 1, {{192, 168, 1, 10}, 12645}, {}};
ServiceInfo receiver1_ipv6{"id3", "name1", 1, {}, {}};
- ASSERT_TRUE(IPAddress::Parse("::12:34", &receiver1_ipv6.v6_endpoint.address));
+
+ ErrorOr<IPAddress> address = IPAddress::Parse("::12:34");
+ ASSERT_TRUE(address);
+ receiver1_ipv6.v6_endpoint.address = address.value();
EXPECT_EQ(receiver1, receiver1);
EXPECT_EQ(receiver2, receiver2);
diff --git a/base/error.cc b/base/error.cc
index e92cffd..bb875d5 100644
--- a/base/error.cc
+++ b/base/error.cc
@@ -20,6 +20,8 @@
Error::Error(Code code, std::string&& message)
: code_(code), message_(std::move(message)) {}
+Error::~Error() = default;
+
Error& Error::operator=(const Error& other) = default;
Error& Error::operator=(Error&& other) = default;
@@ -50,6 +52,12 @@
}
}
+// static
+const Error& Error::None() {
+ static Error& error = *new Error(Code::kNone);
+ return error;
+}
+
std::ostream& operator<<(std::ostream& out, const Error& error) {
out << Error::CodeToString(error.code()) << ": " << error.message();
return out;
diff --git a/base/error.h b/base/error.h
index 2cf2d92..57567d0 100644
--- a/base/error.h
+++ b/base/error.h
@@ -20,6 +20,7 @@
enum class Code {
// No error occurred.
kNone = 0,
+
// CBOR errors.
kCborParsing = 1,
kCborEncoding,
@@ -33,23 +34,50 @@
kNoPresentationFound,
kPreviousStartInProgress,
kUnknownStartError,
+
+ kAddressInUse,
+ kAlreadyListening,
+ kDomainNameTooLong,
+ kDomainNameLabelTooLong,
+
+ kGenericPlatformError,
+
+ kIOFailure,
+ kInitializationFailure,
+ kInvalidIPV4Address,
+ kInvalidIPV6Address,
+
+ kSocketOptionSettingFailure,
+ kSocketBindFailure,
+ kSocketClosedFailure,
+ kSocketReadFailure,
+
+ kMdnsRegisterFailure,
+
+ kNoItemFound,
+ kNotImplemented,
+ kNotRunning,
};
Error();
Error(const Error& error);
Error(Error&& error) noexcept;
- explicit Error(Code code);
+
+ Error(Code code);
Error(Code code, const std::string& message);
Error(Code code, std::string&& message);
+ ~Error();
Error& operator=(const Error& other);
Error& operator=(Error&& other);
bool operator==(const Error& other) const;
+ bool ok() const { return code_ == Code::kNone; }
Code code() const { return code_; }
const std::string& message() const { return message_; }
static std::string CodeToString(Error::Code code);
+ static const Error& None();
private:
Code code_ = Code::kNone;
@@ -79,8 +107,13 @@
template <typename Value>
class ErrorOr {
public:
+ static ErrorOr<Value> None() {
+ static ErrorOr<Value> error(Error::Code::kNone);
+ return error;
+ }
+
ErrorOr(ErrorOr&& error_or) = default;
- ErrorOr(Value&& value) noexcept : value_(value) {}
+ ErrorOr(Value&& value) noexcept : value_(std::move(value)) {}
ErrorOr(Error error) : error_(std::move(error)) {}
ErrorOr(Error::Code code) : error_(code) {}
ErrorOr(Error::Code code, std::string message)
@@ -91,6 +124,11 @@
bool is_error() const { return error_.code() != Error::Code::kNone; }
bool is_value() const { return !is_error(); }
+
+ // Unlike Error, we CAN provide an operator bool here, since it is
+ // more obvious to callers that ErrorOr<Foo> will be true if it's Foo.
+ operator bool() const { return is_value(); }
+
const Error& error() const { return error_; }
Error&& MoveError() { return std::move(error_); }
diff --git a/base/error_unittest.cc b/base/error_unittest.cc
index e13ed1e..8e537da 100644
--- a/base/error_unittest.cc
+++ b/base/error_unittest.cc
@@ -22,7 +22,7 @@
TEST(ErrorTest, TestDefaultError) {
Error error;
- EXPECT_EQ(error.code(), Error::Code::kNone);
+ EXPECT_EQ(error, Error::None());
EXPECT_EQ(error.message(), "");
}
@@ -56,16 +56,19 @@
ErrorOr<Dummy> error_or2(Error::Code::kCborParsing);
ErrorOr<Dummy> error_or3(Error::Code::kCborParsing, "Parse Error Again");
+ EXPECT_FALSE(error_or1);
EXPECT_FALSE(error_or1.is_value());
EXPECT_TRUE(error_or1.is_error());
EXPECT_EQ(error_or1.error().code(), Error::Code::kCborParsing);
EXPECT_EQ(error_or1.error().message(), "Parse Error");
+ EXPECT_FALSE(error_or2);
EXPECT_FALSE(error_or2.is_value());
EXPECT_TRUE(error_or2.is_error());
EXPECT_EQ(error_or2.error().code(), Error::Code::kCborParsing);
EXPECT_EQ(error_or2.error().message(), "");
+ EXPECT_FALSE(error_or3);
EXPECT_FALSE(error_or3.is_value());
EXPECT_TRUE(error_or3.is_error());
EXPECT_EQ(error_or3.error().code(), Error::Code::kCborParsing);
@@ -74,11 +77,13 @@
ErrorOr<Dummy> error_or4(std::move(error_or1));
ErrorOr<Dummy> error_or5 = std::move(error_or3);
+ EXPECT_FALSE(error_or4);
EXPECT_FALSE(error_or4.is_value());
EXPECT_TRUE(error_or4.is_error());
EXPECT_EQ(error_or4.error().code(), Error::Code::kCborParsing);
EXPECT_EQ(error_or4.error().message(), "Parse Error");
+ EXPECT_FALSE(error_or5);
EXPECT_FALSE(error_or5.is_value());
EXPECT_TRUE(error_or5.is_error());
EXPECT_EQ(error_or5.error().code(), Error::Code::kCborParsing);
@@ -89,28 +94,32 @@
ErrorOr<Dummy> error_or1(Dummy("Winterfell"));
ErrorOr<Dummy> error_or2(Dummy("Riverrun"));
+ EXPECT_TRUE(error_or1);
EXPECT_TRUE(error_or1.is_value());
EXPECT_FALSE(error_or1.is_error());
EXPECT_EQ(error_or1.value().message, "Winterfell");
- EXPECT_EQ(error_or1.error().code(), Error::Code::kNone);
+ EXPECT_EQ(error_or1.error(), Error::None());
+ EXPECT_TRUE(error_or2);
EXPECT_TRUE(error_or2.is_value());
EXPECT_FALSE(error_or2.is_error());
EXPECT_EQ(error_or2.value().message, "Riverrun");
- EXPECT_EQ(error_or2.error().code(), Error::Code::kNone);
+ EXPECT_EQ(error_or2.error(), Error::None());
ErrorOr<Dummy> error_or3(std::move(error_or1));
ErrorOr<Dummy> error_or4 = std::move(error_or2);
+ EXPECT_TRUE(error_or3);
EXPECT_TRUE(error_or3.is_value());
EXPECT_FALSE(error_or3.is_error());
EXPECT_EQ(error_or3.value().message, "Winterfell");
- EXPECT_EQ(error_or3.error().code(), Error::Code::kNone);
+ EXPECT_EQ(error_or3.error(), Error::None());
+ EXPECT_TRUE(error_or4);
EXPECT_TRUE(error_or4.is_value());
EXPECT_FALSE(error_or4.is_error());
EXPECT_EQ(error_or4.value().message, "Riverrun");
- EXPECT_EQ(error_or4.error().code(), Error::Code::kNone);
+ EXPECT_EQ(error_or4.error(), Error::None());
Dummy value = error_or4.MoveValue();
EXPECT_EQ(value.message, "Riverrun");
diff --git a/base/ip_address.cc b/base/ip_address.cc
index 4ae688c..1becbd6 100644
--- a/base/ip_address.cc
+++ b/base/ip_address.cc
@@ -13,9 +13,11 @@
namespace openscreen {
// static
-bool IPAddress::Parse(const std::string& s, IPAddress* address) {
- return ParseV4(s, address) || ParseV6(s, address);
-}
+ErrorOr<IPAddress> IPAddress::Parse(const std::string& s) {
+ ErrorOr<IPAddress> v4 = ParseV4(s);
+
+ return v4 ? std::move(v4) : ParseV6(s);
+} // namespace openscreen
IPAddress::IPAddress() : version_(Version::kV4), bytes_({}) {}
IPAddress::IPAddress(const std::array<uint8_t, 4>& bytes)
@@ -99,49 +101,50 @@
}
// static
-bool IPAddress::ParseV4(const std::string& s, IPAddress* address) {
+ErrorOr<IPAddress> IPAddress::ParseV4(const std::string& s) {
if (s.size() > 0 && s[0] == '.')
- return false;
+ return Error::Code::kInvalidIPV4Address;
+ IPAddress address;
uint16_t next_octet = 0;
int i = 0;
bool previous_dot = false;
for (auto c : s) {
if (c == '.') {
if (previous_dot) {
- return false;
+ return Error::Code::kInvalidIPV4Address;
}
- address->bytes_[i++] = static_cast<uint8_t>(next_octet);
+ address.bytes_[i++] = static_cast<uint8_t>(next_octet);
next_octet = 0;
previous_dot = true;
if (i > 3)
- return false;
+ return Error::Code::kInvalidIPV4Address;
continue;
}
previous_dot = false;
if (!std::isdigit(c))
- return false;
+ return Error::Code::kInvalidIPV4Address;
next_octet = next_octet * 10 + (c - '0');
if (next_octet > 255)
- return false;
+ return Error::Code::kInvalidIPV4Address;
}
if (previous_dot)
- return false;
+ return Error::Code::kInvalidIPV4Address;
if (i != 3)
- return false;
+ return Error::Code::kInvalidIPV4Address;
- address->bytes_[i] = static_cast<uint8_t>(next_octet);
- address->version_ = Version::kV4;
- return true;
+ address.bytes_[i] = static_cast<uint8_t>(next_octet);
+ address.version_ = Version::kV4;
+ return address;
}
// static
-bool IPAddress::ParseV6(const std::string& s, IPAddress* address) {
+ErrorOr<IPAddress> IPAddress::ParseV6(const std::string& s) {
if (s.size() > 1 && s[0] == ':' && s[1] != ':')
- return false;
+ return Error::Code::kInvalidIPV6Address;
uint16_t next_value = 0;
uint8_t values[16];
@@ -153,11 +156,11 @@
++num_previous_colons;
if (num_previous_colons == 2) {
if (double_colon_index) {
- return false;
+ return Error::Code::kInvalidIPV6Address;
}
double_colon_index = i;
} else if (i >= 15 || num_previous_colons > 2) {
- return false;
+ return Error::Code::kInvalidIPV6Address;
} else {
values[i++] = static_cast<uint8_t>(next_value >> 8);
values[i++] = static_cast<uint8_t>(next_value & 0xff);
@@ -173,37 +176,39 @@
} else if (c >= 'A' && c <= 'F') {
x = c - 'A' + 10;
} else {
- return false;
+ return Error::Code::kInvalidIPV6Address;
}
if (next_value & 0xf000) {
- return false;
+ return Error::Code::kInvalidIPV6Address;
} else {
next_value = static_cast<uint16_t>(next_value * 16 + x);
}
}
}
if (num_previous_colons == 1)
- return false;
+ return Error::Code::kInvalidIPV6Address;
if (i >= 15)
- return false;
+ return Error::Code::kInvalidIPV6Address;
values[i++] = static_cast<uint8_t>(next_value >> 8);
values[i] = static_cast<uint8_t>(next_value & 0xff);
if (!((i == 15 && !double_colon_index) || (i < 14 && double_colon_index))) {
- return false;
+ return Error::Code::kInvalidIPV6Address;
}
+
+ IPAddress address;
for (int j = 15; j >= 0;) {
if (double_colon_index && (i == double_colon_index)) {
- address->bytes_[j--] = values[i--];
+ address.bytes_[j--] = values[i--];
while (j > i)
- address->bytes_[j--] = 0;
+ address.bytes_[j--] = 0;
} else {
- address->bytes_[j--] = values[i--];
+ address.bytes_[j--] = values[i--];
}
}
- address->version_ = Version::kV6;
- return true;
+ address.version_ = Version::kV6;
+ return address;
}
bool operator==(const IPEndpoint& a, const IPEndpoint& b) {
diff --git a/base/ip_address.h b/base/ip_address.h
index f72a54c..7280a10 100644
--- a/base/ip_address.h
+++ b/base/ip_address.h
@@ -11,6 +11,8 @@
#include <string>
#include <type_traits>
+#include "base/error.h"
+
namespace openscreen {
class IPAddress {
@@ -25,9 +27,7 @@
// Parses a text representation of an IPv4 address (e.g. "192.168.0.1") or an
// IPv6 address (e.g. "abcd::1234") and puts the result into |address|.
- // Returns true if the parsing was successful and |address| was set, false
- // otherwise.
- static bool Parse(const std::string& s, IPAddress* address);
+ static ErrorOr<IPAddress> Parse(const std::string& s);
IPAddress();
explicit IPAddress(const std::array<uint8_t, 4>& bytes);
@@ -73,8 +73,8 @@
void CopyToV6(uint8_t* x) const;
private:
- static bool ParseV4(const std::string& s, IPAddress* address);
- static bool ParseV6(const std::string& s, IPAddress* address);
+ static ErrorOr<IPAddress> ParseV4(const std::string& s);
+ static ErrorOr<IPAddress> ParseV6(const std::string& s);
friend class IPEndpointComparator;
diff --git a/base/ip_address_unittest.cc b/base/ip_address_unittest.cc
index 72ddd5f..7319316 100644
--- a/base/ip_address_unittest.cc
+++ b/base/ip_address_unittest.cc
@@ -3,6 +3,8 @@
// found in the LICENSE file.
#include "base/ip_address.h"
+#include "base/error.h"
+
#include "third_party/googletest/src/googlemock/include/gmock/gmock.h"
#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
@@ -59,40 +61,39 @@
TEST(IPAddressTest, V4Parse) {
uint8_t bytes[4] = {};
- IPAddress address;
- ASSERT_TRUE(IPAddress::Parse("192.168.0.1", &address));
- address.CopyToV4(bytes);
+
+ ErrorOr<IPAddress> address = IPAddress::Parse("192.168.0.1");
+ ASSERT_TRUE(address);
+ address.value().CopyToV4(bytes);
EXPECT_THAT(bytes, ElementsAreArray({192, 168, 0, 1}));
}
TEST(IPAddressTest, V4ParseFailures) {
- IPAddress address;
-
- EXPECT_FALSE(IPAddress::Parse("192..0.1", &address))
+ EXPECT_FALSE(IPAddress::Parse("192..0.1"))
<< "empty value should fail to parse";
- EXPECT_FALSE(IPAddress::Parse(".192.168.0.1", &address))
+ EXPECT_FALSE(IPAddress::Parse(".192.168.0.1"))
<< "leading dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse(".192.168.1", &address))
+ EXPECT_FALSE(IPAddress::Parse(".192.168.1"))
<< "leading dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("..192.168.0.1", &address))
+ EXPECT_FALSE(IPAddress::Parse("..192.168.0.1"))
<< "leading dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("..192.1", &address))
+ EXPECT_FALSE(IPAddress::Parse("..192.1"))
<< "leading dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("192.168.0.1.", &address))
+ EXPECT_FALSE(IPAddress::Parse("192.168.0.1."))
<< "trailing dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("192.168.1.", &address))
+ EXPECT_FALSE(IPAddress::Parse("192.168.1."))
<< "trailing dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("192.168.1..", &address))
+ EXPECT_FALSE(IPAddress::Parse("192.168.1.."))
<< "trailing dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("192.168..", &address))
+ EXPECT_FALSE(IPAddress::Parse("192.168.."))
<< "trailing dot should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("192.x3.0.1", &address))
+ EXPECT_FALSE(IPAddress::Parse("192.x3.0.1"))
<< "non-digit character should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("192.3.1", &address))
+ EXPECT_FALSE(IPAddress::Parse("192.3.1"))
<< "too few values should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("192.3.2.0.1", &address))
+ EXPECT_FALSE(IPAddress::Parse("192.3.2.0.1"))
<< "too many values should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("1920.3.2.1", &address))
+ EXPECT_FALSE(IPAddress::Parse("1920.3.2.1"))
<< "value > 255 should fail to parse";
}
@@ -145,10 +146,10 @@
TEST(IPAddressTest, V6ParseBasic) {
uint8_t bytes[16] = {};
- IPAddress address;
- ASSERT_TRUE(
- IPAddress::Parse("abcd:ef01:2345:6789:9876:5432:10FE:DBCA", &address));
- address.CopyToV6(bytes);
+ ErrorOr<IPAddress> address =
+ IPAddress::Parse("abcd:ef01:2345:6789:9876:5432:10FE:DBCA");
+ ASSERT_TRUE(address);
+ address.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67,
0x89, 0x98, 0x76, 0x54, 0x32, 0x10, 0xfe,
0xdb, 0xca}));
@@ -156,30 +157,30 @@
TEST(IPAddressTest, V6ParseDoubleColon) {
uint8_t bytes[16] = {};
- IPAddress address1;
- ASSERT_TRUE(
- IPAddress::Parse("abcd:ef01:2345:6789:9876:5432::dbca", &address1));
- address1.CopyToV6(bytes);
+ ErrorOr<IPAddress> address1 =
+ IPAddress::Parse("abcd:ef01:2345:6789:9876:5432::dbca");
+ ASSERT_TRUE(address1);
+ address1.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67,
0x89, 0x98, 0x76, 0x54, 0x32, 0x00, 0x00,
0xdb, 0xca}));
- IPAddress address2;
- ASSERT_TRUE(IPAddress::Parse("abcd::10fe:dbca", &address2));
- address2.CopyToV6(bytes);
+ ErrorOr<IPAddress> address2 = IPAddress::Parse("abcd::10fe:dbca");
+ ASSERT_TRUE(address2);
+ address2.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0xab, 0xcd, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0xfe,
0xdb, 0xca}));
- IPAddress address3;
- ASSERT_TRUE(IPAddress::Parse("::10fe:dbca", &address3));
- address3.CopyToV6(bytes);
+ ErrorOr<IPAddress> address3 = IPAddress::Parse("::10fe:dbca");
+ ASSERT_TRUE(address3);
+ address3.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0xfe,
0xdb, 0xca}));
- IPAddress address4;
- ASSERT_TRUE(IPAddress::Parse("10fe:dbca::", &address4));
- address4.CopyToV6(bytes);
+ ErrorOr<IPAddress> address4 = IPAddress::Parse("10fe:dbca::");
+ ASSERT_TRUE(address4);
+ address4.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0x10, 0xfe, 0xdb, 0xca, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00}));
@@ -187,59 +188,57 @@
TEST(IPAddressTest, V6SmallValues) {
uint8_t bytes[16] = {};
- IPAddress address1;
- ASSERT_TRUE(IPAddress::Parse("::", &address1));
- address1.CopyToV6(bytes);
+ ErrorOr<IPAddress> address1 = IPAddress::Parse("::");
+ ASSERT_TRUE(address1);
+ address1.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0x0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00}));
- IPAddress address2;
- ASSERT_TRUE(IPAddress::Parse("::1", &address2));
- address2.CopyToV6(bytes);
+ ErrorOr<IPAddress> address2 = IPAddress::Parse("::1");
+ ASSERT_TRUE(address2);
+ address2.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0x0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01}));
- IPAddress address3;
- ASSERT_TRUE(IPAddress::Parse("::2:1", &address3));
- address3.CopyToV6(bytes);
+ ErrorOr<IPAddress> address3 = IPAddress::Parse("::2:1");
+ ASSERT_TRUE(address3);
+ address3.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0x0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
0x00, 0x01}));
}
TEST(IPAddressTest, V6ParseFailures) {
- IPAddress address;
-
- EXPECT_FALSE(IPAddress::Parse(":abcd::dbca", &address))
+ EXPECT_FALSE(IPAddress::Parse(":abcd::dbca"))
<< "leading colon should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("abcd::dbca:", &address))
+ EXPECT_FALSE(IPAddress::Parse("abcd::dbca:"))
<< "trailing colon should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("abxd::1234", &address))
+ EXPECT_FALSE(IPAddress::Parse("abxd::1234"))
<< "non-hex digit should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("abcd:1234", &address))
+ EXPECT_FALSE(IPAddress::Parse("abcd:1234"))
<< "too few values should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("a:b:c:d:e:f:0:1:2:3:4:5:6:7:8:9:a", &address))
+ EXPECT_FALSE(IPAddress::Parse("a:b:c:d:e:f:0:1:2:3:4:5:6:7:8:9:a"))
<< "too many values should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("abcd1::dbca", &address))
+ EXPECT_FALSE(IPAddress::Parse("abcd1::dbca"))
<< "value > 0xffff should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("::abcd::dbca", &address))
+ EXPECT_FALSE(IPAddress::Parse("::abcd::dbca"))
<< "multiple double colon should fail to parse";
- EXPECT_FALSE(IPAddress::Parse(":::abcd::dbca", &address))
+ EXPECT_FALSE(IPAddress::Parse(":::abcd::dbca"))
<< "leading triple colon should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("abcd:::dbca", &address))
+ EXPECT_FALSE(IPAddress::Parse("abcd:::dbca"))
<< "triple colon should fail to parse";
- EXPECT_FALSE(IPAddress::Parse("abcd:dbca:::", &address))
+ EXPECT_FALSE(IPAddress::Parse("abcd:dbca:::"))
<< "trailing triple colon should fail to parse";
}
TEST(IPAddressTest, V6ParseThreeDigitValue) {
uint8_t bytes[16] = {};
- IPAddress address;
- ASSERT_TRUE(IPAddress::Parse("::123", &address));
- address.CopyToV6(bytes);
+ ErrorOr<IPAddress> address = IPAddress::Parse("::123");
+ ASSERT_TRUE(address);
+ address.value().CopyToV6(bytes);
EXPECT_THAT(bytes, ElementsAreArray({0x0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x23}));
diff --git a/discovery/mdns/domain_name.cc b/discovery/mdns/domain_name.cc
index dc9eccf..e96e9fe 100644
--- a/discovery/mdns/domain_name.cc
+++ b/discovery/mdns/domain_name.cc
@@ -18,28 +18,31 @@
}
// static
-bool DomainName::Append(const DomainName& first,
- const DomainName& second,
- DomainName* result) {
+ErrorOr<DomainName> DomainName::Append(const DomainName& first,
+ const DomainName& second) {
OSP_CHECK(first.domain_name_.size());
OSP_CHECK(second.domain_name_.size());
- OSP_DCHECK_EQ(first.domain_name_.back(), 0);
- OSP_DCHECK_EQ(second.domain_name_.back(), 0);
+
+ // Both vectors should represent null terminated domain names.
+ OSP_DCHECK_EQ(first.domain_name_.back(), '\0');
+ OSP_DCHECK_EQ(second.domain_name_.back(), '\0');
if ((first.domain_name_.size() + second.domain_name_.size() - 1) >
kDomainNameMaxLength) {
- return false;
+ return Error::Code::kDomainNameTooLong;
}
- result->domain_name_.clear();
- result->domain_name_.insert(result->domain_name_.begin(),
- first.domain_name_.begin(),
- first.domain_name_.end());
- result->domain_name_.insert(result->domain_name_.end() - 1,
- second.domain_name_.begin(),
- second.domain_name_.end() - 1);
- return true;
+
+ DomainName result;
+ result.domain_name_.clear();
+ result.domain_name_.insert(result.domain_name_.begin(),
+ first.domain_name_.begin(),
+ first.domain_name_.end());
+ result.domain_name_.insert(result.domain_name_.end() - 1,
+ second.domain_name_.begin(),
+ second.domain_name_.end() - 1);
+ return result;
}
-DomainName::DomainName() : domain_name_{0} {}
+DomainName::DomainName() : domain_name_{'\0'} {}
DomainName::DomainName(std::vector<uint8_t>&& domain_name)
: domain_name_(std::move(domain_name)) {
OSP_CHECK_LE(domain_name_.size(), kDomainNameMaxLength);
@@ -69,16 +72,18 @@
domain_name_.end() - local_domain.domain_name_.size());
}
-bool DomainName::Append(const DomainName& after) {
+Error DomainName::Append(const DomainName& after) {
OSP_CHECK(after.domain_name_.size());
- OSP_DCHECK_EQ(after.domain_name_.back(), 0);
+ OSP_DCHECK_EQ(after.domain_name_.back(), '\0');
+
if ((domain_name_.size() + after.domain_name_.size() - 1) >
kDomainNameMaxLength) {
- return false;
+ return Error::Code::kDomainNameTooLong;
}
+
domain_name_.insert(domain_name_.end() - 1, after.domain_name_.begin(),
after.domain_name_.end() - 1);
- return true;
+ return Error::None();
}
std::vector<std::string> DomainName::GetLabels() const {
diff --git a/discovery/mdns/domain_name.h b/discovery/mdns/domain_name.h
index cbec313..5c18994 100644
--- a/discovery/mdns/domain_name.h
+++ b/discovery/mdns/domain_name.h
@@ -10,37 +10,37 @@
#include <string>
#include <vector>
+#include "base/error.h"
#include "platform/api/logging.h"
namespace openscreen {
namespace mdns {
struct DomainName {
- // TODO(issues/2): Replace bool here and elsewhere with ErrorOr<DomainName>.
- static bool Append(const DomainName& first,
- const DomainName& second,
- DomainName* result);
+ static ErrorOr<DomainName> Append(const DomainName& first,
+ const DomainName& second);
template <typename It>
- static bool FromLabels(It first, It last, DomainName* result) {
+ static ErrorOr<DomainName> FromLabels(It first, It last) {
size_t total_length = 1;
for (auto label = first; label != last; ++label) {
if (label->size() > kDomainNameMaxLabelLength)
- return false;
+ return Error::Code::kDomainNameLabelTooLong;
total_length += label->size() + 1;
}
if (total_length > kDomainNameMaxLength)
- return false;
+ return Error::Code::kDomainNameTooLong;
- result->domain_name_.resize(total_length);
- auto result_it = result->domain_name_.begin();
+ DomainName result;
+ result.domain_name_.resize(total_length);
+ auto result_it = result.domain_name_.begin();
for (auto label = first; label != last; ++label) {
*result_it++ = static_cast<uint8_t>(label->size());
result_it = std::copy(label->begin(), label->end(), result_it);
}
*result_it = 0;
- return true;
+ return std::move(result);
}
static DomainName GetLocalDomain();
@@ -62,7 +62,7 @@
bool EndsWithLocalDomain() const;
bool IsEmpty() const { return domain_name_.size() == 1 && !domain_name_[0]; }
- bool Append(const DomainName& after);
+ Error Append(const DomainName& after);
// TODO: If there's significant use of this, we would rather have string_span
// or similar for this so we could use iterators for zero-copy.
std::vector<std::string> GetLabels() const;
diff --git a/discovery/mdns/domain_name_unittest.cc b/discovery/mdns/domain_name_unittest.cc
index 71ac11c..43de93b 100644
--- a/discovery/mdns/domain_name_unittest.cc
+++ b/discovery/mdns/domain_name_unittest.cc
@@ -6,6 +6,8 @@
#include <sstream>
+#include "base/error.h"
+
#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
namespace openscreen {
@@ -13,8 +15,14 @@
namespace {
-bool FromLabels(const std::vector<std::string>& labels, DomainName* result) {
- return DomainName::FromLabels(labels.begin(), labels.end(), result);
+ErrorOr<DomainName> FromLabels(const std::vector<std::string>& labels) {
+ return DomainName::FromLabels(labels.begin(), labels.end());
+}
+
+template <typename T>
+T UnpackErrorOr(ErrorOr<T> error_or) {
+ EXPECT_TRUE(error_or);
+ return error_or.MoveValue();
}
} // namespace
@@ -52,48 +60,42 @@
const auto typical =
std::vector<uint8_t>{10, 'o', 'p', 'e', 'n', 's', 'c', 'r',
'e', 'e', 'n', 3, 'o', 'r', 'g', 0};
- DomainName result;
- ASSERT_TRUE(FromLabels({"openscreen", "org"}, &result));
+ DomainName result = UnpackErrorOr(FromLabels({"openscreen", "org"}));
EXPECT_EQ(result.domain_name(), typical);
const auto includes_dot =
std::vector<uint8_t>{11, 'o', 'p', 'e', 'n', '.', 's', 'c', 'r',
'e', 'e', 'n', 3, 'o', 'r', 'g', 0};
- ASSERT_TRUE(FromLabels({"open.screen", "org"}, &result));
+ result = UnpackErrorOr(FromLabels({"open.screen", "org"}));
EXPECT_EQ(result.domain_name(), includes_dot);
const auto includes_non_ascii =
std::vector<uint8_t>{11, 'o', 'p', 'e', 'n', 7, 's', 'c', 'r',
'e', 'e', 'n', 3, 'o', 'r', 'g', 0};
- ASSERT_TRUE(FromLabels({"open\7screen", "org"}, &result));
+ result = UnpackErrorOr(FromLabels({"open\7screen", "org"}));
EXPECT_EQ(result.domain_name(), includes_non_ascii);
- ASSERT_FALSE(FromLabels({"extremely-long-label-that-is-actually-too-long-"
- "for-rfc-1034-and-will-not-generate"},
- &result));
+ ASSERT_FALSE(
+ FromLabels({"extremely-long-label-that-is-actually-too-long-"
+ "for-rfc-1034-and-will-not-generate"}));
- ASSERT_FALSE(FromLabels(
- {
- "extremely-long-domain-name-that-is-made-of",
- "valid-labels",
- "however-overall-it-is-too-long-for-rfc-1034",
- "so-it-should-fail-to-generate",
- "filler-filler-filler-filler-filler",
- "filler-filler-filler-filler-filler",
- "filler-filler-filler-filler-filler",
- "filler-filler-filler-filler-filler",
- },
- &result));
+ ASSERT_FALSE(FromLabels({
+ "extremely-long-domain-name-that-is-made-of",
+ "valid-labels",
+ "however-overall-it-is-too-long-for-rfc-1034",
+ "so-it-should-fail-to-generate",
+ "filler-filler-filler-filler-filler",
+ "filler-filler-filler-filler-filler",
+ "filler-filler-filler-filler-filler",
+ "filler-filler-filler-filler-filler",
+ }));
}
TEST(DomainNameTest, Equality) {
- DomainName alpha;
- DomainName beta;
- DomainName alpha_copy;
+ DomainName alpha = UnpackErrorOr(FromLabels({"alpha", "openscreen", "org"}));
+ DomainName beta = UnpackErrorOr(FromLabels({"beta", "openscreen", "org"}));
- ASSERT_TRUE(FromLabels({"alpha", "openscreen", "org"}, &alpha));
- ASSERT_TRUE(FromLabels({"beta", "openscreen", "org"}, &beta));
- alpha_copy = alpha;
+ const DomainName alpha_copy = alpha;
EXPECT_TRUE(alpha == alpha);
EXPECT_FALSE(alpha != alpha);
@@ -105,12 +107,10 @@
TEST(DomainNameTest, EndsWithLocalDomain) {
DomainName alpha;
- DomainName beta;
-
EXPECT_FALSE(alpha.EndsWithLocalDomain());
- ASSERT_TRUE(FromLabels({"alpha", "openscreen", "org"}, &alpha));
- ASSERT_TRUE(FromLabels({"beta", "local"}, &beta));
+ alpha = UnpackErrorOr(FromLabels({"alpha", "openscreen", "org"}));
+ DomainName beta = UnpackErrorOr(FromLabels({"beta", "local"}));
EXPECT_FALSE(alpha.EndsWithLocalDomain());
EXPECT_TRUE(beta.EndsWithLocalDomain());
@@ -123,36 +123,44 @@
EXPECT_TRUE(alpha.IsEmpty());
EXPECT_TRUE(beta.IsEmpty());
- ASSERT_TRUE(FromLabels({"alpha", "openscreen", "org"}, &alpha));
+ alpha = UnpackErrorOr(FromLabels({"alpha", "openscreen", "org"}));
EXPECT_FALSE(alpha.IsEmpty());
}
TEST(DomainNameTest, Append) {
+ const auto expected_service_name =
+ std::vector<uint8_t>{5, 'a', 'l', 'p', 'h', 'a', '\0'};
+ const auto expected_service_type_initial = std::vector<uint8_t>{
+ 11, '_', 'o', 'p', 'e', 'n', 's', 'c', 'r', 'e', 'e', 'n', '\0'};
+ const auto expected_protocol =
+ std::vector<uint8_t>{5, '_', 'q', 'u', 'i', 'c', '\0'};
const auto expected_service_type =
std::vector<uint8_t>{11, '_', 'o', 'p', 'e', 'n', 's', 'c', 'r', 'e',
- 'e', 'n', 5, '_', 'q', 'u', 'i', 'c', 0};
+ 'e', 'n', 5, '_', 'q', 'u', 'i', 'c', '\0'};
const auto total_expected = std::vector<uint8_t>{
5, 'a', 'l', 'p', 'h', 'a', 11, '_', 'o', 'p', 'e', 'n', 's',
- 'c', 'r', 'e', 'e', 'n', 5, '_', 'q', 'u', 'i', 'c', 0};
- DomainName service_name;
- DomainName service_type;
- DomainName protocol;
- ASSERT_TRUE(FromLabels({"alpha"}, &service_name));
- ASSERT_TRUE(FromLabels({"_openscreen"}, &service_type));
- ASSERT_TRUE(FromLabels({"_quic"}, &protocol));
+ 'c', 'r', 'e', 'e', 'n', 5, '_', 'q', 'u', 'i', 'c', '\0'};
- EXPECT_TRUE(service_type.Append(protocol));
+ DomainName service_name = UnpackErrorOr(FromLabels({"alpha"}));
+ EXPECT_EQ(service_name.domain_name(), expected_service_name);
+
+ DomainName service_type = UnpackErrorOr(FromLabels({"_openscreen"}));
+ EXPECT_EQ(service_type.domain_name(), expected_service_type_initial);
+
+ DomainName protocol = UnpackErrorOr(FromLabels({"_quic"}));
+ EXPECT_EQ(protocol.domain_name(), expected_protocol);
+
+ EXPECT_TRUE(service_type.Append(protocol).ok());
EXPECT_EQ(service_type.domain_name(), expected_service_type);
- DomainName result;
- EXPECT_TRUE(DomainName::Append(service_name, service_type, &result));
+ DomainName result =
+ UnpackErrorOr(DomainName::Append(service_name, service_type));
EXPECT_EQ(result.domain_name(), total_expected);
}
TEST(DomainNameTest, GetLabels) {
const auto labels = std::vector<std::string>{"alpha", "beta", "gamma", "org"};
- DomainName d;
- ASSERT_TRUE(FromLabels(labels, &d));
+ DomainName d = UnpackErrorOr(FromLabels(labels));
EXPECT_EQ(d.GetLabels(), labels);
}
diff --git a/discovery/mdns/embedder_demo.cc b/discovery/mdns/embedder_demo.cc
index c8e0d7b..5895d25 100644
--- a/discovery/mdns/embedder_demo.cc
+++ b/discovery/mdns/embedder_demo.cc
@@ -9,6 +9,7 @@
#include <map>
#include <vector>
+#include "base/error.h"
#include "base/make_unique.h"
#include "discovery/mdns/mdns_responder_adapter_impl.h"
#include "platform/api/error.h"
@@ -101,13 +102,14 @@
std::vector<platform::UdpSocketPtr> fds;
for (const auto ifindex : index_list) {
auto* socket = platform::CreateUdpSocketIPv4();
- if (!JoinUdpMulticastGroup(socket, IPAddress{224, 0, 0, 251}, ifindex)) {
+ if (!JoinUdpMulticastGroup(socket, IPAddress{224, 0, 0, 251}, ifindex)
+ .ok()) {
OSP_LOG_ERROR << "join multicast group failed for interface " << ifindex
<< ": " << platform::GetLastErrorString();
DestroyUdpSocket(socket);
continue;
}
- if (!BindUdpSocket(socket, {{}, 5353}, ifindex)) {
+ if (!BindUdpSocket(socket, {{}, 5353}, ifindex).ok()) {
OSP_LOG_ERROR << "bind failed for interface " << ifindex << ": "
<< platform::GetLastErrorString();
DestroyUdpSocket(socket);
@@ -222,10 +224,11 @@
const std::string& service_instance) {
SignalThings();
- mdns::DomainName service_type;
std::vector<std::string> labels{service_name, service_protocol};
- if (!mdns::DomainName::FromLabels(labels.begin(), labels.end(),
- &service_type)) {
+ ErrorOr<mdns::DomainName> service_type =
+ mdns::DomainName::FromLabels(labels.begin(), labels.end());
+
+ if (!service_type) {
OSP_LOG_ERROR << "bad domain labels: " << service_name << ", "
<< service_protocol;
return;
@@ -268,7 +271,7 @@
for (auto* socket : sockets) {
platform::WatchUdpSocketReadable(waiter, socket);
- mdns_adapter->StartPtrQuery(socket, service_type);
+ mdns_adapter->StartPtrQuery(socket, service_type.value());
}
while (!g_done) {
diff --git a/discovery/mdns/mdns_responder_adapter.h b/discovery/mdns/mdns_responder_adapter.h
index 263ed21..d69ca87 100644
--- a/discovery/mdns/mdns_responder_adapter.h
+++ b/discovery/mdns/mdns_responder_adapter.h
@@ -10,6 +10,7 @@
#include <string>
#include <vector>
+#include "base/error.h"
#include "base/ip_address.h"
#include "discovery/mdns/domain_name.h"
#include "discovery/mdns/mdns_responder_platform.h"
@@ -170,7 +171,7 @@
// Initializes mDNSResponder. This should be called before any queries or
// service registrations are made.
- virtual bool Init() = 0;
+ virtual Error Init() = 0;
// Stops all open queries and service registrations. If this is not called
// before destruction, any registered services will not send their goodbye
@@ -181,16 +182,16 @@
// when any service is active (via RegisterService). Returns true if the
// label was set successfully, false otherwise (e.g. the label did not meet
// DNS name requirements).
- virtual bool SetHostLabel(const std::string& host_label) = 0;
+ virtual Error SetHostLabel(const std::string& host_label) = 0;
// The following methods register and deregister a network interface with
// mDNSResponder. |socket| will be used to identify which interface received
// the data in OnDataReceived and will be used to send data via the platform
// layer.
- virtual bool RegisterInterface(const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocketPtr socket) = 0;
- virtual bool DeregisterInterface(platform::UdpSocketPtr socket) = 0;
+ virtual Error RegisterInterface(const platform::InterfaceInfo& interface_info,
+ const platform::IPSubnet& interface_address,
+ platform::UdpSocketPtr socket) = 0;
+ virtual Error DeregisterInterface(platform::UdpSocketPtr socket) = 0;
virtual void OnDataReceived(const IPEndpoint& source,
const IPEndpoint& original_destination,
diff --git a/discovery/mdns/mdns_responder_adapter_impl.cc b/discovery/mdns/mdns_responder_adapter_impl.cc
index 9687293..9610630 100644
--- a/discovery/mdns/mdns_responder_adapter_impl.cc
+++ b/discovery/mdns/mdns_responder_adapter_impl.cc
@@ -194,12 +194,14 @@
MdnsResponderAdapterImpl::MdnsResponderAdapterImpl() = default;
MdnsResponderAdapterImpl::~MdnsResponderAdapterImpl() = default;
-bool MdnsResponderAdapterImpl::Init() {
+Error MdnsResponderAdapterImpl::Init() {
const auto err =
mDNS_Init(&mdns_, &platform_storage_, rr_cache_, kRrCacheSize,
mDNS_Init_DontAdvertiseLocalAddresses, &MdnsStatusCallback,
mDNS_Init_NoInitCallbackContext);
- return err == mStatus_NoError;
+
+ return (err == mStatus_NoError) ? Error::None()
+ : Error::Code::kInitializationFailure;
}
void MdnsResponderAdapterImpl::Close() {
@@ -223,9 +225,9 @@
service_records_.clear();
}
-bool MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) {
+Error MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) {
if (host_label.size() > DomainName::kDomainNameMaxLabelLength)
- return false;
+ return Error::Code::kDomainNameTooLong;
MakeDomainLabelFromLiteralString(&mdns_.hostlabel, host_label.c_str());
mDNS_SetFQDN(&mdns_);
@@ -233,16 +235,16 @@
DeadvertiseInterfaces();
AdvertiseInterfaces();
}
- return true;
+ return Error::None();
}
-bool MdnsResponderAdapterImpl::RegisterInterface(
+Error MdnsResponderAdapterImpl::RegisterInterface(
const platform::InterfaceInfo& interface_info,
const platform::IPSubnet& interface_address,
platform::UdpSocketPtr socket) {
const auto info_it = responder_interface_info_.find(socket);
if (info_it != responder_interface_info_.end())
- return true;
+ return Error::None();
NetworkInterfaceInfo& info = responder_interface_info_[socket];
std::memset(&info, 0, sizeof(NetworkInterfaceInfo));
@@ -268,14 +270,16 @@
auto result = mDNS_RegisterInterface(&mdns_, &info, mDNSfalse);
OSP_LOG_IF(WARN, result != mStatus_NoError)
<< "mDNS_RegisterInterface failed: " << result;
- return result == mStatus_NoError;
+
+ return (result == mStatus_NoError) ? Error::None()
+ : Error::Code::kMdnsRegisterFailure;
}
-bool MdnsResponderAdapterImpl::DeregisterInterface(
+Error MdnsResponderAdapterImpl::DeregisterInterface(
platform::UdpSocketPtr socket) {
const auto info_it = responder_interface_info_.find(socket);
if (info_it == responder_interface_info_.end())
- return false;
+ return Error::Code::kNoItemFound;
const auto it = std::find(platform_storage_.sockets.begin(),
platform_storage_.sockets.end(), socket);
@@ -287,7 +291,7 @@
}
mDNS_DeregisterInterface(&mdns_, &info_it->second, mDNSfalse);
responder_interface_info_.erase(info_it);
- return true;
+ return Error::None();
}
void MdnsResponderAdapterImpl::OnDataReceived(
@@ -366,13 +370,14 @@
service_type.domain_name().end(), question.qname.c);
} else {
const DomainName local_domain = DomainName::GetLocalDomain();
- DomainName service_type_with_local;
- if (!DomainName::Append(service_type, local_domain,
- &service_type_with_local)) {
+ ErrorOr<DomainName> service_type_with_local =
+ DomainName::Append(service_type, local_domain);
+ if (!service_type_with_local) {
return MdnsResponderErrorCode::kDomainOverflowError;
}
- std::copy(service_type_with_local.domain_name().begin(),
- service_type_with_local.domain_name().end(), question.qname.c);
+ std::copy(service_type_with_local.value().domain_name().begin(),
+ service_type_with_local.value().domain_name().end(),
+ question.qname.c);
}
question.qtype = kDNSType_PTR;
question.qclass = kDNSClass_IN;
diff --git a/discovery/mdns/mdns_responder_adapter_impl.h b/discovery/mdns/mdns_responder_adapter_impl.h
index 0e6ffd5..c221ac7 100644
--- a/discovery/mdns/mdns_responder_adapter_impl.h
+++ b/discovery/mdns/mdns_responder_adapter_impl.h
@@ -9,6 +9,7 @@
#include <memory>
#include <vector>
+#include "base/error.h"
#include "discovery/mdns/mdns_responder_adapter.h"
#include "platform/api/socket.h"
#include "third_party/mDNSResponder/src/mDNSCore/mDNSEmbeddedAPI.h"
@@ -23,15 +24,15 @@
MdnsResponderAdapterImpl();
~MdnsResponderAdapterImpl() override;
- bool Init() override;
+ Error Init() override;
void Close() override;
- bool SetHostLabel(const std::string& host_label) override;
+ Error SetHostLabel(const std::string& host_label) override;
- bool RegisterInterface(const platform::InterfaceInfo& interface_info,
- const platform::IPSubnet& interface_address,
- platform::UdpSocketPtr socket) override;
- bool DeregisterInterface(platform::UdpSocketPtr socket) override;
+ Error RegisterInterface(const platform::InterfaceInfo& interface_info,
+ const platform::IPSubnet& interface_address,
+ platform::UdpSocketPtr socket) override;
+ Error DeregisterInterface(platform::UdpSocketPtr socket) override;
void OnDataReceived(const IPEndpoint& source,
const IPEndpoint& original_destination,
diff --git a/platform/api/event_waiter.h b/platform/api/event_waiter.h
index 6d27afc..dafd7aa 100644
--- a/platform/api/event_waiter.h
+++ b/platform/api/event_waiter.h
@@ -7,6 +7,7 @@
#include <vector>
+#include "base/error.h"
#include "platform/api/socket.h"
#include "platform/api/time.h"
@@ -43,37 +44,18 @@
EventWaiterPtr CreateEventWaiter();
void DestroyEventWaiter(EventWaiterPtr waiter);
-// Returns true if |socket| was successfully added to the set of watched
-// sockets, false otherwise. It will also return false if |socket| is already
-// being watched.
-bool WatchUdpSocketReadable(EventWaiterPtr waiter, UdpSocketPtr socket);
+Error WatchUdpSocketReadable(EventWaiterPtr waiter, UdpSocketPtr socket);
+Error StopWatchingUdpSocketReadable(EventWaiterPtr waiter, UdpSocketPtr socket);
-// Returns true if |socket| was successfully removed from the set of watched
-// sockets.
-bool StopWatchingUdpSocketReadable(EventWaiterPtr waiter, UdpSocketPtr socket);
+Error WatchUdpSocketWritable(EventWaiterPtr waiter, UdpSocketPtr socket);
+Error StopWatchingUdpSocketWritable(EventWaiterPtr waiter, UdpSocketPtr socket);
-// Returns true if |socket| was successfully added to the set of watched
-// sockets, false otherwise. It will also return false if |socket| is already
-// being watched.
-bool WatchUdpSocketWritable(EventWaiterPtr waiter, UdpSocketPtr socket);
-
-// Returns true if |socket| was successfully removed from the set of watched
-// sockets.
-bool StopWatchingUdpSocketWritable(EventWaiterPtr waiter, UdpSocketPtr socket);
-
-// Returns true if |waiter| successfully started monitoring network change
-// events, false otherwise. It will also return false if |waiter| is already
-// monitoring network change events.
-bool WatchNetworkChange(EventWaiterPtr waiter);
-
-// Returns true if |waiter| successfully stopped monitoring network change
-// events, false otherwise. It will also return false if |waiter| wasn't
-// monitoring network change events already.
-bool StopWatchingNetworkChange(EventWaiterPtr waiter);
+Error WatchNetworkChange(EventWaiterPtr waiter);
+Error StopWatchingNetworkChange(EventWaiterPtr waiter);
// Returns the number of events that were added to |events| if there were any, 0
// if there were no events, and -1 if an error occurred.
-int WaitForEvents(EventWaiterPtr waiter, Events* events);
+ErrorOr<Events> WaitForEvents(EventWaiterPtr waiter);
} // namespace platform
} // namespace openscreen
diff --git a/platform/api/socket.h b/platform/api/socket.h
index 994b49d..0636225 100644
--- a/platform/api/socket.h
+++ b/platform/api/socket.h
@@ -5,6 +5,7 @@
#ifndef PLATFORM_API_SOCKET_H_
#define PLATFORM_API_SOCKET_H_
+#include "base/error.h"
#include "base/ip_address.h"
#include "platform/api/network_interface.h"
#include "third_party/abseil/src/absl/types/optional.h"
@@ -29,12 +30,12 @@
// Closes the underlying platform socket and frees any allocated memory.
void DestroyUdpSocket(UdpSocketPtr socket);
-bool BindUdpSocket(UdpSocketPtr socket,
- const IPEndpoint& endpoint,
- NetworkInterfaceIndex ifindex);
-bool JoinUdpMulticastGroup(UdpSocketPtr socket,
- const IPAddress& address,
- NetworkInterfaceIndex ifindex);
+Error BindUdpSocket(UdpSocketPtr socket,
+ const IPEndpoint& endpoint,
+ NetworkInterfaceIndex ifindex);
+Error JoinUdpMulticastGroup(UdpSocketPtr socket,
+ const IPAddress& address,
+ NetworkInterfaceIndex ifindex);
absl::optional<int64_t> ReceiveUdp(UdpSocketPtr socket,
void* data,
diff --git a/platform/base/event_loop.cc b/platform/base/event_loop.cc
index d4b7c84..66a68b3 100644
--- a/platform/base/event_loop.cc
+++ b/platform/base/event_loop.cc
@@ -14,41 +14,41 @@
ReceivedData::ReceivedData() = default;
ReceivedData::~ReceivedData() = default;
-bool ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
- ReceivedData* data) {
+Error ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
+ ReceivedData* data) {
OSP_DCHECK(data);
absl::optional<int> len =
ReceiveUdp(read_event.socket, &data->bytes[0], data->bytes.size(),
&data->source, &data->original_destination);
if (!len) {
OSP_LOG_ERROR << "recv() failed: " << GetLastErrorString();
- return false;
+ return Error::Code::kSocketReadFailure;
} else if (len == 0) {
OSP_LOG_WARN << "recv() = 0, closed?";
- return false;
+ return Error::Code::kSocketClosedFailure;
}
OSP_DCHECK_LE(*len, kUdpMaxPacketSize);
data->length = *len;
data->socket = read_event.socket;
- return true;
+ return Error::None();
}
std::vector<ReceivedData> HandleUdpSocketReadEvents(const Events& events) {
std::vector<ReceivedData> data;
for (const auto& read_event : events.udp_readable_events) {
ReceivedData next_data;
- if (ReceiveDataFromEvent(read_event, &next_data))
+ if (ReceiveDataFromEvent(read_event, &next_data).ok())
data.emplace_back(std::move(next_data));
}
return data;
}
std::vector<ReceivedData> OnePlatformLoopIteration(EventWaiterPtr waiter) {
- Events events;
- if (!WaitForEvents(waiter, &events))
+ ErrorOr<Events> events = WaitForEvents(waiter);
+ if (!events)
return {};
- return HandleUdpSocketReadEvents(events);
+ return HandleUdpSocketReadEvents(events.value());
}
} // namespace platform
diff --git a/platform/base/event_loop.h b/platform/base/event_loop.h
index 492df3e..9a2194a 100644
--- a/platform/base/event_loop.h
+++ b/platform/base/event_loop.h
@@ -11,6 +11,7 @@
#include <cstdint>
#include <vector>
+#include "base/error.h"
#include "base/ip_address.h"
#include "platform/api/event_waiter.h"
@@ -30,8 +31,8 @@
UdpSocketPtr socket;
};
-bool ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
- ReceivedData* data);
+Error ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
+ ReceivedData* data);
std::vector<ReceivedData> HandleUdpSocketReadEvents(const Events& events);
std::vector<ReceivedData> OnePlatformLoopIteration(EventWaiterPtr waiter);
diff --git a/platform/posix/event_waiter.cc b/platform/posix/event_waiter.cc
index 26292b7..fa85962 100644
--- a/platform/posix/event_waiter.cc
+++ b/platform/posix/event_waiter.cc
@@ -9,6 +9,7 @@
#include <algorithm>
#include <vector>
+#include "base/error.h"
#include "platform/api/logging.h"
#include "platform/posix/socket.h"
@@ -17,24 +18,24 @@
namespace {
template <typename T>
-bool WatchUdpSocket(std::vector<T>* watched_sockets, T socket) {
+Error WatchUdpSocket(std::vector<T>* watched_sockets, T socket) {
for (const auto* s : *watched_sockets) {
if (s->fd == socket->fd)
- return false;
+ return Error::Code::kAlreadyListening;
}
watched_sockets->push_back(socket);
- return true;
+ return Error::None();
}
template <typename T>
-bool StopWatchingUdpSocket(std::vector<T>* watched_sockets, T socket) {
+Error StopWatchingUdpSocket(std::vector<T>* watched_sockets, T socket) {
const auto it = std::find_if(watched_sockets->begin(), watched_sockets->end(),
[socket](T s) { return s->fd == socket->fd; });
if (it == watched_sockets->end())
- return false;
+ return Error::Code::kNoItemFound;
watched_sockets->erase(it);
- return true;
+ return Error::None();
}
} // namespace
@@ -52,35 +53,37 @@
delete waiter;
}
-bool WatchUdpSocketReadable(EventWaiterPtr waiter, UdpSocketPtr socket) {
+Error WatchUdpSocketReadable(EventWaiterPtr waiter, UdpSocketPtr socket) {
return WatchUdpSocket(&waiter->read_sockets, socket);
}
-bool StopWatchingUdpSocketReadable(EventWaiterPtr waiter, UdpSocketPtr socket) {
+Error StopWatchingUdpSocketReadable(EventWaiterPtr waiter,
+ UdpSocketPtr socket) {
return StopWatchingUdpSocket(&waiter->read_sockets, socket);
}
-bool WatchUdpSocketWritable(EventWaiterPtr waiter, UdpSocketPtr socket) {
+Error WatchUdpSocketWritable(EventWaiterPtr waiter, UdpSocketPtr socket) {
return WatchUdpSocket(&waiter->write_sockets, socket);
}
-bool StopWatchingUdpSocketWritable(EventWaiterPtr waiter, UdpSocketPtr socket) {
+Error StopWatchingUdpSocketWritable(EventWaiterPtr waiter,
+ UdpSocketPtr socket) {
return StopWatchingUdpSocket(&waiter->write_sockets, socket);
}
-bool WatchNetworkChange(EventWaiterPtr waiter) {
+Error WatchNetworkChange(EventWaiterPtr waiter) {
// TODO(btolsch): Implement network change watching.
OSP_UNIMPLEMENTED();
- return false;
+ return Error::Code::kNotImplemented;
}
-bool StopWatchingNetworkChange(EventWaiterPtr waiter) {
+Error StopWatchingNetworkChange(EventWaiterPtr waiter) {
// TODO(btolsch): Implement stop network change watching.
OSP_UNIMPLEMENTED();
- return false;
+ return Error::Code::kNotImplemented;
}
-int WaitForEvents(EventWaiterPtr waiter, Events* events) {
+ErrorOr<Events> WaitForEvents(EventWaiterPtr waiter) {
int max_fd = -1;
fd_set readfds;
fd_set writefds;
@@ -95,22 +98,24 @@
max_fd = std::max(max_fd, write_socket->fd);
}
if (max_fd == -1)
- return 0;
+ return Error::Code::kIOFailure;
struct timeval tv = {};
const int rv = select(max_fd + 1, &readfds, &writefds, nullptr, &tv);
if (rv == -1 || rv == 0)
- return rv;
+ return Error::Code::kIOFailure;
+ Events events;
for (auto* read_socket : waiter->read_sockets) {
if (FD_ISSET(read_socket->fd, &readfds))
- events->udp_readable_events.push_back({read_socket});
+ events.udp_readable_events.push_back({read_socket});
}
for (auto* write_socket : waiter->write_sockets) {
if (FD_ISSET(write_socket->fd, &writefds))
- events->udp_writable_events.push_back({write_socket});
+ events.udp_writable_events.push_back({write_socket});
}
- return rv;
+
+ return std::move(events);
}
} // namespace platform
diff --git a/platform/posix/socket.cc b/platform/posix/socket.cc
index 8839cce..a1f3d20 100644
--- a/platform/posix/socket.cc
+++ b/platform/posix/socket.cc
@@ -14,6 +14,7 @@
#include <cstring>
#include <memory>
+#include "base/error.h"
#include "platform/api/logging.h"
#include "platform/posix/socket.h"
@@ -74,9 +75,9 @@
delete socket;
}
-bool BindUdpSocket(UdpSocketPtr socket,
- const IPEndpoint& endpoint,
- NetworkInterfaceIndex ifindex) {
+Error BindUdpSocket(UdpSocketPtr socket,
+ const IPEndpoint& endpoint,
+ NetworkInterfaceIndex ifindex) {
OSP_DCHECK_GE(socket->fd, 0);
if (socket->version == UdpSocketPrivate::Version::kV4) {
if (ifindex > 0) {
@@ -89,14 +90,14 @@
if (setsockopt(socket->fd, IPPROTO_IP, IP_MULTICAST_IF,
&multicast_properties,
sizeof(multicast_properties)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
// This is effectively a boolean passed to setsockopt() to allow a future
// bind() on |socket| to succeed, even if the address is already in use.
const int reuse_addr = 1;
if (setsockopt(socket->fd, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
sizeof(reuse_addr)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
}
@@ -105,21 +106,24 @@
address.sin_port = htons(endpoint.port);
endpoint.address.CopyToV4(
reinterpret_cast<uint8_t*>(&address.sin_addr.s_addr));
- return bind(socket->fd, reinterpret_cast<struct sockaddr*>(&address),
- sizeof(address)) != -1;
+ if (bind(socket->fd, reinterpret_cast<struct sockaddr*>(&address),
+ sizeof(address)) == -1) {
+ return Error::Code::kSocketBindFailure;
+ }
+ return Error::None();
} else {
if (ifindex > 0) {
const auto index = static_cast<IPv6NetworkInterfaceIndex>(ifindex);
if (setsockopt(socket->fd, IPPROTO_IPV6, IPV6_MULTICAST_IF, &index,
sizeof(index)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
// This is effectively a boolean passed to setsockopt() to allow a future
// bind() on |socket| to succeed, even if the address is already in use.
const int reuse_addr = 1;
if (setsockopt(socket->fd, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
sizeof(reuse_addr)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
}
@@ -129,14 +133,18 @@
address.sin6_port = htons(endpoint.port);
endpoint.address.CopyToV6(reinterpret_cast<uint8_t*>(&address.sin6_addr));
address.sin6_scope_id = 0;
- return bind(socket->fd, reinterpret_cast<struct sockaddr*>(&address),
- sizeof(address)) != -1;
+
+ if (bind(socket->fd, reinterpret_cast<struct sockaddr*>(&address),
+ sizeof(address)) == -1) {
+ return Error::Code::kSocketBindFailure;
+ }
+ return Error::None();
}
}
-bool JoinUdpMulticastGroup(UdpSocketPtr socket,
- const IPAddress& address,
- NetworkInterfaceIndex ifindex) {
+Error JoinUdpMulticastGroup(UdpSocketPtr socket,
+ const IPAddress& address,
+ NetworkInterfaceIndex ifindex) {
OSP_DCHECK_GE(socket->fd, 0);
if (socket->version == UdpSocketPrivate::Version::kV4) {
// Passed as data to setsockopt(). 1 means return IP_PKTINFO control data
@@ -144,7 +152,7 @@
const int enable_pktinfo = 1;
if (setsockopt(socket->fd, IPPROTO_IP, IP_PKTINFO, &enable_pktinfo,
sizeof(enable_pktinfo)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
struct ip_mreqn multicast_properties;
// Appropriate address is set based on |imr_ifindex| when set.
@@ -157,16 +165,16 @@
reinterpret_cast<uint8_t*>(&multicast_properties.imr_multiaddr));
if (setsockopt(socket->fd, IPPROTO_IP, IP_ADD_MEMBERSHIP,
&multicast_properties, sizeof(multicast_properties)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
- return true;
+ return Error::None();
} else {
// Passed as data to setsockopt(). 1 means return IPV6_PKTINFO control data
// in recvmsg() calls.
const int enable_pktinfo = 1;
if (setsockopt(socket->fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &enable_pktinfo,
sizeof(enable_pktinfo)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
struct ipv6_mreq multicast_properties = {
{/* filled-in below */},
@@ -180,9 +188,9 @@
// synonymous with IPV6_ADD_MEMBERSHIP.
if (setsockopt(socket->fd, IPPROTO_IPV6, IPV6_JOIN_GROUP,
&multicast_properties, sizeof(multicast_properties)) == -1) {
- return false;
+ return Error::Code::kSocketOptionSettingFailure;
}
- return true;
+ return Error::None();
}
}