Fix DnsTlsSocket fast shutdown path

Previously, DnsTlsSocket's destructor told the loop thread to
perform a clean shutdown by closing an IPC file descriptor.
However, the IPC file descriptor is now an eventfd, which does
not alert the listening thread when it is closed.

This change uses the eventfd counter's sign bit as an indication
that the destructor is requesting an immediate close.

Test: Includes regression test.
Bug: 123212403
Bug: 124058672
Bug: 122856181
Change-Id: I6edc26bf504cbfbba7d055b1f8e52ac70e02c6e0
Merged-In: I6edc26bf504cbfbba7d055b1f8e52ac70e02c6e0
(cherry picked from commit 83eccadc7e9d0ee0f75aab980cfdc2159c4c98a2)
diff --git a/server/dns/DnsTlsSocket.cpp b/server/dns/DnsTlsSocket.cpp
index 8b2d200..237d6bf 100644
--- a/server/dns/DnsTlsSocket.cpp
+++ b/server/dns/DnsTlsSocket.cpp
@@ -350,6 +350,8 @@
 
         // If we have pending queries, wait for space to write one.
         // Otherwise, listen for new queries.
+        // Note: This blocks the destructor until q is empty, i.e. until all pending
+        // queries are sent or have failed to send.
         if (!q.empty()) {
             fds[SSLFD].events |= POLLOUT;
         } else {
@@ -366,7 +368,7 @@
             ALOGV("Poll failed: %d", errno);
             break;
         }
-        if (fds[SSLFD].revents & (POLLIN | POLLERR)) {
+        if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
             if (!readResponse()) {
                 ALOGV("SSL remote close or read error.");
                 break;
@@ -379,23 +381,17 @@
                 ALOGW("Error during eventfd read");
                 break;
             } else if (res == 0) {
-                ALOGV("eventfd closed; disconnecting");
+                ALOGW("eventfd closed; disconnecting");
                 break;
             } else if (res != sizeof(num_queries)) {
                 ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries));
                 break;
-            } else if (num_queries <= 0) {
-                ALOGE("eventfd reads should always be positive");
+            } else if (num_queries < 0) {
+                ALOGV("Negative eventfd read indicates destructor-initiated shutdown");
                 break;
             }
             // Take ownership of all pending queries.  (q is always empty here.)
             mQueue.swap(q);
-            // The writing thread writes to mQueue and then increments mEventFd, so
-            // there should be at least num_queries entries in mQueue.
-            if (q.size() < (uint64_t) num_queries) {
-                ALOGE("Synchronization error");
-                break;
-            }
         } else if (fds[SSLFD].revents & POLLOUT) {
             // q cannot be empty here.
             // Sending the entire queue here would risk a TCP flow control deadlock, so
@@ -408,8 +404,6 @@
             q.pop_front();
         }
     }
-    ALOGV("Closing event FD");
-    mEventFd.reset();
     ALOGV("Disconnecting");
     sslDisconnect();
     ALOGV("Calling onClosed");
