Merge "Add support for session tickets"
diff --git a/server/Android.mk b/server/Android.mk
index 93cd1fd..cd67f87 100644
--- a/server/Android.mk
+++ b/server/Android.mk
@@ -19,7 +19,7 @@
###
include $(CLEAR_VARS)
-LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CFLAGS := -Wall -Werror -Wthread-safety
LOCAL_SANITIZE := unsigned-integer-overflow
LOCAL_MODULE := libnetdaidl_static
LOCAL_SHARED_LIBRARIES := \
@@ -36,7 +36,7 @@
include $(CLEAR_VARS)
-LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CFLAGS := -Wall -Werror -Wthread-safety
LOCAL_SANITIZE := unsigned-integer-overflow
LOCAL_MODULE := libnetdaidl
LOCAL_SHARED_LIBRARIES := \
@@ -58,7 +58,7 @@
external/mdnsresponder/mDNSShared \
system/netd/include \
-LOCAL_CPPFLAGS := -Wall -Werror
+LOCAL_CPPFLAGS := -Wall -Werror -Wthread-safety
LOCAL_SANITIZE := unsigned-integer-overflow
LOCAL_MODULE := netd
@@ -147,7 +147,7 @@
###
include $(CLEAR_VARS)
-LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CFLAGS := -Wall -Werror -Wthread-safety
LOCAL_SANITIZE := unsigned-integer-overflow
LOCAL_CLANG := true
LOCAL_MODULE := ndc
@@ -163,7 +163,7 @@
LOCAL_MODULE := netd_unit_test
LOCAL_COMPATIBILITY_SUITE := device-tests
LOCAL_SANITIZE := unsigned-integer-overflow
-LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter
+LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter -Wthread-safety
# Bug: http://b/29823425 Disable -Wvarargs for Clang update to r271374
LOCAL_CFLAGS += -Wno-varargs
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index b369022..542b4a9 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -22,7 +22,6 @@
#include <arpa/nameser.h>
#include <errno.h>
#include <openssl/err.h>
-#include <openssl/ssl.h>
#include <stdlib.h>
#define LOG_TAG "DnsTlsTransport"
@@ -143,6 +142,9 @@
} // namespace
android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
+ if (DBG) {
+ ALOGD("%u connecting TCP socket", mMark);
+ }
android::base::unique_fd fd;
int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
switch (mServer.protocol) {
@@ -168,6 +170,11 @@
fd.reset();
}
+ if (!setNonBlocking(fd, false)) {
+ ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
+ fd.reset();
+ }
+
return fd;
}
@@ -231,28 +238,45 @@
return make_tie(*this) == make_tie(other);
}
-SSL* DnsTlsTransport::sslConnect(int fd) {
- if (fd < 0) {
- ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
+bool DnsTlsTransport::initialize() {
+ mSslCtx.reset(SSL_CTX_new(TLS_method()));
+ if (!mSslCtx) {
+ return false;
+ }
+ SSL_CTX_sess_set_new_cb(mSslCtx.get(), DnsTlsTransport::newSessionCallback);
+ SSL_CTX_sess_set_remove_cb(mSslCtx.get(), DnsTlsTransport::removeSessionCallback);
+ return true;
+}
+
+bssl::UniquePtr<SSL> DnsTlsTransport::sslConnect(int fd) {
+ // Check TLS context.
+ if (!mSslCtx) {
+ ALOGE("Internal error: context is null in ssl connect");
+ return nullptr;
+ }
+ if (!SSL_CTX_set_max_proto_version(mSslCtx.get(), TLS1_3_VERSION) ||
+ !SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
+ ALOGE("failed to min/max TLS versions");
return nullptr;
}
- // Set up TLS context.
- bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
- if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
- !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
- ALOGD("failed to min/max TLS versions");
- return nullptr;
- }
-
- bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
+ bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
// This file descriptor is owned by a unique_fd, so don't let libssl close it.
bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
SSL_set_bio(ssl.get(), bio.get(), bio.get());
bio.release();
- if (!setNonBlocking(fd, false)) {
- ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
+ // Add this transport as the 0-index extra data for the socket.
+ // This is used by newSessionCallback.
+ if (SSL_set_ex_data(ssl.get(), 0, this) != 1) {
+ ALOGE("failed to associate SSL socket to transport");
+ return nullptr;
+ }
+
+ // Add this transport as the 0-index extra data for the context.
+ // This is used by removeSessionCallback.
+ if (SSL_CTX_set_ex_data(mSslCtx.get(), 0, this) != 1) {
+ ALOGE("failed to associate SSL context to transport");
return nullptr;
}
@@ -267,6 +291,20 @@
SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
}
+ bssl::UniquePtr<SSL_SESSION> session;
+ {
+ std::lock_guard<std::mutex> guard(sLock);
+ if (!mSessions.empty()) {
+ session = std::move(mSessions.front());
+ mSessions.pop_front();
+ } else if (DBG) {
+ ALOGD("Starting without session ticket.");
+ }
+ }
+ if (session) {
+ SSL_set_session(ssl.get(), session.get());
+ }
+
for (;;) {
if (DBG) {
ALOGD("%u Calling SSL_connect", mMark);
@@ -353,7 +391,66 @@
if (DBG) {
ALOGD("%u handshake complete", mMark);
}
- return ssl.release();
+
+ return ssl;
+}
+
+// static
+int DnsTlsTransport::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+ if (!session) {
+ return 0;
+ }
+ if (DBG) {
+ ALOGD("Recording session ticket");
+ }
+ DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
+ SSL_get_ex_data(ssl, 0));
+ if (!xport) {
+ ALOGE("null transport in new session callback");
+ return 0;
+ }
+ xport->recordSession(session);
+ return 1;
+}
+
+void DnsTlsTransport::removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session) {
+ if (DBG) {
+ ALOGD("Removing session ticket");
+ }
+ DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
+ SSL_CTX_get_ex_data(ssl_ctx, 0));
+ if (!xport) {
+ ALOGE("null transport in remove session callback");
+ return;
+ }
+ xport->removeSession(session);
+}
+
+void DnsTlsTransport::recordSession(SSL_SESSION* session) {
+ std::lock_guard<std::mutex> guard(sLock);
+ mSessions.emplace_front(session);
+ if (mSessions.size() > 5) {
+ if (DBG) {
+ ALOGD("Too many sessions; trimming");
+ }
+ mSessions.pop_back();
+ }
+}
+
+void DnsTlsTransport::removeSession(SSL_SESSION* session) {
+ std::lock_guard<std::mutex> guard(sLock);
+ if (session) {
+ // TODO: Consider implementing targeted removal.
+ mSessions.clear();
+ }
+}
+
+void DnsTlsTransport::sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd) {
+ if (ssl) {
+ SSL_shutdown(ssl.get());
+ ssl.reset();
+ }
+ fd.reset();
}
bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
@@ -425,48 +522,101 @@
}
// static
+std::mutex DnsTlsTransport::sLock;
+std::map<DnsTlsTransport::Key, std::unique_ptr<DnsTlsTransport>> DnsTlsTransport::sStore;
DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
- // TODO: Keep a static container of transports instead of constructing a new one
- // for every query.
- DnsTlsTransport xport(server, mark);
- return xport.doQuery(query, qlen, response, limit, resplen);
+ const Key key = std::make_pair(mark, server);
+ DnsTlsTransport* xport;
+ {
+ std::lock_guard<std::mutex> guard(sLock);
+ auto it = sStore.find(key);
+ if (it == sStore.end()) {
+ xport = new DnsTlsTransport(server, mark);
+ if (!xport->initialize()) {
+ return DnsTlsTransport::Response::internal_error;
+ }
+ sStore[key].reset(xport);
+ } else {
+ xport = it->second.get();
+ }
+ ++xport->mUseCount;
+ }
+
+ Response res = xport->doQuery(query, qlen, response, limit, resplen);
+ auto now = std::chrono::steady_clock::now();
+ {
+ std::lock_guard<std::mutex> guard(sLock);
+ --xport->mUseCount;
+ xport->mLastUsed = now;
+ cleanup(now);
+ }
+ return res;
+}
+
+static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
+std::chrono::time_point<std::chrono::steady_clock> DnsTlsTransport::sLastCleanup;
+void DnsTlsTransport::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
+ if (now - sLastCleanup < IDLE_TIMEOUT) {
+ return;
+ }
+ for (auto it = sStore.begin(); it != sStore.end(); ) {
+ auto& xport = it->second;
+ if (xport->mUseCount == 0 && now - xport->mLastUsed > IDLE_TIMEOUT) {
+ it = sStore.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ sLastCleanup = now;
}
DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
uint8_t *response, size_t limit, int *resplen) {
- *resplen = 0; // Zero indicates an error.
-
- if (DBG) {
- ALOGD("%u connecting TCP socket", mMark);
+ android::base::unique_fd fd = makeConnectedSocket();
+ if (fd.get() < 0) {
+ ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
+ return Response::network_error;
}
- android::base::unique_fd fd(makeConnectedSocket());
- if (DBG) {
- ALOGD("%u connecting SSL", mMark);
- }
- bssl::UniquePtr<SSL> ssl(sslConnect(fd));
- if (ssl == nullptr) {
- if (DBG) {
- ALOGW("%u SSL connection failed", mMark);
- }
+ bssl::UniquePtr<SSL> ssl = sslConnect(fd.get());
+ if (!ssl) {
return Response::network_error;
}
+ Response res = sendQuery(fd.get(), ssl.get(), query, qlen);
+ if (res == Response::success) {
+ res = readResponse(fd.get(), ssl.get(), query, response, limit, resplen);
+ }
+
+ sslDisconnect(std::move(ssl), std::move(fd));
+ return res;
+}
+
+DnsTlsTransport::Response DnsTlsTransport::sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen) {
+ if (DBG) {
+ ALOGD("sending query");
+ }
uint8_t queryHeader[2];
queryHeader[0] = qlen >> 8;
queryHeader[1] = qlen;
- if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
+ if (!sslWrite(fd, ssl, queryHeader, 2)) {
return Response::network_error;
}
- if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
+ if (!sslWrite(fd, ssl, query, qlen)) {
return Response::network_error;
}
if (DBG) {
ALOGD("%u SSL_write complete", mMark);
}
+ return Response::success;
+}
+DnsTlsTransport::Response DnsTlsTransport::readResponse(int fd, SSL* ssl, const uint8_t *query, uint8_t *response, size_t limit, int *resplen) {
+ if (DBG) {
+ ALOGD("reading response");
+ }
uint8_t responseHeader[2];
- if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
+ if (!sslRead(fd, ssl, responseHeader, 2)) {
if (DBG) {
ALOGW("%u Failed to read 2-byte length header", mMark);
}
@@ -480,7 +630,7 @@
ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
return Response::limit_error;
}
- if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
+ if (!sslRead(fd, ssl, response, responseSize)) {
if (DBG) {
ALOGW("%u Failed to read %i bytes", mMark, responseSize);
}
@@ -495,8 +645,6 @@
return Response::internal_error;
}
- SSL_shutdown(ssl.get());
-
*resplen = responseSize;
return Response::success;
}
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
index 5a066a3..d526b1a 100644
--- a/server/dns/DnsTlsTransport.h
+++ b/server/dns/DnsTlsTransport.h
@@ -17,14 +17,20 @@
#ifndef _DNS_DNSTLSTRANSPORT_H
#define _DNS_DNSTLSTRANSPORT_H
+#include <deque>
+#include <memory>
+#include <map>
+#include <mutex>
#include <netinet/in.h>
+#include <openssl/ssl.h>
#include <set>
#include <string>
#include <sys/socket.h>
#include <sys/types.h>
#include <vector>
-#include "android-base/unique_fd.h"
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
// Forward declaration.
typedef struct ssl_st SSL;
@@ -34,6 +40,7 @@
class DnsTlsTransport {
public:
+ ~DnsTlsTransport() {}
struct Server {
// Default constructor
Server() {}
@@ -66,16 +73,31 @@
DnsTlsTransport(const Server& server, unsigned mark)
: mMark(mark), mServer(server)
{}
- ~DnsTlsTransport() {}
+
+ // Cache of reusable DnsTlsTransports. Transports stay in cache as long as
+ // they are in use and for a few minutes after.
+ // The key is a (netid, server) pair. The netid is first for lexicographic comparison speed.
+ typedef std::pair<unsigned, const Server> Key;
+ static std::mutex sLock;
+ static std::map<Key, std::unique_ptr<DnsTlsTransport>> sStore GUARDED_BY(sLock);
+ static std::chrono::time_point<std::chrono::steady_clock> sLastCleanup GUARDED_BY(sLock);
+ static void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+
+ // Creates the SSL context for this transport. Returns false on failure.
+ bool initialize() REQUIRES(sLock);
Response doQuery(const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+ Response sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen);
+ Response readResponse(int fd, SSL* ssl, const uint8_t *query, uint8_t *ans, size_t anssiz,
+ int *resplen);
// On success, returns a non-blocking socket connected to mAddr (the
// connection will likely be in progress if mProtocol is IPPROTO_TCP).
// On error, returns -1 with errno set appropriately.
android::base::unique_fd makeConnectedSocket() const;
- SSL* sslConnect(int fd);
+ bssl::UniquePtr<SSL> sslConnect(int fd);
+ void sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd);
// Writes a buffer to the socket.
bool sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len);
@@ -84,8 +106,24 @@
// Returns false if the socket closes before enough bytes can be read.
bool sslRead(int fd, SSL *ssl, uint8_t *buffer, int len);
+ // There is a 1:1:1 correspondence between Key, DnsTlsTransport, and SSL_CTX.
+ // Using SSL_CTX to create new SSL objects is thread-safe.
+ bssl::UniquePtr<SSL_CTX> mSslCtx;
+
const unsigned mMark; // Socket mark
const Server mServer;
+
+ // Cache of recently seen SSL_SESSIONs. This is used to support session tickets.
+ static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
+ void recordSession(SSL_SESSION* session);
+ static void removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session);
+ void removeSession(SSL_SESSION* session);
+ std::deque<bssl::UniquePtr<SSL_SESSION>> mSessions GUARDED_BY(sLock);
+
+ // This use counter and timestamp are used to ensure that only idle transports are
+ // destroyed.
+ int mUseCount GUARDED_BY(sLock) = 0;
+ std::chrono::time_point<std::chrono::steady_clock> mLastUsed GUARDED_BY(sLock);
};
// This comparison ignores ports, names, and fingerprints.
diff --git a/tests/dns_responder/Android.mk b/tests/dns_responder/Android.mk
index b9bc8f8..e50eff5 100644
--- a/tests/dns_responder/Android.mk
+++ b/tests/dns_responder/Android.mk
@@ -18,7 +18,7 @@
# TODO describe library here
include $(CLEAR_VARS)
LOCAL_MODULE := libnetd_test_dnsresponder
-LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter
+LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter -Wthread-safety
# Bug: http://b/29823425 Disable -Wvarargs for Clang update to r271374
LOCAL_CFLAGS += -Wno-varargs
diff --git a/tests/dns_responder/dns_responder.cpp b/tests/dns_responder/dns_responder.cpp
index 8704bdb..21f5b36 100644
--- a/tests/dns_responder/dns_responder.cpp
+++ b/tests/dns_responder/dns_responder.cpp
@@ -791,6 +791,7 @@
bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
std::vector<DNSRecord>* answers) const {
+ std::lock_guard<std::mutex> guard(mappings_mutex_);
auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
if (it == mappings_.end()) {
// TODO(imaipi): handle correctly
diff --git a/tests/dns_responder/dns_responder.h b/tests/dns_responder/dns_responder.h
index ccb63d9..77a33e5 100644
--- a/tests/dns_responder/dns_responder.h
+++ b/tests/dns_responder/dns_responder.h
@@ -116,8 +116,7 @@
// mutex protecting them.
std::unordered_map<QueryKey, std::string, QueryKeyHash> mappings_
GUARDED_BY(mappings_mutex_);
- // TODO(imaipi): enable GUARDED_BY(mappings_mutex_);
- std::mutex mappings_mutex_;
+ mutable std::mutex mappings_mutex_;
// Query names received so far and the corresponding mutex.
mutable std::vector<std::pair<std::string, ns_type>> queries_
GUARDED_BY(queries_mutex_);
@@ -127,7 +126,7 @@
// File descriptor for epoll.
int epoll_fd_;
// Signal for request handler termination.
- std::atomic<bool> terminate_ GUARDED_BY(update_mutex_);
+ std::atomic<bool> terminate_;
// Thread for handling incoming threads.
std::thread handler_thread_ GUARDED_BY(update_mutex_);
std::mutex update_mutex_;
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 0a2556c..b4630cf 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -75,7 +75,7 @@
int socket_ = -1;
int backend_socket_ = -1;
std::atomic<int> queries_;
- std::atomic<bool> terminate_ GUARDED_BY(update_mutex_);
+ std::atomic<bool> terminate_;
std::thread handler_thread_ GUARDED_BY(update_mutex_);
std::mutex update_mutex_;
int chain_length_ = 1;