blob: f6321ec2464caa04979a9426c9f54de78ea21f58 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dns_tls_frontend.h"
#include <netdb.h>
#include <stdio.h>
#include <unistd.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <arpa/inet.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
#include <unistd.h>
#define LOG_TAG "DnsTlsFrontend"
#include <log/log.h>
#include <netdutils/SocketOption.h>
using android::netdutils::enableSockopt;
namespace {
// Copied from DnsTlsTransport.
bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
unsigned char spki[spki_len];
unsigned char* temp = spki;
if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
ALOGE("SPKI length mismatch");
return false;
}
out->resize(test::SHA256_SIZE);
unsigned int digest_len = 0;
int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
if (ret != 1) {
ALOGE("Server cert digest extraction failed");
return false;
}
if (digest_len != out->size()) {
ALOGE("Wrong digest length: %d", digest_len);
return false;
}
return true;
}
std::string errno2str() {
char error_msg[512] = { 0 };
if (strerror_r(errno, error_msg, sizeof(error_msg)))
return std::string();
return std::string(error_msg);
}
#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
char host_str[NI_MAXHOST] = { 0 };
int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
NI_NUMERICHOST);
if (rv == 0) return std::string(host_str);
return std::string();
}
bssl::UniquePtr<EVP_PKEY> make_private_key() {
bssl::UniquePtr<BIGNUM> e(BN_new());
if (!e) {
ALOGE("BN_new failed");
return nullptr;
}
if (!BN_set_word(e.get(), RSA_F4)) {
ALOGE("BN_set_word failed");
return nullptr;
}
bssl::UniquePtr<RSA> rsa(RSA_new());
if (!rsa) {
ALOGE("RSA_new failed");
return nullptr;
}
if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), NULL)) {
ALOGE("RSA_generate_key_ex failed");
return nullptr;
}
bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
if (!privkey) {
ALOGE("EVP_PKEY_new failed");
return nullptr;
}
if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
ALOGE("EVP_PKEY_assign_RSA failed");
return nullptr;
}
// |rsa| is now owned by |privkey|, so no need to free it.
rsa.release();
return privkey;
}
bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey, EVP_PKEY* parent_key) {
bssl::UniquePtr<X509> cert(X509_new());
if (!cert) {
ALOGE("X509_new failed");
return nullptr;
}
ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
// Set one hour expiration.
X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
X509_set_pubkey(cert.get(), privkey);
if (!X509_sign(cert.get(), parent_key, EVP_sha256())) {
ALOGE("X509_sign failed");
return nullptr;
}
return cert;
}
}
namespace test {
bool DnsTlsFrontend::startServer() {
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
ctx_.reset(SSL_CTX_new(TLS_server_method()));
if (!ctx_) {
ALOGE("SSL context creation failed");
return false;
}
SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
// Make certificate chain
std::vector<bssl::UniquePtr<EVP_PKEY>> keys(chain_length_);
for (int i = 0; i < chain_length_; ++i) {
keys[i] = make_private_key();
}
std::vector<bssl::UniquePtr<X509>> certs(chain_length_);
for (int i = 0; i < chain_length_; ++i) {
int next = std::min(i + 1, chain_length_ - 1);
certs[i] = make_cert(keys[i].get(), keys[next].get());
}
// Install certificate chain.
if (SSL_CTX_use_certificate(ctx_.get(), certs[0].get()) <= 0) {
ALOGE("SSL_CTX_use_certificate failed");
return false;
}
if (SSL_CTX_use_PrivateKey(ctx_.get(), keys[0].get()) <= 0 ) {
ALOGE("SSL_CTX_use_PrivateKey failed");
return false;
}
for (int i = 1; i < chain_length_; ++i) {
if (SSL_CTX_add1_chain_cert(ctx_.get(), certs[i].get()) != 1) {
ALOGE("SSL_CTX_add1_chain_cert failed");
return false;
}
}
// Report the fingerprint of the "middle" cert. For N = 2, this is the root.
int fp_index = chain_length_ / 2;
if (!getSPKIDigest(certs[fp_index].get(), &fingerprint_)) {
ALOGE("getSPKIDigest failed");
return false;
}
// Set up TCP server socket for clients.
addrinfo frontend_ai_hints{
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM,
.ai_flags = AI_PASSIVE
};
addrinfo* frontend_ai_res;
int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
&frontend_ai_hints, &frontend_ai_res);
if (rv) {
ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
listen_service_.c_str(), gai_strerror(rv));
return false;
}
int s = -1;
for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
if (s < 0) continue;
enableSockopt(s, SOL_SOCKET, SO_REUSEPORT);
enableSockopt(s, SOL_SOCKET, SO_REUSEADDR);
if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
APLOGI("bind failed for socket %d", s);
close(s);
s = -1;
continue;
}
std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
break;
}
freeaddrinfo(frontend_ai_res);
if (s < 0) {
ALOGE("server socket creation failed");
return false;
}
if (listen(s, 1) < 0) {
ALOGE("listen failed");
return false;
}
socket_ = s;
// Set up UDP client socket to backend.
addrinfo backend_ai_hints{
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_DGRAM
};
addrinfo* backend_ai_res;
rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
&backend_ai_hints, &backend_ai_res);
if (rv) {
ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
listen_service_.c_str(), gai_strerror(rv));
return false;
}
backend_socket_ = socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
backend_ai_res->ai_protocol);
if (backend_socket_ < 0) {
ALOGE("backend socket creation failed");
return false;
}
connect(backend_socket_, backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
freeaddrinfo(backend_ai_res);
{
std::lock_guard<std::mutex> lock(update_mutex_);
handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
}
ALOGI("server started successfully");
return true;
}
void DnsTlsFrontend::requestHandler() {
ALOGD("Request handler started");
struct pollfd fds[1] = {{ .fd = socket_, .events = POLLIN }};
while (!terminate_) {
int poll_code = poll(fds, 1, 10 /* ms */);
if (poll_code == 0) {
// Timeout. Poll again.
continue;
} else if (poll_code < 0) {
ALOGW("Poll failed with error %d", poll_code);
// Error.
break;
}
sockaddr_storage addr;
socklen_t len = sizeof(addr);
ALOGD("Trying to accept a client");
int client = accept(socket_, reinterpret_cast<sockaddr*>(&addr), &len);
ALOGD("Got client socket %d", client);
if (client < 0) {
// Stop
break;
}
bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
SSL_set_fd(ssl.get(), client);
ALOGD("Doing SSL handshake");
bool success = false;
if (SSL_accept(ssl.get()) <= 0) {
ALOGI("SSL negotiation failure");
} else {
ALOGD("SSL handshake complete");
success = handleOneRequest(ssl.get());
}
close(client);
if (success) {
// Increment queries_ as late as possible, because it represents
// a query that is fully processed, and the response returned to the
// client, including cleanup actions.
++queries_;
}
}
ALOGD("Request handler terminating");
}
bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
uint8_t queryHeader[2];
if (SSL_read(ssl, &queryHeader, 2) != 2) {
ALOGI("Not enough header bytes");
return false;
}
const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
uint8_t query[qlen];
size_t qbytes = 0;
while (qbytes < qlen) {
int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
if (ret <= 0) {
ALOGI("Error while reading query");
return false;
}
qbytes += ret;
}
int sent = send(backend_socket_, query, qlen, 0);
if (sent != qlen) {
ALOGI("Failed to send query");
return false;
}
const int max_size = 4096;
uint8_t recv_buffer[max_size];
int rlen = recv(backend_socket_, recv_buffer, max_size, 0);
if (rlen <= 0) {
ALOGI("Failed to receive response");
return false;
}
uint8_t responseHeader[2];
responseHeader[0] = rlen >> 8;
responseHeader[1] = rlen;
if (SSL_write(ssl, responseHeader, 2) != 2) {
ALOGI("Failed to write response header");
return false;
}
if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
ALOGI("Failed to write response body");
return false;
}
return true;
}
bool DnsTlsFrontend::stopServer() {
std::lock_guard<std::mutex> lock(update_mutex_);
if (!running()) {
ALOGI("server not running");
return false;
}
if (terminate_) {
ALOGI("LOGIC ERROR");
return false;
}
ALOGI("stopping frontend");
terminate_ = true;
handler_thread_.join();
close(socket_);
close(backend_socket_);
terminate_ = false;
socket_ = -1;
backend_socket_ = -1;
ctx_.reset();
fingerprint_.clear();
ALOGI("frontend stopped successfully");
return true;
}
bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
constexpr int intervalMs = 20;
int limit = timeoutMs / intervalMs;
for (int count = 0; count <= limit; ++count) {
bool done = queries_ >= number;
// Always sleep at least one more interval after we are done, to wait for
// any immediate post-query actions that the client may take (such as
// marking this server as reachable during validation).
usleep(intervalMs * 1000);
if (done) {
// For ensuring that calls have sufficient headroom for slow machines
ALOGD("Query arrived in %d/%d of allotted time", count, limit);
return true;
}
}
return false;
}
} // namespace test