EDNS0 clientsubnet (#3447)

* EDN0 ClientSubnet support

* Fix dispatcher

* EDNS0ClientSubnet unit tests

* Python2 fixes
diff --git a/scapy/layers/dns.py b/scapy/layers/dns.py
index 8ce7e75..757836c 100755
--- a/scapy/layers/dns.py
+++ b/scapy/layers/dns.py
@@ -8,25 +8,39 @@
 """
 
 from __future__ import absolute_import
+import operator
+import socket
 import struct
 import time
 import warnings
 
+from scapy.ansmachine import AnsweringMachine
+from scapy.base_classes import Net
 from scapy.config import conf
-from scapy.packet import Packet, bind_layers, NoPayload
+from scapy.compat import orb, raw, chb, bytes_encode, plain_str
+from scapy.error import log_runtime, warning, Scapy_Exception
+from scapy.packet import Packet, bind_layers, NoPayload, Raw
 from scapy.fields import BitEnumField, BitField, ByteEnumField, ByteField, \
     ConditionalField, Field, FieldLenField, FlagsField, IntField, \
     PacketListField, ShortEnumField, ShortField, StrField, \
-    StrLenField, MultipleTypeField, UTCTimeField
-from scapy.compat import orb, raw, chb, bytes_encode
-from scapy.ansmachine import AnsweringMachine
+    StrLenField, MultipleTypeField, UTCTimeField, I
 from scapy.sendrecv import sr1
+from scapy.pton_ntop import inet_ntop, inet_pton
+
 from scapy.layers.inet import IP, DestIPField, IPField, UDP, TCP
 from scapy.layers.inet6 import DestIP6Field, IP6Field
-from scapy.error import log_runtime, warning, Scapy_Exception
 import scapy.libs.six as six
 
 
+from scapy.compat import (
+    Any,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
+
+
 # https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4
 dnstypes = {
     0: "ANY",
@@ -520,15 +534,33 @@
 
 # RFC 2671 - Extension Mechanisms for DNS (EDNS0)
 
+edns0types = {0: "Reserved", 1: "LLQ", 2: "UL", 3: "NSID", 4: "Reserved",
+              5: "PING", 8: "edns-client-subnet"}
+
+
 class EDNS0TLV(Packet):
     name = "DNS EDNS0 TLV"
-    fields_desc = [ShortEnumField("optcode", 0, {0: "Reserved", 1: "LLQ", 2: "UL", 3: "NSID", 4: "Reserved", 5: "PING"}),  # noqa: E501
+    fields_desc = [ShortEnumField("optcode", 0, edns0types),
                    FieldLenField("optlen", None, "optdata", fmt="H"),
-                   StrLenField("optdata", "", length_from=lambda pkt: pkt.optlen)]  # noqa: E501
+                   StrLenField("optdata", "",
+                               length_from=lambda pkt: pkt.optlen)]
 
     def extract_padding(self, p):
+        # type: (bytes) -> Tuple[bytes, Optional[bytes]]
         return "", p
 
+    @classmethod
+    def dispatch_hook(cls, _pkt=None, *args, **kargs):
+        # type: (Optional[bytes], *Any, **Any) -> Type[Packet]
+        if _pkt is None:
+            return EDNS0TLV
+        if len(_pkt) < 2:
+            return Raw
+        edns0type = struct.unpack("!H", _pkt[:2])[0]
+        if edns0type == 8:
+            return EDNS0ClientSubnet
+        return EDNS0TLV
+
 
 class DNSRROPT(InheritOriginDNSStrPacket):
     name = "DNS OPT Resource Record"
@@ -541,7 +573,88 @@
                    BitEnumField("z", 32768, 16, {32768: "D0"}),
                    # D0 means DNSSEC OK from RFC 3225
                    FieldLenField("rdlen", None, length_of="rdata", fmt="H"),
-                   PacketListField("rdata", [], EDNS0TLV, length_from=lambda pkt: pkt.rdlen)]  # noqa: E501
+                   PacketListField("rdata", [], EDNS0TLV,
+                                   length_from=lambda pkt: pkt.rdlen)]
+
+
+# RFC 7871 - Client Subnet in DNS Queries
+
+class ClientSubnetv4(StrLenField):
+    af_familly = socket.AF_INET
+    af_length = 32
+    af_default = b"\xc0"  # 192.0.0.0
+
+    def getfield(self, pkt, s):
+        # type: (Packet, bytes) -> Tuple[bytes, I]
+        sz = operator.floordiv(self.length_from(pkt), 8)
+        sz = min(sz, operator.floordiv(self.af_length, 8))
+        return s[sz:], self.m2i(pkt, s[:sz])
+
+    def m2i(self, pkt, x):
+        # type: (Optional[Packet], bytes) -> str
+        padding = self.af_length - self.length_from(pkt)
+        if padding:
+            x += b"\x00" * operator.floordiv(padding, 8)
+        x = x[: operator.floordiv(self.af_length, 8)]
+        return inet_ntop(self.af_familly, x)
+
+    def _pack_subnet(self, subnet):
+        # type: (bytes) -> bytes
+        packed_subnet = inet_pton(self.af_familly, plain_str(subnet))
+        for i in list(range(operator.floordiv(self.af_length, 8)))[::-1]:
+            if orb(packed_subnet[i]) != 0:
+                i += 1
+                break
+        return packed_subnet[:i]
+
+    def i2m(self, pkt, x):
+        # type: (Optional[Packet], Optional[Union[str, Net]]) -> bytes
+        if x is None:
+            return self.af_default
+        try:
+            return self._pack_subnet(x)
+        except (OSError, socket.error):
+            pkt.family = 2
+            return ClientSubnetv6("", "")._pack_subnet(x)
+
+    def i2len(self, pkt, x):
+        # type: (Packet, Any) -> int
+        if x is None:
+            return 1
+        try:
+            return len(self._pack_subnet(x))
+        except (OSError, socket.error):
+            pkt.family = 2
+            return len(ClientSubnetv6("", "")._pack_subnet(x))
+
+
+class ClientSubnetv6(ClientSubnetv4):
+    af_familly = socket.AF_INET6
+    af_length = 128
+    af_default = b"\x20"  # 2000::
+
+
+class EDNS0ClientSubnet(Packet):
+    name = "DNS EDNS0 Client Subnet"
+    fields_desc = [ShortEnumField("optcode", 8, edns0types),
+                   FieldLenField("optlen", None, "address", fmt="H",
+                                 adjust=lambda pkt, x: x + 4),
+                   ShortField("family", 1),
+                   FieldLenField("source_plen", None,
+                                 length_of="address",
+                                 fmt="B",
+                                 adjust=lambda pkt, x: x * 8),
+                   ByteField("scope_plen", 0),
+                   MultipleTypeField(
+                       [(ClientSubnetv4("address", "192.168.0.0",
+                         length_from=lambda p: p.source_plen),
+                         lambda pkt: pkt.family == 1),
+                        (ClientSubnetv6("address", "2001:db8::",
+                         length_from=lambda p: p.source_plen),
+                         lambda pkt: pkt.family == 2)],
+                       ClientSubnetv4("address", "192.168.0.0",
+                                      length_from=lambda p: p.source_plen))]
+
 
 # RFC 4034 - Resource Records for the DNS Security Extensions
 
diff --git a/test/scapy/layers/dns_edns0.uts b/test/scapy/layers/dns_edns0.uts
index d35871b..143957d 100644
--- a/test/scapy/layers/dns_edns0.uts
+++ b/test/scapy/layers/dns_edns0.uts
@@ -69,3 +69,23 @@
     len(r.ar) and DNSRROPT in r.ar and len(r.ar[DNSRROPT].rdata) and len([x for x in r.ar[DNSRROPT].rdata if x.optcode == 3])
 
 retry_test(_test)
+
+
++ EDNS0 - Client Subnet
+
+= Basic instantiation & dissection
+
+raw_d = b'\x00\x00)\x10\x00\x00\x00\x00\x00\x00\n\x00\x08\x00\x06\x00\x01\x10\x00\xc0\xa8'
+
+d = DNSRROPT(z=0, rdata=[EDNS0ClientSubnet()])
+assert raw(d) == raw_d
+
+d = DNSRROPT(raw_d)
+assert EDNS0ClientSubnet in d.rdata[0] and d.rdata[0].family == 1 and d.rdata[0].address == "192.168.0.0"
+
+raw_d  = b'\x00\x00)\x10\x00\x00\x00\x00\x00\x00\x0c\x00\x08\x00\x08\x00\x02 \x00 \x01\r\xb8'
+d = DNSRROPT(z=0, rdata=[EDNS0ClientSubnet(address="2001:db8::")])
+assert raw(d) == raw_d
+
+d = DNSRROPT(raw_d)
+assert EDNS0ClientSubnet in d.rdata[0] and d.rdata[0].family == 2 and d.rdata[0].address == "2001:db8::"