Support prioritizing DNS servers

The change introduces a way to prioritize DNS servers on the basis of
DNS query response time, which aims to replace the current design that
is biased towards using the first DNS server assigned from networks.

The quality is evaluated based on the heuristics:
  - The more latency it is, the less likely it is used.
  - The longer time it is not used, the more likely it is used.

Compared to the current design, the proposed method detects bad DNS
servers more quickly. For instance, a server which is unreachable or
times out can be detected and deprioritized with few trials by backoff
penalty and abnormal latency.

Similar to the current design, a server which has been regarded as bad
quality can be used again, but it depends on how much worse it is. A
counter is used to count how many times a DNS server not being used,
which avoids from constantly using the same DNS server.

This change comprises:

[1] Allow the resolver to sort DNS servers on the basis of DNS query
    response time.
[2] Add an experiment flag to enable/disable the sorting.
[3] Show the result of the quantified quality of DNS servers in
    dumpsys dnsresolver.
[4] Add unit tests for DnsStats::getSortedServers().
[5] Revise the integration tests which are sensitive to the nameserver
    sorting, including two big changes in SkipBadServersDueToInternalError
    and SkipBadServersDueToTimeout and some minor changes.

Bug: 137169582
Test: ran resolv_unit_test
      ran resolv_integration_test with the sorting enabled
      ran resolv_integration_test with the sorting disabled
Change-Id: I24b6a317f135a942ce0ea310c81dfe658bada6a7
diff --git a/DnsStats.cpp b/DnsStats.cpp
index 970e30e..2f76d8a 100644
--- a/DnsStats.cpp
+++ b/DnsStats.cpp
@@ -77,11 +77,14 @@
            std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
 }
 
+int StatsData::averageLatencyMs() const {
+    return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
+}
+
 std::string StatsData::toString() const {
     if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());
 
     const auto now = std::chrono::steady_clock::now();
-    const int meanLatencyMs = duration_cast<milliseconds>(latencyUs).count() / total;
     const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
     std::string buf;
     for (const auto& [rcode, counts] : rcodeCounts) {
@@ -90,7 +93,7 @@
         }
     }
     return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
-                        meanLatencyMs, buf.c_str(), lastUpdateSec);
+                        averageLatencyMs(), buf.c_str(), lastUpdateSec);
 }
 
 StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
@@ -104,6 +107,10 @@
         updateStatsData(mRecords.front(), false);
         mRecords.pop_front();
     }
+
+    // Update the quality factors.
+    mSkippedCount = 0;
+    updatePenalty(record);
 }
 
 void StatsRecords::updateStatsData(const Record& record, const bool add) {
@@ -120,6 +127,41 @@
     mStatsData.lastUpdate = std::chrono::steady_clock::now();
 }
 
+void StatsRecords::updatePenalty(const Record& record) {
+    switch (record.rcode) {
+        case NS_R_NO_ERROR:
+        case NS_R_NXDOMAIN:
+        case NS_R_NOTAUTH:
+            mPenalty = 0;
+            return;
+        default:
+            // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
+            if (mPenalty == 0) {
+                mPenalty = 100;
+            } else {
+                // The evaluated quality drops more quickly when continuous failures happen.
+                mPenalty = std::min(mPenalty * 2, kMaxQuality);
+            }
+            return;
+    }
+}
+
+double StatsRecords::score() const {
+    const int avgRtt = mStatsData.averageLatencyMs();
+
+    // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
+    //   1) when the server doesn't have any stats yet.
+    //   2) when the sorting has been disabled while it was enabled before.
+    int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);
+
+    // Normalization.
+    return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
+}
+
+void StatsRecords::incrementSkippedCount() {
+    mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
+}
+
 bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
     if (!ensureNoInvalidIp(servers)) return false;
 
@@ -147,6 +189,7 @@
 bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
     if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;
 
