Copy queries synchronously in DnsTlsSocket

Prior to this change, each outgoing query was copied only once,
on the DnsTlsSocket's loop thread.  This could create a problem
if a misbehaving server sent an erroneous response with a
colliding ID number after the query was given to DnsTlsSocket
but before the copy was made.  The erroneous response would
complete the query, causing the caller to deallocate the backing
buffer, resulting in a segfault on copy.

This change moves the copy earlier, onto the calling thread, thus
ensuring that the backing buffer cannot have been deallocated.
Instead of sending the network thread pointers to query buffers,
copies of queries are stored in a shared queue, and the network
thread is notified of new queries on an eventfd socket.

Bug: 122133500
Bug: 122856181
Test: Integrations tests pass, manual tests good.  No regression test.
Change-Id: Ia4e72da561aeef69a17e87bfdc7aa04340c12fd0
Merged-In: Ia4e72da561aeef69a17e87bfdc7aa04340c12fd0
(cherry picked from commit 8b8cf0388d3d463f474795e8996197f267a416e7)
diff --git a/server/dns/DnsTlsSocket.cpp b/server/dns/DnsTlsSocket.cpp
index ca1cdc9..8b2d200 100644
--- a/server/dns/DnsTlsSocket.cpp
+++ b/server/dns/DnsTlsSocket.cpp
@@ -19,13 +19,14 @@
 
 #include "dns/DnsTlsSocket.h"
 
-#include <algorithm>
 #include <arpa/inet.h>
 #include <arpa/nameser.h>
 #include <errno.h>
 #include <linux/tcp.h>
 #include <openssl/err.h>
+#include <sys/eventfd.h>
 #include <sys/poll.h>
+#include <algorithm>
 
 #include "dns/DnsTlsSessionCache.h"
 #include "dns/IDnsTlsSocketObserver.h"
@@ -166,14 +167,8 @@
     if (!mSsl) {
         return false;
     }
-    int sv[2];
-    if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) {
-        return false;
-    }
-    // The two sockets are perfectly symmetrical, so the choice of which one is
-    // "in" and which one is "out" is arbitrary.
-    mIpcInFd.reset(sv[0]);
-    mIpcOutFd.reset(sv[1]);
+
+    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
 
     // Start the I/O loop.
     mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
@@ -341,26 +336,25 @@
 
 void DnsTlsSocket::loop() {
     std::lock_guard<std::mutex> guard(mLock);
-    // Buffer at most one query.
-    Query q;
+    std::deque<std::vector<uint8_t>> q;
 
     const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;
     while (true) {
         // poll() ignores negative fds
         struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
-        enum { SSLFD = 0, IPCFD = 1 };
+        enum { SSLFD = 0, EVENTFD = 1 };
 
         // Always listen for a response from server.
         fds[SSLFD].fd = mSslFd.get();
         fds[SSLFD].events = POLLIN;
 
-        // If we have a pending query, also wait for space
-        // to write it, otherwise listen for a new query.
-        if (!q.query.empty()) {
+        // If we have pending queries, wait for space to write one.
+        // Otherwise, listen for new queries.
+        if (!q.empty()) {
             fds[SSLFD].events |= POLLOUT;
         } else {
-            fds[IPCFD].fd = mIpcOutFd.get();
-            fds[IPCFD].events = POLLIN;
+            fds[EVENTFD].fd = mEventFd.get();
+            fds[EVENTFD].events = POLLIN;
         }
 
         const int s = TEMP_FAILURE_RETRY(poll(fds, ARRAY_SIZE(fds), timeout_msecs));
@@ -378,28 +372,44 @@
                 break;
             }
         }
-        if (fds[IPCFD].revents & (POLLIN | POLLERR)) {
-            int res = read(mIpcOutFd.get(), &q, sizeof(q));
+        if (fds[EVENTFD].revents & (POLLIN | POLLERR)) {
+            int64_t num_queries;
+            ssize_t res = read(mEventFd.get(), &num_queries, sizeof(num_queries));
             if (res < 0) {
-                ALOGW("Error during IPC read");
+                ALOGW("Error during eventfd read");
                 break;
             } else if (res == 0) {
-                ALOGV("IPC channel closed; disconnecting");
+                ALOGV("eventfd closed; disconnecting");
                 break;
-            } else if (res != sizeof(q)) {
-                ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q));
+            } else if (res != sizeof(num_queries)) {
+                ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries));
+                break;
+            } else if (num_queries <= 0) {
+                ALOGE("eventfd reads should always be positive");
+                break;
+            }
+            // Take ownership of all pending queries.  (q is always empty here.)
+            mQueue.swap(q);
+            // The writing thread writes to mQueue and then increments mEventFd, so
+            // there should be at least num_queries entries in mQueue.
+            if (q.size() < (uint64_t) num_queries) {
+                ALOGE("Synchronization error");
                 break;
             }
         } else if (fds[SSLFD].revents & POLLOUT) {
-            // query cannot be null here.
-            if (!sendQuery(q)) {
+            // q cannot be empty here.
+            // Sending the entire queue here would risk a TCP flow control deadlock, so
+            // we only send a single query on each cycle of this loop.
+            // TODO: Coalesce multiple pending queries if there is enough space in the
+            // write buffer.
+            if (!sendQuery(q.front())) {
                 break;
             }
-            q = Query();  // Reset q to empty
+            q.pop_front();
         }
     }
-    ALOGV("Closing IPC read FD");
-    mIpcOutFd.reset();
+    ALOGV("Closing event FD");
+    mEventFd.reset();
     ALOGV("Disconnecting");
     sslDisconnect();
     ALOGV("Calling onClosed");
