Let netd to use the new set_nameservers_for_net call.
Also add more test for netd's resolver.
Change-Id: I79fa6c2d754ace6a76804afccf60c4443b49bf6a
diff --git a/server/CommandListener.cpp b/server/CommandListener.cpp
index 7ecbffc..47edd86 100644
--- a/server/CommandListener.cpp
+++ b/server/CommandListener.cpp
@@ -26,6 +26,7 @@
#include <string.h>
#include <linux/if.h>
#include <resolv_netid.h>
+#include <resolv_params.h>
#define __STDC_FORMAT_MACROS 1
#include <inttypes.h>
@@ -829,12 +830,10 @@
// and making that check here.
if (!strcmp(argv[1], "setnetdns")) {
- // "resolver setnetdns <netId> <domains> <dns1> <dns2> ..."
- if (argc >= 5) {
- rc = sResolverCtrl->setDnsServers(netId, argv[3], &argv[4], argc - 4);
- } else {
+ // "resolver setnetdns <netId> <domains> <dns servers> [<params>]"
+ if (!parseAndExecuteSetNetDns(netId, argc, argv)) {
cli->sendMsg(ResponseCode::CommandSyntaxError,
- "Wrong number of arguments to resolver setnetdns", false);
+ "Wrong number of or invalid arguments to resolver setnetdns", false);
return 0;
}
} else if (!strcmp(argv[1], "clearnetdns")) { // "resolver clearnetdns <netId>"
@@ -867,6 +866,28 @@
return 0;
}
+bool CommandListener::ResolverCmd::parseAndExecuteSetNetDns(int netId, int argc,
+ const char** argv) {
+ // "resolver setnetdns <netId> <domains> <dns1> [<dns2> ...] [--params <params>]"
+ // TODO: This code has to be replaced by a Binder call ASAP
+ if (argc < 5) {
+ return false;
+ }
+ int end = argc;
+ __res_params params;
+ const __res_params* paramsPtr = nullptr;
+ if (end > 6 && !strcmp(argv[end - 2], "--params")) {
+ const char* paramsStr = argv[end - 1];
+ end -= 2;
+ if (sscanf(paramsStr, "%hu %hhu %hhu %hhu", ¶ms.sample_validity,
+ ¶ms.success_threshold, ¶ms.min_samples, ¶ms.max_samples) != 4) {
+ return false;
+ }
+ paramsPtr = ¶ms;
+ }
+ return sResolverCtrl->setDnsServers(netId, argv[3], &argv[4], end - 4, paramsPtr) == 0;
+}
+
CommandListener::BandwidthControlCmd::BandwidthControlCmd() :
NetdCommand("bandwidth") {
}
diff --git a/server/CommandListener.h b/server/CommandListener.h
index 72f4da1..6846323 100644
--- a/server/CommandListener.h
+++ b/server/CommandListener.h
@@ -126,6 +126,9 @@
ResolverCmd();
virtual ~ResolverCmd() {}
int runCommand(SocketClient *c, int argc, char ** argv);
+
+ private:
+ bool parseAndExecuteSetNetDns(int netId, int argc, const char** argv);
};
class FirewallCmd: public NetdCommand {
diff --git a/server/ResolverController.cpp b/server/ResolverController.cpp
index 639423d..16cfd53 100644
--- a/server/ResolverController.cpp
+++ b/server/ResolverController.cpp
@@ -25,21 +25,21 @@
// declarations for _resolv_set_nameservers_for_net and
// _resolv_flush_cache_for_net
#include <resolv_netid.h>
+#include <resolv_params.h>
#include "ResolverController.h"
-int ResolverController::setDnsServers(unsigned netId, const char* domains,
- const char** servers, int numservers) {
+int ResolverController::setDnsServers(unsigned netId, const char* searchDomains,
+ const char** servers, int numservers, const __res_params* params) {
if (DBG) {
ALOGD("setDnsServers netId = %u\n", netId);
}
- _resolv_set_nameservers_for_net(netId, servers, numservers, domains);
-
+ _resolv_set_nameservers_for_net(netId, servers, numservers, searchDomains, params);
return 0;
}
int ResolverController::clearDnsServers(unsigned netId) {
- _resolv_set_nameservers_for_net(netId, NULL, 0, "");
+ _resolv_set_nameservers_for_net(netId, NULL, 0, "", NULL);
if (DBG) {
ALOGD("clearDnsServers netId = %u\n", netId);
}
diff --git a/server/ResolverController.h b/server/ResolverController.h
index 39f002d..048ff3f 100644
--- a/server/ResolverController.h
+++ b/server/ResolverController.h
@@ -20,13 +20,14 @@
#include <netinet/in.h>
#include <linux/in.h>
+struct __res_params;
+
class ResolverController {
public:
ResolverController() {};
virtual ~ResolverController() {};
-
- int setDnsServers(unsigned netid, const char * domains, const char** servers,
- int numservers);
+ int setDnsServers(unsigned netId, const char* searchDomains, const char** servers,
+ int numservers, const __res_params* params);
int clearDnsServers(unsigned netid);
int flushDnsCache(unsigned netid);
// TODO: Add deleteDnsCache(unsigned netId)
diff --git a/tests/Android.mk b/tests/Android.mk
index 848292a..211411b 100644
--- a/tests/Android.mk
+++ b/tests/Android.mk
@@ -27,10 +27,10 @@
include $(CLEAR_VARS)
LOCAL_MODULE := netd_test
EXTRA_LDLIBS := -lpthread
-LOCAL_SHARED_LIBRARIES += libcutils libutils liblog libnetd_client
+LOCAL_SHARED_LIBRARIES += libbase libcutils libutils liblog libnetd_client
LOCAL_STATIC_LIBRARIES += libtestUtil
-LOCAL_C_INCLUDES += system/netd/include system/extras/tests/include
+LOCAL_C_INCLUDES += system/core/base/include system/netd/include \
+ system/extras/tests/include bionic/libc/dns/include
LOCAL_SRC_FILES := netd_test.cpp dns_responder.cpp
LOCAL_MODULE_TAGS := eng tests
include $(BUILD_NATIVE_TEST)
-
diff --git a/tests/dns_responder.cpp b/tests/dns_responder.cpp
index e7baeca..09d6379 100644
--- a/tests/dns_responder.cpp
+++ b/tests/dns_responder.cpp
@@ -21,12 +21,14 @@
#include <netdb.h>
#include <stdarg.h>
#include <stdio.h>
+#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
+#include <iostream>
#include <vector>
#include <log/log.h>
@@ -365,7 +367,7 @@
bool qr;
uint8_t opcode;
bool aa;
- bool tc;
+ bool tr;
bool rd;
std::vector<DNSQuestion> questions;
std::vector<DNSRecord> answers;
@@ -378,8 +380,8 @@
private:
struct Header {
uint16_t id;
- uint8_t rcode;
- uint8_t op;
+ uint8_t flags0;
+ uint8_t flags1;
uint16_t qdcount;
uint16_t ancount;
uint16_t nscount;
@@ -451,9 +453,13 @@
return nullptr;
}
Header& header = *reinterpret_cast<Header*>(buffer);
+ // bytes 0-1
header.id = htons(id);
- header.rcode = (rcode << 4) | ra;
- header.op = (rd << 7) | (tc << 6) | (aa << 5) | (opcode << 1) | qr;
+ // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
+ header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
+ // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
+ header.flags1 = rcode;
+ // rest of header
header.qdcount = htons(questions.size());
header.ancount = htons(answers.size());
header.nscount = htons(authorities.size());
@@ -489,14 +495,18 @@
if (buffer + sizeof(Header) > buffer_end)
return 0;
const auto& header = *reinterpret_cast<const Header*>(buffer);
+ // bytes 0-1
id = ntohs(header.id);
- ra = header.rcode & 1;
- rcode = header.rcode >> 4;
- qr = header.op & 1;
- opcode = (header.op >> 1) & 0x0F;
- aa = (header.op >> 5) & 1;
- tc = (header.op >> 6) & 1;
- rd = header.op >> 7;
+ // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
+ qr = header.flags0 >> 7;
+ opcode = (header.flags0 >> 3) & 0x0F;
+ aa = (header.flags0 >> 2) & 1;
+ tr = (header.flags0 >> 1) & 1;
+ rd = header.flags0 & 1;
+ // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
+ ra = header.flags1 >> 7;
+ rcode = header.flags1 & 0xF;
+ // rest of header
*qdcount = ntohs(header.qdcount);
*ancount = ntohs(header.ancount);
*nscount = ntohs(header.nscount);
@@ -508,9 +518,10 @@
DNSResponder::DNSResponder(const char* listen_address,
const char* listen_service, int poll_timeout_ms,
- uint16_t error_rcode) :
+ uint16_t error_rcode, double response_probability) :
listen_address_(listen_address), listen_service_(listen_service),
poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
+ response_probability_(response_probability),
socket_(-1), epoll_fd_(-1), terminate_(false) { }
DNSResponder::~DNSResponder() {
@@ -544,6 +555,10 @@
mappings_.erase(it);
}
+void DNSResponder::setResponseProbability(double response_probability) {
+ response_probability_ = response_probability;
+}
+
bool DNSResponder::running() const {
return socket_ != -1;
}
@@ -742,6 +757,15 @@
ns_type(question.qtype)));
}
}
+
+ // Ignore requests with the preset probability.
+ auto constexpr bound = std::numeric_limits<unsigned>::max();
+ if (arc4random_uniform(bound) > bound*response_probability_) {
+ ALOGI("returning SRVFAIL in accordance with probability distribution");
+ return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
+ response_len);
+ }
+
for (const DNSQuestion& question : header.questions) {
if (question.qclass != ns_class::ns_c_in &&
question.qclass != ns_class::ns_c_any) {
@@ -754,6 +778,7 @@
response_len);
}
}
+ header.qr = true;
char* response_cur = header.write(response, response + *response_len);
if (response_cur == nullptr) {
return false;
@@ -805,6 +830,7 @@
header->authorities.clear();
header->additionals.clear();
header->rcode = rcode;
+ header->qr = true;
char* response_cur = header->write(response, response + *response_len);
if (response_cur == nullptr) return false;
*response_len = response_cur - response;
diff --git a/tests/dns_responder.h b/tests/dns_responder.h
index 097185b..4ed4bb2 100644
--- a/tests/dns_responder.h
+++ b/tests/dns_responder.h
@@ -27,13 +27,7 @@
#include <unordered_map>
#include <vector>
-// TODO(imaipi): This doesn't belong here.
-#if defined(__clang__) && (!defined(SWIG))
-#define THREAD_ANNOTATION_ATTRIBUTE__(x) __attribute__((x))
-#else
-#define THREAD_ANNOTATION_ATTRIBUTE__(x) // no-op
-#endif
-#define GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(guarded_by(x))
+#include <android-base/thread_annotations.h>
namespace test {
@@ -49,10 +43,12 @@
class DNSResponder {
public:
DNSResponder(const char* listen_address, const char* listen_service,
- int poll_timeout_ms, uint16_t error_rcode);
+ int poll_timeout_ms, uint16_t error_rcode,
+ double response_probability);
~DNSResponder();
void addMapping(const char* name, ns_type type, const char* addr);
void removeMapping(const char* name, ns_type type);
+ void setResponseProbability(double response_probability);
bool running() const;
bool startServer();
bool stopServer();
@@ -106,6 +102,9 @@
const int poll_timeout_ms_;
// Error code to return for requests for an unknown name.
const uint16_t error_rcode_;
+ // Probability that a valid response is being sent instead of being sent
+ // instead of returning error_rcode_.
+ std::atomic<double> response_probability_;
// Mappings from (name, type) to registered response and the
// mutex protecting them.
diff --git a/tests/netd_test.cpp b/tests/netd_test.cpp
index 3816346..fc6240a 100644
--- a/tests/netd_test.cpp
+++ b/tests/netd_test.cpp
@@ -21,9 +21,14 @@
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
+#include <unistd.h>
#include <cutils/sockets.h>
+#include <android-base/stringprintf.h>
#include <private/android_filesystem_config.h>
+
+#include <thread>
+
#include "NetdClient.h"
#include <gtest/gtest.h>
@@ -32,24 +37,19 @@
#include <testUtil.h>
#include "dns_responder.h"
+#include "resolv_params.h"
+
+using android::base::StringPrintf;
+using android::base::StringAppendF;
// TODO: make this dynamic and stop depending on implementation details.
#define TEST_OEM_NETWORK "oem29"
#define TEST_NETID 30
-class ResponseCode {
-public:
- // Keep in sync with
- // frameworks/base/services/java/com/android/server/NetworkManagementService.java
- static const int CommandOkay = 200;
- static const int DnsProxyQueryResult = 222;
-
- static const int DnsProxyOperationFailed = 401;
-
- static const int CommandSyntaxError = 500;
- static const int CommandParameterError = 501;
-};
-
+// The only response code used in this test, see
+// frameworks/base/services/java/com/android/server/NetworkManagementService.java
+// for others.
+static constexpr int ResponseCodeOK = 200;
// Returns ResponseCode.
int netdCommand(const char* sockname, const char* command) {
@@ -81,15 +81,15 @@
}
-bool expectNetdResult(int code, const char* sockname, const char* format, ...) {
+bool expectNetdResult(int expected, const char* sockname, const char* format, ...) {
char command[256];
va_list args;
va_start(args, format);
vsnprintf(command, sizeof(command), format, args);
va_end(args);
int result = netdCommand(sockname, command);
- EXPECT_EQ(code, result) << command;
- return (200 <= code && code < 300);
+ EXPECT_EQ(expected, result) << command;
+ return (200 <= expected && expected < 300);
}
@@ -110,7 +110,7 @@
void SetupOemNetwork() {
netdCommand("netd", "network destroy " TEST_OEM_NETWORK);
- if (expectNetdResult(ResponseCode::CommandOkay, "netd",
+ if (expectNetdResult(ResponseCodeOK, "netd",
"network create %s", TEST_OEM_NETWORK)) {
oemNetId = TEST_NETID;
}
@@ -120,69 +120,123 @@
void TearDownOemNetwork() {
if (oemNetId != -1) {
- expectNetdResult(ResponseCode::CommandOkay, "netd",
+ expectNetdResult(ResponseCodeOK, "netd",
"network destroy %s", TEST_OEM_NETWORK);
}
}
- bool SetResolverForNetwork(const char* address) const {
- return
- expectNetdResult(ResponseCode::CommandOkay, "netd",
- "resolver setnetdns %d \"example.com\" %s", oemNetId,
- address) &&
- FlushCache();
+ bool SetResolversForNetwork(const std::vector<std::string>& searchDomains,
+ const std::vector<std::string>& servers, const std::string& params) {
+ // No use case for empty domains / servers (yet).
+ if (searchDomains.empty() || servers.empty()) return false;
+
+ std::string cmd = StringPrintf("resolver setnetdns %d \"%s", oemNetId,
+ searchDomains[0].c_str());
+ for (size_t i = 1 ; i < searchDomains.size() ; ++i) {
+ cmd += " ";
+ cmd += searchDomains[i];
+ }
+ cmd += "\" ";
+
+ cmd += servers[0];
+ for (size_t i = 1 ; i < servers.size() ; ++i) {
+ cmd += " ";
+ cmd += servers[i];
+ }
+
+ if (!params.empty()) {
+ cmd += " --params \"";
+ cmd += params;
+ cmd += "\"";
+ }
+
+ int rv = netdCommand("netd", cmd.c_str());
+ std::cout << "command: '" << cmd << "', rv = " << rv << "\n";
+ if (rv != ResponseCodeOK) {
+ return false;
+ }
+ return true;
}
bool FlushCache() const {
- return expectNetdResult(ResponseCode::CommandOkay, "netd",
- "resolver flushnet %d", oemNetId);
+ return expectNetdResult(ResponseCodeOK, "netd", "resolver flushnet %d", oemNetId);
}
- const char* ToString(const addrinfo* result) const {
- if (!result)
+ std::string ToString(const hostent* he) const {
+ if (he == nullptr) return "<null>";
+ char buffer[INET6_ADDRSTRLEN];
+ if (!inet_ntop(he->h_addrtype, he->h_addr_list[0], buffer, sizeof(buffer))) {
+ return "<invalid>";
+ }
+ return buffer;
+ }
+
+ std::string ToString(const addrinfo* ai) const {
+ if (!ai)
return "<null>";
- sockaddr_in* addr = reinterpret_cast<sockaddr_in*>(result->ai_addr);
- return inet_ntoa(addr->sin_addr);
+ for (const auto* aip = ai ; aip != nullptr ; aip = aip->ai_next) {
+ char host[NI_MAXHOST];
+ int rv = getnameinfo(aip->ai_addr, aip->ai_addrlen, host, sizeof(host), nullptr, 0,
+ NI_NUMERICHOST);
+ if (rv != 0)
+ return gai_strerror(rv);
+ return host;
+ }
+ return "<invalid>";
}
- const char* ToString(const hostent* result) const {
- in_addr addr;
- memcpy(reinterpret_cast<char*>(&addr), result->h_addr_list[0],
- sizeof(addr));
- return inet_ntoa(addr);
+ size_t GetNumQueries(const test::DNSResponder& dns, const char* name) const {
+ auto queries = dns.queries();
+ size_t found = 0;
+ for (const auto& p : queries) {
+ std::cout << "query " << p.first << "\n";
+ if (p.first == name) {
+ ++found;
+ }
+ }
+ return found;
+ }
+
+ size_t GetNumQueriesForType(const test::DNSResponder& dns, ns_type type,
+ const char* name) const {
+ auto queries = dns.queries();
+ size_t found = 0;
+ for (const auto& p : queries) {
+ std::cout << "query " << p.first << "\n";
+ if (p.second == type && p.first == name) {
+ ++found;
+ }
+ }
+ return found;
}
int pid;
int uid;
int oemNetId = -1;
+ const std::vector<std::string> mDefaultSearchDomains = { "example.com" };
+ // <sample validity in s> <success threshold in percent> <min samples> <max samples>
+ const std::string mDefaultParams = "300 25 8 8";
};
-
TEST_F(ResolverTest, GetHostByName) {
const char* listen_addr = "127.0.0.3";
const char* listen_srv = "53";
- test::DNSResponder resp(listen_addr, listen_srv, 250,
- ns_rcode::ns_r_servfail);
- resp.addMapping("hello.example.com.", ns_type::ns_t_a, "1.2.3.3");
- ASSERT_TRUE(resp.startServer());
- ASSERT_TRUE(SetResolverForNetwork(listen_addr));
+ const char* host_name = "hello.example.com.";
+ test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
- resp.clearQueries();
+ dns.clearQueries();
const hostent* result = gethostbyname("hello");
- auto queries = resp.queries();
- size_t found = 0;
- for (const auto& p : queries) {
- if (p.second == ns_type::ns_t_a && p.first == "hello.example.com.") {
- ++found;
- }
- }
- EXPECT_EQ(1, found);
+ EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
ASSERT_FALSE(result == nullptr);
ASSERT_EQ(4, result->h_length);
ASSERT_FALSE(result->h_addr_list[0] == nullptr);
- EXPECT_STREQ("1.2.3.3", ToString(result));
+ EXPECT_EQ("1.2.3.3", ToString(result));
EXPECT_TRUE(result->h_addr_list[1] == nullptr);
- resp.stopServer();
+ dns.stopServer();
}
TEST_F(ResolverTest, GetAddrInfo) {
@@ -190,54 +244,50 @@
const char* listen_addr = "127.0.0.4";
const char* listen_srv = "53";
- test::DNSResponder resp(listen_addr, listen_srv, 250,
- ns_rcode::ns_r_servfail);
- resp.addMapping("howdie.example.com.", ns_type::ns_t_a, "1.2.3.4");
- resp.addMapping("howdie.example.com.", ns_type::ns_t_aaaa, "::1.2.3.4");
- ASSERT_TRUE(resp.startServer());
- ASSERT_TRUE(SetResolverForNetwork(listen_addr));
+ const char* host_name = "howdie.example.com.";
+ test::DNSResponder dns(listen_addr, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
+ dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
- resp.clearQueries();
+ dns.clearQueries();
EXPECT_EQ(0, getaddrinfo("howdie", nullptr, nullptr, &result));
- auto queries = resp.queries();
- size_t found = 0;
- for (const auto& p : queries) {
- if (p.first == "howdie.example.com.") {
- ++found;
- }
- }
- EXPECT_LE(1, found);
+ size_t found = GetNumQueries(dns, host_name);
+ EXPECT_LE(1U, found);
// Could be A or AAAA
std::string result_str = ToString(result);
- EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4");
+ EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
+ << ", result_str='" << result_str << "'";
if (result) freeaddrinfo(result);
result = nullptr;
// Verify that it's cached.
+ size_t old_found = found;
EXPECT_EQ(0, getaddrinfo("howdie", nullptr, nullptr, &result));
+ found = GetNumQueries(dns, host_name);
+ EXPECT_LE(1U, found);
+ EXPECT_EQ(old_found, found);
result_str = ToString(result);
- EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4");
+ EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
+ << result_str;
if (result) freeaddrinfo(result);
result = nullptr;
// Verify that cache can be flushed.
- resp.clearQueries();
+ dns.clearQueries();
ASSERT_TRUE(FlushCache());
- resp.addMapping("howdie.example.com.", ns_type::ns_t_a, "1.2.3.44");
- resp.addMapping("howdie.example.com.", ns_type::ns_t_aaaa, "::1.2.3.44");
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.44");
+ dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.44");
EXPECT_EQ(0, getaddrinfo("howdie", nullptr, nullptr, &result));
- queries = resp.queries();
- found = 0;
- for (const auto& p : queries) {
- if (p.first == "howdie.example.com.") {
- ++found;
- }
- }
- EXPECT_LE(1, found);
+ EXPECT_LE(1U, GetNumQueries(dns, host_name));
// Could be A or AAAA
result_str = ToString(result);
- EXPECT_TRUE(result_str == "1.2.3.44" || result_str == "::1.2.3.44");
+ EXPECT_TRUE(result_str == "1.2.3.44" || result_str == "::1.2.3.44")
+ << ", result_str='" << result_str << "'";
if (result) freeaddrinfo(result);
}
@@ -246,24 +296,133 @@
const char* listen_addr = "127.0.0.5";
const char* listen_srv = "53";
- test::DNSResponder resp(listen_addr, listen_srv, 250,
- ns_rcode::ns_r_servfail);
- resp.addMapping("hola.example.com.", ns_type::ns_t_a, "1.2.3.5");
- ASSERT_TRUE(resp.startServer());
- ASSERT_TRUE(SetResolverForNetwork(listen_addr));
+ const char* host_name = "hola.example.com.";
+ test::DNSResponder dns(listen_addr, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.5");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
EXPECT_EQ(0, getaddrinfo("hola", nullptr, &hints, &result));
- auto queries = resp.queries();
- size_t found = 0;
- for (const auto& p : queries) {
- if (p.first == "hola.example.com.") {
- ++found;
- }
- }
- EXPECT_LE(1, found);
- EXPECT_STREQ("1.2.3.5", ToString(result));
+ EXPECT_EQ(1U, GetNumQueries(dns, host_name));
+ EXPECT_EQ("1.2.3.5", ToString(result));
if (result) freeaddrinfo(result);
}
+
+TEST_F(ResolverTest, MultidomainResolution) {
+ std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };
+ const char* listen_addr = "127.0.0.6";
+ const char* listen_srv = "53";
+ const char* host_name = "nihao.example2.com.";
+ test::DNSResponder dns(listen_addr, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 1.0);
+ dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
+ ASSERT_TRUE(dns.startServer());
+ std::vector<std::string> servers = { listen_addr };
+ ASSERT_TRUE(SetResolversForNetwork(searchDomains, servers, mDefaultParams));
+
+ dns.clearQueries();
+ const hostent* result = gethostbyname("nihao");
+ EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
+ ASSERT_FALSE(result == nullptr);
+ ASSERT_EQ(4, result->h_length);
+ ASSERT_FALSE(result->h_addr_list[0] == nullptr);
+ EXPECT_EQ("1.2.3.3", ToString(result));
+ EXPECT_TRUE(result->h_addr_list[1] == nullptr);
+ dns.stopServer();
+}
+
+TEST_F(ResolverTest, GetAddrInfoV6_failing) {
+ addrinfo* result = nullptr;
+
+ const char* listen_addr0 = "127.0.0.7";
+ const char* listen_addr1 = "127.0.0.8";
+ const char* listen_srv = "53";
+ const char* host_name = "ohayou.example.com.";
+ test::DNSResponder dns0(listen_addr0, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 0.0);
+ test::DNSResponder dns1(listen_addr1, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 1.0);
+ dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
+ dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
+ ASSERT_TRUE(dns0.startServer());
+ ASSERT_TRUE(dns1.startServer());
+ std::vector<std::string> servers = { listen_addr0, listen_addr1 };
+ // <sample validity in s> <success threshold in percent> <min samples> <max samples>
+ unsigned sample_validity = 300;
+ int success_threshold = 25;
+ int sample_count = 8;
+ std::string params = StringPrintf("%u %d %d %d", sample_validity, success_threshold,
+ sample_count, sample_count);
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, params));
+
+ // Repeatedly perform resolutions for non-existing domains until MAXNSSAMPLES resolutions have
+ // reached the dns0, which is set to fail. No more requests should then arrive at that server
+ // for the next sample_lifetime seconds.
+ // TODO: This approach is implementation-dependent, change once metrics reporting is available.
+ addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET6;
+ for (int i = 0 ; i < sample_count ; ++i) {
+ std::string domain = StringPrintf("nonexistent%d", i);
+ getaddrinfo(domain.c_str(), nullptr, &hints, &result);
+ }
+ // Due to 100% errors for all possible samples, the server should be ignored from now on and
+ // only the second one used for all following queries, until NSSAMPLE_VALIDITY is reached.
+ dns0.clearQueries();
+ dns1.clearQueries();
+ EXPECT_EQ(0, getaddrinfo("ohayou", nullptr, &hints, &result));
+ EXPECT_EQ(0, GetNumQueries(dns0, host_name));
+ EXPECT_EQ(1U, GetNumQueries(dns1, host_name));
+ if (result) freeaddrinfo(result);
+}
+
+TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
+ const char* listen_addr0 = "127.0.0.9";
+ const char* listen_addr1 = "127.0.0.10";
+ const char* listen_addr2 = "127.0.0.11";
+ const char* listen_srv = "53";
+ const char* host_name = "konbanha.example.com.";
+ test::DNSResponder dns0(listen_addr0, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 1.0);
+ test::DNSResponder dns1(listen_addr1, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 1.0);
+ test::DNSResponder dns2(listen_addr2, listen_srv, 250,
+ ns_rcode::ns_r_servfail, 1.0);
+ dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
+ dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
+ dns2.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::7");
+ ASSERT_TRUE(dns0.startServer());
+ ASSERT_TRUE(dns1.startServer());
+ ASSERT_TRUE(dns2.startServer());
+ const std::vector<std::string> servers = { listen_addr0, listen_addr1, listen_addr2 };
+ std::vector<std::thread> threads(10);
+ for (std::thread& thread : threads) {
+ thread = std::thread([this, &servers, &dns0, &dns1, &dns2]() {
+ unsigned delay = arc4random_uniform(1*1000*1000); // <= 1s
+ usleep(delay);
+ std::vector<std::string> serverSubset;
+ for (const auto& server : servers) {
+ if (arc4random_uniform(2)) {
+ serverSubset.push_back(server);
+ }
+ }
+ if (serverSubset.empty()) serverSubset = servers;
+ ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, serverSubset,
+ mDefaultParams));
+ addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET6;
+ addrinfo* result = nullptr;
+ int rv = getaddrinfo("konbanha", nullptr, &hints, &result);
+ EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
+ });
+ }
+ for (std::thread& thread : threads) {
+ thread.join();
+ }
+}