+    bool added = false;
     for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
         if (serverSockAddr == ipSockAddr) {
             const StatsRecords::Record rec = {
@@ -154,10 +197,36 @@
                     .latencyUs = microseconds(record.latency_micros()),
             };
             statsRecords.push(rec);
-            return true;
+            added = true;
+        } else {
+            statsRecords.incrementSkippedCount();
         }
     }
-    return false;
+
+    return added;
+}
+
+std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
+    // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
+    // while. Need to figure out if it is worth doing for DoT servers.
+    if (protocol == PROTO_DOT) return {};
+
+    auto it = mStats.find(protocol);
+    if (it == mStats.end()) return {};
+
+    // Sorting on insertion in decreasing order.
+    std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
+    for (const auto& [ip, statsRecords] : it->second) {
+        sortedData.insert({statsRecords.score(), ip});
+    }
+
+    std::vector<IPSockAddr> ret;
+    ret.reserve(sortedData.size());
+    for (auto& [_, v] : sortedData) {
+        ret.push_back(v);  // IPSockAddr is trivially-copyable.
+    }
+
+    return ret;
 }
 
 std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
@@ -179,7 +248,10 @@
             return;
         }
         for (const auto& [_, statsRecords] : statsMap) {
-            dw.println("%s", statsRecords.getStatsData().toString().c_str());
+            const StatsData& data = statsRecords.getStatsData();
+            std::string str = data.toString();
+            str += StringPrintf(" score{%.1f}", statsRecords.score());
+            dw.println("%s", str.c_str());
         }
     };
 
diff --git a/DnsStats.h b/DnsStats.h
index 40dad94..c5459b4 100644
--- a/DnsStats.h
+++ b/DnsStats.h
@@ -54,6 +54,7 @@
     // The last update timestamp.
     std::chrono::time_point<std::chrono::steady_clock> lastUpdate;
 
+    int averageLatencyMs() const;
     std::string toString() const;
 
     // For testing.
@@ -77,12 +78,31 @@
 
     const StatsData& getStatsData() const { return mStatsData; }
 
+    // Quantifies the quality based on the current quality factors and the latency, and normalize
+    // the value to a score between 0 to 100.
+    double score() const;
+
+    void incrementSkippedCount();
+
   private:
     void updateStatsData(const Record& record, const bool add);
+    void updatePenalty(const Record& record);
 
     std::deque<Record> mRecords;
     size_t mCapacity;
     StatsData mStatsData;
+
+    // A quality factor used to distinguish if the server can't be evaluated by latency alone, such
+    // as instant failure on connect.
+    int mPenalty = 0;
+
+    // A quality factor used to prevent starvation.
+    int mSkippedCount = 0;
+
+    // The maximum of the quantified result. As the sorting is on the basis of server latency, limit
+    // the maximal value of the quantity to 10000 in correspondence with the maximal cleartext
+    // query timeout 10000 milliseconds. This helps normalize the value of the quality to a score.
+    static constexpr int kMaxQuality = 10000;
 };
 
 // DnsStats class manages the statistics of DNS servers per netId.
@@ -98,13 +118,14 @@
     // Return true if |record| is successfully added into |server|'s stats; otherwise, return false.
     bool addStats(const netdutils::IPSockAddr& server, const DnsQueryEvent& record);
 
+    std::vector<netdutils::IPSockAddr> getSortedServers(Protocol protocol) const;
+
     void dump(netdutils::DumpWriter& dw);
 
     // For testing.
     std::vector<StatsData> getStats(Protocol protocol) const;
 
     // TODO: Compatible support for getResolverInfo().
-    // TODO: Support getSortedServers().
 
     static constexpr size_t kLogSize = 128;
 
diff --git a/DnsStatsTest.cpp b/DnsStatsTest.cpp
index 419e6db..38a7a21 100644
--- a/DnsStatsTest.cpp
+++ b/DnsStatsTest.cpp
@@ -52,6 +52,8 @@
 
 }  // namespace
 
+// TODO: add StatsDataTest to ensure its methods return correct outputs.
+
 class StatsRecordsTest : public ::testing::Test {};
 
 TEST_F(StatsRecordsTest, PushRecord) {
@@ -95,9 +97,9 @@
     void verifyDumpOutput(const std::vector<StatsData>& tcpData,
                           const std::vector<StatsData>& udpData,
                           const std::vector<StatsData>& dotData) {
-        // A simple pattern to capture two matches:
-        //     server address (empty allowed) and its statistics.
-        const std::regex pattern(R"(\s{4,}([0-9a-fA-F:\.]*) ([<(].*[>)]))");
+        // A pattern to capture three matches:
+        //     server address (empty allowed), the statistics, and the score.
+        const std::regex pattern(R"(\s{4,}([0-9a-fA-F:\.]*)[ ]?([<(].*[>)])[ ]?(\S*))");
         std::string dumpString = captureDumpOutput();
 
         const auto check = [&](const std::vector<StatsData>& statsData, const std::string& protocol,
@@ -111,6 +113,7 @@
                 ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                 EXPECT_TRUE(sm[1].str().empty());
                 EXPECT_EQ(sm[2], "<no server>");
+                EXPECT_TRUE(sm[3].str().empty());
                 *dumpString = sm.suffix();
                 return;
             }
@@ -119,6 +122,7 @@
                 ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                 EXPECT_EQ(sm[1], stats.serverSockAddr.ip().toString());
                 EXPECT_FALSE(sm[2].str().empty());
+                EXPECT_FALSE(sm[3].str().empty());
                 *dumpString = sm.suffix();
             }
         };
@@ -379,4 +383,110 @@
     verifyDumpOutput(expectedStats, expectedStats, expectedStats);
 }
 
+TEST_F(DnsStatsTest, GetServers_SortingByLatency) {
+    const IPSockAddr server1 = IPSockAddr::toIPSockAddr("127.0.0.1", 53);
+    const IPSockAddr server2 = IPSockAddr::toIPSockAddr("127.0.0.2", 53);
+    const IPSockAddr server3 = IPSockAddr::toIPSockAddr("2001:db8:cafe:d00d::1", 53);
+    const IPSockAddr server4 = IPSockAddr::toIPSockAddr("2001:db8:cafe:d00d::2", 53);
+
+    // Return empty list before setup.
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP), IsEmpty());
+
+    // Before there's any stats, the list of the sorted servers is the same as the setup's one.
+    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_UDP));
+    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_DOT));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server1, server2, server3, server4}));
+
+    // Add a record to server1. The qualities of the other servers increase.
+    EXPECT_TRUE(mDnsStats.addStats(server1, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 10ms)));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server2, server3, server4, server1}));
+
+    // Add a record, with less repose time than server1, to server3.
+    EXPECT_TRUE(mDnsStats.addStats(server3, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 5ms)));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server2, server4, server3, server1}));
+
+    // Even though server2 has zero response time, select server4 as the first server because it
+    // doesn't have stats yet.
+    EXPECT_TRUE(mDnsStats.addStats(server2, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 0ms)));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server4, server2, server3, server1}));
+
+    // Updating DoT record to server4 changes nothing.
+    EXPECT_TRUE(mDnsStats.addStats(server4, makeDnsQueryEvent(PROTO_DOT, NS_R_NO_ERROR, 10ms)));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server4, server2, server3, server1}));
+
+    // Add a record, with a very large value of respose time, to server4.
+    EXPECT_TRUE(mDnsStats.addStats(server4, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 500000ms)));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server2, server3, server1, server4}));
+
+    // The list of the DNS servers changed.
+    EXPECT_TRUE(mDnsStats.setServers({server2, server4}, PROTO_UDP));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server2, server4}));
+
+    // It fails to add records to an non-existing server, and nothing is changed in getting
+    // the sorted servers.
+    EXPECT_FALSE(mDnsStats.addStats(server1, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 10ms)));
+    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
+                testing::ElementsAreArray({server2, server4}));
+}
+
+TEST_F(DnsStatsTest, GetServers_DeprioritizingBadServers) {
+    const IPSockAddr server1 = IPSockAddr::toIPSockAddr("127.0.0.1", 53);
+    const IPSockAddr server2 = IPSockAddr::toIPSockAddr("127.0.0.2", 53);
+    const IPSockAddr server3 = IPSockAddr::toIPSockAddr("127.0.0.3", 53);
+    const IPSockAddr server4 = IPSockAddr::toIPSockAddr("127.0.0.4", 53);
+
+    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_UDP));
+
+    int server1Counts = 0;
+    int server2Counts = 0;
+    for (int i = 0; i < 5000; i++) {
+        const auto servers = mDnsStats.getSortedServers(PROTO_UDP);
+        EXPECT_EQ(servers.size(), 4U);
+        if (servers[0] == server1) {
+            // server1 is relatively slowly responsive.
+            EXPECT_TRUE(mDnsStats.addStats(servers[0],
+                                           makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 200ms)));
+            server1Counts++;
+        } else if (servers[0] == server2) {
+            // server2 is relatively quickly responsive.
+            EXPECT_TRUE(mDnsStats.addStats(servers[0],
+                                           makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 100ms)));
+            server2Counts++;
+        } else if (servers[0] == server3) {
+            // server3 always times out.
+            EXPECT_TRUE(mDnsStats.addStats(servers[0],
+                                           makeDnsQueryEvent(PROTO_UDP, NS_R_TIMEOUT, 1000ms)));
+        } else if (servers[0] == server4) {
+            // server4 is unusable.
+            EXPECT_TRUE(mDnsStats.addStats(servers[0],
+                                           makeDnsQueryEvent(PROTO_UDP, NS_R_INTERNAL_ERROR, 1ms)));
+        }
+    }
+
+    const std::vector<StatsData> allStatsData = mDnsStats.getStats(PROTO_UDP);
+    for (const auto& data : allStatsData) {
+        EXPECT_EQ(data.rcodeCounts.size(), 1U);
+        if (data.serverSockAddr == server1 || data.serverSockAddr == server2) {
+            const auto it = data.rcodeCounts.find(NS_R_NO_ERROR);
+            ASSERT_NE(it, data.rcodeCounts.end());
+            EXPECT_GT(server2Counts, 2 * server1Counts);  // At least twice larger.
+        } else if (data.serverSockAddr == server3) {
+            const auto it = data.rcodeCounts.find(NS_R_TIMEOUT);
+            ASSERT_NE(it, data.rcodeCounts.end());
+            EXPECT_LT(it->second, 10);
+        } else if (data.serverSockAddr == server4) {
+            const auto it = data.rcodeCounts.find(NS_R_INTERNAL_ERROR);
+            ASSERT_NE(it, data.rcodeCounts.end());
+            EXPECT_LT(it->second, 10);
+        }
+    }
+}
+
 }  // namespace android::net
diff --git a/Experiments.h b/Experiments.h
index b200373..e267e50 100644
--- a/Experiments.h
+++ b/Experiments.h
@@ -49,7 +49,8 @@
     // TODO: Migrate other experiment flags to here.
     // (retry_count, retransmission_time_interval, dot_connect_timeout_ms)
     static constexpr const char* const kExperimentFlagKeyList[] = {
-            "keep_listening_udp", "parallel_lookup", "parallel_lookup_sleep_time"};
+            "keep_listening_udp", "parallel_lookup", "parallel_lookup_sleep_time",
+            "sort_nameservers"};
     // This value is used in updateInternal as the default value if any flags can't be found.
     static constexpr int kFlagIntDefault = INT_MIN;
     // For testing.
diff --git a/res_cache.cpp b/res_cache.cpp
index 560de93..ffa2929 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -60,6 +60,7 @@
 #include <server_configurable_flags/get_flags.h>
 
 #include "DnsStats.h"
+#include "Experiments.h"
 #include "res_comp.h"
 #include "res_debug.h"
 #include "resolv_private.h"
