Merge "Make DnsTlsTransport's query method static"
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index 8db1da8..a4ed07a 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -100,16 +100,14 @@
         ALOGE("qhook abort: unknown address family");
         return res_goahead;
     }
-    sockaddr_storage secureResolver;
-    std::set<std::vector<uint8_t>> fingerprints;
+    DnsTlsTransport::Server tlsServer;
     if (net::gCtls->resolverCtrl.shouldUseTls(thread_netcontext.dns_netid,
-            insecureResolver, &secureResolver, &fingerprints)) {
+            insecureResolver, &tlsServer)) {
         if (DBG) {
             ALOGD("qhook using TLS");
         }
-        DnsTlsTransport xport(thread_netcontext.dns_mark, IPPROTO_TCP,
-                              secureResolver, fingerprints);
-        auto response = xport.doQuery(*buf, *buflen, ans, anssiz, resplen);
+        auto response = DnsTlsTransport::query(tlsServer, thread_netcontext.dns_mark,
+                *buf, *buflen, ans, anssiz, resplen);
         if (response == DnsTlsTransport::Response::success) {
             if (DBG) {
                 ALOGD("qhook success");
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 5266166..238eef7 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -53,31 +53,25 @@
 
 namespace {
 
-struct PrivateDnsServer {
-    PrivateDnsServer(const sockaddr_storage& ss) : ss(ss) {}
-    const sockaddr_storage ss;
-    // For now, the fingerprints are always SHA-256.  This is the only digest algorithm
-    // that is mandatory to support (https://tools.ietf.org/html/rfc7858#section-4.2).
-    std::set<std::vector<uint8_t>> fingerprints;
-};
-
 // This comparison ignores ports and fingerprints.
-bool operator<(const PrivateDnsServer& x, const PrivateDnsServer& y) {
-    if (x.ss.ss_family != y.ss.ss_family) {
-        return x.ss.ss_family < y.ss.ss_family;
+struct AddressComparator {
+    bool operator() (const DnsTlsTransport::Server& x, const DnsTlsTransport::Server& y) const {
+      if (x.ss.ss_family != y.ss.ss_family) {
+          return x.ss.ss_family < y.ss.ss_family;
+      }
+      // Same address family.  Compare IP addresses.
+      if (x.ss.ss_family == AF_INET) {
+          const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
+          const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
+          return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
+      } else if (x.ss.ss_family == AF_INET6) {
+          const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
+          const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
+          return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
+      }
+      return false;  // Unknown address type.  This is an error.
     }
-    // Same address family.  Compare IP addresses.
-    if (x.ss.ss_family == AF_INET) {
-        const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
-        const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
-        return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
-    } else if (x.ss.ss_family == AF_INET6) {
-        const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
-        const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
-        return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
-    }
-    return false;  // Unknown address type.  This is an error.
-}
+};
 
 bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
@@ -102,13 +96,13 @@
 
 // Structure for tracking the entire set of known Private DNS servers.
 std::mutex privateDnsLock;
-typedef std::set<PrivateDnsServer> PrivateDnsSet;
+typedef std::set<DnsTlsTransport::Server, AddressComparator> PrivateDnsSet;
 PrivateDnsSet privateDnsServers;
 
 // Structure for tracking the validation status of servers on a specific netid.
 // Servers that fail validation are removed from the tracker, and can be retried.
 enum class Validation : bool { in_process, success };
-typedef std::map<PrivateDnsServer, Validation> PrivateDnsTracker;
+typedef std::map<DnsTlsTransport::Server, Validation, AddressComparator> PrivateDnsTracker;
 std::map<unsigned, PrivateDnsTracker> privateDnsTransports;
 
 PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
@@ -139,7 +133,8 @@
     PrivateDnsSet intersection;
     std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
         serversToCheck.begin(), serversToCheck.end(),
-        std::inserter(intersection, intersection.begin()));
+        std::inserter(intersection, intersection.begin()),
+        AddressComparator());
     if (intersection.empty()) {
         return;
     }
@@ -162,10 +157,9 @@
         }
         tracker[privateServer] = Validation::in_process;
         std::thread validate_thread([privateServer, netId] {
-            // validateDnsTlsServer() is a blocking call that performs network operations.
+            // ::validate() is a blocking call that performs network operations.
             // It can take milliseconds to minutes, up to the SYN retry limit.
-            bool success = validateDnsTlsServer(netId,
-                    privateServer.ss, privateServer.fingerprints);
+            bool success = DnsTlsTransport::validate(privateServer, netId);
             std::lock_guard<std::mutex> guard(privateDnsLock);
             auto netPair = privateDnsTransports.find(netId);
             if (netPair == privateDnsTransports.end()) {
@@ -210,7 +204,7 @@
 }
 
 bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
-        sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints) {
+        DnsTlsTransport::Server* secureServer) {
     // This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
     // If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
     // an atomic_int32_t counter of validated connections, and returning early if it's zero.
@@ -225,8 +219,7 @@
         return false;
     }
     const auto& validatedServer = serverPair->first;
-    *secureServer = validatedServer.ss;
-    *fingerprints = validatedServer.fingerprints;
+    *secureServer = validatedServer;
     return true;
 }
 
@@ -457,13 +450,16 @@
     if (!parseServer(server.c_str(), port, &parsed)) {
         return INetd::PRIVATE_DNS_BAD_ADDRESS;
     }
-    PrivateDnsServer privateServer(parsed);
+    DnsTlsTransport::Server privateServer(parsed);
     privateServer.fingerprints = fingerprints;
     std::lock_guard<std::mutex> guard(privateDnsLock);
     // Ensure we overwrite any previous matching server.  This is necessary because equality is
     // based only on the IP address, not the port or fingerprints.
     privateDnsServers.erase(privateServer);
     privateDnsServers.insert(privateServer);
+    if (DBG) {
+        ALOGD("Recorded private DNS server: %s", server.c_str());
+    }
     return INetd::PRIVATE_DNS_SUCCESS;
 }
 
diff --git a/server/ResolverController.h b/server/ResolverController.h
index a6a559d..1475c5e 100644
--- a/server/ResolverController.h
+++ b/server/ResolverController.h
@@ -18,10 +18,10 @@
 #define _RESOLVER_CONTROLLER_H_
 
 #include <vector>
-#include <netinet/in.h>
-#include <linux/in.h>
+#include "dns/DnsTlsTransport.h"
 
 struct __res_params;
+struct sockaddr_storage;
 
 namespace android {
 namespace net {
@@ -42,11 +42,11 @@
     // Given a netId and the address of an insecure (i.e. normal) DNS server, this method checks
     // if there is a known secure DNS server with the same IP address that has been validated as
     // accessible on this netId.  If so, it returns true, providing the server's address
-    // (including port) and pin fingerprints (possibly empty) in the output parameters.
+    // (including port) and pin fingerprints (possibly empty) in the output parameter.
     // TODO: Add support for optional stronger security, by returning true even if the secure
     // server is not accessible.
     bool shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
-            sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints);
+            DnsTlsTransport::Server* secureServer);
 
     int clearDnsServers(unsigned netid);
 
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index 8d27d20..4988023 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -77,7 +77,7 @@
 android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
     android::base::unique_fd fd;
     int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
-    switch (mProtocol) {
+    switch (mServer.protocol) {
         case IPPROTO_TCP:
             type |= SOCK_STREAM;
             break;
@@ -86,7 +86,7 @@
             return fd;
     }
 
-    fd.reset(socket(mAddr.ss_family, type, mProtocol));
+    fd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
     if (fd.get() == -1) {
         return fd;
     }
@@ -95,7 +95,7 @@
     if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
         fd.reset();
     } else if (connect(fd.get(),
-            reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
+            reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0
         && errno != EINPROGRESS) {
         fd.reset();
     }
@@ -179,7 +179,7 @@
         }
     }
 
