blob: db8543a36aab5fa68c26cf67d270b09411c819f7 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <https/SSLSocket.h>
#include <https/SafeCallbackable.h>
#include <https/Support.h>
#include <android-base/logging.h>
#include <sstream>
#include <sys/socket.h>
// static
void SSLSocket::Init() {
SSL_library_init();
SSL_load_error_strings();
}
// static
SSL_CTX *SSLSocket::CreateSSLContext() {
SSL_CTX *ctx = SSL_CTX_new(SSLv23_method());
/* Recommended to avoid SSLv2 & SSLv3 */
SSL_CTX_set_options(
ctx, SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
return ctx;
}
SSLSocket::SSLSocket(
std::shared_ptr<RunLoop> rl, Mode mode, int sock, uint32_t flags)
: BufferedSocket(rl, sock),
mMode(mode),
mFlags(flags),
mCtx(CreateSSLContext(), SSL_CTX_free),
mSSL(SSL_new(mCtx.get()), SSL_free),
mBioR(BIO_new(BIO_s_mem())),
mBioW(BIO_new(BIO_s_mem())),
mEOS(false),
mFinalErrno(0),
mRecvPending(false),
mRecvCallback(nullptr),
mSendPending(false),
mFlushFn(nullptr) {
if (mMode == Mode::ACCEPT) {
SSL_set_accept_state(mSSL.get());
} else {
SSL_set_connect_state(mSSL.get());
}
SSL_set_bio(mSSL.get(), mBioR, mBioW);
}
bool SSLSocket::useCertificate(const std::string &path) {
return 1 == SSL_use_certificate_file(
mSSL.get(), path.c_str(), SSL_FILETYPE_PEM);
}
bool SSLSocket::usePrivateKey(const std::string &path) {
return 1 == SSL_use_PrivateKey_file(
mSSL.get(), path.c_str(), SSL_FILETYPE_PEM)
&& 1 == SSL_check_private_key(mSSL.get());
}
bool SSLSocket::useTrustedCertificates(const std::string &path) {
return 1 == SSL_CTX_load_verify_locations(
mCtx.get(),
path.c_str(),
nullptr /* CApath */);
}
SSLSocket::SSLSocket(
std::shared_ptr<RunLoop> rl,
int sock,
const std::string &certificate_pem_path,
const std::string &private_key_pem_path,
uint32_t flags)
: SSLSocket(rl, Mode::ACCEPT, sock, flags) {
// This flag makes no sense for a server.
CHECK(!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE));
CHECK(useCertificate(certificate_pem_path)
&& usePrivateKey(private_key_pem_path));
}
SSLSocket::SSLSocket(
std::shared_ptr<RunLoop> rl,
int sock,
uint32_t flags,
const std::optional<std::string> &trusted_pem_path)
: SSLSocket(rl, Mode::CONNECT, sock, flags) {
if (!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)) {
CHECK(trusted_pem_path.has_value());
CHECK(useTrustedCertificates(*trusted_pem_path));
}
}
SSLSocket::~SSLSocket() {
SSL_shutdown(mSSL.get());
mBioW = mBioR = nullptr;
}
void SSLSocket::postRecv(RunLoop::AsyncFunction fn) {
char tmp[128];
int n = SSL_peek(mSSL.get(), tmp, sizeof(tmp));
if (n > 0) {
fn();
return;
}
CHECK(mRecvCallback == nullptr);
mRecvCallback = fn;
if (!mRecvPending) {
mRecvPending = true;
runLoop()->postSocketRecv(
fd(),
makeSafeCallback(this, &SSLSocket::handleIncomingData));
}
}
void SSLSocket::handleIncomingData() {
mRecvPending = false;
uint8_t buffer[1024];
ssize_t len;
do {
len = ::recv(fd(), buffer, sizeof(buffer), 0);
} while (len < 0 && errno == EINTR);
if (len <= 0) {
mEOS = true;
mFinalErrno = (len < 0) ? errno : 0;
sendRecvCallback();
return;
}
size_t offset = 0;
while (len > 0) {
int n = BIO_write(mBioR, &buffer[offset], len);
CHECK_GT(n, 0);
offset += n;
len -= n;
if (!SSL_is_init_finished(mSSL.get())) {
if (mMode == Mode::ACCEPT) {
n = SSL_accept(mSSL.get());
} else {
n = SSL_connect(mSSL.get());
}
auto err = SSL_get_error(mSSL.get(), n);
switch (err) {
case SSL_ERROR_WANT_READ:
{
CHECK_EQ(len, 0);
queueOutputDataFromSSL();
mRecvPending = true;
runLoop()->postSocketRecv(
fd(),
makeSafeCallback(
this, &SSLSocket::handleIncomingData));
return;
}
case SSL_ERROR_WANT_WRITE:
{
CHECK_EQ(len, 0);
mRecvPending = true;
runLoop()->postSocketRecv(
fd(),
makeSafeCallback(
this, &SSLSocket::handleIncomingData));
return;
}
case SSL_ERROR_NONE:
break;
case SSL_ERROR_SYSCALL:
default:
{
// This is where we end up if the client doesn't trust us.
mEOS = true;
mFinalErrno = ECONNREFUSED;
sendRecvCallback();
return;
}
}
CHECK(SSL_is_init_finished(mSSL.get()));
drainOutputBufferPlain();
if (!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)
&& !isPeerCertificateValid()) {
mEOS = true;
mFinalErrno = ECONNREFUSED;
sendRecvCallback();
}
}
}
int n = SSL_peek(mSSL.get(), buffer, sizeof(buffer));
if (n > 0) {
sendRecvCallback();
return;
}
auto err = SSL_get_error(mSSL.get(), n);
switch (err) {
case SSL_ERROR_WANT_READ:
{
queueOutputDataFromSSL();
mRecvPending = true;
runLoop()->postSocketRecv(
fd(),
makeSafeCallback(this, &SSLSocket::handleIncomingData));
break;
}
case SSL_ERROR_WANT_WRITE:
{
mRecvPending = true;
runLoop()->postSocketRecv(
fd(),
makeSafeCallback(this, &SSLSocket::handleIncomingData));
break;
}
case SSL_ERROR_ZERO_RETURN:
{
mEOS = true;
mFinalErrno = 0;
sendRecvCallback();
break;
}
case SSL_ERROR_NONE:
break;
case SSL_ERROR_SYSCALL:
default:
{
// This is where we end up if the client doesn't trust us.
mEOS = true;
mFinalErrno = ECONNREFUSED;
sendRecvCallback();
break;
}
}
}
void SSLSocket::sendRecvCallback() {
const auto cb = mRecvCallback;
mRecvCallback = nullptr;
if (cb != nullptr) {
cb();
}
}
void SSLSocket::postSend(RunLoop::AsyncFunction fn) {
runLoop()->post(fn);
}
ssize_t SSLSocket::recvfrom(
void *data,
size_t size,
sockaddr *address,
socklen_t *addressLen) {
if (address || addressLen) {
errno = EINVAL;
return -1;
}
if (mEOS) {
errno = mFinalErrno;
return (mFinalErrno == 0) ? 0 : -1;
}
int n = SSL_read(mSSL.get(), data, size);
// We should only get here after SSL_peek signaled that there's data to
// be read.
CHECK_GT(n, 0);
return n;
}
void SSLSocket::queueOutputDataFromSSL() {
int n;
do {
char buf[1024];
n = BIO_read(mBioW, buf, sizeof(buf));
if (n > 0) {
queueOutputData(buf, n);
} else if (BIO_should_retry(mBioW)) {
continue;
} else {
LOG(FATAL) << "Should not be here.";
}
} while (n > 0);
}
void SSLSocket::queueOutputData(const void *data, size_t size) {
if (!size) {
return;
}
const size_t pos = mOutBuffer.size();
mOutBuffer.resize(pos + size);
memcpy(mOutBuffer.data() + pos, data, size);
if (!mSendPending) {
mSendPending = true;
runLoop()->postSocketSend(
fd(),
makeSafeCallback(this, &SSLSocket::sendOutputData));
}
}
void SSLSocket::sendOutputData() {
mSendPending = false;
const size_t size = mOutBuffer.size();
size_t offset = 0;
while (offset < size) {
ssize_t n = ::send(
fd(), mOutBuffer.data() + offset, size - offset, 0);
if (n < 0) {
if (errno == EINTR) {
continue;
} else if (errno == EAGAIN || errno == EWOULDBLOCK) {
break;
}
LOG(FATAL) << "Should not be here.";
}
offset += static_cast<size_t>(n);
}
mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
if (!mOutBufferPlain.empty()) {
drainOutputBufferPlain();
}
if (!mOutBuffer.empty()) {
mSendPending = true;
runLoop()->postSocketSend(
fd(),
makeSafeCallback(this, &SSLSocket::sendOutputData));
return;
}
auto fn = mFlushFn;
mFlushFn = nullptr;
if (fn != nullptr) {
fn();
}
}
ssize_t SSLSocket::sendto(
const void *data,
size_t size,
const sockaddr *addr,
socklen_t addrLen) {
if (addr || addrLen) {
errno = -EINVAL;
return -1;
}
if (mEOS) {
errno = mFinalErrno;
return (mFinalErrno == 0) ? 0 : -1;
}
const size_t pos = mOutBufferPlain.size();
mOutBufferPlain.resize(pos + size);
memcpy(&mOutBufferPlain[pos], data, size);
drainOutputBufferPlain();
return size;
}
void SSLSocket::drainOutputBufferPlain() {
size_t offset = 0;
const size_t size = mOutBufferPlain.size();
while (offset < size) {
int n = SSL_write(mSSL.get(), &mOutBufferPlain[offset], size - offset);
if (!SSL_is_init_finished(mSSL.get())) {
if (mMode == Mode::ACCEPT) {
n = SSL_accept(mSSL.get());
} else {
n = SSL_connect(mSSL.get());
}
auto err = SSL_get_error(mSSL.get(), n);
switch (err) {
case SSL_ERROR_WANT_WRITE:
{
mOutBufferPlain.erase(
mOutBufferPlain.begin(),
mOutBufferPlain.begin() + offset);
queueOutputDataFromSSL();
return;
}
case SSL_ERROR_WANT_READ:
{
mOutBufferPlain.erase(
mOutBufferPlain.begin(),
mOutBufferPlain.begin() + offset);
queueOutputDataFromSSL();
if (!mRecvPending) {
mRecvPending = true;
runLoop()->postSocketRecv(
fd(),
makeSafeCallback(
this, &SSLSocket::handleIncomingData));
}
return;
}
case SSL_ERROR_SYSCALL:
{
// This is where we end up if the client doesn't trust us.
mEOS = true;
mFinalErrno = ECONNREFUSED;
LOG(FATAL) << "Should not be here.";
return;
}
case SSL_ERROR_NONE:
break;
default:
LOG(FATAL) << "Should not be here.";
}
CHECK(SSL_is_init_finished(mSSL.get()));
if (!isPeerCertificateValid()) {
mEOS = true;
mFinalErrno = ECONNREFUSED;
sendRecvCallback();
}
}
offset += n;
}
mOutBufferPlain.erase(
mOutBufferPlain.begin(), mOutBufferPlain.begin() + offset);
queueOutputDataFromSSL();
}
bool SSLSocket::isPeerCertificateValid() {
if (mMode == Mode::ACCEPT || (mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)) {
// For now we won't validate the client if we are the server.
return true;
}
std::unique_ptr<X509, std::function<void(X509 *)>> cert(
SSL_get_peer_certificate(mSSL.get()), X509_free);
if (!cert) {
LOG(ERROR) << "SSLSocket::isPeerCertificateValid no certificate.";
return false;
}
int res = SSL_get_verify_result(mSSL.get());
bool valid = (res == X509_V_OK);
if (!valid) {
LOG(ERROR) << "SSLSocket::isPeerCertificateValid invalid certificate.";
const EVP_MD *digest = EVP_get_digestbyname("sha256");
unsigned char md[EVP_MAX_MD_SIZE];
unsigned int n;
int res = X509_digest(cert.get(), digest, md, &n);
CHECK_EQ(res, 1);
std::stringstream ss;
for (unsigned int i = 0; i < n; ++i) {
if (i > 0) {
ss << ":";
}
auto byte = md[i];
auto nibble = byte >> 4;
ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
nibble = byte & 0x0f;
ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
}
LOG(ERROR)
<< "Server offered certificate w/ fingerprint "
<< ss.str();
}
return valid;
}
void SSLSocket::postFlush(RunLoop::AsyncFunction fn) {
CHECK(mFlushFn == nullptr);
if (!mSendPending) {
fn();
return;
}
mFlushFn = fn;
}