Merge "Tcp socket metrics: define INetdEventListener callback"
diff --git a/server/Android.mk b/server/Android.mk
index 8d495e1..f44d92b 100644
--- a/server/Android.mk
+++ b/server/Android.mk
@@ -58,7 +58,7 @@
         external/mdnsresponder/mDNSShared \
         system/netd/include \
 
-LOCAL_CPPFLAGS := -Wall -Werror -Wthread-safety
+LOCAL_CPPFLAGS := -Wall -Werror -Wthread-safety -Wnullable-to-nonnull-conversion
 LOCAL_SANITIZE := unsigned-integer-overflow
 LOCAL_MODULE := netd
 
@@ -143,6 +143,8 @@
         dns/DnsTlsDispatcher.cpp \
         dns/DnsTlsTransport.cpp \
         dns/DnsTlsServer.cpp \
+        dns/DnsTlsSessionCache.cpp \
+        dns/DnsTlsSocket.cpp \
 
 LOCAL_AIDL_INCLUDES := $(LOCAL_PATH)/binder
 
@@ -205,6 +207,11 @@
         binder/android/net/UidRange.cpp \
         binder/android/net/metrics/INetdEventListener.aidl \
         ../tests/tun_interface.cpp \
+        dns/DnsTlsDispatcher.cpp \
+        dns/DnsTlsTransport.cpp \
+        dns/DnsTlsServer.cpp \
+        dns/DnsTlsSessionCache.cpp \
+        dns/DnsTlsSocket.cpp \
 
 LOCAL_MODULE_TAGS := tests
 LOCAL_STATIC_LIBRARIES := libgmock libpcap
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index 7339d93..e3c5cf9 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -37,6 +37,7 @@
 #include <vector>
 
 #include <cutils/log.h>
+#include <netdutils/Slice.h>
 #include <utils/String16.h>
 #include <sysutils/SocketClient.h>
 
@@ -82,6 +83,8 @@
 
 thread_local android_net_context thread_netcontext = {};
 
+DnsTlsDispatcher dnsTlsDispatcher;
+
 res_sendhookact qhook(sockaddr* const * nsap, const u_char** buf, int* buflen,
                       u_char* ans, int anssiz, int* resplen) {
     if (!thread_netcontext.qhook) {
@@ -128,8 +131,10 @@
         if (DBG) {
             ALOGD("Performing query over TLS");
         }
-        auto response = DnsTlsDispatcher::query(tlsServer, thread_netcontext.dns_mark,
-                *buf, *buflen, ans, anssiz, resplen);
+        Slice query = netdutils::Slice(const_cast<u_char*>(*buf), *buflen);
+        Slice answer = netdutils::Slice(const_cast<u_char*>(ans), anssiz);
+        auto response = dnsTlsDispatcher.query(tlsServer, thread_netcontext.dns_mark,
+                query, answer, resplen);
         if (response == DnsTlsTransport::Response::success) {
             if (DBG) {
                 ALOGD("qhook success");
diff --git a/server/dns/DnsTlsDispatcher.cpp b/server/dns/DnsTlsDispatcher.cpp
index 6fed9f0..be9c669 100644
--- a/server/dns/DnsTlsDispatcher.cpp
+++ b/server/dns/DnsTlsDispatcher.cpp
@@ -14,34 +14,53 @@
  * limitations under the License.
  */
 
+#define LOG_TAG "DnsTlsDispatcher"
+//#define LOG_NDEBUG 0
+
 #include "dns/DnsTlsDispatcher.h"
 
+#include "log/log.h"
+
 namespace android {
 namespace net {
 
+using netdutils::Slice;
+
 // static
 std::mutex DnsTlsDispatcher::sLock;
-std::map<DnsTlsDispatcher::Key, std::unique_ptr<DnsTlsDispatcher::Transport>> DnsTlsDispatcher::sStore;
 DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned mark,
-        const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
+                                                  const Slice query,
+                                                  const Slice ans, int *resplen) {
     const Key key = std::make_pair(mark, server);
     Transport* xport;
     {
         std::lock_guard<std::mutex> guard(sLock);
-        auto it = sStore.find(key);
-        if (it == sStore.end()) {
-            xport = new Transport(server, mark);
-            if (!xport->transport.initialize()) {
-                return DnsTlsTransport::Response::internal_error;
-            }
-            sStore[key].reset(xport);
+        auto it = mStore.find(key);
+        if (it == mStore.end()) {
+            xport = new Transport(server, mark, mFactory.get());
+            mStore[key].reset(xport);
         } else {
             xport = it->second.get();
         }
         ++xport->useCount;
     }
 
-    DnsTlsTransport::Response res = xport->transport.query(query, qlen, response, limit, resplen);
+    ALOGV("Sending query of length %zu", query.size());
+    auto result = xport->transport.query(query);
+    DnsTlsTransport::Response code = result.code;
+    if (code == DnsTlsTransport::Response::success) {
+        if (result.response.size() > ans.size()) {
+            ALOGV("Response too large: %zu > %zu", result.response.size(), ans.size());
+            code = DnsTlsTransport::Response::limit_error;
+        } else {
+            ALOGV("Got response successfully");
+            *resplen = result.response.size();
+            netdutils::copy(ans, netdutils::makeSlice(result.response));
+        }
+    } else {
+        ALOGV("Query failed: %u", (unsigned int)code);
+    }
+
     auto now = std::chrono::steady_clock::now();
     {
         std::lock_guard<std::mutex> guard(sLock);
@@ -49,25 +68,27 @@
         xport->lastUsed = now;
         cleanup(now);
     }
-    return res;
+    return code;
 }
 
+// This timeout effectively controls how long to keep SSL session tickets.
 static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
-std::chrono::time_point<std::chrono::steady_clock> DnsTlsDispatcher::sLastCleanup;
 void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
-    if (now - sLastCleanup < IDLE_TIMEOUT) {
+    // To avoid scanning mStore after every query, return early if a cleanup has been
+    // performed recently.
+    if (now - mLastCleanup < IDLE_TIMEOUT) {
         return;
     }
-    for (auto it = sStore.begin(); it != sStore.end(); ) {
+    for (auto it = mStore.begin(); it != mStore.end();) {
         auto& s = it->second;
         if (s->useCount == 0 && now - s->lastUsed > IDLE_TIMEOUT) {
-            it = sStore.erase(it);
+            it = mStore.erase(it);
         } else {
             ++it;
         }
     }
-    sLastCleanup = now;
+    mLastCleanup = now;
 }
 
-}  // namespace net
-}  // namespace android
+}  // end of namespace net
+}  // end of namespace android
diff --git a/server/dns/DnsTlsDispatcher.h b/server/dns/DnsTlsDispatcher.h
index c32bde2..9487b51 100644
--- a/server/dns/DnsTlsDispatcher.h
+++ b/server/dns/DnsTlsDispatcher.h
@@ -23,54 +23,79 @@
 
 #include <android-base/thread_annotations.h>
 
+#include <netdutils/Slice.h>
+
 #include "dns/DnsTlsServer.h"
+#include "dns/DnsTlsSocket.h"
+#include "dns/DnsTlsSocketFactory.h"
+#include "dns/IDnsTlsSocketFactory.h"
 #include "dns/DnsTlsTransport.h"
 
 namespace android {
 namespace net {
 
-// This is a totally static class that manages the collection of active DnsTlsTransports.
+using netdutils::Slice;
+
+// This is a singleton class that manages the collection of active DnsTlsTransports.
 // Queries made here are dispatched to an existing or newly constructed DnsTlsTransport.
 class DnsTlsDispatcher {
 public:
-    // Given a |query| of length |qlen|, sends it to the server on the network indicated by |mark|,
-    // and writes the response into |ans|, which can accept up to |anssiz| bytes.  Indicates
-    // the number of bytes written in |resplen|.  If |resplen| is zero, an
-    // error has occurred.
-    static DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned mark,
-            const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+    // Default constructor.
+    DnsTlsDispatcher() {
+        mFactory.reset(new DnsTlsSocketFactory());
+    }
+    // Constructor with dependency injection for testing.
+    DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory) :
+            mFactory(std::move(factory)) {}
+
+    // Given a |query|, sends it to the server on the network indicated by |mark|,
+    // and writes the response into |ans|,  and indicates
+    // the number of bytes written in |resplen|.  Returns a success or error code.
+    DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned mark,
+                                    const Slice query, const Slice ans, int * _Nonnull resplen);
 
 private:
+    // This lock is static so that it can be used to annotate the Transport struct.
+    // DnsTlsDispatcher is a singleton in practice, so making this static does not change
+    // the locking behavior.
     static std::mutex sLock;
 
+    // Key = <mark, server>
     typedef std::pair<unsigned, const DnsTlsServer> Key;
 
     // Transport is a thin wrapper around DnsTlsTransport, adding reference counting and
-    // idle monitoring so we can expire unused sessions from the cache.
+    // usage monitoring so we can expire idle sessions from the cache.
     struct Transport {
-        Transport(const DnsTlsServer& server, unsigned mark) : transport(server, mark) {}
-        // DnsTlsSession is thread-safe (internally locked), so it doesn't need to be guarded.
+        Transport(const DnsTlsServer& server, unsigned mark,
+                  IDnsTlsSocketFactory* _Nonnull factory) :
+                transport(server, mark, factory) {}
+        // DnsTlsTransport is thread-safe, so it doesn't need to be guarded.
         DnsTlsTransport transport;
         // This use counter and timestamp are used to ensure that only idle sessions are
         // destroyed.
         int useCount GUARDED_BY(sLock) = 0;
+        // lastUsed is only guaranteed to be meaningful after useCount is decremented to zero.
         std::chrono::time_point<std::chrono::steady_clock> lastUsed GUARDED_BY(sLock);
     };
 
     // Cache of reusable DnsTlsTransports.  Transports stay in cache as long as
     // they are in use and for a few minutes after.
     // The key is a (netid, server) pair.  The netid is first for lexicographic comparison speed.
-    static std::map<Key, std::unique_ptr<Transport>> sStore GUARDED_BY(sLock);
+    std::map<Key, std::unique_ptr<Transport>> mStore GUARDED_BY(sLock);
 
     // The last time we did a cleanup.  For efficiency, we only perform a cleanup once every
     // few minutes.
-    static std::chrono::time_point<std::chrono::steady_clock> sLastCleanup GUARDED_BY(sLock);
+    std::chrono::time_point<std::chrono::steady_clock> mLastCleanup GUARDED_BY(sLock);
 
     // Drop any cache entries whose useCount is zero and which have not been used recently.
-    static void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+    // This function performs a linear scan of mStore.
+    void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+
+    // Trivial factory for DnsTlsSockets.  Dependency injection is only used for testing.
+    std::unique_ptr<IDnsTlsSocketFactory> mFactory;
 };
 
-}  // namespace net
-}  // namespace android
+}  // end of namespace net
+}  // end of namespace android
 
 #endif  // _DNS_DNSTLSDISPATCHER_H
