Merge "Add support for session tickets"
diff --git a/server/Android.mk b/server/Android.mk
index 93cd1fd..cd67f87 100644
--- a/server/Android.mk
+++ b/server/Android.mk
@@ -19,7 +19,7 @@
 ###
 include $(CLEAR_VARS)
 
-LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CFLAGS := -Wall -Werror -Wthread-safety
 LOCAL_SANITIZE := unsigned-integer-overflow
 LOCAL_MODULE := libnetdaidl_static
 LOCAL_SHARED_LIBRARIES := \
@@ -36,7 +36,7 @@
 
 include $(CLEAR_VARS)
 
-LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CFLAGS := -Wall -Werror -Wthread-safety
 LOCAL_SANITIZE := unsigned-integer-overflow
 LOCAL_MODULE := libnetdaidl
 LOCAL_SHARED_LIBRARIES := \
@@ -58,7 +58,7 @@
         external/mdnsresponder/mDNSShared \
         system/netd/include \
 
-LOCAL_CPPFLAGS := -Wall -Werror
+LOCAL_CPPFLAGS := -Wall -Werror -Wthread-safety
 LOCAL_SANITIZE := unsigned-integer-overflow
 LOCAL_MODULE := netd
 
@@ -147,7 +147,7 @@
 ###
 include $(CLEAR_VARS)
 
-LOCAL_CFLAGS := -Wall -Werror
+LOCAL_CFLAGS := -Wall -Werror -Wthread-safety
 LOCAL_SANITIZE := unsigned-integer-overflow
 LOCAL_CLANG := true
 LOCAL_MODULE := ndc
@@ -163,7 +163,7 @@
 LOCAL_MODULE := netd_unit_test
 LOCAL_COMPATIBILITY_SUITE := device-tests
 LOCAL_SANITIZE := unsigned-integer-overflow
-LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter
+LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter -Wthread-safety
 # Bug: http://b/29823425 Disable -Wvarargs for Clang update to r271374
 LOCAL_CFLAGS += -Wno-varargs
 
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index b369022..542b4a9 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -22,7 +22,6 @@
 #include <arpa/nameser.h>
 #include <errno.h>
 #include <openssl/err.h>
-#include <openssl/ssl.h>
 #include <stdlib.h>
 
 #define LOG_TAG "DnsTlsTransport"
@@ -143,6 +142,9 @@
 }  // namespace
 
 android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
+    if (DBG) {
+        ALOGD("%u connecting TCP socket", mMark);
+    }
     android::base::unique_fd fd;
     int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
     switch (mServer.protocol) {
@@ -168,6 +170,11 @@
         fd.reset();
     }
 
+    if (!setNonBlocking(fd, false)) {
+        ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
+        fd.reset();
+    }
+
     return fd;
 }
 
@@ -231,28 +238,45 @@
     return make_tie(*this) == make_tie(other);
 }
 
