Test that local IPv6 connectivity goes direct.

Change-Id: Ia7f78b040358d787a1cdd183c0517927b14c4054
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index f5a4fd1..8085ae0 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -8,7 +8,6 @@
 import re
 from socket import *  # pylint: disable=wildcard-import
 import struct
-import time
 import unittest
 
 from scapy import all as scapy
@@ -210,6 +209,23 @@
     cls._SetPacketTos(packet, PING_TOS)
     return ("ICMPv%d echo reply" % version, packet)
 
+  @classmethod
+  def NS(cls, srcaddr, tgtaddr, srcmac):
+    solicited = inet_pton(AF_INET6, tgtaddr)
+    last3bytes = tuple([ord(b) for b in solicited[-3:]])
+    solicited = "ff02::1:ff%02x:%02x%02x" % last3bytes
+    packet = (scapy.IPv6(src=srcaddr, dst=solicited) /
+              scapy.ICMPv6ND_NS(tgt=tgtaddr) /
+              scapy.ICMPv6NDOptSrcLLAddr(lladdr=srcmac))
+    return ("ICMPv6 NS", packet)
+
+  @classmethod
+  def NA(cls, srcaddr, dstaddr, srcmac):
+    packet = (scapy.IPv6(src=srcaddr, dst=dstaddr) /
+              scapy.ICMPv6ND_NA(tgt=srcaddr, R=0, S=1, O=1) /
+              scapy.ICMPv6NDOptDstLLAddr(lladdr=srcmac))
+    return ("ICMPv6 NA", packet)
+
 
 class RunAsUid(object):
 
@@ -285,6 +301,18 @@
     return {4: cls._MyIPv4Address(netid),
             6: cls._MyIPv6Address(netid)}[version]
 
+  @staticmethod
+  def IPv6Prefix(netid):
+    return "2001:db8:%02x::" % netid
+
+  @staticmethod
+  def GetRandomDestination(prefix):
+    if "." in prefix:
+      return prefix + "%d.%d" % (random.randint(0, 31), random.randint(0, 255))
+    else:
+      return prefix + "%x:%x" % (random.randint(0, 65535),
+                                 random.randint(0, 65535))
+
   @classmethod
   def CreateTunInterface(cls, netid):
     iface = cls.GetInterfaceName(netid)
@@ -316,7 +344,7 @@
           scapy.ICMPv6ND_RA(retranstimer=validity_ms,
                             routerlifetime=routerlifetime) /
           scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
-          scapy.ICMPv6NDOptPrefixInfo(prefix="2001:db8:%d::" % netid,
+          scapy.ICMPv6NDOptPrefixInfo(prefix=cls.IPv6Prefix(netid),
                                       prefixlen=64,
                                       L=1, A=1,
                                       validlifetime=validity,
@@ -467,21 +495,24 @@
       iface = ""
     s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
 
+  def ReceiveEtherPacketOn(self, netid, packet):
+    posix.write(self.tuns[netid].fileno(), str(packet))
+
   def ReceivePacketOn(self, netid, ip_packet):
     routermac = self.RouterMacAddress(netid)
     mymac = self.MyMacAddress(netid)
     packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
-    posix.write(self.tuns[netid].fileno(), str(packet))
+    self.ReceiveEtherPacketOn(netid, packet)
 
-  def ReadAllPacketsOn(self, netid):
+  def ReadAllPacketsOn(self, netid, include_multicast=False):
     packets = []
     while True:
       try:
         packet = posix.read(self.tuns[netid].fileno(), 4096)
         ether = scapy.Ether(packet)
-        # Skip multicast frames, i.e., frames where the first byte of the
-        # destination MAC address has 1 in the least-significant bit.
-        if not int(ether.dst.split(":")[0], 16) & 0x1:
+        # Multicast frames are frames where the first byte of the destination
+        # MAC address has 1 in the least-significant bit.
+        if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
           packets.append(ether.payload)
       except OSError, e:
         # EAGAIN means there are no more packets waiting.
@@ -498,48 +529,6 @@
     while waiting != 0:
       waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
 
-
-class MarkTest(MultiNetworkTest):
-
-  # How many times to run packet reflection tests.
-  ITERATIONS = 5
-
-  # For convenience.
-  IPV4_ADDR = net_test.IPV4_ADDR
-  IPV6_ADDR = net_test.IPV6_ADDR
-  IPV4_PING = net_test.IPV4_PING
-  IPV6_PING = net_test.IPV6_PING
-
-  @classmethod
-  def setUpClass(cls):
-    super(MarkTest, cls).setUpClass()
-
-    # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
-    # will accept both IPv4 and IPv6 connections. We do this here instead of in
-    # each test so we can use the same socket every time. That way, if a kernel
-    # bug causes incoming packets to mark the listening socket instead of the
-    # accepted socket, the test will fail as soon as the next address/interface
-    # combination is tried.
-    cls.listenport = 1234
-    cls.listensocket = net_test.IPv6TCPSocket()
-    cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
-    cls.listensocket.bind(("::", cls.listenport))
-    cls.listensocket.listen(100)
-
-  @classmethod
-  def _SetMarkReflectSysctls(cls, value):
-    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
-    try:
-      cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
-    except IOError:
-      # This does not exist if we use the version of the patch that uses a
-      # common sysctl for IPv4 and IPv6.
-      pass
-
-  @classmethod
-  def _SetTCPMarkAcceptSysctl(cls, value):
-    cls.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
-
   def assertPacketMatches(self, expected, actual):
     # The expected packet is just a rough sketch of the packet we expect to
     # receive. For example, it doesn't contain fields we can't predict, such as
@@ -606,7 +595,7 @@
     except AssertionError:
       return False
 
-  def ExpectNoPacketsOn(self, netid, msg, expected):
+  def ExpectNoPacketsOn(self, netid, msg):
     packets = self.ReadAllPacketsOn(netid)
     if packets:
       firstpacket = str(packets[0]).encode("hex")
@@ -615,7 +604,16 @@
     self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
 
   def ExpectPacketOn(self, netid, msg, expected):
-    packets = self.ReadAllPacketsOn(netid)
+    # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
+    # multicast packets unless the packet we expect to see is a multicast
+    # packet. For now the only tests that use this are IPv6.
+    ipv6 = expected.getlayer("IPv6")
+    if ipv6 and ipv6.dst.startswith("ff"):
+      include_multicast = True
+    else:
+      include_multicast = False
+
+    packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
     self.assertTrue(packets, msg + ": received no packets")
 
     # If we receive a packet that matches what we expected, return it.
@@ -633,6 +631,48 @@
       raise UnexpectedPacketError(
           "%s: diff with last packet:\n%s" % (msg, e.message))
 
+
+class MarkTest(MultiNetworkTest):
+
+  # How many times to run packet reflection tests.
+  ITERATIONS = 5
+
+  # For convenience.
+  IPV4_ADDR = net_test.IPV4_ADDR
+  IPV6_ADDR = net_test.IPV6_ADDR
+  IPV4_PING = net_test.IPV4_PING
+  IPV6_PING = net_test.IPV6_PING
+
+  @classmethod
+  def setUpClass(cls):
+    super(MarkTest, cls).setUpClass()
+
+    # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
+    # will accept both IPv4 and IPv6 connections. We do this here instead of in
+    # each test so we can use the same socket every time. That way, if a kernel
+    # bug causes incoming packets to mark the listening socket instead of the
+    # accepted socket, the test will fail as soon as the next address/interface
+    # combination is tried.
+    cls.listenport = 1234
+    cls.listensocket = net_test.IPv6TCPSocket()
+    cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+    cls.listensocket.bind(("::", cls.listenport))
+    cls.listensocket.listen(100)
+
+  @classmethod
+  def _SetMarkReflectSysctls(cls, value):
+    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
+    try:
+      cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
+    except IOError:
+      # This does not exist if we use the version of the patch that uses a
+      # common sysctl for IPv4 and IPv6.
+      pass
+
+  @classmethod
+  def _SetTCPMarkAcceptSysctl(cls, value):
+    cls.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
+
   def setUp(self):
     self.ClearTunQueues()
 
@@ -672,7 +712,6 @@
         version, desc, self.GetInterfaceName(expected_netid))
     self.ExpectPacketOn(expected_netid, msg, expected)
 
-
   def CheckTCPSYNPacket(self, version, mark, uid, oif, dstaddr, expected_netid):
     s = self.BuildSocket(version, net_test.TCPSocket, mark, uid, oif)
 
@@ -838,7 +877,7 @@
                          packet, reply, msg)
             else:
               msg += ": Expecting no packets on %s" % iif
-              self.ExpectNoPacketsOn(iif_netid, msg, reply)
+              self.ExpectNoPacketsOn(iif_netid, msg)
 
   def SYNToClosedPort(self, *args):
     return Packets.SYN(999, *args)
@@ -951,6 +990,42 @@
         self.SendRA(netid)
       CheckIPv6Connectivity(True)
 
+  def testOnlinkCommunication(self):
+    """Checks that on-link communication goes direct and not through routers."""
+    for netid in self.tuns:
+      # Send a UDP packet to a random on-link destination.
+      s = net_test.UDPSocket(AF_INET6)
+      iface = self.GetInterfaceName(netid)
+      self.BindToDevice(s, iface)
+      # dstaddr can never be our address because GetRandomDestination only fills
+      # in the lower 32 bits, but our address has 0xff in the byte before that
+      # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
+      dstaddr = self.GetRandomDestination(self.IPv6Prefix(netid))
+      s.sendto("hello", (dstaddr, 53))
+
+      # Expect an NS for that destination on the interface.
+      myaddr = self.MyAddress(6, netid)
+      mymac = self.MyMacAddress(netid)
+      desc, expected = Packets.NS(myaddr, dstaddr, mymac)
+      msg = "Sending UDP packet to on-link destination: expecting %s" % desc
+      self.ExpectPacketOn(netid, msg, expected)
+
+      # Send an NA.
+      tgtmac = "02:00:00:00:%02x:99" % netid
+      _, reply = Packets.NA(dstaddr, myaddr, tgtmac)
+      # Don't use ReceivePacketOn, since that uses the router's MAC address as
+      # the source. Instead, construct our own Ethernet header with source
+      # MAC of tgtmac.
+      reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
+      self.ReceiveEtherPacketOn(netid, reply)
+
+      # Expect the kernel to send the original UDP packet now that the ND cache
+      # entry has been populated.
+      sport = s.getsockname()[1]
+      desc, expected = Packets.UDP(6, myaddr, dstaddr, sport=sport)
+      msg = "After NA response, expecting %s" % desc
+      self.ExpectPacketOn(netid, msg, expected)
+
   @unittest.skipUnless(False, "Known bug: routing tables are never deleted")
   def testNoLeftoverRoutes(self):
     def GetNumRoutes():
@@ -972,38 +1047,31 @@
   IPV6_PATHMTU = 61
   IPV6_DONTFRAG = 62
 
-  def GetRandomDestination(self, version):
-    if version == 4:
-      return "172.16.%d.%d" % (random.randint(0, 31), random.randint(0, 255))
-    else:
-      return "2001:db8::%x:%x" % (random.randint(0, 65535),
-                                  random.randint(0, 65535))
-
   def GetSocketMTU(self, s):
     ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
     mtu = struct.unpack("=28sI", ip6_mtuinfo)
     return mtu[1]
 
   def testIPv6PMTU(self):
-    s = net_test.Socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
-    s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
-    s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
-    netid = self.NETIDS[2]  # Just pick an arbitrary one.
+    for netid in self.tuns:
+      s = net_test.Socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
+      s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
+      s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
 
-    srcaddr = self.MyAddress(6, netid)
-    dstaddr = self.GetRandomDestination(6)
-    intermediate = "2001:db8::1"
+      srcaddr = self.MyAddress(6, netid)
+      dstaddr = self.GetRandomDestination("2001:db8::")
+      intermediate = "2001:db8::1"
 
-    self.SetSocketMark(s, netid)  # So the packet has somewhere to go.
-    s.connect((dstaddr, 1234))
-    self.assertEquals(1500, self.GetSocketMTU(s))
+      self.SetSocketMark(s, netid)  # So the packet has somewhere to go.
+      s.connect((dstaddr, 1234))
+      self.assertEquals(1500, self.GetSocketMTU(s))
 
-    s.send(1400 * "a")
-    packets = self.ReadAllPacketsOn(netid)
-    self.assertEquals(1, len(packets))
-    toobig = Packets.ICMPPacketTooBig(6, intermediate, srcaddr, packets[0])[1]
-    self.ReceivePacketOn(netid, toobig)
-    self.assertEquals(1280, self.GetSocketMTU(s))
+      s.send(1400 * "a")
+      packets = self.ReadAllPacketsOn(netid)
+      self.assertEquals(1, len(packets))
+      toobig = Packets.ICMPPacketTooBig(6, intermediate, srcaddr, packets[0])[1]
+      self.ReceivePacketOn(netid, toobig)
+      self.assertEquals(1280, self.GetSocketMTU(s))
 
 
 if __name__ == "__main__":