diff --git a/server/dns/DnsTlsServer.cpp b/server/dns/DnsTlsServer.cpp
index 38dc6c5..9ac9893 100644
--- a/server/dns/DnsTlsServer.cpp
+++ b/server/dns/DnsTlsServer.cpp
@@ -89,7 +89,6 @@
 namespace net {
 
 // This comparison ignores ports and fingerprints.
-// TODO: respect IPv6 scope id (e.g. link-local addresses).
 bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const {
     if (x.ss.ss_family != y.ss.ss_family) {
         return x.ss.ss_family < y.ss.ss_family;
@@ -102,7 +101,8 @@
     } else if (x.ss.ss_family == AF_INET6) {
         const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
         const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
-        return x_sin6.sin6_addr < y_sin6.sin6_addr;
+        return std::tie(x_sin6.sin6_addr, x_sin6.sin6_scope_id) <
+                std::tie(y_sin6.sin6_addr, y_sin6.sin6_scope_id);
     }
     return false;  // Unknown address type.  This is an error.
 }
diff --git a/server/dns/DnsTlsServer.h b/server/dns/DnsTlsServer.h
index f11c2e3..c9cbd46 100644
--- a/server/dns/DnsTlsServer.h
+++ b/server/dns/DnsTlsServer.h
@@ -35,8 +35,15 @@
     // Allow sockaddr_storage to be promoted to DnsTlsServer automatically.
     DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {}
 
+    enum class Response : uint8_t { success, network_error, limit_error, internal_error };
+
+    struct Result {
+        Response code;
+        std::vector<uint8_t> response;
+    };
+
     // The server location, including IP and port.
-    sockaddr_storage ss;
+    sockaddr_storage ss = {};
 
     // A set of SHA256 public key fingerprints.  If this set is nonempty, the server
     // must present a self-consistent certificate chain that contains a certificate
diff --git a/server/dns/DnsTlsSessionCache.cpp b/server/dns/DnsTlsSessionCache.cpp
new file mode 100644
index 0000000..880b773
--- /dev/null
+++ b/server/dns/DnsTlsSessionCache.cpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ #include "DnsTlsSessionCache.h"
+
+#define LOG_TAG "DnsTlsSessionCache"
+//#define LOG_NDEBUG 0
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+bool DnsTlsSessionCache::prepareSsl(SSL* ssl) {
+    // Add this cache as the 0-index extra data for the socket.
+    // This is used by newSessionCallback.
+    int ret = SSL_set_ex_data(ssl, 0, this);
+    return ret == 1;
+}
+
+void DnsTlsSessionCache::prepareSslContext(SSL_CTX* ssl_ctx) {
+    SSL_CTX_set_session_cache_mode(ssl_ctx, SSL_SESS_CACHE_CLIENT);
+    SSL_CTX_sess_set_new_cb(ssl_ctx, &DnsTlsSessionCache::newSessionCallback);
+}
+
+// static
+int DnsTlsSessionCache::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+    if (!ssl || !session) {
+        ALOGE("Null SSL object in new session callback");
+        return 0;
+    }
+    DnsTlsSessionCache* cache = reinterpret_cast<DnsTlsSessionCache*>(
+            SSL_get_ex_data(ssl, 0));
+    if (!cache) {
+        ALOGE("null transport in new session callback");
+        return 0;
+    }
+    ALOGV("Recording session");
+    cache->recordSession(session);
+    return 1;  // Increment the refcount of session.
+}
+
+void DnsTlsSessionCache::recordSession(SSL_SESSION* session) {
+    std::lock_guard<std::mutex> guard(mLock);
+    mSessions.emplace_front(session);
+    if (mSessions.size() > kMaxSize) {
+        ALOGV("Too many sessions; trimming");
+        mSessions.pop_back();
+    }
+}
+
+bssl::UniquePtr<SSL_SESSION> DnsTlsSessionCache::getSession() {
+    std::lock_guard<std::mutex> guard(mLock);
+    if (mSessions.size() == 0) {
+        ALOGV("No known sessions");
+        return nullptr;
+    }
+    bssl::UniquePtr<SSL_SESSION> ret = std::move(mSessions.front());
+    mSessions.pop_front();
+    return ret;
+}
+
+}  // end of namespace net
+}  // end of namespace android
diff --git a/server/dns/DnsTlsSessionCache.h b/server/dns/DnsTlsSessionCache.h
new file mode 100644
index 0000000..32b8b55
--- /dev/null
+++ b/server/dns/DnsTlsSessionCache.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSESSIONCACHE_H
+#define _DNS_DNSTLSSESSIONCACHE_H
+
+#include <mutex>
+#include <deque>
+
+#include <openssl/ssl.h>
+
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+
+#include "dns/DnsTlsServer.h"
+
+namespace android {
+namespace net {
+
+// Cache of recently seen SSL_SESSIONs.  This is used to support session tickets.
+// This class is thread-safe.
+class DnsTlsSessionCache {
+public:
+    // Prepare SSL objects to use this session cache.  These methods must be called
+    // before making use of either object.
+    void prepareSslContext(SSL_CTX* _Nonnull ssl_ctx);
+    bool prepareSsl(SSL* _Nonnull ssl);
+
+    // Get the most recently discovered session.  For TLS 1.3 compatibility and
+    // maximum privacy, each session will only be returned once, so the caller
+    // gains ownership of the session.  (Here and throughout,
+    // bssl::UniquePtr<SSL_SESSION> is actually serving as a reference counted
+    // pointer.)
+    bssl::UniquePtr<SSL_SESSION> getSession() EXCLUDES(mLock);
+
+private:
+    static constexpr size_t kMaxSize = 5;
+    static int newSessionCallback(SSL* _Nullable ssl, SSL_SESSION* _Nullable session);
+
+    std::mutex mLock;
+    void recordSession(SSL_SESSION* _Nullable session) EXCLUDES(mLock);
+
+    // Queue of sessions, from least recently added to most recently.
+    std::deque<bssl::UniquePtr<SSL_SESSION>> mSessions GUARDED_BY(mLock);
+};
+
+}  // end of namespace net
+}  // end of namespace android
+
+#endif  // _DNS_DNSTLSSESSIONCACHE_H
diff --git a/server/dns/DnsTlsSocket.cpp b/server/dns/DnsTlsSocket.cpp
new file mode 100644
index 0000000..6243968
--- /dev/null
+++ b/server/dns/DnsTlsSocket.cpp
@@ -0,0 +1,417 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "DnsTlsSocket"
+
+#include "dns/DnsTlsSocket.h"
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <arpa/nameser.h>
+#include <errno.h>
+#include <openssl/err.h>
+#include <sys/select.h>
+
+#include "dns/DnsTlsSessionCache.h"
+
+//#define LOG_NDEBUG 0
+
+#include "log/log.h"
+#include "Fwmark.h"
+#undef ADD  // already defined in nameser.h
+#include "NetdConstants.h"
+#include "Permission.h"
+
+
+namespace android {
+namespace net {
+
+using netdutils::Status;
+
+namespace {
+
+constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
+
+int waitForReading(int fd) {
+    fd_set fds;
+    FD_ZERO(&fds);
+    FD_SET(fd, &fds);
+    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
+    ALOGV_IF(ret <= 0, "select failed during read");
+    return ret;
+}
+
+int waitForWriting(int fd) {
+    fd_set fds;
+    FD_ZERO(&fds);
+    FD_SET(fd, &fds);
+    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
+    ALOGV_IF(ret <= 0, "select failed during write");
+    return ret;
+}
+
+}  // namespace
+
+Status DnsTlsSocket::tcpConnect() {
+    ALOGV("%u connecting TCP socket", mMark);
+    int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
+    switch (mServer.protocol) {
+        case IPPROTO_TCP:
+            type |= SOCK_STREAM;
+            break;
+        default:
+            return Status(EPROTONOSUPPORT);
+    }
+
+    mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
+    if (mSslFd.get() == -1) {
+        ALOGE("Failed to create socket");
+        return Status(errno);
+    }
+
+    const socklen_t len = sizeof(mMark);
+    if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
+        ALOGE("Failed to set socket mark");
+        mSslFd.reset();
+        return Status(errno);
+    }
+    if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
+                sizeof(mServer.ss)) != 0 &&
+            errno != EINPROGRESS) {
+        ALOGV("Socket failed to connect");
+        mSslFd.reset();
+        return Status(errno);
+    }
+
+    return netdutils::status::ok;
+}
+
+bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
+    int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
+    unsigned char spki[spki_len];
+    unsigned char* temp = spki;
+    if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
+        ALOGW("SPKI length mismatch");
+        return false;
+    }
+    out->resize(SHA256_SIZE);
+    unsigned int digest_len = 0;
+    int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
+    if (ret != 1) {
+        ALOGW("Server cert digest extraction failed");
+        return false;
+    }
+    if (digest_len != out->size()) {
+        ALOGW("Wrong digest length: %d", digest_len);
+        return false;
+    }
+    return true;
+}
+
+bool DnsTlsSocket::initialize() {
+    // This method should only be called once, at the beginning, so locking should be
+    // unnecessary.  This lock only serves to help catch bugs in code that calls this method.
+    std::lock_guard<std::mutex> guard(mLock);
+    if (mSslCtx) {
+        // This is a bug in the caller.
+        return false;
+    }
+    mSslCtx.reset(SSL_CTX_new(TLS_method()));
+    if (!mSslCtx) {
+        return false;
+    }
+
+    // Load system CA certs for hostname verification.
+    //
+    // For discussion of alternative, sustainable approaches see b/71909242.
+    if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
+        ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
+        return false;
+    }
+
+    // Enable TLS false start
+    SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
+    SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
+
+    // Enable session cache
+    mCache->prepareSslContext(mSslCtx.get());
+
+    // Connect
+    Status status = tcpConnect();
+    if (!status.ok()) {
+        return false;
+    }
+    mSsl = sslConnect(mSslFd.get());
+    if (!mSsl) {
+        return false;
+    }
+
+    return true;
+}
+
+bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
+    if (!mSslCtx) {
+        ALOGE("Internal error: context is null in sslConnect");
+        return nullptr;
+    }
+    if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
+        ALOGE("Failed to set minimum TLS version");
+        return nullptr;
+    }
+
+    bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
+    // This file descriptor is owned by mSslFd, so don't let libssl close it.
+    bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
+    SSL_set_bio(ssl.get(), bio.get(), bio.get());
+    bio.release();
+
+    if (!mCache->prepareSsl(ssl.get())) {
+        return nullptr;
+    }
+
+    if (!mServer.name.empty()) {
+        if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
+            ALOGE("Failed to set SNI to %s", mServer.name.c_str());
+            return nullptr;
+        }
+        X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
+        X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
+        // This will cause the handshake to fail if certificate verification fails.
+        SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
+    }
+
+    bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
+    if (session) {
+        ALOGV("Setting session");
+        SSL_set_session(ssl.get(), session.get());
+    } else {
+        ALOGV("No session available");
+    }
+
+    for (;;) {
+        ALOGV("%u Calling SSL_connect", mMark);
+        int ret = SSL_connect(ssl.get());
+        ALOGV("%u SSL_connect returned %d", mMark, ret);
+        if (ret == 1) break;  // SSL handshake complete;
+
+        const int ssl_err = SSL_get_error(ssl.get(), ret);
+        switch (ssl_err) {
+            case SSL_ERROR_WANT_READ:
+                if (waitForReading(fd) != 1) {
+                    ALOGW("SSL_connect read error");
+                    return nullptr;
+                }
+                break;
+            case SSL_ERROR_WANT_WRITE:
+                if (waitForWriting(fd) != 1) {
+                    ALOGW("SSL_connect write error");
+                    return nullptr;
+                }
+                break;
+            default:
+                ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
+                return nullptr;
+        }
+    }
+
+    // TODO: Call SSL_shutdown before discarding the session if validation fails.
+    if (!mServer.fingerprints.empty()) {
+        ALOGV("Checking DNS over TLS fingerprint");
+
+        // We only care that the chain is internally self-consistent, not that
+        // it chains to a trusted root, so we can ignore some kinds of errors.
+        // TODO: Add a CA root verification mode that respects these errors.
+        int verify_result = SSL_get_verify_result(ssl.get());
+        switch (verify_result) {
+            case X509_V_OK:
+            case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
+            case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
+            case X509_V_ERR_CERT_UNTRUSTED:
+                break;
+            default:
+                ALOGW("Invalid certificate chain, error %d", verify_result);
+                return nullptr;
+        }
+
+        STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
+        if (!chain) {
+            ALOGW("Server has null certificate");
+            return nullptr;
+        }
+        // Chain and its contents are owned by ssl, so we don't need to free explicitly.
+        bool matched = false;
+        for (size_t i = 0; i < sk_X509_num(chain); ++i) {
+            // This appears to be O(N^2), but there doesn't seem to be a straightforward
+            // way to walk a STACK_OF nondestructively in linear time.
+            X509* cert = sk_X509_value(chain, i);
+            std::vector<uint8_t> digest;
+            if (!getSPKIDigest(cert, &digest)) {
+                ALOGE("Digest computation failed");
+                return nullptr;
+            }
+
+            if (mServer.fingerprints.count(digest) > 0) {
+                matched = true;
+                break;
+            }
+        }
+
+        if (!matched) {
+            ALOGW("No matching fingerprint");
+            return nullptr;
+        }
+
+        ALOGV("DNS over TLS fingerprint is correct");
+    }
+
+    ALOGV("%u handshake complete", mMark);
+
+    return ssl;
+}
+
+void DnsTlsSocket::sslDisconnect() {
+    if (mSsl) {
+        SSL_shutdown(mSsl.get());
+        mSsl.reset();
+    }
+    mSslFd.reset();
+}
+
+bool DnsTlsSocket::sslWrite(const Slice buffer) {
+    ALOGV("%u Writing %zu bytes", mMark, buffer.size());
+    for (;;) {
+        int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
+        if (ret == int(buffer.size())) break;  // SSL write complete;
+
+        if (ret < 1) {
+            const int ssl_err = SSL_get_error(mSsl.get(), ret);
+            switch (ssl_err) {
+                case SSL_ERROR_WANT_WRITE:
+                    if (waitForWriting(mSslFd.get()) != 1) {
+                        ALOGV("SSL_write error");
+                        return false;
+                    }
+                    continue;
+                case 0:
+                    break;  // SSL write complete;
+                default:
+                    ALOGV("SSL_write error %d", ssl_err);
+                    return false;
+            }
+        }
+    }
+    ALOGV("%u Wrote %zu bytes", mMark, buffer.size());
+    return true;
+}
+
+DnsTlsSocket::~DnsTlsSocket() {
+    sslDisconnect();
+}
+
+DnsTlsServer::Result DnsTlsSocket::query(uint16_t id, const Slice query) {
+    std::lock_guard<std::mutex> guard(mLock);
+    const Query q = { .id = id, .query = query };
+    if (!sendQuery(q)) {
+        return { .code = DnsTlsServer::Response::network_error };
+    }
+    return readResponse();
+}
+
+// Read exactly len bytes into buffer or fail
+bool DnsTlsSocket::sslRead(const Slice buffer) {
+    size_t remaining = buffer.size();
+    while (remaining > 0) {
+        int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
+        if (ret == 0) {
+            ALOGW_IF(remaining < buffer.size(), "SSL closed with %zu of %zu bytes remaining",
+                     remaining, buffer.size());
+            return false;
+        }
+
+        if (ret < 0) {
+            const int ssl_err = SSL_get_error(mSsl.get(), ret);
+            if (ssl_err == SSL_ERROR_WANT_READ) {
+                if (waitForReading(mSslFd.get()) != 1) {
+                    ALOGV("SSL_read error");
+                    return false;
+                }
+                continue;
+            } else {
+                ALOGV("SSL_read error %d", ssl_err);
+                return false;
+            }
+        }
+
+        remaining -= ret;
+    }
+    return true;
+}
+
+bool DnsTlsSocket::sendQuery(const Query& q) {
+    ALOGV("sending query");
+    // Compose the entire message in a single buffer, so that it can be
+    // sent as a single TLS record.
+    std::vector<uint8_t> buf(q.query.size() + 4);
+    // Write 2-byte length
+    uint16_t len = q.query.size() + 2; // + 2 for the ID.
+    buf[0] = len >> 8;
+    buf[1] = len;
+    // Write 2-byte ID
+    buf[2] = q.id >> 8;
+    buf[3] = q.id;
+    // Copy body
+    std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
+    if (!sslWrite(netdutils::makeSlice(buf))) {
+        return false;
+    }
+    ALOGV("%u SSL_write complete", mMark);
+    return true;
+}
+
+DnsTlsServer::Result DnsTlsSocket::readResponse() {
+    ALOGV("reading response");
+    uint8_t responseHeader[2];
+    const DnsTlsServer::Result failed = { .code = DnsTlsServer::Response::network_error };
+    if (!sslRead(Slice(responseHeader, 2))) {
+        return failed;
+    }
+    // Truncate responses larger than MAX_SIZE.  This is safe because a DNS packet is
+    // always invalid when truncated, so the response will be treated as an error.
+    constexpr uint16_t MAX_SIZE = 8192;
+    const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
+    ALOGV("%u Expecting response of size %i", mMark, responseSize);
+    std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
+    if (!sslRead(netdutils::makeSlice(response))) {
+        ALOGV("%u Failed to read %zu bytes", mMark, response.size());
+        return failed;
+    }
+    uint16_t remainingBytes = responseSize - response.size();
+    while (remainingBytes > 0) {
+        constexpr uint16_t CHUNK_SIZE = 2048;
+        std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
+        if (!sslRead(netdutils::makeSlice(discard))) {
+            ALOGV("%u Failed to discard %zu bytes", mMark, discard.size());
+            return failed;
+        }
+        remainingBytes -= discard.size();
+    }
+    ALOGV("%u SSL_read complete", mMark);
+
+    return { .code = DnsTlsServer::Response::success, .response = response };
+}
+
+}  // end of namespace net
+}  // end of namespace android
diff --git a/server/dns/DnsTlsSocket.h b/server/dns/DnsTlsSocket.h
new file mode 100644
index 0000000..9f923c5
--- /dev/null
+++ b/server/dns/DnsTlsSocket.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSOCKET_H
+#define _DNS_DNSTLSSOCKET_H
+
+#include <future>
+#include <mutex>
+#include <openssl/ssl.h>
+
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+#include <netdutils/Slice.h>
+#include <netdutils/Status.h>
+
+#include "dns/DnsTlsServer.h"
+#include "dns/IDnsTlsSocket.h"
+
+namespace android {
+namespace net {
+
+class DnsTlsSessionCache;
+
+using netdutils::Slice;
+
+// A class for managing a TLS socket that sends and receives messages in
+// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
+class DnsTlsSocket : public IDnsTlsSocket {
+public:
+    DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
+                 DnsTlsSessionCache* _Nonnull cache) :
+            mMark(mark), mServer(server), mCache(cache) {}
+    ~DnsTlsSocket();
+
+    // Creates the SSL context for this session and connect.  Returns false on failure.
+    // This method should be called after construction and before use of a DnsTlsSocket.
+    // Only call this method once per DnsTlsSocket.
+    bool initialize() EXCLUDES(mLock);
+
+    // Send a query on the provided SSL socket.  |query| contains
+    // the body of a query, not including the ID header. Returns the server's response.
+    DnsTlsServer::Result query(uint16_t id, const Slice query) override;
+
+private:
+    // Lock to be held while performing a query.
+    std::mutex mLock;
+
+    // On success, sets mSslFd to a socket connected to mAddr (the
+    // connection will likely be in progress if mProtocol is IPPROTO_TCP).
+    // On error, returns the errno.
+    netdutils::Status tcpConnect() REQUIRES(mLock);
+
+    // Connect an SSL session on the provided socket.  If connection fails, closing the
+    // socket remains the caller's responsibility.
+    bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
+
+    // Disconnect the SSL session and close the socket.
+    void sslDisconnect() REQUIRES(mLock);
+
+    // Writes a buffer to the socket.
+    bool sslWrite(const Slice buffer) REQUIRES(mLock);
+
+    // Reads exactly the specified number of bytes from the socket.  Blocking.
+    // Returns false if the socket closes before enough bytes can be read.
+    bool sslRead(const Slice buffer) REQUIRES(mLock);
+
+    struct Query {
+        uint16_t id;
+        const Slice query;
+    };
+
+    bool sendQuery(const Query& q) REQUIRES(mLock);
+    DnsTlsServer::Result readResponse() REQUIRES(mLock);
+
+    // SSL Socket fields.
+    bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
+    base::unique_fd mSslFd GUARDED_BY(mLock);
+    bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
+
+    const unsigned mMark;  // Socket mark
+    const DnsTlsServer mServer;
+    DnsTlsSessionCache* _Nonnull const mCache;
+};
+
+}  // end of namespace net
+}  // end of namespace android
+
+#endif  // _DNS_DNSTLSSOCKET_H
diff --git a/server/dns/DnsTlsSocketFactory.h b/server/dns/DnsTlsSocketFactory.h
new file mode 100644
index 0000000..9c597a0
--- /dev/null
+++ b/server/dns/DnsTlsSocketFactory.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSOCKETFACTORY_H
+#define _DNS_DNSTLSSOCKETFACTORY_H
+
+#include <memory>
+
+#include "dns/DnsTlsSocket.h"
+#include "dns/IDnsTlsSocketFactory.h"
+
+namespace android {
+namespace net {
+
+class DnsTlsSessionCache;
+struct DnsTlsServer;
+
+// Trivial RAII factory for DnsTlsSocket.  This is owned by DnsTlsDispatcher.
+class DnsTlsSocketFactory : public IDnsTlsSocketFactory {
+public:
+    std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(const DnsTlsServer& server, unsigned mark,
+                                                     DnsTlsSessionCache* _Nonnull cache) override {
+        auto socket = std::make_unique<DnsTlsSocket>(server, mark, cache);
+        if (!socket->initialize()) {
+            return nullptr;
+        }
+        return std::move(socket);
+    }
+};
+
+}  // end of namespace net
+}  // end of namespace android
+
+#endif  // _DNS_DNSTLSSOCKETFACTORY_H
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index 7610c92..4bf33eb 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -14,15 +14,18 @@
  * limitations under the License.
  */
 
