| // Copyright 2020 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <openssl/evp.h> |
| #include <openssl/mem.h> |
| |
| #include <atomic> |
| #include <chrono> |
| |
| #include "cast/common/certificate/cast_trust_store.h" |
| #include "cast/common/certificate/testing/test_helpers.h" |
| #include "cast/common/channel/connection_namespace_handler.h" |
| #include "cast/common/channel/message_util.h" |
| #include "cast/common/channel/virtual_connection_router.h" |
| #include "cast/common/public/cast_socket.h" |
| #include "cast/receiver/channel/device_auth_namespace_handler.h" |
| #include "cast/receiver/channel/static_credentials.h" |
| #include "cast/receiver/public/receiver_socket_factory.h" |
| #include "cast/sender/public/sender_socket_factory.h" |
| #include "gmock/gmock.h" |
| #include "gtest/gtest.h" |
| #include "platform/api/serial_delete_ptr.h" |
| #include "platform/api/tls_connection_factory.h" |
| #include "platform/base/tls_connect_options.h" |
| #include "platform/base/tls_credentials.h" |
| #include "platform/base/tls_listen_options.h" |
| #include "platform/impl/logging.h" |
| #include "platform/impl/network_interface.h" |
| #include "platform/impl/platform_client_posix.h" |
| #include "testing/util/task_util.h" |
| #include "util/crypto/certificate_utils.h" |
| #include "util/osp_logging.h" |
| |
| namespace openscreen { |
| namespace cast { |
| namespace { |
| |
| using ::testing::_; |
| using ::testing::StrictMock; |
| |
| constexpr char kLogDecorator[] = "--- "; |
| |
| } // namespace |
| |
| class SenderSocketsClient : public SenderSocketFactory::Client, |
| public VirtualConnectionRouter::SocketErrorHandler { |
| public: |
| explicit SenderSocketsClient(VirtualConnectionRouter* router) // NOLINT |
| : router_(router) {} |
| virtual ~SenderSocketsClient() = default; |
| |
| CastSocket* socket() const { return socket_; } |
| |
| // SenderSocketFactory::Client overrides. |
| void OnConnected(SenderSocketFactory* factory, |
| const IPEndpoint& endpoint, |
| std::unique_ptr<CastSocket> socket) { |
| OSP_CHECK(!socket_); |
| OSP_LOG_INFO << kLogDecorator |
| << "Sender connected to endpoint: " << endpoint; |
| socket_ = socket.get(); |
| router_->TakeSocket(this, std::move(socket)); |
| } |
| |
| void OnError(SenderSocketFactory* factory, |
| const IPEndpoint& endpoint, |
| Error error) override { |
| OSP_LOG_FATAL << error; |
| } |
| |
| // VirtualConnectionRouter::SocketErrorHandler overrides. |
| void OnClose(CastSocket* socket) override { |
| socket_ = nullptr; |
| OnCloseMock(socket); |
| } |
| void OnError(CastSocket* socket, Error error) override { |
| socket_ = nullptr; |
| OnErrorMock(socket, std::move(error)); |
| } |
| |
| MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ()); |
| MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ()); |
| |
| private: |
| VirtualConnectionRouter* const router_; |
| std::atomic<CastSocket*> socket_{nullptr}; |
| }; |
| |
| class ReceiverSocketsClient |
| : public ReceiverSocketFactory::Client, |
| public VirtualConnectionRouter::SocketErrorHandler { |
| public: |
| explicit ReceiverSocketsClient(VirtualConnectionRouter* router) |
| : router_(router) {} |
| virtual ~ReceiverSocketsClient() = default; |
| |
| const IPEndpoint& endpoint() const { return endpoint_; } |
| CastSocket* socket() const { return socket_; } |
| |
| // ReceiverSocketFactory::Client overrides. |
| void OnConnected(ReceiverSocketFactory* factory, |
| const IPEndpoint& endpoint, |
| std::unique_ptr<CastSocket> socket) override { |
| OSP_CHECK(!socket_); |
| OSP_LOG_INFO << kLogDecorator |
| << "Receiver got connection from endpoint: " << endpoint; |
| endpoint_ = endpoint; |
| socket_ = socket.get(); |
| router_->TakeSocket(this, std::move(socket)); |
| } |
| |
| void OnError(ReceiverSocketFactory* factory, Error error) override { |
| OSP_LOG_FATAL << error; |
| } |
| |
| // VirtualConnectionRouter::SocketErrorHandler overrides. |
| void OnClose(CastSocket* socket) override { |
| socket_ = nullptr; |
| OnCloseMock(socket); |
| } |
| void OnError(CastSocket* socket, Error error) override { |
| socket_ = nullptr; |
| OnErrorMock(socket, std::move(error)); |
| } |
| |
| MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ()); |
| MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ()); |
| |
| private: |
| VirtualConnectionRouter* router_; |
| IPEndpoint endpoint_; |
| std::atomic<CastSocket*> socket_{nullptr}; |
| }; |
| |
| class CastSocketE2ETest : public ::testing::Test { |
| public: |
| void SetUp() override { |
| PlatformClientPosix::Create(std::chrono::milliseconds(10), |
| std::chrono::milliseconds(0)); |
| task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner(); |
| |
| sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_); |
| sender_client_ = |
| std::make_unique<StrictMock<SenderSocketsClient>>(sender_router_.get()); |
| sender_factory_ = MakeSerialDelete<SenderSocketFactory>( |
| task_runner_, sender_client_.get(), task_runner_); |
| sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>( |
| task_runner_, |
| TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_) |
| .release()); |
| sender_factory_->set_factory(sender_tls_factory_.get()); |
| |
| ErrorOr<GeneratedCredentials> creds = |
| GenerateCredentialsForTesting("Device ID"); |
| ASSERT_TRUE(creds.is_value()); |
| credentials_ = std::move(creds.value()); |
| |
| CastTrustStore::CreateInstanceForTest(credentials_.root_cert_der); |
| auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>( |
| task_runner_, credentials_.provider.get()); |
| receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_); |
| receiver_router_->AddHandlerForLocalId(kPlatformReceiverId, |
| auth_handler_.get()); |
| receiver_client_ = std::make_unique<StrictMock<ReceiverSocketsClient>>( |
| receiver_router_.get()); |
| receiver_factory_ = MakeSerialDelete<ReceiverSocketFactory>( |
| task_runner_, receiver_client_.get(), receiver_router_.get()); |
| |
| receiver_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>( |
| task_runner_, TlsConnectionFactory::CreateFactory( |
| receiver_factory_.get(), task_runner_) |
| .release()); |
| } |
| |
| void TearDown() override { |
| OSP_LOG_INFO << "Shutting down"; |
| sender_router_.reset(); |
| receiver_router_.reset(); |
| receiver_tls_factory_.reset(); |
| receiver_factory_.reset(); |
| auth_handler_.reset(); |
| sender_tls_factory_.reset(); |
| sender_factory_.reset(); |
| CastTrustStore::ResetInstance(); |
| PlatformClientPosix::ShutDown(); |
| } |
| |
| protected: |
| IPAddress GetLoopbackV4Address() { |
| absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting(); |
| OSP_CHECK(loopback); |
| IPAddress address = loopback->GetIpAddressV4(); |
| OSP_CHECK(address); |
| return address; |
| } |
| |
| IPAddress GetLoopbackV6Address() { |
| absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting(); |
| OSP_CHECK(loopback); |
| IPAddress address = loopback->GetIpAddressV6(); |
| return address; |
| } |
| |
| void Connect(const IPAddress& address) { |
| uint16_t port = 65321; |
| OSP_LOG_INFO << kLogDecorator << "Starting socket factories"; |
| task_runner_->PostTask([this, &address, port]() { |
| OSP_LOG_INFO << kLogDecorator << "Receiver TLS factory Listen()"; |
| receiver_tls_factory_->SetListenCredentials(credentials_.tls_credentials); |
| receiver_tls_factory_->Listen(IPEndpoint{address, port}, |
| TlsListenOptions{1u}); |
| }); |
| |
| task_runner_->PostTask([this, &address, port]() { |
| OSP_LOG_INFO << kLogDecorator << "Sender CastSocket factory Connect()"; |
| sender_factory_->Connect(IPEndpoint{address, port}, |
| SenderSocketFactory::DeviceMediaPolicy::kNone, |
| sender_router_.get()); |
| }); |
| |
| WaitForCondition([this]() { return sender_client_->socket(); }); |
| } |
| |
| void ConnectSocketsV4() { |
| OSP_LOG_INFO << "Getting loopback IPv4 address"; |
| IPAddress loopback_address = GetLoopbackV4Address(); |
| OSP_LOG_INFO << "Connecting CastSockets"; |
| Connect(loopback_address); |
| } |
| |
| template <typename SocketClient, typename PeerSocketClient> |
| void CloseSocketsFromOneEnd(VirtualConnectionRouter* router, |
| SocketClient* client, |
| PeerSocketClient* peer_client) { |
| // TODO(issuetracker.google.com/169967989): Would like to have a symmetric |
| // OnClose check. |
| EXPECT_CALL(*client, OnCloseMock(client->socket())); |
| EXPECT_CALL(*peer_client, OnErrorMock(peer_client->socket(), _)) |
| .WillOnce([](CastSocket* socket, Error error) { |
| EXPECT_EQ(error.code(), Error::Code::kSocketClosedFailure); |
| }); |
| int32_t id = client->socket()->socket_id(); |
| std::atomic_bool did_run{false}; |
| task_runner_->PostTask([id, router, &did_run]() { |
| router->CloseSocket(id); |
| did_run = true; |
| }); |
| OSP_LOG_INFO << "Waiting for socket to close"; |
| WaitForCondition([&did_run]() { return did_run.load(); }); |
| EXPECT_FALSE(sender_client_->socket()); |
| EXPECT_FALSE(receiver_client_->socket()); |
| } |
| |
| TaskRunner* task_runner_; |
| |
| // NOTE: Sender components. |
| SerialDeletePtr<VirtualConnectionRouter> sender_router_; |
| std::unique_ptr<StrictMock<SenderSocketsClient>> sender_client_; |
| SerialDeletePtr<SenderSocketFactory> sender_factory_; |
| SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_; |
| |
| // NOTE: Receiver components. |
| SerialDeletePtr<VirtualConnectionRouter> receiver_router_; |
| GeneratedCredentials credentials_; |
| SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_; |
| std::unique_ptr<StrictMock<ReceiverSocketsClient>> receiver_client_; |
| SerialDeletePtr<ReceiverSocketFactory> receiver_factory_; |
| SerialDeletePtr<TlsConnectionFactory> receiver_tls_factory_; |
| }; |
| |
| // These test the most basic setup of a complete CastSocket. This means |
| // constructing both a SenderSocketFactory and ReceiverSocketFactory, making a |
| // TLS connection to a known port over the loopback device, and checking device |
| // authentication. |
| TEST_F(CastSocketE2ETest, ConnectV4) { |
| ConnectSocketsV4(); |
| } |
| |
| TEST_F(CastSocketE2ETest, ConnectV6) { |
| OSP_LOG_INFO << "Getting loopback IPv6 address"; |
| IPAddress loopback_address = GetLoopbackV6Address(); |
| if (loopback_address) { |
| OSP_LOG_INFO << "Connecting CastSockets"; |
| Connect(loopback_address); |
| } else { |
| OSP_LOG_WARN << "Test skipped due to missing IPv6 loopback address"; |
| } |
| } |
| |
| TEST_F(CastSocketE2ETest, SenderClose) { |
| ConnectSocketsV4(); |
| |
| CloseSocketsFromOneEnd(sender_router_.get(), sender_client_.get(), |
| receiver_client_.get()); |
| } |
| |
| TEST_F(CastSocketE2ETest, ReceiverClose) { |
| ConnectSocketsV4(); |
| |
| CloseSocketsFromOneEnd(receiver_router_.get(), receiver_client_.get(), |
| sender_client_.get()); |
| } |
| |
| } // namespace cast |
| } // namespace openscreen |