Merge "Make DnsTlsTransport's query method static"
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index 8db1da8..a4ed07a 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -100,16 +100,14 @@
ALOGE("qhook abort: unknown address family");
return res_goahead;
}
- sockaddr_storage secureResolver;
- std::set<std::vector<uint8_t>> fingerprints;
+ DnsTlsTransport::Server tlsServer;
if (net::gCtls->resolverCtrl.shouldUseTls(thread_netcontext.dns_netid,
- insecureResolver, &secureResolver, &fingerprints)) {
+ insecureResolver, &tlsServer)) {
if (DBG) {
ALOGD("qhook using TLS");
}
- DnsTlsTransport xport(thread_netcontext.dns_mark, IPPROTO_TCP,
- secureResolver, fingerprints);
- auto response = xport.doQuery(*buf, *buflen, ans, anssiz, resplen);
+ auto response = DnsTlsTransport::query(tlsServer, thread_netcontext.dns_mark,
+ *buf, *buflen, ans, anssiz, resplen);
if (response == DnsTlsTransport::Response::success) {
if (DBG) {
ALOGD("qhook success");
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 5266166..238eef7 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -53,31 +53,25 @@
namespace {
-struct PrivateDnsServer {
- PrivateDnsServer(const sockaddr_storage& ss) : ss(ss) {}
- const sockaddr_storage ss;
- // For now, the fingerprints are always SHA-256. This is the only digest algorithm
- // that is mandatory to support (https://tools.ietf.org/html/rfc7858#section-4.2).
- std::set<std::vector<uint8_t>> fingerprints;
-};
-
// This comparison ignores ports and fingerprints.
-bool operator<(const PrivateDnsServer& x, const PrivateDnsServer& y) {
- if (x.ss.ss_family != y.ss.ss_family) {
- return x.ss.ss_family < y.ss.ss_family;
+struct AddressComparator {
+ bool operator() (const DnsTlsTransport::Server& x, const DnsTlsTransport::Server& y) const {
+ if (x.ss.ss_family != y.ss.ss_family) {
+ return x.ss.ss_family < y.ss.ss_family;
+ }
+ // Same address family. Compare IP addresses.
+ if (x.ss.ss_family == AF_INET) {
+ const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
+ const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
+ return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
+ } else if (x.ss.ss_family == AF_INET6) {
+ const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
+ const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
+ return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
+ }
+ return false; // Unknown address type. This is an error.
}
- // Same address family. Compare IP addresses.
- if (x.ss.ss_family == AF_INET) {
- const sockaddr_in& x_sin = reinterpret_cast<const sockaddr_in&>(x.ss);
- const sockaddr_in& y_sin = reinterpret_cast<const sockaddr_in&>(y.ss);
- return x_sin.sin_addr.s_addr < y_sin.sin_addr.s_addr;
- } else if (x.ss.ss_family == AF_INET6) {
- const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
- const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
- return std::memcmp(x_sin6.sin6_addr.s6_addr, y_sin6.sin6_addr.s6_addr, 16) < 0;
- }
- return false; // Unknown address type. This is an error.
-}
+};
bool parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
@@ -102,13 +96,13 @@
// Structure for tracking the entire set of known Private DNS servers.
std::mutex privateDnsLock;
-typedef std::set<PrivateDnsServer> PrivateDnsSet;
+typedef std::set<DnsTlsTransport::Server, AddressComparator> PrivateDnsSet;
PrivateDnsSet privateDnsServers;
// Structure for tracking the validation status of servers on a specific netid.
// Servers that fail validation are removed from the tracker, and can be retried.
enum class Validation : bool { in_process, success };
-typedef std::map<PrivateDnsServer, Validation> PrivateDnsTracker;
+typedef std::map<DnsTlsTransport::Server, Validation, AddressComparator> PrivateDnsTracker;
std::map<unsigned, PrivateDnsTracker> privateDnsTransports;
PrivateDnsSet parseServers(const char** servers, int numservers, in_port_t port) {
@@ -139,7 +133,8 @@
PrivateDnsSet intersection;
std::set_intersection(privateDnsServers.begin(), privateDnsServers.end(),
serversToCheck.begin(), serversToCheck.end(),
- std::inserter(intersection, intersection.begin()));
+ std::inserter(intersection, intersection.begin()),
+ AddressComparator());
if (intersection.empty()) {
return;
}
@@ -162,10 +157,9 @@
}
tracker[privateServer] = Validation::in_process;
std::thread validate_thread([privateServer, netId] {
- // validateDnsTlsServer() is a blocking call that performs network operations.
+ // ::validate() is a blocking call that performs network operations.
// It can take milliseconds to minutes, up to the SYN retry limit.
- bool success = validateDnsTlsServer(netId,
- privateServer.ss, privateServer.fingerprints);
+ bool success = DnsTlsTransport::validate(privateServer, netId);
std::lock_guard<std::mutex> guard(privateDnsLock);
auto netPair = privateDnsTransports.find(netId);
if (netPair == privateDnsTransports.end()) {
@@ -210,7 +204,7 @@
}
bool ResolverController::shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
- sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints) {
+ DnsTlsTransport::Server* secureServer) {
// This mutex is on the critical path of every DNS lookup that doesn't hit a local cache.
// If the overhead of mutex acquisition proves too high, we could reduce it by maintaining
// an atomic_int32_t counter of validated connections, and returning early if it's zero.
@@ -225,8 +219,7 @@
return false;
}
const auto& validatedServer = serverPair->first;
- *secureServer = validatedServer.ss;
- *fingerprints = validatedServer.fingerprints;
+ *secureServer = validatedServer;
return true;
}
@@ -457,13 +450,16 @@
if (!parseServer(server.c_str(), port, &parsed)) {
return INetd::PRIVATE_DNS_BAD_ADDRESS;
}
- PrivateDnsServer privateServer(parsed);
+ DnsTlsTransport::Server privateServer(parsed);
privateServer.fingerprints = fingerprints;
std::lock_guard<std::mutex> guard(privateDnsLock);
// Ensure we overwrite any previous matching server. This is necessary because equality is
// based only on the IP address, not the port or fingerprints.
privateDnsServers.erase(privateServer);
privateDnsServers.insert(privateServer);
+ if (DBG) {
+ ALOGD("Recorded private DNS server: %s", server.c_str());
+ }
return INetd::PRIVATE_DNS_SUCCESS;
}
diff --git a/server/ResolverController.h b/server/ResolverController.h
index a6a559d..1475c5e 100644
--- a/server/ResolverController.h
+++ b/server/ResolverController.h
@@ -18,10 +18,10 @@
#define _RESOLVER_CONTROLLER_H_
#include <vector>
-#include <netinet/in.h>
-#include <linux/in.h>
+#include "dns/DnsTlsTransport.h"
struct __res_params;
+struct sockaddr_storage;
namespace android {
namespace net {
@@ -42,11 +42,11 @@
// Given a netId and the address of an insecure (i.e. normal) DNS server, this method checks
// if there is a known secure DNS server with the same IP address that has been validated as
// accessible on this netId. If so, it returns true, providing the server's address
- // (including port) and pin fingerprints (possibly empty) in the output parameters.
+ // (including port) and pin fingerprints (possibly empty) in the output parameter.
// TODO: Add support for optional stronger security, by returning true even if the secure
// server is not accessible.
bool shouldUseTls(unsigned netId, const sockaddr_storage& insecureServer,
- sockaddr_storage* secureServer, std::set<std::vector<uint8_t>>* fingerprints);
+ DnsTlsTransport::Server* secureServer);
int clearDnsServers(unsigned netid);
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index 8d27d20..4988023 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -77,7 +77,7 @@
android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
android::base::unique_fd fd;
int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
- switch (mProtocol) {
+ switch (mServer.protocol) {
case IPPROTO_TCP:
type |= SOCK_STREAM;
break;
@@ -86,7 +86,7 @@
return fd;
}
- fd.reset(socket(mAddr.ss_family, type, mProtocol));
+ fd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
if (fd.get() == -1) {
return fd;
}
@@ -95,7 +95,7 @@
if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
fd.reset();
} else if (connect(fd.get(),
- reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
+ reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0
&& errno != EINPROGRESS) {
fd.reset();
}
@@ -179,7 +179,7 @@
}
}
- if (!mFingerprints.empty()) {
+ if (!mServer.fingerprints.empty()) {
if (DBG) {
ALOGD("Checking DNS over TLS fingerprint");
}
@@ -195,7 +195,7 @@
return nullptr;
}
- if (mFingerprints.count(digest) == 0) {
+ if (mServer.fingerprints.count(digest) == 0) {
ALOGW("No matching fingerprint");
return nullptr;
}
@@ -278,6 +278,15 @@
return true;
}
+// static
+DnsTlsTransport::Response DnsTlsTransport::query(const Server& server, unsigned mark,
+ const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
+ // TODO: Keep a static container of transports instead of constructing a new one
+ // for every query.
+ DnsTlsTransport xport(server, mark);
+ return xport.doQuery(query, qlen, response, limit, resplen);
+}
+
DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
uint8_t *response, size_t limit, int *resplen) {
*resplen = 0; // Zero indicates an error.
@@ -346,8 +355,8 @@
return Response::success;
}
-bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
- const std::set<std::vector<uint8_t>>& fingerprints) {
+// static
+bool DnsTlsTransport::validate(const Server& server, unsigned netid) {
if (DBG) {
ALOGD("Beginning validation on %u", netid);
}
@@ -392,12 +401,11 @@
fwmark.protectedFromVpn = true;
fwmark.netId = netid;
unsigned mark = fwmark.intValue;
- DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
int replylen = 0;
- xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
+ DnsTlsTransport::query(server, mark, query, qlen, recvbuf, kRecvBufSize, &replylen);
if (replylen == 0) {
if (DBG) {
- ALOGD("doQuery failed");
+ ALOGD("query failed");
}
return false;
}
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
index b9e9f7f..ddcaa1f 100644
--- a/server/dns/DnsTlsTransport.h
+++ b/server/dns/DnsTlsTransport.h
@@ -33,21 +33,38 @@
class DnsTlsTransport {
public:
- DnsTlsTransport(unsigned mark, int protocol, const sockaddr_storage &ss,
- const std::set<std::vector<uint8_t>>& fingerprints)
- : mMark(mark), mProtocol(protocol), mAddr(ss), mFingerprints(fingerprints)
- {}
- ~DnsTlsTransport() {}
+ struct Server {
+ // Default constructor
+ Server() {}
+ // Allow sockaddr_storage to be promoted to Server automatically.
+ Server(const sockaddr_storage& ss) : ss(ss) {}
+ sockaddr_storage ss;
+ std::set<std::vector<uint8_t>> fingerprints;
+ int protocol = IPPROTO_TCP;
+ };
enum class Response : uint8_t { success, network_error, limit_error, internal_error };
- // Given a |query| of length |qlen|, sends it to the server and writes the
- // response into |ans|, which can accept up to |anssiz| bytes. Indicates
+ // Given a |query| of length |qlen|, sends it to the server on the network indicated by |mark|,
+ // and writes the response into |ans|, which can accept up to |anssiz| bytes. Indicates
// the number of bytes written in |resplen|. If |resplen| is zero, an
// error has occurred.
- Response doQuery(const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+ static Response query(const Server& server, unsigned mark, const uint8_t *query, size_t qlen,
+ uint8_t *ans, size_t anssiz, int *resplen);
+
+ // Check that a given TLS server is fully working on the specified netid, and has the
+ // provided SHA-256 fingerprint (if nonempty). This function is used in ResolverController
+ // to ensure that we don't enable DNS over TLS on networks where it doesn't actually work.
+ static bool validate(const Server& server, unsigned netid);
private:
+ DnsTlsTransport(const Server& server, unsigned mark)
+ : mMark(mark), mServer(server)
+ {}
+ ~DnsTlsTransport() {}
+
+ Response doQuery(const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+
// On success, returns a non-blocking socket connected to mAddr (the
// connection will likely be in progress if mProtocol is IPPROTO_TCP).
// On error, returns -1 with errno set appropriately.
@@ -63,17 +80,9 @@
bool sslRead(int fd, SSL *ssl, uint8_t *buffer, int len);
const unsigned mMark; // Socket mark
- const int mProtocol;
- const sockaddr_storage mAddr;
- const std::set<std::vector<uint8_t>> mFingerprints;
+ const Server mServer;
};
-// Check that a given TLS server (ss) is fully working on the specified netid, and has a
-// provided SHA-256 fingerprint (if nonempty). This function is used in ResolverController
-// to ensure that we don't enable DNS over TLS on networks where it doesn't actually work.
-bool validateDnsTlsServer(unsigned netid, const sockaddr_storage& ss,
- const std::set<std::vector<uint8_t>>& fingerprints);
-
} // namespace net
} // namespace android