Add integration test for DNS answer RR with CNAMEs chain

Bug: 123376330
Test: resolv_integration_test
Change-Id: I74ba26f6a892f86e40b6b02611d7f9adee454fec
diff --git a/resolv/dns_responder/dns_responder.cpp b/resolv/dns_responder/dns_responder.cpp
index fe116b8..5af200b 100644
--- a/resolv/dns_responder/dns_responder.cpp
+++ b/resolv/dns_responder/dns_responder.cpp
@@ -27,6 +27,7 @@
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <unistd.h>
+#include <set>
 
 #include <iostream>
 #include <vector>
@@ -827,11 +828,12 @@
             return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
                                      response_len);
         }
+
         if (!addAnswerRecords(question, &header.answers)) {
-            return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
-                                     response_len);
+            return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, response_len);
         }
     }
+
     header.qr = true;
     char* response_cur = header.write(response, response + *response_len);
     if (response_cur == nullptr) {
@@ -844,45 +846,73 @@
 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
                                     std::vector<DNSRecord>* answers) const {
     std::lock_guard guard(mappings_mutex_);
-    auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
-    if (it == mappings_.end()) {
+    std::string rname = question.qname.name;
+    std::vector<int> rtypes;
+
+    if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa)
+        rtypes.push_back(ns_type::ns_t_cname);
+    rtypes.push_back(question.qtype);
+    for (int rtype : rtypes) {
+        std::set<std::string> cnames_Loop;
+        std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
+        while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
+            if (rtype == ns_type::ns_t_cname) {
+                // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
+                // As following, the query will stop on loop3 by detecting the same cname.
+                // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
+                // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
+                // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
+                //   is found in cnames_Loop already, break the query loop.)
+                if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
+                cnames_Loop.insert(it->first.name);
+            }
+            DNSRecord record{
+                    .name = {.name = it->first.name},
+                    .rtype = it->first.type,
+                    .rclass = ns_class::ns_c_in,
+                    .ttl = 5,  // seconds
+            };
+            fillAnswerRdata(it->second, record);
+            answers->push_back(std::move(record));
+            if (rtype != ns_type::ns_t_cname) break;
+            rname = it->second;
+        }
+    }
+
+    if (answers->size() == 0) {
         // TODO(imaipi): handle correctly
         ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
-            question.qname.name.c_str(), dnstype2str(question.qtype));
-        return true;
+              question.qname.name.c_str(), dnstype2str(question.qtype));
     }
-    DBGLOG("mapping found for %s %s: %s", question.qname.name.c_str(), dnstype2str(question.qtype),
-           it->second.c_str());
-    DNSRecord record;
-    record.name = question.qname;
-    record.rtype = question.qtype;
-    record.rclass = ns_class::ns_c_in;
-    record.ttl = 5;  // seconds
-    if (question.qtype == ns_type::ns_t_a) {
+
+    return true;
+}
+
+bool DNSResponder::fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const {
+    if (record.rtype == ns_type::ns_t_a) {
         record.rdata.resize(4);
-        if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
-            ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
+        if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
+            ALOGI("inet_pton(AF_INET, %s) failed", rdatastr.c_str());
             return false;
         }
-    } else if (question.qtype == ns_type::ns_t_aaaa) {
+    } else if (record.rtype == ns_type::ns_t_aaaa) {
         record.rdata.resize(16);
-        if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
-            ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
+        if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
+            ALOGI("inet_pton(AF_INET6, %s) failed", rdatastr.c_str());
             return false;
         }
-    } else if (question.qtype == ns_type::ns_t_ptr) {
+    } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname)) {
         constexpr char delimiter = '.';
-        std::string name = it->second;
+        std::string name = rdatastr;
         std::vector<char> rdata;
 
-        // PTRDNAME field
+        // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
         // The "name" should be an absolute domain name which ends in a dot.
         if (name.back() != delimiter) {
             ALOGI("invalid absolute domain name");
             return false;
         }
         name.pop_back();  // remove the dot in tail
-
         for (const std::string& label : android::base::Split(name, {delimiter})) {
             // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
             if (label.length() == 0 || label.length() > 63) {
@@ -902,10 +932,9 @@
         }
         record.rdata = move(rdata);
     } else {
-        ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
+        ALOGI("unhandled qtype %s", dnstype2str(record.rtype));
         return false;
     }
-    answers->push_back(std::move(record));
     return true;
 }
 
diff --git a/resolv/dns_responder/dns_responder.h b/resolv/dns_responder/dns_responder.h
index 360ddfc..31e41c7 100644
--- a/resolv/dns_responder/dns_responder.h
+++ b/resolv/dns_responder/dns_responder.h
@@ -113,8 +113,9 @@
     bool handleDNSRequest(const char* buffer, ssize_t buffer_len,
                           char* response, size_t* response_len) const;
 
-    bool addAnswerRecords(const DNSQuestion& question,
-                          std::vector<DNSRecord>* answers) const;
+    bool addAnswerRecords(const DNSQuestion& question, std::vector<DNSRecord>* answers) const;
+
+    bool fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const;
 
     bool generateErrorResponse(DNSHeader* header, ns_rcode rcode,
                                char* response, size_t* response_len) const;
diff --git a/resolv/resolver_test.cpp b/resolv/resolver_test.cpp
index 1f9c8cc..50138a2 100644
--- a/resolv/resolver_test.cpp
+++ b/resolv/resolver_test.cpp
@@ -334,6 +334,80 @@
     EXPECT_TRUE(result->h_addr_list[1] == nullptr);
 }
 
