Unit tests for SSLAdapter

R=juberti@webrtc.org

Review URL: https://webrtc-codereview.appspot.com/17309004

Patch from Manish Jethani <manish.jethani@gmail.com>.

git-svn-id: http://webrtc.googlecode.com/svn/trunk/webrtc@7269 4adac7df-926f-26a2-2b94-8c16560cd09d
diff --git a/base/base_tests.gyp b/base/base_tests.gyp
index 939d1d7..2d99e81 100644
--- a/base/base_tests.gyp
+++ b/base/base_tests.gyp
@@ -149,6 +149,7 @@
           }],
           ['os_posix==1', {
             'sources': [
+              #'ssladapter_unittest.cc',
               #'sslidentity_unittest.cc',
               #'sslstreamadapter_unittest.cc',
             ],
diff --git a/base/nssstreamadapter.cc b/base/nssstreamadapter.cc
index f4b2d31..40c017f 100644
--- a/base/nssstreamadapter.cc
+++ b/base/nssstreamadapter.cc
@@ -486,6 +486,8 @@
       return -1;
     }
 
+    // TODO(juberti): Check for client_auth_enabled()
+
     rv = SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
     if (rv != SECSuccess) {
       Error("BeginSSL", -1, false);
diff --git a/base/opensslstreamadapter.cc b/base/opensslstreamadapter.cc
index ed5ac74..070a948 100644
--- a/base/opensslstreamadapter.cc
+++ b/base/opensslstreamadapter.cc
@@ -743,8 +743,15 @@
   SSL_CTX_set_info_callback(ctx, OpenSSLAdapter::SSLInfoCallback);
 #endif
 
-  SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER |SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
-                     SSLVerifyCallback);
+  int mode = SSL_VERIFY_PEER;
+  if (client_auth_enabled()) {
+    // Require a certificate from the client.
+    // Note: Normally this is always true in production, but it may be disabled
+    // for testing purposes (e.g. SSLAdapter unit tests).
+    mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
+  }
+
+  SSL_CTX_set_verify(ctx, mode, SSLVerifyCallback);
   SSL_CTX_set_verify_depth(ctx, 4);
   SSL_CTX_set_cipher_list(ctx, "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
 
diff --git a/base/ssladapter_unittest.cc b/base/ssladapter_unittest.cc
new file mode 100644
index 0000000..6d4536d
--- /dev/null
+++ b/base/ssladapter_unittest.cc
@@ -0,0 +1,342 @@
+/*
+ *  Copyright 2014 The WebRTC Project Authors. All rights reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <string>
+
+#include "webrtc/base/gunit.h"
+#include "webrtc/base/ipaddress.h"
+#include "webrtc/base/socketstream.h"
+#include "webrtc/base/ssladapter.h"
+#include "webrtc/base/sslstreamadapter.h"
+#include "webrtc/base/stream.h"
+#include "webrtc/base/virtualsocketserver.h"
+
+static const int kTimeout = 5000;
+
+static rtc::AsyncSocket* CreateSocket(const rtc::SSLMode& ssl_mode) {
+  rtc::SocketAddress address(rtc::IPAddress(INADDR_ANY), 0);
+
+  rtc::AsyncSocket* socket = rtc::Thread::Current()->
+      socketserver()->CreateAsyncSocket(
+      address.family(), (ssl_mode == rtc::SSL_MODE_DTLS) ?
+      SOCK_DGRAM : SOCK_STREAM);
+  socket->Bind(address);
+
+  return socket;
+}
+
+static std::string GetSSLProtocolName(const rtc::SSLMode& ssl_mode) {
+  return (ssl_mode == rtc::SSL_MODE_DTLS) ? "DTLS" : "TLS";
+}
+
+class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
+ public:
+  explicit SSLAdapterTestDummyClient(const rtc::SSLMode& ssl_mode)
+      : ssl_mode_(ssl_mode) {
+    rtc::AsyncSocket* socket = CreateSocket(ssl_mode_);
+
+    ssl_adapter_.reset(rtc::SSLAdapter::Create(socket));
+
+    // Ignore any certificate errors for the purpose of testing.
+    // Note: We do this only because we don't have a real certificate.
+    // NEVER USE THIS IN PRODUCTION CODE!
+    ssl_adapter_->set_ignore_bad_cert(true);
+
+    ssl_adapter_->SignalReadEvent.connect(this,
+        &SSLAdapterTestDummyClient::OnSSLAdapterReadEvent);
+    ssl_adapter_->SignalCloseEvent.connect(this,
+        &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent);
+  }
+
+  rtc::AsyncSocket::ConnState GetState() const {
+    return ssl_adapter_->GetState();
+  }
+
+  const std::string& GetReceivedData() const {
+    return data_;
+  }
+
+  int Connect(const std::string& hostname, const rtc::SocketAddress& address) {
+    LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_)
+        << " handshake with " << hostname;
+
+    if (ssl_adapter_->StartSSL(hostname.c_str(), false) != 0) {
+      return -1;
+    }
+
+    LOG(LS_INFO) << "Initiating connection with " << address;
+
+    return ssl_adapter_->Connect(address);
+  }
+
+  int Close() {
+    return ssl_adapter_->Close();
+  }
+
+  int Send(const std::string& message) {
+    LOG(LS_INFO) << "Client sending '" << message << "'";
+
+    return ssl_adapter_->Send(message.data(), message.length());
+  }
+
+  void OnSSLAdapterReadEvent(rtc::AsyncSocket* socket) {
+    char buffer[4096] = "";
+
+    // Read data received from the server and store it in our internal buffer.
+    int read = socket->Recv(buffer, sizeof(buffer) - 1);
+    if (read != -1) {
+      buffer[read] = '\0';
+
+      LOG(LS_INFO) << "Client received '" << buffer << "'";
+
+      data_ += buffer;
+    }
+  }
+
+  void OnSSLAdapterCloseEvent(rtc::AsyncSocket* socket, int error) {
+    // OpenSSLAdapter signals handshake failure with a close event, but without
+    // closing the socket! Let's close the socket here. This way GetState() can
+    // return CS_CLOSED after failure.
+    if (socket->GetState() != rtc::AsyncSocket::CS_CLOSED) {
+      socket->Close();
+    }
+  }
+
+ private:
+  const rtc::SSLMode ssl_mode_;
+
+  rtc::scoped_ptr<rtc::SSLAdapter> ssl_adapter_;
+
+  std::string data_;
+};
+
+class SSLAdapterTestDummyServer : public sigslot::has_slots<> {
+ public:
+  explicit SSLAdapterTestDummyServer(const rtc::SSLMode& ssl_mode)
+      : ssl_mode_(ssl_mode) {
+    // Generate a key pair and a certificate for this host.
+    ssl_identity_.reset(rtc::SSLIdentity::Generate(GetHostname()));
+
+    server_socket_.reset(CreateSocket(ssl_mode_));
+
+    server_socket_->SignalReadEvent.connect(this,
+        &SSLAdapterTestDummyServer::OnServerSocketReadEvent);
+
+    server_socket_->Listen(1);
+
+    LOG(LS_INFO) << ((ssl_mode_ == rtc::SSL_MODE_DTLS) ? "UDP" : "TCP")
+        << " server listening on " << server_socket_->GetLocalAddress();
+  }
+
+  rtc::SocketAddress GetAddress() const {
+    return server_socket_->GetLocalAddress();
+  }
+
+  std::string GetHostname() const {
+    // Since we don't have a real certificate anyway, the value here doesn't
+    // really matter.
+    return "example.com";
+  }
+
+  const std::string& GetReceivedData() const {
+    return data_;
+  }
+
+  int Send(const std::string& message) {
+    if (ssl_stream_adapter_ == NULL
+        || ssl_stream_adapter_->GetState() != rtc::SS_OPEN) {
+      // No connection yet.
+      return -1;
+    }
+
+    LOG(LS_INFO) << "Server sending '" << message << "'";
+
+    size_t written;
+    int error;
+
+    rtc::StreamResult r = ssl_stream_adapter_->Write(message.data(),
+        message.length(), &written, &error);
+    if (r == rtc::SR_SUCCESS) {
+      return written;
+    } else {
+      return -1;
+    }
+  }
+
+  void OnServerSocketReadEvent(rtc::AsyncSocket* socket) {
+    if (ssl_stream_adapter_ != NULL) {
+      // Only a single connection is supported.
+      return;
+    }
+
+    rtc::SocketAddress address;
+    rtc::AsyncSocket* new_socket = socket->Accept(&address);
+    rtc::SocketStream* stream = new rtc::SocketStream(new_socket);
+
+    ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream));
+    ssl_stream_adapter_->SetServerRole();
+
+    // SSLStreamAdapter is normally used for peer-to-peer communication, but
+    // here we're testing communication between a client and a server
+    // (e.g. a WebRTC-based application and an RFC 5766 TURN server), where
+    // clients are not required to provide a certificate during handshake.
+    // Accordingly, we must disable client authentication here.
+    ssl_stream_adapter_->set_client_auth_enabled(false);
+
+    ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference());
+
+    // Set a bogus peer certificate digest.
+    unsigned char digest[20];
+    size_t digest_len = sizeof(digest);
+    ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
+        digest_len);
+
+    ssl_stream_adapter_->StartSSLWithPeer();
+
+    ssl_stream_adapter_->SignalEvent.connect(this,
+        &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent);
+  }
+
+  void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) {
+    if (sig & rtc::SE_READ) {
+      char buffer[4096] = "";
+
+      size_t read;
+      int error;
+
+      // Read data received from the client and store it in our internal
+      // buffer.
+      rtc::StreamResult r = stream->Read(buffer,
+          sizeof(buffer) - 1, &read, &error);
+      if (r == rtc::SR_SUCCESS) {
+        buffer[read] = '\0';
+
+        LOG(LS_INFO) << "Server received '" << buffer << "'";
+
+        data_ += buffer;
+      }
+    }
+  }
+
+ private:
+  const rtc::SSLMode ssl_mode_;
+
+  rtc::scoped_ptr<rtc::AsyncSocket> server_socket_;
+  rtc::scoped_ptr<rtc::SSLStreamAdapter> ssl_stream_adapter_;
+
+  rtc::scoped_ptr<rtc::SSLIdentity> ssl_identity_;
+
+  std::string data_;
+};
+
+class SSLAdapterTestBase : public testing::Test,
+                           public sigslot::has_slots<> {
+ public:
+  explicit SSLAdapterTestBase(const rtc::SSLMode& ssl_mode)
+      : ssl_mode_(ssl_mode),
+        ss_scope_(new rtc::VirtualSocketServer(NULL)),
+        server_(new SSLAdapterTestDummyServer(ssl_mode_)),
+        client_(new SSLAdapterTestDummyClient(ssl_mode_)),
+        handshake_wait_(kTimeout) {
+  }
+
+  static void SetUpTestCase() {
+    rtc::InitializeSSL();
+  }
+
+  static void TearDownTestCase() {
+    rtc::CleanupSSL();
+  }
+
+  void SetHandshakeWait(int wait) {
+    handshake_wait_ = wait;
+  }
+
+  void TestHandshake(bool expect_success) {
+    int rv;
+
+    // The initial state is CS_CLOSED
+    ASSERT_EQ(rtc::AsyncSocket::CS_CLOSED, client_->GetState());
+
+    rv = client_->Connect(server_->GetHostname(), server_->GetAddress());
+    ASSERT_EQ(0, rv);
+
+    // Now the state should be CS_CONNECTING
+    ASSERT_EQ(rtc::AsyncSocket::CS_CONNECTING, client_->GetState());
+
+    if (expect_success) {
+      // If expecting success, the client should end up in the CS_CONNECTED
+      // state after handshake.
+      EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CONNECTED, client_->GetState(),
+          handshake_wait_);
+
+      LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake complete.";
+
+    } else {
+      // On handshake failure the client should end up in the CS_CLOSED state.
+      EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CLOSED, client_->GetState(),
+          handshake_wait_);
+
+      LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake failed.";
+    }
+  }
+
+  void TestTransfer(const std::string& message) {
+    int rv;
+
+    rv = client_->Send(message);
+    ASSERT_EQ(static_cast<int>(message.length()), rv);
+
+    // The server should have received the client's message.
+    EXPECT_EQ_WAIT(message, server_->GetReceivedData(), kTimeout);
+
+    rv = server_->Send(message);
+    ASSERT_EQ(static_cast<int>(message.length()), rv);
+
+    // The client should have received the server's message.
+    EXPECT_EQ_WAIT(message, client_->GetReceivedData(), kTimeout);
+
+    LOG(LS_INFO) << "Transfer complete.";
+  }
+
+ private:
+  const rtc::SSLMode ssl_mode_;
+
+  const rtc::SocketServerScope ss_scope_;
+
+  rtc::scoped_ptr<SSLAdapterTestDummyServer> server_;
+  rtc::scoped_ptr<SSLAdapterTestDummyClient> client_;
+
+  int handshake_wait_;
+};
+
+class SSLAdapterTestTLS : public SSLAdapterTestBase {
+ public:
+  SSLAdapterTestTLS() : SSLAdapterTestBase(rtc::SSL_MODE_TLS) {}
+};
+
+
+#if SSL_USE_OPENSSL
+
+// Basic tests: TLS
+
+// Test that handshake works
+TEST_F(SSLAdapterTestTLS, TestTLSConnect) {
+  TestHandshake(true);
+}
+
+// Test transfer between client and server
+TEST_F(SSLAdapterTestTLS, TestTLSTransfer) {
+  TestHandshake(true);
+  TestTransfer("Hello, world!");
+}
+
+#endif  // SSL_USE_OPENSSL
+
diff --git a/base/sslstreamadapter.h b/base/sslstreamadapter.h
index ffe6b2f..ea966c5 100644
--- a/base/sslstreamadapter.h
+++ b/base/sslstreamadapter.h
@@ -48,11 +48,15 @@
   static SSLStreamAdapter* Create(StreamInterface* stream);
 
   explicit SSLStreamAdapter(StreamInterface* stream)
-      : StreamAdapterInterface(stream), ignore_bad_cert_(false) { }
+      : StreamAdapterInterface(stream), ignore_bad_cert_(false),
+        client_auth_enabled_(true) { }
 
   void set_ignore_bad_cert(bool ignore) { ignore_bad_cert_ = ignore; }
   bool ignore_bad_cert() const { return ignore_bad_cert_; }
 
+  void set_client_auth_enabled(bool enabled) { client_auth_enabled_ = enabled; }
+  bool client_auth_enabled() const { return client_auth_enabled_; }
+
   // Specify our SSL identity: key and certificate. Mostly this is
   // only used in the peer-to-peer mode (unless we actually want to
   // provide a client certificate to a server).
@@ -151,10 +155,16 @@
   static bool HaveDtlsSrtp();
   static bool HaveExporter();
 
+ private:
   // If true, the server certificate need not match the configured
   // server_name, and in fact missing certificate authority and other
   // verification errors are ignored.
   bool ignore_bad_cert_;
+
+  // If true (default), the client is required to provide a certificate during
+  // handshake. If no certificate is given, handshake fails. This applies to
+  // server mode only.
+  bool client_auth_enabled_;
 };
 
 }  // namespace rtc