-SSL* DnsTlsTransport::sslConnect(int fd) {
-    if (fd < 0) {
-        ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
+bool DnsTlsTransport::initialize() {
+    mSslCtx.reset(SSL_CTX_new(TLS_method()));
+    if (!mSslCtx) {
+        return false;
+    }
+    SSL_CTX_sess_set_new_cb(mSslCtx.get(), DnsTlsTransport::newSessionCallback);
+    SSL_CTX_sess_set_remove_cb(mSslCtx.get(), DnsTlsTransport::removeSessionCallback);
+    return true;
+}
+
+bssl::UniquePtr<SSL> DnsTlsTransport::sslConnect(int fd) {
+    // Check TLS context.
+    if (!mSslCtx) {
+        ALOGE("Internal error: context is null in ssl connect");
+        return nullptr;
+    }
+    if (!SSL_CTX_set_max_proto_version(mSslCtx.get(), TLS1_3_VERSION) ||
+        !SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
+        ALOGE("failed to min/max TLS versions");
         return nullptr;
     }
 
-    // Set up TLS context.
-    bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
-    if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
-        !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
-        ALOGD("failed to min/max TLS versions");
-        return nullptr;
-    }
-
-    bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
+    bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
     // This file descriptor is owned by a unique_fd, so don't let libssl close it.
     bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
     SSL_set_bio(ssl.get(), bio.get(), bio.get());
     bio.release();
 
-    if (!setNonBlocking(fd, false)) {
-        ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
+    // Add this transport as the 0-index extra data for the socket.
+    // This is used by newSessionCallback.
+    if (SSL_set_ex_data(ssl.get(), 0, this) != 1) {
+        ALOGE("failed to associate SSL socket to transport");
+        return nullptr;
+    }
+
+    // Add this transport as the 0-index extra data for the context.
+    // This is used by removeSessionCallback.
+    if (SSL_CTX_set_ex_data(mSslCtx.get(), 0, this) != 1) {
+        ALOGE("failed to associate SSL context to transport");
         return nullptr;
     }
 
@@ -267,6 +291,20 @@
         SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
     }
 
+    bssl::UniquePtr<SSL_SESSION> session;
+    {
+        std::lock_guard<std::mutex> guard(sLock);
+        if (!mSessions.empty()) {
+            session = std::move(mSessions.front());
+            mSessions.pop_front();
+        } else if (DBG) {
+            ALOGD("Starting without session ticket.");
+        }
+    }
+    if (session) {
+        SSL_set_session(ssl.get(), session.get());
+    }
+
     for (;;) {
         if (DBG) {
             ALOGD("%u Calling SSL_connect", mMark);
@@ -353,7 +391,66 @@
     if (DBG) {
         ALOGD("%u handshake complete", mMark);
     }
-    return ssl.release();
+
+    return ssl;
+}
+
+// static
+int DnsTlsTransport::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+    if (!session) {
+        return 0;
+    }
+    if (DBG) {
+        ALOGD("Recording session ticket");
+    }
+    DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
+            SSL_get_ex_data(ssl, 0));
+    if (!xport) {
+        ALOGE("null transport in new session callback");
+        return 0;
+    }
+    xport->recordSession(session);
+    return 1;
+}
+
+void DnsTlsTransport::removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session) {
+    if (DBG) {
+        ALOGD("Removing session ticket");
+    }
+    DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
+            SSL_CTX_get_ex_data(ssl_ctx, 0));
+    if (!xport) {
+        ALOGE("null transport in remove session callback");
+        return;
+    }
+    xport->removeSession(session);
+}
+
+void DnsTlsTransport::recordSession(SSL_SESSION* session) {
+    std::lock_guard<std::mutex> guard(sLock);
+    mSessions.emplace_front(session);
+    if (mSessions.size() > 5) {
+        if (DBG) {
+            ALOGD("Too many sessions; trimming");
+        }
+        mSessions.pop_back();
+    }
+}
+
+void DnsTlsTransport::removeSession(SSL_SESSION* session) {
+    std::lock_guard<std::mutex> guard(sLock);
+    if (session) {
+        // TODO: Consider implementing targeted removal.
+        mSessions.clear();
+    }
+}
+
+void DnsTlsTransport::sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd) {
+    if (ssl) {
+        SSL_shutdown(ssl.get());
+        ssl.reset();
+    }
+    fd.reset();
 }
 
 bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
@@ -425,48 +522,101 @@
 }
 
 // static
+std::mutex DnsTlsTransport::sLock;
+std::map<DnsTlsTransport::Key, std::unique_ptr<DnsTlsTransport>> DnsTlsTransport::sStore;
 DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
         const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