+#define LOG_TAG "DnsTlsTransport"
+
 #include "dns/DnsTlsTransport.h"
 
 #include <arpa/inet.h>
 #include <arpa/nameser.h>
-#include <errno.h>
-#include <openssl/err.h>
 
-#define LOG_TAG "DnsTlsTransport"
-#define DBG 0
+#include "dns/DnsTlsServer.h"
+#include "dns/DnsTlsSocketFactory.h"
+#include "dns/IDnsTlsSocketFactory.h"
+
+//#define LOG_NDEBUG 0
 
 #include "log/log.h"
 #include "Fwmark.h"
@@ -30,497 +33,33 @@
 #include "NetdConstants.h"
 #include "Permission.h"
 
-
 namespace android {
 namespace net {
 
-namespace {
-
-constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
-
-bool setNonBlocking(int fd, bool enabled) {
-    int flags = fcntl(fd, F_GETFL);
-    if (flags < 0) return false;
-
-    if (enabled) {
-        flags |= O_NONBLOCK;
-    } else {
-        flags &= ~O_NONBLOCK;
+DnsTlsTransport::Result DnsTlsTransport::query(const netdutils::Slice query) {
+    if (query.size() < 2) {
+        return (Result) { .code = Response::internal_error };
     }
-    return (fcntl(fd, F_SETFL, flags) == 0);
+
+    const uint8_t* data = query.base();
+    uint16_t id = data[0] << 8 | data[1];
+
+    auto socket = mFactory->createDnsTlsSocket(mServer, mMark, &mCache);
+    if (!socket) {
+        return (Result) { .code = Response::network_error };
+    }
+
+    return socket->query(id, netdutils::drop(query, 2));
 }
 
-int waitForReading(int fd) {
-    fd_set fds;
-    FD_ZERO(&fds);
-    FD_SET(fd, &fds);
-    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
-    if (DBG && ret <= 0) {
-        ALOGD("select");
-    }
-    return ret;
-}
-
-int waitForWriting(int fd) {
-    fd_set fds;
-    FD_ZERO(&fds);
-    FD_SET(fd, &fds);
-    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
-    if (DBG && ret <= 0) {
-        ALOGD("select");
-    }
-    return ret;
-}
-
-}  // namespace
-
-android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
-    if (DBG) {
-        ALOGD("%u connecting TCP socket", mMark);
-    }
-    android::base::unique_fd fd;
-    int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
-    switch (mServer.protocol) {
-        case IPPROTO_TCP:
-            type |= SOCK_STREAM;
-            break;
-        default:
-            errno = EPROTONOSUPPORT;
-            return fd;
-    }
-
-    fd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
-    if (fd.get() == -1) {
-        return fd;
-    }
-
-    const socklen_t len = sizeof(mMark);
-    if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
-        fd.reset();
-    } else if (connect(fd.get(),
-            reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0
-        && errno != EINPROGRESS) {
-        fd.reset();
-    }
-
-    if (!setNonBlocking(fd, false)) {
-        ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
-        fd.reset();
-    }
-
-    return fd;
-}
-
-bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
-    int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
-    unsigned char spki[spki_len];
-    unsigned char* temp = spki;
-    if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
-        ALOGW("SPKI length mismatch");
-        return false;
-    }
-    out->resize(SHA256_SIZE);
-    unsigned int digest_len = 0;
-    int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
-    if (ret != 1) {
-        ALOGW("Server cert digest extraction failed");
-        return false;
-    }
-    if (digest_len != out->size()) {
-        ALOGW("Wrong digest length: %d", digest_len);
-        return false;
-    }
-    return true;
-}
-
-bool DnsTlsTransport::initialize() {
-    // This method should only be called once, at the beginning, so locking should be
-    // unnecessary.  This lock only serves to help catch bugs in code that calls this method.
-    std::lock_guard<std::mutex> guard(mLock);
-    if (mSslCtx) {
-        // This is a bug in the caller.
-        return false;
-    }
-    mSslCtx.reset(SSL_CTX_new(TLS_method()));
-    if (!mSslCtx) {
-        return false;
-    }
-
-    // Load system CA certs for hostname verification.
-    //
-    // For discussion of alternative, sustainable approaches see b/71909242.
-    if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
-        ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
-        return false;
-    }
-
-    SSL_CTX_sess_set_new_cb(mSslCtx.get(), DnsTlsTransport::newSessionCallback);
-    SSL_CTX_sess_set_remove_cb(mSslCtx.get(), DnsTlsTransport::removeSessionCallback);
-
-    // Enable TLS false start.
-    SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
-    SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
-    return true;
-}
-
-bssl::UniquePtr<SSL> DnsTlsTransport::sslConnect(int fd) {
-    // Check TLS context.
-    if (!mSslCtx) {
-        ALOGE("Internal error: context is null in ssl connect");
-        return nullptr;
-    }
-    if (!SSL_CTX_set_max_proto_version(mSslCtx.get(), TLS1_3_VERSION) ||
-        !SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
-        ALOGE("failed to min/max TLS versions");
-        return nullptr;
-    }
-
-    bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
-    // This file descriptor is owned by a unique_fd, so don't let libssl close it.
-    bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
-    SSL_set_bio(ssl.get(), bio.get(), bio.get());
-    bio.release();
-
-    // Add this transport as the 0-index extra data for the socket.
-    // This is used by newSessionCallback.
-    if (SSL_set_ex_data(ssl.get(), 0, this) != 1) {
-        ALOGE("failed to associate SSL socket to transport");
-        return nullptr;
-    }
-
-    // Add this transport as the 0-index extra data for the context.
-    // This is used by removeSessionCallback.
-    if (SSL_CTX_set_ex_data(mSslCtx.get(), 0, this) != 1) {
-        ALOGE("failed to associate SSL context to transport");
-        return nullptr;
-    }
-
-    if (!mServer.name.empty()) {
-        if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
-            ALOGE("Failed to set SNI to %s", mServer.name.c_str());
-            return nullptr;
-        }
-        X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
-        X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
-        // This will cause the handshake to fail if certificate verification fails.
-        SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
-    }
-
-    bssl::UniquePtr<SSL_SESSION> session;
-    {
-        std::lock_guard<std::mutex> guard(mLock);
-        if (!mSessions.empty()) {
-            session = std::move(mSessions.front());
-            mSessions.pop_front();
-        } else if (DBG) {
-            ALOGD("Starting without session ticket.");
-        }
-    }
-    if (session) {
-        SSL_set_session(ssl.get(), session.get());
-    }
-
-    for (;;) {
-        if (DBG) {
-            ALOGD("%u Calling SSL_connect", mMark);
-        }
-        int ret = SSL_connect(ssl.get());
-        if (DBG) {
-            ALOGD("%u SSL_connect returned %d", mMark, ret);
-        }
-        if (ret == 1) break;  // SSL handshake complete;
-
-        const int ssl_err = SSL_get_error(ssl.get(), ret);
-        switch (ssl_err) {
-            case SSL_ERROR_WANT_READ:
-                if (waitForReading(fd) != 1) {
-                    ALOGW("SSL_connect read error");
-                    return nullptr;
-                }
-                break;
-            case SSL_ERROR_WANT_WRITE:
-                if (waitForWriting(fd) != 1) {
-                    ALOGW("SSL_connect write error");
-                    return nullptr;
-                }
-                break;
-            default:
-                ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
-                return nullptr;
-        }
-    }
-
-    if (!mServer.fingerprints.empty()) {
-        if (DBG) {
-            ALOGD("Checking DNS over TLS fingerprint");
-        }
-
-        // We only care that the chain is internally self-consistent, not that
-        // it chains to a trusted root, so we can ignore some kinds of errors.
-        // TODO: Add a CA root verification mode that respects these errors.
-        int verify_result = SSL_get_verify_result(ssl.get());
-        switch (verify_result) {
-            case X509_V_OK:
-            case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
-            case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
-            case X509_V_ERR_CERT_UNTRUSTED:
-                break;
-            default:
-                ALOGW("Invalid certificate chain, error %d", verify_result);
-                return nullptr;
-        }
-
-        STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
-        if (!chain) {
-            ALOGW("Server has null certificate");
-            return nullptr;
-        }
-        // Chain and its contents are owned by ssl, so we don't need to free explicitly.
-        bool matched = false;
-        for (size_t i = 0; i < sk_X509_num(chain); ++i) {
-            // This appears to be O(N^2), but there doesn't seem to be a straightforward
-            // way to walk a STACK_OF nondestructively in linear time.
-            X509* cert = sk_X509_value(chain, i);
-            std::vector<uint8_t> digest;
-            if (!getSPKIDigest(cert, &digest)) {
-                ALOGE("Digest computation failed");
-                return nullptr;
-            }
-
-            if (mServer.fingerprints.count(digest) > 0) {
-                matched = true;
-                break;
-            }
-        }
-
-        if (!matched) {
-            ALOGW("No matching fingerprint");
-            return nullptr;
-        }
-
-        if (DBG) {
-            ALOGD("DNS over TLS fingerprint is correct");
-        }
-    }
-
-    if (DBG) {
-        ALOGD("%u handshake complete", mMark);
-    }
-
-    return ssl;
+DnsTlsTransport::~DnsTlsTransport() {
 }
 
 // static
-int DnsTlsTransport::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
-    if (!session) {
-        return 0;
-    }
-    if (DBG) {
-        ALOGD("Recording session ticket");
-    }
-    DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
-            SSL_get_ex_data(ssl, 0));
-    if (!xport) {
-        ALOGE("null transport in new session callback");
-        return 0;
-    }
-    xport->recordSession(session);
-    return 1;
-}
-
-void DnsTlsTransport::removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session) {
-    if (DBG) {
-        ALOGD("Removing session ticket");
-    }
-    DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
-            SSL_CTX_get_ex_data(ssl_ctx, 0));
-    if (!xport) {
-        ALOGE("null transport in remove session callback");
-        return;
-    }
-    xport->removeSession(session);
-}
-
-void DnsTlsTransport::recordSession(SSL_SESSION* session) {
-    std::lock_guard<std::mutex> guard(mLock);
-    mSessions.emplace_front(session);
-    if (mSessions.size() > 5) {
-        if (DBG) {
-            ALOGD("Too many sessions; trimming");
-        }
-        mSessions.pop_back();
-    }
-}
-
-void DnsTlsTransport::removeSession(SSL_SESSION* session) {
-    std::lock_guard<std::mutex> guard(mLock);
-    if (session) {
-        // TODO: Consider implementing targeted removal.
-        mSessions.clear();
-    }
-}
-
-void DnsTlsTransport::sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd) {
-    if (ssl) {
-        SSL_shutdown(ssl.get());
-        ssl.reset();
-    }
-    fd.reset();
-}
-
-bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
-    if (DBG) {
-        ALOGD("%u Writing %d bytes", mMark, len);
-    }
-    for (;;) {
-        int ret = SSL_write(ssl, buffer, len);
-        if (ret == len) break;  // SSL write complete;
-
-        if (ret < 1) {
-            const int ssl_err = SSL_get_error(ssl, ret);
-            switch (ssl_err) {
-                case SSL_ERROR_WANT_WRITE:
-                    if (waitForWriting(fd) != 1) {
-                        if (DBG) {
-                            ALOGW("SSL_write error");
-                        }
-                        return false;
-                    }
-                    continue;
-                case 0:
-                    break;  // SSL write complete;
-                default:
-                    if (DBG) {
-                        ALOGW("SSL_write error %d", ssl_err);
-                    }
-                    return false;
-            }
-        }
-    }
-    if (DBG) {
-        ALOGD("%u Wrote %d bytes", mMark, len);
-    }
-    return true;
-}
-
-// Read exactly len bytes into buffer or fail
-bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
-    int remaining = len;
-    while (remaining > 0) {
-        int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
-        if (ret == 0) {
-            ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
-            return false;
-        }
-
-        if (ret < 0) {
-            const int ssl_err = SSL_get_error(ssl, ret);
-            if (ssl_err == SSL_ERROR_WANT_READ) {
-                if (waitForReading(fd) != 1) {
-                    if (DBG) {
-                        ALOGW("SSL_read error");
-                    }
-                    return false;
-                }
-                continue;
-            } else {
-                if (DBG) {
-                    ALOGW("SSL_read error %d", ssl_err);
-                }
-                return false;
-            }
-        }
-
-        remaining -= ret;
-    }
-    return true;
-}
-
-DnsTlsTransport::Response DnsTlsTransport::query(const uint8_t *query, size_t qlen,
-        uint8_t *response, size_t limit, int *resplen) {
-    android::base::unique_fd fd = makeConnectedSocket();
-    if (fd.get() < 0) {
-        ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
-        return Response::network_error;
-    }
-    bssl::UniquePtr<SSL> ssl = sslConnect(fd.get());
-    if (!ssl) {
-        return Response::network_error;
-    }
-
-    Response res = sendQuery(fd.get(), ssl.get(), query, qlen);
-    if (res == Response::success) {
-        res = readResponse(fd.get(), ssl.get(), query, response, limit, resplen);
-    }
-
-    sslDisconnect(std::move(ssl), std::move(fd));
-    return res;
-}
-
-DnsTlsTransport::Response DnsTlsTransport::sendQuery(int fd, SSL* ssl,
-        const uint8_t *query, size_t qlen) {
-    if (DBG) {
-        ALOGD("sending query");
-    }
-    uint8_t queryHeader[2];
-    queryHeader[0] = qlen >> 8;
-    queryHeader[1] = qlen;
-    if (!sslWrite(fd, ssl, queryHeader, 2)) {
-        return Response::network_error;
-    }
-    if (!sslWrite(fd, ssl, query, qlen)) {
-        return Response::network_error;
-    }
-    if (DBG) {
-        ALOGD("%u SSL_write complete", mMark);
-    }
-    return Response::success;
-}
-
-DnsTlsTransport::Response DnsTlsTransport::readResponse(int fd, SSL* ssl,
-        const uint8_t *query, uint8_t *response, size_t limit, int *resplen) {
-    if (DBG) {
-        ALOGD("reading response");
-    }
-    uint8_t responseHeader[2];
-    if (!sslRead(fd, ssl, responseHeader, 2)) {
-        if (DBG) {
-            ALOGW("%u Failed to read 2-byte length header", mMark);
-        }
-        return Response::network_error;
-    }
-    const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
-    if (DBG) {
-        ALOGD("%u Expecting response of size %i", mMark, responseSize);
-    }
-    if (responseSize > limit) {
-        ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
-        return Response::limit_error;
-    }
-    if (!sslRead(fd, ssl, response, responseSize)) {
-        if (DBG) {
-            ALOGW("%u Failed to read %i bytes", mMark, responseSize);
-        }
-        return Response::network_error;
-    }
-    if (DBG) {
-        ALOGD("%u SSL_read complete", mMark);
-    }
-
-    if (response[0] != query[0] || response[1] != query[1]) {
-        ALOGE("reply query ID != query ID");
-        return Response::internal_error;
-    }
-
-    *resplen = responseSize;
-    return Response::success;
-}
-
-// static
+// TODO: Use this function to preheat the session cache.
+// That may require moving it to DnsTlsDispatcher.
 bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid) {
-    if (DBG) {
-        ALOGD("Beginning validation on %u", netid);
-    }
+    ALOGV("Beginning validation on %u", netid);
     // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
     // order to prove that it is actually a working DNS over TLS server.
     static const char kDnsSafeChars[] =
@@ -551,9 +90,6 @@
     };
     const int qlen = ARRAY_SIZE(query);
 
-    const int kRecvBufSize = 4 * 1024;
-    uint8_t recvbuf[kRecvBufSize];
-
     // At validation time, we only know the netId, so we have to guess/compute the
     // corresponding socket mark.
     Fwmark fwmark;
@@ -563,22 +99,17 @@
     fwmark.netId = netid;
     unsigned mark = fwmark.intValue;
     int replylen = 0;
-    DnsTlsTransport transport(server, mark);
-    if (!transport.initialize()) {
-        return false;
-    }
-    transport.query(query, qlen, recvbuf, kRecvBufSize, &replylen);
-    if (replylen == 0) {
-        if (DBG) {
-            ALOGD("query failed");
-        }
+    DnsTlsSocketFactory factory;
+    DnsTlsTransport transport(server, mark, &factory);
+    auto r = transport.query(Slice(query, qlen));
+    if (r.code != Response::success) {
+        ALOGV("query failed");
         return false;
     }
 
-    if (replylen < NS_HFIXEDSZ) {
-        if (DBG) {
-            ALOGW("short response: %d", replylen);
-        }
+    const std::vector<uint8_t>& recvbuf = r.response;
+    if (recvbuf.size() < NS_HFIXEDSZ) {
+        ALOGW("short response: %d", replylen);
         return false;
     }
 
@@ -589,9 +120,7 @@
     }
 
     const int ancount = (recvbuf[6] << 8) | recvbuf[7];
-    if (DBG) {
-        ALOGD("%u answer count: %d", netid, ancount);
-    }
+    ALOGV("%u answer count: %d", netid, ancount);
 
     // TODO: Further validate the response contents (check for valid AAAA record, ...).
     // Note that currently, integration tests rely on this function accepting a
@@ -604,5 +133,5 @@
     return true;
 }
 
