Add cast sender socket factory

This change adds a CastSocket factory for the sender-side which performs
the sender auth challenge and verification before passing a CastSocket
back to the caller.

Bug: openscreen:59
Change-Id: Ibbbdb2b8881e385cc0a8defbe309c7f10a2af323
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1834457
Commit-Queue: Brandon Tolsch <btolsch@chromium.org>
Reviewed-by: Ryan Keane <rwkeane@google.com>
diff --git a/cast/common/channel/BUILD.gn b/cast/common/channel/BUILD.gn
index 71b393b..7a362e3 100644
--- a/cast/common/channel/BUILD.gn
+++ b/cast/common/channel/BUILD.gn
@@ -8,6 +8,7 @@
     "cast_socket.h",
     "message_framer.cc",
     "message_framer.h",
+    "message_util.h",
   ]
 
   deps = [
diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc
index 8ad6154..ce35e54 100644
--- a/cast/common/channel/cast_socket.cc
+++ b/cast/common/channel/cast_socket.cc
@@ -4,6 +4,8 @@
 
 #include "cast/common/channel/cast_socket.h"
 
+#include <atomic>
+
 #include "cast/common/channel/message_framer.h"
 #include "platform/api/logging.h"
 
@@ -14,6 +16,11 @@
 using openscreen::ErrorOr;
 using openscreen::platform::TlsConnection;
 
+uint32_t GetNextSocketId() {
+  static std::atomic<uint32_t> id(1);
+  return id++;
+}
+
 CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
                        Client* client,
                        uint32_t socket_id)
diff --git a/cast/common/channel/cast_socket.h b/cast/common/channel/cast_socket.h
index a8fa3c4..0c173ef 100644
--- a/cast/common/channel/cast_socket.h
+++ b/cast/common/channel/cast_socket.h
@@ -17,6 +17,8 @@
 
 class CastMessage;
 
+uint32_t GetNextSocketId();
+
 // Represents a simple message-oriented socket for communicating with the Cast
 // V2 protocol.  It isn't thread-safe, so it should only be used on the same
 // TaskRunner thread as its TlsConnection.
@@ -38,9 +40,9 @@
   ~CastSocket();
 
   // Sends |message| immediately unless the underlying TLS connection is