@@ -69,6 +70,7 @@
 using android::base::StringAppendF;
 using android::net::DnsQueryEvent;
 using android::net::DnsStats;
+using android::net::Experiments;
 using android::net::PROTO_DOT;
 using android::net::PROTO_TCP;
 using android::net::PROTO_UDP;
@@ -1682,7 +1684,10 @@
     NetConfig* info = find_netconfig_locked(statp->netid);
     if (info == nullptr) return;
 
-    statp->nsaddrs = info->nameserverSockAddrs;
+    const bool sortNameservers = Experiments::getInstance()->getFlag("sort_nameservers", 0);
+    statp->sort_nameservers = sortNameservers;
+    statp->nsaddrs = sortNameservers ? info->dnsStats.getSortedServers(PROTO_UDP)
+                                     : info->nameserverSockAddrs;
     statp->search_domains = info->search_domains;
     statp->tc_mode = info->tc_mode;
     statp->enforce_dns_uid = info->enforceDnsUid;
diff --git a/res_send.cpp b/res_send.cpp
index 2db8220..fe38994 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -497,6 +497,15 @@
     int usableServersCount = android_net_res_stats_get_usable_servers(
             &params, stats, statp->nameserverCount(), usable_servers);
 
+    if (statp->sort_nameservers) {
+        // It's unnecessary to mark a DNS server as unusable since broken servers will be less
+        // likely to be chosen.
+        for (int i = 0; i < statp->nameserverCount(); i++) {
+            usable_servers[i] = true;
+        }
+    }
+
+    // TODO: Let it always choose the first nameserver when sort_nameservers is enabled.
     if ((flags & ANDROID_RESOLV_NO_RETRY) && usableServersCount > 1) {
         auto hp = reinterpret_cast<const HEADER*>(buf);
 
diff --git a/resolv_private.h b/resolv_private.h
index f1b667e..de127e0 100644
--- a/resolv_private.h
+++ b/resolv_private.h
@@ -113,6 +113,8 @@
     uint32_t netcontext_flags;
     int tc_mode = 0;
     bool enforce_dns_uid = false;
+    bool sort_nameservers = false;              // A flag to indicate whether nsaddrs has been
+                                                // sorted or not.
     // clang-format on
 };
 
diff --git a/tests/resolv_integration_test.cpp b/tests/resolv_integration_test.cpp
index 1ea7d01..f44a0b5 100644
--- a/tests/resolv_integration_test.cpp
+++ b/tests/resolv_integration_test.cpp
@@ -148,6 +148,19 @@
     int internal_errors = 0;
 };
 
+class ScopedSystemProperties {
+  public:
+    ScopedSystemProperties(const std::string& key, const std::string& value) : mStoredKey(key) {
+        mStoredValue = android::base::GetProperty(key, "");
+        android::base::SetProperty(key, value);
+    }
+    ~ScopedSystemProperties() { android::base::SetProperty(mStoredKey, mStoredValue); }
+
+  private:
+    std::string mStoredKey;
+    std::string mStoredValue;
+};
+
 }  // namespace
 
 class ResolverTest : public ::testing::Test {
@@ -250,7 +263,18 @@
         } while (true);
     }
 
