| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "cast/sender/public/sender_socket_factory.h" |
| |
| #include "cast/common/channel/proto/cast_channel.pb.h" |
| #include "cast/sender/channel/cast_auth_util.h" |
| #include "cast/sender/channel/message_util.h" |
| #include "platform/base/tls_connect_options.h" |
| #include "util/crypto/certificate_utils.h" |
| #include "util/osp_logging.h" |
| |
| using ::cast::channel::CastMessage; |
| |
| namespace openscreen { |
| namespace cast { |
| |
| bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a, |
| int b) { |
| return a && a->socket->socket_id() < b; |
| } |
| |
| bool operator<(int a, |
| const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) { |
| return b && a < b->socket->socket_id(); |
| } |
| |
| SenderSocketFactory::SenderSocketFactory(Client* client, |
| TaskRunner* task_runner) |
| : client_(client), task_runner_(task_runner) { |
| OSP_DCHECK(client); |
| OSP_DCHECK(task_runner); |
| } |
| |
| SenderSocketFactory::~SenderSocketFactory() { |
| OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); |
| } |
| |
| void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) { |
| OSP_DCHECK(factory); |
| factory_ = factory; |
| } |
| |
| void SenderSocketFactory::Connect(const IPEndpoint& endpoint, |
| DeviceMediaPolicy media_policy, |
| CastSocket::Client* client) { |
| OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); |
| OSP_DCHECK(client); |
| auto it = FindPendingConnection(endpoint); |
| if (it == pending_connections_.end()) { |
| pending_connections_.emplace_back( |
| PendingConnection{endpoint, media_policy, client}); |
| factory_->Connect(endpoint, TlsConnectOptions{true}); |
| } |
| } |
| |
| void SenderSocketFactory::OnAccepted( |
| TlsConnectionFactory* factory, |
| std::vector<uint8_t> der_x509_peer_cert, |
| std::unique_ptr<TlsConnection> connection) { |
| OSP_NOTREACHED(); |
| OSP_LOG_FATAL << "This factory is connect-only"; |
| } |
| |
| void SenderSocketFactory::OnConnected( |
| TlsConnectionFactory* factory, |
| std::vector<uint8_t> der_x509_peer_cert, |
| std::unique_ptr<TlsConnection> connection) { |
| const IPEndpoint& endpoint = connection->GetRemoteEndpoint(); |
| auto it = FindPendingConnection(endpoint); |
| if (it == pending_connections_.end()) { |
| OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: " |
| << endpoint; |
| return; |
| } |
| DeviceMediaPolicy media_policy = it->media_policy; |
| CastSocket::Client* client = it->client; |
| pending_connections_.erase(it); |
| |
| ErrorOr<bssl::UniquePtr<X509>> peer_cert = |
| ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size()); |
| if (!peer_cert) { |
| client_->OnError(this, endpoint, peer_cert.error()); |
| return; |
| } |
| |
| auto socket = |
| MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this); |
| pending_auth_.emplace_back( |
| new PendingAuth{endpoint, media_policy, std::move(socket), client, |
| std::make_unique<AuthContext>(AuthContext::Create()), |
| std::move(peer_cert.value())}); |
| PendingAuth& pending = *pending_auth_.back(); |
| |
| CastMessage auth_challenge = |
| CreateAuthChallengeMessage(*pending.auth_context); |
| Error error = pending.socket->Send(auth_challenge); |
| if (!error.ok()) { |
| pending_auth_.pop_back(); |
| client_->OnError(this, endpoint, error); |
| } |
| } |
| |
| void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory, |
| const IPEndpoint& remote_address) { |
| auto it = FindPendingConnection(remote_address); |
| if (it == pending_connections_.end()) { |
| OSP_DVLOG << "OnConnectionFailed reported for untracked address: " |
| << remote_address; |
| return; |
| } |
| pending_connections_.erase(it); |
| client_->OnError(this, remote_address, Error::Code::kConnectionFailed); |
| } |
| |
| void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) { |
| std::vector<PendingConnection> connections; |
| pending_connections_.swap(connections); |
| for (const PendingConnection& pending : connections) { |
| client_->OnError(this, pending.endpoint, error); |
| } |
| } |
| |
| std::vector<SenderSocketFactory::PendingConnection>::iterator |
| SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) { |
| return std::find_if(pending_connections_.begin(), pending_connections_.end(), |
| [&endpoint](const PendingConnection& pending) { |
| return pending.endpoint == endpoint; |
| }); |
| } |
| |
| void SenderSocketFactory::OnError(CastSocket* socket, Error error) { |
| auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(), |
| [id = socket->socket_id()]( |
| const std::unique_ptr<PendingAuth>& pending_auth) { |
| return pending_auth->socket->socket_id() == id; |
| }); |
| if (it == pending_auth_.end()) { |
| OSP_DLOG_ERROR << "Got error for unknown pending socket"; |
| return; |
| } |
| IPEndpoint endpoint = (*it)->endpoint; |
| pending_auth_.erase(it); |
| client_->OnError(this, endpoint, error); |
| } |
| |
| void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) { |
| auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(), |
| [id = socket->socket_id()]( |
| const std::unique_ptr<PendingAuth>& pending_auth) { |
| return pending_auth->socket->socket_id() == id; |
| }); |
| if (it == pending_auth_.end()) { |
| OSP_DLOG_ERROR << "Got message for unknown pending socket"; |
| return; |
| } |
| |
| std::unique_ptr<PendingAuth> pending = std::move(*it); |
| pending_auth_.erase(it); |
| if (!IsAuthMessage(message)) { |
| client_->OnError(this, pending->endpoint, |
| Error::Code::kCastV2AuthenticationError); |
| return; |
| } |
| |
| ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply( |
| message, pending->peer_cert.get(), *pending->auth_context); |
| if (policy_or_error.is_error()) { |
| OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint |
| << " with error: " << policy_or_error.error(); |
| client_->OnError(this, pending->endpoint, policy_or_error.error()); |
| return; |
| } |
| |
| if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly && |
| pending->media_policy == DeviceMediaPolicy::kIncludesVideo) { |
| client_->OnError(this, pending->endpoint, |
| Error::Code::kCastV2ChannelPolicyMismatch); |
| return; |
| } |
| pending->socket->set_audio_only(policy_or_error.value() == |
| CastDeviceCertPolicy::kAudioOnly); |
| |
| pending->socket->SetClient(pending->client); |
| client_->OnConnected(this, pending->endpoint, |
| std::unique_ptr<CastSocket>(pending->socket.release())); |
| } |
| |
| } // namespace cast |
| } // namespace openscreen |