-    if (!mFingerprints.empty()) {
+    if (!mServer.fingerprints.empty()) {
         if (DBG) {
             ALOGD("Checking DNS over TLS fingerprint");
         }
@@ -195,7 +195,7 @@
             return nullptr;
         }
 
-        if (mFingerprints.count(digest) == 0) {
+        if (mServer.fingerprints.count(digest) == 0) {
             ALOGW("No matching fingerprint");
             return nullptr;
         }
@@ -278,6 +278,15 @@
     return true;
 }
 
+// static
+DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
+        const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
+    // TODO: Keep a static container of transports instead of constructing a new one
+    // for every query.
+    DnsTlsTransport xport(server, mark);
+    return xport.doQuery(query, qlen, response, limit, resplen);
+}
+
 DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
         uint8_t *response, size_t limit, int *resplen) {
     *resplen = 0;  // Zero indicates an error.
@@ -346,8 +355,8 @@
     return Response::success;
 }
 
-bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
-        const std::set<std::vector<uint8_t>>& fingerprints) {
+// static
+bool DnsTlsTransport::validate(const Server& server, unsigned netid) {
     if (DBG) {
         ALOGD("Beginning validation on %u", netid);
     }
@@ -392,12 +401,11 @@
     fwmark.protectedFromVpn = true;
     fwmark.netId = netid;
     unsigned mark = fwmark.intValue;
-    DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
     int replylen = 0;
-    xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
+    DnsTlsTransport::query(server, mark, query, qlen, recvbuf, kRecvBufSize, &replylen);
     if (replylen == 0) {
         if (DBG) {
-            ALOGD("doQuery failed");
+            ALOGD("query failed");
         }
         return false;
     }
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
index b9e9f7f..ddcaa1f 100644
--- a/server/dns/DnsTlsTransport.h
+++ b/server/dns/DnsTlsTransport.h
@@ -33,21 +33,38 @@
 
 class DnsTlsTransport {
 public:
-    DnsTlsTransport(unsigned mark, int protocol, const sockaddr_storage &ss,
-            const std::set<std::vector<uint8_t>>& fingerprints)
-            : mMark(mark), mProtocol(protocol), mAddr(ss), mFingerprints(fingerprints)
-            {}
-    ~DnsTlsTransport() {}
+    struct Server {
+        // Default constructor
+        Server() {}
+        // Allow sockaddr_storage to be promoted to Server automatically.
+        Server(const sockaddr_storage& ss) : ss(ss) {}
+        sockaddr_storage ss;
+        std::set<std::vector<uint8_t>> fingerprints;
+        int protocol = IPPROTO_TCP;
+    };
 
     enum class Response : uint8_t { success, network_error, limit_error, internal_error };
 
-    // Given a |query| of length |qlen|, sends it to the server and writes the
-    // response into |ans|, which can accept up to |anssiz| bytes.  Indicates
+    // Given a |query| of length |qlen|, sends it to the server on the network indicated by |mark|,
+    // and writes the response into |ans|, which can accept up to |anssiz| bytes.  Indicates
     // the number of bytes written in |resplen|.  If |resplen| is zero, an
     // error has occurred.
-    Response doQuery(const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+    static Response query(const Server& server, unsigned mark, const uint8_t *query, size_t qlen,
+            uint8_t *ans, size_t anssiz, int *resplen);
+
+    // Check that a given TLS server is fully working on the specified netid, and has the
+    // provided SHA-256 fingerprint (if nonempty).  This function is used in ResolverController
+    // to ensure that we don't enable DNS over TLS on networks where it doesn't actually work.
+    static bool validate(const Server& server, unsigned netid);
 
 private:
+    DnsTlsTransport(const Server& server, unsigned mark)
+            : mMark(mark), mServer(server)
+            {}
+    ~DnsTlsTransport() {}
+
+    Response doQuery(const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+
     // On success, returns a non-blocking socket connected to mAddr (the
     // connection will likely be in progress if mProtocol is IPPROTO_TCP).
     // On error, returns -1 with errno set appropriately.
@@ -63,17 +80,9 @@
     bool sslRead(int fd, SSL *ssl, uint8_t *buffer, int len);
 
     const unsigned mMark;  // Socket mark
-    const int mProtocol;
-    const sockaddr_storage mAddr;
-    const std::set<std::vector<uint8_t>> mFingerprints;
+    const Server mServer;
 };
 
-// Check that a given TLS server (ss) is fully working on the specified netid, and has a
-// provided SHA-256 fingerprint (if nonempty).  This function is used in ResolverController
-// to ensure that we don't enable DNS over TLS on networks where it doesn't actually work.
-bool validateDnsTlsServer(unsigned netid, const sockaddr_storage& ss,
-        const std::set<std::vector<uint8_t>>& fingerprints);
-
 }  // namespace net
 }  // namespace android