blob: 446a880c82a5ff0ef9318d4647bd2489db2fd72f [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <openssl/evp.h>
#include <openssl/mem.h>
#include <atomic>
#include <chrono>
#include "cast/common/certificate/cast_trust_store.h"
#include "cast/common/certificate/testing/test_helpers.h"
#include "cast/common/channel/connection_namespace_handler.h"
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "cast/receiver/channel/device_auth_namespace_handler.h"
#include "cast/receiver/channel/static_credentials.h"
#include "cast/receiver/public/receiver_socket_factory.h"
#include "cast/sender/public/sender_socket_factory.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/api/serial_delete_ptr.h"
#include "platform/api/tls_connection_factory.h"
#include "platform/base/tls_connect_options.h"
#include "platform/base/tls_credentials.h"
#include "platform/base/tls_listen_options.h"
#include "platform/impl/logging.h"
#include "platform/impl/network_interface.h"
#include "platform/impl/platform_client_posix.h"
#include "testing/util/task_util.h"
#include "util/crypto/certificate_utils.h"
#include "util/osp_logging.h"
namespace openscreen {
namespace cast {
namespace {
using ::testing::_;
using ::testing::StrictMock;
constexpr char kLogDecorator[] = "--- ";
} // namespace
class SenderSocketsClient : public SenderSocketFactory::Client,
public VirtualConnectionRouter::SocketErrorHandler {
public:
explicit SenderSocketsClient(VirtualConnectionRouter* router) // NOLINT
: router_(router) {}
virtual ~SenderSocketsClient() = default;
CastSocket* socket() const { return socket_; }
// SenderSocketFactory::Client overrides.
void OnConnected(SenderSocketFactory* factory,
const IPEndpoint& endpoint,
std::unique_ptr<CastSocket> socket) {
OSP_CHECK(!socket_);
OSP_LOG_INFO << kLogDecorator
<< "Sender connected to endpoint: " << endpoint;
socket_ = socket.get();
router_->TakeSocket(this, std::move(socket));
}
void OnError(SenderSocketFactory* factory,
const IPEndpoint& endpoint,
Error error) override {
OSP_LOG_FATAL << error;
}
// VirtualConnectionRouter::SocketErrorHandler overrides.
void OnClose(CastSocket* socket) override {
socket_ = nullptr;
OnCloseMock(socket);
}
void OnError(CastSocket* socket, Error error) override {
socket_ = nullptr;
OnErrorMock(socket, std::move(error));
}
MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
private:
VirtualConnectionRouter* const router_;
std::atomic<CastSocket*> socket_{nullptr};
};
class ReceiverSocketsClient
: public ReceiverSocketFactory::Client,
public VirtualConnectionRouter::SocketErrorHandler {
public:
explicit ReceiverSocketsClient(VirtualConnectionRouter* router)
: router_(router) {}
virtual ~ReceiverSocketsClient() = default;
const IPEndpoint& endpoint() const { return endpoint_; }
CastSocket* socket() const { return socket_; }
// ReceiverSocketFactory::Client overrides.
void OnConnected(ReceiverSocketFactory* factory,
const IPEndpoint& endpoint,
std::unique_ptr<CastSocket> socket) override {
OSP_CHECK(!socket_);
OSP_LOG_INFO << kLogDecorator
<< "Receiver got connection from endpoint: " << endpoint;
endpoint_ = endpoint;
socket_ = socket.get();
router_->TakeSocket(this, std::move(socket));
}
void OnError(ReceiverSocketFactory* factory, Error error) override {
OSP_LOG_FATAL << error;
}
// VirtualConnectionRouter::SocketErrorHandler overrides.
void OnClose(CastSocket* socket) override {
socket_ = nullptr;
OnCloseMock(socket);
}
void OnError(CastSocket* socket, Error error) override {
socket_ = nullptr;
OnErrorMock(socket, std::move(error));
}
MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
private:
VirtualConnectionRouter* router_;
IPEndpoint endpoint_;
std::atomic<CastSocket*> socket_{nullptr};
};
class CastSocketE2ETest : public ::testing::Test {
public:
void SetUp() override {
PlatformClientPosix::Create(std::chrono::milliseconds(10));
task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner();
sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
sender_client_ =
std::make_unique<StrictMock<SenderSocketsClient>>(sender_router_.get());
sender_factory_ = MakeSerialDelete<SenderSocketFactory>(
task_runner_, sender_client_.get(), task_runner_);
sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
task_runner_,
TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_)
.release());
sender_factory_->set_factory(sender_tls_factory_.get());
ErrorOr<GeneratedCredentials> creds =
GenerateCredentialsForTesting("Device ID");
ASSERT_TRUE(creds.is_value());
credentials_ = std::move(creds.value());
CastTrustStore::CreateInstanceForTest(credentials_.root_cert_der);
auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>(
task_runner_, credentials_.provider.get());
receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
receiver_router_->AddHandlerForLocalId(kPlatformReceiverId,
auth_handler_.get());
receiver_client_ = std::make_unique<StrictMock<ReceiverSocketsClient>>(
receiver_router_.get());
receiver_factory_ = MakeSerialDelete<ReceiverSocketFactory>(
task_runner_, receiver_client_.get(), receiver_router_.get());
receiver_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
task_runner_, TlsConnectionFactory::CreateFactory(
receiver_factory_.get(), task_runner_)
.release());
}
void TearDown() override {
OSP_LOG_INFO << "Shutting down";
sender_router_.reset();
receiver_router_.reset();
receiver_tls_factory_.reset();
receiver_factory_.reset();
auth_handler_.reset();
sender_tls_factory_.reset();
sender_factory_.reset();
CastTrustStore::ResetInstance();
PlatformClientPosix::ShutDown();
}
protected:
IPAddress GetLoopbackV4Address() {
absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
OSP_CHECK(loopback);
IPAddress address = loopback->GetIpAddressV4();
OSP_CHECK(address);
return address;
}
IPAddress GetLoopbackV6Address() {
absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
OSP_CHECK(loopback);
IPAddress address = loopback->GetIpAddressV6();
return address;
}
void Connect(const IPAddress& address) {
uint16_t port = 65321;
OSP_LOG_INFO << kLogDecorator << "Starting socket factories";
task_runner_->PostTask([this, &address, port]() {
OSP_LOG_INFO << kLogDecorator << "Receiver TLS factory Listen()";
receiver_tls_factory_->SetListenCredentials(credentials_.tls_credentials);
receiver_tls_factory_->Listen(IPEndpoint{address, port},
TlsListenOptions{1u});
});
task_runner_->PostTask([this, &address, port]() {
OSP_LOG_INFO << kLogDecorator << "Sender CastSocket factory Connect()";
sender_factory_->Connect(IPEndpoint{address, port},
SenderSocketFactory::DeviceMediaPolicy::kNone,
sender_router_.get());
});
WaitForCondition([this]() { return sender_client_->socket(); });
}
void ConnectSocketsV4() {
OSP_LOG_INFO << "Getting loopback IPv4 address";
IPAddress loopback_address = GetLoopbackV4Address();
OSP_LOG_INFO << "Connecting CastSockets";
Connect(loopback_address);
}
template <typename SocketClient, typename PeerSocketClient>
void CloseSocketsFromOneEnd(VirtualConnectionRouter* router,
SocketClient* client,
PeerSocketClient* peer_client) {
// TODO(issuetracker.google.com/169967989): Would like to have a symmetric
// OnClose check.
EXPECT_CALL(*client, OnCloseMock(client->socket()));
EXPECT_CALL(*peer_client, OnErrorMock(peer_client->socket(), _))
.WillOnce([](CastSocket* socket, Error error) {
EXPECT_EQ(error.code(), Error::Code::kSocketClosedFailure);
});
int32_t id = client->socket()->socket_id();
std::atomic_bool did_run{false};
task_runner_->PostTask([id, router, &did_run]() {
router->CloseSocket(id);
did_run = true;
});
OSP_LOG_INFO << "Waiting for socket to close";
WaitForCondition([&did_run]() { return did_run.load(); });
EXPECT_FALSE(sender_client_->socket());
EXPECT_FALSE(receiver_client_->socket());
}
TaskRunner* task_runner_;
// NOTE: Sender components.
SerialDeletePtr<VirtualConnectionRouter> sender_router_;
std::unique_ptr<StrictMock<SenderSocketsClient>> sender_client_;
SerialDeletePtr<SenderSocketFactory> sender_factory_;
SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_;
// NOTE: Receiver components.
SerialDeletePtr<VirtualConnectionRouter> receiver_router_;
GeneratedCredentials credentials_;
SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_;
std::unique_ptr<StrictMock<ReceiverSocketsClient>> receiver_client_;
SerialDeletePtr<ReceiverSocketFactory> receiver_factory_;
SerialDeletePtr<TlsConnectionFactory> receiver_tls_factory_;
};
// These test the most basic setup of a complete CastSocket. This means
// constructing both a SenderSocketFactory and ReceiverSocketFactory, making a
// TLS connection to a known port over the loopback device, and checking device
// authentication.
TEST_F(CastSocketE2ETest, ConnectV4) {
ConnectSocketsV4();
}
TEST_F(CastSocketE2ETest, ConnectV6) {
OSP_LOG_INFO << "Getting loopback IPv6 address";
IPAddress loopback_address = GetLoopbackV6Address();
if (loopback_address) {
OSP_LOG_INFO << "Connecting CastSockets";
Connect(loopback_address);
} else {
OSP_LOG_WARN << "Test skipped due to missing IPv6 loopback address";
}
}
TEST_F(CastSocketE2ETest, SenderClose) {
ConnectSocketsV4();
CloseSocketsFromOneEnd(sender_router_.get(), sender_client_.get(),
receiver_client_.get());
}
TEST_F(CastSocketE2ETest, ReceiverClose) {
ConnectSocketsV4();
CloseSocketsFromOneEnd(receiver_router_.get(), receiver_client_.get(),
sender_client_.get());
}
} // namespace cast
} // namespace openscreen