blob: e971976bebd6030c5a58ad7d30a258f2657a7539 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "cast/sender/public/sender_socket_factory.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/sender/channel/cast_auth_util.h"
#include "cast/sender/channel/message_util.h"
#include "platform/base/tls_connect_options.h"
#include "util/crypto/certificate_utils.h"
#include "util/osp_logging.h"
using ::cast::channel::CastMessage;
namespace openscreen {
namespace cast {
bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
int b) {
return a && a->socket->socket_id() < b;
}
bool operator<(int a,
const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
return b && a < b->socket->socket_id();
}
SenderSocketFactory::SenderSocketFactory(Client* client,
TaskRunner* task_runner)
: client_(client), task_runner_(task_runner) {
OSP_DCHECK(client);
OSP_DCHECK(task_runner);
}
SenderSocketFactory::~SenderSocketFactory() {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
}
void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
OSP_DCHECK(factory);
factory_ = factory;
}
void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
DeviceMediaPolicy media_policy,
CastSocket::Client* client) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(client);
auto it = FindPendingConnection(endpoint);
if (it == pending_connections_.end()) {
pending_connections_.emplace_back(
PendingConnection{endpoint, media_policy, client});
factory_->Connect(endpoint, TlsConnectOptions{true});
}
}
void SenderSocketFactory::OnAccepted(
TlsConnectionFactory* factory,
std::vector<uint8_t> der_x509_peer_cert,
std::unique_ptr<TlsConnection> connection) {
OSP_NOTREACHED();
OSP_LOG_FATAL << "This factory is connect-only";
}
void SenderSocketFactory::OnConnected(
TlsConnectionFactory* factory,
std::vector<uint8_t> der_x509_peer_cert,
std::unique_ptr<TlsConnection> connection) {
const IPEndpoint& endpoint = connection->GetRemoteEndpoint();
auto it = FindPendingConnection(endpoint);
if (it == pending_connections_.end()) {
OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
<< endpoint;
return;
}
DeviceMediaPolicy media_policy = it->media_policy;
CastSocket::Client* client = it->client;
pending_connections_.erase(it);
ErrorOr<bssl::UniquePtr<X509>> peer_cert =
ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size());
if (!peer_cert) {
client_->OnError(this, endpoint, peer_cert.error());
return;
}
auto socket =
MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this);
pending_auth_.emplace_back(
new PendingAuth{endpoint, media_policy, std::move(socket), client,
std::make_unique<AuthContext>(AuthContext::Create()),
std::move(peer_cert.value())});
PendingAuth& pending = *pending_auth_.back();
CastMessage auth_challenge =
CreateAuthChallengeMessage(*pending.auth_context);
Error error = pending.socket->Send(auth_challenge);
if (!error.ok()) {
pending_auth_.pop_back();
client_->OnError(this, endpoint, error);
}
}
void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
const IPEndpoint& remote_address) {
auto it = FindPendingConnection(remote_address);
if (it == pending_connections_.end()) {
OSP_DVLOG << "OnConnectionFailed reported for untracked address: "
<< remote_address;
return;
}
pending_connections_.erase(it);
client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
}
void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
std::vector<PendingConnection> connections;
pending_connections_.swap(connections);
for (const PendingConnection& pending : connections) {
client_->OnError(this, pending.endpoint, error);
}
}
std::vector<SenderSocketFactory::PendingConnection>::iterator
SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
return std::find_if(pending_connections_.begin(), pending_connections_.end(),
[&endpoint](const PendingConnection& pending) {
return pending.endpoint == endpoint;
});
}
void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
[id = socket->socket_id()](
const std::unique_ptr<PendingAuth>& pending_auth) {
return pending_auth->socket->socket_id() == id;
});
if (it == pending_auth_.end()) {
OSP_DLOG_ERROR << "Got error for unknown pending socket";
return;
}
IPEndpoint endpoint = (*it)->endpoint;
pending_auth_.erase(it);
client_->OnError(this, endpoint, error);
}
void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
[id = socket->socket_id()](
const std::unique_ptr<PendingAuth>& pending_auth) {
return pending_auth->socket->socket_id() == id;
});
if (it == pending_auth_.end()) {
OSP_DLOG_ERROR << "Got message for unknown pending socket";
return;
}
std::unique_ptr<PendingAuth> pending = std::move(*it);
pending_auth_.erase(it);
if (!IsAuthMessage(message)) {
client_->OnError(this, pending->endpoint,
Error::Code::kCastV2AuthenticationError);
return;
}
ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
message, pending->peer_cert.get(), *pending->auth_context);
if (policy_or_error.is_error()) {
OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint
<< " with error: " << policy_or_error.error();
client_->OnError(this, pending->endpoint, policy_or_error.error());
return;
}
if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
pending->media_policy == DeviceMediaPolicy::kIncludesVideo) {
client_->OnError(this, pending->endpoint,
Error::Code::kCastV2ChannelPolicyMismatch);
return;
}
pending->socket->set_audio_only(policy_or_error.value() ==
CastDeviceCertPolicy::kAudioOnly);
pending->socket->SetClient(pending->client);
client_->OnConnected(this, pending->endpoint,
std::unique_ptr<CastSocket>(pending->socket.release()));
}
} // namespace cast
} // namespace openscreen