-    bool expectStatsFromGetResolverInfo(const std::vector<NameserverStats>& nameserversStats) {
+    enum class StatsCmp { LE, EQ };
+
+    bool expectStatsNotGreaterThan(const std::vector<NameserverStats>& nameserversStats) {
+        return expectStatsFromGetResolverInfo(nameserversStats, StatsCmp::LE);
+    }
+
+    bool expectStatsEqualTo(const std::vector<NameserverStats>& nameserversStats) {
+        return expectStatsFromGetResolverInfo(nameserversStats, StatsCmp::EQ);
+    }
+
+    bool expectStatsFromGetResolverInfo(const std::vector<NameserverStats>& nameserversStats,
+                                        const StatsCmp cmp) {
         std::vector<std::string> res_servers;
         std::vector<std::string> res_domains;
         std::vector<std::string> res_tls_servers;
@@ -289,10 +313,23 @@
 
             // The check excludes rtt_avg, last_sample_time, and usable since they will be obsolete
             // after |res_stats| is retrieved from NetConfig.dnsStats rather than NetConfig.nsstats.
-            EXPECT_EQ(res_stats[index].successes, stats.successes);
-            EXPECT_EQ(res_stats[index].errors, stats.errors);
-            EXPECT_EQ(res_stats[index].timeouts, stats.timeouts);
-            EXPECT_EQ(res_stats[index].internal_errors, stats.internal_errors);
+            switch (cmp) {
+                case StatsCmp::EQ:
+                    EXPECT_EQ(res_stats[index].successes, stats.successes);
+                    EXPECT_EQ(res_stats[index].errors, stats.errors);
+                    EXPECT_EQ(res_stats[index].timeouts, stats.timeouts);
+                    EXPECT_EQ(res_stats[index].internal_errors, stats.internal_errors);
+                    break;
+                case StatsCmp::LE:
+                    EXPECT_LE(res_stats[index].successes, stats.successes);
+                    EXPECT_LE(res_stats[index].errors, stats.errors);
+                    EXPECT_LE(res_stats[index].timeouts, stats.timeouts);
+                    EXPECT_LE(res_stats[index].internal_errors, stats.internal_errors);
+                    break;
+                default:
+                    ADD_FAILURE() << "Unknown comparator " << static_cast<int>(cmp);
+                    return false;
+            }
         }
 
         return true;
@@ -1040,39 +1077,66 @@
 }
 
 TEST_F(ResolverTest, SkipBadServersDueToInternalError) {
+    const std::string kSortNameserversFlag("persist.device_config.netd_native.sort_nameservers");
     constexpr char listen_addr1[] = "fe80::1";
     constexpr char listen_addr2[] = "255.255.255.255";
     constexpr char listen_addr3[] = "127.0.0.3";
-
+    int counter = 0;  // To generate unique hostnames.
     test::DNSResponder dns(listen_addr3);
     ASSERT_TRUE(dns.startServer());
 
-    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
-    parcel.servers = {listen_addr1, listen_addr2, listen_addr3};
+    ResolverParamsParcel setupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
+    setupParams.servers = {listen_addr1, listen_addr2, listen_addr3};
+    setupParams.minSamples = 2;  // Recognize bad servers in two attempts when sorting not enabled.
 
-    // Bad servers can be distinguished after two attempts.
-    parcel.minSamples = 2;
-    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+    ResolverParamsParcel cleanupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
+    cleanupParams.servers.clear();
+    cleanupParams.tlsServers.clear();
 
-    // Start querying five times.
-    for (int i = 0; i < 5; i++) {
-        std::string hostName = StringPrintf("hello%d.com.", i);
-        dns.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
-        const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
-        EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
+    for (const auto& sortNameserversFlag : {"" /* unset */, "0" /* off */, "1" /* on */}) {
+        SCOPED_TRACE(fmt::format("sortNameversFlag_{}", sortNameserversFlag));
+        ScopedSystemProperties scopedSystemProperties(kSortNameserversFlag, sortNameserversFlag);
+
+        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(setupParams));
+
+        // Start sending synchronized querying.
+        for (int i = 0; i < 100; i++) {
+            std::string hostName = StringPrintf("hello%d.com.", counter++);
+            dns.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
+            const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
+            EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
+        }
+
+        const std::vector<NameserverStats> targetStats = {
+                NameserverStats(listen_addr1).setInternalErrors(5),
+                NameserverStats(listen_addr2).setInternalErrors(5),
+                NameserverStats(listen_addr3).setSuccesses(setupParams.maxSamples),
+        };
+        EXPECT_TRUE(expectStatsNotGreaterThan(targetStats));
+
+        // Also verify the number of queries received in the server because res_stats.successes has
+        // a maximum.
+        EXPECT_EQ(dns.queries().size(), 100U);
+
+        // Reset the state.
+        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(cleanupParams));
+        dns.clearQueries();
     }