-  // write-blocked, in which case |message| will be queued.  No error is
-  // returned for both queueing and successful sending.  An error will be
-  // returned if |message| cannot be serialized for any reason.
+  // write-blocked, in which case |message| will be queued.  An error will be
+  // returned if |message| cannot be serialized for any reason, even while
+  // write-blocked.
   Error SendMessage(const CastMessage& message);
 
   void set_client(Client* client) {
diff --git a/cast/common/channel/message_util.h b/cast/common/channel/message_util.h
new file mode 100644
index 0000000..5b84dbd
--- /dev/null
+++ b/cast/common/channel/message_util.h
@@ -0,0 +1,39 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_
+#define CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_
+
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace cast {
+namespace channel {
+
+// Reserved message namespaces for internal messages.
+static constexpr char kCastInternalNamespacePrefix[] =
+    "urn:x-cast:com.google.cast.";
+static constexpr char kAuthNamespace[] =
+    "urn:x-cast:com.google.cast.tp.deviceauth";
+static constexpr char kHeartbeatNamespace[] =
+    "urn:x-cast:com.google.cast.tp.heartbeat";
+static constexpr char kConnectionNamespace[] =
+    "urn:x-cast:com.google.cast.tp.connection";
+static constexpr char kReceiverNamespace[] =
+    "urn:x-cast:com.google.cast.receiver";
+static constexpr char kBroadcastNamespace[] =
+    "urn:x-cast:com.google.cast.broadcast";
+static constexpr char kMediaNamespace[] = "urn:x-cast:com.google.cast.media";
+
+// Sender and receiver IDs to use for platform messages.
+static constexpr char kPlatformSenderId[] = "sender-0";
+static constexpr char kPlatformReceiverId[] = "receiver-0";
+
+inline bool IsAuthMessage(const CastMessage& message) {
+  return message.namespace_() == kAuthNamespace;
+}
+
+}  // namespace channel
+}  // namespace cast
+
+#endif  // CAST_COMMON_CHANNEL_MESSAGE_UTIL_H_
diff --git a/cast/sender/channel/BUILD.gn b/cast/sender/channel/BUILD.gn
index 85f633b..5fbed9c 100644
--- a/cast/sender/channel/BUILD.gn
+++ b/cast/sender/channel/BUILD.gn
@@ -6,6 +6,10 @@
   sources = [
     "cast_auth_util.cc",
     "cast_auth_util.h",
+    "message_util.cc",
+    "message_util.h",
+    "sender_socket_factory.cc",
+    "sender_socket_factory.h",
   ]
 
   deps = [
diff --git a/cast/sender/channel/message_util.cc b/cast/sender/channel/message_util.cc
new file mode 100644
index 0000000..ab3ed5d
--- /dev/null
+++ b/cast/sender/channel/message_util.cc
@@ -0,0 +1,34 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/channel/message_util.h"
+
+#include "cast/sender/channel/cast_auth_util.h"
+
+namespace cast {
+namespace channel {
+
+CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context) {
+  CastMessage message;
+  DeviceAuthMessage auth_message;
+
+  AuthChallenge* challenge = auth_message.mutable_challenge();
+  challenge->set_sender_nonce(auth_context.nonce());
+  challenge->set_hash_algorithm(SHA256);
+
+  std::string auth_message_string;
+  auth_message.SerializeToString(&auth_message_string);
+
+  message.set_protocol_version(CastMessage::CASTV2_1_0);
+  message.set_source_id(kPlatformSenderId);
+  message.set_destination_id(kPlatformReceiverId);
+  message.set_namespace_(kAuthNamespace);
+  message.set_payload_type(CastMessage_PayloadType_BINARY);
+  message.set_payload_binary(auth_message_string);
+
+  return message;
+}
+
+}  // namespace channel
+}  // namespace cast
diff --git a/cast/sender/channel/message_util.h b/cast/sender/channel/message_util.h
new file mode 100644
index 0000000..e2da0cd
--- /dev/null
+++ b/cast/sender/channel/message_util.h
@@ -0,0 +1,21 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
+#define CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
+
+#include "cast/common/channel/message_util.h"
+#include "cast/common/channel/proto/cast_channel.pb.h"
+
+namespace cast {
+namespace channel {
+
+class AuthContext;
+
+CastMessage CreateAuthChallengeMessage(const AuthContext& auth_context);
+
+}  // namespace channel
+}  // namespace cast
+
+#endif  // CAST_SENDER_CHANNEL_MESSAGE_UTIL_H_
diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc
new file mode 100644
index 0000000..83e7337
--- /dev/null
+++ b/cast/sender/channel/sender_socket_factory.cc
@@ -0,0 +1,166 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "cast/sender/channel/sender_socket_factory.h"
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/sender/channel/message_util.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::platform::TlsConnectOptions;
+
+bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
+               uint32_t b) {
+  return a && a->socket->socket_id() < b;
+}
+
+bool operator<(uint32_t a,
+               const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
+  return b && a < b->socket->socket_id();
+}
+
+SenderSocketFactory::SenderSocketFactory(Client* client) : client_(client) {
+  OSP_DCHECK(client);
+}
+
+SenderSocketFactory::~SenderSocketFactory() = default;
+
+void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
+                                  DeviceMediaPolicy media_policy,
+                                  CastSocket::Client* client) {
+  OSP_DCHECK(client);
+  auto it = FindPendingConnection(endpoint);
+  if (it == pending_connections_.end()) {
+    pending_connections_.emplace_back(
+        PendingConnection{endpoint, media_policy, client});
+    factory_->Connect(endpoint, TlsConnectOptions{true});
+  }
+}
+
+void SenderSocketFactory::OnAccepted(
+    TlsConnectionFactory* factory,
+    X509* peer_cert,
+    std::unique_ptr<TlsConnection> connection) {
+  OSP_NOTREACHED() << "This factory is connect-only.";
+}
+
+void SenderSocketFactory::OnConnected(
+    TlsConnectionFactory* factory,
+    X509* peer_cert,
+    std::unique_ptr<TlsConnection> connection) {
+  const IPEndpoint& endpoint = connection->remote_address();
+  auto it = FindPendingConnection(endpoint);
+  if (it == pending_connections_.end()) {
+    OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
+                   << endpoint;
+    return;
+  }
+  DeviceMediaPolicy media_policy = it->media_policy;
+  CastSocket::Client* client = it->client;
+  pending_connections_.erase(it);
+
+  if (!peer_cert) {
+    client_->OnError(this, endpoint, Error::Code::kErrCertsMissing);
+    return;
+  }
+
+  auto socket = std::make_unique<CastSocket>(std::move(connection), this,
+                                             GetNextSocketId());
+  pending_auth_.emplace_back(new PendingAuth{endpoint, media_policy,
+                                             std::move(socket), client,
+                                             AuthContext::Create(), peer_cert});
+  PendingAuth& pending = *pending_auth_.back();
+
+  CastMessage auth_challenge = CreateAuthChallengeMessage(pending.auth_context);
+  Error error = pending.socket->SendMessage(auth_challenge);
+  if (!error.ok()) {
+    pending_auth_.pop_back();
+    client_->OnError(this, endpoint, error);
+  }
+}
+
+void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
+                                             const IPEndpoint& remote_address) {
+  auto it = FindPendingConnection(remote_address);
+  if (it == pending_connections_.end()) {
+    OSP_DVLOG << "OnConnectionFailed reported for untracked address: "
+              << remote_address;
+    return;
+  }
+  pending_connections_.erase(it);
+  client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
+}
+
+void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
+  std::vector<PendingConnection> connections;
+  pending_connections_.swap(connections);
+  for (const PendingConnection& pending : connections) {
+    client_->OnError(this, pending.endpoint, error);
+  }
+}
+
+std::vector<SenderSocketFactory::PendingConnection>::iterator
+SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
+  return std::find_if(pending_connections_.begin(), pending_connections_.end(),
+                      [&endpoint](const PendingConnection& pending) {
+                        return pending.endpoint == endpoint;
+                      });
+}
+
+void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
+  auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+                         [id = socket->socket_id()](
+                             const std::unique_ptr<PendingAuth>& pending_auth) {
+                           return pending_auth->socket->socket_id() == id;
+                         });
+  if (it == pending_auth_.end()) {
+    OSP_DLOG_ERROR << "Got error for unknown pending socket";
+    return;
+  }
+  IPEndpoint endpoint = (*it)->endpoint;
+  pending_auth_.erase(it);
+  client_->OnError(this, endpoint, error);
+}
+
+void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
+  auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
+                         [id = socket->socket_id()](
+                             const std::unique_ptr<PendingAuth>& pending_auth) {
+                           return pending_auth->socket->socket_id() == id;
+                         });
+  if (it == pending_auth_.end()) {
+    OSP_DLOG_ERROR << "Got message for unknown pending socket";
+    return;
+  }
+
+  std::unique_ptr<PendingAuth> pending = std::move(*it);
+  pending_auth_.erase(it);
+  if (!IsAuthMessage(message)) {
+    client_->OnError(this, pending->endpoint,
+                     Error::Code::kCastV2AuthenticationError);
+    return;
+  }
+
+  ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
+      message, (*it)->peer_cert, (*it)->auth_context);
+  if (policy_or_error.is_error()) {
+    client_->OnError(this, pending->endpoint, policy_or_error.error());
+    return;
+  }
+
+  if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
+      pending->media_policy != DeviceMediaPolicy::kAudioOnly) {
+    client_->OnError(this, pending->endpoint,
+                     Error::Code::kCastV2ChannelPolicyMismatch);
+    return;
+  }
+
+  pending->socket->set_client(pending->client);
+  client_->OnConnected(this, pending->endpoint, std::move(pending->socket));
+}
+
+}  // namespace channel
+}  // namespace cast
diff --git a/cast/sender/channel/sender_socket_factory.h b/cast/sender/channel/sender_socket_factory.h
new file mode 100644
index 0000000..d5b6622
--- /dev/null
+++ b/cast/sender/channel/sender_socket_factory.h
@@ -0,0 +1,104 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
+#define CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "cast/common/channel/cast_socket.h"
+#include "cast/sender/channel/cast_auth_util.h"
+#include "platform/api/logging.h"
+#include "platform/api/tls_connection_factory.h"
+#include "platform/base/ip_address.h"
+
+namespace cast {
+namespace channel {
+
+using openscreen::Error;
+using openscreen::IPEndpoint;
+using openscreen::IPEndpointComparator;
+using openscreen::platform::TlsConnection;
+using openscreen::platform::TlsConnectionFactory;
+
+class SenderSocketFactory final : public TlsConnectionFactory::Client,
+                                  public CastSocket::Client {
+ public:
+  class Client {
+   public:
+    virtual void OnConnected(SenderSocketFactory* factory,
+                             const IPEndpoint& endpoint,
+                             std::unique_ptr<CastSocket> socket) = 0;
+    virtual void OnError(SenderSocketFactory* factory,
+                         const IPEndpoint& endpoint,
+                         Error error) = 0;
+  };
+
+  enum class DeviceMediaPolicy {
+    kAudioOnly,
+    kIncludesVideo,
+  };
+
+  // |client| must outlive |this|.
+  explicit SenderSocketFactory(Client* client);
+  ~SenderSocketFactory();
+
+  void set_factory(TlsConnectionFactory* factory) {
+    OSP_DCHECK(factory);
+    factory_ = factory;
+  }
+
+  void Connect(const IPEndpoint& endpoint,
+               DeviceMediaPolicy media_policy,
+               CastSocket::Client* client);
+
+  // TlsConnectionFactory::Client overrides.
+  void OnAccepted(TlsConnectionFactory* factory,
+                  X509* peer_cert,
+                  std::unique_ptr<TlsConnection> connection) override;
+  void OnConnected(TlsConnectionFactory* factory,
+                   X509* peer_cert,
+                   std::unique_ptr<TlsConnection> connection) override;
+  void OnConnectionFailed(TlsConnectionFactory* factory,
+                          const IPEndpoint& remote_address) override;
+  void OnError(TlsConnectionFactory* factory, Error error) override;
+
+ private:
+  struct PendingConnection {
+    IPEndpoint endpoint;
+    DeviceMediaPolicy media_policy;
+    CastSocket::Client* client;
+  };
+
+  struct PendingAuth {
+    IPEndpoint endpoint;
+    DeviceMediaPolicy media_policy;
+    std::unique_ptr<CastSocket> socket;
+    CastSocket::Client* client;
+    AuthContext auth_context;
+    X509* peer_cert;
+  };
+
+  friend bool operator<(const std::unique_ptr<PendingAuth>& a, uint32_t b);
+  friend bool operator<(uint32_t a, const std::unique_ptr<PendingAuth>& b);
+
+  std::vector<PendingConnection>::iterator FindPendingConnection(
+      const IPEndpoint& endpoint);
+
+  // CastSocket::Client overrides.
+  void OnError(CastSocket* socket, Error error) override;
+  void OnMessage(CastSocket* socket, CastMessage message) override;
+
+  Client* const client_;
+  TlsConnectionFactory* factory_ = nullptr;
+  std::vector<PendingConnection> pending_connections_;
+  std::vector<std::unique_ptr<PendingAuth>> pending_auth_;
+};
+
+}  // namespace channel
+}  // namespace cast
+
+#endif  // CAST_SENDER_CHANNEL_SENDER_SOCKET_FACTORY_H_
diff --git a/platform/api/tls_connection_factory.cc b/platform/api/tls_connection_factory.cc
index 0ec13bf..8a74b1e 100644
--- a/platform/api/tls_connection_factory.cc
+++ b/platform/api/tls_connection_factory.cc
@@ -8,19 +8,23 @@
 namespace platform {
 
 void TlsConnectionFactory::OnAccepted(
+    X509* peer_cert,
     std::unique_ptr<TlsConnection> connection) {
-  task_runner_->PostTask([c = std::move(connection), this]() mutable {
-    // TODO(issues/71): |this| may be invalid at this point.
-    this->client_->OnAccepted(this, std::move(c));
-  });
+  task_runner_->PostTask(
+      [peer_cert, c = std::move(connection), this]() mutable {
+        // TODO(issues/71): |this| may be invalid at this point.
+        this->client_->OnAccepted(this, peer_cert, std::move(c));
+      });
 }
 
 void TlsConnectionFactory::OnConnected(
+    X509* peer_cert,
     std::unique_ptr<TlsConnection> connection) {
-  task_runner_->PostTask([c = std::move(connection), this]() mutable {
-    // TODO(issues/71): |this| may be invalid at this point.
-    this->client_->OnConnected(this, std::move(c));
-  });
+  task_runner_->PostTask(
+      [peer_cert, c = std::move(connection), this]() mutable {
+        // TODO(issues/71): |this| may be invalid at this point.
+        this->client_->OnConnected(this, peer_cert, std::move(c));
+      });
 }
 
 void TlsConnectionFactory::OnConnectionFailed(
diff --git a/platform/api/tls_connection_factory.h b/platform/api/tls_connection_factory.h
index 21e6a14..81fb76c 100644
--- a/platform/api/tls_connection_factory.h
+++ b/platform/api/tls_connection_factory.h
@@ -28,9 +28,11 @@
   class Client {
    public:
     virtual void OnAccepted(TlsConnectionFactory* factory,
+                            X509* peer_cert,
                             std::unique_ptr<TlsConnection> connection) = 0;
 
     virtual void OnConnected(TlsConnectionFactory* factory,
+                             X509* peer_cert,
                              std::unique_ptr<TlsConnection> connection) = 0;
 
     virtual void OnConnectionFailed(TlsConnectionFactory* factory,
@@ -69,9 +71,9 @@
       : client_(client), task_runner_(task_runner) {}
 
   // The below methods proxy calls to this TlsConnectionFactory's Client.
-  void OnAccepted(std::unique_ptr<TlsConnection> connection);
+  void OnAccepted(X509* peer_cert, std::unique_ptr<TlsConnection> connection);
 
-  void OnConnected(std::unique_ptr<TlsConnection> connection);
+  void OnConnected(X509* peer_cert, std::unique_ptr<TlsConnection> connection);
 
   void OnConnectionFailed(const IPEndpoint& remote_address);
 
diff --git a/platform/base/error.cc b/platform/base/error.cc
index 33b0279..fd300c1 100644
--- a/platform/base/error.cc
+++ b/platform/base/error.cc
@@ -218,6 +218,8 @@
       return os << "Failure: kCastV2ConnectTimeout";
     case Error::Code::kCastV2PingTimeout:
       return os << "Failure: kCastV2PingTimeout";
+    case Error::Code::kCastV2ChannelPolicyMismatch:
+      return os << "Failure: kCastV2ChannelPolicyMismatch";
   }
 
   // Unused 'return' to get around failure on GCC.
diff --git a/platform/base/error.h b/platform/base/error.h
index 557c30b..11901cd 100644
--- a/platform/base/error.h
+++ b/platform/base/error.h
@@ -153,6 +153,7 @@
     kCastV2InvalidChannelId,
     kCastV2ConnectTimeout,
     kCastV2PingTimeout,
+    kCastV2ChannelPolicyMismatch,
 
     // Generic errors.
     kUnknownError,
diff --git a/platform/impl/tls_connection_factory_posix.cc b/platform/impl/tls_connection_factory_posix.cc
index cd83d71..3815996 100644
--- a/platform/impl/tls_connection_factory_posix.cc
+++ b/platform/impl/tls_connection_factory_posix.cc
@@ -80,7 +80,8 @@
     return;
   }
 
-  OnConnected(std::move(connection));
+  X509* peer_cert = SSL_get_peer_certificate(connection->ssl_.get());
+  OnConnected(peer_cert, std::move(connection));
 }
 
 void TlsConnectionFactoryPosix::SetListenCredentials(
@@ -151,7 +152,8 @@
     return;
   }
 
-  OnAccepted(std::move(connection));
+  X509* peer_cert = SSL_get_peer_certificate(connection->ssl_.get());
+  OnAccepted(peer_cert, std::move(connection));
 }
 
 bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) {