-}  // namespace net
-}  // namespace android
+}  // end of namespace net
+}  // end of namespace android
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
index 8f03c8a..4624be7 100644
--- a/server/dns/DnsTlsTransport.h
+++ b/server/dns/DnsTlsTransport.h
@@ -17,37 +17,29 @@
 #ifndef _DNS_DNSTLSTRANSPORT_H
 #define _DNS_DNSTLSTRANSPORT_H
 
-#include <deque>
-#include <mutex>
-#include <openssl/ssl.h>
-
-#include <android-base/thread_annotations.h>
-#include <android-base/unique_fd.h>
-
+#include "dns/DnsTlsSessionCache.h"
 #include "dns/DnsTlsServer.h"
 
+#include <netdutils/Slice.h>
+
 namespace android {
 namespace net {
 
+class IDnsTlsSocketFactory;
+
 class DnsTlsTransport {
 public:
-    DnsTlsTransport(const DnsTlsServer& server, unsigned mark)
-            : mMark(mark), mServer(server)
-            {}
-    ~DnsTlsTransport() {}
+    DnsTlsTransport(const DnsTlsServer& server, unsigned mark,
+                    IDnsTlsSocketFactory* _Nonnull factory) :
+            mMark(mark), mServer(server), mFactory(factory) {}
+    ~DnsTlsTransport();
 
-    // Creates the SSL context for this session.  Returns false on failure.
-    // This method should be called after construction and before use of a DnsTlsTransport.
-    bool initialize();
-    
-    enum class Response : uint8_t { success, network_error, limit_error, internal_error };
+    typedef DnsTlsServer::Response Response;
+    typedef DnsTlsServer::Result Result;
 
-    // Given a |query| of length |qlen|, this method sends it to the server
-    // and writes the response into |ans|, which can accept up to |anssiz| bytes.
-    // The number of bytes is written to |resplen|.  If |resplen| is zero, an
-    // error has occurred.
-    Response query(const uint8_t *query, size_t qlen,
-            uint8_t *ans, size_t anssiz, int *resplen);
+    // Given a |query|, this method sends it to the server
+    // and returns the server's response synchronously.
+    Result query(const netdutils::Slice query);
 
     // Check that a given TLS server is fully working on the specified netid, and has the
     // provided SHA-256 fingerprint (if nonempty).  This function is used in ResolverController
@@ -55,50 +47,14 @@
     static bool validate(const DnsTlsServer& server, unsigned netid);
 
 private:
-    // Send a query on the provided SSL socket.
-    Response sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen);
-
-    // Wait for the response to |query| on |ssl|, and write it to |ans|, an output buffer
-    // of size |anssiz|.  If |resplen| is zero, the read failed.
-    Response readResponse(int fd, SSL* ssl, const uint8_t *query,
-        uint8_t *ans, size_t anssiz, int *resplen);
-
-    // On success, returns a non-blocking socket connected to mAddr (the
-    // connection will likely be in progress if mProtocol is IPPROTO_TCP).
-    // On error, returns -1 with errno set appropriately.
-    base::unique_fd makeConnectedSocket() const;
-
-    // Connect an SSL session on the provided socket.  If connection fails, closing the
-    // socket remains the caller's responsibility.
-    bssl::UniquePtr<SSL> sslConnect(int fd);
-
-    // Disconnect the SSL session and close the socket.
-    void sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd);
-
-    // Writes a buffer to the socket.
-    bool sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len);
-
-    // Reads exactly the specified number of bytes from the socket.  Blocking.
-    // Returns false if the socket closes before enough bytes can be read.
-    bool sslRead(int fd, SSL *ssl, uint8_t *buffer, int len);
-
-    // Using SSL_CTX to create new SSL objects is thread-safe, so this object does not
-    // require a lock annotation.
-    bssl::UniquePtr<SSL_CTX> mSslCtx;
+    DnsTlsSessionCache mCache;
 
     const unsigned mMark;  // Socket mark
     const DnsTlsServer mServer;
-
-    // Cache of recently seen SSL_SESSIONs.  This is used to support session tickets.
-    static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
-    void recordSession(SSL_SESSION* session);
-    static void removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session);
-    void removeSession(SSL_SESSION* session);
-    std::mutex mLock;
-    std::deque<bssl::UniquePtr<SSL_SESSION>> mSessions GUARDED_BY(mLock);
+    IDnsTlsSocketFactory* _Nonnull const mFactory;
 };
 
-}  // namespace net
-}  // namespace android
+}  // end of namespace net
+}  // end of namespace android
 
 #endif  // _DNS_DNSTLSTRANSPORT_H
