Add integration test for DNS answer RR with CNAMEs chain
Bug: 123376330
Test: resolv_integration_test
Change-Id: I74ba26f6a892f86e40b6b02611d7f9adee454fec
diff --git a/resolv/dns_responder/dns_responder.cpp b/resolv/dns_responder/dns_responder.cpp
index fe116b8..5af200b 100644
--- a/resolv/dns_responder/dns_responder.cpp
+++ b/resolv/dns_responder/dns_responder.cpp
@@ -27,6 +27,7 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
+#include <set>
#include <iostream>
#include <vector>
@@ -827,11 +828,12 @@
return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
response_len);
}
+
if (!addAnswerRecords(question, &header.answers)) {
- return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
- response_len);
+ return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, response_len);
}
}
+
header.qr = true;
char* response_cur = header.write(response, response + *response_len);
if (response_cur == nullptr) {
@@ -844,45 +846,73 @@
bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
std::vector<DNSRecord>* answers) const {
std::lock_guard guard(mappings_mutex_);
- auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
- if (it == mappings_.end()) {
+ std::string rname = question.qname.name;
+ std::vector<int> rtypes;
+
+ if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa)
+ rtypes.push_back(ns_type::ns_t_cname);
+ rtypes.push_back(question.qtype);
+ for (int rtype : rtypes) {
+ std::set<std::string> cnames_Loop;
+ std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
+ while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
+ if (rtype == ns_type::ns_t_cname) {
+ // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
+ // As following, the query will stop on loop3 by detecting the same cname.
+ // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
+ // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
+ // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
+ // is found in cnames_Loop already, break the query loop.)
+ if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
+ cnames_Loop.insert(it->first.name);
+ }
+ DNSRecord record{
+ .name = {.name = it->first.name},
+ .rtype = it->first.type,
+ .rclass = ns_class::ns_c_in,
+ .ttl = 5, // seconds
+ };
+ fillAnswerRdata(it->second, record);
+ answers->push_back(std::move(record));
+ if (rtype != ns_type::ns_t_cname) break;
+ rname = it->second;
+ }
+ }
+
+ if (answers->size() == 0) {
// TODO(imaipi): handle correctly
ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
- question.qname.name.c_str(), dnstype2str(question.qtype));
- return true;
+ question.qname.name.c_str(), dnstype2str(question.qtype));
}
- DBGLOG("mapping found for %s %s: %s", question.qname.name.c_str(), dnstype2str(question.qtype),
- it->second.c_str());
- DNSRecord record;
- record.name = question.qname;
- record.rtype = question.qtype;
- record.rclass = ns_class::ns_c_in;
- record.ttl = 5; // seconds
- if (question.qtype == ns_type::ns_t_a) {
+
+ return true;
+}
+
+bool DNSResponder::fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const {
+ if (record.rtype == ns_type::ns_t_a) {
record.rdata.resize(4);
- if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
- ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
+ if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
+ ALOGI("inet_pton(AF_INET, %s) failed", rdatastr.c_str());
return false;
}
- } else if (question.qtype == ns_type::ns_t_aaaa) {
+ } else if (record.rtype == ns_type::ns_t_aaaa) {
record.rdata.resize(16);
- if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
- ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
+ if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
+ ALOGI("inet_pton(AF_INET6, %s) failed", rdatastr.c_str());
return false;
}
- } else if (question.qtype == ns_type::ns_t_ptr) {
+ } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname)) {
constexpr char delimiter = '.';
- std::string name = it->second;
+ std::string name = rdatastr;
std::vector<char> rdata;
- // PTRDNAME field
+ // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
// The "name" should be an absolute domain name which ends in a dot.
if (name.back() != delimiter) {
ALOGI("invalid absolute domain name");
return false;
}
name.pop_back(); // remove the dot in tail
-
for (const std::string& label : android::base::Split(name, {delimiter})) {
// The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
if (label.length() == 0 || label.length() > 63) {
@@ -902,10 +932,9 @@
}
record.rdata = move(rdata);
} else {
- ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
+ ALOGI("unhandled qtype %s", dnstype2str(record.rtype));
return false;
}
- answers->push_back(std::move(record));
return true;
}
diff --git a/resolv/dns_responder/dns_responder.h b/resolv/dns_responder/dns_responder.h
index 360ddfc..31e41c7 100644
--- a/resolv/dns_responder/dns_responder.h
+++ b/resolv/dns_responder/dns_responder.h
@@ -113,8 +113,9 @@
bool handleDNSRequest(const char* buffer, ssize_t buffer_len,
char* response, size_t* response_len) const;
- bool addAnswerRecords(const DNSQuestion& question,
- std::vector<DNSRecord>* answers) const;
+ bool addAnswerRecords(const DNSQuestion& question, std::vector<DNSRecord>* answers) const;
+
+ bool fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const;
bool generateErrorResponse(DNSHeader* header, ns_rcode rcode,
char* response, size_t* response_len) const;
diff --git a/resolv/resolver_test.cpp b/resolv/resolver_test.cpp
index 1f9c8cc..50138a2 100644
--- a/resolv/resolver_test.cpp
+++ b/resolv/resolver_test.cpp
@@ -334,6 +334,80 @@
EXPECT_TRUE(result->h_addr_list[1] == nullptr);
}
+TEST_F(ResolverTest, GetHostByName_cnames) {
+ constexpr char host_name[] = "host.example.com.";
+ size_t cnamecount = 0;
+ test::DNSResponder dns;
+
+ const std::vector<DnsRecord> records = {
+ {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+ {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
+ {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
+ {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
+ {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
+ {"e.example.com.", ns_type::ns_t_cname, host_name},
+ {host_name, ns_type::ns_t_a, "1.2.3.3"},
+ {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
+ };
+ StartDns(dns, records);
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+ // using gethostbyname2() to resolve ipv4 hello.example.com. to 1.2.3.3
+ // Ensure the v4 address and cnames are correct
+ const hostent* result;
+ result = gethostbyname2("hello", AF_INET);
+ ASSERT_FALSE(result == nullptr);
+
+ for (int i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
+ std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
+ EXPECT_EQ(result->h_aliases[i], domain_name);
+ cnamecount++;
+ }
+ // The size of "Non-cname type" record in DNS records is 2
+ ASSERT_EQ(cnamecount, records.size() - 2);
+ ASSERT_EQ(4, result->h_length);
+ ASSERT_FALSE(result->h_addr_list[0] == nullptr);
+ EXPECT_EQ("1.2.3.3", ToString(result));
+ EXPECT_TRUE(result->h_addr_list[1] == nullptr);
+ EXPECT_EQ(1U, dns.queries().size()) << dns.dumpQueries();
+
+ // using gethostbyname2() to resolve ipv6 hello.example.com. to 2001:db8::42
+ // Ensure the v6 address and cnames are correct
+ cnamecount = 0;
+ dns.clearQueries();
+ result = gethostbyname2("hello", AF_INET6);
+ for (unsigned i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
+ std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
+ EXPECT_EQ(result->h_aliases[i], domain_name);
+ cnamecount++;
+ }
+ // The size of "Non-cname type" DNS record in records is 2
+ ASSERT_EQ(cnamecount, records.size() - 2);
+ ASSERT_FALSE(result == nullptr);
+ ASSERT_EQ(16, result->h_length);
+ ASSERT_FALSE(result->h_addr_list[0] == nullptr);
+ EXPECT_EQ("2001:db8::42", ToString(result));
+ EXPECT_TRUE(result->h_addr_list[1] == nullptr);
+}
+
+TEST_F(ResolverTest, GetHostByName_cnamesInfiniteLoop) {
+ test::DNSResponder dns;
+ const std::vector<DnsRecord> records = {
+ {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+ {"a.example.com.", ns_type::ns_t_cname, kHelloExampleCom},
+ };
+ StartDns(dns, records);
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+ const hostent* result;
+ result = gethostbyname2("hello", AF_INET);
+ ASSERT_TRUE(result == nullptr);
+
+ dns.clearQueries();
+ result = gethostbyname2("hello", AF_INET6);
+ ASSERT_TRUE(result == nullptr);
+}
+
TEST_F(ResolverTest, GetHostByName_localhost) {
constexpr char name_camelcase[] = "LocalHost";
constexpr char name_ip6_dot[] = "ip6-localhost.";
@@ -677,6 +751,70 @@
t2.join();
}
+TEST_F(ResolverTest, GetAddrInfo_cnames) {
+ constexpr char host_name[] = "host.example.com.";
+ test::DNSResponder dns;
+ const std::vector<DnsRecord> records = {
+ {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+ {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
+ {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
+ {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
+ {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
+ {"e.example.com.", ns_type::ns_t_cname, host_name},
+ {host_name, ns_type::ns_t_a, "1.2.3.3"},
+ {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
+ };
+ StartDns(dns, records);
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+ addrinfo hints = {.ai_family = AF_INET};
+ ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
+ EXPECT_TRUE(result != nullptr);
+ EXPECT_EQ("1.2.3.3", ToString(result));
+
+ dns.clearQueries();
+ hints = {.ai_family = AF_INET6};
+ result = safe_getaddrinfo("hello", nullptr, &hints);
+ EXPECT_TRUE(result != nullptr);
+ EXPECT_EQ("2001:db8::42", ToString(result));
+}
+
+TEST_F(ResolverTest, GetAddrInfo_cnamesNoIpAddress) {
+ test::DNSResponder dns;
+ const std::vector<DnsRecord> records = {
+ {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+ };
+ StartDns(dns, records);
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+ addrinfo hints = {.ai_family = AF_INET};
+ ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
+ EXPECT_TRUE(result == nullptr);
+
+ dns.clearQueries();
+ hints = {.ai_family = AF_INET6};
+ result = safe_getaddrinfo("hello", nullptr, &hints);
+ EXPECT_TRUE(result == nullptr);
+}
+
+TEST_F(ResolverTest, GetAddrInfo_cnamesIllegalRdata) {
+ test::DNSResponder dns;
+ const std::vector<DnsRecord> records = {
+ {kHelloExampleCom, ns_type::ns_t_cname, ".!#?"},
+ };
+ StartDns(dns, records);
+ ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+ addrinfo hints = {.ai_family = AF_INET};
+ ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
+ EXPECT_TRUE(result == nullptr);
+
+ dns.clearQueries();
+ hints = {.ai_family = AF_INET6};
+ result = safe_getaddrinfo("hello", nullptr, &hints);
+ EXPECT_TRUE(result == nullptr);
+}
+
TEST_F(ResolverTest, MultidomainResolution) {
constexpr char host_name[] = "nihao.example2.com.";
std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };