Merge "DO NOT MERGE - Skip pie-platform-release (PPRL.190705.004) in master"
diff --git a/client/Android.bp b/client/Android.bp
index 3dae6f0..7b51322 100644
--- a/client/Android.bp
+++ b/client/Android.bp
@@ -28,11 +28,6 @@
"system/netd/libnetdutils/include",
],
defaults: ["netd_defaults"],
- product_variables: {
- debuggable: {
- cflags: ["-DNETD_CLIENT_DEBUGGABLE_BUILD"],
- }
- }
}
cc_test {
diff --git a/client/FwmarkClient.cpp b/client/FwmarkClient.cpp
index cc4893d..592fe31 100644
--- a/client/FwmarkClient.cpp
+++ b/client/FwmarkClient.cpp
@@ -31,21 +31,11 @@
namespace {
// Env flag to control whether FwmarkClient sends sockets to netd for marking.
-// This can only be disabled in debuggable builds and is meant for kernel testing.
+// This can only be disabled when the process running as root and is meant for kernel testing.
inline constexpr char ANDROID_NO_USE_FWMARK_CLIENT[] = "ANDROID_NO_USE_FWMARK_CLIENT";
const sockaddr_un FWMARK_SERVER_PATH = {AF_UNIX, "/dev/socket/fwmarkd"};
-#if defined(NETD_CLIENT_DEBUGGABLE_BUILD)
-constexpr bool isBuildDebuggable = true;
-#else
-constexpr bool isBuildDebuggable = false;
-#endif
-
-bool isOverriddenBy(const char *name) {
- return isBuildDebuggable && getenv(name);
-}
-
bool commandHasFd(int cmdId) {
return (cmdId != FwmarkCommand::QUERY_USER_ACCESS) &&
(cmdId != FwmarkCommand::SET_COUNTERSET) &&
@@ -55,13 +45,20 @@
} // namespace
bool FwmarkClient::shouldSetFwmark(int family) {
- if (isOverriddenBy(ANDROID_NO_USE_FWMARK_CLIENT)) return false;
- return FwmarkCommand::isSupportedFamily(family);
-}
+ // Checking whether family is supported before checking whether this can be
+ // disabled. Because there are existing processes using AF_LOCAL socket but it
+ // doesn't have permission to call geteuid(). Reference b/135422468.
+ if (!FwmarkCommand::isSupportedFamily(family)) {
+ return false;
+ }
-bool FwmarkClient::shouldReportConnectComplete(int family) {
- if (isOverriddenBy(ANDROID_NO_USE_FWMARK_CLIENT)) return false;
- return shouldSetFwmark(family);
+ // Permit processes running as root to disable marking. This is required, for
+ // example, to run the kernel networking tests.
+ if (getenv(ANDROID_NO_USE_FWMARK_CLIENT) && geteuid() == 0) {
+ return false;
+ }
+
+ return true;
}
FwmarkClient::FwmarkClient() : mChannel(-1) {
diff --git a/client/FwmarkClient.h b/client/FwmarkClient.h
index 31fcbc4..c51688f 100644
--- a/client/FwmarkClient.h
+++ b/client/FwmarkClient.h
@@ -28,10 +28,6 @@
// its SO_MARK set.
static bool shouldSetFwmark(int family);
- // Returns true if an additional call should be made after ON_CONNECT calls, to log extra
- // information like latency and source IP.
- static bool shouldReportConnectComplete(int family);
-
FwmarkClient();
~FwmarkClient();
diff --git a/client/NetdClient.cpp b/client/NetdClient.cpp
index f6fa886..d5945d0 100644
--- a/client/NetdClient.cpp
+++ b/client/NetdClient.cpp
@@ -132,7 +132,7 @@
const int connectErrno = errno;
const auto latencyMs = static_cast<unsigned>(s.timeTakenUs() / 1000);
// Send an ON_CONNECT_COMPLETE command that includes sockaddr and connect latency for reporting
- if (shouldSetFwmark && FwmarkClient::shouldReportConnectComplete(addr->sa_family)) {
+ if (shouldSetFwmark) {
FwmarkConnectInfo connectInfo(ret == 0 ? 0 : connectErrno, latencyMs, addr);
// TODO: get the netId from the socket mark once we have continuous benchmark runs
FwmarkCommand command = {FwmarkCommand::ON_CONNECT_COMPLETE, /* netId (ignored) */ 0,
diff --git a/resolv/ResolverController.cpp b/resolv/ResolverController.cpp
index 7347a44..2e330d5 100644
--- a/resolv/ResolverController.cpp
+++ b/resolv/ResolverController.cpp
@@ -221,14 +221,6 @@
return err;
}
- // Convert network-assigned server list to bionic's format.
- const size_t serverCount = std::min<size_t>(MAXNS, resolverParams.servers.size());
- std::vector<const char*> server_ptrs;
- for (size_t i = 0; i < serverCount; ++i) {
- server_ptrs.push_back(resolverParams.servers[i].c_str());
- }
-
- // TODO: Change resolv_set_nameservers() to use ResolverParamsParcel directly.
res_params res_params = {};
res_params.sample_validity = resolverParams.sampleValiditySeconds;
res_params.success_threshold = resolverParams.successThreshold;
@@ -237,11 +229,8 @@
res_params.base_timeout_msec = resolverParams.baseTimeoutMsec;
res_params.retry_count = resolverParams.retryCount;
- LOG(VERBOSE) << "setDnsServers netId = " << resolverParams.netId
- << ", numservers = " << resolverParams.servers.size();
-
- return -resolv_set_nameservers(resolverParams.netId, server_ptrs.data(), server_ptrs.size(),
- resolverParams.domains, &res_params);
+ return -resolv_set_nameservers(resolverParams.netId, resolverParams.servers,
+ resolverParams.domains, res_params);
}
int ResolverController::getResolverInfo(int32_t netId, std::vector<std::string>* servers,
diff --git a/resolv/ResolverController.h b/resolv/ResolverController.h
index 6d08cdb..4649225 100644
--- a/resolv/ResolverController.h
+++ b/resolv/ResolverController.h
@@ -31,8 +31,6 @@
namespace android {
namespace net {
-struct ResolverStats;
-
class ResolverController {
public:
ResolverController();
diff --git a/resolv/dns_responder/dns_responder.h b/resolv/dns_responder/dns_responder.h
index bf147a3..006dbcd 100644
--- a/resolv/dns_responder/dns_responder.h
+++ b/resolv/dns_responder/dns_responder.h
@@ -119,6 +119,10 @@
DNSResponder(std::string listen_address = kDefaultListenAddr,
std::string listen_service = kDefaultListenService,
ns_rcode error_rcode = ns_rcode::ns_r_servfail);
+
+ DNSResponder(ns_rcode error_rcode)
+ : DNSResponder(kDefaultListenAddr, kDefaultListenService, error_rcode){};
+
~DNSResponder();
enum class Edns : uint8_t {
diff --git a/resolv/getaddrinfo.cpp b/resolv/getaddrinfo.cpp
index 9f43ec8..2cb2034 100644
--- a/resolv/getaddrinfo.cpp
+++ b/resolv/getaddrinfo.cpp
@@ -256,7 +256,7 @@
}
// Internal version of getaddrinfo(), but limited to AI_NUMERICHOST.
-// NOTE: also called by resolv_set_nameservers_for_net().
+// NOTE: also called by resolv_set_nameservers().
int getaddrinfo_numeric(const char* hostname, const char* servname, addrinfo hints,
addrinfo** result) {
hints.ai_flags = AI_NUMERICHOST;
diff --git a/resolv/libnetd_resolv_test.cpp b/resolv/libnetd_resolv_test.cpp
index d6fd80d..d0e6ecf 100644
--- a/resolv/libnetd_resolv_test.cpp
+++ b/resolv/libnetd_resolv_test.cpp
@@ -90,15 +90,20 @@
return found;
}
- const char* mDefaultSearchDomains = "example.com";
- const res_params mDefaultParams_Binder = {
- .sample_validity = 300,
- .success_threshold = 25,
- .min_samples = 8,
- .max_samples = 8,
- .base_timeout_msec = 1000,
- .retry_count = 2,
- };
+ int setResolvers() {
+ const std::vector<std::string> servers = {test::kDefaultListenAddr};
+ const std::vector<std::string> domains = {"example.com"};
+ const res_params params = {
+ .sample_validity = 300,
+ .success_threshold = 25,
+ .min_samples = 8,
+ .max_samples = 8,
+ .base_timeout_msec = 1000,
+ .retry_count = 2,
+ };
+ return resolv_set_nameservers(TEST_NETID, servers, domains, params);
+ }
+
const android_net_context mNetcontext = {
.app_netid = TEST_NETID,
.app_mark = MARK_UNSET,
@@ -361,16 +366,11 @@
}
TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname_NoData) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char v4_host_name[] = "v4only.example.com.";
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3");
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
- dns.clearQueries();
+ ASSERT_EQ(0, setResolvers());
// Want AAAA answer but DNS server has A answer only.
addrinfo* result = nullptr;
@@ -384,19 +384,15 @@
}
TEST_F(ResolvGetAddrInfoTest, AlphabeticalHostname) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char host_name[] = "sawadee.example.com.";
constexpr char v4addr[] = "1.2.3.4";
constexpr char v6addr[] = "::1.2.3.4";
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
dns.addMapping(host_name, ns_type::ns_t_a, v4addr);
dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr);
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
static const struct TestConfig {
int ai_family;
@@ -423,15 +419,9 @@
}
TEST_F(ResolvGetAddrInfoTest, IllegalHostname) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
// Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp.
static constexpr char const* illegalHostnames[] = {
@@ -469,8 +459,6 @@
}
TEST_F(ResolvGetAddrInfoTest, ServerResponseError) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char host_name[] = "hello.example.com.";
static const struct TestConfig {
@@ -492,41 +480,33 @@
for (const auto& config : testConfigs) {
SCOPED_TRACE(StringPrintf("rcode: %d", config.rcode));
- test::DNSResponder dns(listen_addr, listen_srv, config.rcode /*response specific rcode*/);
+ test::DNSResponder dns(config.rcode);
dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
dns.setResponseProbability(0.0); // always ignore requests and response preset rcode
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
addrinfo* result = nullptr;
const addrinfo hints = {.ai_family = AF_UNSPEC};
NetworkDnsEventReported event;
int rv = resolv_getaddrinfo(host_name, nullptr, &hints, &mNetcontext, &result, &event);
- ScopedAddrinfo result_cleanup(result);
EXPECT_EQ(config.expected_eai_error, rv);
}
}
// TODO: Add private DNS server timeout test.
TEST_F(ResolvGetAddrInfoTest, ServerTimeout) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char host_name[] = "hello.example.com.";
- test::DNSResponder dns(listen_addr, listen_srv, static_cast<ns_rcode>(-1) /*no response*/);
+ test::DNSResponder dns(static_cast<ns_rcode>(-1) /*no response*/);
dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
dns.setResponseProbability(0.0); // always ignore requests and don't response
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
addrinfo* result = nullptr;
const addrinfo hints = {.ai_family = AF_UNSPEC};
NetworkDnsEventReported event;
int rv = resolv_getaddrinfo("hello", nullptr, &hints, &mNetcontext, &result, &event);
- ScopedAddrinfo result_cleanup(result);
EXPECT_EQ(NETD_RESOLV_TIMEOUT, rv);
}
@@ -534,17 +514,11 @@
constexpr char ACNAME[] = "acname"; // expect a cname in answer
constexpr char CNAMES[] = "cnames"; // expect cname chain in answer
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
dns.addMapping("cnames.example.com.", ns_type::ns_t_cname, "acname.example.com.");
dns.addMapping("acname.example.com.", ns_type::ns_t_cname, "hello.example.com.");
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
static const struct TestConfig {
const char* name;
@@ -575,15 +549,9 @@
}
TEST_F(ResolvGetAddrInfoTest, CnamesBrokenChainByIllegalCname) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
static const struct TestConfig {
const char* name;
@@ -633,17 +601,11 @@
}
TEST_F(ResolvGetAddrInfoTest, CnamesInfiniteLoop) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
dns.addMapping("hello.example.com.", ns_type::ns_t_cname, "a.example.com.");
dns.addMapping("a.example.com.", ns_type::ns_t_cname, "hello.example.com.");
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
for (const auto& family : {AF_INET, AF_INET6, AF_UNSPEC}) {
SCOPED_TRACE(StringPrintf("family: %d", family));
@@ -659,19 +621,15 @@
}
TEST_F(GetHostByNameForNetContextTest, AlphabeticalHostname) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char host_name[] = "jiababuei.example.com.";
constexpr char v4addr[] = "1.2.3.4";
constexpr char v6addr[] = "::1.2.3.4";
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
dns.addMapping(host_name, ns_type::ns_t_a, v4addr);
dns.addMapping(host_name, ns_type::ns_t_aaaa, v6addr);
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
static const struct TestConfig {
int ai_family;
@@ -697,15 +655,9 @@
}
TEST_F(GetHostByNameForNetContextTest, IllegalHostname) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
// Illegal hostname is verified by res_hnok() in system/netd/resolv/res_comp.cpp.
static constexpr char const* illegalHostnames[] = {
@@ -742,15 +694,12 @@
}
TEST_F(GetHostByNameForNetContextTest, NoData) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char v4_host_name[] = "v4only.example.com.";
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+
+ test::DNSResponder dns;
dns.addMapping(v4_host_name, ns_type::ns_t_a, "1.2.3.3");
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
dns.clearQueries();
// Want AAAA answer but DNS server has A answer only.
@@ -763,8 +712,6 @@
}
TEST_F(GetHostByNameForNetContextTest, ServerResponseError) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char host_name[] = "hello.example.com.";
static const struct TestConfig {
@@ -788,13 +735,11 @@
for (const auto& config : testConfigs) {
SCOPED_TRACE(StringPrintf("rcode: %d", config.rcode));
- test::DNSResponder dns(listen_addr, listen_srv, config.rcode /*response specific rcode*/);
+ test::DNSResponder dns(config.rcode);
dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
dns.setResponseProbability(0.0); // always ignore requests and response preset rcode
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
hostent* hp = nullptr;
NetworkDnsEventReported event;
@@ -806,16 +751,12 @@
// TODO: Add private DNS server timeout test.
TEST_F(GetHostByNameForNetContextTest, ServerTimeout) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
constexpr char host_name[] = "hello.example.com.";
- test::DNSResponder dns(listen_addr, listen_srv, static_cast<ns_rcode>(-1) /*no response*/);
+ test::DNSResponder dns(static_cast<ns_rcode>(-1) /*no response*/);
dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
dns.setResponseProbability(0.0); // always ignore requests and don't response
ASSERT_TRUE(dns.startServer());
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
hostent* hp = nullptr;
NetworkDnsEventReported event;
@@ -827,17 +768,11 @@
constexpr char ACNAME[] = "acname"; // expect a cname in answer
constexpr char CNAMES[] = "cnames"; // expect cname chain in answer
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
dns.addMapping("cnames.example.com.", ns_type::ns_t_cname, "acname.example.com.");
dns.addMapping("acname.example.com.", ns_type::ns_t_cname, "hello.example.com.");
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
static const struct TestConfig {
const char* name;
@@ -863,15 +798,9 @@
}
TEST_F(GetHostByNameForNetContextTest, CnamesBrokenChainByIllegalCname) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
static const struct TestConfig {
const char* name;
@@ -920,17 +849,11 @@
}
TEST_F(GetHostByNameForNetContextTest, CnamesInfiniteLoop) {
- constexpr char listen_addr[] = "127.0.0.3";
- constexpr char listen_srv[] = "53";
-
- test::DNSResponder dns(listen_addr, listen_srv, ns_rcode::ns_r_servfail);
+ test::DNSResponder dns;
dns.addMapping("hello.example.com.", ns_type::ns_t_cname, "a.example.com.");
dns.addMapping("a.example.com.", ns_type::ns_t_cname, "hello.example.com.");
ASSERT_TRUE(dns.startServer());
-
- const char* servers[] = {listen_addr};
- ASSERT_EQ(0, resolv_set_nameservers(TEST_NETID, servers, std::size(servers),
- mDefaultSearchDomains, &mDefaultParams_Binder));
+ ASSERT_EQ(0, setResolvers());
for (const auto& family : {AF_INET, AF_INET6}) {
SCOPED_TRACE(StringPrintf("family: %d", family));
diff --git a/resolv/res_cache.cpp b/resolv/res_cache.cpp
index befcbed..c26ea75 100644
--- a/resolv/res_cache.cpp
+++ b/resolv/res_cache.cpp
@@ -38,6 +38,7 @@
#include <algorithm>
#include <mutex>
#include <set>
+#include <string>
#include <vector>
#include <arpa/inet.h>
@@ -542,9 +543,8 @@
* - there is no point for a query packet sent to a server
* to have the TC bit set, but the implementation might
* set the bit in the query buffer for its own needs
- * between a resolv_cache_lookup and a
- * _resolv_cache_add. We should not freak out if this
- * is the case.
+ * between a resolv_cache_lookup and a resolv_cache_add.
+ * We should not freak out if this is the case.
*
* - we consider that the result from a query might depend on
* the RD, AD, and CD bits, so these bits
@@ -1141,8 +1141,8 @@
Cache* cache;
struct resolv_cache_info* next;
int nscount;
- char* nameservers[MAXNS];
- struct addrinfo* nsaddrinfo[MAXNS];
+ std::vector<std::string> nameservers;
+ struct addrinfo* nsaddrinfo[MAXNS]; // TODO: Use struct sockaddr_storage.
int revision_id; // # times the nameservers have been replaced
res_params params;
struct res_stats nsstats[MAXNS];
@@ -1317,7 +1317,7 @@
* So, the caller must check '*result' to check for success/failure.
*
* The main idea is that the result can later be used directly in
- * calls to _resolv_cache_add or _resolv_cache_remove as the 'lookup'
+ * calls to resolv_cache_add or _resolv_cache_remove as the 'lookup'
* parameter. This makes the code simpler and avoids re-searching
* for the key position in the htable.
*
@@ -1593,10 +1593,9 @@
static resolv_cache_info* create_cache_info();
// empty the nameservers set for the named cache
static void free_nameservers_locked(resolv_cache_info* cache_info);
-// return 1 if the provided list of name servers differs from the list of name servers
-// currently attached to the provided cache_info
-static int resolv_is_nameservers_equal_locked(resolv_cache_info* cache_info, const char** servers,
- int numservers);
+// Order-insensitive comparison for the two set of servers.
+static bool resolv_is_nameservers_equal(const std::vector<std::string>& oldServers,
+ const std::vector<std::string>& newServers);
// clears the stats samples contained withing the given cache_info
static void res_cache_clear_stats_locked(resolv_cache_info* cache_info);
@@ -1697,15 +1696,6 @@
return cache_info;
}
-static void resolv_set_default_params(res_params* params) {
- params->sample_validity = NSSAMPLE_VALIDITY;
- params->success_threshold = SUCCESS_THRESHOLD;
- params->min_samples = 0;
- params->max_samples = 0;
- params->base_timeout_msec = 0; // 0 = legacy algorithm
- params->retry_count = 0;
-}
-
static void resolv_set_experiment_params(res_params* params) {
using android::base::ParseInt;
using server_configurable_flags::GetServerConfigurableFlag;
@@ -1722,12 +1712,6 @@
}
}
-int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
- const char* domains, const res_params* params) {
- return resolv_set_nameservers(netid, servers, numservers, android::base::Split(domains, " "),
- params);
-}
-
namespace {
// Returns valid domains without duplicates which are limited to max size |MAXDNSRCH|.
@@ -1747,56 +1731,59 @@
return res;
}
+std::vector<std::string> filter_nameservers(const std::vector<std::string>& servers) {
+ std::vector<std::string> res = servers;
+ if (res.size() > MAXNS) {
+ LOG(WARNING) << __func__ << ": too many servers: " << res.size();
+ res.resize(MAXNS);
+ }
+ return res;
+}
+
} // namespace
-int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
- const std::vector<std::string>& domains, const res_params* params) {
- if (numservers > MAXNS) {
- LOG(ERROR) << __func__ << ": numservers=" << numservers << ", MAXNS=" << MAXNS;
- return E2BIG;
- }
+int resolv_set_nameservers(unsigned netid, const std::vector<std::string>& servers,
+ const std::vector<std::string>& domains, const res_params& params) {
+ std::vector<std::string> nameservers = filter_nameservers(servers);
+ const int numservers = static_cast<int>(nameservers.size());
+
+ LOG(INFO) << __func__ << ": netId = " << netid << ", numservers = " << numservers;
// Parse the addresses before actually locking or changing any state, in case there is an error.
// As a side effect this also reduces the time the lock is kept.
// TODO: find a better way to replace addrinfo*, something like std::vector<SafeAddrinfo>
addrinfo* nsaddrinfo[MAXNS];
- char sbuf[NI_MAXSERV];
- snprintf(sbuf, sizeof(sbuf), "%u", NAMESERVER_PORT);
for (int i = 0; i < numservers; i++) {
// The addrinfo structures allocated here are freed in free_nameservers_locked().
const addrinfo hints = {
.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM, .ai_flags = AI_NUMERICHOST};
- int rt = getaddrinfo_numeric(servers[i], sbuf, hints, &nsaddrinfo[i]);
+ const int rt = getaddrinfo_numeric(nameservers[i].c_str(), "53", hints, &nsaddrinfo[i]);
if (rt != 0) {
for (int j = 0; j < i; j++) {
freeaddrinfo(nsaddrinfo[j]);
}
- LOG(INFO) << __func__ << ": getaddrinfo_numeric(" << servers[i]
+ LOG(INFO) << __func__ << ": getaddrinfo_numeric(" << nameservers[i]
<< ") = " << gai_strerror(rt);
- return EINVAL;
+ return -EINVAL;
}
}
std::lock_guard guard(cache_mutex);
-
resolv_cache_info* cache_info = find_cache_info_locked(netid);
- if (cache_info == NULL) return ENONET;
+ if (cache_info == nullptr) return -ENONET;
uint8_t old_max_samples = cache_info->params.max_samples;
- if (params != NULL) {
- cache_info->params = *params;
- } else {
- resolv_set_default_params(&cache_info->params);
- }
+ cache_info->params = params;
resolv_set_experiment_params(&cache_info->params);
- if (!resolv_is_nameservers_equal_locked(cache_info, servers, numservers)) {
+ if (!resolv_is_nameservers_equal(cache_info->nameservers, nameservers)) {
// free current before adding new
free_nameservers_locked(cache_info);
+ cache_info->nameservers = std::move(nameservers);
for (int i = 0; i < numservers; i++) {
cache_info->nsaddrinfo[i] = nsaddrinfo[i];
- cache_info->nameservers[i] = strdup(servers[i]);
- LOG(INFO) << __func__ << ": netid = " << netid << ", addr = " << servers[i];
+ LOG(INFO) << __func__ << ": netid = " << netid
+ << ", addr = " << cache_info->nameservers[i];
}
cache_info->nscount = numservers;
@@ -1830,37 +1817,23 @@
return 0;
}
-static int resolv_is_nameservers_equal_locked(resolv_cache_info* cache_info, const char** servers,
- int numservers) {
- if (cache_info->nscount != numservers) {
- return 0;
- }
+static bool resolv_is_nameservers_equal(const std::vector<std::string>& oldServers,
+ const std::vector<std::string>& newServers) {
+ const std::set<std::string> olds(oldServers.begin(), oldServers.end());
+ const std::set<std::string> news(newServers.begin(), newServers.end());
- // Compare each name server against current name servers.
// TODO: this is incorrect if the list of current or previous nameservers
// contains duplicates. This does not really matter because the framework
// filters out duplicates, but we should probably fix it. It's also
// insensitive to the order of the nameservers; we should probably fix that
// too.
- for (int i = 0; i < numservers; i++) {
- for (int j = 0;; j++) {
- if (j >= numservers) {
- return 0;
- }
- if (strcmp(cache_info->nameservers[i], servers[j]) == 0) {
- break;
- }
- }
- }
-
- return 1;
+ return olds == news;
}
static void free_nameservers_locked(resolv_cache_info* cache_info) {
int i;
for (i = 0; i < cache_info->nscount; i++) {
- free(cache_info->nameservers[i]);
- cache_info->nameservers[i] = NULL;
+ cache_info->nameservers.clear();
if (cache_info->nsaddrinfo[i] != NULL) {
freeaddrinfo(cache_info->nsaddrinfo[i]);
cache_info->nsaddrinfo[i] = NULL;
@@ -2019,27 +1992,26 @@
return find_named_cache_locked(netid) != nullptr;
}
-int resolv_cache_get_expiration(unsigned netid, const std::vector<char> query, time_t* expiration) {
+int resolv_cache_get_expiration(unsigned netid, const std::vector<char>& query,
+ time_t* expiration) {
Entry key;
- Entry** lookup;
- Entry* e;
- Cache* cache;
*expiration = -1;
- // A malfored query is not allowed.
+ // A malformed query is not allowed.
if (!entry_init_key(&key, query.data(), query.size())) {
LOG(WARNING) << __func__ << ": unsupported query";
return -EINVAL;
}
// lookup cache.
+ Cache* cache;
std::lock_guard guard(cache_mutex);
if (cache = find_named_cache_locked(netid); cache == nullptr) {
LOG(WARNING) << __func__ << ": cache not created in the network " << netid;
return -ENONET;
}
- lookup = _cache_lookup_p(cache, &key);
- e = *lookup;
+ Entry** lookup = _cache_lookup_p(cache, &key);
+ Entry* e = *lookup;
if (e == NULL) {
LOG(WARNING) << __func__ << ": not in cache";
return -ENODATA;
diff --git a/resolv/res_cache_test.cpp b/resolv/res_cache_test.cpp
index 97386e6..5c17d4d 100644
--- a/resolv/res_cache_test.cpp
+++ b/resolv/res_cache_test.cpp
@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
+#include <array>
#include <atomic>
#include <chrono>
#include <ctime>
@@ -26,6 +27,7 @@
#include <android/multinetwork.h>
#include "dns_responder/dns_responder.h"
+#include "netd_resolv/stats.h"
#include "resolv_cache.h"
#include "resolv_private.h"
@@ -46,6 +48,18 @@
std::vector<char> answer;
};
+struct SetupParams {
+ std::vector<std::string> servers;
+ std::vector<std::string> domains;
+ res_params params;
+};
+
+struct CacheStats {
+ SetupParams setup;
+ std::vector<res_stats> stats;
+ int pendingReqTimeoutCount;
+};
+
std::vector<char> makeQuery(int op, const char* qname, int qclass, int qtype) {
res_state res = res_get_state();
uint8_t buf[MAXPACKET] = {};
@@ -80,10 +94,39 @@
return std::time(nullptr);
}
+std::string addrToString(const sockaddr_storage* addr) {
+ char out[INET6_ADDRSTRLEN] = {0};
+ getnameinfo((const sockaddr*)addr, sizeof(sockaddr_storage), out, INET6_ADDRSTRLEN, nullptr, 0,
+ NI_NUMERICHOST);
+ return std::string(out);
+}
+
+// Comparison for res_stats. Simply check the count in the cache test.
+bool operator==(const res_stats& a, const res_stats& b) {
+ return std::tie(a.sample_count, a.sample_next) == std::tie(b.sample_count, b.sample_next);
+}
+
+// Comparison for res_params.
+bool operator==(const res_params& a, const res_params& b) {
+ return std::tie(a.sample_validity, a.success_threshold, a.min_samples, a.max_samples,
+ a.base_timeout_msec, a.retry_count) ==
+ std::tie(b.sample_validity, b.success_threshold, b.min_samples, b.max_samples,
+ b.base_timeout_msec, b.retry_count);
+}
+
} // namespace
class ResolvCacheTest : public ::testing::Test {
protected:
+ static constexpr res_params kParams = {
+ .sample_validity = 300,
+ .success_threshold = 25,
+ .min_samples = 8,
+ .max_samples = 8,
+ .base_timeout_msec = 1000,
+ .retry_count = 2,
+ };
+
ResolvCacheTest() {
// Store the default one and conceal 10000+ lines of resolver cache logs.
defaultLogSeverity = android::base::SetMinimumLogSeverity(
@@ -144,6 +187,45 @@
_resolv_cache_query_failed(netId, ce.query.data(), ce.query.size(), flags);
}
+ int cacheSetupResolver(uint32_t netId, const SetupParams& setup) {
+ return resolv_set_nameservers(netId, setup.servers, setup.domains, setup.params);
+ }
+
+ void expectCacheStats(const std::string& msg, uint32_t netId, const CacheStats& expected) {
+ int nscount = -1;
+ sockaddr_storage servers[MAXNS];
+ int dcount = -1;
+ char domains[MAXDNSRCH][MAXDNSRCHPATH];
+ res_stats stats[MAXNS];
+ res_params params = {};
+ int res_wait_for_pending_req_timeout_count;
+ android_net_res_stats_get_info_for_net(netId, &nscount, servers, &dcount, domains, ¶ms,
+ stats, &res_wait_for_pending_req_timeout_count);
+
+ // Server checking.
+ EXPECT_EQ(nscount, static_cast<int>(expected.setup.servers.size())) << msg;
+ for (int i = 0; i < nscount; i++) {
+ EXPECT_EQ(addrToString(&servers[i]), expected.setup.servers[i]) << msg;
+ }
+
+ // Domain checking
+ EXPECT_EQ(dcount, static_cast<int>(expected.setup.domains.size())) << msg;
+ for (int i = 0; i < dcount; i++) {
+ EXPECT_EQ(std::string(domains[i]), expected.setup.domains[i]) << msg;
+ }
+
+ // res_params checking.
+ EXPECT_TRUE(params == expected.setup.params) << msg;
+
+ // res_stats checking.
+ for (size_t i = 0; i < expected.stats.size(); i++) {
+ EXPECT_TRUE(stats[i] == expected.stats[i]) << msg;
+ }
+
+ // wait_for_pending_req_timeout_count checking.
+ EXPECT_EQ(res_wait_for_pending_req_timeout_count, expected.pendingReqTimeoutCount) << msg;
+ }
+
CacheEntry makeCacheEntry(int op, const char* qname, int qclass, int qtype, const char* rdata,
std::chrono::seconds ttl = 10s) {
CacheEntry ce;
@@ -511,6 +593,124 @@
EXPECT_TRUE(cacheLookup(RESOLV_CACHE_NOTFOUND, TEST_NETID, ce1));
}
+TEST_F(ResolvCacheTest, ResolverSetup) {
+ const SetupParams setup = {
+ .servers = {"127.0.0.1", "::127.0.0.2", "fe80::3"},
+ .domains = {"domain1.com", "domain2.com"},
+ .params = kParams,
+ };
+
+ // Failed to setup resolver because of the cache not created.
+ EXPECT_EQ(-ENONET, cacheSetupResolver(TEST_NETID, setup));
+ EXPECT_FALSE(resolv_has_nameservers(TEST_NETID));
+
+ // The cache is created now.
+ EXPECT_EQ(0, cacheCreate(TEST_NETID));
+ EXPECT_EQ(0, cacheSetupResolver(TEST_NETID, setup));
+ EXPECT_TRUE(resolv_has_nameservers(TEST_NETID));
+}
+
+TEST_F(ResolvCacheTest, ResolverSetup_InvalidNameServers) {
+ EXPECT_EQ(0, cacheCreate(TEST_NETID));
+ const std::string invalidServers[]{
+ "127.A.b.1",
+ "127.^.0",
+ "::^:1",
+ "",
+ };
+ SetupParams setup = {
+ .servers = {},
+ .domains = {"domain1.com"},
+ .params = kParams,
+ };
+
+ // Failed to setup resolver because of invalid name servers.
+ for (const auto& server : invalidServers) {
+ SCOPED_TRACE(server);
+ setup.servers = {"127.0.0.1", server, "127.0.0.2"};
+ EXPECT_EQ(-EINVAL, cacheSetupResolver(TEST_NETID, setup));
+ EXPECT_FALSE(resolv_has_nameservers(TEST_NETID));
+ }
+}
+
+TEST_F(ResolvCacheTest, ResolverSetup_DropDomain) {
+ EXPECT_EQ(0, cacheCreate(TEST_NETID));
+
+ // Setup with one domain which is too long.
+ const std::vector<std::string> servers = {"127.0.0.1", "fe80::1"};
+ const std::string domainTooLong(MAXDNSRCHPATH, '1');
+ const std::string validDomain1(MAXDNSRCHPATH - 1, '2');
+ const std::string validDomain2(MAXDNSRCHPATH - 1, '3');
+ SetupParams setup = {
+ .servers = servers,
+ .domains = {},
+ .params = kParams,
+ };
+ CacheStats expect = {
+ .setup = setup,
+ .stats = {},
+ .pendingReqTimeoutCount = 0,
+ };
+
+ // Overlength domains are dropped.
+ setup.domains = {validDomain1, domainTooLong, validDomain2};
+ expect.setup.domains = {validDomain1, validDomain2};
+ EXPECT_EQ(0, cacheSetupResolver(TEST_NETID, setup));
+ EXPECT_TRUE(resolv_has_nameservers(TEST_NETID));
+ expectCacheStats("ResolverSetup_Domains drop overlength", TEST_NETID, expect);
+
+ // Duplicate domains are dropped.
+ setup.domains = {validDomain1, validDomain2, validDomain1, validDomain2};
+ expect.setup.domains = {validDomain1, validDomain2};
+ EXPECT_EQ(0, cacheSetupResolver(TEST_NETID, setup));
+ EXPECT_TRUE(resolv_has_nameservers(TEST_NETID));
+ expectCacheStats("ResolverSetup_Domains drop duplicates", TEST_NETID, expect);
+}
+
+TEST_F(ResolvCacheTest, ResolverSetup_Prune) {
+ EXPECT_EQ(0, cacheCreate(TEST_NETID));
+ const std::vector<std::string> servers = {"127.0.0.1", "::127.0.0.2", "fe80::1", "fe80::2",
+ "fe80::3"};
+ const std::vector<std::string> domains = {"d1.com", "d2.com", "d3.com", "d4.com",
+ "d5.com", "d6.com", "d7.com"};
+ const SetupParams setup = {
+ .servers = servers,
+ .domains = domains,
+ .params = kParams,
+ };
+
+ EXPECT_EQ(0, cacheSetupResolver(TEST_NETID, setup));
+ EXPECT_TRUE(resolv_has_nameservers(TEST_NETID));
+
+ const CacheStats cacheStats = {
+ .setup = {.servers = std::vector(servers.begin(), servers.begin() + MAXNS),
+ .domains = std::vector(domains.begin(), domains.begin() + MAXDNSRCH),
+ .params = setup.params},
+ .stats = {},
+ .pendingReqTimeoutCount = 0,
+ };
+ expectCacheStats("ResolverSetup_Prune", TEST_NETID, cacheStats);
+}
+
+TEST_F(ResolvCacheTest, GetStats) {
+ EXPECT_EQ(0, cacheCreate(TEST_NETID));
+ const SetupParams setup = {
+ .servers = {"127.0.0.1", "::127.0.0.2", "fe80::3"},
+ .domains = {"domain1.com", "domain2.com"},
+ .params = kParams,
+ };
+
+ EXPECT_EQ(0, cacheSetupResolver(TEST_NETID, setup));
+ EXPECT_TRUE(resolv_has_nameservers(TEST_NETID));
+
+ const CacheStats cacheStats = {
+ .setup = setup,
+ .stats = {},
+ .pendingReqTimeoutCount = 0,
+ };
+ expectCacheStats("GetStats", TEST_NETID, cacheStats);
+}
+
// TODO: Tests for struct resolv_cache_info, including:
// - res_params
// -- resolv_cache_get_resolver_stats()
diff --git a/resolv/resolv_cache.h b/resolv/resolv_cache.h
index 1a27e4b..cbd1ac2 100644
--- a/resolv/resolv_cache.h
+++ b/resolv/resolv_cache.h
@@ -64,12 +64,8 @@
void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags);
// Sets name servers for a given network.
-int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
- const std::vector<std::string>& domains, const res_params* params);
-
-// TODO: remove it after updating all callers.
-int resolv_set_nameservers(unsigned netid, const char** servers, int numservers,
- const char* domains, const res_params* params);
+int resolv_set_nameservers(unsigned netid, const std::vector<std::string>& servers,
+ const std::vector<std::string>& domains, const res_params& params);
// Creates the cache associated with the given network.
int resolv_create_cache_for_net(unsigned netid);
@@ -84,4 +80,4 @@
// For test only.
// Get the expiration time of a cache entry. Return 0 on success; otherwise, an negative error is
// returned if the expiration time can't be acquired.
-int resolv_cache_get_expiration(unsigned netid, const std::vector<char> query, time_t* expiration);
+int resolv_cache_get_expiration(unsigned netid, const std::vector<char>& query, time_t* expiration);