Convert mDNS header flags to MessageType
Change-Id: I0d768a97324fbca0ce0abff89330a1f1cfa362f8
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1715979
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
Reviewed-by: Peter Thatcher <pthatcher@google.com>
diff --git a/cast/common/mdns/mdns_constants.h b/cast/common/mdns/mdns_constants.h
index d0ff6c6..053f70a 100644
--- a/cast/common/mdns/mdns_constants.h
+++ b/cast/common/mdns/mdns_constants.h
@@ -131,41 +131,33 @@
uint16_t additional_record_count;
};
-// TODO(mayaki): Here and below consider converting constants to members of
-// enum classes.
+static_assert(sizeof(Header) == 12, "Size of mDNS header must be 12 bytes.");
-// DNS Header flags. All flags are formatted to mask directly onto FLAG header
-// field in network-byte order.
+enum class MessageType {
+ Query = 0,
+ Response = 1,
+};
+
constexpr uint16_t kFlagResponse = 0x8000;
constexpr uint16_t kFlagAA = 0x0400;
constexpr uint16_t kFlagTC = 0x0200;
-constexpr uint16_t kFlagRD = 0x0100;
-constexpr uint16_t kFlagRA = 0x0080;
-constexpr uint16_t kFlagZ = 0x0040; // Unused field
-constexpr uint16_t kFlagAD = 0x0020;
-constexpr uint16_t kFlagCD = 0x0010;
-
-// DNS Header OPCODE mask and values. The mask is formatted to mask directly
-// onto FLAG header field in network-byte order. The values are formatted after
-// shifting into correct position.
constexpr uint16_t kOpcodeMask = 0x7800;
-constexpr uint8_t kOpcodeQUERY = 0;
-constexpr uint8_t kOpcodeIQUERY = 1;
-constexpr uint8_t kOpcodeSTATUS = 2;
-constexpr uint8_t kOpcodeUNASSIGNED = 3; // Unused for now
-constexpr uint8_t kOpcodeNOTIFY = 4;
-constexpr uint8_t kOpcodeUPDATE = 5;
-
-// DNS Header RCODE mask and values. The mask is formatted to mask directly onto
-// FLAG header field in network-byte order. The values are formatted after
-// shifting into correct position.
constexpr uint16_t kRcodeMask = 0x000F;
-constexpr uint8_t kRcodeNOERROR = 0;
-constexpr uint8_t kRcodeFORMERR = 1;
-constexpr uint8_t kRcodeSERVFAIL = 2;
-constexpr uint8_t kRcodeNXDOMAIN = 3;
-constexpr uint8_t kRcodeNOTIMP = 4;
-constexpr uint8_t kRcodeREFUSED = 5;
+
+constexpr MessageType GetMessageType(uint16_t flags) {
+ // RFC 6762 Section 18.2
+ return (flags & kFlagResponse) ? MessageType::Response : MessageType::Query;
+}
+
+constexpr uint16_t MakeFlags(MessageType type) {
+ // RFC 6762 Section 18.2 and Section 18.4
+ return (type == MessageType::Response) ? (kFlagResponse | kFlagAA) : 0;
+}
+
+constexpr bool IsValidFlagsSection(uint16_t flags) {
+ // RFC 6762 Section 18.3 and Section 18.11
+ return (flags & (kOpcodeMask | kRcodeMask)) == 0;
+}
// ============================================================================
// Domain Name
diff --git a/cast/common/mdns/mdns_reader.cc b/cast/common/mdns/mdns_reader.cc
index ae4443d..36bc3ed 100644
--- a/cast/common/mdns/mdns_reader.cc
+++ b/cast/common/mdns/mdns_reader.cc
@@ -237,8 +237,12 @@
Read(header.answer_count, &answers) &&
Read(header.authority_record_count, &authority_records) &&
Read(header.additional_record_count, &additional_records)) {
- *out = MdnsMessage(header.id, header.flags, questions, answers,
- authority_records, additional_records);
+ // TODO(yakimakha): Skip messages with non-zero opcode and rcode.
+ // One way to do this is to change the method signature to return
+ // ErrorOr<MdnsMessage> and return different error codes for failure to read
+ // and for messages that were read successfully but are non-conforming.
+ *out = MdnsMessage(header.id, GetMessageType(header.flags), questions,
+ answers, authority_records, additional_records);
cursor.Commit();
return true;
}
diff --git a/cast/common/mdns/mdns_reader_unittest.cc b/cast/common/mdns/mdns_reader_unittest.cc
index 5153806..5fdf4d4 100644
--- a/cast/common/mdns/mdns_reader_unittest.cc
+++ b/cast/common/mdns/mdns_reader_unittest.cc
@@ -563,9 +563,10 @@
120, PtrRecordRdata(DomainName{"testing", "local"}));
MdnsRecord record2(DomainName{"record2"}, DnsType::kA, DnsClass::kIN, false,
120, ARecordRdata(IPAddress{172, 0, 0, 1}));
- MdnsMessage message(
- 1, 0x8400, std::vector<MdnsQuestion>{}, std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{}, std::vector<MdnsRecord>{record2});
+ MdnsMessage message(1, MessageType::Response, std::vector<MdnsQuestion>{},
+ std::vector<MdnsRecord>{record1},
+ std::vector<MdnsRecord>{},
+ std::vector<MdnsRecord>{record2});
TestReadEntrySucceeds(kTestMessage, sizeof(kTestMessage), message);
}
diff --git a/cast/common/mdns/mdns_receiver.cc b/cast/common/mdns/mdns_receiver.cc
index 72b6683..e3c3251 100644
--- a/cast/common/mdns/mdns_receiver.cc
+++ b/cast/common/mdns/mdns_receiver.cc
@@ -52,8 +52,7 @@
if (!reader.Read(&message)) {
return;
}
- // TODO(yakimakha): make flags a proper type and hide bit manipulation
- if ((message.flags() & kFlagResponse) != 0) {
+ if (message.type() == MessageType::Response) {
delegate_->OnResponseReceived(message, packet.source());
} else {
delegate_->OnQueryReceived(message, packet.source());
diff --git a/cast/common/mdns/mdns_receiver_unittest.cc b/cast/common/mdns/mdns_receiver_unittest.cc
index 5c056e0..45233d7 100644
--- a/cast/common/mdns/mdns_receiver_unittest.cc
+++ b/cast/common/mdns/mdns_receiver_unittest.cc
@@ -38,7 +38,7 @@
// clang-format off
const std::vector<uint8_t> kQueryBytes = {
0x00, 0x01, // ID = 1
- 0x04, 0x00, // FLAGS = AA
+ 0x00, 0x00, // FLAGS = None
0x00, 0x01, // Question count
0x00, 0x00, // Answer count
0x00, 0x00, // Authority count
@@ -63,7 +63,7 @@
MdnsQuestion question(DomainName{"testing", "local"}, DnsType::kA,
DnsClass::kIN, false);
- MdnsMessage message(1, 0x0400);
+ MdnsMessage message(1, MessageType::Query);
message.AddQuestion(question);
UdpPacket packet(kQueryBytes.size());
@@ -119,7 +119,7 @@
MdnsRecord record(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
false, 120, ARecordRdata(IPAddress{172, 0, 0, 1}));
- MdnsMessage message(1, 0x8400);
+ MdnsMessage message(1, MessageType::Response);
message.AddAnswer(record);
UdpPacket packet(kResponseBytes.size());
diff --git a/cast/common/mdns/mdns_records.cc b/cast/common/mdns/mdns_records.cc
index 3c26df1..7775b78 100644
--- a/cast/common/mdns/mdns_records.cc
+++ b/cast/common/mdns/mdns_records.cc
@@ -31,11 +31,8 @@
}
bool DomainName::operator==(const DomainName& rhs) const {
- auto predicate = [](const std::string& left, const std::string& right) {
- return absl::EqualsIgnoreCase(left, right);
- };
return std::equal(labels_.begin(), labels_.end(), rhs.labels_.begin(),
- rhs.labels_.end(), predicate);
+ rhs.labels_.end(), absl::EqualsIgnoreCase);
}
bool DomainName::operator!=(const DomainName& rhs) const {
@@ -232,17 +229,17 @@
return name_.MaxWireSize() + sizeof(type_) + sizeof(record_class_);
}
-MdnsMessage::MdnsMessage(uint16_t id, uint16_t flags)
- : id_(id), flags_(flags) {}
+MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
+ : id_(id), type_(type) {}
MdnsMessage::MdnsMessage(uint16_t id,
- uint16_t flags,
+ MessageType type,
std::vector<MdnsQuestion> questions,
std::vector<MdnsRecord> answers,
std::vector<MdnsRecord> authority_records,
std::vector<MdnsRecord> additional_records)
: id_(id),
- flags_(flags),
+ type_(type),
questions_(std::move(questions)),
answers_(std::move(answers)),
authority_records_(std::move(authority_records)),
@@ -267,8 +264,8 @@
}
bool MdnsMessage::operator==(const MdnsMessage& rhs) const {
- return id_ == rhs.id_ && flags_ == rhs.flags_ &&
- questions_ == rhs.questions_ && answers_ == rhs.answers_ &&
+ return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ &&
+ answers_ == rhs.answers_ &&
authority_records_ == rhs.authority_records_ &&
additional_records_ == rhs.additional_records_;
}
diff --git a/cast/common/mdns/mdns_records.h b/cast/common/mdns/mdns_records.h
index adbece2..d40dfa9 100644
--- a/cast/common/mdns/mdns_records.h
+++ b/cast/common/mdns/mdns_records.h
@@ -351,9 +351,9 @@
MdnsMessage() = default;
// Constructs a message with ID, flags and empty question, answer, authority
// and additional record collections.
- MdnsMessage(uint16_t id, uint16_t flags);
+ MdnsMessage(uint16_t id, MessageType type);
MdnsMessage(uint16_t id,
- uint16_t flags,
+ MessageType type,
std::vector<MdnsQuestion> questions,
std::vector<MdnsRecord> answers,
std::vector<MdnsRecord> authority_records,
@@ -375,7 +375,7 @@
size_t MaxWireSize() const;
uint16_t id() const { return id_; }
- uint16_t flags() const { return flags_; }
+ MessageType type() const { return type_; }
const std::vector<MdnsQuestion>& questions() const { return questions_; }
const std::vector<MdnsRecord>& answers() const { return answers_; }
const std::vector<MdnsRecord>& authority_records() const {
@@ -389,7 +389,7 @@
// The mDNS header is 12 bytes long
size_t max_wire_size_ = sizeof(Header);
uint16_t id_ = 0;
- uint16_t flags_ = 0;
+ MessageType type_ = MessageType::Query;
std::vector<MdnsQuestion> questions_;
std::vector<MdnsRecord> answers_;
std::vector<MdnsRecord> authority_records_;
diff --git a/cast/common/mdns/mdns_records_unittest.cc b/cast/common/mdns/mdns_records_unittest.cc
index c82a81e..87e0c4c 100644
--- a/cast/common/mdns/mdns_records_unittest.cc
+++ b/cast/common/mdns/mdns_records_unittest.cc
@@ -387,7 +387,7 @@
MdnsMessage message1;
EXPECT_EQ(message1.MaxWireSize(), UINT64_C(12));
EXPECT_EQ(message1.id(), UINT16_C(0));
- EXPECT_EQ(message1.flags(), UINT16_C(0));
+ EXPECT_EQ(message1.type(), MessageType::Query);
EXPECT_EQ(message1.questions().size(), UINT64_C(0));
EXPECT_EQ(message1.answers().size(), UINT64_C(0));
EXPECT_EQ(message1.authority_records().size(), UINT64_C(0));
@@ -402,10 +402,10 @@
MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false,
120, PtrRecordRdata(DomainName{"device", "local"}));
- MdnsMessage message2(123, 0x8400);
+ MdnsMessage message2(123, MessageType::Response);
EXPECT_EQ(message2.MaxWireSize(), UINT64_C(12));
EXPECT_EQ(message2.id(), UINT16_C(123));
- EXPECT_EQ(message2.flags(), UINT16_C(0x8400));
+ EXPECT_EQ(message2.type(), MessageType::Response);
EXPECT_EQ(message2.questions().size(), UINT64_C(0));
EXPECT_EQ(message2.answers().size(), UINT64_C(0));
EXPECT_EQ(message2.authority_records().size(), UINT64_C(0));
@@ -427,10 +427,10 @@
EXPECT_EQ(message2.authority_records()[0], record2);
EXPECT_EQ(message2.additional_records()[0], record3);
- MdnsMessage message3(123, 0x8400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{record3});
+ MdnsMessage message3(
+ 123, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{record3});
EXPECT_EQ(message3.MaxWireSize(), UINT64_C(118));
ASSERT_EQ(message3.questions().size(), UINT64_C(1));
@@ -454,38 +454,38 @@
MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false,
120, PtrRecordRdata(DomainName{"device", "local"}));
- MdnsMessage message1(123, 0x8400, std::vector<MdnsQuestion>{question},
+ MdnsMessage message1(
+ 123, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{record3});
+ MdnsMessage message2(
+ 123, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{record3});
+ MdnsMessage message3(
+ 456, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{record3});
+ MdnsMessage message4(
+ 123, MessageType::Query, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{record3});
+ MdnsMessage message5(123, MessageType::Response, std::vector<MdnsQuestion>{},
std::vector<MdnsRecord>{record1},
std::vector<MdnsRecord>{record2},
std::vector<MdnsRecord>{record3});
- MdnsMessage message2(123, 0x8400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{record3});
- MdnsMessage message3(456, 0x8400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{record3});
- MdnsMessage message4(123, 0x400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{record3});
- MdnsMessage message5(123, 0x8400, std::vector<MdnsQuestion>{},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{record3});
- MdnsMessage message6(123, 0x8400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{record3});
- MdnsMessage message7(123, 0x8400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{},
- std::vector<MdnsRecord>{record3});
- MdnsMessage message8(123, 0x8400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{});
+ MdnsMessage message6(
+ 123, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{record3});
+ MdnsMessage message7(
+ 123, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{},
+ std::vector<MdnsRecord>{record3});
+ MdnsMessage message8(
+ 123, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{});
EXPECT_EQ(message1, message2);
EXPECT_NE(message1, message3);
@@ -505,10 +505,10 @@
120, TxtRecordRdata{"foo=1", "bar=2"});
MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false,
120, PtrRecordRdata(DomainName{"device", "local"}));
- MdnsMessage message(123, 0x8400, std::vector<MdnsQuestion>{question},
- std::vector<MdnsRecord>{record1},
- std::vector<MdnsRecord>{record2},
- std::vector<MdnsRecord>{record3});
+ MdnsMessage message(
+ 123, MessageType::Response, std::vector<MdnsQuestion>{question},
+ std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+ std::vector<MdnsRecord>{record3});
TestCopyAndMove(message);
}
diff --git a/cast/common/mdns/mdns_sender_unittest.cc b/cast/common/mdns/mdns_sender_unittest.cc
index 998395e..12c5b49 100644
--- a/cast/common/mdns/mdns_sender_unittest.cc
+++ b/cast/common/mdns/mdns_sender_unittest.cc
@@ -46,8 +46,8 @@
false,
120,
ARecordRdata(IPAddress{172, 0, 0, 1})),
- query_message_(1, 0x0400),
- response_message_(1, 0x8400),
+ query_message_(1, MessageType::Query),
+ response_message_(1, MessageType::Response),
ipv4_multicast_endpoint_{
.address = IPAddress(kDefaultMulticastGroupIPv4),
.port = kDefaultMulticastPort},
@@ -62,7 +62,7 @@
// clang-format off
const std::vector<uint8_t> kQueryBytes = {
0x00, 0x01, // ID = 1
- 0x04, 0x00, // FLAGS = AA
+ 0x00, 0x00, // FLAGS = None
0x00, 0x01, // Question count
0x00, 0x00, // Answer count
0x00, 0x00, // Authority count
@@ -149,7 +149,7 @@
}
TEST_F(MdnsSenderTest, MessageTooBig) {
- MdnsMessage big_message_(1, 0x0400);
+ MdnsMessage big_message_(1, MessageType::Query);
for (size_t i = 0; i < 100; ++i) {
big_message_.AddQuestion(a_question_);
big_message_.AddAnswer(a_record_);
diff --git a/cast/common/mdns/mdns_writer.cc b/cast/common/mdns/mdns_writer.cc
index 79d1017..a6ae96f 100644
--- a/cast/common/mdns/mdns_writer.cc
+++ b/cast/common/mdns/mdns_writer.cc
@@ -228,7 +228,7 @@
Cursor cursor(this);
Header header;
header.id = message.id();
- header.flags = message.flags();
+ header.flags = MakeFlags(message.type());
header.question_count = message.questions().size();
header.answer_count = message.answers().size();
header.authority_record_count = message.authority_records().size();
diff --git a/cast/common/mdns/mdns_writer_unittest.cc b/cast/common/mdns/mdns_writer_unittest.cc
index bf8b255..a7bf676 100644
--- a/cast/common/mdns/mdns_writer_unittest.cc
+++ b/cast/common/mdns/mdns_writer_unittest.cc
@@ -365,7 +365,7 @@
// clang-format off
constexpr uint8_t kExpectedMessage[] = {
0x00, 0x01, // ID = 1
- 0x04, 0x00, // FLAGS = AA
+ 0x00, 0x00, // FLAGS = None
0x00, 0x01, // Question count
0x00, 0x00, // Answer count
0x00, 0x01, // Authority count
@@ -392,7 +392,7 @@
MdnsRecord auth_record(DomainName{"auth"}, DnsType::kTXT, DnsClass::kIN,
false, 120, TxtRecordRdata{"foo=1", "bar=2"});
- MdnsMessage message(1, 0x0400);
+ MdnsMessage message(1, MessageType::Query);
message.AddQuestion(question);
message.AddAuthorityRecord(auth_record);
@@ -410,7 +410,7 @@
MdnsRecord auth_record(DomainName{"auth"}, DnsType::kTXT, DnsClass::kIN, 120,
false, TxtRecordRdata{"foo=1", "bar=2"});
- MdnsMessage message(1, 0x0400);
+ MdnsMessage message(1, MessageType::Query);
message.AddQuestion(question);
message.AddAuthorityRecord(auth_record);
TestWriteEntryInsufficientBuffer(message);