Add QuicClient+QuicServer+MessageDemuxer
This change adds a basic QUIC layer which includes a CBOR message
demuxer. It also includes a number of other changes due to varying
degrees of necessity.
1. A few .cc files were moved from api/public to api/impl when they
relied on other api/impl files to clean up the build dependencies.
2. Binding a UDP socket now uses the address of the IPEndpoint.
3. A platform LogInit function is added to optionally open a log file
or pipe instead of always using stdout.
Bug: openscreen:14
Change-Id: I18dae59d1a961328c45644f68f1e529d0cdb0e68
Reviewed-on: https://chromium-review.googlesource.com/c/1336451
Commit-Queue: Brandon Tolsch <btolsch@chromium.org>
Reviewed-by: mark a. foltz <mfoltz@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 724baf2..14c270a 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -19,17 +19,29 @@
]
}
-executable("demo") {
- sources = [
- "//demo/demo.cc",
- ]
+if (current_os == "mac") {
+ source_set("demo") {
+ }
+} else {
+ # TODO(btolsch): Darwin linker has deprecated -m, which handles the multiple
+ # definition error. Until the boringssl/mDNSResponder conflict is fixed
+ # (which appears to be exactly the same code in both projects), the demo is
+ # excluded on mac.
+ executable("demo") {
+ sources = [
+ "//demo/demo.cc",
+ ]
- deps = [
- "//api",
- "//base",
- "//discovery/mdns",
- "//platform",
- ]
+ # TODO(btolsch): Handles MD5_* conflict between boringssl and mDNSResponder
+ # temporarily.
+ ldflags = [ "-Wl,-z,muldefs" ]
+ deps = [
+ "//api:api_with_chromium_quic",
+ "//base",
+ "//discovery/mdns",
+ "//platform",
+ ]
+ }
}
executable("unittests") {
diff --git a/api/BUILD.gn b/api/BUILD.gn
index 73b3415..a435696 100644
--- a/api/BUILD.gn
+++ b/api/BUILD.gn
@@ -3,9 +3,20 @@
# found in the LICENSE file.
source_set("api") {
+ public_deps = [
+ "public:api",
+ ]
deps = [
"impl",
- "public:api",
+ ]
+}
+
+source_set("api_with_chromium_quic") {
+ public_deps = [
+ ":api",
+ ]
+ deps = [
+ "impl:chromium_quic_integration",
]
}
@@ -14,17 +25,20 @@
sources = [
"impl/mdns_responder_service_unittest.cc",
+ "impl/quic/quic_client_unittest.cc",
+ "impl/quic/quic_server_unittest.cc",
"impl/screen_list_unittest.cc",
"impl/screen_listener_impl_unittest.cc",
"impl/screen_publisher_impl_unittest.cc",
+ "public/message_demuxer_unittest.cc",
"public/screen_info_unittest.cc",
]
deps = [
- "impl",
+ ":api",
+ "impl/quic:test_support",
"impl/testing",
"impl/testing:fakes_unittests",
- "public:api",
"//third_party/googletest:gmock",
"//third_party/googletest:gtest",
]
diff --git a/api/impl/BUILD.gn b/api/impl/BUILD.gn
index 175a7c1..d5a936a 100644
--- a/api/impl/BUILD.gn
+++ b/api/impl/BUILD.gn
@@ -11,9 +11,8 @@
"mdns_responder_service.cc",
"mdns_responder_service.h",
"mdns_screen_listener_factory.cc",
- "mdns_screen_listener_factory.h",
"mdns_screen_publisher_factory.cc",
- "mdns_screen_publisher_factory.h",
+ "network_service_manager.cc",
"screen_list.cc",
"screen_list.h",
"screen_listener_impl.cc",
@@ -22,9 +21,34 @@
"screen_publisher_impl.h",
]
- deps = [
+ public_deps = [
"../public:api",
+ "//msgs",
+ ]
+ deps = [
+ "quic",
"//base",
"//platform",
]
}
+
+source_set("chromium_quic_integration") {
+ sources = [
+ "protocol_connection_client_factory.cc",
+ "protocol_connection_server_factory.cc",
+ "quic/quic_connection_factory_impl.cc",
+ "quic/quic_connection_factory_impl.h",
+ "quic/quic_connection_impl.cc",
+ "quic/quic_connection_impl.h",
+ ]
+
+ public_configs = [ "//third_party/chromium_quic:chromium_quic_config" ]
+
+ deps = [
+ "quic",
+ "//base",
+ "//msgs",
+ "//platform",
+ "//third_party/chromium_quic",
+ ]
+}
diff --git a/api/impl/internal_services.cc b/api/impl/internal_services.cc
index 9ef0d3e..c6bcb8f 100644
--- a/api/impl/internal_services.cc
+++ b/api/impl/internal_services.cc
@@ -23,11 +23,7 @@
0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
};
-const IPEndpoint kMulticastListeningEndpoint{IPAddress{0, 0, 0, 0}, 5353};
-const IPEndpoint kMulticastIPv6ListeningEndpoint{
- IPAddress{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00, 0x00, 0x00},
- 5353};
+const uint16_t kMulticastListeningPort = 5353;
class MdnsResponderAdapterImplFactory final
: public MdnsResponderAdapterFactory {
@@ -51,10 +47,7 @@
break;
}
- const IPEndpoint listen_endpoint = IsIPv6Socket(socket)
- ? kMulticastIPv6ListeningEndpoint
- : kMulticastListeningEndpoint;
- if (!BindUdpSocket(socket, listen_endpoint, ifindex)) {
+ if (!BindUdpSocket(socket, {{}, kMulticastListeningPort}, ifindex)) {
OSP_LOG_ERROR << "bind failed for interface " << ifindex << ": "
<< platform::GetLastErrorString();
break;
@@ -81,7 +74,7 @@
void InternalServices::RunEventLoopOnce() {
OSP_CHECK(g_instance) << "No listener or publisher is alive.";
g_instance->mdns_service_.HandleNewEvents(
- platform::OnePlatformLoopIteration(g_instance->internal_service_waiter_));
+ platform::OnePlatformLoopIteration(g_instance->mdns_waiter_));
}
// static
@@ -174,20 +167,20 @@
kServiceProtocol,
std::make_unique<MdnsResponderAdapterImplFactory>(),
std::make_unique<InternalPlatformLinkage>(this)),
- internal_service_waiter_(platform::CreateEventWaiter()) {
- OSP_DCHECK(internal_service_waiter_);
+ mdns_waiter_(platform::CreateEventWaiter()) {
+ OSP_DCHECK(mdns_waiter_);
}
InternalServices::~InternalServices() {
- DestroyEventWaiter(internal_service_waiter_);
+ DestroyEventWaiter(mdns_waiter_);
}
void InternalServices::RegisterMdnsSocket(platform::UdpSocketPtr socket) {
- platform::WatchUdpSocketReadable(internal_service_waiter_, socket);
+ platform::WatchUdpSocketReadable(mdns_waiter_, socket);
}
void InternalServices::DeregisterMdnsSocket(platform::UdpSocketPtr socket) {
- platform::StopWatchingUdpSocketReadable(internal_service_waiter_, socket);
+ platform::StopWatchingUdpSocketReadable(mdns_waiter_, socket);
}
// static
diff --git a/api/impl/internal_services.h b/api/impl/internal_services.h
index 0de74f2..113564e 100644
--- a/api/impl/internal_services.h
+++ b/api/impl/internal_services.h
@@ -10,10 +10,14 @@
#include "api/impl/mdns_platform_service.h"
#include "api/impl/mdns_responder_service.h"
-#include "api/impl/mdns_screen_listener_factory.h"
-#include "api/impl/mdns_screen_publisher_factory.h"
+#include "api/impl/quic/quic_connection_factory.h"
#include "api/impl/screen_listener_impl.h"
#include "api/impl/screen_publisher_impl.h"
+#include "api/public/mdns_screen_listener_factory.h"
+#include "api/public/mdns_screen_publisher_factory.h"
+#include "api/public/protocol_connection_client.h"
+#include "api/public/protocol_connection_server.h"
+#include "base/ip_address.h"
#include "base/macros.h"
#include "platform/api/event_waiter.h"
#include "platform/api/network_interface.h"
@@ -70,7 +74,7 @@
// - remember who registered for what in a wrapper here
// - something else...
// Currently, RegisterMdnsSocket is our hook to do 1 or 2.
- platform::EventWaiterPtr internal_service_waiter_;
+ platform::EventWaiterPtr mdns_waiter_;
DISALLOW_COPY_AND_ASSIGN(InternalServices);
};
diff --git a/api/impl/mdns_responder_service.h b/api/impl/mdns_responder_service.h
index 2fe46b2..546d1c3 100644
--- a/api/impl/mdns_responder_service.h
+++ b/api/impl/mdns_responder_service.h
@@ -75,6 +75,7 @@
// NOTE: hostname implicit in map key.
struct HostnameWatchers {
std::vector<ServiceInstance*> services;
+ // TODO(btolsch): std::vector<IPAddress>
IPAddress address;
};
diff --git a/api/impl/mdns_screen_listener_factory.cc b/api/impl/mdns_screen_listener_factory.cc
index f9ab9a5..9ce282a 100644
--- a/api/impl/mdns_screen_listener_factory.cc
+++ b/api/impl/mdns_screen_listener_factory.cc
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "api/impl/mdns_screen_listener_factory.h"
+#include "api/public/mdns_screen_listener_factory.h"
#include "api/impl/internal_services.h"
diff --git a/api/impl/mdns_screen_publisher_factory.cc b/api/impl/mdns_screen_publisher_factory.cc
index ed31f54..ed57975 100644
--- a/api/impl/mdns_screen_publisher_factory.cc
+++ b/api/impl/mdns_screen_publisher_factory.cc
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "api/impl/mdns_screen_publisher_factory.h"
+#include "api/public/mdns_screen_publisher_factory.h"
#include "api/impl/internal_services.h"
diff --git a/api/public/network_service_manager.cc b/api/impl/network_service_manager.cc
similarity index 95%
rename from api/public/network_service_manager.cc
rename to api/impl/network_service_manager.cc
index 14a4fb9..bbf9f4b 100644
--- a/api/public/network_service_manager.cc
+++ b/api/impl/network_service_manager.cc
@@ -48,6 +48,10 @@
void NetworkServiceManager::RunEventLoopOnce() {
InternalServices::RunEventLoopOnce();
+ if (connection_client_)
+ connection_client_->RunTasks();
+ if (connection_server_)
+ connection_server_->RunTasks();
}
ScreenListener* NetworkServiceManager::GetMdnsScreenListener() {
diff --git a/api/impl/protocol_connection_client_factory.cc b/api/impl/protocol_connection_client_factory.cc
new file mode 100644
index 0000000..431b709
--- /dev/null
+++ b/api/impl/protocol_connection_client_factory.cc
@@ -0,0 +1,22 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/public/protocol_connection_client_factory.h"
+
+#include "api/impl/quic/quic_client.h"
+#include "api/impl/quic/quic_connection_factory_impl.h"
+#include "base/make_unique.h"
+
+namespace openscreen {
+
+// static
+std::unique_ptr<ProtocolConnectionClient>
+ProtocolConnectionClientFactory::Create(
+ MessageDemuxer* demuxer,
+ ProtocolConnectionServiceObserver* observer) {
+ return MakeUnique<QuicClient>(
+ demuxer, MakeUnique<QuicConnectionFactoryImpl>(), observer);
+}
+
+} // namespace openscreen
diff --git a/api/impl/protocol_connection_server_factory.cc b/api/impl/protocol_connection_server_factory.cc
new file mode 100644
index 0000000..b82f307
--- /dev/null
+++ b/api/impl/protocol_connection_server_factory.cc
@@ -0,0 +1,23 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/public/protocol_connection_server_factory.h"
+
+#include "api/impl/quic/quic_connection_factory_impl.h"
+#include "api/impl/quic/quic_server.h"
+#include "base/make_unique.h"
+
+namespace openscreen {
+
+// static
+std::unique_ptr<ProtocolConnectionServer>
+ProtocolConnectionServerFactory::Create(
+ const ServerConfig& config,
+ MessageDemuxer* demuxer,
+ ProtocolConnectionServer::Observer* observer) {
+ return MakeUnique<QuicServer>(
+ config, demuxer, MakeUnique<QuicConnectionFactoryImpl>(), observer);
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/BUILD.gn b/api/impl/quic/BUILD.gn
new file mode 100644
index 0000000..c1ec3f9
--- /dev/null
+++ b/api/impl/quic/BUILD.gn
@@ -0,0 +1,37 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+source_set("quic") {
+ sources = [
+ "quic_client.cc",
+ "quic_client.h",
+ "quic_connection.h",
+ "quic_connection_factory.h",
+ "quic_server.cc",
+ "quic_server.h",
+ "quic_service_common.cc",
+ "quic_service_common.h",
+ ]
+
+ deps = [
+ "../../public:api",
+ "//base",
+ "//platform",
+ ]
+}
+
+source_set("test_support") {
+ sources = [
+ "testing/fake_quic_connection.cc",
+ "testing/fake_quic_connection.h",
+ "testing/fake_quic_connection_factory.cc",
+ "testing/fake_quic_connection_factory.h",
+ ]
+
+ deps = [
+ "//base",
+ "//msgs",
+ "//platform",
+ ]
+}
diff --git a/api/impl/quic/quic_client.cc b/api/impl/quic/quic_client.cc
new file mode 100644
index 0000000..e3dcbd4
--- /dev/null
+++ b/api/impl/quic/quic_client.cc
@@ -0,0 +1,206 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/quic_client.h"
+
+#include <algorithm>
+
+#include "base/make_unique.h"
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+QuicClient::QuicClient(
+ MessageDemuxer* demuxer,
+ std::unique_ptr<QuicConnectionFactory> connection_factory,
+ ProtocolConnectionServiceObserver* observer)
+ : ProtocolConnectionClient(demuxer, observer),
+ connection_factory_(std::move(connection_factory)) {}
+
+QuicClient::~QuicClient() {
+ CloseAllConnections();
+}
+
+bool QuicClient::Start() {
+ if (state_ == State::kRunning)
+ return false;
+ state_ = State::kRunning;
+ observer_->OnRunning();
+ return true;
+}
+
+bool QuicClient::Stop() {
+ if (state_ == State::kStopped)
+ return false;
+ CloseAllConnections();
+ state_ = State::kStopped;
+ observer_->OnStopped();
+ return true;
+}
+
+void QuicClient::RunTasks() {
+ connection_factory_->RunTasks();
+ for (auto& entry : delete_connections_)
+ connections_.erase(entry);
+ delete_connections_.clear();
+}
+
+QuicClient::ConnectRequest QuicClient::Connect(
+ const IPEndpoint& endpoint,
+ ConnectionRequestCallback* request) {
+ if (state_ != State::kRunning)
+ return ConnectRequest(this, 0);
+ auto endpoint_entry = endpoint_map_.find(endpoint);
+ if (endpoint_entry != endpoint_map_.end()) {
+ auto connection_entry = connections_.find(endpoint_entry->second);
+ if (connection_entry != connections_.end()) {
+ std::unique_ptr<QuicProtocolConnection> pc =
+ QuicProtocolConnection::FromExisting(
+ this, connection_entry->second.connection.get(),
+ connection_entry->second.delegate.get(), endpoint_entry->second);
+ request->OnConnectionOpened(0, std::move(pc));
+ return ConnectRequest(this, 0);
+ }
+ auto pending_entry = pending_connections_.find(endpoint);
+ if (pending_entry == pending_connections_.end()) {
+ uint64_t request_id = StartConnectionRequest(endpoint, request);
+ return ConnectRequest(this, request_id);
+ } else {
+ uint64_t request_id = next_request_id_++;
+ pending_entry->second.callbacks.emplace_back(request_id, request);
+ return ConnectRequest(this, request_id);
+ }
+ }
+
+ uint64_t request_id = StartConnectionRequest(endpoint, request);
+ return ConnectRequest(this, request_id);
+}
+
+std::unique_ptr<ProtocolConnection> QuicClient::CreateProtocolConnection(
+ uint64_t endpoint_id) {
+ if (state_ != State::kRunning)
+ return nullptr;
+ auto connection_entry = connections_.find(endpoint_id);
+ if (connection_entry == connections_.end())
+ return nullptr;
+ return QuicProtocolConnection::FromExisting(
+ this, connection_entry->second.connection.get(),
+ connection_entry->second.delegate.get(), endpoint_id);
+}
+
+void QuicClient::OnConnectionDestroyed(QuicProtocolConnection* connection) {
+ auto connection_entry = connections_.find(connection->endpoint_id());
+ if (connection_entry == connections_.end())
+ return;
+ connection_entry->second.delegate->DropProtocolConnection(connection);
+}
+
+uint64_t QuicClient::OnCryptoHandshakeComplete(
+ ServiceConnectionDelegate* delegate,
+ uint64_t connection_id) {
+ const IPEndpoint& endpoint = delegate->endpoint();
+ auto pending_entry = pending_connections_.find(endpoint);
+ if (pending_entry == pending_connections_.end())
+ return 0;
+ ServiceConnectionData connection_data = std::move(pending_entry->second.data);
+ auto* connection = connection_data.connection.get();
+ uint64_t endpoint_id = next_endpoint_id_++;
+ endpoint_map_[endpoint] = endpoint_id;
+ connections_.emplace(endpoint_id, std::move(connection_data));
+
+ for (auto& request : pending_entry->second.callbacks) {
+ request_map_.erase(request.first);
+ std::unique_ptr<QuicProtocolConnection> pc =
+ QuicProtocolConnection::FromExisting(this, connection, delegate,
+ endpoint_id);
+ request_map_.erase(request.first);
+ request.second->OnConnectionOpened(request.first, std::move(pc));
+ }
+ pending_connections_.erase(pending_entry);
+ return endpoint_id;
+}
+
+void QuicClient::OnIncomingStream(
+ std::unique_ptr<QuicProtocolConnection>&& connection) {
+ connection->CloseWriteEnd();
+ connection.reset();
+}
+
+void QuicClient::OnConnectionClosed(uint64_t endpoint_id,
+ uint64_t connection_id) {
+ // TODO(btolsch): Is this how handshake failure is communicated to the
+ // delegate?
+ auto connection_entry = connections_.find(endpoint_id);
+ if (connection_entry == connections_.end())
+ return;
+ delete_connections_.emplace_back(connection_entry);
+}
+
+void QuicClient::OnDataReceived(uint64_t endpoint_id,
+ uint64_t connection_id,
+ const uint8_t* data,
+ size_t data_size) {
+ demuxer_->OnStreamData(endpoint_id, connection_id, data, data_size);
+}
+
+QuicClient::PendingConnectionData::PendingConnectionData(
+ ServiceConnectionData&& data)
+ : data(std::move(data)) {}
+QuicClient::PendingConnectionData::PendingConnectionData(
+ PendingConnectionData&&) = default;
+QuicClient::PendingConnectionData::~PendingConnectionData() = default;
+QuicClient::PendingConnectionData& QuicClient::PendingConnectionData::operator=(
+ PendingConnectionData&&) = default;
+
+uint64_t QuicClient::StartConnectionRequest(
+ const IPEndpoint& endpoint,
+ ConnectionRequestCallback* request) {
+ auto delegate = MakeUnique<ServiceConnectionDelegate>(this, endpoint);
+ std::unique_ptr<QuicConnection> connection =
+ connection_factory_->Connect(endpoint, delegate.get());
+ auto pending_result = pending_connections_.emplace(
+ endpoint, PendingConnectionData(ServiceConnectionData(
+ std::move(connection), std::move(delegate))));
+ uint64_t request_id = next_request_id_++;
+ pending_result.first->second.callbacks.emplace_back(request_id, request);
+ return request_id;
+}
+
+void QuicClient::CloseAllConnections() {
+ for (auto& conn : pending_connections_)
+ conn.second.data.connection->Close();
+ pending_connections_.clear();
+ for (auto& conn : connections_)
+ conn.second.connection->Close();
+ connections_.clear();
+ endpoint_map_.clear();
+ next_endpoint_id_ = 0;
+ for (auto& request : request_map_) {
+ request.second.second->OnConnectionFailed(request.first);
+ }
+ request_map_.clear();
+}
+
+void QuicClient::CancelConnectRequest(uint64_t request_id) {
+ auto request_entry = request_map_.find(request_id);
+ if (request_entry == request_map_.end())
+ return;
+ auto pending_entry = pending_connections_.find(request_entry->second.first);
+ if (pending_entry != pending_connections_.end()) {
+ auto& callbacks = pending_entry->second.callbacks;
+ callbacks.erase(
+ std::remove_if(
+ callbacks.begin(), callbacks.end(),
+ [request_id](const std::pair<uint64_t, ConnectionRequestCallback*>&
+ callback) {
+ return request_id == callback.first;
+ }),
+ callbacks.end());
+ if (callbacks.empty())
+ pending_connections_.erase(pending_entry);
+ }
+ request_map_.erase(request_entry);
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/quic_client.h b/api/impl/quic/quic_client.h
new file mode 100644
index 0000000..cfedf48
--- /dev/null
+++ b/api/impl/quic/quic_client.h
@@ -0,0 +1,125 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_QUIC_CLIENT_H_
+#define API_IMPL_QUIC_QUIC_CLIENT_H_
+
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "api/impl/quic/quic_connection_factory.h"
+#include "api/impl/quic/quic_service_common.h"
+#include "api/public/protocol_connection_client.h"
+#include "base/ip_address.h"
+
+namespace openscreen {
+
+// This class is the default implementation of ProtocolConnectionClient for the
+// library. It manages connections to other endpoints as well as the lifetime
+// of each incoming and outgoing stream. It works in conjunction with a
+// QuicConnectionFactory implementation and MessageDemuxer.
+// QuicConnectionFactory provides the actual ability to make a new QUIC
+// connection with another endpoint. Incoming data is given to the QuicClient
+// by the underlying QUIC implementation (through QuicConnectionFactory) and
+// this is in turn handed to MessageDemuxer for routing CBOR messages.
+//
+// The two most significant methods of this class are Connect and
+// CreateProtocolConnection. Both will return a new QUIC stream to a given
+// endpoint to which the caller can write but the former is allowed to be
+// asynchronous. If there isn't currently a connection to the specified
+// endpoint, Connect will start a connection attempt and store the callback for
+// when the connection completes. CreateProtocolConnection simply returns
+// nullptr if there's no existing connection.
+class QuicClient final : public ProtocolConnectionClient,
+ public ServiceConnectionDelegate::ServiceDelegate {
+ public:
+ QuicClient(MessageDemuxer* demuxer,
+ std::unique_ptr<QuicConnectionFactory> connection_factory,
+ ProtocolConnectionServiceObserver* observer);
+ ~QuicClient() override;
+
+ // ProtocolConnectionClient overrides.
+ bool Start() override;
+ bool Stop() override;
+ void RunTasks() override;
+ ConnectRequest Connect(const IPEndpoint& endpoint,
+ ConnectionRequestCallback* request) override;
+ std::unique_ptr<ProtocolConnection> CreateProtocolConnection(
+ uint64_t endpoint_id) override;
+
+ // QuicProtocolConnection::Owner overrides.
+ void OnConnectionDestroyed(QuicProtocolConnection* connection) override;
+
+ // ServiceConnectionDelegate::ServiceDelegate overrides.
+ uint64_t OnCryptoHandshakeComplete(ServiceConnectionDelegate* delegate,
+ uint64_t connection_id) override;
+ void OnIncomingStream(
+ std::unique_ptr<QuicProtocolConnection>&& connection) override;
+ void OnConnectionClosed(uint64_t endpoint_id,
+ uint64_t connection_id) override;
+ void OnDataReceived(uint64_t endpoint_id,
+ uint64_t connection_id,
+ const uint8_t* data,
+ size_t data_size) override;
+
+ private:
+ struct PendingConnectionData {
+ explicit PendingConnectionData(ServiceConnectionData&& data);
+ PendingConnectionData(PendingConnectionData&&);
+ ~PendingConnectionData();
+ PendingConnectionData& operator=(PendingConnectionData&&);
+
+ ServiceConnectionData data;
+
+ // Pairs of request IDs and the associated connection callback.
+ std::vector<std::pair<uint64_t, ConnectionRequestCallback*>> callbacks;
+ };
+
+ uint64_t StartConnectionRequest(const IPEndpoint& endpoint,
+ ConnectionRequestCallback* request);
+ void CloseAllConnections();
+ std::unique_ptr<QuicProtocolConnection> MakeProtocolConnection(
+ QuicConnection* connection,
+ ServiceConnectionDelegate* delegate,
+ uint64_t endpoint_id);
+
+ void CancelConnectRequest(uint64_t request_id) override;
+
+ std::unique_ptr<QuicConnectionFactory> connection_factory_;
+
+ // Maps an IPEndpoint to a generated endpoint ID. This is used to insulate
+ // callers from post-handshake changes to a connections actual peer endpoint.
+ std::map<IPEndpoint, uint64_t, IPEndpointComparator> endpoint_map_;
+
+ // Value that will be used for the next new endpoint in a Connect call.
+ uint64_t next_endpoint_id_ = 1;
+
+ // Maps request IDs to their callbacks. The callback is paired with the
+ // IPEndpoint it originally requested to connect to so cancelling the request
+ // can also remove a pending connection.
+ std::map<uint64_t, std::pair<IPEndpoint, ConnectionRequestCallback*>>
+ request_map_;
+
+ // Value that will be used for the next new connection request.
+ uint64_t next_request_id_ = 1;
+
+ // Maps endpoint addresses to data about connections that haven't successfully
+ // completed the QUIC handshake.
+ std::map<IPEndpoint, PendingConnectionData, IPEndpointComparator>
+ pending_connections_;
+
+ // Maps endpoint IDs to data about connections that have successfully
+ // completed the QUIC handshake.
+ std::map<uint64_t, ServiceConnectionData> connections_;
+
+ // Connections that need to be destroyed, but have to wait for the next event
+ // loop due to the underlying QUIC implementation's way of referencing them.
+ std::vector<decltype(connections_)::iterator> delete_connections_;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_QUIC_CLIENT_H_
diff --git a/api/impl/quic/quic_client_unittest.cc b/api/impl/quic/quic_client_unittest.cc
new file mode 100644
index 0000000..ff93176
--- /dev/null
+++ b/api/impl/quic/quic_client_unittest.cc
@@ -0,0 +1,232 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/quic_client.h"
+
+#include "api/impl/quic/quic_service_common.h"
+#include "api/impl/quic/testing/fake_quic_connection_factory.h"
+#include "api/public/network_metrics.h"
+#include "base/error.h"
+#include "base/make_unique.h"
+#include "platform/api/logging.h"
+#include "third_party/googletest/src/googlemock/include/gmock/gmock.h"
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+namespace openscreen {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+
+class MockServiceObserver final : public ProtocolConnectionServiceObserver {
+ public:
+ ~MockServiceObserver() override = default;
+
+ MOCK_METHOD0(OnRunning, void());
+ MOCK_METHOD0(OnStopped, void());
+ MOCK_METHOD1(OnMetrics, void(const NetworkMetrics& metrics));
+ MOCK_METHOD1(OnError, void(const Error& error));
+};
+
+class MockMessageCallback final : public MessageDemuxer::MessageCallback {
+ public:
+ ~MockMessageCallback() override = default;
+
+ MOCK_METHOD5(OnStreamMessage,
+ ErrorOr<size_t>(uint64_t endpoint_id,
+ uint64_t connection_id,
+ msgs::Type message_type,
+ const uint8_t* buffer,
+ size_t buffer_size));
+};
+
+class MockConnectionObserver final : public ProtocolConnection::Observer {
+ public:
+ ~MockConnectionObserver() override = default;
+
+ MOCK_METHOD1(OnConnectionChanged, void(const ProtocolConnection& connection));
+ MOCK_METHOD1(OnConnectionClosed, void(const ProtocolConnection& connection));
+};
+
+class ConnectionCallback final
+ : public ProtocolConnectionClient::ConnectionRequestCallback {
+ public:
+ explicit ConnectionCallback(std::unique_ptr<ProtocolConnection>* connection)
+ : connection_(connection) {}
+ ~ConnectionCallback() override = default;
+
+ bool failed() const { return failed_; }
+
+ void OnConnectionOpened(
+ uint64_t request_id,
+ std::unique_ptr<ProtocolConnection>&& connection) override {
+ OSP_DCHECK(!failed_ && !*connection_);
+ *connection_ = std::move(connection);
+ }
+
+ void OnConnectionFailed(uint64_t request_id) override {
+ OSP_DCHECK(!failed_ && !*connection_);
+ failed_ = true;
+ }
+
+ private:
+ bool failed_ = false;
+ std::unique_ptr<ProtocolConnection>* const connection_;
+};
+
+class QuicClientTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ auto connection_factory = MakeUnique<FakeQuicConnectionFactory>(
+ local_endpoint_, &server_demuxer_);
+ connection_factory_ = connection_factory.get();
+ client_ = MakeUnique<QuicClient>(&demuxer_, std::move(connection_factory),
+ &mock_observer_);
+ }
+
+ void RunTasksUntilIdle() {
+ do {
+ client_->RunTasks();
+ } while (!connection_factory_->idle());
+ }
+
+ void SendTestMessage(ProtocolConnection* connection) {
+ MockMessageCallback mock_message_callback;
+ MessageDemuxer::MessageWatch message_watch =
+ server_demuxer_.WatchMessageType(
+ 0, msgs::Type::kPresentationConnectionMessage,
+ &mock_message_callback);
+
+ msgs::CborEncodeBuffer buffer;
+ msgs::PresentationConnectionMessage message;
+ message.presentation_id = "some-id";
+ message.connection_id = 7;
+ message.message.which = decltype(message.message.which)::kString;
+ new (&message.message.str) std::string("message from client");
+ ASSERT_TRUE(msgs::EncodePresentationConnectionMessage(message, &buffer));
+ connection->Write(buffer.data(), buffer.size());
+ connection->CloseWriteEnd();
+
+ ssize_t decode_result = 0;
+ msgs::PresentationConnectionMessage received_message;
+ EXPECT_CALL(
+ mock_message_callback,
+ OnStreamMessage(0, connection->connection_id(),
+ msgs::Type::kPresentationConnectionMessage, _, _))
+ .WillOnce(Invoke([&decode_result, &received_message](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionMessage(
+ buffer, buffer_size, &received_message);
+ if (decode_result < 0)
+ return ErrorOr<size_t>(Error::Code::kCborParsing);
+ return ErrorOr<size_t>(decode_result);
+ }));
+ RunTasksUntilIdle();
+
+ ASSERT_GT(decode_result, 0);
+ EXPECT_EQ(decode_result, static_cast<ssize_t>(buffer.size() - 1));
+ EXPECT_EQ(received_message.presentation_id, message.presentation_id);
+ EXPECT_EQ(received_message.connection_id, message.connection_id);
+ ASSERT_EQ(received_message.message.which,
+ decltype(received_message.message.which)::kString);
+ EXPECT_EQ(received_message.message.str, message.message.str);
+ }
+
+ const IPEndpoint local_endpoint_{{192, 168, 1, 10}, 44327};
+ const IPEndpoint server_endpoint_{{192, 168, 1, 15}, 54368};
+ MessageDemuxer demuxer_;
+ MessageDemuxer server_demuxer_;
+ FakeQuicConnectionFactory* connection_factory_;
+ MockServiceObserver mock_observer_;
+ std::unique_ptr<QuicClient> client_;
+};
+
+} // namespace
+
+TEST_F(QuicClientTest, Connect) {
+ client_->Start();
+
+ std::unique_ptr<ProtocolConnection> connection;
+ ConnectionCallback connection_callback(&connection);
+ ProtocolConnectionClient::ConnectRequest request =
+ client_->Connect(server_endpoint_, &connection_callback);
+ ASSERT_TRUE(request);
+
+ RunTasksUntilIdle();
+ ASSERT_TRUE(connection);
+
+ SendTestMessage(connection.get());
+
+ client_->Stop();
+}
+
+TEST_F(QuicClientTest, OpenImmediate) {
+ client_->Start();
+
+ std::unique_ptr<ProtocolConnection> connection1;
+ std::unique_ptr<ProtocolConnection> connection2;
+
+ connection2 = client_->CreateProtocolConnection(1);
+ EXPECT_FALSE(connection2);
+
+ ConnectionCallback connection_callback(&connection1);
+ ProtocolConnectionClient::ConnectRequest request =
+ client_->Connect(server_endpoint_, &connection_callback);
+ ASSERT_TRUE(request);
+
+ connection2 = client_->CreateProtocolConnection(1);
+ EXPECT_FALSE(connection2);
+
+ RunTasksUntilIdle();
+ ASSERT_TRUE(connection1);
+
+ connection2 = client_->CreateProtocolConnection(connection1->endpoint_id());
+ ASSERT_TRUE(connection2);
+
+ SendTestMessage(connection2.get());
+
+ client_->Stop();
+}
+
+TEST_F(QuicClientTest, States) {
+ std::unique_ptr<ProtocolConnection> connection1;
+ ConnectionCallback connection_callback(&connection1);
+ ProtocolConnectionClient::ConnectRequest request =
+ client_->Connect(server_endpoint_, &connection_callback);
+ EXPECT_FALSE(request);
+ std::unique_ptr<ProtocolConnection> connection2 =
+ client_->CreateProtocolConnection(1);
+ EXPECT_FALSE(connection2);
+
+ EXPECT_CALL(mock_observer_, OnRunning());
+ EXPECT_TRUE(client_->Start());
+ EXPECT_FALSE(client_->Start());
+
+ request = client_->Connect(server_endpoint_, &connection_callback);
+ ASSERT_TRUE(request);
+ RunTasksUntilIdle();
+ ASSERT_TRUE(connection1);
+ MockConnectionObserver mock_connection_observer1;
+ connection1->SetObserver(&mock_connection_observer1);
+
+ connection2 = client_->CreateProtocolConnection(connection1->endpoint_id());
+ ASSERT_TRUE(connection2);
+ MockConnectionObserver mock_connection_observer2;
+ connection2->SetObserver(&mock_connection_observer2);
+
+ EXPECT_CALL(mock_connection_observer1, OnConnectionClosed(_));
+ EXPECT_CALL(mock_connection_observer2, OnConnectionClosed(_));
+ EXPECT_CALL(mock_observer_, OnStopped());
+ EXPECT_TRUE(client_->Stop());
+ EXPECT_FALSE(client_->Stop());
+
+ request = client_->Connect(server_endpoint_, &connection_callback);
+ EXPECT_FALSE(request);
+ connection2 = client_->CreateProtocolConnection(1);
+ EXPECT_FALSE(connection2);
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/quic_connection.h b/api/impl/quic/quic_connection.h
new file mode 100644
index 0000000..0503683
--- /dev/null
+++ b/api/impl/quic/quic_connection.h
@@ -0,0 +1,86 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_QUIC_CONNECTION_H_
+#define API_IMPL_QUIC_QUIC_CONNECTION_H_
+
+#include <memory>
+#include <vector>
+
+#include "platform/api/socket.h"
+#include "platform/base/event_loop.h"
+
+namespace openscreen {
+
+class QuicStream {
+ public:
+ class Delegate {
+ public:
+ virtual ~Delegate() = default;
+
+ virtual void OnReceived(QuicStream* stream,
+ const char* data,
+ size_t data_size) = 0;
+ virtual void OnClose(uint64_t stream_id) = 0;
+ };
+
+ QuicStream(Delegate* delegate, uint64_t id) : delegate_(delegate), id_(id) {}
+ virtual ~QuicStream() = default;
+
+ uint64_t id() const { return id_; }
+ virtual void Write(const uint8_t* data, size_t data_size) = 0;
+ virtual void CloseWriteEnd() = 0;
+
+ protected:
+ Delegate* const delegate_;
+ uint64_t id_;
+};
+
+class QuicConnection {
+ public:
+ class Delegate {
+ public:
+ virtual ~Delegate() = default;
+
+ // Called when the QUIC handshake has successfully completed.
+ virtual void OnCryptoHandshakeComplete(uint64_t connection_id) = 0;
+
+ // Called when a new stream on this connection is initiated by the other
+ // endpoint. |stream| will use a delegate returned by NextStreamDelegate.
+ virtual void OnIncomingStream(uint64_t connection_id,
+ std::unique_ptr<QuicStream> stream) = 0;
+
+ // Called when the QUIC connection was closed. The QuicConnection should
+ // not be destroyed immediately, because the QUIC implementation will still
+ // reference it briefly. Instead, it should be destroyed during the next
+ // event loop.
+ // TODO(btolsch): Hopefully this can be changed with future QUIC
+ // implementations.
+ virtual void OnConnectionClosed(uint64_t connection_id) = 0;
+
+ // This is used to get a QuicStream::Delegate for an incoming stream, which
+ // will be returned via OnIncomingStream immediately after this call.
+ virtual QuicStream::Delegate* NextStreamDelegate(uint64_t connection_id,
+ uint64_t stream_id) = 0;
+ };
+
+ explicit QuicConnection(Delegate* delegate) : delegate_(delegate) {}
+ virtual ~QuicConnection() = default;
+
+ // Passes a received UDP packet to the QUIC implementation. If this contains
+ // any stream data, it will be passed automatically to the relevant
+ // QuicStream::Delegate objects.
+ virtual void OnDataReceived(const platform::ReceivedData& data) = 0;
+
+ virtual std::unique_ptr<QuicStream> MakeOutgoingStream(
+ QuicStream::Delegate* delegate) = 0;
+ virtual void Close() = 0;
+
+ protected:
+ Delegate* const delegate_;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_QUIC_CONNECTION_H_
diff --git a/api/impl/quic/quic_connection_factory.h b/api/impl/quic/quic_connection_factory.h
new file mode 100644
index 0000000..22a1076
--- /dev/null
+++ b/api/impl/quic/quic_connection_factory.h
@@ -0,0 +1,48 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_QUIC_CONNECTION_FACTORY_H_
+#define API_IMPL_QUIC_QUIC_CONNECTION_FACTORY_H_
+
+#include <memory>
+#include <vector>
+
+#include "api/impl/quic/quic_connection.h"
+#include "base/ip_address.h"
+
+namespace openscreen {
+
+// This interface provides a way to make new QUIC connections to endpoints. It
+// also provides a way to receive incoming QUIC connections (as a server).
+class QuicConnectionFactory {
+ public:
+ class ServerDelegate {
+ public:
+ virtual ~ServerDelegate() = default;
+
+ virtual QuicConnection::Delegate* NextConnectionDelegate(
+ const IPEndpoint& source) = 0;
+ virtual void OnIncomingConnection(
+ std::unique_ptr<QuicConnection>&& connection) = 0;
+ };
+
+ virtual ~QuicConnectionFactory() = default;
+
+ // Initializes a server socket listening on |port| where new connection
+ // callbacks are sent to |delegate|.
+ virtual void SetServerDelegate(ServerDelegate* delegate,
+ const std::vector<IPEndpoint>& endpoints) = 0;
+
+ // Listen for incoming network packets on both client and server sockets and
+ // dispatch any results.
+ virtual void RunTasks() = 0;
+
+ virtual std::unique_ptr<QuicConnection> Connect(
+ const IPEndpoint& endpoint,
+ QuicConnection::Delegate* connection_delegate) = 0;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_QUIC_CONNECTION_FACTORY_H_
diff --git a/api/impl/quic/quic_connection_factory_impl.cc b/api/impl/quic/quic_connection_factory_impl.cc
new file mode 100644
index 0000000..6ec0200
--- /dev/null
+++ b/api/impl/quic/quic_connection_factory_impl.cc
@@ -0,0 +1,179 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/quic_connection_factory_impl.h"
+
+#include <algorithm>
+
+#include "api/impl/quic/quic_connection_impl.h"
+#include "base/make_unique.h"
+#include "platform/api/logging.h"
+#include "platform/base/event_loop.h"
+#include "third_party/chromium_quic/src/base/location.h"
+#include "third_party/chromium_quic/src/base/task_runner.h"
+#include "third_party/chromium_quic/src/net/third_party/quic/core/quic_constants.h"
+#include "third_party/chromium_quic/src/net/third_party/quic/platform/impl/quic_chromium_clock.h"
+
+namespace openscreen {
+
+struct Task {
+ ::base::Location whence;
+ ::base::OnceClosure task;
+ ::base::TimeDelta delay;
+};
+
+class QuicTaskRunner final : public ::base::TaskRunner {
+ public:
+ QuicTaskRunner();
+ ~QuicTaskRunner() override;
+
+ void RunTasks();
+
+ // base::TaskRunner overrides.
+ bool PostDelayedTask(const ::base::Location& whence,
+ ::base::OnceClosure task,
+ ::base::TimeDelta delay) override;
+
+ bool RunsTasksInCurrentSequence() const override;
+
+ private:
+ uint64_t last_run_unix_;
+ std::list<Task> tasks_;
+};
+
+QuicTaskRunner::QuicTaskRunner() = default;
+QuicTaskRunner::~QuicTaskRunner() = default;
+
+void QuicTaskRunner::RunTasks() {
+ auto* clock = ::quic::QuicChromiumClock::GetInstance();
+ ::quic::QuicWallTime now = clock->WallNow();
+ uint64_t now_unix = now.ToUNIXMicroseconds();
+ for (auto it = tasks_.begin(); it != tasks_.end();) {
+ Task& next_task = *it;
+ next_task.delay -=
+ ::base::TimeDelta::FromMicroseconds(now_unix - last_run_unix_);
+ if (next_task.delay.InMicroseconds() < 0) {
+ std::move(next_task.task).Run();
+ it = tasks_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ last_run_unix_ = now_unix;
+}
+
+bool QuicTaskRunner::PostDelayedTask(const ::base::Location& whence,
+ ::base::OnceClosure task,
+ ::base::TimeDelta delay) {
+ tasks_.push_back({whence, std::move(task), delay});
+ return true;
+}
+
+bool QuicTaskRunner::RunsTasksInCurrentSequence() const {
+ return true;
+}
+
+QuicConnectionFactoryImpl::QuicConnectionFactoryImpl() {
+ task_runner_ = ::base::MakeRefCounted<QuicTaskRunner>();
+ alarm_factory_ = MakeUnique<::net::QuicChromiumAlarmFactory>(
+ task_runner_.get(), ::quic::QuicChromiumClock::GetInstance());
+ ::quic::QuartcFactoryConfig factory_config;
+ factory_config.alarm_factory = alarm_factory_.get();
+ factory_config.clock = ::quic::QuicChromiumClock::GetInstance();
+ quartc_factory_ = MakeUnique<::quic::QuartcFactory>(factory_config);
+ waiter_ = platform::CreateEventWaiter();
+}
+
+QuicConnectionFactoryImpl::~QuicConnectionFactoryImpl() {
+ OSP_DCHECK(connections_.empty());
+ for (auto* socket : server_sockets_)
+ platform::DestroyUdpSocket(socket);
+ platform::DestroyEventWaiter(waiter_);
+}
+
+void QuicConnectionFactoryImpl::SetServerDelegate(
+ ServerDelegate* delegate,
+ const std::vector<IPEndpoint>& endpoints) {
+ server_delegate_ = delegate;
+ server_sockets_.reserve(endpoints.size());
+ for (const auto& endpoint : endpoints) {
+ auto server_socket = (endpoint.address.version() == IPAddress::Version::kV4)
+ ? platform::CreateUdpSocketIPv4()
+ : platform::CreateUdpSocketIPv6();
+ platform::BindUdpSocket(server_socket, endpoint, 0);
+ platform::WatchUdpSocketReadable(waiter_, server_socket);
+ server_sockets_.push_back(server_socket);
+ }
+}
+
+void QuicConnectionFactoryImpl::RunTasks() {
+ for (const auto& packet : platform::OnePlatformLoopIteration(waiter_)) {
+ // TODO(btolsch): We will need to rethink this both for ICE and connection
+ // migration support.
+ auto conn_it = connections_.find(packet.source);
+ if (conn_it == connections_.end()) {
+ if (server_delegate_) {
+ OSP_VLOG(1) << __func__ << ": spawning connection from "
+ << packet.source;
+ auto transport = MakeUnique<UdpTransport>(packet.socket, packet.source);
+ ::quic::QuartcSessionConfig session_config;
+ session_config.perspective = ::quic::Perspective::IS_SERVER;
+ session_config.packet_transport = transport.get();
+
+ auto result = MakeUnique<QuicConnectionImpl>(
+ this, server_delegate_->NextConnectionDelegate(packet.source),
+ std::move(transport),
+ quartc_factory_->CreateQuartcSession(session_config));
+ connections_.emplace(packet.source, result.get());
+ auto* result_ptr = result.get();
+ server_delegate_->OnIncomingConnection(std::move(result));
+ result_ptr->OnDataReceived(packet);
+ }
+ } else {
+ OSP_VLOG(1) << __func__ << ": data for existing connection from "
+ << packet.source;
+ conn_it->second->OnDataReceived(packet);
+ }
+ }
+}
+
+std::unique_ptr<QuicConnection> QuicConnectionFactoryImpl::Connect(
+ const IPEndpoint& endpoint,
+ QuicConnection::Delegate* connection_delegate) {
+ auto* socket = endpoint.address.IsV4() ? platform::CreateUdpSocketIPv4()
+ : platform::CreateUdpSocketIPv6();
+ platform::BindUdpSocket(socket, {}, 0);
+ auto transport = MakeUnique<UdpTransport>(socket, endpoint);
+
+ ::quic::QuartcSessionConfig session_config;
+ session_config.perspective = ::quic::Perspective::IS_CLIENT;
+ // TODO(btolsch): Proper server id. Does this go in the QUIC server name
+ // parameter?
+ session_config.unique_remote_server_id = "turtle";
+ session_config.packet_transport = transport.get();
+
+ auto result = MakeUnique<QuicConnectionImpl>(
+ this, connection_delegate, std::move(transport),
+ quartc_factory_->CreateQuartcSession(session_config));
+
+ platform::WatchUdpSocketReadable(waiter_, socket);
+ // TODO(btolsch): This presents a problem for multihomed receivers, which may
+ // register as a different endpoint in their response. I think QUIC is
+ // already tolerant of this via connection IDs but this hasn't been tested
+ // (and even so, those aren't necessarily stable either).
+ connections_.emplace(endpoint, result.get());
+
+ return result;
+}
+
+void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) {
+ auto entry = std::find_if(
+ connections_.begin(), connections_.end(),
+ [connection](const std::pair<IPEndpoint, QuicConnection*>& entry) {
+ return entry.second == connection;
+ });
+ connections_.erase(entry);
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/quic_connection_factory_impl.h b/api/impl/quic/quic_connection_factory_impl.h
new file mode 100644
index 0000000..65377b5
--- /dev/null
+++ b/api/impl/quic/quic_connection_factory_impl.h
@@ -0,0 +1,53 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_QUIC_CONNECTION_FACTORY_IMPL_H_
+#define API_IMPL_QUIC_QUIC_CONNECTION_FACTORY_IMPL_H_
+
+#include <map>
+#include <memory>
+
+#include "api/impl/quic/quic_connection_factory.h"
+#include "base/ip_address.h"
+#include "platform/api/event_waiter.h"
+#include "platform/api/socket.h"
+#include "third_party/chromium_quic/src/base/at_exit.h"
+#include "third_party/chromium_quic/src/net/quic/quic_chromium_alarm_factory.h"
+#include "third_party/chromium_quic/src/net/third_party/quic/quartc/quartc_factory.h"
+
+namespace openscreen {
+
+class QuicTaskRunner;
+
+class QuicConnectionFactoryImpl final : public QuicConnectionFactory {
+ public:
+ QuicConnectionFactoryImpl();
+ ~QuicConnectionFactoryImpl() override;
+
+ // QuicConnectionFactory overrides.
+ void SetServerDelegate(ServerDelegate* delegate,
+ const std::vector<IPEndpoint>& endpoints) override;
+ void RunTasks() override;
+ std::unique_ptr<QuicConnection> Connect(
+ const IPEndpoint& endpoint,
+ QuicConnection::Delegate* connection_delegate) override;
+
+ void OnConnectionClosed(QuicConnection* connection);
+
+ private:
+ ::base::AtExitManager exit_manager_;
+ scoped_refptr<QuicTaskRunner> task_runner_;
+ std::unique_ptr<::net::QuicChromiumAlarmFactory> alarm_factory_;
+ std::unique_ptr<::quic::QuartcFactory> quartc_factory_;
+
+ ServerDelegate* server_delegate_ = nullptr;
+ std::vector<platform::UdpSocketPtr> server_sockets_;
+
+ platform::EventWaiterPtr waiter_;
+ std::map<IPEndpoint, QuicConnection*, IPEndpointComparator> connections_;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_QUIC_CONNECTION_FACTORY_IMPL_H_
diff --git a/api/impl/quic/quic_connection_impl.cc b/api/impl/quic/quic_connection_impl.cc
new file mode 100644
index 0000000..884c9c2
--- /dev/null
+++ b/api/impl/quic/quic_connection_impl.cc
@@ -0,0 +1,109 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/quic_connection_impl.h"
+
+#include "api/impl/quic/quic_connection_factory_impl.h"
+#include "base/make_unique.h"
+#include "third_party/chromium_quic/src/net/third_party/quic/platform/impl/quic_chromium_clock.h"
+
+namespace openscreen {
+
+UdpTransport::UdpTransport(platform::UdpSocketPtr socket,
+ const IPEndpoint& destination)
+ : socket_(socket), destination_(destination) {}
+UdpTransport::UdpTransport(UdpTransport&&) = default;
+UdpTransport::~UdpTransport() = default;
+
+UdpTransport& UdpTransport::operator=(UdpTransport&&) = default;
+
+int UdpTransport::Write(const char* buffer,
+ size_t buffer_length,
+ const PacketInfo& info) {
+ return platform::SendUdp(socket_, buffer, buffer_length, destination_);
+}
+
+QuicStreamImpl::QuicStreamImpl(QuicStream::Delegate* delegate,
+ ::quic::QuartcStream* stream)
+ : QuicStream(delegate, stream->id()), stream_(stream) {
+ stream_->SetDelegate(this);
+}
+
+QuicStreamImpl::~QuicStreamImpl() = default;
+
+void QuicStreamImpl::Write(const uint8_t* data, size_t data_size) {
+ stream_->WriteOrBufferData(
+ ::quic::QuicStringPiece(reinterpret_cast<const char*>(data), data_size),
+ false, nullptr);
+}
+
+void QuicStreamImpl::CloseWriteEnd() {
+ stream_->FinishWriting();
+}
+
+void QuicStreamImpl::OnReceived(::quic::QuartcStream* stream,
+ const char* data,
+ size_t data_size) {
+ delegate_->OnReceived(this, data, data_size);
+}
+
+void QuicStreamImpl::OnClose(::quic::QuartcStream* stream) {
+ delegate_->OnClose(stream->id());
+}
+
+void QuicStreamImpl::OnBufferChanged(::quic::QuartcStream* stream) {}
+
+QuicConnectionImpl::QuicConnectionImpl(
+ QuicConnectionFactoryImpl* parent_factory,
+ QuicConnection::Delegate* delegate,
+ std::unique_ptr<UdpTransport>&& udp_transport,
+ std::unique_ptr<::quic::QuartcSession>&& session)
+ : QuicConnection(delegate),
+ parent_factory_(parent_factory),
+ session_(std::move(session)),
+ udp_transport_(std::move(udp_transport)) {
+ session_->SetDelegate(this);
+ session_->OnTransportCanWrite();
+ session_->StartCryptoHandshake();
+}
+
+QuicConnectionImpl::~QuicConnectionImpl() = default;
+
+void QuicConnectionImpl::OnDataReceived(const platform::ReceivedData& data) {
+ session_->OnTransportReceived(
+ reinterpret_cast<const char*>(data.bytes.data()), data.length);
+}
+
+std::unique_ptr<QuicStream> QuicConnectionImpl::MakeOutgoingStream(
+ QuicStream::Delegate* delegate) {
+ ::quic::QuartcStream* stream = session_->CreateOutgoingDynamicStream();
+ return MakeUnique<QuicStreamImpl>(delegate, stream);
+}
+
+void QuicConnectionImpl::Close() {
+ session_->CloseConnection("closed");
+}
+
+void QuicConnectionImpl::OnCryptoHandshakeComplete() {
+ delegate_->OnCryptoHandshakeComplete(session_->connection_id());
+}
+
+void QuicConnectionImpl::OnIncomingStream(::quic::QuartcStream* stream) {
+ auto public_stream = MakeUnique<QuicStreamImpl>(
+ delegate_->NextStreamDelegate(session_->connection_id(), stream->id()),
+ stream);
+ streams_.push_back(public_stream.get());
+ delegate_->OnIncomingStream(session_->connection_id(),
+ std::move(public_stream));
+}
+
+void QuicConnectionImpl::OnConnectionClosed(
+ ::quic::QuicErrorCode error_code,
+ const ::quic::QuicString& error_details,
+ ::quic::ConnectionCloseSource source) {
+ parent_factory_->OnConnectionClosed(this);
+ delegate_->OnConnectionClosed(session_->connection_id());
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/quic_connection_impl.h b/api/impl/quic/quic_connection_impl.h
new file mode 100644
index 0000000..826148c
--- /dev/null
+++ b/api/impl/quic/quic_connection_impl.h
@@ -0,0 +1,99 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_QUIC_CONNECTION_IMPL_H_
+#define API_IMPL_QUIC_QUIC_CONNECTION_IMPL_H_
+
+#include <list>
+#include <memory>
+
+#include "api/impl/quic/quic_connection.h"
+#include "base/ip_address.h"
+#include "platform/api/socket.h"
+#include "third_party/chromium_quic/src/base/callback.h"
+#include "third_party/chromium_quic/src/base/location.h"
+#include "third_party/chromium_quic/src/base/task_runner.h"
+#include "third_party/chromium_quic/src/base/time/time.h"
+#include "third_party/chromium_quic/src/net/third_party/quic/quartc/quartc_packet_writer.h"
+#include "third_party/chromium_quic/src/net/third_party/quic/quartc/quartc_session.h"
+#include "third_party/chromium_quic/src/net/third_party/quic/quartc/quartc_stream.h"
+
+namespace openscreen {
+
+class QuicConnectionFactoryImpl;
+
+class UdpTransport final : public ::quic::QuartcPacketTransport {
+ public:
+ UdpTransport(platform::UdpSocketPtr socket, const IPEndpoint& destination);
+ UdpTransport(UdpTransport&&);
+ ~UdpTransport() override;
+
+ UdpTransport& operator=(UdpTransport&&);
+
+ // ::quic::QuartcPacketTransport overrides.
+ int Write(const char* buffer,
+ size_t buffer_length,
+ const PacketInfo& info) override;
+
+ platform::UdpSocketPtr socket() { return socket_; }
+
+ private:
+ platform::UdpSocketPtr socket_;
+ IPEndpoint destination_;
+};
+
+class QuicStreamImpl final : public QuicStream,
+ public ::quic::QuartcStream::Delegate {
+ public:
+ QuicStreamImpl(QuicStream::Delegate* delegate, ::quic::QuartcStream* stream);
+ ~QuicStreamImpl() override;
+
+ // QuicStream overrides.
+ void Write(const uint8_t* data, size_t size) override;
+ void CloseWriteEnd() override;
+
+ // ::quic::QuartcStream::Delegate overrides.
+ void OnReceived(::quic::QuartcStream* stream,
+ const char* data,
+ size_t data_size) override;
+ void OnClose(::quic::QuartcStream* stream) override;
+ void OnBufferChanged(::quic::QuartcStream* stream) override;
+
+ private:
+ ::quic::QuartcStream* const stream_;
+};
+
+class QuicConnectionImpl final : public QuicConnection,
+ public ::quic::QuartcSession::Delegate {
+ public:
+ QuicConnectionImpl(QuicConnectionFactoryImpl* parent_factory,
+ QuicConnection::Delegate* delegate,
+ std::unique_ptr<UdpTransport>&& udp_transport,
+ std::unique_ptr<::quic::QuartcSession>&& session);
+
+ ~QuicConnectionImpl() override;
+
+ // QuicConnection overrides.
+ void OnDataReceived(const platform::ReceivedData& data) override;
+ std::unique_ptr<QuicStream> MakeOutgoingStream(
+ QuicStream::Delegate* delegate) override;
+ void Close() override;
+
+ // ::quic::QuartcSession::Delegate overrides.
+ void OnCryptoHandshakeComplete() override;
+ void OnIncomingStream(::quic::QuartcStream* stream) override;
+ void OnConnectionClosed(::quic::QuicErrorCode error_code,
+ const ::quic::QuicString& error_details,
+ ::quic::ConnectionCloseSource source) override;
+
+ private:
+ QuicConnectionFactoryImpl* const parent_factory_;
+ const std::unique_ptr<::quic::QuartcSession> session_;
+ const std::unique_ptr<UdpTransport> udp_transport_;
+ std::vector<QuicStream*> streams_;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_QUIC_CONNECTION_IMPL_H_
diff --git a/api/impl/quic/quic_server.cc b/api/impl/quic/quic_server.cc
new file mode 100644
index 0000000..7d851fe
--- /dev/null
+++ b/api/impl/quic/quic_server.cc
@@ -0,0 +1,156 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/quic_server.h"
+
+#include "base/make_unique.h"
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+QuicServer::QuicServer(
+ const ServerConfig& config,
+ MessageDemuxer* demuxer,
+ std::unique_ptr<QuicConnectionFactory> connection_factory,
+ ProtocolConnectionServer::Observer* observer)
+ : ProtocolConnectionServer(demuxer, observer),
+ connection_endpoints_(config.connection_endpoints),
+ connection_factory_(std::move(connection_factory)) {}
+
+QuicServer::~QuicServer() {
+ CloseAllConnections();
+}
+
+bool QuicServer::Start() {
+ if (state_ != State::kStopped)
+ return false;
+ state_ = State::kRunning;
+ connection_factory_->SetServerDelegate(this, connection_endpoints_);
+ observer_->OnRunning();
+ return true;
+}
+
+bool QuicServer::Stop() {
+ if (state_ != State::kRunning && state_ != State::kSuspended)
+ return false;
+ connection_factory_->SetServerDelegate(nullptr, {});
+ CloseAllConnections();
+ state_ = State::kStopped;
+ observer_->OnStopped();
+ return true;
+}
+
+bool QuicServer::Suspend() {
+ // TODO(btolsch): QuicStreams should either buffer or reject writes.
+ if (state_ != State::kRunning)
+ return false;
+ state_ = State::kSuspended;
+ observer_->OnSuspended();
+ return true;
+}
+
+bool QuicServer::Resume() {
+ if (state_ != State::kSuspended)
+ return false;
+ state_ = State::kRunning;
+ observer_->OnRunning();
+ return true;
+}
+
+void QuicServer::RunTasks() {
+ if (state_ == State::kRunning)
+ connection_factory_->RunTasks();
+ for (auto& entry : delete_connections_)
+ connections_.erase(entry);
+ delete_connections_.clear();
+}
+
+std::unique_ptr<ProtocolConnection> QuicServer::CreateProtocolConnection(
+ uint64_t endpoint_id) {
+ if (state_ != State::kRunning)
+ return nullptr;
+ auto connection_entry = connections_.find(endpoint_id);
+ if (connection_entry == connections_.end())
+ return nullptr;
+ return QuicProtocolConnection::FromExisting(
+ this, connection_entry->second.connection.get(),
+ connection_entry->second.delegate.get(), endpoint_id);
+}
+
+void QuicServer::OnConnectionDestroyed(QuicProtocolConnection* connection) {
+ auto connection_entry = connections_.find(connection->endpoint_id());
+ if (connection_entry == connections_.end())
+ return;
+ connection_entry->second.delegate->DropProtocolConnection(connection);
+}
+
+uint64_t QuicServer::OnCryptoHandshakeComplete(
+ ServiceConnectionDelegate* delegate,
+ uint64_t connection_id) {
+ OSP_DCHECK_EQ(state_, State::kRunning);
+ const IPEndpoint& endpoint = delegate->endpoint();
+ auto pending_entry = pending_connections_.find(endpoint);
+ if (pending_entry == pending_connections_.end())
+ return 0;
+ ServiceConnectionData connection_data = std::move(pending_entry->second);
+ pending_connections_.erase(pending_entry);
+ uint64_t endpoint_id = next_endpoint_id_++;
+ endpoint_map_[endpoint] = endpoint_id;
+ connections_.emplace(endpoint_id, std::move(connection_data));
+ return endpoint_id;
+}
+
+void QuicServer::OnIncomingStream(
+ std::unique_ptr<QuicProtocolConnection>&& connection) {
+ OSP_DCHECK_EQ(state_, State::kRunning);
+ observer_->OnIncomingConnection(std::move(connection));
+}
+
+void QuicServer::OnConnectionClosed(uint64_t endpoint_id,
+ uint64_t connection_id) {
+ OSP_DCHECK_EQ(state_, State::kRunning);
+ auto connection_entry = connections_.find(endpoint_id);
+ if (connection_entry == connections_.end())
+ return;
+ delete_connections_.emplace_back(connection_entry);
+}
+
+void QuicServer::OnDataReceived(uint64_t endpoint_id,
+ uint64_t connection_id,
+ const uint8_t* data,
+ size_t data_size) {
+ OSP_DCHECK_EQ(state_, State::kRunning);
+ demuxer_->OnStreamData(endpoint_id, connection_id, data, data_size);
+}
+
+void QuicServer::CloseAllConnections() {
+ for (auto& conn : pending_connections_)
+ conn.second.connection->Close();
+ pending_connections_.clear();
+ for (auto& conn : connections_)
+ conn.second.connection->Close();
+ connections_.clear();
+ endpoint_map_.clear();
+ next_endpoint_id_ = 0;
+}
+
+QuicConnection::Delegate* QuicServer::NextConnectionDelegate(
+ const IPEndpoint& source) {
+ OSP_DCHECK_EQ(state_, State::kRunning);
+ OSP_DCHECK(!pending_connection_delegate_);
+ pending_connection_delegate_ =
+ MakeUnique<ServiceConnectionDelegate>(this, source);
+ return pending_connection_delegate_.get();
+}
+
+void QuicServer::OnIncomingConnection(
+ std::unique_ptr<QuicConnection>&& connection) {
+ OSP_DCHECK_EQ(state_, State::kRunning);
+ const IPEndpoint& endpoint = pending_connection_delegate_->endpoint();
+ pending_connections_.emplace(
+ endpoint, ServiceConnectionData(std::move(connection),
+ std::move(pending_connection_delegate_)));
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/quic_server.h b/api/impl/quic/quic_server.h
new file mode 100644
index 0000000..76cf752
--- /dev/null
+++ b/api/impl/quic/quic_server.h
@@ -0,0 +1,99 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_QUIC_SERVER_H_
+#define API_IMPL_QUIC_QUIC_SERVER_H_
+
+#include <cstdint>
+#include <map>
+#include <memory>
+
+#include "api/impl/quic/quic_connection_factory.h"
+#include "api/impl/quic/quic_service_common.h"
+#include "api/public/protocol_connection_server.h"
+#include "base/ip_address.h"
+
+namespace openscreen {
+
+// This class is the default implementation of ProtocolConnectionServer for the
+// library. It manages connections to other endpoints as well as the lifetime
+// of each incoming and outgoing stream. It works in conjunction with a
+// QuicConnectionFactory implementation and MessageDemuxer.
+// QuicConnectionFactory provides the ability to make a new QUIC
+// connection from packets received on its server sockets. Incoming data is
+// given to the QuicServer by the underlying QUIC implementation (through
+// QuicConnectionFactory) and this is in turn handed to MessageDemuxer for
+// routing CBOR messages.
+class QuicServer final : public ProtocolConnectionServer,
+ public QuicConnectionFactory::ServerDelegate,
+ public ServiceConnectionDelegate::ServiceDelegate {
+ public:
+ QuicServer(const ServerConfig& config,
+ MessageDemuxer* demuxer,
+ std::unique_ptr<QuicConnectionFactory> connection_factory,
+ ProtocolConnectionServer::Observer* observer);
+ ~QuicServer() override;
+
+ // ProtocolConnectionServer overrides.
+ bool Start() override;
+ bool Stop() override;
+ bool Suspend() override;
+ bool Resume() override;
+ void RunTasks() override;
+ std::unique_ptr<ProtocolConnection> CreateProtocolConnection(
+ uint64_t endpoint_id) override;
+
+ // QuicProtocolConnection::Owner overrides.
+ void OnConnectionDestroyed(QuicProtocolConnection* connection) override;
+
+ // ServiceConnectionDelegate::ServiceDelegate overrides.
+ uint64_t OnCryptoHandshakeComplete(ServiceConnectionDelegate* delegate,
+ uint64_t connection_id) override;
+ void OnIncomingStream(
+ std::unique_ptr<QuicProtocolConnection>&& connection) override;
+ void OnConnectionClosed(uint64_t endpoint_id,
+ uint64_t connection_id) override;
+ void OnDataReceived(uint64_t endpoint_id,
+ uint64_t connection_id,
+ const uint8_t* data,
+ size_t data_size) override;
+
+ private:
+ void CloseAllConnections();
+
+ // QuicConnectionFactory::ServerDelegate overrides.
+ QuicConnection::Delegate* NextConnectionDelegate(
+ const IPEndpoint& source) override;
+ void OnIncomingConnection(
+ std::unique_ptr<QuicConnection>&& connection) override;
+
+ const std::vector<IPEndpoint> connection_endpoints_;
+ std::unique_ptr<QuicConnectionFactory> connection_factory_;
+
+ std::unique_ptr<ServiceConnectionDelegate> pending_connection_delegate_;
+
+ // Maps an IPEndpoint to a generated endpoint ID. This is used to insulate
+ // callers from post-handshake changes to a connections actual peer endpoint.
+ std::map<IPEndpoint, uint64_t, IPEndpointComparator> endpoint_map_;
+
+ // Value that will be used for the next new endpoint in a Connect call.
+ uint64_t next_endpoint_id_ = 0;
+
+ // Maps endpoint addresses to data about connections that haven't successfully
+ // completed the QUIC handshake.
+ std::map<IPEndpoint, ServiceConnectionData, IPEndpointComparator>
+ pending_connections_;
+
+ // Maps endpoint IDs to data about connections that have successfully
+ // completed the QUIC handshake.
+ std::map<uint64_t, ServiceConnectionData> connections_;
+
+ // Connections that need to be destroyed, but have to wait for the next event
+ // loop due to the underlying QUIC implementation's way of referencing them.
+ std::vector<decltype(connections_)::iterator> delete_connections_;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_QUIC_SERVER_H_
diff --git a/api/impl/quic/quic_server_unittest.cc b/api/impl/quic/quic_server_unittest.cc
new file mode 100644
index 0000000..cd5e80d
--- /dev/null
+++ b/api/impl/quic/quic_server_unittest.cc
@@ -0,0 +1,214 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/quic_server.h"
+
+#include "api/impl/quic/testing/fake_quic_connection_factory.h"
+#include "api/public/network_metrics.h"
+#include "base/error.h"
+#include "base/make_unique.h"
+#include "third_party/googletest/src/googlemock/include/gmock/gmock.h"
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+
+namespace openscreen {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+
+class MockServerObserver final : public ProtocolConnectionServer::Observer {
+ public:
+ ~MockServerObserver() override = default;
+
+ MOCK_METHOD0(OnRunning, void());
+ MOCK_METHOD0(OnStopped, void());
+ MOCK_METHOD1(OnMetrics, void(const NetworkMetrics& metrics));
+ MOCK_METHOD1(OnError, void(const Error& error));
+
+ MOCK_METHOD0(OnSuspended, void());
+
+ void OnIncomingConnection(
+ std::unique_ptr<ProtocolConnection>&& connection) override {
+ OnIncomingConnectionMock(connection.release());
+ }
+ MOCK_METHOD1(OnIncomingConnectionMock, void(ProtocolConnection* connection));
+};
+
+class MockMessageCallback final : public MessageDemuxer::MessageCallback {
+ public:
+ ~MockMessageCallback() override = default;
+
+ MOCK_METHOD5(OnStreamMessage,
+ ErrorOr<size_t>(uint64_t endpoint_id,
+ uint64_t connection_id,
+ msgs::Type message_type,
+ const uint8_t* buffer,
+ size_t buffer_size));
+};
+
+class MockConnectionObserver final : public ProtocolConnection::Observer {
+ public:
+ ~MockConnectionObserver() override = default;
+
+ MOCK_METHOD1(OnConnectionChanged, void(const ProtocolConnection& connection));
+ MOCK_METHOD1(OnConnectionClosed, void(const ProtocolConnection& connection));
+};
+
+class QuicServerTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ auto connection_factory = MakeUnique<FakeQuicConnectionFactory>(
+ local_endpoint_, &client_demuxer_);
+ connection_factory_ = connection_factory.get();
+ ServerConfig config;
+ config.connection_endpoints.push_back(local_endpoint_);
+ server_ = MakeUnique<QuicServer>(
+ config, &demuxer_, std::move(connection_factory), &mock_observer_);
+ }
+
+ void RunTasksUntilIdle() {
+ do {
+ server_->RunTasks();
+ } while (!connection_factory_->idle());
+ }
+
+ void SendTestMessage(ProtocolConnection* connection) {
+ MockMessageCallback mock_message_callback;
+ MessageDemuxer::MessageWatch message_watch =
+ client_demuxer_.WatchMessageType(
+ 0, msgs::Type::kPresentationConnectionMessage,
+ &mock_message_callback);
+
+ msgs::CborEncodeBuffer buffer;
+ msgs::PresentationConnectionMessage message;
+ message.presentation_id = "some-id";
+ message.connection_id = 7;
+ message.message.which = decltype(message.message.which)::kString;
+ new (&message.message.str) std::string("message from server");
+ ASSERT_TRUE(msgs::EncodePresentationConnectionMessage(message, &buffer));
+ connection->Write(buffer.data(), buffer.size());
+ connection->CloseWriteEnd();
+
+ ssize_t decode_result = 0;
+ msgs::PresentationConnectionMessage received_message;
+ EXPECT_CALL(
+ mock_message_callback,
+ OnStreamMessage(0, connection->connection_id(),
+ msgs::Type::kPresentationConnectionMessage, _, _))
+ .WillOnce(Invoke([&decode_result, &received_message](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionMessage(
+ buffer, buffer_size, &received_message);
+ if (decode_result < 0)
+ return ErrorOr<size_t>(Error::Code::kCborParsing);
+ return ErrorOr<size_t>(decode_result);
+ }));
+ RunTasksUntilIdle();
+
+ ASSERT_GT(decode_result, 0);
+ EXPECT_EQ(decode_result, static_cast<ssize_t>(buffer.size() - 1));
+ EXPECT_EQ(received_message.presentation_id, message.presentation_id);
+ EXPECT_EQ(received_message.connection_id, message.connection_id);
+ ASSERT_EQ(received_message.message.which,
+ decltype(received_message.message.which)::kString);
+ EXPECT_EQ(received_message.message.str, message.message.str);
+ }
+
+ const IPEndpoint local_endpoint_{{192, 168, 1, 10}, 44327};
+ const IPEndpoint client_endpoint_{{192, 168, 1, 15}, 54368};
+ MessageDemuxer demuxer_;
+ MessageDemuxer client_demuxer_;
+ FakeQuicConnectionFactory* connection_factory_;
+ MockServerObserver mock_observer_;
+ std::unique_ptr<QuicServer> server_;
+};
+
+} // namespace
+
+TEST_F(QuicServerTest, Connect) {
+ server_->Start();
+
+ std::unique_ptr<ProtocolConnection> connection;
+ EXPECT_CALL(mock_observer_, OnIncomingConnectionMock(_))
+ .WillOnce(Invoke(
+ [&connection](ProtocolConnection* c) { connection.reset(c); }));
+ connection_factory_->StartServerConnection(client_endpoint_);
+ RunTasksUntilIdle();
+ connection_factory_->StartIncomingStream(client_endpoint_);
+ RunTasksUntilIdle();
+ ASSERT_TRUE(connection);
+
+ SendTestMessage(connection.get());
+
+ server_->Stop();
+}
+
+TEST_F(QuicServerTest, OpenImmediate) {
+ server_->Start();
+
+ EXPECT_FALSE(server_->CreateProtocolConnection(1));
+
+ std::unique_ptr<ProtocolConnection> connection1;
+ EXPECT_CALL(mock_observer_, OnIncomingConnectionMock(_))
+ .WillOnce(Invoke(
+ [&connection1](ProtocolConnection* c) { connection1.reset(c); }));
+ connection_factory_->StartServerConnection(client_endpoint_);
+ RunTasksUntilIdle();
+ connection_factory_->StartIncomingStream(client_endpoint_);
+ RunTasksUntilIdle();
+ ASSERT_TRUE(connection1);
+
+ std::unique_ptr<ProtocolConnection> connection2;
+ connection2 = server_->CreateProtocolConnection(connection1->endpoint_id());
+
+ SendTestMessage(connection2.get());
+
+ server_->Stop();
+}
+
+TEST_F(QuicServerTest, States) {
+ EXPECT_CALL(mock_observer_, OnRunning());
+ EXPECT_TRUE(server_->Start());
+ EXPECT_FALSE(server_->Start());
+
+ std::unique_ptr<ProtocolConnection> connection;
+ EXPECT_CALL(mock_observer_, OnIncomingConnectionMock(_))
+ .WillOnce(Invoke(
+ [&connection](ProtocolConnection* c) { connection.reset(c); }));
+ connection_factory_->StartServerConnection(client_endpoint_);
+ RunTasksUntilIdle();
+ connection_factory_->StartIncomingStream(client_endpoint_);
+ RunTasksUntilIdle();
+ ASSERT_TRUE(connection);
+ MockConnectionObserver mock_connection_observer;
+ connection->SetObserver(&mock_connection_observer);
+
+ EXPECT_CALL(mock_connection_observer, OnConnectionClosed(_));
+ EXPECT_CALL(mock_observer_, OnStopped());
+ EXPECT_TRUE(server_->Stop());
+ EXPECT_FALSE(server_->Stop());
+
+ EXPECT_CALL(mock_observer_, OnRunning());
+ EXPECT_TRUE(server_->Start());
+
+ EXPECT_CALL(mock_observer_, OnSuspended());
+ EXPECT_TRUE(server_->Suspend());
+ EXPECT_FALSE(server_->Suspend());
+ EXPECT_FALSE(server_->Start());
+
+ EXPECT_CALL(mock_observer_, OnRunning());
+ EXPECT_TRUE(server_->Resume());
+ EXPECT_FALSE(server_->Resume());
+ EXPECT_FALSE(server_->Start());
+
+ EXPECT_CALL(mock_observer_, OnSuspended());
+ EXPECT_TRUE(server_->Suspend());
+
+ EXPECT_CALL(mock_observer_, OnStopped());
+ EXPECT_TRUE(server_->Stop());
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/quic_service_common.cc b/api/impl/quic/quic_service_common.cc
new file mode 100644
index 0000000..ef3d5c4
--- /dev/null
+++ b/api/impl/quic/quic_service_common.cc
@@ -0,0 +1,150 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/quic_service_common.h"
+
+#include "base/make_unique.h"
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+// static
+std::unique_ptr<QuicProtocolConnection> QuicProtocolConnection::FromExisting(
+ Owner* owner,
+ QuicConnection* connection,
+ ServiceConnectionDelegate* delegate,
+ uint64_t endpoint_id) {
+ std::unique_ptr<QuicStream> stream = connection->MakeOutgoingStream(delegate);
+ auto pc =
+ MakeUnique<QuicProtocolConnection>(owner, endpoint_id, stream->id());
+ pc->set_stream(stream.get());
+ delegate->AddStreamPair(ServiceStreamPair(std::move(stream), pc.get()));
+ return pc;
+}
+
+QuicProtocolConnection::QuicProtocolConnection(Owner* owner,
+ uint64_t endpoint_id,
+ uint64_t connection_id)
+ : ProtocolConnection(endpoint_id, connection_id), owner_(owner) {}
+
+QuicProtocolConnection::~QuicProtocolConnection() {
+ if (stream_)
+ owner_->OnConnectionDestroyed(this);
+}
+
+void QuicProtocolConnection::Write(const uint8_t* data, size_t data_size) {
+ if (stream_)
+ stream_->Write(data, data_size);
+}
+
+void QuicProtocolConnection::CloseWriteEnd() {
+ if (stream_)
+ stream_->CloseWriteEnd();
+}
+
+void QuicProtocolConnection::OnClose() {
+ if (observer_)
+ observer_->OnConnectionClosed(*this);
+}
+
+ServiceStreamPair::ServiceStreamPair(
+ std::unique_ptr<QuicStream>&& stream,
+ QuicProtocolConnection* protocol_connection)
+ : stream(std::move(stream)),
+ protocol_connection(std::move(protocol_connection)) {}
+ServiceStreamPair::~ServiceStreamPair() = default;
+
+ServiceStreamPair::ServiceStreamPair(ServiceStreamPair&& other) = default;
+
+ServiceStreamPair& ServiceStreamPair::operator=(ServiceStreamPair&& other) =
+ default;
+
+ServiceConnectionDelegate::ServiceConnectionDelegate(ServiceDelegate* parent,
+ const IPEndpoint& endpoint)
+ : parent_(parent), endpoint_(endpoint) {}
+
+ServiceConnectionDelegate::~ServiceConnectionDelegate() {
+ OSP_DCHECK(streams_.empty());
+}
+
+void ServiceConnectionDelegate::AddStreamPair(ServiceStreamPair&& stream_pair) {
+ uint64_t stream_id = stream_pair.stream->id();
+ streams_.emplace(stream_id, std::move(stream_pair));
+}
+
+void ServiceConnectionDelegate::DropProtocolConnection(
+ QuicProtocolConnection* connection) {
+ auto stream_entry = streams_.find(connection->stream()->id());
+ if (stream_entry == streams_.end())
+ return;
+ stream_entry->second.protocol_connection = nullptr;
+}
+
+void ServiceConnectionDelegate::OnCryptoHandshakeComplete(
+ uint64_t connection_id) {
+ endpoint_id_ = parent_->OnCryptoHandshakeComplete(this, connection_id);
+}
+
+void ServiceConnectionDelegate::OnIncomingStream(
+ uint64_t connection_id,
+ std::unique_ptr<QuicStream> stream) {
+ pending_connection_->set_stream(stream.get());
+ AddStreamPair(
+ ServiceStreamPair(std::move(stream), pending_connection_.get()));
+ parent_->OnIncomingStream(std::move(pending_connection_));
+}
+
+void ServiceConnectionDelegate::OnConnectionClosed(uint64_t connection_id) {
+ parent_->OnConnectionClosed(endpoint_id_, connection_id);
+}
+
+QuicStream::Delegate* ServiceConnectionDelegate::NextStreamDelegate(
+ uint64_t connection_id,
+ uint64_t stream_id) {
+ OSP_DCHECK(!pending_connection_);
+ pending_connection_ =
+ MakeUnique<QuicProtocolConnection>(parent_, endpoint_id_, stream_id);
+ return this;
+}
+
+void ServiceConnectionDelegate::OnReceived(QuicStream* stream,
+ const char* data,
+ size_t data_size) {
+ auto stream_entry = streams_.find(stream->id());
+ if (stream_entry == streams_.end())
+ return;
+ // TODO(btolsch): It happens that for normal stream data, OnClose is called
+ // before the fin bit is passed here. Would OnClose instead be the last
+ // callback if a RST_STREAM or CONNECTION_CLOSE is received?
+ if (stream_entry->second.protocol_connection) {
+ parent_->OnDataReceived(
+ stream_entry->second.protocol_connection->endpoint_id(),
+ stream_entry->second.protocol_connection->connection_id(),
+ reinterpret_cast<const uint8_t*>(data), data_size);
+ }
+ if (!data_size) {
+ if (stream_entry->second.protocol_connection)
+ stream_entry->second.protocol_connection->set_stream(nullptr);
+ streams_.erase(stream_entry);
+ }
+}
+
+void ServiceConnectionDelegate::OnClose(uint64_t stream_id) {
+ auto stream_entry = streams_.find(stream_id);
+ if (stream_entry == streams_.end())
+ return;
+ if (stream_entry->second.protocol_connection)
+ stream_entry->second.protocol_connection->OnClose();
+}
+
+ServiceConnectionData::ServiceConnectionData(
+ std::unique_ptr<QuicConnection>&& connection,
+ std::unique_ptr<ServiceConnectionDelegate>&& delegate)
+ : connection(std::move(connection)), delegate(std::move(delegate)) {}
+ServiceConnectionData::ServiceConnectionData(ServiceConnectionData&&) = default;
+ServiceConnectionData::~ServiceConnectionData() = default;
+ServiceConnectionData& ServiceConnectionData::operator=(
+ ServiceConnectionData&&) = default;
+
+} // namespace openscreen
diff --git a/api/impl/quic/quic_service_common.h b/api/impl/quic/quic_service_common.h
new file mode 100644
index 0000000..b5b4ea0
--- /dev/null
+++ b/api/impl/quic/quic_service_common.h
@@ -0,0 +1,130 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_QUIC_SERVICE_COMMON_H_
+#define API_IMPL_QUIC_QUIC_SERVICE_COMMON_H_
+
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "api/impl/quic/quic_connection.h"
+#include "api/public/protocol_connection.h"
+
+namespace openscreen {
+
+class ServiceConnectionDelegate;
+
+class QuicProtocolConnection final : public ProtocolConnection {
+ public:
+ class Owner {
+ public:
+ virtual ~Owner() = default;
+
+ // Called right before |connection| is destroyed (destructor runs).
+ virtual void OnConnectionDestroyed(QuicProtocolConnection* connection) = 0;
+ };
+
+ static std::unique_ptr<QuicProtocolConnection> FromExisting(
+ Owner* owner,
+ QuicConnection* connection,
+ ServiceConnectionDelegate* delegate,
+ uint64_t endpoint_id);
+
+ QuicProtocolConnection(Owner* owner,
+ uint64_t endpoint_id,
+ uint64_t connection_id);
+ ~QuicProtocolConnection() override;
+
+ // ProtocolConnection overrides.
+ void Write(const uint8_t* data, size_t data_size) override;
+ void CloseWriteEnd() override;
+
+ QuicStream* stream() { return stream_; }
+ void set_stream(QuicStream* stream) { stream_ = stream; }
+
+ void OnClose();
+
+ private:
+ Owner* const owner_;
+ QuicStream* stream_ = nullptr;
+};
+
+struct ServiceStreamPair {
+ ServiceStreamPair(std::unique_ptr<QuicStream>&& stream,
+ QuicProtocolConnection* protocol_connection);
+ ~ServiceStreamPair();
+ ServiceStreamPair(ServiceStreamPair&&);
+ ServiceStreamPair& operator=(ServiceStreamPair&&);
+
+ std::unique_ptr<QuicStream> stream;
+ QuicProtocolConnection* protocol_connection;
+};
+
+class ServiceConnectionDelegate final : public QuicConnection::Delegate,
+ public QuicStream::Delegate {
+ public:
+ class ServiceDelegate : public QuicProtocolConnection::Owner {
+ public:
+ ~ServiceDelegate() override = default;
+
+ virtual uint64_t OnCryptoHandshakeComplete(
+ ServiceConnectionDelegate* delegate,
+ uint64_t connection_id) = 0;
+ virtual void OnIncomingStream(
+ std::unique_ptr<QuicProtocolConnection>&& connection) = 0;
+ virtual void OnConnectionClosed(uint64_t endpoint_id,
+ uint64_t connection_id) = 0;
+ virtual void OnDataReceived(uint64_t endpoint_id,
+ uint64_t connection_id,
+ const uint8_t* data,
+ size_t data_size) = 0;
+ };
+
+ ServiceConnectionDelegate(ServiceDelegate* parent,
+ const IPEndpoint& endpoint);
+ ~ServiceConnectionDelegate() override;
+
+ void AddStreamPair(ServiceStreamPair&& stream_pair);
+ void DropProtocolConnection(QuicProtocolConnection* connection);
+ const IPEndpoint& endpoint() const { return endpoint_; }
+
+ // QuicConnection::Delegate overrides.
+ void OnCryptoHandshakeComplete(uint64_t connection_id) override;
+ void OnIncomingStream(uint64_t connection_id,
+ std::unique_ptr<QuicStream> stream) override;
+ void OnConnectionClosed(uint64_t connection_id) override;
+ QuicStream::Delegate* NextStreamDelegate(uint64_t connection_id,
+ uint64_t stream_id) override;
+
+ // QuicStream::Delegate overrides.
+ void OnReceived(QuicStream* stream,
+ const char* data,
+ size_t data_size) override;
+ void OnClose(uint64_t stream_id) override;
+
+ private:
+ ServiceDelegate* const parent_;
+ IPEndpoint endpoint_;
+ uint64_t endpoint_id_;
+ std::unique_ptr<QuicProtocolConnection> pending_connection_;
+ std::map<uint64_t, ServiceStreamPair> streams_;
+};
+
+struct ServiceConnectionData {
+ explicit ServiceConnectionData(
+ std::unique_ptr<QuicConnection>&& connection,
+ std::unique_ptr<ServiceConnectionDelegate>&& delegate);
+ ServiceConnectionData(ServiceConnectionData&&);
+ ~ServiceConnectionData();
+ ServiceConnectionData& operator=(ServiceConnectionData&&);
+
+ std::unique_ptr<QuicConnection> connection;
+ std::unique_ptr<ServiceConnectionDelegate> delegate;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_QUIC_SERVICE_COMMON_H_
diff --git a/api/impl/quic/testing/fake_quic_connection.cc b/api/impl/quic/testing/fake_quic_connection.cc
new file mode 100644
index 0000000..ca5b504
--- /dev/null
+++ b/api/impl/quic/testing/fake_quic_connection.cc
@@ -0,0 +1,82 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/testing/fake_quic_connection.h"
+
+#include "api/impl/quic/testing/fake_quic_connection_factory.h"
+#include "base/make_unique.h"
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+FakeQuicStream::FakeQuicStream(Delegate* delegate, uint64_t id)
+ : QuicStream(delegate, id) {}
+
+FakeQuicStream::~FakeQuicStream() = default;
+
+void FakeQuicStream::ReceiveData(const uint8_t* data, size_t size) {
+ OSP_DCHECK(!read_end_closed_);
+ read_buffer_.insert(read_buffer_.end(), data, data + size);
+}
+
+void FakeQuicStream::CloseReadEnd() {
+ read_end_closed_ = true;
+}
+
+std::vector<uint8_t> FakeQuicStream::TakeReceivedData() {
+ return std::move(read_buffer_);
+}
+
+std::vector<uint8_t> FakeQuicStream::TakeWrittenData() {
+ return std::move(write_buffer_);
+}
+
+void FakeQuicStream::Write(const uint8_t* data, size_t size) {
+ OSP_DCHECK(!write_end_closed_);
+ write_buffer_.insert(write_buffer_.end(), data, data + size);
+}
+
+void FakeQuicStream::CloseWriteEnd() {
+ write_end_closed_ = true;
+}
+
+FakeQuicConnection::FakeQuicConnection(
+ FakeQuicConnectionFactory* parent_factory,
+ uint64_t connection_id,
+ Delegate* delegate)
+ : QuicConnection(delegate),
+ parent_factory_(parent_factory),
+ connection_id_(connection_id) {}
+
+FakeQuicConnection::~FakeQuicConnection() = default;
+
+std::unique_ptr<FakeQuicStream> FakeQuicConnection::MakeIncomingStream() {
+ uint64_t stream_id = next_stream_id_++;
+ auto result = MakeUnique<FakeQuicStream>(
+ delegate()->NextStreamDelegate(id(), stream_id), stream_id);
+ streams_.emplace(result->id(), result.get());
+ return result;
+}
+
+void FakeQuicConnection::OnDataReceived(const platform::ReceivedData& data) {
+ OSP_DCHECK(false) << "data should go directly to fake streams";
+}
+
+std::unique_ptr<QuicStream> FakeQuicConnection::MakeOutgoingStream(
+ QuicStream::Delegate* delegate) {
+ auto result = MakeUnique<FakeQuicStream>(delegate, next_stream_id_++);
+ streams_.emplace(result->id(), result.get());
+ return result;
+}
+
+void FakeQuicConnection::Close() {
+ parent_factory_->OnConnectionClosed(this);
+ delegate()->OnConnectionClosed(connection_id_);
+ for (auto& stream : streams_) {
+ stream.second->delegate()->OnClose(stream.first);
+ stream.second->delegate()->OnReceived(stream.second, nullptr, 0);
+ }
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/testing/fake_quic_connection.h b/api/impl/quic/testing/fake_quic_connection.h
new file mode 100644
index 0000000..3af1dd4
--- /dev/null
+++ b/api/impl/quic/testing/fake_quic_connection.h
@@ -0,0 +1,73 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_H_
+#define API_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_H_
+
+#include <map>
+#include <vector>
+
+#include "api/impl/quic/quic_connection.h"
+
+namespace openscreen {
+
+class FakeQuicConnectionFactory;
+
+class FakeQuicStream final : public QuicStream {
+ public:
+ FakeQuicStream(Delegate* delegate, uint64_t id);
+ ~FakeQuicStream() override;
+
+ void ReceiveData(const uint8_t* data, size_t size);
+ void CloseReadEnd();
+
+ std::vector<uint8_t> TakeReceivedData();
+ std::vector<uint8_t> TakeWrittenData();
+
+ bool write_end_closed() const { return write_end_closed_; }
+ bool read_end_closed() const { return read_end_closed_; }
+
+ Delegate* delegate() { return delegate_; }
+
+ void Write(const uint8_t* data, size_t size) override;
+ void CloseWriteEnd() override;
+
+ private:
+ bool write_end_closed_ = false;
+ bool read_end_closed_ = false;
+ std::vector<uint8_t> write_buffer_;
+ std::vector<uint8_t> read_buffer_;
+};
+
+class FakeQuicConnection final : public QuicConnection {
+ public:
+ FakeQuicConnection(FakeQuicConnectionFactory* parent_factory,
+ uint64_t connection_id,
+ Delegate* delegate);
+ ~FakeQuicConnection() override;
+
+ Delegate* delegate() { return delegate_; }
+ uint64_t id() const { return connection_id_; }
+ const std::map<uint64_t, FakeQuicStream*>& streams() const {
+ return streams_;
+ }
+
+ std::unique_ptr<FakeQuicStream> MakeIncomingStream();
+
+ // QuicConnection overrides.
+ void OnDataReceived(const platform::ReceivedData& data) override;
+ std::unique_ptr<QuicStream> MakeOutgoingStream(
+ QuicStream::Delegate* delegate) override;
+ void Close() override;
+
+ private:
+ FakeQuicConnectionFactory* const parent_factory_;
+ const uint64_t connection_id_;
+ uint64_t next_stream_id_ = 1;
+ std::map<uint64_t, FakeQuicStream*> streams_;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_H_
diff --git a/api/impl/quic/testing/fake_quic_connection_factory.cc b/api/impl/quic/testing/fake_quic_connection_factory.cc
new file mode 100644
index 0000000..21c754a
--- /dev/null
+++ b/api/impl/quic/testing/fake_quic_connection_factory.cc
@@ -0,0 +1,99 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/impl/quic/testing/fake_quic_connection_factory.h"
+
+#include <algorithm>
+
+#include "base/make_unique.h"
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+FakeQuicConnectionFactory::FakeQuicConnectionFactory(
+ const IPEndpoint& local_endpoint,
+ MessageDemuxer* remote_demuxer)
+ : remote_demuxer_(remote_demuxer), local_endpoint_(local_endpoint) {}
+
+FakeQuicConnectionFactory::~FakeQuicConnectionFactory() = default;
+
+void FakeQuicConnectionFactory::StartServerConnection(
+ const IPEndpoint& endpoint) {
+ QuicConnection::Delegate* delegate =
+ server_delegate_->NextConnectionDelegate(endpoint);
+ auto connection =
+ MakeUnique<FakeQuicConnection>(this, next_connection_id_++, delegate);
+ pending_connections_.emplace(endpoint, connection.get());
+ server_delegate_->OnIncomingConnection(std::move(connection));
+}
+
+FakeQuicStream* FakeQuicConnectionFactory::StartIncomingStream(
+ const IPEndpoint& endpoint) {
+ auto connection_entry = connections_.find(endpoint);
+ if (connection_entry == connections_.end())
+ return nullptr;
+ std::unique_ptr<FakeQuicStream> stream =
+ connection_entry->second->MakeIncomingStream();
+ FakeQuicStream* ptr = stream.get();
+ connection_entry->second->delegate()->OnIncomingStream(
+ connection_entry->second->id(), std::move(stream));
+ return ptr;
+}
+
+void FakeQuicConnectionFactory::OnConnectionClosed(QuicConnection* connection) {
+ for (auto entry = connections_.begin(); entry != connections_.end();
+ ++entry) {
+ if (entry->second == connection) {
+ connections_.erase(entry);
+ return;
+ }
+ }
+ OSP_DCHECK(false) << "reporting an unknown connection as closed";
+}
+
+void FakeQuicConnectionFactory::SetServerDelegate(
+ ServerDelegate* delegate,
+ const std::vector<IPEndpoint>& endpoints) {
+ server_delegate_ = delegate;
+}
+
+void FakeQuicConnectionFactory::RunTasks() {
+ idle_ = true;
+ for (auto& connection : connections_) {
+ for (auto& stream : connection.second->streams()) {
+ std::vector<uint8_t> received_data = stream.second->TakeReceivedData();
+ std::vector<uint8_t> written_data = stream.second->TakeWrittenData();
+ if (received_data.size()) {
+ idle_ = false;
+ stream.second->delegate()->OnReceived(
+ stream.second, reinterpret_cast<const char*>(received_data.data()),
+ received_data.size());
+ }
+ if (written_data.size()) {
+ idle_ = false;
+ remote_demuxer_->OnStreamData(0, stream.second->id(),
+ written_data.data(), written_data.size());
+ }
+ }
+ }
+ for (auto& connection : pending_connections_) {
+ idle_ = false;
+ connection.second->delegate()->OnCryptoHandshakeComplete(
+ connection.second->id());
+ connections_.emplace(connection.first, connection.second);
+ }
+ pending_connections_.clear();
+}
+
+std::unique_ptr<QuicConnection> FakeQuicConnectionFactory::Connect(
+ const IPEndpoint& endpoint,
+ QuicConnection::Delegate* connection_delegate) {
+ OSP_DCHECK(pending_connections_.find(endpoint) == pending_connections_.end());
+ auto connection = MakeUnique<FakeQuicConnection>(this, next_connection_id_++,
+ connection_delegate);
+ pending_connections_.emplace(endpoint, connection.get());
+ return connection;
+}
+
+} // namespace openscreen
diff --git a/api/impl/quic/testing/fake_quic_connection_factory.h b/api/impl/quic/testing/fake_quic_connection_factory.h
new file mode 100644
index 0000000..483face
--- /dev/null
+++ b/api/impl/quic/testing/fake_quic_connection_factory.h
@@ -0,0 +1,49 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_FACTORY_H_
+#define API_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_FACTORY_H_
+
+#include <vector>
+
+#include "api/impl/quic/quic_connection_factory.h"
+#include "api/impl/quic/testing/fake_quic_connection.h"
+#include "api/public/message_demuxer.h"
+
+namespace openscreen {
+
+class FakeQuicConnectionFactory final : public QuicConnectionFactory {
+ public:
+ FakeQuicConnectionFactory(const IPEndpoint& local_endpoint,
+ MessageDemuxer* remote_demuxer);
+ ~FakeQuicConnectionFactory() override;
+
+ bool idle() const { return idle_; }
+
+ void StartServerConnection(const IPEndpoint& endpoint);
+ FakeQuicStream* StartIncomingStream(const IPEndpoint& endpoint);
+ void OnConnectionClosed(QuicConnection* connection);
+
+ // QuicConnectionFactory overrides.
+ void SetServerDelegate(ServerDelegate* delegate,
+ const std::vector<IPEndpoint>& endpoints) override;
+ void RunTasks() override;
+ std::unique_ptr<QuicConnection> Connect(
+ const IPEndpoint& endpoint,
+ QuicConnection::Delegate* connection_delegate) override;
+
+ private:
+ ServerDelegate* server_delegate_ = nullptr;
+ MessageDemuxer* const remote_demuxer_;
+ const IPEndpoint local_endpoint_;
+ bool idle_ = true;
+ uint64_t next_connection_id_ = 0;
+ std::map<IPEndpoint, FakeQuicConnection*, IPEndpointComparator>
+ pending_connections_;
+ std::map<IPEndpoint, FakeQuicConnection*, IPEndpointComparator> connections_;
+};
+
+} // namespace openscreen
+
+#endif // API_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_FACTORY_H_
diff --git a/api/public/BUILD.gn b/api/public/BUILD.gn
index 4d866f3..27fee99 100644
--- a/api/public/BUILD.gn
+++ b/api/public/BUILD.gn
@@ -6,17 +6,23 @@
sources = [
"client_config.cc",
"client_config.h",
+ "mdns_screen_listener_factory.h",
+ "mdns_screen_publisher_factory.h",
+ "message_demuxer.cc",
+ "message_demuxer.h",
"network_metrics.h",
- "network_service_manager.cc",
"network_service_manager.h",
- "presentation/presentation_common.h",
+ "presentation/presentation_connection.h",
"presentation/presentation_controller.h",
"presentation/presentation_receiver.h",
+ "protocol_connection.cc",
"protocol_connection.h",
"protocol_connection_client.cc",
"protocol_connection_client.h",
+ "protocol_connection_client_factory.h",
"protocol_connection_server.cc",
"protocol_connection_server.h",
+ "protocol_connection_server_factory.h",
"screen_info.cc",
"screen_info.h",
"screen_listener.cc",
@@ -27,7 +33,11 @@
"server_config.h",
]
+ public_deps = [
+ "//msgs",
+ ]
deps = [
"//base",
+ "//platform",
]
}
diff --git a/api/impl/mdns_screen_listener_factory.h b/api/public/mdns_screen_listener_factory.h
similarity index 79%
rename from api/impl/mdns_screen_listener_factory.h
rename to api/public/mdns_screen_listener_factory.h
index 15cee2c..486c67b 100644
--- a/api/impl/mdns_screen_listener_factory.h
+++ b/api/public/mdns_screen_listener_factory.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef API_IMPL_MDNS_SCREEN_LISTENER_FACTORY_H_
-#define API_IMPL_MDNS_SCREEN_LISTENER_FACTORY_H_
+#ifndef API_PUBLIC_MDNS_SCREEN_LISTENER_FACTORY_H_
+#define API_PUBLIC_MDNS_SCREEN_LISTENER_FACTORY_H_
#include <memory>
@@ -25,4 +25,4 @@
} // namespace openscreen
-#endif // API_IMPL_MDNS_SCREEN_LISTENER_FACTORY_H_
+#endif // API_PUBLIC_MDNS_SCREEN_LISTENER_FACTORY_H_
diff --git a/api/impl/mdns_screen_publisher_factory.h b/api/public/mdns_screen_publisher_factory.h
similarity index 74%
rename from api/impl/mdns_screen_publisher_factory.h
rename to api/public/mdns_screen_publisher_factory.h
index b5ee3e2..f57db88 100644
--- a/api/impl/mdns_screen_publisher_factory.h
+++ b/api/public/mdns_screen_publisher_factory.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef API_IMPL_MDNS_SCREEN_PUBLISHER_FACTORY_H_
-#define API_IMPL_MDNS_SCREEN_PUBLISHER_FACTORY_H_
+#ifndef API_PUBLIC_MDNS_SCREEN_PUBLISHER_FACTORY_H_
+#define API_PUBLIC_MDNS_SCREEN_PUBLISHER_FACTORY_H_
#include <memory>
@@ -20,4 +20,4 @@
} // namespace openscreen
-#endif // API_IMPL_MDNS_SCREEN_PUBLISHER_FACTORY_H_
+#endif // API_PUBLIC_MDNS_SCREEN_PUBLISHER_FACTORY_H_
diff --git a/api/public/message_demuxer.cc b/api/public/message_demuxer.cc
new file mode 100644
index 0000000..023da54
--- /dev/null
+++ b/api/public/message_demuxer.cc
@@ -0,0 +1,197 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/public/message_demuxer.h"
+
+#include "api/impl/quic/quic_connection.h"
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+MessageDemuxer::MessageWatch::MessageWatch() = default;
+
+MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent,
+ bool is_default,
+ uint64_t endpoint_id,
+ msgs::Type message_type)
+ : parent_(parent),
+ is_default_(is_default),
+ endpoint_id_(endpoint_id),
+ message_type_(message_type) {}
+
+MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer::MessageWatch&& other)
+ : parent_(other.parent_),
+ is_default_(other.is_default_),
+ endpoint_id_(other.endpoint_id_),
+ message_type_(other.message_type_) {
+ other.parent_ = nullptr;
+}
+
+MessageDemuxer::MessageWatch::~MessageWatch() {
+ if (parent_) {
+ if (is_default_) {
+ parent_->StopDefaultMessageTypeWatch(message_type_);
+ } else {
+ parent_->StopWatchingMessageType(endpoint_id_, message_type_);
+ }
+ }
+}
+
+MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
+ MessageWatch&& other) {
+ using std::swap;
+ swap(parent_, other.parent_);
+ swap(is_default_, other.is_default_);
+ swap(endpoint_id_, other.endpoint_id_);
+ swap(message_type_, other.message_type_);
+ return *this;
+}
+
+MessageDemuxer::MessageDemuxer(size_t buffer_limit)
+ : buffer_limit_(buffer_limit) {}
+MessageDemuxer::~MessageDemuxer() = default;
+
+MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType(
+ uint64_t endpoint_id,
+ msgs::Type message_type,
+ MessageCallback* callback) {
+ auto callbacks_entry = message_callbacks_.find(endpoint_id);
+ if (callbacks_entry == message_callbacks_.end()) {
+ callbacks_entry =
+ message_callbacks_
+ .emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{})
+ .first;
+ }
+ auto emplace_result = callbacks_entry->second.emplace(message_type, callback);
+ if (!emplace_result.second)
+ return MessageWatch();
+ auto endpoint_entry = buffers_.find(endpoint_id);
+ if (endpoint_entry != buffers_.end()) {
+ for (auto& buffer : endpoint_entry->second) {
+ if (buffer.second.empty())
+ continue;
+ auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
+ if (message_type == buffered_type) {
+ HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry,
+ &buffer.second);
+ }
+ }
+ }
+ return MessageWatch(this, false, endpoint_id, message_type);
+}
+
+MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch(
+ msgs::Type message_type,
+ MessageCallback* callback) {
+ auto emplace_result = default_callbacks_.emplace(message_type, callback);
+ if (!emplace_result.second)
+ return MessageWatch();
+ for (auto& endpoint_buffers : buffers_) {
+ for (auto& buffer : endpoint_buffers.second) {
+ if (buffer.second.empty())
+ continue;
+ auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
+ if (message_type == buffered_type) {
+ auto callbacks_entry = message_callbacks_.find(endpoint_buffers.first);
+ HandleStreamBufferLoop(endpoint_buffers.first, buffer.first,
+ callbacks_entry, &buffer.second);
+ }
+ }
+ }
+ return MessageWatch(this, true, 0, message_type);
+}
+
+void MessageDemuxer::OnStreamData(uint64_t endpoint_id,
+ uint64_t connection_id,
+ const uint8_t* data,
+ size_t data_size) {
+ OSP_VLOG(1) << __func__ << ": " << endpoint_id << " - (" << data_size << ")";
+ auto& stream_map = buffers_[endpoint_id];
+ if (!data_size) {
+ stream_map.erase(connection_id);
+ if (stream_map.empty())
+ buffers_.erase(endpoint_id);
+ return;
+ }
+ std::vector<uint8_t>& buffer = stream_map[connection_id];
+ buffer.insert(buffer.end(), data, data + data_size);
+
+ auto callbacks_entry = message_callbacks_.find(endpoint_id);
+ HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer);
+
+ if (buffer.size() > buffer_limit_)
+ stream_map.erase(connection_id);
+}
+
+void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id,
+ msgs::Type message_type) {
+ auto& message_map = message_callbacks_[endpoint_id];
+ auto it = message_map.find(message_type);
+ message_map.erase(it);
+}
+
+void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) {
+ default_callbacks_.erase(message_type);
+}
+
+MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop(
+ uint64_t endpoint_id,
+ uint64_t connection_id,
+ std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
+ callbacks_entry,
+ std::vector<uint8_t>* buffer) {
+ HandleStreamBufferResult result;
+ do {
+ result = {false, 0};
+ if (callbacks_entry != message_callbacks_.end()) {
+ OSP_VLOG(1) << "attempting endpoint-specific handling";
+ result = HandleStreamBuffer(endpoint_id, connection_id,
+ &callbacks_entry->second, buffer);
+ }
+ if (!result.handled) {
+ if (!default_callbacks_.empty()) {
+ OSP_VLOG(1) << "attempting generic message handling";
+ result = HandleStreamBuffer(endpoint_id, connection_id,
+ &default_callbacks_, buffer);
+ }
+ }
+ } while (result.consumed && !buffer->empty());
+ return result;
+}
+
+MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer(
+ uint64_t endpoint_id,
+ uint64_t connection_id,
+ std::map<msgs::Type, MessageCallback*>* message_callbacks,
+ std::vector<uint8_t>* buffer) {
+ size_t consumed = 0;
+ size_t total_consumed = 0;
+ bool handled = false;
+ do {
+ consumed = 0;
+ auto message_type = static_cast<msgs::Type>((*buffer)[0]);
+ auto callback_entry = message_callbacks->find(message_type);
+ if (callback_entry == message_callbacks->end())
+ break;
+ handled = true;
+ OSP_VLOG(1) << "handling message type " << static_cast<int>(message_type);
+ auto consumed_or_error = callback_entry->second->OnStreamMessage(
+ endpoint_id, connection_id, message_type, buffer->data() + 1,
+ buffer->size() - 1);
+ if (consumed_or_error.is_error()) {
+ if (consumed_or_error.error().code() !=
+ Error::Code::kCborIncompleteMessage) {
+ buffer->clear();
+ break;
+ }
+ } else {
+ consumed = consumed_or_error.value();
+ buffer->erase(buffer->begin(), buffer->begin() + consumed + 1);
+ }
+ total_consumed += consumed;
+ } while (consumed && !buffer->empty());
+ return HandleStreamBufferResult{handled, total_consumed};
+}
+
+} // namespace openscreen
diff --git a/api/public/message_demuxer.h b/api/public/message_demuxer.h
new file mode 100644
index 0000000..aab109a
--- /dev/null
+++ b/api/public/message_demuxer.h
@@ -0,0 +1,112 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_PUBLIC_MESSAGE_DEMUXER_H_
+#define API_PUBLIC_MESSAGE_DEMUXER_H_
+
+#include <map>
+#include <vector>
+
+#include "base/error.h"
+#include "msgs/osp_messages.h"
+
+namespace openscreen {
+
+class QuicStream;
+
+// This class separates QUIC stream data into CBOR messages by reading a type
+// prefix from the stream and passes those messages to any callback matching the
+// source endpoint and message type. If there is no callback for a given
+// message type, it will also try a default message listener.
+class MessageDemuxer {
+ public:
+ class MessageCallback {
+ public:
+ virtual ~MessageCallback() = default;
+
+ // |buffer| contains data for a message of type |message_type|. However,
+ // the data may be incomplete, in which case the callback should return an
+ // error code of Error::Code::kCborIncompleteMessage. This way, the
+ // MessageDemuxer knows to neither consume the data nor discard it as bad.
+ virtual ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
+ uint64_t connection_id,
+ msgs::Type message_type,
+ const uint8_t* buffer,
+ size_t buffer_size) = 0;
+ };
+
+ class MessageWatch {
+ public:
+ MessageWatch();
+ MessageWatch(MessageDemuxer* parent,
+ bool is_default,
+ uint64_t endpoint_id,
+ msgs::Type message_type);
+ MessageWatch(MessageWatch&&);
+ ~MessageWatch();
+ MessageWatch& operator=(MessageWatch&&);
+
+ explicit operator bool() const { return parent_; }
+
+ private:
+ MessageDemuxer* parent_ = nullptr;
+ bool is_default_;
+ uint64_t endpoint_id_;
+ msgs::Type message_type_;
+ };
+
+ static constexpr size_t kDefaultBufferLimit = 1 << 16;
+
+ explicit MessageDemuxer(size_t buffer_limit = kDefaultBufferLimit);
+ ~MessageDemuxer();
+
+ // Starts watching for messages of type |message_type| from the endpoint
+ // identified by |endpoint_id|. When such a message arrives, or if some are
+ // already buffered, |callback| will be called with the message data.
+ MessageWatch WatchMessageType(uint64_t endpoint_id,
+ msgs::Type message_type,
+ MessageCallback* callback);
+
+ // Starts watching for messages of type |message_type| from any endpoint when
+ // there is not callback set for its specific endpoint ID.
+ MessageWatch SetDefaultMessageTypeWatch(msgs::Type message_type,
+ MessageCallback* callback);
+
+ // Gives data from |endpoint_id| to the demuxer for processing.
+ void OnStreamData(uint64_t endpoint_id,
+ uint64_t connection_id,
+ const uint8_t* data,
+ size_t data_size);
+
+ private:
+ struct HandleStreamBufferResult {
+ bool handled;
+ size_t consumed;
+ };
+
+ void StopWatchingMessageType(uint64_t endpoint_id, msgs::Type message_type);
+ void StopDefaultMessageTypeWatch(msgs::Type message_type);
+
+ HandleStreamBufferResult HandleStreamBufferLoop(
+ uint64_t endpoint_id,
+ uint64_t connection_id,
+ std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
+ endpoint_entry,
+ std::vector<uint8_t>* buffer);
+
+ HandleStreamBufferResult HandleStreamBuffer(
+ uint64_t endpoint_id,
+ uint64_t connection_id,
+ std::map<msgs::Type, MessageCallback*>* message_callbacks,
+ std::vector<uint8_t>* buffer);
+
+ const size_t buffer_limit_;
+ std::map<uint64_t, std::map<msgs::Type, MessageCallback*>> message_callbacks_;
+ std::map<msgs::Type, MessageCallback*> default_callbacks_;
+ std::map<uint64_t, std::map<uint64_t, std::vector<uint8_t>>> buffers_;
+};
+
+} // namespace openscreen
+
+#endif // API_PUBLIC_MESSAGE_DEMUXER_H_
diff --git a/api/public/message_demuxer_unittest.cc b/api/public/message_demuxer_unittest.cc
new file mode 100644
index 0000000..8c5af51
--- /dev/null
+++ b/api/public/message_demuxer_unittest.cc
@@ -0,0 +1,349 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/public/message_demuxer.h"
+
+#include "third_party/googletest/src/googlemock/include/gmock/gmock.h"
+#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
+#include "third_party/tinycbor/src/src/cbor.h"
+
+namespace openscreen {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+
+class MockMessageCallback final : public MessageDemuxer::MessageCallback {
+ public:
+ ~MockMessageCallback() override = default;
+
+ MOCK_METHOD5(OnStreamMessage,
+ ErrorOr<size_t>(uint64_t endpoint_id,
+ uint64_t connection_id,
+ msgs::Type message_type,
+ const uint8_t* buffer,
+ size_t buffer_size));
+};
+
+ErrorOr<size_t> ConvertDecodeResult(ssize_t result) {
+ if (result < 0) {
+ if (result == -CborErrorUnexpectedEOF)
+ return Error::Code::kCborIncompleteMessage;
+ else
+ return Error::Code::kCborParsing;
+ } else {
+ return result;
+ }
+}
+
+class MessageDemuxerTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ ASSERT_TRUE(
+ msgs::EncodePresentationConnectionOpenRequest(request_, &buffer_));
+ }
+
+ void ExpectDecodedRequest(
+ ssize_t decode_result,
+ const msgs::PresentationConnectionOpenRequest& received_request) {
+ ASSERT_GT(decode_result, 0);
+ EXPECT_EQ(decode_result, static_cast<ssize_t>(buffer_.size() - 1));
+ EXPECT_EQ(request_.request_id, received_request.request_id);
+ EXPECT_EQ(request_.presentation_id, received_request.presentation_id);
+ EXPECT_EQ(request_.connection_id, received_request.connection_id);
+ }
+
+ const uint64_t endpoint_id_ = 13;
+ const uint64_t connection_id_ = 45;
+ msgs::CborEncodeBuffer buffer_;
+ msgs::PresentationConnectionOpenRequest request_{1, "fry-am-the-egg-man", 3};
+ MockMessageCallback mock_callback_;
+ MessageDemuxer demuxer_;
+};
+
+} // namespace
+
+TEST_F(MessageDemuxerTest, WatchStartStop) {
+ MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
+ endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
+ &mock_callback_);
+ ASSERT_TRUE(watch);
+
+ EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _)).Times(0);
+ demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size());
+
+ msgs::PresentationConnectionOpenRequest received_request;
+ ssize_t decode_result = 0;
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_sizesize) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_sizesize, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+ ExpectDecodedRequest(decode_result, received_request);
+
+ watch = MessageDemuxer::MessageWatch();
+ EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _)).Times(0);
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+}
+
+TEST_F(MessageDemuxerTest, BufferPartialMessage) {
+ MessageDemuxer demuxer_;
+ MockMessageCallback mock_callback_;
+ constexpr uint64_t endpoint_id_ = 13;
+
+ MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
+ endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
+ &mock_callback_);
+ ASSERT_TRUE(watch);
+
+ msgs::PresentationConnectionOpenRequest received_request;
+ ssize_t decode_result = 0;
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .Times(2)
+ .WillRepeatedly(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size() - 3);
+ demuxer_.OnStreamData(endpoint_id_, connection_id_,
+ buffer_.data() + buffer_.size() - 3, 3);
+ ExpectDecodedRequest(decode_result, received_request);
+}
+
+TEST_F(MessageDemuxerTest, DefaultWatch) {
+ MessageDemuxer demuxer_;
+ MockMessageCallback mock_callback_;
+ constexpr uint64_t endpoint_id_ = 13;
+
+ MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch(
+ msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_);
+ ASSERT_TRUE(watch);
+
+ msgs::PresentationConnectionOpenRequest received_request;
+ ssize_t decode_result = 0;
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+ ExpectDecodedRequest(decode_result, received_request);
+}
+
+TEST_F(MessageDemuxerTest, DefaultWatchOverridden) {
+ MessageDemuxer demuxer_;
+ MockMessageCallback mock_callback_global;
+ MockMessageCallback mock_callback_;
+ constexpr uint64_t endpoint_id_ = 13;
+
+ MessageDemuxer::MessageWatch default_watch =
+ demuxer_.SetDefaultMessageTypeWatch(
+ msgs::Type::kPresentationConnectionOpenRequest,
+ &mock_callback_global);
+ ASSERT_TRUE(default_watch);
+ MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
+ endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
+ &mock_callback_);
+ ASSERT_TRUE(watch);
+
+ msgs::PresentationConnectionOpenRequest received_request;
+ ssize_t decode_result = 0;
+ EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _)).Times(0);
+ EXPECT_CALL(
+ mock_callback_global,
+ OnStreamMessage(endpoint_id_ + 1, 14,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size());
+ ExpectDecodedRequest(decode_result, received_request);
+
+ decode_result = 0;
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+ ExpectDecodedRequest(decode_result, received_request);
+}
+
+TEST_F(MessageDemuxerTest, WatchAfterData) {
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+
+ msgs::PresentationConnectionOpenRequest received_request;
+ ssize_t decode_result = 0;
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
+ endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
+ &mock_callback_);
+ ASSERT_TRUE(watch);
+ ExpectDecodedRequest(decode_result, received_request);
+}
+
+TEST_F(MessageDemuxerTest, WatchAfterMultipleData) {
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+
+ msgs::CborEncodeBuffer buffer;
+ msgs::PresentationInitiationRequest request;
+ request.request_id = 2;
+ request.url = "https://example.com/recv";
+ request.connection_id = 98;
+ request.has_connection_id = true;
+ ASSERT_TRUE(msgs::EncodePresentationInitiationRequest(request, &buffer));
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer.data(),
+ buffer.size());
+
+ MockMessageCallback mock_init_callback;
+ msgs::PresentationConnectionOpenRequest received_request;
+ msgs::PresentationInitiationRequest received_init_request;
+ ssize_t decode_result1 = 0;
+ ssize_t decode_result2 = 0;
+ MessageDemuxer::MessageWatch init_watch = demuxer_.WatchMessageType(
+ endpoint_id_, msgs::Type::kPresentationInitiationRequest,
+ &mock_init_callback);
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result1, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result1 = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result1);
+ }));
+ EXPECT_CALL(mock_init_callback,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationInitiationRequest, _, _))
+ .WillOnce(Invoke([&decode_result2, &received_init_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result2 = msgs::DecodePresentationInitiationRequest(
+ buffer, buffer_size, &received_init_request);
+ return ConvertDecodeResult(decode_result2);
+ }));
+ MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
+ endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
+ &mock_callback_);
+ ASSERT_TRUE(watch);
+ ExpectDecodedRequest(decode_result1, received_request);
+
+ ASSERT_GT(decode_result2, 0);
+ EXPECT_EQ(decode_result2, static_cast<ssize_t>(buffer.size() - 1));
+ EXPECT_EQ(request.request_id, received_init_request.request_id);
+ EXPECT_EQ(request.url, received_init_request.url);
+ EXPECT_EQ(request.connection_id, received_init_request.connection_id);
+ EXPECT_EQ(request.has_connection_id, received_init_request.has_connection_id);
+}
+
+TEST_F(MessageDemuxerTest, GlobalWatchAfterData) {
+ demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+
+ msgs::PresentationConnectionOpenRequest received_request;
+ ssize_t decode_result = 0;
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch(
+ msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_);
+ ASSERT_TRUE(watch);
+ ExpectDecodedRequest(decode_result, received_request);
+}
+
+TEST_F(MessageDemuxerTest, BufferLimit) {
+ MessageDemuxer demuxer(10);
+
+ demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+ EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _)).Times(0);
+ MessageDemuxer::MessageWatch watch = demuxer.WatchMessageType(
+ endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
+ &mock_callback_);
+
+ msgs::PresentationConnectionOpenRequest received_request;
+ ssize_t decode_result = 0;
+ EXPECT_CALL(
+ mock_callback_,
+ OnStreamMessage(endpoint_id_, connection_id_,
+ msgs::Type::kPresentationConnectionOpenRequest, _, _))
+ .WillOnce(Invoke([&decode_result, &received_request](
+ uint64_t endpoint_id, uint64_t connection_id,
+ msgs::Type message_type, const uint8_t* buffer,
+ size_t buffer_size) {
+ decode_result = msgs::DecodePresentationConnectionOpenRequest(
+ buffer, buffer_size, &received_request);
+ return ConvertDecodeResult(decode_result);
+ }));
+ demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
+ buffer_.size());
+ ExpectDecodedRequest(decode_result, received_request);
+}
+
+} // namespace openscreen
diff --git a/api/public/network_service_manager.h b/api/public/network_service_manager.h
index 6a6b3af..5456a77 100644
--- a/api/public/network_service_manager.h
+++ b/api/public/network_service_manager.h
@@ -40,6 +40,9 @@
// by the service instance destructors.
static void Dispose();
+ // Runs the event loop once for all of its owned services. This mostly
+ // consists of check for available network events and passing that data to the
+ // listening services.
void RunEventLoopOnce();
// Returns an instance of the mDNS screen listener, or nullptr if
diff --git a/api/public/protocol_connection.cc b/api/public/protocol_connection.cc
new file mode 100644
index 0000000..30081ec
--- /dev/null
+++ b/api/public/protocol_connection.cc
@@ -0,0 +1,20 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "api/public/protocol_connection.h"
+
+#include "platform/api/logging.h"
+
+namespace openscreen {
+
+ProtocolConnection::ProtocolConnection(uint64_t endpoint_id,
+ uint64_t connection_id)
+ : endpoint_id_(endpoint_id), connection_id_(connection_id) {}
+
+void ProtocolConnection::SetObserver(Observer* observer) {
+ OSP_DCHECK(!observer_ || !observer);
+ observer_ = observer;
+}
+
+} // namespace openscreen
diff --git a/api/public/protocol_connection.h b/api/public/protocol_connection.h
index 2e06b25..b68ac3c 100644
--- a/api/public/protocol_connection.h
+++ b/api/public/protocol_connection.h
@@ -5,10 +5,13 @@
#ifndef API_PUBLIC_PROTOCOL_CONNECTION_H_
#define API_PUBLIC_PROTOCOL_CONNECTION_H_
+#include <cstddef>
+#include <cstdint>
+
namespace openscreen {
class Error;
-class NetworkMetrics;
+struct NetworkMetrics;
// Represents an embedder's view of a connection between an Open Screen
// controller and a receiver. Both the controller and receiver will have a
@@ -19,7 +22,20 @@
// standard and can be extended by embedders with additional protocols.
class ProtocolConnection {
public:
- ProtocolConnection() = default;
+ class Observer {
+ public:
+ virtual ~Observer() = default;
+
+ // Called when the state of |connection| has changed.
+ virtual void OnConnectionChanged(const ProtocolConnection& connection) = 0;
+
+ // Called when |connection| is no longer available, either because the
+ // underlying transport was terminated, the underying system resource was
+ // closed, or data can no longer be exchanged.
+ virtual void OnConnectionClosed(const ProtocolConnection& connection) = 0;
+ };
+
+ ProtocolConnection(uint64_t endpoint_id, uint64_t connection_id);
virtual ~ProtocolConnection() = default;
// TODO(mfoltz): Define extension API exposed to embedders. This would be
@@ -31,30 +47,37 @@
// establishment. What about server connections? We probably want to have
// two different structures representing what the client and server know about
// a connection.
+
+ void SetObserver(Observer* observer);
+
+ // TODO(btolsch): This should be derived from the handshake auth identifier
+ // when that is finalized and implemented.
+ uint64_t endpoint_id() const { return endpoint_id_; }
+ uint64_t connection_id() const { return connection_id_; }
+
+ virtual void Write(const uint8_t* data, size_t data_size) = 0;
+ virtual void CloseWriteEnd() = 0;
+
+ protected:
+ uint64_t endpoint_id_;
+ uint64_t connection_id_;
+ Observer* observer_ = nullptr;
};
-class ProtocolConnectionObserver {
+class ProtocolConnectionServiceObserver {
public:
// Called when the state becomes kRunning.
virtual void OnRunning() = 0;
// Called when the state becomes kStopped.
virtual void OnStopped() = 0;
- // Called when a new connection was created between 5-tuples.
- virtual void OnConnectionAdded(const ProtocolConnection& connection) = 0;
- // Called when the state of |connection| has changed.
- virtual void OnConnectionChanged(const ProtocolConnection& connection) = 0;
- // Called when |connection| is no longer available, either because the
- // underlying transport was terminated, the underying system resource was
- // closed, or data can no longer be exchanged.
- virtual void OnConnectionRemoved(const ProtocolConnection& connection) = 0;
// Called when metrics have been collected by the service.
virtual void OnMetrics(const NetworkMetrics& metrics) = 0;
// Called when an error has occurred.
virtual void OnError(const Error& error) = 0;
protected:
- virtual ~ProtocolConnectionObserver() = default;
+ virtual ~ProtocolConnectionServiceObserver() = default;
};
} // namespace openscreen
diff --git a/api/public/protocol_connection_client.cc b/api/public/protocol_connection_client.cc
index 0696222..88e0c16 100644
--- a/api/public/protocol_connection_client.cc
+++ b/api/public/protocol_connection_client.cc
@@ -6,10 +6,52 @@
namespace openscreen {
+ProtocolConnectionClient::ConnectRequest::ConnectRequest() = default;
+
+ProtocolConnectionClient::ConnectRequest::ConnectRequest(
+ ProtocolConnectionClient* parent,
+ uint64_t request_id)
+ : parent_(parent), request_id_(request_id) {}
+
+ProtocolConnectionClient::ConnectRequest::ConnectRequest(ConnectRequest&& other)
+ : parent_(other.parent_), request_id_(other.request_id_) {
+ other.request_id_ = 0;
+}
+
+ProtocolConnectionClient::ConnectRequest::~ConnectRequest() {
+ if (request_id_)
+ parent_->CancelConnectRequest(request_id_);
+}
+
+ProtocolConnectionClient::ConnectRequest&
+ProtocolConnectionClient::ConnectRequest::operator=(ConnectRequest&& other) {
+ using std::swap;
+ swap(parent_, other.parent_);
+ swap(request_id_, other.request_id_);
+ return *this;
+}
+
ProtocolConnectionClient::ProtocolConnectionClient(
- ProtocolConnectionObserver* observer)
- : observer_(observer) {}
+ MessageDemuxer* demuxer,
+ ProtocolConnectionServiceObserver* observer)
+ : demuxer_(demuxer), observer_(observer) {}
ProtocolConnectionClient::~ProtocolConnectionClient() = default;
+std::ostream& operator<<(std::ostream& os,
+ ProtocolConnectionClient::State state) {
+ switch (state) {
+ case ProtocolConnectionClient::State::kStopped:
+ return os << "STOPPED";
+ case ProtocolConnectionClient::State::kStarting:
+ return os << "STARTING";
+ case ProtocolConnectionClient::State::kRunning:
+ return os << "RUNNING";
+ case ProtocolConnectionClient::State::kStopping:
+ return os << "STOPPING";
+ default:
+ return os << "UNKNOWN";
+ }
+}
+
} // namespace openscreen
diff --git a/api/public/protocol_connection_client.h b/api/public/protocol_connection_client.h
index 829f024..32b5ebf 100644
--- a/api/public/protocol_connection_client.h
+++ b/api/public/protocol_connection_client.h
@@ -5,10 +5,14 @@
#ifndef API_PUBLIC_PROTOCOL_CONNECTION_CLIENT_H_
#define API_PUBLIC_PROTOCOL_CONNECTION_CLIENT_H_
+#include <memory>
+#include <ostream>
#include <string>
+#include "api/public/message_demuxer.h"
#include "api/public/protocol_connection.h"
#include "base/error.h"
+#include "base/ip_address.h"
#include "base/macros.h"
namespace openscreen {
@@ -23,6 +27,38 @@
public:
enum class State { kStopped = 0, kStarting, kRunning, kStopping };
+ class ConnectionRequestCallback {
+ public:
+ virtual ~ConnectionRequestCallback() = default;
+
+ // Called when a new connection was created between 5-tuples.
+ virtual void OnConnectionOpened(
+ uint64_t request_id,
+ std::unique_ptr<ProtocolConnection>&& connection) = 0;
+ virtual void OnConnectionFailed(uint64_t request_id) = 0;
+ };
+
+ class ConnectRequest {
+ public:
+ ConnectRequest();
+ ConnectRequest(ProtocolConnectionClient* parent, uint64_t request_id);
+ ConnectRequest(ConnectRequest&& other);
+ ~ConnectRequest();
+ ConnectRequest& operator=(ConnectRequest&& other);
+
+ explicit operator bool() const { return request_id_; }
+
+ uint64_t request_id() const { return request_id_; }
+
+ // Records that the requested connect operation is successful so it doesn't
+ // need to attempt a cancel on destruction.
+ void MarkSuccess() { request_id_ = 0; }
+
+ private:
+ ProtocolConnectionClient* parent_ = nullptr;
+ uint64_t request_id_ = 0;
+ };
+
virtual ~ProtocolConnectionClient();
// Starts the client using the config object.
@@ -39,6 +75,20 @@
// Returns true if state() != (kStopped|kStopping).
virtual bool Stop() = 0;
+ virtual void RunTasks() = 0;
+
+ // Open a new connection to |endpoint|. This may succeed synchronously if
+ // there are already connections open to |endpoint|, otherwise it will be
+ // asynchronous.
+ virtual ConnectRequest Connect(const IPEndpoint& endpoint,
+ ConnectionRequestCallback* request) = 0;
+
+ // Synchronously open a new connection to an endpoint identified by
+ // |endpoint_id|. Returns nullptr if it can't be completed synchronously
+ // (e.g. there are no existing open connections to that endpoint).
+ virtual std::unique_ptr<ProtocolConnection> CreateProtocolConnection(
+ uint64_t endpoint_id) = 0;
+
// Returns the current state of the listener.
State state() const { return state_; }
@@ -46,15 +96,23 @@
const Error& last_error() const { return last_error_; }
protected:
- explicit ProtocolConnectionClient(ProtocolConnectionObserver* observer);
+ explicit ProtocolConnectionClient(
+ MessageDemuxer* demuxer,
+ ProtocolConnectionServiceObserver* observer);
+
+ virtual void CancelConnectRequest(uint64_t request_id) = 0;
State state_ = State::kStopped;
Error last_error_;
- ProtocolConnectionObserver* const observer_;
+ MessageDemuxer* const demuxer_;
+ ProtocolConnectionServiceObserver* const observer_;
DISALLOW_COPY_AND_ASSIGN(ProtocolConnectionClient);
};
+std::ostream& operator<<(std::ostream& os,
+ ProtocolConnectionClient::State state);
+
} // namespace openscreen
#endif // API_PUBLIC_PROTOCOL_CONNECTION_CLIENT_H_
diff --git a/api/public/protocol_connection_client_factory.h b/api/public/protocol_connection_client_factory.h
new file mode 100644
index 0000000..199c7e2
--- /dev/null
+++ b/api/public/protocol_connection_client_factory.h
@@ -0,0 +1,23 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_PUBLIC_PROTOCOL_CONNECTION_CLIENT_FACTORY_H_
+#define API_PUBLIC_PROTOCOL_CONNECTION_CLIENT_FACTORY_H_
+
+#include <memory>
+
+#include "api/public/protocol_connection_client.h"
+
+namespace openscreen {
+
+class ProtocolConnectionClientFactory {
+ public:
+ static std::unique_ptr<ProtocolConnectionClient> Create(
+ MessageDemuxer* demuxer,
+ ProtocolConnectionServiceObserver* observer);
+};
+
+} // namespace openscreen
+
+#endif // API_PUBLIC_PROTOCOL_CONNECTION_CLIENT_FACTORY_H_
diff --git a/api/public/protocol_connection_server.cc b/api/public/protocol_connection_server.cc
index c778cb2..fb3845c 100644
--- a/api/public/protocol_connection_server.cc
+++ b/api/public/protocol_connection_server.cc
@@ -6,10 +6,28 @@
namespace openscreen {
-ProtocolConnectionServer::ProtocolConnectionServer(const ServerConfig& config,
+ProtocolConnectionServer::ProtocolConnectionServer(MessageDemuxer* demuxer,
Observer* observer)
- : config_(config), observer_(observer) {}
+ : demuxer_(demuxer), observer_(observer) {}
ProtocolConnectionServer::~ProtocolConnectionServer() = default;
+std::ostream& operator<<(std::ostream& os,
+ ProtocolConnectionServer::State state) {
+ switch (state) {
+ case ProtocolConnectionServer::State::kStopped:
+ return os << "STOPPED";
+ case ProtocolConnectionServer::State::kStarting:
+ return os << "STARTING";
+ case ProtocolConnectionServer::State::kRunning:
+ return os << "RUNNING";
+ case ProtocolConnectionServer::State::kStopping:
+ return os << "STOPPING";
+ case ProtocolConnectionServer::State::kSuspended:
+ return os << "SUSPENDED";
+ default:
+ return os << "UNKNOWN";
+ }
+}
+
} // namespace openscreen
diff --git a/api/public/protocol_connection_server.h b/api/public/protocol_connection_server.h
index e0b20ae..9d91d8a 100644
--- a/api/public/protocol_connection_server.h
+++ b/api/public/protocol_connection_server.h
@@ -5,12 +5,16 @@
#ifndef API_PUBLIC_PROTOCOL_CONNECTION_SERVER_H_
#define API_PUBLIC_PROTOCOL_CONNECTION_SERVER_H_
+#include <memory>
+#include <ostream>
#include <string>
#include <vector>
+#include "api/public/message_demuxer.h"
#include "api/public/protocol_connection.h"
#include "api/public/server_config.h"
#include "base/error.h"
+#include "base/ip_address.h"
#include "base/macros.h"
namespace openscreen {
@@ -25,12 +29,15 @@
kSuspended,
};
- class Observer : public ProtocolConnectionObserver {
+ class Observer : public ProtocolConnectionServiceObserver {
public:
virtual ~Observer() = default;
// Called when the state becomes kSuspended.
virtual void OnSuspended() = 0;
+
+ virtual void OnIncomingConnection(
+ std::unique_ptr<ProtocolConnection>&& connection) = 0;
};
virtual ~ProtocolConnectionServer();
@@ -57,6 +64,14 @@
// connections.
virtual bool Resume() = 0;
+ virtual void RunTasks() = 0;
+
+ // Synchronously open a new connection to an endpoint identified by
+ // |endpoint_id|. Returns nullptr if it can't be completed synchronously
+ // (e.g. there are no existing open connections to that endpoint).
+ virtual std::unique_ptr<ProtocolConnection> CreateProtocolConnection(
+ uint64_t endpoint_id) = 0;
+
// Returns the current state of the listener.
State state() const { return state_; }
@@ -64,16 +79,20 @@
const Error& last_error() const { return last_error_; }
protected:
- ProtocolConnectionServer(const ServerConfig& config, Observer* observer);
+ explicit ProtocolConnectionServer(MessageDemuxer* demuxer,
+ Observer* observer);
- ServerConfig config_;
State state_ = State::kStopped;
Error last_error_;
+ MessageDemuxer* const demuxer_;
Observer* const observer_;
DISALLOW_COPY_AND_ASSIGN(ProtocolConnectionServer);
};
+std::ostream& operator<<(std::ostream& os,
+ ProtocolConnectionServer::State state);
+
} // namespace openscreen
#endif // API_PUBLIC_PROTOCOL_CONNECTION_SERVER_H_
diff --git a/api/public/protocol_connection_server_factory.h b/api/public/protocol_connection_server_factory.h
new file mode 100644
index 0000000..c932642
--- /dev/null
+++ b/api/public/protocol_connection_server_factory.h
@@ -0,0 +1,25 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef API_PUBLIC_PROTOCOL_CONNECTION_SERVER_FACTORY_H_
+#define API_PUBLIC_PROTOCOL_CONNECTION_SERVER_FACTORY_H_
+
+#include <memory>
+
+#include "api/public/protocol_connection_server.h"
+#include "api/public/server_config.h"
+
+namespace openscreen {
+
+class ProtocolConnectionServerFactory {
+ public:
+ static std::unique_ptr<ProtocolConnectionServer> Create(
+ const ServerConfig& config,
+ MessageDemuxer* demuxer,
+ ProtocolConnectionServer::Observer* observer);
+};
+
+} // namespace openscreen
+
+#endif // API_PUBLIC_PROTOCOL_CONNECTION_SERVER_FACTORY_H_
diff --git a/base/error.h b/base/error.h
index 01c42c9..e72a3d8 100644
--- a/base/error.h
+++ b/base/error.h
@@ -22,6 +22,7 @@
kNone = 0,
// CBOR parsing error.
kCborParsing = 1,
+ kCborIncompleteMessage,
// Presentation start errors.
kNoAvailableScreens,
diff --git a/base/ip_address.cc b/base/ip_address.cc
index 2a5dad1..d39f9ed 100644
--- a/base/ip_address.cc
+++ b/base/ip_address.cc
@@ -206,6 +206,26 @@
return true;
}
+bool operator==(const IPEndpoint& a, const IPEndpoint& b) {
+ return (a.address == b.address) && (a.port == b.port);
+}
+
+bool IPEndpointComparator::operator()(const IPEndpoint& a,
+ const IPEndpoint& b) const {
+ if (a.address.version() != b.address.version())
+ return a.address.version() < b.address.version();
+ if (a.address.IsV4()) {
+ int ret = memcmp(a.address.bytes_.data(), b.address.bytes_.data(), 4);
+ if (ret != 0)
+ return ret < 0;
+ } else {
+ int ret = memcmp(a.address.bytes_.data(), b.address.bytes_.data(), 16);
+ if (ret != 0)
+ return ret < 0;
+ }
+ return a.port < b.port;
+}
+
std::ostream& operator<<(std::ostream& out, const IPAddress& address) {
uint8_t values[16];
size_t len = 0;
diff --git a/base/ip_address.h b/base/ip_address.h
index 1ee763c..fb67b07 100644
--- a/base/ip_address.h
+++ b/base/ip_address.h
@@ -76,6 +76,8 @@
static bool ParseV4(const std::string& s, IPAddress* address);
static bool ParseV6(const std::string& s, IPAddress* address);
+ friend class IPEndpointComparator;
+
Version version_;
std::array<uint8_t, 16> bytes_;
};
@@ -86,6 +88,13 @@
uint16_t port;
};
+bool operator==(const IPEndpoint& a, const IPEndpoint& b);
+
+class IPEndpointComparator {
+ public:
+ bool operator()(const IPEndpoint& a, const IPEndpoint& b) const;
+};
+
// Outputs a string of the form:
// 123.234.34.56
// or fe80:0000:0000:0000:1234:5678:9abc:def0
diff --git a/base/macros.h b/base/macros.h
index 108cc51..451e90f 100644
--- a/base/macros.h
+++ b/base/macros.h
@@ -11,5 +11,8 @@
#define DISALLOW_COPY_AND_ASSIGN(ClassName) \
DISALLOW_COPY(ClassName); \
DISALLOW_ASSIGN(ClassName)
+#define DISALLOW_IMPLICIT_CONSTRUCTORS(ClassName) \
+ ClassName() = delete; \
+ DISALLOW_COPY_AND_ASSIGN(ClassName)
#endif // BASE_MACROS_H_
diff --git a/demo/demo.cc b/demo/demo.cc
index 8449f84..104f597 100644
--- a/demo/demo.cc
+++ b/demo/demo.cc
@@ -8,12 +8,21 @@
#include <algorithm>
#include <vector>
-#include "api/impl/mdns_screen_listener_factory.h"
-#include "api/impl/mdns_screen_publisher_factory.h"
+#include "api/public/mdns_screen_listener_factory.h"
+#include "api/public/mdns_screen_publisher_factory.h"
+#include "api/public/message_demuxer.h"
#include "api/public/network_service_manager.h"
+#include "api/public/protocol_connection_client.h"
+#include "api/public/protocol_connection_client_factory.h"
+#include "api/public/protocol_connection_server.h"
+#include "api/public/protocol_connection_server_factory.h"
#include "api/public/screen_listener.h"
#include "api/public/screen_publisher.h"
+#include "base/make_unique.h"
+#include "msgs/osp_messages.h"
#include "platform/api/logging.h"
+#include "platform/api/network_interface.h"
+#include "third_party/tinycbor/src/src/cbor.h"
namespace openscreen {
namespace {
@@ -50,6 +59,38 @@
OSP_LOG_INFO << "pid: " << getpid();
}
+class AutoMessage final
+ : public ProtocolConnectionClient::ConnectionRequestCallback {
+ public:
+ ~AutoMessage() override = default;
+
+ void TakeRequest(ProtocolConnectionClient::ConnectRequest&& request) {
+ request_ = std::move(request);
+ }
+
+ void OnConnectionOpened(
+ uint64_t request_id,
+ std::unique_ptr<ProtocolConnection>&& connection) override {
+ request_ = ProtocolConnectionClient::ConnectRequest();
+ msgs::CborEncodeBuffer buffer;
+ msgs::PresentationConnectionMessage message;
+ message.connection_id = 0;
+ message.presentation_id = "presentation-id-foo";
+ message.message.which = decltype(message.message.which)::kString;
+ new (&message.message.str) std::string("message from client");
+ if (msgs::EncodePresentationConnectionMessage(message, &buffer))
+ connection->Write(buffer.data(), buffer.size());
+ connection->CloseWriteEnd();
+ }
+
+ void OnConnectionFailed(uint64_t request_id) override {
+ request_ = ProtocolConnectionClient::ConnectRequest();
+ }
+
+ private:
+ ProtocolConnectionClient::ConnectRequest request_;
+};
+
class ListenerObserver final : public ScreenListener::Observer {
public:
~ListenerObserver() override = default;
@@ -60,6 +101,12 @@
void OnScreenAdded(const ScreenInfo& info) override {
OSP_LOG_INFO << "found! " << info.friendly_name;
+ if (!auto_message_) {
+ auto_message_ = MakeUnique<AutoMessage>();
+ auto_message_->TakeRequest(
+ NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
+ info.endpoint, auto_message_.get()));
+ }
}
void OnScreenChanged(const ScreenInfo& info) override {
OSP_LOG_INFO << "changed! " << info.friendly_name;
@@ -70,6 +117,9 @@
void OnAllScreensRemoved() override { OSP_LOG_INFO << "all removed!"; }
void OnError(ScreenListenerError) override {}
void OnMetrics(ScreenListener::Metrics) override {}
+
+ private:
+ std::unique_ptr<AutoMessage> auto_message_;
};
class PublisherObserver final : public ScreenPublisher::Observer {
@@ -84,6 +134,93 @@
void OnMetrics(ScreenPublisher::Metrics) override {}
};
+class ConnectionClientObserver final
+ : public ProtocolConnectionServiceObserver {
+ public:
+ ~ConnectionClientObserver() override = default;
+ void OnRunning() override {}
+ void OnStopped() override {}
+
+ void OnMetrics(const NetworkMetrics& metrics) override {}
+ void OnError(const Error& error) override {}
+};
+
+class ConnectionServerObserver final
+ : public ProtocolConnectionServer::Observer {
+ public:
+ class ConnectionObserver final : public ProtocolConnection::Observer {
+ public:
+ explicit ConnectionObserver(ConnectionServerObserver* parent)
+ : parent_(parent) {}
+ ~ConnectionObserver() override = default;
+
+ void OnConnectionChanged(const ProtocolConnection& connection) override {}
+
+ void OnConnectionClosed(const ProtocolConnection& connection) override {
+ auto& connections = parent_->connections_;
+ connections.erase(
+ std::remove_if(
+ connections.begin(), connections.end(),
+ [this](const std::pair<std::unique_ptr<ConnectionObserver>,
+ std::unique_ptr<ProtocolConnection>>& p) {
+ return p.first.get() == this;
+ }),
+ connections.end());
+ }
+
+ private:
+ ConnectionServerObserver* const parent_;
+ };
+
+ ~ConnectionServerObserver() override = default;
+
+ void OnRunning() override {}
+ void OnStopped() override {}
+ void OnSuspended() override {}
+
+ void OnMetrics(const NetworkMetrics& metrics) override {}
+ void OnError(const Error& error) override {}
+
+ void OnIncomingConnection(
+ std::unique_ptr<ProtocolConnection>&& connection) override {
+ auto observer = MakeUnique<ConnectionObserver>(this);
+ connection->SetObserver(observer.get());
+ connections_.emplace_back(std::move(observer), std::move(connection));
+ connections_.back().second->CloseWriteEnd();
+ }
+
+ private:
+ std::vector<std::pair<std::unique_ptr<ConnectionObserver>,
+ std::unique_ptr<ProtocolConnection>>>
+ connections_;
+};
+
+class ConnectionMessageCallback final : public MessageDemuxer::MessageCallback {
+ public:
+ ~ConnectionMessageCallback() override = default;
+
+ ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
+ uint64_t connection_id,
+ msgs::Type message_type,
+ const uint8_t* buffer,
+ size_t buffer_size) override {
+ msgs::PresentationConnectionMessage message;
+ ssize_t result = msgs::DecodePresentationConnectionMessage(
+ buffer, buffer_size, &message);
+ if (result < 0) {
+ // TODO(btolsch): Need something better than including tinycbor.
+ if (result == -CborErrorUnexpectedEOF) {
+ return Error::Code::kCborIncompleteMessage;
+ } else {
+ return Error::Code::kCborParsing;
+ }
+ } else {
+ OSP_LOG_INFO << "message: " << message.message.str;
+ return result;
+ }
+ }
+};
+
void ListenerDemo() {
SignalThings();
@@ -92,16 +229,23 @@
auto mdns_listener =
MdnsScreenListenerFactory::Create(listener_config, &listener_observer);
+ MessageDemuxer demuxer;
+ ConnectionClientObserver client_observer;
+ auto connection_client =
+ ProtocolConnectionClientFactory::Create(&demuxer, &client_observer);
+
auto* network_service = NetworkServiceManager::Create(
- std::move(mdns_listener), nullptr, nullptr, nullptr);
+ std::move(mdns_listener), nullptr, std::move(connection_client), nullptr);
network_service->GetMdnsScreenListener()->Start();
+ network_service->GetProtocolConnectionClient()->Start();
while (!g_done) {
network_service->RunEventLoopOnce();
}
network_service->GetMdnsScreenListener()->Stop();
+ network_service->GetProtocolConnectionClient()->Stop();
NetworkServiceManager::Dispose();
}
@@ -118,16 +262,34 @@
publisher_config.connection_server_port = 6667;
auto mdns_publisher =
MdnsScreenPublisherFactory::Create(publisher_config, &publisher_observer);
- auto* network_service = NetworkServiceManager::Create(
- nullptr, std::move(mdns_publisher), nullptr, nullptr);
+
+ ServerConfig server_config;
+ std::vector<platform::InterfaceAddresses> interfaces =
+ platform::GetInterfaceAddresses();
+ for (const auto& interface : interfaces) {
+ server_config.connection_endpoints.push_back(
+ IPEndpoint{interface.addresses[0].address, 6667});
+ }
+ MessageDemuxer demuxer;
+ ConnectionMessageCallback message_callback;
+ MessageDemuxer::MessageWatch message_watch = demuxer.WatchMessageType(
+ 0, msgs::Type::kPresentationConnectionMessage, &message_callback);
+ ConnectionServerObserver server_observer;
+ auto connection_server = ProtocolConnectionServerFactory::Create(
+ server_config, &demuxer, &server_observer);
+ auto* network_service =
+ NetworkServiceManager::Create(nullptr, std::move(mdns_publisher), nullptr,
+ std::move(connection_server));
network_service->GetMdnsScreenPublisher()->Start();
+ network_service->GetProtocolConnectionServer()->Start();
while (!g_done) {
network_service->RunEventLoopOnce();
}
network_service->GetMdnsScreenPublisher()->Stop();
+ network_service->GetProtocolConnectionServer()->Stop();
NetworkServiceManager::Dispose();
}
@@ -136,6 +298,7 @@
} // namespace openscreen
int main(int argc, char** argv) {
+ openscreen::platform::LogInit(nullptr);
openscreen::platform::SetLogLevel(openscreen::platform::LogLevel::kVerbose,
1);
if (argc == 1) {
diff --git a/discovery/mdns/embedder_demo.cc b/discovery/mdns/embedder_demo.cc
index 174878f..bd7bdc0 100644
--- a/discovery/mdns/embedder_demo.cc
+++ b/discovery/mdns/embedder_demo.cc
@@ -107,8 +107,7 @@
DestroyUdpSocket(socket);
continue;
}
- if (!BindUdpSocket(socket, IPEndpoint{IPAddress{0, 0, 0, 0}, 5353},
- ifindex)) {
+ if (!BindUdpSocket(socket, {{}, 5353}, ifindex)) {
OSP_LOG_ERROR << "bind failed for interface " << ifindex << ": "
<< platform::GetLastErrorString();
DestroyUdpSocket(socket);
@@ -314,6 +313,7 @@
} // namespace openscreen
int main(int argc, char** argv) {
+ openscreen::platform::LogInit(nullptr);
openscreen::platform::SetLogLevel(openscreen::platform::LogLevel::kVerbose,
0);
std::string service_instance;
diff --git a/platform/api/logging.h b/platform/api/logging.h
index 7cec606..c323810 100644
--- a/platform/api/logging.h
+++ b/platform/api/logging.h
@@ -22,7 +22,8 @@
//
// PLATFORM IMPLEMENTATION
-// The follow functions must be implemented by the platform.
+// The following functions must be implemented by the platform.
+void LogInit(const char* filename);
void SetLogLevel(LogLevel level, int verbose_level = 0);
void LogWithLevel(LogLevel level,
int verbose_level,
diff --git a/platform/base/logging.cc b/platform/base/logging.cc
index 56cc4a8..8c203bf 100644
--- a/platform/base/logging.cc
+++ b/platform/base/logging.cc
@@ -2,8 +2,11 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include <unistd.h>
+
#include <cstdlib>
#include <iostream>
+#include <sstream>
#include "platform/api/logging.h"
@@ -39,6 +42,8 @@
} // namespace
+int g_log_fd;
+
void SetLogLevel(LogLevel level, int verbose_level) {
g_log_level = CombinedLogLevel{level, verbose_level};
}
@@ -51,8 +56,10 @@
if (CombinedLogLevel{level, verbose_level} < g_log_level)
return;
- std::cout << "[" << CombinedLogLevel{level, verbose_level} << ":" << file
- << ":" << line << "] " << msg << std::endl;
+ std::stringstream ss;
+ ss << "[" << CombinedLogLevel{level, verbose_level} << ":" << file << ":"
+ << line << "] " << msg << std::endl;
+ write(g_log_fd, ss.str().c_str(), ss.str().size());
}
void Break() {
diff --git a/platform/posix/BUILD.gn b/platform/posix/BUILD.gn
index 1159b30..030b2db 100644
--- a/platform/posix/BUILD.gn
+++ b/platform/posix/BUILD.gn
@@ -13,6 +13,7 @@
sources = [
"error.cc",
"event_waiter.cc",
+ "logging.cc",
"socket.cc",
"socket.h",
]
diff --git a/platform/posix/logging.cc b/platform/posix/logging.cc
new file mode 100644
index 0000000..861ca92
--- /dev/null
+++ b/platform/posix/logging.cc
@@ -0,0 +1,32 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+namespace openscreen {
+namespace platform {
+
+extern int g_log_fd;
+
+void LogInit(const char* filename) {
+ struct stat st = {};
+ if (stat(filename, &st) == -1 && errno == ENOENT) {
+ if (mkfifo(filename, 0644) == 0) {
+ g_log_fd = open(filename, O_WRONLY);
+ } else {
+ g_log_fd = STDOUT_FILENO;
+ }
+ } else if (S_ISFIFO(st.st_mode)) {
+ g_log_fd = open(filename, O_WRONLY);
+ } else {
+ g_log_fd = STDOUT_FILENO;
+ }
+}
+
+} // namespace platform
+} // namespace openscreen
diff --git a/platform/posix/socket.cc b/platform/posix/socket.cc
index c56ce3d..7be9679 100644
--- a/platform/posix/socket.cc
+++ b/platform/posix/socket.cc
@@ -101,7 +101,8 @@
struct sockaddr_in address;
address.sin_family = AF_INET;
address.sin_port = htons(endpoint.port);
- address.sin_addr.s_addr = INADDR_ANY;
+ endpoint.address.CopyToV4(
+ reinterpret_cast<uint8_t*>(&address.sin_addr.s_addr));
return bind(socket->fd, reinterpret_cast<struct sockaddr*>(&address),
sizeof(address)) != -1;
} else {
@@ -124,7 +125,7 @@
address.sin6_family = AF_INET6;
address.sin6_flowinfo = 0;
address.sin6_port = htons(endpoint.port);
- address.sin6_addr = IN6ADDR_ANY_INIT;
+ endpoint.address.CopyToV6(reinterpret_cast<uint8_t*>(&address.sin6_addr));
address.sin6_scope_id = 0;
return bind(socket->fd, reinterpret_cast<struct sockaddr*>(&address),
sizeof(address)) != -1;