-
-    const std::vector<NameserverStats> expectedCleartextDnsStats = {
-            NameserverStats(listen_addr1).setInternalErrors(2),
-            NameserverStats(listen_addr2).setInternalErrors(2),
-            NameserverStats(listen_addr3).setSuccesses(5),
-    };
-    EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
 }
 
 TEST_F(ResolverTest, SkipBadServersDueToTimeout) {
+    const std::string kSortNameserversFlag("persist.device_config.netd_native.sort_nameservers");
     constexpr char listen_addr1[] = "127.0.0.3";
     constexpr char listen_addr2[] = "127.0.0.4";
+    int counter = 0;  // To generate unique hostnames.
+
+    ResolverParamsParcel setupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
+    setupParams.servers = {listen_addr1, listen_addr2};
+    setupParams.minSamples = 2;  // Recognize bad servers in two attempts when sorting not enabled.
+
+    ResolverParamsParcel cleanupParams = DnsResponderClient::GetDefaultResolverParamsParcel();
+    cleanupParams.servers.clear();
+    cleanupParams.tlsServers.clear();
 
     // Set dns1 non-responsive and dns2 workable.
     test::DNSResponder dns1(listen_addr1, test::kDefaultListenService, static_cast<ns_rcode>(-1));
@@ -1081,29 +1145,38 @@
     ASSERT_TRUE(dns1.startServer());
     ASSERT_TRUE(dns2.startServer());
 
-    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
-    parcel.servers = {listen_addr1, listen_addr2};
+    for (const auto& sortNameserversFlag : {"" /* unset */, "0" /* off */, "1" /* on */}) {
+        SCOPED_TRACE(fmt::format("sortNameversFlag_{}", sortNameserversFlag));
+        ScopedSystemProperties scopedSystemProperties(kSortNameserversFlag, sortNameserversFlag);
 
-    // Bad servers can be distinguished after two attempts.
-    parcel.minSamples = 2;
-    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
+        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(setupParams));
 
-    // Start querying five times.
-    for (int i = 0; i < 5; i++) {
-        std::string hostName = StringPrintf("hello%d.com.", i);
-        dns1.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
-        dns2.addMapping(hostName, ns_type::ns_t_a, "1.2.3.5");
-        const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
-        EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
+        // Start sending synchronized querying.
+        for (int i = 0; i < 100; i++) {
+            std::string hostName = StringPrintf("hello%d.com.", counter++);
+            dns1.addMapping(hostName, ns_type::ns_t_a, "1.2.3.4");
+            dns2.addMapping(hostName, ns_type::ns_t_a, "1.2.3.5");
+            const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
+            EXPECT_TRUE(safe_getaddrinfo(hostName.c_str(), nullptr, &hints) != nullptr);
+        }
+
+        const std::vector<NameserverStats> targetStats = {
+                NameserverStats(listen_addr1).setTimeouts(5),
+                NameserverStats(listen_addr2).setSuccesses(setupParams.maxSamples),
+        };
+        EXPECT_TRUE(expectStatsNotGreaterThan(targetStats));
+
+        // Also verify the number of queries received in the server because res_stats.successes has
+        // an upper bound.
+        EXPECT_GT(dns1.queries().size(), 0U);
+        EXPECT_LT(dns1.queries().size(), 5U);
+        EXPECT_EQ(dns2.queries().size(), 100U);
+
+        // Reset the state.
+        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(cleanupParams));
+        dns1.clearQueries();
+        dns2.clearQueries();
     }
-
-    const std::vector<NameserverStats> expectedCleartextDnsStats = {
-            NameserverStats(listen_addr1).setTimeouts(2),
-            NameserverStats(listen_addr2).setSuccesses(5),
-    };
-    EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
-    EXPECT_EQ(dns1.queries().size(), 2U);
-    EXPECT_EQ(dns2.queries().size(), 5U);
 }
 
 TEST_F(ResolverTest, GetAddrInfoFromCustTable_InvalidInput) {
@@ -1489,7 +1562,7 @@
             NameserverStats(listen_addr2).setErrors(1),
             NameserverStats(listen_addr3).setSuccesses(1),
     };