-    // TODO: Keep a static container of transports instead of constructing a new one
-    // for every query.
-    DnsTlsTransport xport(server, mark);
-    return xport.doQuery(query, qlen, response, limit, resplen);
+    const Key key = std::make_pair(mark, server);
+    DnsTlsTransport* xport;
+    {
+        std::lock_guard<std::mutex> guard(sLock);
+        auto it = sStore.find(key);
+        if (it == sStore.end()) {
+            xport = new DnsTlsTransport(server, mark);
+            if (!xport->initialize()) {
+                return DnsTlsTransport::Response::internal_error;
+            }
+            sStore[key].reset(xport);
+        } else {
+            xport = it->second.get();
+        }
+        ++xport->mUseCount;
+    }
+
+    Response res = xport->doQuery(query, qlen, response, limit, resplen);
+    auto now = std::chrono::steady_clock::now();
+    {
+        std::lock_guard<std::mutex> guard(sLock);
+        --xport->mUseCount;
+        xport->mLastUsed = now;
+        cleanup(now);
+    }
+    return res;
+}
+
+static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
+std::chrono::time_point<std::chrono::steady_clock> DnsTlsTransport::sLastCleanup;
+void DnsTlsTransport::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
+    if (now - sLastCleanup < IDLE_TIMEOUT) {
+        return;
+    }
+    for (auto it = sStore.begin(); it != sStore.end(); ) {
+        auto& xport = it->second;
+        if (xport->mUseCount == 0 && now - xport->mLastUsed > IDLE_TIMEOUT) {
+            it = sStore.erase(it);
+        } else {
+            ++it;
+        }
+    }
+    sLastCleanup = now;
 }
 
 DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
         uint8_t *response, size_t limit, int *resplen) {
-    *resplen = 0;  // Zero indicates an error.
-
-    if (DBG) {
-        ALOGD("%u connecting TCP socket", mMark);
+    android::base::unique_fd fd = makeConnectedSocket();
+    if (fd.get() < 0) {
+        ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
+        return Response::network_error;
     }
-    android::base::unique_fd fd(makeConnectedSocket());
-    if (DBG) {
-        ALOGD("%u connecting SSL", mMark);
-    }
-    bssl::UniquePtr<SSL> ssl(sslConnect(fd));
-    if (ssl == nullptr) {
-        if (DBG) {
-            ALOGW("%u SSL connection failed", mMark);
-        }
+    bssl::UniquePtr<SSL> ssl = sslConnect(fd.get());
+    if (!ssl) {
         return Response::network_error;
     }
 
+    Response res = sendQuery(fd.get(), ssl.get(), query, qlen);
+    if (res == Response::success) {
+        res = readResponse(fd.get(), ssl.get(), query, response, limit, resplen);
+    }
+
+    sslDisconnect(std::move(ssl), std::move(fd));
+    return res;
+}
+
+DnsTlsTransport::Response DnsTlsTransport::sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen) {
+    if (DBG) {
+        ALOGD("sending query");
+    }
     uint8_t queryHeader[2];
     queryHeader[0] = qlen >> 8;
     queryHeader[1] = qlen;
-    if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
+    if (!sslWrite(fd, ssl, queryHeader, 2)) {
         return Response::network_error;
     }
-    if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
+    if (!sslWrite(fd, ssl, query, qlen)) {
         return Response::network_error;
     }
     if (DBG) {
         ALOGD("%u SSL_write complete", mMark);
     }
+    return Response::success;
+}
 
+DnsTlsTransport::Response DnsTlsTransport::readResponse(int fd, SSL* ssl, const uint8_t *query, uint8_t *response, size_t limit, int *resplen) {
+    if (DBG) {
+        ALOGD("reading response");
+    }
     uint8_t responseHeader[2];
-    if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
+    if (!sslRead(fd, ssl, responseHeader, 2)) {
         if (DBG) {
             ALOGW("%u Failed to read 2-byte length header", mMark);
         }
@@ -480,7 +630,7 @@
         ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
         return Response::limit_error;
     }
-    if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
+    if (!sslRead(fd, ssl, response, responseSize)) {
         if (DBG) {
             ALOGW("%u Failed to read %i bytes", mMark, responseSize);
         }
@@ -495,8 +645,6 @@
         return Response::internal_error;
     }
 
-    SSL_shutdown(ssl.get());
-
     *resplen = responseSize;
     return Response::success;
 }
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
index 5a066a3..d526b1a 100644
--- a/server/dns/DnsTlsTransport.h
+++ b/server/dns/DnsTlsTransport.h
@@ -17,14 +17,20 @@
 #ifndef _DNS_DNSTLSTRANSPORT_H
 #define _DNS_DNSTLSTRANSPORT_H
 
+#include <deque>
+#include <memory>
+#include <map>
+#include <mutex>
 #include <netinet/in.h>
+#include <openssl/ssl.h>
 #include <set>
 #include <string>
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <vector>
 
-#include "android-base/unique_fd.h"
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
 
 // Forward declaration.
 typedef struct ssl_st SSL;
@@ -34,6 +40,7 @@
 
 class DnsTlsTransport {
 public:
+    ~DnsTlsTransport() {}
     struct Server {
         // Default constructor
         Server() {}
@@ -66,16 +73,31 @@
     DnsTlsTransport(const Server& server, unsigned mark)
             : mMark(mark), mServer(server)
             {}
-    ~DnsTlsTransport() {}
+
+    // Cache of reusable DnsTlsTransports.  Transports stay in cache as long as
+    // they are in use and for a few minutes after.
+    // The key is a (netid, server) pair.  The netid is first for lexicographic comparison speed.
+    typedef std::pair<unsigned, const Server> Key;
+    static std::mutex sLock;
+    static std::map<Key, std::unique_ptr<DnsTlsTransport>> sStore GUARDED_BY(sLock);
+    static std::chrono::time_point<std::chrono::steady_clock> sLastCleanup GUARDED_BY(sLock);
+    static void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+
+    // Creates the SSL context for this transport.  Returns false on failure.
+    bool initialize() REQUIRES(sLock);
 
     Response doQuery(const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+    Response sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen);
+    Response readResponse(int fd, SSL* ssl, const uint8_t *query, uint8_t *ans, size_t anssiz,
+            int *resplen);
 
     // On success, returns a non-blocking socket connected to mAddr (the
     // connection will likely be in progress if mProtocol is IPPROTO_TCP).
     // On error, returns -1 with errno set appropriately.
     android::base::unique_fd makeConnectedSocket() const;
 
-    SSL* sslConnect(int fd);
+    bssl::UniquePtr<SSL> sslConnect(int fd);
+    void sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd);
 
     // Writes a buffer to the socket.
     bool sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len);
@@ -84,8 +106,24 @@
     // Returns false if the socket closes before enough bytes can be read.
     bool sslRead(int fd, SSL *ssl, uint8_t *buffer, int len);
 
+    // There is a 1:1:1 correspondence between Key, DnsTlsTransport, and SSL_CTX.
+    // Using SSL_CTX to create new SSL objects is thread-safe.
+    bssl::UniquePtr<SSL_CTX> mSslCtx;
+
     const unsigned mMark;  // Socket mark
     const Server mServer;
+
+    // Cache of recently seen SSL_SESSIONs.  This is used to support session tickets.
+    static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
+    void recordSession(SSL_SESSION* session);
+    static void removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session);
+    void removeSession(SSL_SESSION* session);
+    std::deque<bssl::UniquePtr<SSL_SESSION>> mSessions GUARDED_BY(sLock);
+
+    // This use counter and timestamp are used to ensure that only idle transports are
+    // destroyed.
+    int mUseCount GUARDED_BY(sLock) = 0;
+    std::chrono::time_point<std::chrono::steady_clock> mLastUsed GUARDED_BY(sLock);
 };
 
 // This comparison ignores ports, names, and fingerprints.
