Merge "Tcp socket metrics: define INetdEventListener callback"
diff --git a/server/Android.mk b/server/Android.mk
index 8d495e1..f44d92b 100644
--- a/server/Android.mk
+++ b/server/Android.mk
@@ -58,7 +58,7 @@
external/mdnsresponder/mDNSShared \
system/netd/include \
-LOCAL_CPPFLAGS := -Wall -Werror -Wthread-safety
+LOCAL_CPPFLAGS := -Wall -Werror -Wthread-safety -Wnullable-to-nonnull-conversion
LOCAL_SANITIZE := unsigned-integer-overflow
LOCAL_MODULE := netd
@@ -143,6 +143,8 @@
dns/DnsTlsDispatcher.cpp \
dns/DnsTlsTransport.cpp \
dns/DnsTlsServer.cpp \
+ dns/DnsTlsSessionCache.cpp \
+ dns/DnsTlsSocket.cpp \
LOCAL_AIDL_INCLUDES := $(LOCAL_PATH)/binder
@@ -205,6 +207,11 @@
binder/android/net/UidRange.cpp \
binder/android/net/metrics/INetdEventListener.aidl \
../tests/tun_interface.cpp \
+ dns/DnsTlsDispatcher.cpp \
+ dns/DnsTlsTransport.cpp \
+ dns/DnsTlsServer.cpp \
+ dns/DnsTlsSessionCache.cpp \
+ dns/DnsTlsSocket.cpp \
LOCAL_MODULE_TAGS := tests
LOCAL_STATIC_LIBRARIES := libgmock libpcap
diff --git a/server/DnsProxyListener.cpp b/server/DnsProxyListener.cpp
index 7339d93..e3c5cf9 100644
--- a/server/DnsProxyListener.cpp
+++ b/server/DnsProxyListener.cpp
@@ -37,6 +37,7 @@
#include <vector>
#include <cutils/log.h>
+#include <netdutils/Slice.h>
#include <utils/String16.h>
#include <sysutils/SocketClient.h>
@@ -82,6 +83,8 @@
thread_local android_net_context thread_netcontext = {};
+DnsTlsDispatcher dnsTlsDispatcher;
+
res_sendhookact qhook(sockaddr* const * nsap, const u_char** buf, int* buflen,
u_char* ans, int anssiz, int* resplen) {
if (!thread_netcontext.qhook) {
@@ -128,8 +131,10 @@
if (DBG) {
ALOGD("Performing query over TLS");
}
- auto response = DnsTlsDispatcher::query(tlsServer, thread_netcontext.dns_mark,
- *buf, *buflen, ans, anssiz, resplen);
+ Slice query = netdutils::Slice(const_cast<u_char*>(*buf), *buflen);
+ Slice answer = netdutils::Slice(const_cast<u_char*>(ans), anssiz);
+ auto response = dnsTlsDispatcher.query(tlsServer, thread_netcontext.dns_mark,
+ query, answer, resplen);
if (response == DnsTlsTransport::Response::success) {
if (DBG) {
ALOGD("qhook success");
diff --git a/server/dns/DnsTlsDispatcher.cpp b/server/dns/DnsTlsDispatcher.cpp
index 6fed9f0..be9c669 100644
--- a/server/dns/DnsTlsDispatcher.cpp
+++ b/server/dns/DnsTlsDispatcher.cpp
@@ -14,34 +14,53 @@
* limitations under the License.
*/
+#define LOG_TAG "DnsTlsDispatcher"
+//#define LOG_NDEBUG 0
+
#include "dns/DnsTlsDispatcher.h"
+#include "log/log.h"
+
namespace android {
namespace net {
+using netdutils::Slice;
+
// static
std::mutex DnsTlsDispatcher::sLock;
-std::map<DnsTlsDispatcher::Key, std::unique_ptr<DnsTlsDispatcher::Transport>> DnsTlsDispatcher::sStore;
DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned mark,
- const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) {
+ const Slice query,
+ const Slice ans, int *resplen) {
const Key key = std::make_pair(mark, server);
Transport* xport;
{
std::lock_guard<std::mutex> guard(sLock);
- auto it = sStore.find(key);
- if (it == sStore.end()) {
- xport = new Transport(server, mark);
- if (!xport->transport.initialize()) {
- return DnsTlsTransport::Response::internal_error;
- }
- sStore[key].reset(xport);
+ auto it = mStore.find(key);
+ if (it == mStore.end()) {
+ xport = new Transport(server, mark, mFactory.get());
+ mStore[key].reset(xport);
} else {
xport = it->second.get();
}
++xport->useCount;
}
- DnsTlsTransport::Response res = xport->transport.query(query, qlen, response, limit, resplen);
+ ALOGV("Sending query of length %zu", query.size());
+ auto result = xport->transport.query(query);
+ DnsTlsTransport::Response code = result.code;
+ if (code == DnsTlsTransport::Response::success) {
+ if (result.response.size() > ans.size()) {
+ ALOGV("Response too large: %zu > %zu", result.response.size(), ans.size());
+ code = DnsTlsTransport::Response::limit_error;
+ } else {
+ ALOGV("Got response successfully");
+ *resplen = result.response.size();
+ netdutils::copy(ans, netdutils::makeSlice(result.response));
+ }
+ } else {
+ ALOGV("Query failed: %u", (unsigned int)code);
+ }
+
auto now = std::chrono::steady_clock::now();
{
std::lock_guard<std::mutex> guard(sLock);
@@ -49,25 +68,27 @@
xport->lastUsed = now;
cleanup(now);
}
- return res;
+ return code;
}
+// This timeout effectively controls how long to keep SSL session tickets.
static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
-std::chrono::time_point<std::chrono::steady_clock> DnsTlsDispatcher::sLastCleanup;
void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
- if (now - sLastCleanup < IDLE_TIMEOUT) {
+ // To avoid scanning mStore after every query, return early if a cleanup has been
+ // performed recently.
+ if (now - mLastCleanup < IDLE_TIMEOUT) {
return;
}
- for (auto it = sStore.begin(); it != sStore.end(); ) {
+ for (auto it = mStore.begin(); it != mStore.end();) {
auto& s = it->second;
if (s->useCount == 0 && now - s->lastUsed > IDLE_TIMEOUT) {
- it = sStore.erase(it);
+ it = mStore.erase(it);
} else {
++it;
}
}
- sLastCleanup = now;
+ mLastCleanup = now;
}
-} // namespace net
-} // namespace android
+} // end of namespace net
+} // end of namespace android
diff --git a/server/dns/DnsTlsDispatcher.h b/server/dns/DnsTlsDispatcher.h
index c32bde2..9487b51 100644
--- a/server/dns/DnsTlsDispatcher.h
+++ b/server/dns/DnsTlsDispatcher.h
@@ -23,54 +23,79 @@
#include <android-base/thread_annotations.h>
+#include <netdutils/Slice.h>
+
#include "dns/DnsTlsServer.h"
+#include "dns/DnsTlsSocket.h"
+#include "dns/DnsTlsSocketFactory.h"
+#include "dns/IDnsTlsSocketFactory.h"
#include "dns/DnsTlsTransport.h"
namespace android {
namespace net {
-// This is a totally static class that manages the collection of active DnsTlsTransports.
+using netdutils::Slice;
+
+// This is a singleton class that manages the collection of active DnsTlsTransports.
// Queries made here are dispatched to an existing or newly constructed DnsTlsTransport.
class DnsTlsDispatcher {
public:
- // Given a |query| of length |qlen|, sends it to the server on the network indicated by |mark|,
- // and writes the response into |ans|, which can accept up to |anssiz| bytes. Indicates
- // the number of bytes written in |resplen|. If |resplen| is zero, an
- // error has occurred.
- static DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned mark,
- const uint8_t *query, size_t qlen, uint8_t *ans, size_t anssiz, int *resplen);
+ // Default constructor.
+ DnsTlsDispatcher() {
+ mFactory.reset(new DnsTlsSocketFactory());
+ }
+ // Constructor with dependency injection for testing.
+ DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory) :
+ mFactory(std::move(factory)) {}
+
+ // Given a |query|, sends it to the server on the network indicated by |mark|,
+ // and writes the response into |ans|, and indicates
+ // the number of bytes written in |resplen|. Returns a success or error code.
+ DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned mark,
+ const Slice query, const Slice ans, int * _Nonnull resplen);
private:
+ // This lock is static so that it can be used to annotate the Transport struct.
+ // DnsTlsDispatcher is a singleton in practice, so making this static does not change
+ // the locking behavior.
static std::mutex sLock;
+ // Key = <mark, server>
typedef std::pair<unsigned, const DnsTlsServer> Key;
// Transport is a thin wrapper around DnsTlsTransport, adding reference counting and
- // idle monitoring so we can expire unused sessions from the cache.
+ // usage monitoring so we can expire idle sessions from the cache.
struct Transport {
- Transport(const DnsTlsServer& server, unsigned mark) : transport(server, mark) {}
- // DnsTlsSession is thread-safe (internally locked), so it doesn't need to be guarded.
+ Transport(const DnsTlsServer& server, unsigned mark,
+ IDnsTlsSocketFactory* _Nonnull factory) :
+ transport(server, mark, factory) {}
+ // DnsTlsTransport is thread-safe, so it doesn't need to be guarded.
DnsTlsTransport transport;
// This use counter and timestamp are used to ensure that only idle sessions are
// destroyed.
int useCount GUARDED_BY(sLock) = 0;
+ // lastUsed is only guaranteed to be meaningful after useCount is decremented to zero.
std::chrono::time_point<std::chrono::steady_clock> lastUsed GUARDED_BY(sLock);
};
// Cache of reusable DnsTlsTransports. Transports stay in cache as long as
// they are in use and for a few minutes after.
// The key is a (netid, server) pair. The netid is first for lexicographic comparison speed.
- static std::map<Key, std::unique_ptr<Transport>> sStore GUARDED_BY(sLock);
+ std::map<Key, std::unique_ptr<Transport>> mStore GUARDED_BY(sLock);
// The last time we did a cleanup. For efficiency, we only perform a cleanup once every
// few minutes.
- static std::chrono::time_point<std::chrono::steady_clock> sLastCleanup GUARDED_BY(sLock);
+ std::chrono::time_point<std::chrono::steady_clock> mLastCleanup GUARDED_BY(sLock);
// Drop any cache entries whose useCount is zero and which have not been used recently.
- static void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+ // This function performs a linear scan of mStore.
+ void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
+
+ // Trivial factory for DnsTlsSockets. Dependency injection is only used for testing.
+ std::unique_ptr<IDnsTlsSocketFactory> mFactory;
};
-} // namespace net
-} // namespace android
+} // end of namespace net
+} // end of namespace android
#endif // _DNS_DNSTLSDISPATCHER_H
diff --git a/server/dns/DnsTlsServer.cpp b/server/dns/DnsTlsServer.cpp
index 38dc6c5..9ac9893 100644
--- a/server/dns/DnsTlsServer.cpp
+++ b/server/dns/DnsTlsServer.cpp
@@ -89,7 +89,6 @@
namespace net {
// This comparison ignores ports and fingerprints.
-// TODO: respect IPv6 scope id (e.g. link-local addresses).
bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const {
if (x.ss.ss_family != y.ss.ss_family) {
return x.ss.ss_family < y.ss.ss_family;
@@ -102,7 +101,8 @@
} else if (x.ss.ss_family == AF_INET6) {
const sockaddr_in6& x_sin6 = reinterpret_cast<const sockaddr_in6&>(x.ss);
const sockaddr_in6& y_sin6 = reinterpret_cast<const sockaddr_in6&>(y.ss);
- return x_sin6.sin6_addr < y_sin6.sin6_addr;
+ return std::tie(x_sin6.sin6_addr, x_sin6.sin6_scope_id) <
+ std::tie(y_sin6.sin6_addr, y_sin6.sin6_scope_id);
}
return false; // Unknown address type. This is an error.
}
diff --git a/server/dns/DnsTlsServer.h b/server/dns/DnsTlsServer.h
index f11c2e3..c9cbd46 100644
--- a/server/dns/DnsTlsServer.h
+++ b/server/dns/DnsTlsServer.h
@@ -35,8 +35,15 @@
// Allow sockaddr_storage to be promoted to DnsTlsServer automatically.
DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {}
+ enum class Response : uint8_t { success, network_error, limit_error, internal_error };
+
+ struct Result {
+ Response code;
+ std::vector<uint8_t> response;
+ };
+
// The server location, including IP and port.
- sockaddr_storage ss;
+ sockaddr_storage ss = {};
// A set of SHA256 public key fingerprints. If this set is nonempty, the server
// must present a self-consistent certificate chain that contains a certificate
diff --git a/server/dns/DnsTlsSessionCache.cpp b/server/dns/DnsTlsSessionCache.cpp
new file mode 100644
index 0000000..880b773
--- /dev/null
+++ b/server/dns/DnsTlsSessionCache.cpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ #include "DnsTlsSessionCache.h"
+
+#define LOG_TAG "DnsTlsSessionCache"
+//#define LOG_NDEBUG 0
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+bool DnsTlsSessionCache::prepareSsl(SSL* ssl) {
+ // Add this cache as the 0-index extra data for the socket.
+ // This is used by newSessionCallback.
+ int ret = SSL_set_ex_data(ssl, 0, this);
+ return ret == 1;
+}
+
+void DnsTlsSessionCache::prepareSslContext(SSL_CTX* ssl_ctx) {
+ SSL_CTX_set_session_cache_mode(ssl_ctx, SSL_SESS_CACHE_CLIENT);
+ SSL_CTX_sess_set_new_cb(ssl_ctx, &DnsTlsSessionCache::newSessionCallback);
+}
+
+// static
+int DnsTlsSessionCache::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+ if (!ssl || !session) {
+ ALOGE("Null SSL object in new session callback");
+ return 0;
+ }
+ DnsTlsSessionCache* cache = reinterpret_cast<DnsTlsSessionCache*>(
+ SSL_get_ex_data(ssl, 0));
+ if (!cache) {
+ ALOGE("null transport in new session callback");
+ return 0;
+ }
+ ALOGV("Recording session");
+ cache->recordSession(session);
+ return 1; // Increment the refcount of session.
+}
+
+void DnsTlsSessionCache::recordSession(SSL_SESSION* session) {
+ std::lock_guard<std::mutex> guard(mLock);
+ mSessions.emplace_front(session);
+ if (mSessions.size() > kMaxSize) {
+ ALOGV("Too many sessions; trimming");
+ mSessions.pop_back();
+ }
+}
+
+bssl::UniquePtr<SSL_SESSION> DnsTlsSessionCache::getSession() {
+ std::lock_guard<std::mutex> guard(mLock);
+ if (mSessions.size() == 0) {
+ ALOGV("No known sessions");
+ return nullptr;
+ }
+ bssl::UniquePtr<SSL_SESSION> ret = std::move(mSessions.front());
+ mSessions.pop_front();
+ return ret;
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/server/dns/DnsTlsSessionCache.h b/server/dns/DnsTlsSessionCache.h
new file mode 100644
index 0000000..32b8b55
--- /dev/null
+++ b/server/dns/DnsTlsSessionCache.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSESSIONCACHE_H
+#define _DNS_DNSTLSSESSIONCACHE_H
+
+#include <mutex>
+#include <deque>
+
+#include <openssl/ssl.h>
+
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+
+#include "dns/DnsTlsServer.h"
+
+namespace android {
+namespace net {
+
+// Cache of recently seen SSL_SESSIONs. This is used to support session tickets.
+// This class is thread-safe.
+class DnsTlsSessionCache {
+public:
+ // Prepare SSL objects to use this session cache. These methods must be called
+ // before making use of either object.
+ void prepareSslContext(SSL_CTX* _Nonnull ssl_ctx);
+ bool prepareSsl(SSL* _Nonnull ssl);
+
+ // Get the most recently discovered session. For TLS 1.3 compatibility and
+ // maximum privacy, each session will only be returned once, so the caller
+ // gains ownership of the session. (Here and throughout,
+ // bssl::UniquePtr<SSL_SESSION> is actually serving as a reference counted
+ // pointer.)
+ bssl::UniquePtr<SSL_SESSION> getSession() EXCLUDES(mLock);
+
+private:
+ static constexpr size_t kMaxSize = 5;
+ static int newSessionCallback(SSL* _Nullable ssl, SSL_SESSION* _Nullable session);
+
+ std::mutex mLock;
+ void recordSession(SSL_SESSION* _Nullable session) EXCLUDES(mLock);
+
+ // Queue of sessions, from least recently added to most recently.
+ std::deque<bssl::UniquePtr<SSL_SESSION>> mSessions GUARDED_BY(mLock);
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSSESSIONCACHE_H
diff --git a/server/dns/DnsTlsSocket.cpp b/server/dns/DnsTlsSocket.cpp
new file mode 100644
index 0000000..6243968
--- /dev/null
+++ b/server/dns/DnsTlsSocket.cpp
@@ -0,0 +1,417 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "DnsTlsSocket"
+
+#include "dns/DnsTlsSocket.h"
+
+#include <algorithm>
+#include <arpa/inet.h>
+#include <arpa/nameser.h>
+#include <errno.h>
+#include <openssl/err.h>
+#include <sys/select.h>
+
+#include "dns/DnsTlsSessionCache.h"
+
+//#define LOG_NDEBUG 0
+
+#include "log/log.h"
+#include "Fwmark.h"
+#undef ADD // already defined in nameser.h
+#include "NetdConstants.h"
+#include "Permission.h"
+
+
+namespace android {
+namespace net {
+
+using netdutils::Status;
+
+namespace {
+
+constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
+
+int waitForReading(int fd) {
+ fd_set fds;
+ FD_ZERO(&fds);
+ FD_SET(fd, &fds);
+ const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
+ ALOGV_IF(ret <= 0, "select failed during read");
+ return ret;
+}
+
+int waitForWriting(int fd) {
+ fd_set fds;
+ FD_ZERO(&fds);
+ FD_SET(fd, &fds);
+ const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
+ ALOGV_IF(ret <= 0, "select failed during write");
+ return ret;
+}
+
+} // namespace
+
+Status DnsTlsSocket::tcpConnect() {
+ ALOGV("%u connecting TCP socket", mMark);
+ int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
+ switch (mServer.protocol) {
+ case IPPROTO_TCP:
+ type |= SOCK_STREAM;
+ break;
+ default:
+ return Status(EPROTONOSUPPORT);
+ }
+
+ mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
+ if (mSslFd.get() == -1) {
+ ALOGE("Failed to create socket");
+ return Status(errno);
+ }
+
+ const socklen_t len = sizeof(mMark);
+ if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
+ ALOGE("Failed to set socket mark");
+ mSslFd.reset();
+ return Status(errno);
+ }
+ if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
+ sizeof(mServer.ss)) != 0 &&
+ errno != EINPROGRESS) {
+ ALOGV("Socket failed to connect");
+ mSslFd.reset();
+ return Status(errno);
+ }
+
+ return netdutils::status::ok;
+}
+
+bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
+ int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
+ unsigned char spki[spki_len];
+ unsigned char* temp = spki;
+ if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
+ ALOGW("SPKI length mismatch");
+ return false;
+ }
+ out->resize(SHA256_SIZE);
+ unsigned int digest_len = 0;
+ int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
+ if (ret != 1) {
+ ALOGW("Server cert digest extraction failed");
+ return false;
+ }
+ if (digest_len != out->size()) {
+ ALOGW("Wrong digest length: %d", digest_len);
+ return false;
+ }
+ return true;
+}
+
+bool DnsTlsSocket::initialize() {
+ // This method should only be called once, at the beginning, so locking should be
+ // unnecessary. This lock only serves to help catch bugs in code that calls this method.
+ std::lock_guard<std::mutex> guard(mLock);
+ if (mSslCtx) {
+ // This is a bug in the caller.
+ return false;
+ }
+ mSslCtx.reset(SSL_CTX_new(TLS_method()));
+ if (!mSslCtx) {
+ return false;
+ }
+
+ // Load system CA certs for hostname verification.
+ //
+ // For discussion of alternative, sustainable approaches see b/71909242.
+ if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
+ ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
+ return false;
+ }
+
+ // Enable TLS false start
+ SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
+ SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
+
+ // Enable session cache
+ mCache->prepareSslContext(mSslCtx.get());
+
+ // Connect
+ Status status = tcpConnect();
+ if (!status.ok()) {
+ return false;
+ }
+ mSsl = sslConnect(mSslFd.get());
+ if (!mSsl) {
+ return false;
+ }
+
+ return true;
+}
+
+bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
+ if (!mSslCtx) {
+ ALOGE("Internal error: context is null in sslConnect");
+ return nullptr;
+ }
+ if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
+ ALOGE("Failed to set minimum TLS version");
+ return nullptr;
+ }
+
+ bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
+ // This file descriptor is owned by mSslFd, so don't let libssl close it.
+ bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
+ SSL_set_bio(ssl.get(), bio.get(), bio.get());
+ bio.release();
+
+ if (!mCache->prepareSsl(ssl.get())) {
+ return nullptr;
+ }
+
+ if (!mServer.name.empty()) {
+ if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
+ ALOGE("Failed to set SNI to %s", mServer.name.c_str());
+ return nullptr;
+ }
+ X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
+ X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
+ // This will cause the handshake to fail if certificate verification fails.
+ SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
+ }
+
+ bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
+ if (session) {
+ ALOGV("Setting session");
+ SSL_set_session(ssl.get(), session.get());
+ } else {
+ ALOGV("No session available");
+ }
+
+ for (;;) {
+ ALOGV("%u Calling SSL_connect", mMark);
+ int ret = SSL_connect(ssl.get());
+ ALOGV("%u SSL_connect returned %d", mMark, ret);
+ if (ret == 1) break; // SSL handshake complete;
+
+ const int ssl_err = SSL_get_error(ssl.get(), ret);
+ switch (ssl_err) {
+ case SSL_ERROR_WANT_READ:
+ if (waitForReading(fd) != 1) {
+ ALOGW("SSL_connect read error");
+ return nullptr;
+ }
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ if (waitForWriting(fd) != 1) {
+ ALOGW("SSL_connect write error");
+ return nullptr;
+ }
+ break;
+ default:
+ ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
+ return nullptr;
+ }
+ }
+
+ // TODO: Call SSL_shutdown before discarding the session if validation fails.
+ if (!mServer.fingerprints.empty()) {
+ ALOGV("Checking DNS over TLS fingerprint");
+
+ // We only care that the chain is internally self-consistent, not that
+ // it chains to a trusted root, so we can ignore some kinds of errors.
+ // TODO: Add a CA root verification mode that respects these errors.
+ int verify_result = SSL_get_verify_result(ssl.get());
+ switch (verify_result) {
+ case X509_V_OK:
+ case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
+ case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
+ case X509_V_ERR_CERT_UNTRUSTED:
+ break;
+ default:
+ ALOGW("Invalid certificate chain, error %d", verify_result);
+ return nullptr;
+ }
+
+ STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
+ if (!chain) {
+ ALOGW("Server has null certificate");
+ return nullptr;
+ }
+ // Chain and its contents are owned by ssl, so we don't need to free explicitly.
+ bool matched = false;
+ for (size_t i = 0; i < sk_X509_num(chain); ++i) {
+ // This appears to be O(N^2), but there doesn't seem to be a straightforward
+ // way to walk a STACK_OF nondestructively in linear time.
+ X509* cert = sk_X509_value(chain, i);
+ std::vector<uint8_t> digest;
+ if (!getSPKIDigest(cert, &digest)) {
+ ALOGE("Digest computation failed");
+ return nullptr;
+ }
+
+ if (mServer.fingerprints.count(digest) > 0) {
+ matched = true;
+ break;
+ }
+ }
+
+ if (!matched) {
+ ALOGW("No matching fingerprint");
+ return nullptr;
+ }
+
+ ALOGV("DNS over TLS fingerprint is correct");
+ }
+
+ ALOGV("%u handshake complete", mMark);
+
+ return ssl;
+}
+
+void DnsTlsSocket::sslDisconnect() {
+ if (mSsl) {
+ SSL_shutdown(mSsl.get());
+ mSsl.reset();
+ }
+ mSslFd.reset();
+}
+
+bool DnsTlsSocket::sslWrite(const Slice buffer) {
+ ALOGV("%u Writing %zu bytes", mMark, buffer.size());
+ for (;;) {
+ int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
+ if (ret == int(buffer.size())) break; // SSL write complete;
+
+ if (ret < 1) {
+ const int ssl_err = SSL_get_error(mSsl.get(), ret);
+ switch (ssl_err) {
+ case SSL_ERROR_WANT_WRITE:
+ if (waitForWriting(mSslFd.get()) != 1) {
+ ALOGV("SSL_write error");
+ return false;
+ }
+ continue;
+ case 0:
+ break; // SSL write complete;
+ default:
+ ALOGV("SSL_write error %d", ssl_err);
+ return false;
+ }
+ }
+ }
+ ALOGV("%u Wrote %zu bytes", mMark, buffer.size());
+ return true;
+}
+
+DnsTlsSocket::~DnsTlsSocket() {
+ sslDisconnect();
+}
+
+DnsTlsServer::Result DnsTlsSocket::query(uint16_t id, const Slice query) {
+ std::lock_guard<std::mutex> guard(mLock);
+ const Query q = { .id = id, .query = query };
+ if (!sendQuery(q)) {
+ return { .code = DnsTlsServer::Response::network_error };
+ }
+ return readResponse();
+}
+
+// Read exactly len bytes into buffer or fail
+bool DnsTlsSocket::sslRead(const Slice buffer) {
+ size_t remaining = buffer.size();
+ while (remaining > 0) {
+ int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
+ if (ret == 0) {
+ ALOGW_IF(remaining < buffer.size(), "SSL closed with %zu of %zu bytes remaining",
+ remaining, buffer.size());
+ return false;
+ }
+
+ if (ret < 0) {
+ const int ssl_err = SSL_get_error(mSsl.get(), ret);
+ if (ssl_err == SSL_ERROR_WANT_READ) {
+ if (waitForReading(mSslFd.get()) != 1) {
+ ALOGV("SSL_read error");
+ return false;
+ }
+ continue;
+ } else {
+ ALOGV("SSL_read error %d", ssl_err);
+ return false;
+ }
+ }
+
+ remaining -= ret;
+ }
+ return true;
+}
+
+bool DnsTlsSocket::sendQuery(const Query& q) {
+ ALOGV("sending query");
+ // Compose the entire message in a single buffer, so that it can be
+ // sent as a single TLS record.
+ std::vector<uint8_t> buf(q.query.size() + 4);
+ // Write 2-byte length
+ uint16_t len = q.query.size() + 2; // + 2 for the ID.
+ buf[0] = len >> 8;
+ buf[1] = len;
+ // Write 2-byte ID
+ buf[2] = q.id >> 8;
+ buf[3] = q.id;
+ // Copy body
+ std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
+ if (!sslWrite(netdutils::makeSlice(buf))) {
+ return false;
+ }
+ ALOGV("%u SSL_write complete", mMark);
+ return true;
+}
+
+DnsTlsServer::Result DnsTlsSocket::readResponse() {
+ ALOGV("reading response");
+ uint8_t responseHeader[2];
+ const DnsTlsServer::Result failed = { .code = DnsTlsServer::Response::network_error };
+ if (!sslRead(Slice(responseHeader, 2))) {
+ return failed;
+ }
+ // Truncate responses larger than MAX_SIZE. This is safe because a DNS packet is
+ // always invalid when truncated, so the response will be treated as an error.
+ constexpr uint16_t MAX_SIZE = 8192;
+ const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
+ ALOGV("%u Expecting response of size %i", mMark, responseSize);
+ std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
+ if (!sslRead(netdutils::makeSlice(response))) {
+ ALOGV("%u Failed to read %zu bytes", mMark, response.size());
+ return failed;
+ }
+ uint16_t remainingBytes = responseSize - response.size();
+ while (remainingBytes > 0) {
+ constexpr uint16_t CHUNK_SIZE = 2048;
+ std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
+ if (!sslRead(netdutils::makeSlice(discard))) {
+ ALOGV("%u Failed to discard %zu bytes", mMark, discard.size());
+ return failed;
+ }
+ remainingBytes -= discard.size();
+ }
+ ALOGV("%u SSL_read complete", mMark);
+
+ return { .code = DnsTlsServer::Response::success, .response = response };
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/server/dns/DnsTlsSocket.h b/server/dns/DnsTlsSocket.h
new file mode 100644
index 0000000..9f923c5
--- /dev/null
+++ b/server/dns/DnsTlsSocket.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSOCKET_H
+#define _DNS_DNSTLSSOCKET_H
+
+#include <future>
+#include <mutex>
+#include <openssl/ssl.h>
+
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+#include <netdutils/Slice.h>
+#include <netdutils/Status.h>
+
+#include "dns/DnsTlsServer.h"
+#include "dns/IDnsTlsSocket.h"
+
+namespace android {
+namespace net {
+
+class DnsTlsSessionCache;
+
+using netdutils::Slice;
+
+// A class for managing a TLS socket that sends and receives messages in
+// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
+class DnsTlsSocket : public IDnsTlsSocket {
+public:
+ DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
+ DnsTlsSessionCache* _Nonnull cache) :
+ mMark(mark), mServer(server), mCache(cache) {}
+ ~DnsTlsSocket();
+
+ // Creates the SSL context for this session and connect. Returns false on failure.
+ // This method should be called after construction and before use of a DnsTlsSocket.
+ // Only call this method once per DnsTlsSocket.
+ bool initialize() EXCLUDES(mLock);
+
+ // Send a query on the provided SSL socket. |query| contains
+ // the body of a query, not including the ID header. Returns the server's response.
+ DnsTlsServer::Result query(uint16_t id, const Slice query) override;
+
+private:
+ // Lock to be held while performing a query.
+ std::mutex mLock;
+
+ // On success, sets mSslFd to a socket connected to mAddr (the
+ // connection will likely be in progress if mProtocol is IPPROTO_TCP).
+ // On error, returns the errno.
+ netdutils::Status tcpConnect() REQUIRES(mLock);
+
+ // Connect an SSL session on the provided socket. If connection fails, closing the
+ // socket remains the caller's responsibility.
+ bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
+
+ // Disconnect the SSL session and close the socket.
+ void sslDisconnect() REQUIRES(mLock);
+
+ // Writes a buffer to the socket.
+ bool sslWrite(const Slice buffer) REQUIRES(mLock);
+
+ // Reads exactly the specified number of bytes from the socket. Blocking.
+ // Returns false if the socket closes before enough bytes can be read.
+ bool sslRead(const Slice buffer) REQUIRES(mLock);
+
+ struct Query {
+ uint16_t id;
+ const Slice query;
+ };
+
+ bool sendQuery(const Query& q) REQUIRES(mLock);
+ DnsTlsServer::Result readResponse() REQUIRES(mLock);
+
+ // SSL Socket fields.
+ bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
+ base::unique_fd mSslFd GUARDED_BY(mLock);
+ bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
+
+ const unsigned mMark; // Socket mark
+ const DnsTlsServer mServer;
+ DnsTlsSessionCache* _Nonnull const mCache;
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSSOCKET_H
diff --git a/server/dns/DnsTlsSocketFactory.h b/server/dns/DnsTlsSocketFactory.h
new file mode 100644
index 0000000..9c597a0
--- /dev/null
+++ b/server/dns/DnsTlsSocketFactory.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_DNSTLSSOCKETFACTORY_H
+#define _DNS_DNSTLSSOCKETFACTORY_H
+
+#include <memory>
+
+#include "dns/DnsTlsSocket.h"
+#include "dns/IDnsTlsSocketFactory.h"
+
+namespace android {
+namespace net {
+
+class DnsTlsSessionCache;
+struct DnsTlsServer;
+
+// Trivial RAII factory for DnsTlsSocket. This is owned by DnsTlsDispatcher.
+class DnsTlsSocketFactory : public IDnsTlsSocketFactory {
+public:
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(const DnsTlsServer& server, unsigned mark,
+ DnsTlsSessionCache* _Nonnull cache) override {
+ auto socket = std::make_unique<DnsTlsSocket>(server, mark, cache);
+ if (!socket->initialize()) {
+ return nullptr;
+ }
+ return std::move(socket);
+ }
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_DNSTLSSOCKETFACTORY_H
diff --git a/server/dns/DnsTlsTransport.cpp b/server/dns/DnsTlsTransport.cpp
index 7610c92..4bf33eb 100644
--- a/server/dns/DnsTlsTransport.cpp
+++ b/server/dns/DnsTlsTransport.cpp
@@ -14,15 +14,18 @@
* limitations under the License.
*/
+#define LOG_TAG "DnsTlsTransport"
+
#include "dns/DnsTlsTransport.h"
#include <arpa/inet.h>
#include <arpa/nameser.h>
-#include <errno.h>
-#include <openssl/err.h>
-#define LOG_TAG "DnsTlsTransport"
-#define DBG 0
+#include "dns/DnsTlsServer.h"
+#include "dns/DnsTlsSocketFactory.h"
+#include "dns/IDnsTlsSocketFactory.h"
+
+//#define LOG_NDEBUG 0
#include "log/log.h"
#include "Fwmark.h"
@@ -30,497 +33,33 @@
#include "NetdConstants.h"
#include "Permission.h"
-
namespace android {
namespace net {
-namespace {
-
-constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
-
-bool setNonBlocking(int fd, bool enabled) {
- int flags = fcntl(fd, F_GETFL);
- if (flags < 0) return false;
-
- if (enabled) {
- flags |= O_NONBLOCK;
- } else {
- flags &= ~O_NONBLOCK;
+DnsTlsTransport::Result DnsTlsTransport::query(const netdutils::Slice query) {
+ if (query.size() < 2) {
+ return (Result) { .code = Response::internal_error };
}
- return (fcntl(fd, F_SETFL, flags) == 0);
+
+ const uint8_t* data = query.base();
+ uint16_t id = data[0] << 8 | data[1];
+
+ auto socket = mFactory->createDnsTlsSocket(mServer, mMark, &mCache);
+ if (!socket) {
+ return (Result) { .code = Response::network_error };
+ }
+
+ return socket->query(id, netdutils::drop(query, 2));
}
-int waitForReading(int fd) {
- fd_set fds;
- FD_ZERO(&fds);
- FD_SET(fd, &fds);
- const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
- if (DBG && ret <= 0) {
- ALOGD("select");
- }
- return ret;
-}
-
-int waitForWriting(int fd) {
- fd_set fds;
- FD_ZERO(&fds);
- FD_SET(fd, &fds);
- const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
- if (DBG && ret <= 0) {
- ALOGD("select");
- }
- return ret;
-}
-
-} // namespace
-
-android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
- if (DBG) {
- ALOGD("%u connecting TCP socket", mMark);
- }
- android::base::unique_fd fd;
- int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
- switch (mServer.protocol) {
- case IPPROTO_TCP:
- type |= SOCK_STREAM;
- break;
- default:
- errno = EPROTONOSUPPORT;
- return fd;
- }
-
- fd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
- if (fd.get() == -1) {
- return fd;
- }
-
- const socklen_t len = sizeof(mMark);
- if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
- fd.reset();
- } else if (connect(fd.get(),
- reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0
- && errno != EINPROGRESS) {
- fd.reset();
- }
-
- if (!setNonBlocking(fd, false)) {
- ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
- fd.reset();
- }
-
- return fd;
-}
-
-bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
- int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
- unsigned char spki[spki_len];
- unsigned char* temp = spki;
- if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
- ALOGW("SPKI length mismatch");
- return false;
- }
- out->resize(SHA256_SIZE);
- unsigned int digest_len = 0;
- int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
- if (ret != 1) {
- ALOGW("Server cert digest extraction failed");
- return false;
- }
- if (digest_len != out->size()) {
- ALOGW("Wrong digest length: %d", digest_len);
- return false;
- }
- return true;
-}
-
-bool DnsTlsTransport::initialize() {
- // This method should only be called once, at the beginning, so locking should be
- // unnecessary. This lock only serves to help catch bugs in code that calls this method.
- std::lock_guard<std::mutex> guard(mLock);
- if (mSslCtx) {
- // This is a bug in the caller.
- return false;
- }
- mSslCtx.reset(SSL_CTX_new(TLS_method()));
- if (!mSslCtx) {
- return false;
- }
-
- // Load system CA certs for hostname verification.
- //
- // For discussion of alternative, sustainable approaches see b/71909242.
- if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
- ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
- return false;
- }
-
- SSL_CTX_sess_set_new_cb(mSslCtx.get(), DnsTlsTransport::newSessionCallback);
- SSL_CTX_sess_set_remove_cb(mSslCtx.get(), DnsTlsTransport::removeSessionCallback);
-
- // Enable TLS false start.
- SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
- SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
- return true;
-}
-
-bssl::UniquePtr<SSL> DnsTlsTransport::sslConnect(int fd) {
- // Check TLS context.
- if (!mSslCtx) {
- ALOGE("Internal error: context is null in ssl connect");
- return nullptr;
- }
- if (!SSL_CTX_set_max_proto_version(mSslCtx.get(), TLS1_3_VERSION) ||
- !SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
- ALOGE("failed to min/max TLS versions");
- return nullptr;
- }
-
- bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
- // This file descriptor is owned by a unique_fd, so don't let libssl close it.
- bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
- SSL_set_bio(ssl.get(), bio.get(), bio.get());
- bio.release();
-
- // Add this transport as the 0-index extra data for the socket.
- // This is used by newSessionCallback.
- if (SSL_set_ex_data(ssl.get(), 0, this) != 1) {
- ALOGE("failed to associate SSL socket to transport");
- return nullptr;
- }
-
- // Add this transport as the 0-index extra data for the context.
- // This is used by removeSessionCallback.
- if (SSL_CTX_set_ex_data(mSslCtx.get(), 0, this) != 1) {
- ALOGE("failed to associate SSL context to transport");
- return nullptr;
- }
-
- if (!mServer.name.empty()) {
- if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
- ALOGE("Failed to set SNI to %s", mServer.name.c_str());
- return nullptr;
- }
- X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
- X509_VERIFY_PARAM_set1_host(param, mServer.name.c_str(), 0);
- // This will cause the handshake to fail if certificate verification fails.
- SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
- }
-
- bssl::UniquePtr<SSL_SESSION> session;
- {
- std::lock_guard<std::mutex> guard(mLock);
- if (!mSessions.empty()) {
- session = std::move(mSessions.front());
- mSessions.pop_front();
- } else if (DBG) {
- ALOGD("Starting without session ticket.");
- }
- }
- if (session) {
- SSL_set_session(ssl.get(), session.get());
- }
-
- for (;;) {
- if (DBG) {
- ALOGD("%u Calling SSL_connect", mMark);
- }
- int ret = SSL_connect(ssl.get());
- if (DBG) {
- ALOGD("%u SSL_connect returned %d", mMark, ret);
- }
- if (ret == 1) break; // SSL handshake complete;
-
- const int ssl_err = SSL_get_error(ssl.get(), ret);
- switch (ssl_err) {
- case SSL_ERROR_WANT_READ:
- if (waitForReading(fd) != 1) {
- ALOGW("SSL_connect read error");
- return nullptr;
- }
- break;
- case SSL_ERROR_WANT_WRITE:
- if (waitForWriting(fd) != 1) {
- ALOGW("SSL_connect write error");
- return nullptr;
- }
- break;
- default:
- ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
- return nullptr;
- }
- }
-
- if (!mServer.fingerprints.empty()) {
- if (DBG) {
- ALOGD("Checking DNS over TLS fingerprint");
- }
-
- // We only care that the chain is internally self-consistent, not that
- // it chains to a trusted root, so we can ignore some kinds of errors.
- // TODO: Add a CA root verification mode that respects these errors.
- int verify_result = SSL_get_verify_result(ssl.get());
- switch (verify_result) {
- case X509_V_OK:
- case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
- case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
- case X509_V_ERR_CERT_UNTRUSTED:
- break;
- default:
- ALOGW("Invalid certificate chain, error %d", verify_result);
- return nullptr;
- }
-
- STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
- if (!chain) {
- ALOGW("Server has null certificate");
- return nullptr;
- }
- // Chain and its contents are owned by ssl, so we don't need to free explicitly.
- bool matched = false;
- for (size_t i = 0; i < sk_X509_num(chain); ++i) {
- // This appears to be O(N^2), but there doesn't seem to be a straightforward
- // way to walk a STACK_OF nondestructively in linear time.
- X509* cert = sk_X509_value(chain, i);
- std::vector<uint8_t> digest;
- if (!getSPKIDigest(cert, &digest)) {
- ALOGE("Digest computation failed");
- return nullptr;
- }
-
- if (mServer.fingerprints.count(digest) > 0) {
- matched = true;
- break;
- }
- }
-
- if (!matched) {
- ALOGW("No matching fingerprint");
- return nullptr;
- }
-
- if (DBG) {
- ALOGD("DNS over TLS fingerprint is correct");
- }
- }
-
- if (DBG) {
- ALOGD("%u handshake complete", mMark);
- }
-
- return ssl;
+DnsTlsTransport::~DnsTlsTransport() {
}
// static
-int DnsTlsTransport::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
- if (!session) {
- return 0;
- }
- if (DBG) {
- ALOGD("Recording session ticket");
- }
- DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
- SSL_get_ex_data(ssl, 0));
- if (!xport) {
- ALOGE("null transport in new session callback");
- return 0;
- }
- xport->recordSession(session);
- return 1;
-}
-
-void DnsTlsTransport::removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session) {
- if (DBG) {
- ALOGD("Removing session ticket");
- }
- DnsTlsTransport* xport = reinterpret_cast<DnsTlsTransport*>(
- SSL_CTX_get_ex_data(ssl_ctx, 0));
- if (!xport) {
- ALOGE("null transport in remove session callback");
- return;
- }
- xport->removeSession(session);
-}
-
-void DnsTlsTransport::recordSession(SSL_SESSION* session) {
- std::lock_guard<std::mutex> guard(mLock);
- mSessions.emplace_front(session);
- if (mSessions.size() > 5) {
- if (DBG) {
- ALOGD("Too many sessions; trimming");
- }
- mSessions.pop_back();
- }
-}
-
-void DnsTlsTransport::removeSession(SSL_SESSION* session) {
- std::lock_guard<std::mutex> guard(mLock);
- if (session) {
- // TODO: Consider implementing targeted removal.
- mSessions.clear();
- }
-}
-
-void DnsTlsTransport::sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd) {
- if (ssl) {
- SSL_shutdown(ssl.get());
- ssl.reset();
- }
- fd.reset();
-}
-
-bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
- if (DBG) {
- ALOGD("%u Writing %d bytes", mMark, len);
- }
- for (;;) {
- int ret = SSL_write(ssl, buffer, len);
- if (ret == len) break; // SSL write complete;
-
- if (ret < 1) {
- const int ssl_err = SSL_get_error(ssl, ret);
- switch (ssl_err) {
- case SSL_ERROR_WANT_WRITE:
- if (waitForWriting(fd) != 1) {
- if (DBG) {
- ALOGW("SSL_write error");
- }
- return false;
- }
- continue;
- case 0:
- break; // SSL write complete;
- default:
- if (DBG) {
- ALOGW("SSL_write error %d", ssl_err);
- }
- return false;
- }
- }
- }
- if (DBG) {
- ALOGD("%u Wrote %d bytes", mMark, len);
- }
- return true;
-}
-
-// Read exactly len bytes into buffer or fail
-bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
- int remaining = len;
- while (remaining > 0) {
- int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
- if (ret == 0) {
- ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
- return false;
- }
-
- if (ret < 0) {
- const int ssl_err = SSL_get_error(ssl, ret);
- if (ssl_err == SSL_ERROR_WANT_READ) {
- if (waitForReading(fd) != 1) {
- if (DBG) {
- ALOGW("SSL_read error");
- }
- return false;
- }
- continue;
- } else {
- if (DBG) {
- ALOGW("SSL_read error %d", ssl_err);
- }
- return false;
- }
- }
-
- remaining -= ret;
- }
- return true;
-}
-
-DnsTlsTransport::Response DnsTlsTransport::query(const uint8_t *query, size_t qlen,
- uint8_t *response, size_t limit, int *resplen) {
- android::base::unique_fd fd = makeConnectedSocket();
- if (fd.get() < 0) {
- ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
- return Response::network_error;
- }
- bssl::UniquePtr<SSL> ssl = sslConnect(fd.get());
- if (!ssl) {
- return Response::network_error;
- }
-
- Response res = sendQuery(fd.get(), ssl.get(), query, qlen);
- if (res == Response::success) {
- res = readResponse(fd.get(), ssl.get(), query, response, limit, resplen);
- }
-
- sslDisconnect(std::move(ssl), std::move(fd));
- return res;
-}
-
-DnsTlsTransport::Response DnsTlsTransport::sendQuery(int fd, SSL* ssl,
- const uint8_t *query, size_t qlen) {
- if (DBG) {
- ALOGD("sending query");
- }
- uint8_t queryHeader[2];
- queryHeader[0] = qlen >> 8;
- queryHeader[1] = qlen;
- if (!sslWrite(fd, ssl, queryHeader, 2)) {
- return Response::network_error;
- }
- if (!sslWrite(fd, ssl, query, qlen)) {
- return Response::network_error;
- }
- if (DBG) {
- ALOGD("%u SSL_write complete", mMark);
- }
- return Response::success;
-}
-
-DnsTlsTransport::Response DnsTlsTransport::readResponse(int fd, SSL* ssl,
- const uint8_t *query, uint8_t *response, size_t limit, int *resplen) {
- if (DBG) {
- ALOGD("reading response");
- }
- uint8_t responseHeader[2];
- if (!sslRead(fd, ssl, responseHeader, 2)) {
- if (DBG) {
- ALOGW("%u Failed to read 2-byte length header", mMark);
- }
- return Response::network_error;
- }
- const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
- if (DBG) {
- ALOGD("%u Expecting response of size %i", mMark, responseSize);
- }
- if (responseSize > limit) {
- ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
- return Response::limit_error;
- }
- if (!sslRead(fd, ssl, response, responseSize)) {
- if (DBG) {
- ALOGW("%u Failed to read %i bytes", mMark, responseSize);
- }
- return Response::network_error;
- }
- if (DBG) {
- ALOGD("%u SSL_read complete", mMark);
- }
-
- if (response[0] != query[0] || response[1] != query[1]) {
- ALOGE("reply query ID != query ID");
- return Response::internal_error;
- }
-
- *resplen = responseSize;
- return Response::success;
-}
-
-// static
+// TODO: Use this function to preheat the session cache.
+// That may require moving it to DnsTlsDispatcher.
bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid) {
- if (DBG) {
- ALOGD("Beginning validation on %u", netid);
- }
+ ALOGV("Beginning validation on %u", netid);
// Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
// order to prove that it is actually a working DNS over TLS server.
static const char kDnsSafeChars[] =
@@ -551,9 +90,6 @@
};
const int qlen = ARRAY_SIZE(query);
- const int kRecvBufSize = 4 * 1024;
- uint8_t recvbuf[kRecvBufSize];
-
// At validation time, we only know the netId, so we have to guess/compute the
// corresponding socket mark.
Fwmark fwmark;
@@ -563,22 +99,17 @@
fwmark.netId = netid;
unsigned mark = fwmark.intValue;
int replylen = 0;
- DnsTlsTransport transport(server, mark);
- if (!transport.initialize()) {
- return false;
- }
- transport.query(query, qlen, recvbuf, kRecvBufSize, &replylen);
- if (replylen == 0) {
- if (DBG) {
- ALOGD("query failed");
- }
+ DnsTlsSocketFactory factory;
+ DnsTlsTransport transport(server, mark, &factory);
+ auto r = transport.query(Slice(query, qlen));
+ if (r.code != Response::success) {
+ ALOGV("query failed");
return false;
}
- if (replylen < NS_HFIXEDSZ) {
- if (DBG) {
- ALOGW("short response: %d", replylen);
- }
+ const std::vector<uint8_t>& recvbuf = r.response;
+ if (recvbuf.size() < NS_HFIXEDSZ) {
+ ALOGW("short response: %d", replylen);
return false;
}
@@ -589,9 +120,7 @@
}
const int ancount = (recvbuf[6] << 8) | recvbuf[7];
- if (DBG) {
- ALOGD("%u answer count: %d", netid, ancount);
- }
+ ALOGV("%u answer count: %d", netid, ancount);
// TODO: Further validate the response contents (check for valid AAAA record, ...).
// Note that currently, integration tests rely on this function accepting a
@@ -604,5 +133,5 @@
return true;
}
-} // namespace net
-} // namespace android
+} // end of namespace net
+} // end of namespace android
diff --git a/server/dns/DnsTlsTransport.h b/server/dns/DnsTlsTransport.h
index 8f03c8a..4624be7 100644
--- a/server/dns/DnsTlsTransport.h
+++ b/server/dns/DnsTlsTransport.h
@@ -17,37 +17,29 @@
#ifndef _DNS_DNSTLSTRANSPORT_H
#define _DNS_DNSTLSTRANSPORT_H
-#include <deque>
-#include <mutex>
-#include <openssl/ssl.h>
-
-#include <android-base/thread_annotations.h>
-#include <android-base/unique_fd.h>
-
+#include "dns/DnsTlsSessionCache.h"
#include "dns/DnsTlsServer.h"
+#include <netdutils/Slice.h>
+
namespace android {
namespace net {
+class IDnsTlsSocketFactory;
+
class DnsTlsTransport {
public:
- DnsTlsTransport(const DnsTlsServer& server, unsigned mark)
- : mMark(mark), mServer(server)
- {}
- ~DnsTlsTransport() {}
+ DnsTlsTransport(const DnsTlsServer& server, unsigned mark,
+ IDnsTlsSocketFactory* _Nonnull factory) :
+ mMark(mark), mServer(server), mFactory(factory) {}
+ ~DnsTlsTransport();
- // Creates the SSL context for this session. Returns false on failure.
- // This method should be called after construction and before use of a DnsTlsTransport.
- bool initialize();
-
- enum class Response : uint8_t { success, network_error, limit_error, internal_error };
+ typedef DnsTlsServer::Response Response;
+ typedef DnsTlsServer::Result Result;
- // Given a |query| of length |qlen|, this method sends it to the server
- // and writes the response into |ans|, which can accept up to |anssiz| bytes.
- // The number of bytes is written to |resplen|. If |resplen| is zero, an
- // error has occurred.
- Response query(const uint8_t *query, size_t qlen,
- uint8_t *ans, size_t anssiz, int *resplen);
+ // Given a |query|, this method sends it to the server
+ // and returns the server's response synchronously.
+ Result query(const netdutils::Slice query);
// Check that a given TLS server is fully working on the specified netid, and has the
// provided SHA-256 fingerprint (if nonempty). This function is used in ResolverController
@@ -55,50 +47,14 @@
static bool validate(const DnsTlsServer& server, unsigned netid);
private:
- // Send a query on the provided SSL socket.
- Response sendQuery(int fd, SSL* ssl, const uint8_t *query, size_t qlen);
-
- // Wait for the response to |query| on |ssl|, and write it to |ans|, an output buffer
- // of size |anssiz|. If |resplen| is zero, the read failed.
- Response readResponse(int fd, SSL* ssl, const uint8_t *query,
- uint8_t *ans, size_t anssiz, int *resplen);
-
- // On success, returns a non-blocking socket connected to mAddr (the
- // connection will likely be in progress if mProtocol is IPPROTO_TCP).
- // On error, returns -1 with errno set appropriately.
- base::unique_fd makeConnectedSocket() const;
-
- // Connect an SSL session on the provided socket. If connection fails, closing the
- // socket remains the caller's responsibility.
- bssl::UniquePtr<SSL> sslConnect(int fd);
-
- // Disconnect the SSL session and close the socket.
- void sslDisconnect(bssl::UniquePtr<SSL> ssl, base::unique_fd fd);
-
- // Writes a buffer to the socket.
- bool sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len);
-
- // Reads exactly the specified number of bytes from the socket. Blocking.
- // Returns false if the socket closes before enough bytes can be read.
- bool sslRead(int fd, SSL *ssl, uint8_t *buffer, int len);
-
- // Using SSL_CTX to create new SSL objects is thread-safe, so this object does not
- // require a lock annotation.
- bssl::UniquePtr<SSL_CTX> mSslCtx;
+ DnsTlsSessionCache mCache;
const unsigned mMark; // Socket mark
const DnsTlsServer mServer;
-
- // Cache of recently seen SSL_SESSIONs. This is used to support session tickets.
- static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
- void recordSession(SSL_SESSION* session);
- static void removeSessionCallback(SSL_CTX* ssl_ctx, SSL_SESSION* session);
- void removeSession(SSL_SESSION* session);
- std::mutex mLock;
- std::deque<bssl::UniquePtr<SSL_SESSION>> mSessions GUARDED_BY(mLock);
+ IDnsTlsSocketFactory* _Nonnull const mFactory;
};
-} // namespace net
-} // namespace android
+} // end of namespace net
+} // end of namespace android
#endif // _DNS_DNSTLSTRANSPORT_H
diff --git a/server/dns/IDnsTlsSocket.h b/server/dns/IDnsTlsSocket.h
new file mode 100644
index 0000000..1551418
--- /dev/null
+++ b/server/dns/IDnsTlsSocket.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_IDNSTLSSOCKET_H
+#define _DNS_IDNSTLSSOCKET_H
+
+#include <cstdint>
+#include <cstddef>
+
+#include <netdutils/Slice.h>
+
+#include "dns/DnsTlsServer.h"
+
+namespace android {
+namespace net {
+
+class IDnsTlsSocketObserver;
+class DnsTlsSessionCache;
+
+// A class for managing a TLS socket that sends and receives messages in
+// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
+class IDnsTlsSocket {
+public:
+ virtual ~IDnsTlsSocket() {};
+ // Send a query on the provided SSL socket. |query| contains
+ // the body of a query, not including the ID bytes. Returns the server's response.
+ virtual DnsTlsServer::Result query(uint16_t id, const netdutils::Slice query) = 0;
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_IDNSTLSSOCKET_H
diff --git a/server/dns/IDnsTlsSocketFactory.h b/server/dns/IDnsTlsSocketFactory.h
new file mode 100644
index 0000000..6fc0cfd
--- /dev/null
+++ b/server/dns/IDnsTlsSocketFactory.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _DNS_IDNSTLSSOCKETFACTORY_H
+#define _DNS_IDNSTLSSOCKETFACTORY_H
+
+#include "dns/IDnsTlsSocket.h"
+
+namespace android {
+namespace net {
+
+class DnsTlsSessionCache;
+struct DnsTlsServer;
+
+// Dependency injection interface for DnsTlsSocketFactory.
+// This pattern allows mocking of DnsTlsSocket for tests.
+class IDnsTlsSocketFactory {
+public:
+ virtual ~IDnsTlsSocketFactory() {};
+ virtual std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server,
+ unsigned mark,
+ DnsTlsSessionCache* _Nonnull cache) = 0;
+};
+
+} // end of namespace net
+} // end of namespace android
+
+#endif // _DNS_IDNSTLSSOCKETFACTORY_H
diff --git a/server/dns/README.md b/server/dns/README.md
new file mode 100644
index 0000000..23185a0
--- /dev/null
+++ b/server/dns/README.md
@@ -0,0 +1,136 @@
+# DNS-over-TLS query forwarder design
+# NOTE: This design is not yet implemented in this change.
+
+
+## Overview
+
+The DNS-over-TLS query forwarder consists of five classes:
+ * `DnsTlsDispatcher`
+ * `DnsTlsTransport`
+ * `DnsTlsQueryMap`
+ * `DnsTlsSessionCache`
+ * `DnsTlsSocket`
+
+`DnsTlsDispatcher` is a singleton class whose `query` method is the `dns/` directory's
+only public interface. `DnsTlsDispatcher` is just a table holding the
+`DnsTlsTransport` for each server (represented by a `DnsTlsServer` struct) and
+network. `DnsTlsDispatcher` also blocks each query thread, waiting on a
+`std::future` returned by `DnsTlsTransport` that represents the response.
+
+`DnsTlsTransport` sends each query over a `DnsTlsSocket`, opening a
+new one if necessary. It also has to listen for responses from the
+`DnsTlsSocket`, which happen on a different thread.
+`IDnsTlsSocketObserver` is an interface defining how `DnsTlsSocket` returns
+responses to `DnsTlsTransport`.
+
+`DnsTlsQueryMap` and `DnsTlsSessionCache` are helper classes owned by `DnsTlsTransport`.
+`DnsTlsQueryMap` handles ID renumbering and query-response pairing.
+`DnsTlsSessionCache` allows TLS session resumption.
+
+`DnsTlsSocket` interleaves all queries onto a single socket, and reports all
+responses to `DnsTlsTransport` (through the `IDnsTlsObserver` interface). It doesn't
+know anything about which queries correspond to which responses, and does not retain
+state to indicate whether there is an outstanding query.
+
+## Threading
+
+### Overall patterns
+
+For clarity, each of the five classes in this design is thread-safe and holds one lock.
+Classes that spawn a helper thread call `thread::join()` in their destructor to ensure
+that it is cleaned up appropriately.
+
+All the classes here make full use of Clang thread annotations (and also null-pointer
+annotations) to minimize the likelihood of a latent threading bug. The unit tests are
+also heavily threaded to exercise this functionality.
+
+This code creates O(1) threads per socket, and does not create a new thread for each
+query or response. However, bionic's stub resolver does create a thread for each query.
+
+### Threading in `DnsTlsSocket`
+
+`DnsTlsSocket` can receive queries on any thread, and send them over a
+"reliable datagram pipe" (`socketpair()` in `SOCK_SEQPACKET` mode).
+The query method writes a struct (containing a pointer to the query) to the pipe
+from its thread, and the loop thread (which owns the SSL socket)
+reads off the other end of the pipe. The pipe doesn't actually have a queue "inside";
+instead, any queueing happens by blocking the query thread until the
+socket thread can read the datagram off the other end.
+
+We need to pass messages between threads using a pipe, and not a condition variable
+or a thread-safe queue, because the socket thread has to be blocked
+in `select` waiting for data from the server, but also has to be woken
+up on inputs from the query threads. Therefore, inputs from the query
+threads have to arrive on a socket, so that `select()` can listen for them.
+(There can only be a single thread because [you can't use different threads
+to read and write in OpenSSL](https://www.openssl.org/blog/blog/2017/02/21/threads/)).
+
+## ID renumbering
+
+`DnsTlsDispatcher` accepts queries that have colliding ID numbers and still sends them on
+a single socket. To avoid confusion at the server, `DnsTlsQueryMap` assigns each
+query a new ID for transmission, records the mapping from input IDs to sent IDs, and
+applies the inverse mapping to responses before returning them to the caller.
+
+`DnsTlsQueryMap` assigns each new query the ID number one greater than the largest
+ID number of an outstanding query. This means that ID numbers are initially sequential
+and usually small. If the largest possible ID number is already in use,
+`DnsTlsQueryMap` will scan the ID space to find an available ID, or fail the query
+if there are no available IDs. Queries will not block waiting for an ID number to
+become available.
+
+## Time constants
+
+`DnsTlsSocket` imposes a 20-second inactivity timeout. A socket that has been idle for
+20 seconds will be closed. This sets the limit of tolerance for slow replies,
+which could happen as a result of malfunctioning authoritative DNS servers.
+If there are any pending queries, `DnsTlsTransport` will retry them.
+
+`DnsTlsQueryMap` imposes a retry limit of 3. `DnsTlsTransport` will retry the query up
+to 3 times before reporting failure to `DnsTlsDispatcher`.
+This limit helps to ensure proper functioning in the case of a recursive resolver that
+is malfunctioning or is flooded with requests that are stalled due to malfunctioning
+authoritative servers.
+
+`DnsTlsDispatcher` maintains a 5-minute timeout. Any `DnsTlsTransport` that has had no
+outstanding queries for 5 minutes will be destroyed at the next query on a different
+transport.
+This sets the limit on how long session tickets will be preserved during idle periods,
+because each `DnsTlsTransport` owns a `DnsTlsSessionCache`. Imposing this timeout
+increases latency on the first query after an idle period, but also helps to avoid
+unbounded memory usage.
+
+`DnsTlsSessionCache` sets a limit of 5 sessions in each cache, expiring the oldest one
+when the limit is reached. However, because the client code does not currently
+reuse sessions more than once, it should not be possible to hit this limit.
+
+## Testing
+
+Unit tests are in `../tests/dns_tls_test.cpp`. They cover all the classes except
+`DnsTlsSocket` (which requires `CAP_NET_ADMIN` because it uses `setsockopt(SO_MARK)`) and
+`DnsTlsSessionCache` (which requires integration with libssl). These classes are
+exercised by the integration tests in `../tests/netd_test.cpp`.
+
+### Dependency Injection
+
+For unit testing, we would like to be able to mock out `DnsTlsSocket`. This is
+particularly required for unit testing of `DnsTlsDispatcher` and `DnsTlsTransport`.
+To make these unit tests possible, this code uses a dependency injection pattern:
+`DnsTlsSocket` is produced by a `DnsTlsSocketFactory`, and both of these have a
+defined interface.
+
+`DnsTlsDispatcher`'s constructor takes an `IDnsTlsSocketFactory`,
+which in production is a `DnsTlsSocketFactory`. However, in unit tests, we can
+substitute a test factory that returns a fake socket, so that the unit tests can
+run without actually connecting over TLS to a test server. (The integration tests
+do actual TLS.)
+
+## Logging
+
+This code uses `ALOGV` throughout for low-priority logging, and does not use
+`ALOGD`. `ALOGV` is disabled by default, unless activated by `#define LOG_NDEBUG 0`.
+(`ALOGD` is not disabled by default, requiring extra measures to avoid spamming the
+system log in production builds.)
+
+## Reference
+ * [BoringSSL API docs](https://commondatastorage.googleapis.com/chromium-boringssl-docs/headers.html)
diff --git a/tests/Android.mk b/tests/Android.mk
index 9812bcb..5c6fde5 100644
--- a/tests/Android.mk
+++ b/tests/Android.mk
@@ -38,11 +38,18 @@
# runtest -x system/netd/tests/netd_integration_test.cpp
LOCAL_SRC_FILES := binder_test.cpp \
dns_responder/dns_responder.cpp \
+ dns_tls_test.cpp \
netd_integration_test.cpp \
netd_test.cpp \
tun_interface.cpp \
../server/NetdConstants.cpp \
- ../server/binder/android/net/metrics/INetdEventListener.aidl
+ ../server/binder/android/net/metrics/INetdEventListener.aidl \
+ ../server/dns/DnsTlsDispatcher.cpp \
+ ../server/dns/DnsTlsTransport.cpp \
+ ../server/dns/DnsTlsServer.cpp \
+ ../server/dns/DnsTlsSessionCache.cpp \
+ ../server/dns/DnsTlsSocket.cpp \
+
LOCAL_MODULE_TAGS := eng tests
include $(BUILD_NATIVE_TEST)
diff --git a/tests/dns_responder/dns_tls_frontend.cpp b/tests/dns_responder/dns_tls_frontend.cpp
index b360060..6c29353 100644
--- a/tests/dns_responder/dns_tls_frontend.cpp
+++ b/tests/dns_responder/dns_tls_frontend.cpp
@@ -317,9 +317,14 @@
}
const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
uint8_t query[qlen];
- if (SSL_read(ssl, &query, qlen) != qlen) {
- ALOGI("Not enough query bytes");
- return false;
+ size_t qbytes = 0;
+ while (qbytes < qlen) {
+ int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
+ if (ret <= 0) {
+ ALOGI("Error while reading query");
+ return false;
+ }
+ qbytes += ret;
}
int sent = send(backend_socket_, query, qlen, 0);
if (sent != qlen) {
diff --git a/tests/dns_tls_test.cpp b/tests/dns_tls_test.cpp
new file mode 100644
index 0000000..4cb4d50
--- /dev/null
+++ b/tests/dns_tls_test.cpp
@@ -0,0 +1,372 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "dns_tls_test"
+
+#include <gtest/gtest.h>
+
+#include "dns/DnsTlsDispatcher.h"
+#include "dns/DnsTlsServer.h"
+#include "dns/DnsTlsSessionCache.h"
+#include "dns/DnsTlsSocket.h"
+#include "dns/DnsTlsTransport.h"
+#include "dns/IDnsTlsSocket.h"
+#include "dns/IDnsTlsSocketFactory.h"
+
+#include <chrono>
+#include <arpa/inet.h>
+#include <android-base/macros.h>
+#include <netdutils/Slice.h>
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+using netdutils::Slice;
+using netdutils::makeSlice;
+
+typedef std::vector<uint8_t> bytevec;
+
+static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
+ sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
+ if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
+ // IPv4 parse succeeded, so it's IPv4
+ sin->sin_family = AF_INET;
+ sin->sin_port = htons(port);
+ return;
+ }
+ sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
+ if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
+ // IPv6 parse succeeded, so it's IPv6.
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_port = htons(port);
+ return;
+ }
+ ALOGE("Failed to parse server address: %s", server);
+}
+
+bytevec FINGERPRINT1 = { 1 };
+bytevec FINGERPRINT2 = { 2 };
+
+std::string SERVERNAME1 = "dns.example.com";
+std::string SERVERNAME2 = "dns.example.org";
+
+// BaseTest just provides constants that are useful for the tests.
+class BaseTest : public ::testing::Test {
+protected:
+ BaseTest() {
+ parseServer("192.0.2.1", 853, &V4ADDR1);
+ parseServer("192.0.2.2", 853, &V4ADDR2);
+ parseServer("2001:db8::1", 853, &V6ADDR1);
+ parseServer("2001:db8::2", 853, &V6ADDR2);
+
+ SERVER1 = DnsTlsServer(V4ADDR1);
+ SERVER1.fingerprints.insert(FINGERPRINT1);
+ SERVER1.name = SERVERNAME1;
+ }
+
+ sockaddr_storage V4ADDR1;
+ sockaddr_storage V4ADDR2;
+ sockaddr_storage V6ADDR1;
+ sockaddr_storage V6ADDR2;
+
+ DnsTlsServer SERVER1;
+};
+
+bytevec make_query(uint16_t id, size_t size) {
+ bytevec vec(size);
+ vec[0] = id >> 8;
+ vec[1] = id;
+ // Arbitrarily fill the query body with unique data.
+ for (size_t i = 2; i < size; ++i) {
+ vec[i] = id + i;
+ }
+ return vec;
+}
+
+// Query constants
+const unsigned MARK = 123;
+const uint16_t ID = 52;
+const uint16_t SIZE = 22;
+const bytevec QUERY = make_query(ID, SIZE);
+
+template <class T>
+class FakeSocketFactory : public IDnsTlsSocketFactory {
+public:
+ FakeSocketFactory() {}
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server ATTRIBUTE_UNUSED,
+ unsigned mark ATTRIBUTE_UNUSED,
+ DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+ return std::make_unique<T>();
+ }
+};
+
+bytevec make_echo(uint16_t id, const Slice query) {
+ bytevec response(query.size() + 2);
+ response[0] = id >> 8;
+ response[1] = id;
+ // Echo the query as the fake response.
+ memcpy(response.data() + 2, query.base(), query.size());
+ return response;
+}
+
+// Simplest possible fake server. This just echoes the query as the response.
+class FakeSocketEcho : public IDnsTlsSocket {
+public:
+ FakeSocketEcho() {}
+ DnsTlsServer::Result query(uint16_t id, const Slice query) override {
+ // Return the response immediately.
+ return { .code = DnsTlsServer::Response::success, .response = make_echo(id, query) };
+ }
+};
+
+class TransportTest : public BaseTest {};
+
+TEST_F(TransportTest, Query) {
+ FakeSocketFactory<FakeSocketEcho> factory;
+ DnsTlsTransport transport(SERVER1, MARK, &factory);
+ auto r = transport.query(makeSlice(QUERY));
+
+ EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
+ EXPECT_EQ(QUERY, r.response);
+}
+
+TEST_F(TransportTest, SerialQueries) {
+ FakeSocketFactory<FakeSocketEcho> factory;
+ DnsTlsTransport transport(SERVER1, MARK, &factory);
+ // Send more than 65536 queries serially.
+ for (int i = 0; i < 100000; ++i) {
+ auto r = transport.query(makeSlice(QUERY));
+
+ EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
+ EXPECT_EQ(QUERY, r.response);
+ }
+}
+
+// Returning null from the factory indicates a connection failure.
+class NullSocketFactory : public IDnsTlsSocketFactory {
+public:
+ NullSocketFactory() {}
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server ATTRIBUTE_UNUSED,
+ unsigned mark ATTRIBUTE_UNUSED,
+ DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+ return nullptr;
+ }
+};
+
+TEST_F(TransportTest, ConnectFail) {
+ NullSocketFactory factory;
+ DnsTlsTransport transport(SERVER1, MARK, &factory);
+ auto r = transport.query(makeSlice(QUERY));
+
+ EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
+ EXPECT_TRUE(r.response.empty());
+}
+
+// Dispatcher tests
+class DispatcherTest : public BaseTest {};
+
+TEST_F(DispatcherTest, Query) {
+ bytevec ans(4096);
+ int resplen = 0;
+
+ auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
+ DnsTlsDispatcher dispatcher(std::move(factory));
+ auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
+ makeSlice(ans), &resplen);
+
+ EXPECT_EQ(DnsTlsTransport::Response::success, r);
+ EXPECT_EQ(int(QUERY.size()), resplen);
+ ans.resize(resplen);
+ EXPECT_EQ(QUERY, ans);
+}
+
+TEST_F(DispatcherTest, AnswerTooLarge) {
+ bytevec ans(SIZE - 1); // Too small to hold the answer
+ int resplen = 0;
+
+ auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
+ DnsTlsDispatcher dispatcher(std::move(factory));
+ auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
+ makeSlice(ans), &resplen);
+
+ EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
+}
+
+template<class T>
+class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
+public:
+ TrackingFakeSocketFactory() {}
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server,
+ unsigned mark,
+ DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+ std::lock_guard<std::mutex> guard(mLock);
+ keys.emplace(mark, server);
+ return std::make_unique<T>();
+ }
+ std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
+private:
+ std::mutex mLock;
+};
+
+TEST_F(DispatcherTest, Dispatching) {
+ auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketEcho>>();
+ auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
+ DnsTlsDispatcher dispatcher(std::move(factory));
+
+ // Populate a vector of two servers and two socket marks, four combinations
+ // in total.
+ std::vector<std::pair<unsigned, DnsTlsServer>> keys;
+ keys.emplace_back(MARK, SERVER1);
+ keys.emplace_back(MARK + 1, SERVER1);
+ keys.emplace_back(MARK, V4ADDR2);
+ keys.emplace_back(MARK + 1, V4ADDR2);
+
+ // Do one query on each server. They should all succeed.
+ std::vector<std::thread> threads;
+ for (size_t i = 0; i < keys.size(); ++i) {
+ auto key = keys[i % keys.size()];
+ threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
+ auto q = make_query(i, SIZE);
+ bytevec ans(4096);
+ int resplen = 0;
+ unsigned mark = key.first;
+ const DnsTlsServer& server = key.second;
+ auto r = dispatcher->query(server, mark, makeSlice(q),
+ makeSlice(ans), &resplen);
+ EXPECT_EQ(DnsTlsTransport::Response::success, r);
+ EXPECT_EQ(int(q.size()), resplen);
+ ans.resize(resplen);
+ EXPECT_EQ(q, ans);
+ }, &dispatcher);
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ // We expect that the factory created one socket for each key.
+ EXPECT_EQ(keys.size(), weak_factory->keys.size());
+ for (auto& key : keys) {
+ EXPECT_EQ(1U, weak_factory->keys.count(key));
+ }
+}
+
+// Check DnsTlsServer's comparison logic.
+AddressComparator ADDRESS_COMPARATOR;
+bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
+ bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
+ bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
+ EXPECT_FALSE(cmp1 && cmp2);
+ return !cmp1 && !cmp2;
+}
+
+void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
+ EXPECT_TRUE(s1 == s1);
+ EXPECT_TRUE(s2 == s2);
+ EXPECT_TRUE(isAddressEqual(s1, s1));
+ EXPECT_TRUE(isAddressEqual(s2, s2));
+
+ EXPECT_TRUE(s1 < s2 ^ s2 < s1);
+ EXPECT_FALSE(s1 == s2);
+ EXPECT_FALSE(s2 == s1);
+}
+
+class ServerTest : public BaseTest {};
+
+TEST_F(ServerTest, IPv4) {
+ checkUnequal(V4ADDR1, V4ADDR2);
+ EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
+}
+
+TEST_F(ServerTest, IPv6) {
+ checkUnequal(V6ADDR1, V6ADDR2);
+ EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
+}
+
+TEST_F(ServerTest, MixedAddressFamily) {
+ checkUnequal(V6ADDR1, V4ADDR1);
+ EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
+}
+
+TEST_F(ServerTest, IPv6ScopeId) {
+ DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
+ sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
+ addr1->sin6_scope_id = 1;
+ sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
+ addr2->sin6_scope_id = 2;
+ checkUnequal(s1, s2);
+ EXPECT_FALSE(isAddressEqual(s1, s2));
+}
+
+TEST_F(ServerTest, IPv6FlowInfo) {
+ DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
+ sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
+ addr1->sin6_flowinfo = 1;
+ sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
+ addr2->sin6_flowinfo = 2;
+ // All comparisons ignore flowinfo.
+ EXPECT_EQ(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+}
+
+TEST_F(ServerTest, Port) {
+ DnsTlsServer s1, s2;
+ parseServer("192.0.2.1", 853, &s1.ss);
+ parseServer("192.0.2.1", 854, &s2.ss);
+ checkUnequal(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ DnsTlsServer s3, s4;
+ parseServer("2001:db8::1", 853, &s3.ss);
+ parseServer("2001:db8::1", 852, &s4.ss);
+ checkUnequal(s3, s4);
+ EXPECT_TRUE(isAddressEqual(s3, s4));
+}
+
+TEST_F(ServerTest, Name) {
+ DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
+ s1.name = SERVERNAME1;
+ checkUnequal(s1, s2);
+ s2.name = SERVERNAME2;
+ checkUnequal(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+}
+
+TEST_F(ServerTest, Fingerprint) {
+ DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
+
+ s1.fingerprints.insert(FINGERPRINT1);
+ checkUnequal(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ s2.fingerprints.insert(FINGERPRINT2);
+ checkUnequal(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ s2.fingerprints.insert(FINGERPRINT1);
+ checkUnequal(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+
+ s1.fingerprints.insert(FINGERPRINT2);
+ EXPECT_EQ(s1, s2);
+ EXPECT_TRUE(isAddressEqual(s1, s2));
+}
+
+} // end of namespace net
+} // end of namespace android
diff --git a/tests/runtests.sh b/tests/runtests.sh
new file mode 100755
index 0000000..c1a6b70
--- /dev/null
+++ b/tests/runtests.sh
@@ -0,0 +1,94 @@
+#!/usr/bin/env bash
+
+readonly PROJECT_TOP="system/netd"
+
+# TODO:
+# - add Android.bp test targets
+# - switch away from runtest.py
+readonly ALL_TESTS="
+ server/netd_unit_test.cpp
+ tests/netd_integration_test.cpp
+"
+
+REPO_TOP=""
+DEBUG=""
+
+function logToStdErr() {
+ echo "$1" >&2
+}
+
+function testAndSetRepoTop() {
+ if [[ -n "$1" && -d "$1/.repo" ]]; then
+ REPO_TOP="$1"
+ return 0
+ fi
+ return 1
+}
+
+function gotoRepoTop() {
+ if testAndSetRepoTop "$ANDROID_BUILD_TOP"; then
+ return
+ fi
+
+ while ! testAndSetRepoTop "$PWD"; do
+ if [[ "$PWD" == "/" ]]; then
+ break
+ fi
+ cd ..
+ done
+}
+
+function runOneTest() {
+ local testName="$1"
+ local cmd="$REPO_TOP/development/testrunner/runtest.py -x $PROJECT_TOP/$testName"
+ echo "###"
+ echo "# $testName"
+ echo "#"
+ echo "# $cmd"
+ echo "###"
+ echo ""
+ $DEBUG $cmd
+ local rval=$?
+ echo ""
+
+ # NOTE: currently runtest.py returns 0 even for failed tests.
+ return $rval
+}
+
+function main() {
+ gotoRepoTop
+ if ! testAndSetRepoTop "$REPO_TOP"; then
+ logToStdErr "Could not find useful top of repo directory"
+ return 1
+ fi
+ logToStdErr "using REPO_TOP=$REPO_TOP"
+
+ if [[ -n "$1" ]]; then
+ case "$1" in
+ "-n")
+ DEBUG=echo
+ shift
+ ;;
+ esac
+ fi
+
+ # Allow us to do things like "runtests.sh integration", etc.
+ readonly TEST_REGEX="$1"
+
+ failures=0
+ for testName in $ALL_TESTS; do
+ if [[ -z "$TEST_REGEX" || "$testName" =~ "$TEST_REGEX" ]]; then
+ runOneTest "$testName"
+ let failures+=$?
+ else
+ logToStdErr "Skipping $testName"
+ fi
+ done
+
+ echo "Number of tests failing: $failures"
+ return $failures
+}
+
+
+main "$@"
+exit $?