+TEST_F(ResolverTest, GetHostByName_cnames) {
+    constexpr char host_name[] = "host.example.com.";
+    size_t cnamecount = 0;
+    test::DNSResponder dns;
+
+    const std::vector<DnsRecord> records = {
+            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+            {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
+            {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
+            {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
+            {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
+            {"e.example.com.", ns_type::ns_t_cname, host_name},
+            {host_name, ns_type::ns_t_a, "1.2.3.3"},
+            {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
+    };
+    StartDns(dns, records);
+    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+    // using gethostbyname2() to resolve ipv4 hello.example.com. to 1.2.3.3
+    // Ensure the v4 address and cnames are correct
+    const hostent* result;
+    result = gethostbyname2("hello", AF_INET);
+    ASSERT_FALSE(result == nullptr);
+
+    for (int i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
+        std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
+        EXPECT_EQ(result->h_aliases[i], domain_name);
+        cnamecount++;
+    }
+    // The size of "Non-cname type" record in DNS records is 2
+    ASSERT_EQ(cnamecount, records.size() - 2);
+    ASSERT_EQ(4, result->h_length);
+    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
+    EXPECT_EQ("1.2.3.3", ToString(result));
+    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
+    EXPECT_EQ(1U, dns.queries().size()) << dns.dumpQueries();
+
+    // using gethostbyname2() to resolve ipv6 hello.example.com. to 2001:db8::42
+    // Ensure the v6 address and cnames are correct
+    cnamecount = 0;
+    dns.clearQueries();
+    result = gethostbyname2("hello", AF_INET6);
+    for (unsigned i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
+        std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
+        EXPECT_EQ(result->h_aliases[i], domain_name);
+        cnamecount++;
+    }
+    // The size of "Non-cname type" DNS record in records is 2
+    ASSERT_EQ(cnamecount, records.size() - 2);
+    ASSERT_FALSE(result == nullptr);
+    ASSERT_EQ(16, result->h_length);
+    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
+    EXPECT_EQ("2001:db8::42", ToString(result));
+    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
+}
+
+TEST_F(ResolverTest, GetHostByName_cnamesInfiniteLoop) {
+    test::DNSResponder dns;
+    const std::vector<DnsRecord> records = {
+            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+            {"a.example.com.", ns_type::ns_t_cname, kHelloExampleCom},
+    };
+    StartDns(dns, records);
+    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+    const hostent* result;
+    result = gethostbyname2("hello", AF_INET);
+    ASSERT_TRUE(result == nullptr);
+
+    dns.clearQueries();
+    result = gethostbyname2("hello", AF_INET6);
+    ASSERT_TRUE(result == nullptr);
+}
+
 TEST_F(ResolverTest, GetHostByName_localhost) {
     constexpr char name_camelcase[] = "LocalHost";
     constexpr char name_ip6_dot[] = "ip6-localhost.";
@@ -677,6 +751,70 @@
     t2.join();
 }
 
+TEST_F(ResolverTest, GetAddrInfo_cnames) {
+    constexpr char host_name[] = "host.example.com.";
+    test::DNSResponder dns;
+    const std::vector<DnsRecord> records = {
+            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+            {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
+            {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
+            {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
+            {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
+            {"e.example.com.", ns_type::ns_t_cname, host_name},
+            {host_name, ns_type::ns_t_a, "1.2.3.3"},
+            {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
+    };
+    StartDns(dns, records);
+    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+    addrinfo hints = {.ai_family = AF_INET};
+    ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
+    EXPECT_TRUE(result != nullptr);
+    EXPECT_EQ("1.2.3.3", ToString(result));
+
+    dns.clearQueries();
+    hints = {.ai_family = AF_INET6};
+    result = safe_getaddrinfo("hello", nullptr, &hints);
+    EXPECT_TRUE(result != nullptr);
+    EXPECT_EQ("2001:db8::42", ToString(result));
+}
+
+TEST_F(ResolverTest, GetAddrInfo_cnamesNoIpAddress) {
+    test::DNSResponder dns;
+    const std::vector<DnsRecord> records = {
+            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
+    };
+    StartDns(dns, records);
+    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+    addrinfo hints = {.ai_family = AF_INET};
+    ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
+    EXPECT_TRUE(result == nullptr);
+
+    dns.clearQueries();
+    hints = {.ai_family = AF_INET6};
+    result = safe_getaddrinfo("hello", nullptr, &hints);
+    EXPECT_TRUE(result == nullptr);
+}
+
+TEST_F(ResolverTest, GetAddrInfo_cnamesIllegalRdata) {
+    test::DNSResponder dns;
+    const std::vector<DnsRecord> records = {
+            {kHelloExampleCom, ns_type::ns_t_cname, ".!#?"},
+    };
+    StartDns(dns, records);
+    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());
+
+    addrinfo hints = {.ai_family = AF_INET};
+    ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
+    EXPECT_TRUE(result == nullptr);
+
+    dns.clearQueries();
+    hints = {.ai_family = AF_INET6};
+    result = safe_getaddrinfo("hello", nullptr, &hints);
+    EXPECT_TRUE(result == nullptr);
+}
+
 TEST_F(ResolverTest, MultidomainResolution) {
     constexpr char host_name[] = "nihao.example2.com.";
     std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };