Convert mDNS header flags to MessageType

Change-Id: I0d768a97324fbca0ce0abff89330a1f1cfa362f8
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1715979
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
Reviewed-by: Peter Thatcher <pthatcher@google.com>
diff --git a/cast/common/mdns/mdns_constants.h b/cast/common/mdns/mdns_constants.h
index d0ff6c6..053f70a 100644
--- a/cast/common/mdns/mdns_constants.h
+++ b/cast/common/mdns/mdns_constants.h
@@ -131,41 +131,33 @@
   uint16_t additional_record_count;
 };
 
-// TODO(mayaki): Here and below consider converting constants to members of
-// enum classes.
+static_assert(sizeof(Header) == 12, "Size of mDNS header must be 12 bytes.");
 
-// DNS Header flags. All flags are formatted to mask directly onto FLAG header
-// field in network-byte order.
+enum class MessageType {
+  Query = 0,
+  Response = 1,
+};
+
 constexpr uint16_t kFlagResponse = 0x8000;
 constexpr uint16_t kFlagAA = 0x0400;
 constexpr uint16_t kFlagTC = 0x0200;
-constexpr uint16_t kFlagRD = 0x0100;
-constexpr uint16_t kFlagRA = 0x0080;
-constexpr uint16_t kFlagZ = 0x0040;  // Unused field
-constexpr uint16_t kFlagAD = 0x0020;
-constexpr uint16_t kFlagCD = 0x0010;
-
-// DNS Header OPCODE mask and values. The mask is formatted to mask directly
-// onto FLAG header field in network-byte order. The values are formatted after
-// shifting into correct position.
 constexpr uint16_t kOpcodeMask = 0x7800;
-constexpr uint8_t kOpcodeQUERY = 0;
-constexpr uint8_t kOpcodeIQUERY = 1;
-constexpr uint8_t kOpcodeSTATUS = 2;
-constexpr uint8_t kOpcodeUNASSIGNED = 3;  // Unused for now
-constexpr uint8_t kOpcodeNOTIFY = 4;
-constexpr uint8_t kOpcodeUPDATE = 5;
-
-// DNS Header RCODE mask and values. The mask is formatted to mask directly onto
-// FLAG header field in network-byte order. The values are formatted after
-// shifting into correct position.
 constexpr uint16_t kRcodeMask = 0x000F;
-constexpr uint8_t kRcodeNOERROR = 0;
-constexpr uint8_t kRcodeFORMERR = 1;
-constexpr uint8_t kRcodeSERVFAIL = 2;
-constexpr uint8_t kRcodeNXDOMAIN = 3;
-constexpr uint8_t kRcodeNOTIMP = 4;
-constexpr uint8_t kRcodeREFUSED = 5;
+
+constexpr MessageType GetMessageType(uint16_t flags) {
+  // RFC 6762 Section 18.2
+  return (flags & kFlagResponse) ? MessageType::Response : MessageType::Query;
+}
+
+constexpr uint16_t MakeFlags(MessageType type) {
+  // RFC 6762 Section 18.2 and Section 18.4
+  return (type == MessageType::Response) ? (kFlagResponse | kFlagAA) : 0;
+}
+
+constexpr bool IsValidFlagsSection(uint16_t flags) {
+  // RFC 6762 Section 18.3 and Section 18.11
+  return (flags & (kOpcodeMask | kRcodeMask)) == 0;
+}
 
 // ============================================================================
 // Domain Name
diff --git a/cast/common/mdns/mdns_reader.cc b/cast/common/mdns/mdns_reader.cc
index ae4443d..36bc3ed 100644
--- a/cast/common/mdns/mdns_reader.cc
+++ b/cast/common/mdns/mdns_reader.cc
@@ -237,8 +237,12 @@
       Read(header.answer_count, &answers) &&
       Read(header.authority_record_count, &authority_records) &&
       Read(header.additional_record_count, &additional_records)) {
-    *out = MdnsMessage(header.id, header.flags, questions, answers,
-                       authority_records, additional_records);
+    // TODO(yakimakha): Skip messages with non-zero opcode and rcode.
+    // One way to do this is to change the method signature to return
+    // ErrorOr<MdnsMessage> and return different error codes for failure to read
+    // and for messages that were read successfully but are non-conforming.
+    *out = MdnsMessage(header.id, GetMessageType(header.flags), questions,
+                       answers, authority_records, additional_records);
     cursor.Commit();
     return true;
   }