@@ -420,12 +414,7 @@
 DnsTlsSocket::~DnsTlsSocket() {
     ALOGV("Destructor");
     // This will trigger an orderly shutdown in loop().
-    // In principle there is a data race here: If there is an I/O error in the network thread
-    // simultaneous with a call to the destructor in a different thread, both threads could
-    // attempt to call mEventFd.reset() at the same time.  However, the implementation of
-    // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd
-    // after this point, so we don't expect an issue in practice.
-    mEventFd.reset();
+    requestLoopShutdown();
     {
         // Wait for the orderly shutdown to complete.
         std::lock_guard<std::mutex> guard(mLock);
@@ -443,10 +432,6 @@
 }
 
 bool DnsTlsSocket::query(uint16_t id, const Slice query) {
-    if (!mEventFd) {
-        return false;
-    }
-
     // Compose the entire message in a single buffer, so that it can be
     // sent as a single TLS record.
     std::vector<uint8_t> buf(query.size() + 4);
@@ -462,9 +447,25 @@
 
     mQueue.push(std::move(buf));
     // Increment the mEventFd counter by 1.
-    constexpr int64_t num_queries = 1;
-    int written = write(mEventFd.get(), &num_queries, sizeof(num_queries));
-    return written == sizeof(num_queries);
+    return incrementEventFd(1);
+}
+
+void DnsTlsSocket::requestLoopShutdown() {
+    // Write a negative number to the eventfd.  This triggers an immediate shutdown.
+    incrementEventFd(INT64_MIN);
+}
+
+bool DnsTlsSocket::incrementEventFd(const int64_t count) {
+    if (!mEventFd) {
+        ALOGV("eventfd is not initialized");
+        return false;
+    }
+    int written = write(mEventFd.get(), &count, sizeof(count));
+    if (written != sizeof(count)) {
+        ALOGE("Failed to increment eventfd by %" PRId64, count);
+        return false;
+    }
+    return true;
 }
 
 // Read exactly len bytes into buffer or fail with an SSL error code
diff --git a/server/dns/DnsTlsSocket.h b/server/dns/DnsTlsSocket.h
index 2593bcf..57e1acc 100644
--- a/server/dns/DnsTlsSocket.h
+++ b/server/dns/DnsTlsSocket.h
@@ -65,7 +65,7 @@
     // notified that the socket is closed.
     // Note that success here indicates successful sending, not receipt of a response.
     // Thread-safe.
-    bool query(uint16_t id, const Slice query) override;
+    bool query(uint16_t id, const Slice query) override EXCLUDES(mLock);
 
 private:
     // Lock to be held by the SSL event loop thread.  This is not normally in contention.
@@ -99,6 +99,15 @@
     bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
     bool readResponse() REQUIRES(mLock);
 
+    // Similar to query(), this function uses incrementEventFd to send a message to the
+    // loop thread.  However, instead of incrementing the counter by one (indicating a
+    // new query), it wraps the counter to negative, which we use to indicate a shutdown
+    // request.
+    void requestLoopShutdown() EXCLUDES(mLock);
+
+    // This function sends a message to the loop thread by incrementing mEventFd.
+    bool incrementEventFd(int64_t count) EXCLUDES(mLock);
+
     // Queue of pending queries.  query() pushes items onto the queue and notifies
     // the loop thread by incrementing mEventFd.  loop() reads items off the queue.
     LockedQueue<std::vector<uint8_t>> mQueue;
@@ -106,8 +115,10 @@
     // eventfd socket used for notifying the SSL thread when queries are ready to send.
     // This socket acts similarly to an atomic counter, incremented by query() and cleared
     // by loop().  We have to use a socket because the SSL thread needs to wait in poll()
-    // for input from either a remote server or a query thread.
-    // EOF indicates a close request.
+    // for input from either a remote server or a query thread.  Since eventfd does not have
+    // EOF, we indicate a close request by setting the counter to a negative number.
+    // This file descriptor is opened by initialize(), and closed implicitly after
+    // destruction.
     base::unique_fd mEventFd;
 
     // SSL Socket fields.
diff --git a/tests/dns_tls_test.cpp b/tests/dns_tls_test.cpp
index bb5bfe5..b7fb3a4 100644
--- a/tests/dns_tls_test.cpp
+++ b/tests/dns_tls_test.cpp
@@ -28,6 +28,8 @@
 #include "dns/IDnsTlsSocketFactory.h"
 #include "dns/IDnsTlsSocketObserver.h"
 
+#include "dns_responder/dns_tls_frontend.h"
+
 #include <chrono>
 #include <arpa/inet.h>
 #include <android-base/macros.h>
@@ -871,5 +873,44 @@
     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
 }
 
+class StubObserver : public IDnsTlsSocketObserver {
+  public:
+    bool closed = false;
+    void onResponse(std::vector<uint8_t>) override {}
+
+    void onClosed() override { closed = true; }
+};
+
+TEST(DnsTlsSocketTest, SlowDestructor) {
+    constexpr char tls_addr[] = "127.0.0.3";
+    constexpr char tls_port[] = "8530";  // High-numbered port so root isn't required.
+    // This test doesn't perform any queries, so the backend address can be invalid.
+    constexpr char backend_addr[] = "192.0.2.1";
+    constexpr char backend_port[] = "1";
+
+    test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
+    ASSERT_TRUE(tls.startServer());
+
+    DnsTlsServer server;
+    parseServer(tls_addr, 8530, &server.ss);
+
+    StubObserver observer;
+    ASSERT_FALSE(observer.closed);
+    DnsTlsSessionCache cache;
+    auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
+    ASSERT_TRUE(socket->initialize());
+
+    // Test: Time the socket destructor.  This should be fast.
+    auto before = std::chrono::steady_clock::now();
+    socket.reset();
+    auto after = std::chrono::steady_clock::now();
+    auto delay = after - before;
+    ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
+    EXPECT_TRUE(observer.closed);
+    // Shutdown should complete in milliseconds, but if the shutdown signal is lost
+    // it will wait for the timeout, which is expected to take 20seconds.
+    EXPECT_LT(delay, std::chrono::seconds{5});
+}
+
 } // end of namespace net
 } // end of namespace android