diff --git a/server/dns/IDnsTlsSocket.h b/server/dns/IDnsTlsSocket.h
new file mode 100644
index 0000000..1551418
--- /dev/null
+++ b/server/dns/IDnsTlsSocket.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_IDNSTLSSOCKET_H
+#define _DNS_IDNSTLSSOCKET_H
+
+#include <cstdint>
+#include <cstddef>
+
+#include <netdutils/Slice.h>
+
+#include "dns/DnsTlsServer.h"
+
+namespace android {
+namespace net {
+
+class IDnsTlsSocketObserver;
+class DnsTlsSessionCache;
+
+// A class for managing a TLS socket that sends and receives messages in
+// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
+class IDnsTlsSocket {
+public:
+    virtual ~IDnsTlsSocket() {};
+    // Send a query on the provided SSL socket.  |query| contains
+    // the body of a query, not including the ID bytes.  Returns the server's response.
+    virtual DnsTlsServer::Result query(uint16_t id, const netdutils::Slice query) = 0;
+};
+
+}  // end of namespace net
+}  // end of namespace android
+
+#endif  // _DNS_IDNSTLSSOCKET_H
diff --git a/server/dns/IDnsTlsSocketFactory.h b/server/dns/IDnsTlsSocketFactory.h
new file mode 100644
index 0000000..6fc0cfd
--- /dev/null
+++ b/server/dns/IDnsTlsSocketFactory.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_IDNSTLSSOCKETFACTORY_H
+#define _DNS_IDNSTLSSOCKETFACTORY_H
+
+#include "dns/IDnsTlsSocket.h"
+
+namespace android {
+namespace net {
+
+class DnsTlsSessionCache;
+struct DnsTlsServer;
+
+// Dependency injection interface for DnsTlsSocketFactory.
+// This pattern allows mocking of DnsTlsSocket for tests.
+class IDnsTlsSocketFactory {
+public:
+    virtual ~IDnsTlsSocketFactory() {};
+    virtual std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+            const DnsTlsServer& server,
+            unsigned mark,
+            DnsTlsSessionCache* _Nonnull cache) = 0;
+};
+
+}  // end of namespace net
+}  // end of namespace android
+
+#endif  // _DNS_IDNSTLSSOCKETFACTORY_H
diff --git a/server/dns/README.md b/server/dns/README.md
new file mode 100644
index 0000000..23185a0
--- /dev/null
+++ b/server/dns/README.md
@@ -0,0 +1,136 @@
+# DNS-over-TLS query forwarder design
+# NOTE: This design is not yet implemented in this change.
+
+
+## Overview
+
+The DNS-over-TLS query forwarder consists of five classes:
+ * `DnsTlsDispatcher`
+ * `DnsTlsTransport`
+ * `DnsTlsQueryMap`
+ * `DnsTlsSessionCache`
+ * `DnsTlsSocket`
+
+`DnsTlsDispatcher` is a singleton class whose `query` method is the `dns/` directory's
+only public interface.  `DnsTlsDispatcher` is just a table holding the
+`DnsTlsTransport` for each server (represented by a `DnsTlsServer` struct) and
+network.  `DnsTlsDispatcher` also blocks each query thread, waiting on a
+`std::future` returned by `DnsTlsTransport` that represents the response.
+
+`DnsTlsTransport` sends each query over a `DnsTlsSocket`, opening a
+new one if necessary.  It also has to listen for responses from the
+`DnsTlsSocket`, which happen on a different thread.
+`IDnsTlsSocketObserver` is an interface defining how `DnsTlsSocket` returns
+responses to `DnsTlsTransport`.
+
+`DnsTlsQueryMap` and `DnsTlsSessionCache` are helper classes owned by `DnsTlsTransport`.
+`DnsTlsQueryMap` handles ID renumbering and query-response pairing.
+`DnsTlsSessionCache` allows TLS session resumption.
+
+`DnsTlsSocket` interleaves all queries onto a single socket, and reports all
+responses to `DnsTlsTransport` (through the `IDnsTlsObserver` interface).  It doesn't
+know anything about which queries correspond to which responses, and does not retain
+state to indicate whether there is an outstanding query.
+
+## Threading
+
+### Overall patterns
+
+For clarity, each of the five classes in this design is thread-safe and holds one lock.
+Classes that spawn a helper thread call `thread::join()` in their destructor to ensure
+that it is cleaned up appropriately.
+
+All the classes here make full use of Clang thread annotations (and also null-pointer
+annotations) to minimize the likelihood of a latent threading bug.  The unit tests are
+also heavily threaded to exercise this functionality.
+
+This code creates O(1) threads per socket, and does not create a new thread for each
+query or response.  However, bionic's stub resolver does create a thread for each query.
+
+### Threading in `DnsTlsSocket`
+
+`DnsTlsSocket` can receive queries on any thread, and send them over a
+"reliable datagram pipe" (`socketpair()` in `SOCK_SEQPACKET` mode).
+The query method writes a struct (containing a pointer to the query) to the pipe
+from its thread, and the loop thread (which owns the SSL socket)
+reads off the other end of the pipe.  The pipe doesn't actually have a queue "inside";
+instead, any queueing happens by blocking the query thread until the
+socket thread can read the datagram off the other end.
+
+We need to pass messages between threads using a pipe, and not a condition variable
+or a thread-safe queue, because the socket thread has to be blocked
+in `select` waiting for data from the server, but also has to be woken
+up on inputs from the query threads.  Therefore, inputs from the query
+threads have to arrive on a socket, so that `select()` can listen for them.
+(There can only be a single thread because [you can't use different threads
+to read and write in OpenSSL](https://www.openssl.org/blog/blog/2017/02/21/threads/)).
+
+## ID renumbering
+
+`DnsTlsDispatcher` accepts queries that have colliding ID numbers and still sends them on
+a single socket.  To avoid confusion at the server, `DnsTlsQueryMap` assigns each
+query a new ID for transmission, records the mapping from input IDs to sent IDs, and
+applies the inverse mapping to responses before returning them to the caller.
+
+`DnsTlsQueryMap` assigns each new query the ID number one greater than the largest
+ID number of an outstanding query.  This means that ID numbers are initially sequential
+and usually small.  If the largest possible ID number is already in use,
+`DnsTlsQueryMap` will scan the ID space to find an available ID, or fail the query
+if there are no available IDs.  Queries will not block waiting for an ID number to
+become available.
+
+## Time constants
+
+`DnsTlsSocket` imposes a 20-second inactivity timeout.  A socket that has been idle for
+20 seconds will be closed.  This sets the limit of tolerance for slow replies,
+which could happen as a result of malfunctioning authoritative DNS servers.
+If there are any pending queries, `DnsTlsTransport` will retry them.
+
+`DnsTlsQueryMap` imposes a retry limit of 3.  `DnsTlsTransport` will retry the query up
+to 3 times before reporting failure to `DnsTlsDispatcher`.
+This limit helps to ensure proper functioning in the case of a recursive resolver that
+is malfunctioning or is flooded with requests that are stalled due to malfunctioning
+authoritative servers.
+
+`DnsTlsDispatcher` maintains a 5-minute timeout.  Any `DnsTlsTransport` that has had no
+outstanding queries for 5 minutes will be destroyed at the next query on a different
+transport.
+This sets the limit on how long session tickets will be preserved during idle periods,
+because each `DnsTlsTransport` owns a `DnsTlsSessionCache`.  Imposing this timeout
+increases latency on the first query after an idle period, but also helps to avoid
+unbounded memory usage.
+
+`DnsTlsSessionCache` sets a limit of 5 sessions in each cache, expiring the oldest one
+when the limit is reached.  However, because the client code does not currently
+reuse sessions more than once, it should not be possible to hit this limit.
+
+## Testing
+
+Unit tests are in `../tests/dns_tls_test.cpp`.  They cover all the classes except
+`DnsTlsSocket` (which requires `CAP_NET_ADMIN` because it uses `setsockopt(SO_MARK)`) and
+`DnsTlsSessionCache` (which requires integration with libssl).  These classes are
+exercised by the integration tests in `../tests/netd_test.cpp`.
+
+### Dependency Injection
+
+For unit testing, we would like to be able to mock out `DnsTlsSocket`.  This is
+particularly required for unit testing of `DnsTlsDispatcher` and `DnsTlsTransport`.
+To make these unit tests possible, this code uses a dependency injection pattern:
+`DnsTlsSocket` is produced by a `DnsTlsSocketFactory`, and both of these have a
+defined interface.
+
+`DnsTlsDispatcher`'s constructor takes an `IDnsTlsSocketFactory`,
+which in production is a `DnsTlsSocketFactory`.  However, in unit tests, we can
+substitute a test factory that returns a fake socket, so that the unit tests can
+run without actually connecting over TLS to a test server.  (The integration tests
+do actual TLS.)
+
+## Logging
+
+This code uses `ALOGV` throughout for low-priority logging, and does not use
+`ALOGD`.  `ALOGV` is disabled by default, unless activated by `#define LOG_NDEBUG 0`.
+(`ALOGD` is not disabled by default, requiring extra measures to avoid spamming the
+system log in production builds.)
+
+## Reference
+ * [BoringSSL API docs](https://commondatastorage.googleapis.com/chromium-boringssl-docs/headers.html)
diff --git a/tests/Android.mk b/tests/Android.mk
index 9812bcb..5c6fde5 100644
--- a/tests/Android.mk
+++ b/tests/Android.mk
@@ -38,11 +38,18 @@
 # runtest -x system/netd/tests/netd_integration_test.cpp
 LOCAL_SRC_FILES := binder_test.cpp \
                    dns_responder/dns_responder.cpp \
+                   dns_tls_test.cpp \
                    netd_integration_test.cpp \
                    netd_test.cpp \
                    tun_interface.cpp \
                    ../server/NetdConstants.cpp \
-                   ../server/binder/android/net/metrics/INetdEventListener.aidl
+                   ../server/binder/android/net/metrics/INetdEventListener.aidl \
+                   ../server/dns/DnsTlsDispatcher.cpp \
+                   ../server/dns/DnsTlsTransport.cpp \
+                   ../server/dns/DnsTlsServer.cpp \
+                   ../server/dns/DnsTlsSessionCache.cpp \
+                   ../server/dns/DnsTlsSocket.cpp \
+
 LOCAL_MODULE_TAGS := eng tests
 include $(BUILD_NATIVE_TEST)
 
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index b360060..6c29353 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -317,9 +317,14 @@
     }
     const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
     uint8_t query[qlen];
-    if (SSL_read(ssl, &query, qlen) != qlen) {
-        ALOGI("Not enough query bytes");
-        return false;
+    size_t qbytes = 0;
+    while (qbytes < qlen) {
+        int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
+        if (ret <= 0) {
+            ALOGI("Error while reading query");
+            return false;
+        }
+        qbytes += ret;
     }
     int sent = send(backend_socket_, query, qlen, 0);
     if (sent != qlen) {
diff --git a/tests/dns_tls_test.cpp b/tests/dns_tls_test.cpp
new file mode 100644
index 0000000..4cb4d50
--- /dev/null
+++ b/tests/dns_tls_test.cpp
@@ -0,0 +1,372 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "dns_tls_test"
+
+#include <gtest/gtest.h>
+
+#include "dns/DnsTlsDispatcher.h"
+#include "dns/DnsTlsServer.h"
+#include "dns/DnsTlsSessionCache.h"
+#include "dns/DnsTlsSocket.h"
+#include "dns/DnsTlsTransport.h"
+#include "dns/IDnsTlsSocket.h"
+#include "dns/IDnsTlsSocketFactory.h"
+
+#include <chrono>
+#include <arpa/inet.h>
+#include <android-base/macros.h>
+#include <netdutils/Slice.h>
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+using netdutils::Slice;
+using netdutils::makeSlice;
+
+typedef std::vector<uint8_t> bytevec;
+
+static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
+    sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
+    if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
+        // IPv4 parse succeeded, so it's IPv4
+        sin->sin_family = AF_INET;
+        sin->sin_port = htons(port);
+        return;
+    }
+    sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
+    if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
+        // IPv6 parse succeeded, so it's IPv6.
+        sin6->sin6_family = AF_INET6;
+        sin6->sin6_port = htons(port);
+        return;
+    }
+    ALOGE("Failed to parse server address: %s", server);
+}
+
+bytevec FINGERPRINT1 = { 1 };
+bytevec FINGERPRINT2 = { 2 };
+
+std::string SERVERNAME1 = "dns.example.com";
+std::string SERVERNAME2 = "dns.example.org";
+
+// BaseTest just provides constants that are useful for the tests.
+class BaseTest : public ::testing::Test {
+protected:
+    BaseTest() {
+        parseServer("192.0.2.1", 853, &V4ADDR1);
+        parseServer("192.0.2.2", 853, &V4ADDR2);
+        parseServer("2001:db8::1", 853, &V6ADDR1);
+        parseServer("2001:db8::2", 853, &V6ADDR2);
+
+        SERVER1 = DnsTlsServer(V4ADDR1);
+        SERVER1.fingerprints.insert(FINGERPRINT1);
+        SERVER1.name = SERVERNAME1;
+    }
+
+    sockaddr_storage V4ADDR1;
+    sockaddr_storage V4ADDR2;
+    sockaddr_storage V6ADDR1;
+    sockaddr_storage V6ADDR2;
+
+    DnsTlsServer SERVER1;
+};
+
+bytevec make_query(uint16_t id, size_t size) {
+    bytevec vec(size);
+    vec[0] = id >> 8;
+    vec[1] = id;
+    // Arbitrarily fill the query body with unique data.
+    for (size_t i = 2; i < size; ++i) {
+        vec[i] = id + i;
+    }
+    return vec;
+}
+
+// Query constants
+const unsigned MARK = 123;
+const uint16_t ID = 52;
+const uint16_t SIZE = 22;
+const bytevec QUERY = make_query(ID, SIZE);
+
+template <class T>
+class FakeSocketFactory : public IDnsTlsSocketFactory {
+public:
+    FakeSocketFactory() {}
+    std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+            const DnsTlsServer& server ATTRIBUTE_UNUSED,
+            unsigned mark ATTRIBUTE_UNUSED,
+            DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+        return std::make_unique<T>();
+    }
+};
+
+bytevec make_echo(uint16_t id, const Slice query) {
+    bytevec response(query.size() + 2);
+    response[0] = id >> 8;
+    response[1] = id;
+    // Echo the query as the fake response.
+    memcpy(response.data() + 2, query.base(), query.size());
+    return response;
+}
+
+// Simplest possible fake server.  This just echoes the query as the response.
+class FakeSocketEcho : public IDnsTlsSocket {
+public:
+    FakeSocketEcho() {}
+    DnsTlsServer::Result query(uint16_t id, const Slice query) override {
+        // Return the response immediately.
+        return { .code = DnsTlsServer::Response::success, .response = make_echo(id, query) };
+    }
+};
+
+class TransportTest : public BaseTest {};
+
+TEST_F(TransportTest, Query) {
+    FakeSocketFactory<FakeSocketEcho> factory;
+    DnsTlsTransport transport(SERVER1, MARK, &factory);
+    auto r = transport.query(makeSlice(QUERY));
+
+    EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
+    EXPECT_EQ(QUERY, r.response);
+}
+
+TEST_F(TransportTest, SerialQueries) {
+    FakeSocketFactory<FakeSocketEcho> factory;
+    DnsTlsTransport transport(SERVER1, MARK, &factory);
+    // Send more than 65536 queries serially.
+    for (int i = 0; i < 100000; ++i) {
+        auto r = transport.query(makeSlice(QUERY));
+
+        EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
+        EXPECT_EQ(QUERY, r.response);
+    }
+}
+
+// Returning null from the factory indicates a connection failure.
+class NullSocketFactory : public IDnsTlsSocketFactory {
+public:
+    NullSocketFactory() {}
+    std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+            const DnsTlsServer& server ATTRIBUTE_UNUSED,
+            unsigned mark ATTRIBUTE_UNUSED,
+            DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+        return nullptr;
+    }
+};
+
+TEST_F(TransportTest, ConnectFail) {
+    NullSocketFactory factory;
+    DnsTlsTransport transport(SERVER1, MARK, &factory);
+    auto r = transport.query(makeSlice(QUERY));
+
+    EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
+    EXPECT_TRUE(r.response.empty());
+}
+
+// Dispatcher tests
+class DispatcherTest : public BaseTest {};
+
+TEST_F(DispatcherTest, Query) {
+    bytevec ans(4096);
+    int resplen = 0;
+
+    auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
+    DnsTlsDispatcher dispatcher(std::move(factory));
+    auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
+                              makeSlice(ans), &resplen);
+
+    EXPECT_EQ(DnsTlsTransport::Response::success, r);
+    EXPECT_EQ(int(QUERY.size()), resplen);
+    ans.resize(resplen);
+    EXPECT_EQ(QUERY, ans);
+}
+
+TEST_F(DispatcherTest, AnswerTooLarge) {
+    bytevec ans(SIZE - 1);  // Too small to hold the answer
+    int resplen = 0;
+
+    auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
+    DnsTlsDispatcher dispatcher(std::move(factory));
+    auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
+                              makeSlice(ans), &resplen);
+
+    EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
+}
+
+template<class T>
+class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
+public:
+    TrackingFakeSocketFactory() {}
+    std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+            const DnsTlsServer& server,
+            unsigned mark,
+            DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+        std::lock_guard<std::mutex> guard(mLock);
+        keys.emplace(mark, server);
+        return std::make_unique<T>();
+    }
+    std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
+private:
+    std::mutex mLock;
+};
+
+TEST_F(DispatcherTest, Dispatching) {
+    auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketEcho>>();
+    auto* weak_factory = factory.get();  // Valid as long as dispatcher is in scope.
+    DnsTlsDispatcher dispatcher(std::move(factory));
+
+    // Populate a vector of two servers and two socket marks, four combinations
+    // in total.
+    std::vector<std::pair<unsigned, DnsTlsServer>> keys;
+    keys.emplace_back(MARK, SERVER1);
+    keys.emplace_back(MARK + 1, SERVER1);
+    keys.emplace_back(MARK, V4ADDR2);
+    keys.emplace_back(MARK + 1, V4ADDR2);
+
+    // Do one query on each server.  They should all succeed.
+    std::vector<std::thread> threads;
+    for (size_t i = 0; i < keys.size(); ++i) {
+        auto key = keys[i % keys.size()];
+        threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
+            auto q = make_query(i, SIZE);
+            bytevec ans(4096);
+            int resplen = 0;
+            unsigned mark = key.first;
+            const DnsTlsServer& server = key.second;
+            auto r = dispatcher->query(server, mark, makeSlice(q),
+                                       makeSlice(ans), &resplen);
+            EXPECT_EQ(DnsTlsTransport::Response::success, r);
+            EXPECT_EQ(int(q.size()), resplen);
+            ans.resize(resplen);
+            EXPECT_EQ(q, ans);
+        }, &dispatcher);
+    }
+    for (auto& thread : threads) {
+        thread.join();
+    }
+    // We expect that the factory created one socket for each key.
+    EXPECT_EQ(keys.size(), weak_factory->keys.size());
+    for (auto& key : keys) {
+        EXPECT_EQ(1U, weak_factory->keys.count(key));
+    }
+}
+
+// Check DnsTlsServer's comparison logic.
+AddressComparator ADDRESS_COMPARATOR;
+bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
+    bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
+    bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
+    EXPECT_FALSE(cmp1 && cmp2);
+    return !cmp1 && !cmp2;
+}
+
+void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
+    EXPECT_TRUE(s1 == s1);
+    EXPECT_TRUE(s2 == s2);
+    EXPECT_TRUE(isAddressEqual(s1, s1));
+    EXPECT_TRUE(isAddressEqual(s2, s2));
+
+    EXPECT_TRUE(s1 < s2 ^ s2 < s1);
+    EXPECT_FALSE(s1 == s2);
+    EXPECT_FALSE(s2 == s1);
+}
+
+class ServerTest : public BaseTest {};
+
+TEST_F(ServerTest, IPv4) {
+    checkUnequal(V4ADDR1, V4ADDR2);
+    EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
+}
+
+TEST_F(ServerTest, IPv6) {
+    checkUnequal(V6ADDR1, V6ADDR2);
+    EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
+}
+
+TEST_F(ServerTest, MixedAddressFamily) {
+    checkUnequal(V6ADDR1, V4ADDR1);
+    EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
+}
+
+TEST_F(ServerTest, IPv6ScopeId) {
+    DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
+    sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
+    addr1->sin6_scope_id = 1;
+    sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
+    addr2->sin6_scope_id = 2;
+    checkUnequal(s1, s2);
+    EXPECT_FALSE(isAddressEqual(s1, s2));
+}
+
+TEST_F(ServerTest, IPv6FlowInfo) {
+    DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
+    sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
+    addr1->sin6_flowinfo = 1;
+    sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
+    addr2->sin6_flowinfo = 2;
+    // All comparisons ignore flowinfo.
+    EXPECT_EQ(s1, s2);
+    EXPECT_TRUE(isAddressEqual(s1, s2));
+}
+
+TEST_F(ServerTest, Port) {
+    DnsTlsServer s1, s2;
+    parseServer("192.0.2.1", 853, &s1.ss);
+    parseServer("192.0.2.1", 854, &s2.ss);
+    checkUnequal(s1, s2);
+    EXPECT_TRUE(isAddressEqual(s1, s2));
+
+    DnsTlsServer s3, s4;
+    parseServer("2001:db8::1", 853, &s3.ss);
+    parseServer("2001:db8::1", 852, &s4.ss);
+    checkUnequal(s3, s4);
+    EXPECT_TRUE(isAddressEqual(s3, s4));
+}
+
+TEST_F(ServerTest, Name) {
+    DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
+    s1.name = SERVERNAME1;
+    checkUnequal(s1, s2);
+    s2.name = SERVERNAME2;
+    checkUnequal(s1, s2);
+    EXPECT_TRUE(isAddressEqual(s1, s2));
+}
+
+TEST_F(ServerTest, Fingerprint) {
+    DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
+
+    s1.fingerprints.insert(FINGERPRINT1);
+    checkUnequal(s1, s2);
+    EXPECT_TRUE(isAddressEqual(s1, s2));
+
+    s2.fingerprints.insert(FINGERPRINT2);
+    checkUnequal(s1, s2);
+    EXPECT_TRUE(isAddressEqual(s1, s2));
+
+    s2.fingerprints.insert(FINGERPRINT1);
+    checkUnequal(s1, s2);
+    EXPECT_TRUE(isAddressEqual(s1, s2));
+
+    s1.fingerprints.insert(FINGERPRINT2);
+    EXPECT_EQ(s1, s2);
+    EXPECT_TRUE(isAddressEqual(s1, s2));
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/tests/runtests.sh b/tests/runtests.sh
new file mode 100755
index 0000000..c1a6b70
--- /dev/null
+++ b/tests/runtests.sh
@@ -0,0 +1,94 @@
+#!/usr/bin/env bash
+
+readonly PROJECT_TOP="system/netd"
+
+# TODO:
+#   - add Android.bp test targets
+#   - switch away from runtest.py
+readonly ALL_TESTS="
+    server/netd_unit_test.cpp
+    tests/netd_integration_test.cpp
+"
+
+REPO_TOP=""
+DEBUG=""
+
+function logToStdErr() {
+    echo "$1" >&2
+}
+
+function testAndSetRepoTop() {
+    if [[ -n "$1" && -d "$1/.repo" ]]; then
+        REPO_TOP="$1"
+        return 0
+    fi
+    return 1
+}
+
+function gotoRepoTop() {
+    if testAndSetRepoTop "$ANDROID_BUILD_TOP"; then
+        return
+    fi
+
+    while ! testAndSetRepoTop "$PWD"; do
+        if [[ "$PWD" == "/" ]]; then
+            break
+        fi
+        cd ..
+    done
+}
+
+function runOneTest() {
+    local testName="$1"
+    local cmd="$REPO_TOP/development/testrunner/runtest.py -x $PROJECT_TOP/$testName"
+    echo "###"
+    echo "# $testName"
+    echo "#"
+    echo "# $cmd"
+    echo "###"
+    echo ""
+    $DEBUG $cmd
+    local rval=$?
+    echo ""
+
+    # NOTE: currently runtest.py returns 0 even for failed tests.
+    return $rval
+}
+
+function main() {
+    gotoRepoTop
+    if ! testAndSetRepoTop "$REPO_TOP"; then
+        logToStdErr "Could not find useful top of repo directory"
+        return 1
+    fi
+    logToStdErr "using REPO_TOP=$REPO_TOP"
+
+    if [[ -n "$1" ]]; then
+        case "$1" in
+            "-n")
+                DEBUG=echo
+                shift
+                ;;
+        esac
+    fi
+
+    # Allow us to do things like "runtests.sh integration", etc.
+    readonly TEST_REGEX="$1"
+
+    failures=0
+    for testName in $ALL_TESTS; do
+        if [[ -z "$TEST_REGEX" || "$testName" =~ "$TEST_REGEX" ]]; then
+            runOneTest "$testName"
+            let failures+=$?
+        else
+            logToStdErr "Skipping $testName"
+        fi
+    done
+
+    echo "Number of tests failing: $failures"
+    return $failures
+}
+
+
+main "$@"
+exit $?