-    EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
+    EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
 }
 
 TEST_F(ResolverTest, AlwaysUseLatestSetupParamsInLookups) {
@@ -1547,7 +1620,7 @@
             NameserverStats(listen_addr2),
             NameserverStats(listen_addr3).setSuccesses(1),
     };
-    EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
+    EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
 }
 
 // Test what happens if the specified TLS server is nonexistent.
@@ -4124,28 +4197,9 @@
     }
 }
 
-namespace {
-
-const std::string kDotConnectTimeoutMsFlag(
-        "persist.device_config.netd_native.dot_connect_timeout_ms");
-
-class ScopedSystemProperties {
-  public:
-    explicit ScopedSystemProperties(const std::string& key, const std::string& value)
-        : mStoredKey(key) {
-        mStoredValue = android::base::GetProperty(key, "");
-        android::base::SetProperty(key, value);
-    }
-    ~ScopedSystemProperties() { android::base::SetProperty(mStoredKey, mStoredValue); }
-
-  private:
-    std::string mStoredKey;
-    std::string mStoredValue;
-};
-
-}  // namespace
-
 TEST_F(ResolverTest, ConnectTlsServerTimeout) {
+    const std::string kDotConnectTimeoutMsFlag(
+            "persist.device_config.netd_native.dot_connect_timeout_ms");
     constexpr int expectedTimeout = 1000;
     constexpr char hostname1[] = "query1.example.com.";
     constexpr char hostname2[] = "query2.example.com.";
@@ -4390,6 +4444,11 @@
         dns.clearQueries();
         dns2.clearQueries();
         ASSERT_TRUE(mDnsClient.resolvService()->flushNetworkCache(TEST_NETID).isOk());
+
+        // Clear the stats to make the resolver always choose the same server for the first query.
+        parcel.servers.clear();
+        parcel.tlsServers.clear();
+        ASSERT_EQ(mDnsClient.resolvService()->setResolverConfiguration(parcel).isOk(), config.ret);
     }
 }
 
@@ -4423,7 +4482,7 @@
             NameserverStats(unusable_listen_addr).setInternalErrors(1),
             NameserverStats(listen_addr).setSuccesses(1),
     };
-    EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
+    EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
     EXPECT_EQ(GetNumQueries(dns, hostname), 1U);
 
     // The stats is supposed to remain as long as the list of cleartext DNS servers is unchanged.
@@ -4453,12 +4512,12 @@
         parcel.tlsServers = config.tlsServers;
         parcel.tlsName = config.tlsName;
         repeatedSetResolversFromParcel(parcel);
-        EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
+        EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
 
         // The stats remains when the list of search domains changes.
         parcel.domains.push_back("tmp.domains");
         repeatedSetResolversFromParcel(parcel);
-        EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
+        EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
 
         // The stats remains when the parameters change (except maxSamples).
         parcel.sampleValiditySeconds++;
@@ -4467,7 +4526,7 @@
         parcel.baseTimeoutMsec++;
         parcel.retryCount++;
         repeatedSetResolversFromParcel(parcel);
-        EXPECT_TRUE(expectStatsFromGetResolverInfo(expectedCleartextDnsStats));
+        EXPECT_TRUE(expectStatsEqualTo(expectedCleartextDnsStats));
     }
 
     // The cache remains.
@@ -5116,7 +5175,7 @@
             NameserverStats(listen_addr1),
             NameserverStats(listen_addr2),
     };
-    expectStatsFromGetResolverInfo(expectedEmptyDnsStats);
+    expectStatsEqualTo(expectedEmptyDnsStats);
     EXPECT_EQ(dns1.queries().size(), 0U);
     EXPECT_EQ(dns2.queries().size(), 0U);
 }