blob: b6d0c8fef465a443e60b03f72bea4e912dc4247f [file] [log] [blame]
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <netdutils/NetNativeTestBase.h>
#include <resolv_stats_test_utils.h>
#include "PrivateDnsConfiguration.h"
#include "resolv_cache.h"
#include "tests/dns_responder/dns_responder.h"
#include "tests/dns_responder/dns_tls_frontend.h"
#include "tests/resolv_test_utils.h"
namespace android::net {
using namespace std::chrono_literals;
class PrivateDnsConfigurationTest : public NetNativeTestBase {
public:
using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
static void SetUpTestSuite() {
// stopServer() will be called in their destructor.
ASSERT_TRUE(tls1.startServer());
ASSERT_TRUE(tls2.startServer());
ASSERT_TRUE(backend.startServer());
ASSERT_TRUE(backend1ForUdpProbe.startServer());
ASSERT_TRUE(backend2ForUdpProbe.startServer());
}
void SetUp() {
mPdc.setObserver(&mObserver);
mPdc.mBackoffBuilder.withInitialRetransmissionTime(std::chrono::seconds(1))
.withMaximumRetransmissionTime(std::chrono::seconds(1));
// The default and sole action when the observer is notified of onValidationStateUpdate.
// Don't override the action. In other words, don't use WillOnce() or WillRepeatedly()
// when mObserver.onValidationStateUpdate is expected to be called, like:
//
// EXPECT_CALL(mObserver, onValidationStateUpdate).WillOnce(Return());
//
// This is to ensure that tests can monitor how many validation threads are running. Tests
// must wait until every validation thread finishes.
ON_CALL(mObserver, onValidationStateUpdate)
.WillByDefault([&](const std::string& server, Validation validation, uint32_t) {
if (validation == Validation::in_process) {
std::lock_guard guard(mObserver.lock);
auto it = mObserver.serverStateMap.find(server);
if (it == mObserver.serverStateMap.end() ||
it->second != Validation::in_process) {
// Increment runningThreads only when receive the first in_process
// notification. The rest of the continuous in_process notifications
// are due to probe retry which runs on the same thread.
// TODO: consider adding onValidationThreadStart() and
// onValidationThreadEnd() callbacks.
mObserver.runningThreads++;
}
} else if (validation == Validation::success ||
validation == Validation::fail) {
mObserver.runningThreads--;
}
std::lock_guard guard(mObserver.lock);
mObserver.serverStateMap[server] = validation;
});
// Create a NetConfig for stats.
EXPECT_EQ(0, resolv_create_cache_for_net(kNetId));
}
void TearDown() { resolv_delete_cache_for_net(kNetId); }
protected:
class MockObserver : public PrivateDnsValidationObserver {
public:
MOCK_METHOD(void, onValidationStateUpdate,
(const std::string& serverIp, Validation validation, uint32_t netId),
(override));
std::map<std::string, Validation> getServerStateMap() const {
std::lock_guard guard(lock);
return serverStateMap;
}
void removeFromServerStateMap(const std::string& server) {
std::lock_guard guard(lock);
if (const auto it = serverStateMap.find(server); it != serverStateMap.end())
serverStateMap.erase(it);
}
// The current number of validation threads running.
std::atomic<int> runningThreads = 0;
mutable std::mutex lock;
std::map<std::string, Validation> serverStateMap GUARDED_BY(lock);
};
void expectPrivateDnsStatus(PrivateDnsMode mode) {
// Use PollForCondition because mObserver is notified asynchronously.
EXPECT_TRUE(PollForCondition([&]() { return checkPrivateDnsStatus(mode); }));
}
bool checkPrivateDnsStatus(PrivateDnsMode mode) {
const PrivateDnsStatus status = mPdc.getStatus(kNetId);
if (status.mode != mode) return false;
std::map<std::string, Validation> serverStateMap;
for (const auto& [server, validation] : status.dotServersMap) {
serverStateMap[ToString(&server.ss)] = validation;
}
return (serverStateMap == mObserver.getServerStateMap());
}
bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
return mPdc.getDotServer(identity, netId).ok();
}
static constexpr uint32_t kNetId = 30;
static constexpr uint32_t kMark = 30;
static constexpr char kBackend[] = "127.0.2.1";
static constexpr char kServer1[] = "127.0.2.2";
static constexpr char kServer2[] = "127.0.2.3";
MockObserver mObserver;
PrivateDnsConfiguration mPdc;
// TODO: Because incorrect CAs result in validation failed in strict mode, have
// PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate().
inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
inline static test::DNSResponder backend{kBackend, "53"};
inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
};
TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
testing::InSequence seq;
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
}
TEST_F(PrivateDnsConfigurationTest, ValidationFail_Opportunistic) {
ASSERT_TRUE(backend.stopServer());
testing::InSequence seq;
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
// Strictly wait for all of the validation finish; otherwise, the test can crash somehow.
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
ASSERT_TRUE(backend.startServer());
}
TEST_F(PrivateDnsConfigurationTest, Revalidation_Opportunistic) {
const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
// Step 1: Set up and wait for validation complete.
testing::InSequence seq;
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
// Step 2: Simulate the DNS is temporarily broken, and then request a validation.
// Expect the validation to run as follows:
// 1. DnsResolver notifies of Validation::in_process when the validation is about to run.
// 2. The first probing fails. DnsResolver notifies of Validation::in_process.
// 3. One second later, the second probing begins and succeeds. DnsResolver notifies of
// Validation::success.
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId))
.Times(2);
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
std::thread t([] {
std::this_thread::sleep_for(1000ms);
backend.startServer();
});
backend.stopServer();
EXPECT_TRUE(mPdc.requestDotValidation(kNetId, ServerIdentity(server), kMark).ok());
t.join();
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
}
TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
backend.setDeferredResp(true);
// onValidationStateUpdate() is called in sequence.
{
testing::InSequence seq;
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::in_process, kNetId));
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}, {}, {}), 0);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 2; }));
mObserver.removeFromServerStateMap(kServer1);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
// No duplicate validation as long as not in OFF mode; otherwise, an unexpected
// onValidationStateUpdate() will be caught.
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1, kServer2}, {}, {}), 0);
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}, {}, {}), 0);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
// The status keeps unchanged if pass invalid arguments.
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}, {}, {}), -EINVAL);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
}
// The update for |kServer1| will be Validation::fail because |kServer1| is not an expected
// server for the network.
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::success, kNetId));
backend.setDeferredResp(false);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
// kServer1 is not a present server and thus should not be available from
// PrivateDnsConfiguration::getStatus().
mObserver.removeFromServerStateMap(kServer1);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
}
TEST_F(PrivateDnsConfigurationTest, Validation_NetworkDestroyedOrOffMode) {
for (const std::string_view config : {"OFF", "NETWORK_DESTROYED"}) {
SCOPED_TRACE(config);
backend.setDeferredResp(true);
testing::InSequence seq;
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
if (config == "OFF") {
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}, {}), 0);
} else if (config == "NETWORK_DESTROYED") {
mPdc.clear(kNetId);
}
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
backend.setDeferredResp(false);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
mObserver.removeFromServerStateMap(kServer1);
expectPrivateDnsStatus(PrivateDnsMode::OFF);
}
}
TEST_F(PrivateDnsConfigurationTest, NoValidation) {
// If onValidationStateUpdate() is called, the test will fail with uninteresting mock
// function calls in the end of the test.
const auto expectStatus = [&]() {
const PrivateDnsStatus status = mPdc.getStatus(kNetId);
EXPECT_EQ(status.mode, PrivateDnsMode::OFF);
EXPECT_THAT(status.dotServersMap, testing::IsEmpty());
};
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}, {}, {}), -EINVAL);
expectStatus();
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}, {}, {}), 0);
expectStatus();
}
TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
server.name = "dns.example.com";
// Different socket address.
DnsTlsServer other = server;
EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
// Different provider hostname.
other = server;
EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
other.name = "other.example.com";
EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
other.name = "";
EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
}
TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
const ServerIdentity identity(server);
testing::InSequence seq;
for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
SCOPED_TRACE(config);
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
if (config == "SUCCESS") {
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
} else if (config == "IN_PROGRESS") {
backend.setDeferredResp(true);
} else {
// config = "FAIL"
ASSERT_TRUE(backend.stopServer());
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
}
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
// Wait until the validation state is transitioned.
const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));
if (config == "SUCCESS") {
EXPECT_CALL(mObserver,
onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
EXPECT_TRUE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
} else if (config == "IN_PROGRESS") {
EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
} else if (config == "FAIL") {
EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
}
// Resending the same request or requesting nonexistent servers are denied.
EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark + 1).ok());
EXPECT_FALSE(mPdc.requestDotValidation(kNetId + 1, identity, kMark).ok());
// Reset the test state.
backend.setDeferredResp(false);
backend.startServer();
// Ensure the status of mObserver is synced.
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
mPdc.clear(kNetId);
}
}
TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
// Suppress the warning.
EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);
EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}, {}, {}), 0);
expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
}
// Tests that getStatusForMetrics() returns the correct data.
TEST_F(PrivateDnsConfigurationTest, GetStatusForMetrics) {
tls2.stopServer();
const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
// Suppress the warning.
EXPECT_CALL(mObserver, onValidationStateUpdate).Times(4);
// Set 1 unencrypted server and 2 encrypted servers (one will pass DoT validation; the other
// will fail. Both of them don't support DoH).
EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {kServer1, kServer2}, {}, {}), 0);
ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
// Get the metric before call clear().
NetworkDnsServerSupportReported event = mPdc.getStatusForMetrics(kNetId);
NetworkDnsServerSupportReported expectedEvent;
// It's NT_UNKNOWN because this test didn't call resolv_set_nameservers() to set
// the network type.
expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_OPPORTUNISTIC);
Server* server = expectedEvent.mutable_servers()->add_server();
server->set_protocol(PROTO_UDP); // kServer2
server->set_index(0);
server->set_validated(false);
server = expectedEvent.mutable_servers()->add_server();
server->set_protocol(PROTO_DOT); // kServer1
server->set_index(0);
server->set_validated(true);
server = expectedEvent.mutable_servers()->add_server();
server->set_protocol(PROTO_DOT); // kServer2
server->set_index(1);
server->set_validated(false);
server = expectedEvent.mutable_servers()->add_server();
server->set_protocol(PROTO_DOH); // kServer1
server->set_index(0);
server->set_validated(false);
server = expectedEvent.mutable_servers()->add_server();
server->set_protocol(PROTO_DOH); // kServer2
server->set_index(1);
server->set_validated(false);
EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
// Get the metric after call clear().
mPdc.clear(kNetId);
event = mPdc.getStatusForMetrics(kNetId);
expectedEvent.Clear();
expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_UNKNOWN);
EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
tls2.startServer();
}
// TODO: add ValidationFail_Strict test.
} // namespace android::net