diff --git a/tests/dns_responder/Android.mk b/tests/dns_responder/Android.mk
index b9bc8f8..e50eff5 100644
--- a/tests/dns_responder/Android.mk
+++ b/tests/dns_responder/Android.mk
@@ -18,7 +18,7 @@
 # TODO describe library here
 include $(CLEAR_VARS)
 LOCAL_MODULE := libnetd_test_dnsresponder
-LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter
+LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter -Wthread-safety
 # Bug: http://b/29823425 Disable -Wvarargs for Clang update to r271374
 LOCAL_CFLAGS += -Wno-varargs
 
diff --git a/tests/dns_responder/dns_responder.cpp b/tests/dns_responder/dns_responder.cpp
index 8704bdb..21f5b36 100644
--- a/tests/dns_responder/dns_responder.cpp
+++ b/tests/dns_responder/dns_responder.cpp
@@ -791,6 +791,7 @@
 
 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
                                     std::vector<DNSRecord>* answers) const {
+    std::lock_guard<std::mutex> guard(mappings_mutex_);
     auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
     if (it == mappings_.end()) {
         // TODO(imaipi): handle correctly
diff --git a/tests/dns_responder/dns_responder.h b/tests/dns_responder/dns_responder.h
index ccb63d9..77a33e5 100644
--- a/tests/dns_responder/dns_responder.h
+++ b/tests/dns_responder/dns_responder.h
@@ -116,8 +116,7 @@
     // mutex protecting them.
     std::unordered_map<QueryKey, std::string, QueryKeyHash> mappings_
         GUARDED_BY(mappings_mutex_);
-    // TODO(imaipi): enable GUARDED_BY(mappings_mutex_);
-    std::mutex mappings_mutex_;
+    mutable std::mutex mappings_mutex_;
     // Query names received so far and the corresponding mutex.
     mutable std::vector<std::pair<std::string, ns_type>> queries_
         GUARDED_BY(queries_mutex_);
@@ -127,7 +126,7 @@
     // File descriptor for epoll.
     int epoll_fd_;
     // Signal for request handler termination.
-    std::atomic<bool> terminate_ GUARDED_BY(update_mutex_);
+    std::atomic<bool> terminate_;
     // Thread for handling incoming threads.
     std::thread handler_thread_ GUARDED_BY(update_mutex_);
     std::mutex update_mutex_;
diff --git a/tests/dns_responder/dns_tls_frontend.h b/tests/dns_responder/dns_tls_frontend.h
index 0a2556c..b4630cf 100644
--- a/tests/dns_responder/dns_tls_frontend.h
+++ b/tests/dns_responder/dns_tls_frontend.h
@@ -75,7 +75,7 @@
     int socket_ = -1;
     int backend_socket_ = -1;
     std::atomic<int> queries_;
-    std::atomic<bool> terminate_ GUARDED_BY(update_mutex_);
+    std::atomic<bool> terminate_;
     std::thread handler_thread_ GUARDED_BY(update_mutex_);
     std::mutex update_mutex_;
     int chain_length_ = 1;