@@ -410,7 +420,12 @@
 DnsTlsSocket::~DnsTlsSocket() {
     ALOGV("Destructor");
     // This will trigger an orderly shutdown in loop().
-    mIpcInFd.reset();
+    // In principle there is a data race here: If there is an I/O error in the network thread
+    // simultaneous with a call to the destructor in a different thread, both threads could
+    // attempt to call mEventFd.reset() at the same time.  However, the implementation of
+    // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd
+    // after this point, so we don't expect an issue in practice.
+    mEventFd.reset();
     {
         // Wait for the orderly shutdown to complete.
         std::lock_guard<std::mutex> guard(mLock);
@@ -428,12 +443,28 @@
 }
 
 bool DnsTlsSocket::query(uint16_t id, const Slice query) {
-    const Query q = { .id = id, .query = query };
-    if (!mIpcInFd) {
+    if (!mEventFd) {
         return false;
     }
-    int written = write(mIpcInFd.get(), &q, sizeof(q));
-    return written == sizeof(q);
+
+    // Compose the entire message in a single buffer, so that it can be
+    // sent as a single TLS record.
+    std::vector<uint8_t> buf(query.size() + 4);
+    // Write 2-byte length
+    uint16_t len = query.size() + 2;  // + 2 for the ID.
+    buf[0] = len >> 8;
+    buf[1] = len;
+    // Write 2-byte ID
+    buf[2] = id >> 8;
+    buf[3] = id;
+    // Copy body
+    std::memcpy(buf.data() + 4, query.base(), query.size());
+
+    mQueue.push(std::move(buf));
+    // Increment the mEventFd counter by 1.
+    constexpr int64_t num_queries = 1;
+    int written = write(mEventFd.get(), &num_queries, sizeof(num_queries));
+    return written == sizeof(num_queries);
 }
 
 // Read exactly len bytes into buffer or fail with an SSL error code
@@ -467,20 +498,7 @@
     return SSL_ERROR_NONE;
 }
 
-bool DnsTlsSocket::sendQuery(const Query& q) {
-    ALOGV("sending query");
-    // Compose the entire message in a single buffer, so that it can be
-    // sent as a single TLS record.
-    std::vector<uint8_t> buf(q.query.size() + 4);
-    // Write 2-byte length
-    uint16_t len = q.query.size() + 2; // + 2 for the ID.
-    buf[0] = len >> 8;
-    buf[1] = len;
-    // Write 2-byte ID
-    buf[2] = q.id >> 8;
-    buf[3] = q.id;
-    // Copy body
-    std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
+bool DnsTlsSocket::sendQuery(const std::vector<uint8_t>& buf) {
     if (!sslWrite(netdutils::makeSlice(buf))) {
         return false;
     }
diff --git a/server/dns/DnsTlsSocket.h b/server/dns/DnsTlsSocket.h
index 0c37a52..2593bcf 100644
--- a/server/dns/DnsTlsSocket.h
+++ b/server/dns/DnsTlsSocket.h
@@ -28,6 +28,7 @@
 
 #include "dns/DnsTlsServer.h"
 #include "dns/IDnsTlsSocket.h"
+#include "dns/LockedQueue.h"
 
 namespace android {
 namespace net {
@@ -95,20 +96,19 @@
     // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
     int sslRead(const Slice buffer, bool wait) REQUIRES(mLock);
 
-    struct Query {
-        uint16_t id;
-        Slice query;
-    };
-
-    bool sendQuery(const Query& q) REQUIRES(mLock);
+    bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
     bool readResponse() REQUIRES(mLock);
 
-    // SOCK_SEQPACKET socket pair used for sending queries from myriad query
-    // threads to the SSL thread.  EOF indicates a close request.
-    // We have to use a socket pair (i.e. a pipe) because the SSL thread needs to wait in
-    // select() for input from either a remote server or a query thread.
-    base::unique_fd mIpcInFd;
-    base::unique_fd mIpcOutFd GUARDED_BY(mLock);
+    // Queue of pending queries.  query() pushes items onto the queue and notifies
+    // the loop thread by incrementing mEventFd.  loop() reads items off the queue.
+    LockedQueue<std::vector<uint8_t>> mQueue;
+
+    // eventfd socket used for notifying the SSL thread when queries are ready to send.
+    // This socket acts similarly to an atomic counter, incremented by query() and cleared
+    // by loop().  We have to use a socket because the SSL thread needs to wait in poll()
+    // for input from either a remote server or a query thread.
+    // EOF indicates a close request.
+    base::unique_fd mEventFd;
 
     // SSL Socket fields.
     bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
diff --git a/server/dns/LockedQueue.h b/server/dns/LockedQueue.h
new file mode 100644
index 0000000..65b81ce
--- /dev/null
+++ b/server/dns/LockedQueue.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_LOCKED_QUEUE_H
+#define _DNS_LOCKED_QUEUE_H
+
+#include <algorithm>
+#include <deque>
+#include <mutex>
+
+#include <android-base/thread_annotations.h>
+
+namespace android {
+namespace net {
+
+template <typename T>
+class LockedQueue {
+  public:
+    // Push an item onto the queue.
+    void push(T item) {
+        std::lock_guard<std::mutex> guard(mLock);
+        mQueue.push_front(std::move(item));
+    }
+
+    // Swap out the contents of the queue
+    void swap(std::deque<T>& other) {
+        std::lock_guard<std::mutex> guard(mLock);
+        mQueue.swap(other);
+    }
+
+  private:
+    std::mutex mLock;
+    std::deque<T> mQueue GUARDED_BY(mLock);
+};
+
+}  // end of namespace net
+}  // end of namespace android
+
+#endif  // _DNS_LOCKEDQUEUE_H