blob: 5c8434bd809f7295fe0834a65a64647a1dfdce44 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/mdns/mdns_querier.h"
#include <memory>
#include <utility>
#include "discovery/common/config.h"
#include "discovery/common/testing/mock_reporting_client.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
#include "discovery/mdns/mdns_trackers.h"
#include "discovery/mdns/mdns_writer.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "platform/base/udp_packet.h"
#include "platform/test/fake_clock.h"
#include "platform/test/fake_task_runner.h"
#include "platform/test/mock_udp_socket.h"
namespace openscreen {
namespace discovery {
using testing::_;
using testing::Args;
using testing::Invoke;
using testing::Return;
using testing::StrictMock;
using testing::WithArgs;
// Only compare NAME, CLASS, TYPE and RDATA
ACTION_P(PartialCompareRecords, expected) {
const MdnsRecord& actual = arg0;
EXPECT_TRUE(actual.name() == expected.name());
EXPECT_TRUE(actual.dns_class() == expected.dns_class());
EXPECT_TRUE(actual.dns_type() == expected.dns_type());
EXPECT_TRUE(actual.rdata() == expected.rdata());
return std::vector<PendingQueryChange>{};
}
class MockRecordChangedCallback : public MdnsRecordChangedCallback {
public:
MOCK_METHOD(std::vector<PendingQueryChange>,
OnRecordChanged,
(const MdnsRecord&, RecordChangedEvent event),
(override));
};
class MdnsQuerierTest : public testing::Test {
public:
MdnsQuerierTest()
: clock_(Clock::now()),
task_runner_(&clock_),
sender_(&socket_),
receiver_(config_),
record0_created_(DomainName{"testing", "local"},
DnsType::kA,
DnsClass::kIN,
RecordType::kUnique,
std::chrono::seconds(120),
ARecordRdata(IPAddress{172, 0, 0, 1})),
record0_updated_(DomainName{"testing", "local"},
DnsType::kA,
DnsClass::kIN,
RecordType::kUnique,
std::chrono::seconds(120),
ARecordRdata(IPAddress{172, 0, 0, 2})),
record0_deleted_(DomainName{"testing", "local"},
DnsType::kA,
DnsClass::kIN,
RecordType::kUnique,
std::chrono::seconds(0), // a goodbye record
ARecordRdata(IPAddress{172, 0, 0, 2})),
record1_created_(DomainName{"poking", "local"},
DnsType::kA,
DnsClass::kIN,
RecordType::kShared,
std::chrono::seconds(120),
ARecordRdata(IPAddress{192, 168, 0, 1})),
record1_deleted_(DomainName{"poking", "local"},
DnsType::kA,
DnsClass::kIN,
RecordType::kShared,
std::chrono::seconds(0), // a goodbye record
ARecordRdata(IPAddress{192, 168, 0, 1})),
record2_created_(DomainName{"testing", "local"},
DnsType::kAAAA,
DnsClass::kIN,
RecordType::kUnique,
std::chrono::seconds(120),
AAAARecordRdata(IPAddress{1, 2, 3, 4, 5, 6, 7, 8})),
nsec_record_created_(
DomainName{"testing", "local"},
DnsType::kNSEC,
DnsClass::kIN,
RecordType::kUnique,
std::chrono::seconds(120),
NsecRecordRdata(DomainName{"testing", "local"}, DnsType::kA)) {
receiver_.Start();
}
std::unique_ptr<MdnsQuerier> CreateQuerier() {
return std::make_unique<MdnsQuerier>(&sender_, &receiver_, &task_runner_,
&FakeClock::now, &random_,
&reporting_client_, config_);
}
protected:
template <typename... DnsTypes>
MdnsRecord CreateNsec(DomainName name, DnsTypes... types) {
NsecRecordRdata rdata(name, types...);
return MdnsRecord(std::move(name), nsec_record_created_.dns_type(),
nsec_record_created_.dns_class(),
nsec_record_created_.record_type(),
nsec_record_created_.ttl(), std::move(rdata));
}
UdpPacket CreatePacketWithRecords(
const std::vector<MdnsRecord::ConstRef>& records,
std::vector<MdnsRecord::ConstRef> additional_records) {
MdnsMessage message(CreateMessageId(), MessageType::Response);
for (const MdnsRecord& record : records) {
message.AddAnswer(record);
}
for (const MdnsRecord& additional_record : additional_records) {
message.AddAdditionalRecord(additional_record);
}
UdpPacket packet(message.MaxWireSize());
MdnsWriter writer(packet.data(), packet.size());
EXPECT_TRUE(writer.Write(message));
packet.resize(writer.offset());
return packet;
}
UdpPacket CreatePacketWithRecords(
const std::vector<MdnsRecord::ConstRef>& records) {
return CreatePacketWithRecords(records, {});
}
UdpPacket CreatePacketWithRecord(const MdnsRecord& record) {
return CreatePacketWithRecords({MdnsRecord::ConstRef(record)});
}
// NSEC records are never exposed to outside callers, so the below methods are
// necessary to validate that they are functioning as expected.
bool ContainsRecord(MdnsQuerier* querier,
const MdnsRecord& record,
DnsType type = DnsType::kANY) {
auto record_trackers =
querier->records_.Find(record.name(), type, record.dns_class());
return std::find_if(record_trackers.begin(), record_trackers.end(),
[&record](const MdnsRecordTracker& tracker) {
return tracker.rdata() == record.rdata() &&
tracker.ttl() == record.ttl();
}) != record_trackers.end();
}
size_t RecordCount(MdnsQuerier* querier) { return querier->records_.size(); }
Config config_;
FakeClock clock_;
FakeTaskRunner task_runner_;
testing::NiceMock<MockUdpSocket> socket_;
MdnsSender sender_;
MdnsReceiver receiver_;
MdnsRandom random_;
StrictMock<MockReportingClient> reporting_client_;
MdnsRecord record0_created_;
MdnsRecord record0_updated_;
MdnsRecord record0_deleted_;
MdnsRecord record1_created_;
MdnsRecord record1_deleted_;
MdnsRecord record2_created_;
MdnsRecord nsec_record_created_;
};
TEST_F(MdnsQuerierTest, UniqueRecordCreatedUpdatedDeleted) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kUpdated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_updated_)));
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kExpired))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_deleted_)));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
// Receiving the same record should only reset TTL, no callback
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_deleted_));
// Advance clock for expiration to happen, since it's delayed by 1 second as
// per RFC 6762.
clock_.Advance(std::chrono::seconds(1));
}
TEST_F(MdnsQuerierTest, WildcardQuery) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"poking", "local"}, DnsType::kANY,
DnsClass::kANY, &callback);
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record1_created_)));
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kExpired))
.WillOnce(WithArgs<0>(PartialCompareRecords(record1_deleted_)));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_deleted_));
// Advance clock for expiration to happen, since it's delayed by 1 second as
// per RFC 6762.
clock_.Advance(std::chrono::seconds(1));
}
TEST_F(MdnsQuerierTest, SharedRecordCreatedDeleted) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record1_created_)));
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kExpired))
.WillOnce(WithArgs<0>(PartialCompareRecords(record1_deleted_)));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
// Receiving the same record should only reset TTL, no callback
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_deleted_));
// Advance clock for expiration to happen, since it's delayed by 1 second as
// per RFC 6762.
clock_.Advance(std::chrono::seconds(1));
}
TEST_F(MdnsQuerierTest, StartQueryTwice) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback, OnRecordChanged(_, _)).Times(1);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
}
TEST_F(MdnsQuerierTest, MultipleCallbacks) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback_1;
MockRecordChangedCallback callback_2;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback_1);
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback_2);
EXPECT_CALL(callback_1, OnRecordChanged(_, _)).Times(1);
EXPECT_CALL(callback_2, OnRecordChanged(_, _)).Times(2);
// Both callbacks will be invoked.
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
&callback_1);
// Only callback_2 will be invoked.
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
&callback_2);
// No callbacks will be invoked as all have been stopped.
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
}
TEST_F(MdnsQuerierTest, NoRecordChangesAfterStop) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback, OnRecordChanged(_, _)).Times(1);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
}
TEST_F(MdnsQuerierTest, OnRecordChangeCallbacksGetRun) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
DomainName name = DomainName{"testing", "local"};
querier->StartQuery(name, DnsType::kA, DnsClass::kIN, &callback);
PendingQueryChange result{name, DnsType::kA, DnsClass::kIN, &callback,
PendingQueryChange::kStopQuery};
EXPECT_CALL(callback, OnRecordChanged(_, _))
.WillOnce(Return(std::vector<PendingQueryChange>{result}));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
}
TEST_F(MdnsQuerierTest, StopQueryTwice) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback, OnRecordChanged(_, _)).Times(0);
querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
}
TEST_F(MdnsQuerierTest, StopNonExistingQuery) {
// Just making sure nothing crashes.
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
}
TEST_F(MdnsQuerierTest, IrrelevantRecordReceived) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback, OnRecordChanged(_, _)).Times(1);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
}
TEST_F(MdnsQuerierTest, DifferentCallersSameQuestion) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback1;
MockRecordChangedCallback callback2;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback1);
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback2);
EXPECT_CALL(callback1, OnRecordChanged(_, _)).Times(1);
EXPECT_CALL(callback2, OnRecordChanged(_, _)).Times(1);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
}
TEST_F(MdnsQuerierTest, DifferentCallersDifferentQuestions) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback1;
MockRecordChangedCallback callback2;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback1);
querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
&callback2);
EXPECT_CALL(callback1, OnRecordChanged(_, _)).Times(1);
EXPECT_CALL(callback2, OnRecordChanged(_, _)).Times(1);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
}
TEST_F(MdnsQuerierTest, SameCallerDifferentQuestions) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
EXPECT_CALL(callback, OnRecordChanged(_, _)).Times(2);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
}
TEST_F(MdnsQuerierTest, ReinitializeQueries) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
// Receiving the same record should only reset TTL, no callback
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
testing::Mock::VerifyAndClearExpectations(&receiver_);
// Queries should still be ongoing but all received records should have been
// deleted.
querier->ReinitializeQueries(DomainName{"testing", "local"});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
testing::Mock::VerifyAndClearExpectations(&receiver_);
// Reinitializing a different domain should not affect other queries.
querier->ReinitializeQueries(DomainName{"testing2", "local"});
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
}
TEST_F(MdnsQuerierTest, MessagesForUnknownQueriesDropped) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
// Message for unknown query does not get processed.
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record1_created_));
querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
testing::Mock::VerifyAndClearExpectations(&callback);
querier->StopQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
// Only known records from the message are processed.
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.Times(1);
receiver_.OnRead(
&socket_, CreatePacketWithRecords({record0_created_, record1_created_}));
querier->StartQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
}
TEST_F(MdnsQuerierTest, MessagesForKnownRecordsAllowed) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
// Store a message for a known query.
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_created_));
testing::Mock::VerifyAndClearExpectations(&callback);
// Stop the query and validate that record updates are still received.
querier->StopQuery(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
receiver_.OnRead(&socket_, CreatePacketWithRecord(record0_updated_));
testing::Mock::VerifyAndClearExpectations(&callback);
querier->StopQuery(DomainName{"poking", "local"}, DnsType::kA, DnsClass::kIN,
&callback);
// Only known records from the message are processed.
EXPECT_CALL(callback,
OnRecordChanged(record0_updated_, RecordChangedEvent::kCreated))
.Times(1);
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
}
TEST_F(MdnsQuerierTest, MessagesForUnknownKnownRecordsAllowsAdditionalRecords) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MockRecordChangedCallback callback;
// Store a message for a known query.
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback,
OnRecordChanged(record0_created_, RecordChangedEvent::kCreated))
.Times(1);
receiver_.OnRead(&socket_, CreatePacketWithRecords({record1_created_},
{record0_created_}));
testing::Mock::VerifyAndClearExpectations(&callback);
}
TEST_F(MdnsQuerierTest, NsecDroppedWhenCorrespondsToNonNsec) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
auto packet = CreatePacketWithRecords(
{record0_created_, nsec_record_created_, record2_created_});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)))
.WillOnce(WithArgs<0>(PartialCompareRecords(record2_created_)));
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{2});
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), record2_created_, DnsType::kAAAA));
// Do it again in the another record order.
querier = CreateQuerier();
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
packet = CreatePacketWithRecords(
{nsec_record_created_, record2_created_, record0_created_});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)))
.WillOnce(WithArgs<0>(PartialCompareRecords(record2_created_)));
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{2});
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), record2_created_, DnsType::kAAAA));
}
TEST_F(MdnsQuerierTest, MultipleNsecCombinedWhenSameName) {
DomainName alternate_name{"poking", "local"};
MdnsRecord multi_type_nsec =
CreateNsec(nsec_record_created_.name(), DnsType::kA, DnsType::kAAAA);
MdnsRecord aaaa_nsec =
CreateNsec(nsec_record_created_.name(), DnsType::kAAAA);
MdnsRecord srv_nsec = CreateNsec(alternate_name, DnsType::kSRV);
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
querier->StartQuery(alternate_name, DnsType::kANY, DnsClass::kIN, &callback);
auto packet =
CreatePacketWithRecords({nsec_record_created_, srv_nsec, aaaa_nsec});
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{3});
EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), srv_nsec, DnsType::kSRV));
}
TEST_F(MdnsQuerierTest, NsecBitsUnSetWhenCorrespondsToExistingRecord) {
MdnsRecord multi_type_nsec =
CreateNsec(nsec_record_created_.name(), DnsType::kA, DnsType::kAAAA);
MdnsRecord aaaa_nsec =
CreateNsec(nsec_record_created_.name(), DnsType::kAAAA);
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
auto packet = CreatePacketWithRecords({multi_type_nsec, record0_created_});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{2});
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), aaaa_nsec, DnsType::kAAAA));
// Do it again in the another record order.
querier = CreateQuerier();
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
packet = CreatePacketWithRecords({record0_created_, multi_type_nsec});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)));
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{2});
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), aaaa_nsec, DnsType::kAAAA));
}
TEST_F(MdnsQuerierTest, NsecDroppedWhenAllBitsUnset) {
MdnsRecord multi_type_nsec =
CreateNsec(nsec_record_created_.name(), DnsType::kA, DnsType::kAAAA);
MdnsRecord aaaa(record0_created_.name(), DnsType::kAAAA,
record0_created_.dns_class(), record0_created_.record_type(),
record0_created_.ttl(),
AAAARecordRdata(IPAddress{172, 0, 0, 0, 0, 0, 0, 1}));
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
auto packet =
CreatePacketWithRecords({multi_type_nsec, record0_created_, aaaa});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.WillOnce(WithArgs<0>(PartialCompareRecords(record0_created_)))
.WillOnce(WithArgs<0>(PartialCompareRecords(aaaa)));
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{2});
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), aaaa, DnsType::kAAAA));
}
TEST_F(MdnsQuerierTest, NsecUpdatesMultipleRecords) {
const DomainName& first_name = record0_created_.name();
const DomainName alternate_name{"poking", "local"};
const DomainName third_name{"third", "local"};
MdnsRecord nsec_first_name =
CreateNsec(first_name, DnsType::kA, DnsType::kAAAA, DnsType::kSRV);
MdnsRecord aaaa(alternate_name, DnsType::kAAAA, record0_created_.dns_class(),
record0_created_.record_type(), record0_created_.ttl(),
AAAARecordRdata(IPAddress{172, 0, 0, 0, 0, 0, 0, 1}));
MdnsRecord nsec_second_name =
CreateNsec(alternate_name, DnsType::kA, DnsType::kAAAA, DnsType::kSRV);
MdnsRecord srv(third_name, DnsType::kSRV, record0_created_.dns_class(),
record0_created_.record_type(), record0_created_.ttl(),
SrvRecordRdata(1, 2, 3, third_name));
MdnsRecord nsec_third_name =
CreateNsec(third_name, DnsType::kA, DnsType::kAAAA, DnsType::kSRV);
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(first_name, DnsType::kANY, DnsClass::kIN, &callback);
querier->StartQuery(alternate_name, DnsType::kANY, DnsClass::kIN, &callback);
querier->StartQuery(third_name, DnsType::kANY, DnsClass::kIN, &callback);
auto packet =
CreatePacketWithRecords({nsec_first_name, nsec_second_name,
nsec_third_name, record0_created_, aaaa, srv});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.Times(3);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{9});
MdnsRecord new_first_nsec =
CreateNsec(first_name, DnsType::kAAAA, DnsType::kSRV);
MdnsRecord new_second_nsec =
CreateNsec(alternate_name, DnsType::kA, DnsType::kSRV);
MdnsRecord new_third_nsec =
CreateNsec(third_name, DnsType::kA, DnsType::kAAAA);
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), new_first_nsec, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), new_first_nsec, DnsType::kSRV));
EXPECT_TRUE(ContainsRecord(querier.get(), new_second_nsec, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), aaaa, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), new_second_nsec, DnsType::kSRV));
EXPECT_TRUE(ContainsRecord(querier.get(), new_third_nsec, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), new_third_nsec, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), srv, DnsType::kSRV));
}
TEST_F(MdnsQuerierTest, NsecDeletesMultipleRecords) {
const DomainName& first_name = record0_created_.name();
const DomainName alternate_name{"poking", "local"};
MdnsRecord nsec_first = CreateNsec(first_name, DnsType::kA);
MdnsRecord aaaa(alternate_name, DnsType::kAAAA, record0_created_.dns_class(),
record0_created_.record_type(), record0_created_.ttl(),
AAAARecordRdata(IPAddress{172, 0, 0, 0, 0, 0, 0, 1}));
MdnsRecord srv(alternate_name, DnsType::kSRV, record0_created_.dns_class(),
record0_created_.record_type(), record0_created_.ttl(),
SrvRecordRdata(1, 2, 3, first_name));
MdnsRecord nsec_second =
CreateNsec(alternate_name, DnsType::kAAAA, DnsType::kSRV);
MdnsRecord nsec_third = CreateNsec(alternate_name, DnsType::kAAAA);
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(first_name, DnsType::kANY, DnsClass::kIN, &callback);
querier->StartQuery(alternate_name, DnsType::kANY, DnsClass::kIN, &callback);
auto packet = CreatePacketWithRecords(
{record0_created_, aaaa, srv, nsec_first, nsec_second, nsec_third});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.Times(3);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{3});
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), aaaa, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), srv, DnsType::kSRV));
}
TEST_F(MdnsQuerierTest, NsecCreatesUpdatesAndDeletesMultipleRecords) {
const DomainName& first_name = record0_created_.name();
const DomainName alternate_name{"poking", "local"};
const DomainName third_name{"third", "local"};
MdnsRecord nsec_first_name = CreateNsec(first_name, DnsType::kA);
MdnsRecord aaaa(alternate_name, DnsType::kAAAA, record0_created_.dns_class(),
record0_created_.record_type(), record0_created_.ttl(),
AAAARecordRdata(IPAddress{172, 0, 0, 0, 0, 0, 0, 1}));
MdnsRecord nsec_second_name =
CreateNsec(alternate_name, DnsType::kAAAA, DnsType::kSRV);
MdnsRecord srv(alternate_name, DnsType::kSRV, record0_created_.dns_class(),
record0_created_.record_type(), record0_created_.ttl(),
SrvRecordRdata(1, 2, 3, third_name));
MdnsRecord nsec_third_name = CreateNsec(third_name, DnsType::kA);
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(first_name, DnsType::kANY, DnsClass::kIN, &callback);
querier->StartQuery(alternate_name, DnsType::kANY, DnsClass::kIN, &callback);
querier->StartQuery(third_name, DnsType::kANY, DnsClass::kIN, &callback);
auto packet =
CreatePacketWithRecords({nsec_first_name, nsec_second_name,
nsec_third_name, record0_created_, aaaa, srv});
EXPECT_CALL(callback, OnRecordChanged(_, RecordChangedEvent::kCreated))
.Times(3);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{4});
MdnsRecord new_second_nsec = CreateNsec(alternate_name, DnsType::kSRV);
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), aaaa, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), srv, DnsType::kSRV));
EXPECT_TRUE(ContainsRecord(querier.get(), nsec_third_name, DnsType::kA));
}
TEST_F(MdnsQuerierTest, CallbackNotCalledOnStartQueryForNsecRecords) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
// Set up so an NSEC record has been received
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
auto packet = CreatePacketWithRecord(nsec_record_created_);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{1});
EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
// Start new query
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
}
TEST_F(MdnsQuerierTest, ReceiveNsecRecordFansOutToEachType) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
MdnsRecord multi_type_nsec =
MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
nsec_record_created_.dns_class(),
nsec_record_created_.record_type(), nsec_record_created_.ttl(),
NsecRecordRdata(nsec_record_created_.name(), DnsType::kA,
DnsType::kSRV, DnsType::kAAAA));
auto packet = CreatePacketWithRecord(multi_type_nsec);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{3});
EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kSRV));
}
TEST_F(MdnsQuerierTest, ReceiveNsecKAnyRecordFansOutToAllTypes) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
MdnsRecord any_type_nsec =
MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
nsec_record_created_.dns_class(),
nsec_record_created_.record_type(), nsec_record_created_.ttl(),
NsecRecordRdata(nsec_record_created_.name(), DnsType::kANY));
auto packet = CreatePacketWithRecord(any_type_nsec);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{5});
EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kAAAA));
EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kSRV));
EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kTXT));
EXPECT_TRUE(ContainsRecord(querier.get(), any_type_nsec, DnsType::kPTR));
}
TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNsecRecordReplacesNonNsec) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
// Set up so an A record has been received
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback,
OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
auto packet = CreatePacketWithRecord(record0_created_);
receiver_.OnRead(&socket_, std::move(packet));
testing::Mock::VerifyAndClearExpectations(&callback);
ASSERT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_FALSE(
ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
EXPECT_CALL(callback,
OnRecordChanged(record0_created_, RecordChangedEvent::kExpired));
packet = CreatePacketWithRecord(nsec_record_created_);
receiver_.OnRead(&socket_, std::move(packet));
EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
}
TEST_F(MdnsQuerierTest,
NoCallbackCalledWhenNsecRecordWouldReplaceNonNsecButNsecDisabled) {
config_.ignore_nsec_responses = true;
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
// Set up so an A record has been received
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
EXPECT_CALL(callback,
OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
auto packet = CreatePacketWithRecord(record0_created_);
receiver_.OnRead(&socket_, std::move(packet));
testing::Mock::VerifyAndClearExpectations(&callback);
ASSERT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_FALSE(
ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
packet = CreatePacketWithRecord(nsec_record_created_);
receiver_.OnRead(&socket_, std::move(packet));
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_FALSE(
ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
}
TEST_F(MdnsQuerierTest, CorrectCallbackCalledWhenNonNsecRecordReplacesNsec) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
// Set up so an A record has been received
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
auto packet = CreatePacketWithRecord(nsec_record_created_);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_CALL(callback,
OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
packet = CreatePacketWithRecord(record0_created_);
receiver_.OnRead(&socket_, std::move(packet));
EXPECT_FALSE(
ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
}
TEST_F(MdnsQuerierTest, NoCallbackCalledWhenSecondNsecRecordReceived) {
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
MdnsRecord multi_type_nsec =
MdnsRecord(nsec_record_created_.name(), nsec_record_created_.dns_type(),
nsec_record_created_.dns_class(),
nsec_record_created_.record_type(), nsec_record_created_.ttl(),
NsecRecordRdata(nsec_record_created_.name(), DnsType::kA,
DnsType::kSRV, DnsType::kAAAA));
// Set up so an A record has been received
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, &callback);
auto packet = CreatePacketWithRecord(nsec_record_created_);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_TRUE(ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
EXPECT_FALSE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
packet = CreatePacketWithRecord(multi_type_nsec);
receiver_.OnRead(&socket_, std::move(packet));
EXPECT_FALSE(
ContainsRecord(querier.get(), nsec_record_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), multi_type_nsec, DnsType::kA));
}
TEST_F(MdnsQuerierTest, TestMaxRecordsRespected) {
config_.querier_max_records_cached = 1;
std::unique_ptr<MdnsQuerier> querier = CreateQuerier();
// Set up so an A record has been received
StrictMock<MockRecordChangedCallback> callback;
querier->StartQuery(DomainName{"testing", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
querier->StartQuery(DomainName{"poking", "local"}, DnsType::kANY,
DnsClass::kIN, &callback);
auto packet = CreatePacketWithRecord(record0_created_);
EXPECT_CALL(callback,
OnRecordChanged(record0_created_, RecordChangedEvent::kCreated));
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{1});
EXPECT_TRUE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_FALSE(ContainsRecord(querier.get(), record1_created_, DnsType::kA));
testing::Mock::VerifyAndClearExpectations(&callback);
EXPECT_CALL(callback,
OnRecordChanged(record0_created_, RecordChangedEvent::kExpired));
EXPECT_CALL(callback,
OnRecordChanged(record1_created_, RecordChangedEvent::kCreated));
packet = CreatePacketWithRecord(record1_created_);
receiver_.OnRead(&socket_, std::move(packet));
ASSERT_EQ(RecordCount(querier.get()), size_t{1});
EXPECT_FALSE(ContainsRecord(querier.get(), record0_created_, DnsType::kA));
EXPECT_TRUE(ContainsRecord(querier.get(), record1_created_, DnsType::kA));
}
} // namespace discovery
} // namespace openscreen