blob: 0860adab58e76e7ddbae25e4d0d730c07a068082 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/mdns/mdns_responder.h"
#include <utility>
#include "discovery/common/config.h"
#include "discovery/mdns/mdns_probe_manager.h"
#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_receiver.h"
#include "discovery/mdns/mdns_records.h"
#include "discovery/mdns/mdns_sender.h"
#include "platform/test/fake_clock.h"
#include "platform/test/fake_task_runner.h"
#include "platform/test/fake_udp_socket.h"
namespace openscreen {
namespace discovery {
namespace {
constexpr Clock::duration kMaximumSharedRecordResponseDelayMs(120 * 1000);
bool ContainsRecordType(const std::vector<MdnsRecord>& records, DnsType type) {
return std::find_if(records.begin(), records.end(),
[type](const MdnsRecord& record) {
return record.dns_type() == type;
}) != records.end();
}
void CheckSingleNsecRecordType(const MdnsMessage& message, DnsType type) {
ASSERT_EQ(message.answers().size(), size_t{1});
const MdnsRecord record = message.answers()[0];
ASSERT_EQ(record.dns_type(), DnsType::kNSEC);
const NsecRecordRdata& rdata = absl::get<NsecRecordRdata>(record.rdata());
ASSERT_EQ(rdata.types().size(), size_t{1});
EXPECT_EQ(rdata.types()[0], type);
}
void CheckPtrDomain(const MdnsRecord& record, const DomainName& domain) {
ASSERT_EQ(record.dns_type(), DnsType::kPTR);
const PtrRecordRdata& rdata = absl::get<PtrRecordRdata>(record.rdata());
EXPECT_EQ(rdata.ptr_domain(), domain);
}
void ExpectContainsNsecRecordType(const std::vector<MdnsRecord>& records,
DnsType type) {
auto it = std::find_if(
records.begin(), records.end(), [type](const MdnsRecord& record) {
if (record.dns_type() != DnsType::kNSEC) {
return false;
}
const NsecRecordRdata& rdata =
absl::get<NsecRecordRdata>(record.rdata());
return rdata.types().size() == 1 && rdata.types()[0] == type;
});
EXPECT_TRUE(it != records.end());
}
} // namespace
using testing::_;
using testing::Args;
using testing::Invoke;
using testing::Return;
using testing::StrictMock;
class MockRecordHandler : public MdnsResponder::RecordHandler {
public:
void AddRecord(MdnsRecord record) { records_.push_back(record); }
MOCK_METHOD3(HasRecords, bool(const DomainName&, DnsType, DnsClass));
std::vector<MdnsRecord::ConstRef> GetRecords(const DomainName& name,
DnsType type,
DnsClass clazz) override {
std::vector<MdnsRecord::ConstRef> records;
for (const auto& record : records_) {
if (type == DnsType::kANY || record.dns_type() == type) {
records.push_back(record);
}
}
return records;
}
std::vector<MdnsRecord::ConstRef> GetPtrRecords(DnsClass clazz) override {
std::vector<MdnsRecord::ConstRef> records;
for (const auto& record : records_) {
if (record.dns_type() == DnsType::kPTR) {
records.push_back(record);
}
}
return records;
}
private:
std::vector<MdnsRecord> records_;
};
class MockMdnsSender : public MdnsSender {
public:
explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
MOCK_METHOD2(SendMessage,
Error(const MdnsMessage& message, const IPEndpoint& endpoint));
};
class MockProbeManager : public MdnsProbeManager {
public:
MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&));
MOCK_METHOD2(RespondToProbeQuery,
void(const MdnsMessage&, const IPEndpoint&));
};
class MdnsResponderTest : public testing::Test {
public:
MdnsResponderTest()
: clock_(Clock::now()),
task_runner_(&clock_),
socket_(&task_runner_),
sender_(&socket_),
receiver_(config_),
responder_(&record_handler_,
&probe_manager_,
&sender_,
&receiver_,
&task_runner_,
FakeClock::now,
&random_,
config_) {}
protected:
MdnsRecord GetFakePtrRecord(const DomainName& target) {
DomainName name(++target.labels().begin(), target.labels().end());
PtrRecordRdata rdata(target);
return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN,
RecordType::kUnique, std::chrono::seconds(0), rdata);
}
MdnsRecord GetFakeSrvRecord(const DomainName& name) {
SrvRecordRdata rdata(0, 0, 80, name);
return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique,
std::chrono::seconds(0), rdata);
}
MdnsRecord GetFakeTxtRecord(const DomainName& name) {
TxtRecordRdata rdata;
return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique,
std::chrono::seconds(0), rdata);
}
MdnsRecord GetFakeARecord(const DomainName& name) {
ARecordRdata rdata(IPAddress(192, 168, 0, 0));
return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique,
std::chrono::seconds(0), rdata);
}
MdnsRecord GetFakeAAAARecord(const DomainName& name) {
AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8));
return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique,
std::chrono::seconds(0), rdata);
}
void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src) {
responder_.OnMessageReceived(message, src);
}
void QueryForRecordTypeWhenNonePresent(DnsType type) {
MdnsQuestion question(domain_, type, DnsClass::kANY,
ResponseType::kMulticast);
MdnsMessage message(0, MessageType::Query);
message.AddQuestion(question);
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([type](const MdnsMessage& msg) -> Error {
CheckSingleNsecRecordType(msg, type);
return Error::None();
});
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
OnMessageReceived(message, endpoint_);
}
MdnsMessage CreateMulticastMdnsQuery(DnsType type) {
MdnsQuestion question(domain_, type, DnsClass::kANY,
ResponseType::kMulticast);
MdnsMessage message(0, MessageType::Query);
message.AddQuestion(std::move(question));
return message;
}
MdnsMessage CreateTypeEnumerationQuery() {
MdnsQuestion question(type_enumeration_domain_, DnsType::kPTR,
DnsClass::kANY, ResponseType::kMulticast);
MdnsMessage message(0, MessageType::Query);
message.AddQuestion(std::move(question));
return message;
}
const Config config_;
FakeClock clock_;
FakeTaskRunner task_runner_;
FakeUdpSocket socket_;
StrictMock<MockMdnsSender> sender_;
StrictMock<MockRecordHandler> record_handler_;
StrictMock<MockProbeManager> probe_manager_;
MdnsReceiver receiver_;
MdnsRandom random_;
MdnsResponder responder_;
DomainName domain_{"instance", "_googlecast", "_tcp", "local"};
DomainName type_enumeration_domain_{"_services", "_dns-sd", "_udp", "local"};
IPEndpoint endpoint_{IPAddress(192, 168, 0, 0), 80};
};
// Validate that when records may be sent from multiple receivers, the broadcast
// is delayed and it is not delayed otherwise.
TEST_F(MdnsResponderTest, OwnedRecordsSentImmediately) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
OnMessageReceived(message, endpoint_);
testing::Mock::VerifyAndClearExpectations(&sender_);
testing::Mock::VerifyAndClearExpectations(&record_handler_);
testing::Mock::VerifyAndClearExpectations(&probe_manager_);
EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
}
TEST_F(MdnsResponderTest, NonOwnedRecordsDelayed) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_)).Times(0);
OnMessageReceived(message, endpoint_);
testing::Mock::VerifyAndClearExpectations(&sender_);
testing::Mock::VerifyAndClearExpectations(&record_handler_);
testing::Mock::VerifyAndClearExpectations(&probe_manager_);
EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
}
TEST_F(MdnsResponderTest, MultipleQuestionsProcessed) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
MdnsQuestion question2(domain_, DnsType::kANY, DnsClass::kANY,
ResponseType::kMulticast);
message.AddQuestion(std::move(question2));
EXPECT_CALL(probe_manager_, IsDomainClaimed(_))
.WillOnce(Return(true))
.WillOnce(Return(false));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
OnMessageReceived(message, endpoint_);
testing::Mock::VerifyAndClearExpectations(&sender_);
testing::Mock::VerifyAndClearExpectations(&record_handler_);
testing::Mock::VerifyAndClearExpectations(&probe_manager_);
EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
}
// Validate that the correct messaging scheme (unicast vs multicast) is used.
TEST_F(MdnsResponderTest, UnicastMessageSentOverUnicast) {
MdnsQuestion question(domain_, DnsType::kANY, DnsClass::kANY,
ResponseType::kUnicast);
MdnsMessage message(0, MessageType::Query);
message.AddQuestion(question);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
EXPECT_CALL(sender_, SendMessage(_, endpoint_)).Times(1);
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, MulticastMessageSentOverMulticast) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_)).Times(1);
OnMessageReceived(message, endpoint_);
}
// Validate that records are added as expected based on the query type, and that
// additional records are populated as specified in RFC 6762 and 6763.
TEST_F(MdnsResponderTest, AnyQueryResultsAllApplied) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(GetFakeAAAARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{4});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kTXT));
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
EXPECT_FALSE(ContainsRecordType(message.answers(), DnsType::kPTR));
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, PtrQueryResultsApplied) {
DomainName ptr_domain{"_googlecast", "_tcp", "local"};
MdnsQuestion question(ptr_domain, DnsType::kPTR, DnsClass::kANY,
ResponseType::kMulticast);
MdnsMessage message(0, MessageType::Query);
message.AddQuestion(question);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(GetFakeAAAARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{4});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
const auto& records = message.additional_records();
EXPECT_EQ(records.size(), size_t{4});
EXPECT_TRUE(ContainsRecordType(records, DnsType::kSRV));
EXPECT_TRUE(ContainsRecordType(records, DnsType::kTXT));
EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, SrvQueryResultsApplied) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(GetFakeAAAARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{2});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
const auto& records = message.additional_records();
EXPECT_EQ(records.size(), size_t{2});
EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, AQueryResultsApplied) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(GetFakeAAAARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{1});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
const auto& records = message.additional_records();
EXPECT_EQ(records.size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kA));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, AAAAQueryResultsApplied) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(GetFakeAAAARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{1});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
const auto& records = message.additional_records();
EXPECT_EQ(records.size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(records, DnsType::kA));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kAAAA));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT));
EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR));
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, MessageOnlySentIfAnswerNotKnown) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
message.AddAnswer(aaaa_record);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(aaaa_record);
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnown) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
message.AddAnswer(aaaa_record);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(aaaa_record);
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePackets) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
message.set_truncated();
MdnsMessage message2(1, MessageType::Query);
MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
message2.AddAnswer(aaaa_record);
OnMessageReceived(message, endpoint_);
OnMessageReceived(message2, endpoint_);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(aaaa_record);
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
return Error::None();
});
clock_.Advance(std::chrono::seconds(1));
}
TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePacketsOutOfOrder) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
message.set_truncated();
MdnsMessage message2(2, MessageType::Query);
MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
message2.AddAnswer(aaaa_record);
message2.set_truncated();
MdnsMessage message3(3, MessageType::Query);
MdnsRecord a_record = GetFakeARecord(domain_);
message3.AddAnswer(a_record);
OnMessageReceived(message2, endpoint_);
OnMessageReceived(message3, endpoint_);
OnMessageReceived(message, endpoint_);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(a_record);
record_handler_.AddRecord(aaaa_record);
record_handler_.AddRecord(aaaa_record);
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
return Error::None();
});
clock_.Advance(std::chrono::seconds(1));
}
TEST_F(MdnsResponderTest, RecordSentForMultiPacketsSuppressionIfMoreNotFound) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY);
MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
message.AddAnswer(aaaa_record);
message.set_truncated();
OnMessageReceived(message, endpoint_);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakeARecord(domain_));
record_handler_.AddRecord(aaaa_record);
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.additional_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
return Error::None();
});
clock_.Advance(std::chrono::seconds(1));
}
TEST_F(MdnsResponderTest, RecordNotSentForMultiPacketsSuppressionIfNoQuery) {
MdnsMessage message(1, MessageType::Query);
MdnsRecord aaaa_record = GetFakeAAAARecord(domain_);
message.AddAnswer(aaaa_record);
OnMessageReceived(message, endpoint_);
clock_.Advance(std::chrono::seconds(1));
}
// Validate NSEC records are used correctly.
TEST_F(MdnsResponderTest, QueryForRecordTypesWhenNonePresent) {
QueryForRecordTypeWhenNonePresent(DnsType::kANY);
QueryForRecordTypeWhenNonePresent(DnsType::kSRV);
QueryForRecordTypeWhenNonePresent(DnsType::kTXT);
QueryForRecordTypeWhenNonePresent(DnsType::kA);
QueryForRecordTypeWhenNonePresent(DnsType::kAAAA);
}
TEST_F(MdnsResponderTest, AAAAQueryGiveANsec) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeAAAARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA));
EXPECT_EQ(message.additional_records().size(), size_t{1});
ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, AQueryGiveAAAANsec) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA));
EXPECT_EQ(message.additional_records().size(), size_t{1});
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kAAAA);
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsecForNoAOrAAAA) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
EXPECT_EQ(message.additional_records().size(), size_t{2});
ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kAAAA);
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsec) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV));
EXPECT_EQ(message.additional_records().size(), size_t{2});
EXPECT_TRUE(
ContainsRecordType(message.additional_records(), DnsType::kA));
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kAAAA);
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForNoPtrOrSrv) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
EXPECT_EQ(message.additional_records().size(), size_t{2});
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kTXT);
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kSRV);
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlyPtr) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
EXPECT_EQ(message.additional_records().size(), size_t{2});
EXPECT_TRUE(
ContainsRecordType(message.additional_records(), DnsType::kTXT));
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kSRV);
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlySrv) {
MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR);
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
record_handler_.AddRecord(GetFakePtrRecord(domain_));
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
EXPECT_EQ(message.additional_records().size(), size_t{4});
EXPECT_TRUE(
ContainsRecordType(message.additional_records(), DnsType::kSRV));
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kTXT);
ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA);
ExpectContainsNsecRecordType(message.additional_records(),
DnsType::kAAAA);
return Error::None();
});
OnMessageReceived(message, endpoint_);
}
TEST_F(MdnsResponderTest, EnumerateAllQuery) {
MdnsMessage message = CreateTypeEnumerationQuery();
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
const auto ptr = GetFakePtrRecord(domain_);
record_handler_.AddRecord(ptr);
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
OnMessageReceived(message, endpoint_);
EXPECT_CALL(sender_, SendMulticast(_))
.WillOnce([this, &ptr](const MdnsMessage& message) -> Error {
EXPECT_EQ(message.questions().size(), size_t{0});
EXPECT_EQ(message.authority_records().size(), size_t{0});
EXPECT_EQ(message.answers().size(), size_t{1});
EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR));
EXPECT_EQ(message.answers()[0].name(), type_enumeration_domain_);
CheckPtrDomain(message.answers()[0], ptr.name());
return Error::None();
});
clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
}
TEST_F(MdnsResponderTest, EnumerateAllQueryNoResults) {
MdnsMessage message = CreateTypeEnumerationQuery();
EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false));
EXPECT_CALL(record_handler_, HasRecords(_, _, _))
.WillRepeatedly(Return(true));
const auto ptr = GetFakePtrRecord(domain_);
record_handler_.AddRecord(GetFakeSrvRecord(domain_));
record_handler_.AddRecord(GetFakeTxtRecord(domain_));
record_handler_.AddRecord(GetFakeARecord(domain_));
OnMessageReceived(message, endpoint_);
clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs));
}
} // namespace discovery
} // namespace openscreen