diff --git a/cast/common/mdns/mdns_reader_unittest.cc b/cast/common/mdns/mdns_reader_unittest.cc
index 5153806..5fdf4d4 100644
--- a/cast/common/mdns/mdns_reader_unittest.cc
+++ b/cast/common/mdns/mdns_reader_unittest.cc
@@ -563,9 +563,10 @@
                      120, PtrRecordRdata(DomainName{"testing", "local"}));
   MdnsRecord record2(DomainName{"record2"}, DnsType::kA, DnsClass::kIN, false,
                      120, ARecordRdata(IPAddress{172, 0, 0, 1}));
-  MdnsMessage message(
-      1, 0x8400, std::vector<MdnsQuestion>{}, std::vector<MdnsRecord>{record1},
-      std::vector<MdnsRecord>{}, std::vector<MdnsRecord>{record2});
+  MdnsMessage message(1, MessageType::Response, std::vector<MdnsQuestion>{},
+                      std::vector<MdnsRecord>{record1},
+                      std::vector<MdnsRecord>{},
+                      std::vector<MdnsRecord>{record2});
   TestReadEntrySucceeds(kTestMessage, sizeof(kTestMessage), message);
 }
 
diff --git a/cast/common/mdns/mdns_receiver.cc b/cast/common/mdns/mdns_receiver.cc
index 72b6683..e3c3251 100644
--- a/cast/common/mdns/mdns_receiver.cc
+++ b/cast/common/mdns/mdns_receiver.cc
@@ -52,8 +52,7 @@
   if (!reader.Read(&message)) {
     return;
   }
-  // TODO(yakimakha): make flags a proper type and hide bit manipulation
-  if ((message.flags() & kFlagResponse) != 0) {
+  if (message.type() == MessageType::Response) {
     delegate_->OnResponseReceived(message, packet.source());
   } else {
     delegate_->OnQueryReceived(message, packet.source());
diff --git a/cast/common/mdns/mdns_receiver_unittest.cc b/cast/common/mdns/mdns_receiver_unittest.cc
index 5c056e0..45233d7 100644
--- a/cast/common/mdns/mdns_receiver_unittest.cc
+++ b/cast/common/mdns/mdns_receiver_unittest.cc
@@ -38,7 +38,7 @@
   // clang-format off
   const std::vector<uint8_t> kQueryBytes = {
       0x00, 0x01,  // ID = 1
-      0x04, 0x00,  // FLAGS = AA
+      0x00, 0x00,  // FLAGS = None
       0x00, 0x01,  // Question count
       0x00, 0x00,  // Answer count
       0x00, 0x00,  // Authority count
@@ -63,7 +63,7 @@
 
   MdnsQuestion question(DomainName{"testing", "local"}, DnsType::kA,
                         DnsClass::kIN, false);
-  MdnsMessage message(1, 0x0400);
+  MdnsMessage message(1, MessageType::Query);
   message.AddQuestion(question);
 
   UdpPacket packet(kQueryBytes.size());
@@ -119,7 +119,7 @@
 
   MdnsRecord record(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN,
                     false, 120, ARecordRdata(IPAddress{172, 0, 0, 1}));
-  MdnsMessage message(1, 0x8400);
+  MdnsMessage message(1, MessageType::Response);
   message.AddAnswer(record);
 
   UdpPacket packet(kResponseBytes.size());
diff --git a/cast/common/mdns/mdns_records.cc b/cast/common/mdns/mdns_records.cc
index 3c26df1..7775b78 100644
--- a/cast/common/mdns/mdns_records.cc
+++ b/cast/common/mdns/mdns_records.cc
@@ -31,11 +31,8 @@
 }
 
 bool DomainName::operator==(const DomainName& rhs) const {
-  auto predicate = [](const std::string& left, const std::string& right) {
-    return absl::EqualsIgnoreCase(left, right);
-  };
   return std::equal(labels_.begin(), labels_.end(), rhs.labels_.begin(),
-                    rhs.labels_.end(), predicate);
+                    rhs.labels_.end(), absl::EqualsIgnoreCase);
 }
 
 bool DomainName::operator!=(const DomainName& rhs) const {
@@ -232,17 +229,17 @@
   return name_.MaxWireSize() + sizeof(type_) + sizeof(record_class_);
 }
 
-MdnsMessage::MdnsMessage(uint16_t id, uint16_t flags)
-    : id_(id), flags_(flags) {}
+MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
+    : id_(id), type_(type) {}
 
 MdnsMessage::MdnsMessage(uint16_t id,
-                         uint16_t flags,
+                         MessageType type,
                          std::vector<MdnsQuestion> questions,
                          std::vector<MdnsRecord> answers,
                          std::vector<MdnsRecord> authority_records,
                          std::vector<MdnsRecord> additional_records)
     : id_(id),
-      flags_(flags),
+      type_(type),
       questions_(std::move(questions)),
       answers_(std::move(answers)),
       authority_records_(std::move(authority_records)),
@@ -267,8 +264,8 @@
 }
 
 bool MdnsMessage::operator==(const MdnsMessage& rhs) const {
-  return id_ == rhs.id_ && flags_ == rhs.flags_ &&
-         questions_ == rhs.questions_ && answers_ == rhs.answers_ &&
+  return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ &&
+         answers_ == rhs.answers_ &&
          authority_records_ == rhs.authority_records_ &&
          additional_records_ == rhs.additional_records_;
 }
diff --git a/cast/common/mdns/mdns_records.h b/cast/common/mdns/mdns_records.h
index adbece2..d40dfa9 100644
--- a/cast/common/mdns/mdns_records.h
+++ b/cast/common/mdns/mdns_records.h
@@ -351,9 +351,9 @@
   MdnsMessage() = default;
   // Constructs a message with ID, flags and empty question, answer, authority
   // and additional record collections.
-  MdnsMessage(uint16_t id, uint16_t flags);
+  MdnsMessage(uint16_t id, MessageType type);
   MdnsMessage(uint16_t id,
-              uint16_t flags,
+              MessageType type,
               std::vector<MdnsQuestion> questions,
               std::vector<MdnsRecord> answers,
               std::vector<MdnsRecord> authority_records,
@@ -375,7 +375,7 @@
 
   size_t MaxWireSize() const;
   uint16_t id() const { return id_; }
-  uint16_t flags() const { return flags_; }
+  MessageType type() const { return type_; }
   const std::vector<MdnsQuestion>& questions() const { return questions_; }
   const std::vector<MdnsRecord>& answers() const { return answers_; }
   const std::vector<MdnsRecord>& authority_records() const {
@@ -389,7 +389,7 @@
   // The mDNS header is 12 bytes long
   size_t max_wire_size_ = sizeof(Header);
   uint16_t id_ = 0;
-  uint16_t flags_ = 0;
+  MessageType type_ = MessageType::Query;
   std::vector<MdnsQuestion> questions_;
   std::vector<MdnsRecord> answers_;
   std::vector<MdnsRecord> authority_records_;
diff --git a/cast/common/mdns/mdns_records_unittest.cc b/cast/common/mdns/mdns_records_unittest.cc
index c82a81e..87e0c4c 100644
--- a/cast/common/mdns/mdns_records_unittest.cc
+++ b/cast/common/mdns/mdns_records_unittest.cc
@@ -387,7 +387,7 @@
   MdnsMessage message1;
   EXPECT_EQ(message1.MaxWireSize(), UINT64_C(12));
   EXPECT_EQ(message1.id(), UINT16_C(0));
-  EXPECT_EQ(message1.flags(), UINT16_C(0));
+  EXPECT_EQ(message1.type(), MessageType::Query);
   EXPECT_EQ(message1.questions().size(), UINT64_C(0));
   EXPECT_EQ(message1.answers().size(), UINT64_C(0));
   EXPECT_EQ(message1.authority_records().size(), UINT64_C(0));
@@ -402,10 +402,10 @@
   MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false,
                      120, PtrRecordRdata(DomainName{"device", "local"}));
 
-  MdnsMessage message2(123, 0x8400);
+  MdnsMessage message2(123, MessageType::Response);
   EXPECT_EQ(message2.MaxWireSize(), UINT64_C(12));
   EXPECT_EQ(message2.id(), UINT16_C(123));
-  EXPECT_EQ(message2.flags(), UINT16_C(0x8400));
+  EXPECT_EQ(message2.type(), MessageType::Response);
   EXPECT_EQ(message2.questions().size(), UINT64_C(0));
   EXPECT_EQ(message2.answers().size(), UINT64_C(0));
   EXPECT_EQ(message2.authority_records().size(), UINT64_C(0));
@@ -427,10 +427,10 @@
   EXPECT_EQ(message2.authority_records()[0], record2);
   EXPECT_EQ(message2.additional_records()[0], record3);
 
-  MdnsMessage message3(123, 0x8400, std::vector<MdnsQuestion>{question},
-                       std::vector<MdnsRecord>{record1},
-                       std::vector<MdnsRecord>{record2},
-                       std::vector<MdnsRecord>{record3});
+  MdnsMessage message3(
+      123, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{record3});
 
   EXPECT_EQ(message3.MaxWireSize(), UINT64_C(118));
   ASSERT_EQ(message3.questions().size(), UINT64_C(1));
@@ -454,38 +454,38 @@
   MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false,
                      120, PtrRecordRdata(DomainName{"device", "local"}));
 
-  MdnsMessage message1(123, 0x8400, std::vector<MdnsQuestion>{question},
+  MdnsMessage message1(
+      123, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{record3});
+  MdnsMessage message2(
+      123, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{record3});
+  MdnsMessage message3(
+      456, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{record3});
+  MdnsMessage message4(
+      123, MessageType::Query, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{record3});
+  MdnsMessage message5(123, MessageType::Response, std::vector<MdnsQuestion>{},
                        std::vector<MdnsRecord>{record1},
                        std::vector<MdnsRecord>{record2},
                        std::vector<MdnsRecord>{record3});
-  MdnsMessage message2(123, 0x8400, std::vector<MdnsQuestion>{question},
-                       std::vector<MdnsRecord>{record1},
-                       std::vector<MdnsRecord>{record2},
-                       std::vector<MdnsRecord>{record3});
-  MdnsMessage message3(456, 0x8400, std::vector<MdnsQuestion>{question},
-                       std::vector<MdnsRecord>{record1},
-                       std::vector<MdnsRecord>{record2},
-                       std::vector<MdnsRecord>{record3});
-  MdnsMessage message4(123, 0x400, std::vector<MdnsQuestion>{question},
-                       std::vector<MdnsRecord>{record1},
-                       std::vector<MdnsRecord>{record2},
-                       std::vector<MdnsRecord>{record3});
-  MdnsMessage message5(123, 0x8400, std::vector<MdnsQuestion>{},
-                       std::vector<MdnsRecord>{record1},
-                       std::vector<MdnsRecord>{record2},
-                       std::vector<MdnsRecord>{record3});
-  MdnsMessage message6(123, 0x8400, std::vector<MdnsQuestion>{question},
-                       std::vector<MdnsRecord>{},
-                       std::vector<MdnsRecord>{record2},
-                       std::vector<MdnsRecord>{record3});
-  MdnsMessage message7(123, 0x8400, std::vector<MdnsQuestion>{question},
-                       std::vector<MdnsRecord>{record1},
-                       std::vector<MdnsRecord>{},
-                       std::vector<MdnsRecord>{record3});
-  MdnsMessage message8(123, 0x8400, std::vector<MdnsQuestion>{question},
-                       std::vector<MdnsRecord>{record1},
-                       std::vector<MdnsRecord>{record2},
-                       std::vector<MdnsRecord>{});
+  MdnsMessage message6(
+      123, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{record3});
+  MdnsMessage message7(
+      123, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{},
+      std::vector<MdnsRecord>{record3});
+  MdnsMessage message8(
+      123, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{});
 
   EXPECT_EQ(message1, message2);
   EXPECT_NE(message1, message3);
@@ -505,10 +505,10 @@
                      120, TxtRecordRdata{"foo=1", "bar=2"});
   MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false,
                      120, PtrRecordRdata(DomainName{"device", "local"}));
-  MdnsMessage message(123, 0x8400, std::vector<MdnsQuestion>{question},
-                      std::vector<MdnsRecord>{record1},
-                      std::vector<MdnsRecord>{record2},
-                      std::vector<MdnsRecord>{record3});
+  MdnsMessage message(
+      123, MessageType::Response, std::vector<MdnsQuestion>{question},
+      std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2},
+      std::vector<MdnsRecord>{record3});
   TestCopyAndMove(message);
 }
 
diff --git a/cast/common/mdns/mdns_sender_unittest.cc b/cast/common/mdns/mdns_sender_unittest.cc
index 998395e..12c5b49 100644
--- a/cast/common/mdns/mdns_sender_unittest.cc
+++ b/cast/common/mdns/mdns_sender_unittest.cc
@@ -46,8 +46,8 @@
                   false,
                   120,
                   ARecordRdata(IPAddress{172, 0, 0, 1})),
-        query_message_(1, 0x0400),
-        response_message_(1, 0x8400),
+        query_message_(1, MessageType::Query),
+        response_message_(1, MessageType::Response),
         ipv4_multicast_endpoint_{
             .address = IPAddress(kDefaultMulticastGroupIPv4),
             .port = kDefaultMulticastPort},
@@ -62,7 +62,7 @@
   // clang-format off
   const std::vector<uint8_t> kQueryBytes = {
       0x00, 0x01,  // ID = 1
-      0x04, 0x00,  // FLAGS = AA
+      0x00, 0x00,  // FLAGS = None
       0x00, 0x01,  // Question count
       0x00, 0x00,  // Answer count
       0x00, 0x00,  // Authority count
@@ -149,7 +149,7 @@
 }
 
 TEST_F(MdnsSenderTest, MessageTooBig) {
-  MdnsMessage big_message_(1, 0x0400);
+  MdnsMessage big_message_(1, MessageType::Query);
   for (size_t i = 0; i < 100; ++i) {
     big_message_.AddQuestion(a_question_);
     big_message_.AddAnswer(a_record_);
diff --git a/cast/common/mdns/mdns_writer.cc b/cast/common/mdns/mdns_writer.cc
index 79d1017..a6ae96f 100644
--- a/cast/common/mdns/mdns_writer.cc
+++ b/cast/common/mdns/mdns_writer.cc
@@ -228,7 +228,7 @@
   Cursor cursor(this);
   Header header;
   header.id = message.id();
-  header.flags = message.flags();
+  header.flags = MakeFlags(message.type());
   header.question_count = message.questions().size();
   header.answer_count = message.answers().size();
   header.authority_record_count = message.authority_records().size();
diff --git a/cast/common/mdns/mdns_writer_unittest.cc b/cast/common/mdns/mdns_writer_unittest.cc
index bf8b255..a7bf676 100644
--- a/cast/common/mdns/mdns_writer_unittest.cc
+++ b/cast/common/mdns/mdns_writer_unittest.cc
@@ -365,7 +365,7 @@
   // clang-format off
   constexpr uint8_t kExpectedMessage[] = {
       0x00, 0x01,  // ID = 1
-      0x04, 0x00,  // FLAGS = AA
+      0x00, 0x00,  // FLAGS = None
       0x00, 0x01,  // Question count
       0x00, 0x00,  // Answer count
       0x00, 0x01,  // Authority count
@@ -392,7 +392,7 @@
   MdnsRecord auth_record(DomainName{"auth"}, DnsType::kTXT, DnsClass::kIN,
                          false, 120, TxtRecordRdata{"foo=1", "bar=2"});
 
-  MdnsMessage message(1, 0x0400);
+  MdnsMessage message(1, MessageType::Query);
   message.AddQuestion(question);
   message.AddAuthorityRecord(auth_record);
 
@@ -410,7 +410,7 @@
   MdnsRecord auth_record(DomainName{"auth"}, DnsType::kTXT, DnsClass::kIN, 120,
                          false, TxtRecordRdata{"foo=1", "bar=2"});
 
-  MdnsMessage message(1, 0x0400);
+  MdnsMessage message(1, MessageType::Query);
   message.AddQuestion(question);
   message.AddAuthorityRecord(auth_record);
   TestWriteEntryInsufficientBuffer(message);