Improve MarkTest.

1. Add TCP SYN+ACK tests including syncookies and checks that
   accepting connections succeeids and that the sockets returned
   by accept() are marked.
2. Mark the tests more robust with respect to extra packets by
   always explicitly expecting packets (including when testing
   outgoing kernel-generated packets) and looking for them
   anywhere in the queue instead of insisting they're the first
   packet in the queue.
3. Make the tests more robust by using random source port,
   disabling ICMP rate limits, setting SO_REUSEADDR, and
   clearing queues more reliably.
4. Move from 2 to 4 interfaces (mostly made possible by the
   robustness improvements above).
5. Use named constants instead of repeating the numbers in
   multiple places.

Change-Id: I596e557a7eea02ccf603c812a9b8ea6f5b2f95da
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index 775b441..d837662 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -1,9 +1,11 @@
 #!/usr/bin/python
 
-import fcntl
 import errno
+import fcntl
 import os
 import posix
+import random
+import re
 import struct
 import time
 import unittest
@@ -21,6 +23,21 @@
 
 AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/route/autoconf_table_offset"
 
+PING_IDENT = 0xff19
+PING_PAYLOAD = "foobarbaz"
+PING_SEQ = 3
+PING_TOS = 0x83
+
+TCP_SYN = 2
+TCP_RST = 4
+TCP_ACK = 16
+
+TCP_SEQ = 1692871236
+TCP_WINDOW = 14400
+
+UDP_PAYLOAD = "hello"
+
+
 class ConfigurationError(AssertionError):
   pass
 
@@ -32,6 +49,10 @@
 class Packets(object):
 
   @staticmethod
+  def RandomPort():
+    return random.randint(1025, 65535)
+
+  @staticmethod
   def _GetIpLayer(version):
     return {4: scapy.IP, 6: scapy.IPv6}[version]
 
@@ -45,19 +66,25 @@
       raise ValueError("Can't find ToS Field")
 
   @classmethod
-  def UdpPacket(self, version, srcaddr, dstaddr):
+  def UDP(self, version, srcaddr, dstaddr, sport=0):
     ip = self._GetIpLayer(version)
+    # Can't just use "if sport" because None has meaning (it means unspecified).
+    if sport == 0:
+      sport = self.RandomPort()
     return ("UDPv%d packet" % version,
             ip(src=srcaddr, dst=dstaddr) /
-            scapy.UDP(sport=999, dport=1234) / "hello")
+            scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
 
   @classmethod
-  def SYN(self, port, version, srcaddr, dstaddr):
+  def SYN(self, dport, version, srcaddr, dstaddr, sport=0, seq=TCP_SEQ):
     ip = self._GetIpLayer(version)
+    if sport == 0:
+      sport = self.RandomPort()
     return ("TCP SYN",
             ip(src=srcaddr, dst=dstaddr) /
-            scapy.TCP(sport=50999, dport=port, seq=1692871236, ack=0,
-                      flags=2, window=14400))
+            scapy.TCP(sport=sport, dport=dport,
+                      seq=seq, ack=0,
+                      flags=TCP_SYN, window=TCP_WINDOW))
 
   @classmethod
   def RST(self, version, srcaddr, dstaddr, packet):
@@ -67,7 +94,7 @@
             ip(src=srcaddr, dst=dstaddr) /
             scapy.TCP(sport=original.dport, dport=original.sport,
                       ack=original.seq + 1, seq=None,
-                      flags=20, window=None))
+                      flags=TCP_RST | TCP_ACK, window=TCP_WINDOW))
 
   @classmethod
   def SYNACK(self, version, srcaddr, dstaddr, packet):
@@ -77,7 +104,18 @@
             ip(src=srcaddr, dst=dstaddr) /
             scapy.TCP(sport=original.dport, dport=original.sport,
                       ack=original.seq + 1, seq=None,
-                      flags=18, window=None))
+                      flags=TCP_SYN | TCP_ACK, window=None))
+
+  @classmethod
+  def ACK(self, version, srcaddr, dstaddr, packet):
+    ip = self._GetIpLayer(version)
+    original = packet.getlayer("TCP")
+    was_syn = (original.flags & TCP_SYN) != 0
+    return ("TCP ACK",
+            ip(src=srcaddr, dst=dstaddr) /
+            scapy.TCP(sport=original.dport, dport=original.sport,
+                      ack=original.seq + was_syn, seq=original.ack,
+                      flags=TCP_ACK, window=TCP_WINDOW))
 
   @classmethod
   def ICMPPortUnreachable(self, version, srcaddr, dstaddr, packet):
@@ -97,26 +135,26 @@
     ip = self._GetIpLayer(version)
     icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
     packet = (ip(src=srcaddr, dst=dstaddr) /
-              icmp(id=0xff19, seq=3) / "foobarbaz")
-    self._SetPacketTos(packet, 0x83)
+              icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
+    self._SetPacketTos(packet, PING_TOS)
     return ("ICMPv%d echo" % version, packet)
 
   @classmethod
-  def ICMPReply(self, version, srcaddr, dstaddr, packet, tos=None):
+  def ICMPReply(self, version, srcaddr, dstaddr, packet):
     ip = self._GetIpLayer(version)
-
     # Scapy doesn't provide an ICMP echo reply constructor.
     icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
     icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
     packet = (ip(src=srcaddr, dst=dstaddr) /
-              icmp(id=0xff19, seq=3) / "foobarbaz")
-    self._SetPacketTos(packet, 0x83)
+              icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
+    self._SetPacketTos(packet, PING_TOS)
     return ("ICMPv%d echo" % version, packet)
 
 
 class MarkTest(net_test.NetworkTest):
 
-  NETIDS = [100, 200]
+  # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
+  NETIDS = [100, 150, 200, 250]
 
   @staticmethod
   def _RouterMacAddress(netid):
@@ -135,19 +173,27 @@
     else:
       raise ValueError("Don't support IPv%s" % version)
 
-  @staticmethod
-  def _MyIPv4Address(netid):
+  @classmethod
+  def _MyIPv4Address(self, netid):
     return "10.0.%d.2" % netid
 
   @classmethod
+  def _MyIPv6Address(self, netid):
+    return net_test.GetLinkAddress(self._GetInterfaceName(netid), False)
+
+  @classmethod
+  def _MyAddress(self, version, netid):
+    return {4: self._MyIPv4Address(netid),
+            6: self._MyIPv6Address(netid)}[version]
+
+  @classmethod
   def _CreateTunInterface(self, netid):
     iface = self._GetInterfaceName(netid)
     f = open("/dev/net/tun", "r+b")
     ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
-    ifr = ifr + "\x00" * (40 - len(ifr))
+    ifr += "\x00" * (40 - len(ifr))
     fcntl.ioctl(f, TUNSETIFF, ifr)
     # Give ourselves a predictable MAC address.
-    macaddr = self._MyMacAddress(netid)
     net_test.SetInterfaceHWAddr(iface, self._MyMacAddress(netid))
     # Disable DAD so we don't have to wait for it.
     open("/proc/sys/net/ipv6/conf/%s/dad_transmits" % iface, "w").write("0")
@@ -244,13 +290,36 @@
     if self.AUTOCONF_TABLE_OFFSET >= 0:
       return self.ifindices[netid] + self.AUTOCONF_TABLE_OFFSET
     else:
-      return netid      
+      return netid
+
+  @classmethod
+  def _ICMPRatelimitFilename(self, version):
+    return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
+                               6: "ipv6/icmp/ratelimit"}[version]
+
+  @classmethod
+  def _GetICMPRatelimit(self, version):
+    return int(open(self._ICMPRatelimitFilename(version), "r").read().strip())
+
+  @classmethod
+  def _SetICMPRatelimit(self, version, limit):
+    return open(self._ICMPRatelimitFilename(version), "w").write("%d" % limit)
 
   @classmethod
   def setUpClass(self):
+    # This is per-class setup instead of per-testcase setup because shelling out
+    # to ip and iptables is slow, and because routing configuration doesn't
+    # change during the test.
     self.tuns = {}
     self.ifindices = {}
     self._SetAutoconfTableSysctl(1000)
+
+    # Disable ICMP rate limits.
+    self.ratelimits = {}
+    for version in [4, 6]:
+      self.ratelimits[version] = self._GetICMPRatelimit(version)
+      self._SetICMPRatelimit(version, 0)
+
     for netid in self.NETIDS:
       self.tuns[netid] = self._CreateTunInterface(netid)
 
@@ -268,6 +337,7 @@
     # combination is tried.
     self.listenport = 1234
     self.listensocket = net_test.IPv6TCPSocket()
+    self.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
     self.listensocket.bind(("::", self.listenport))
     self.listensocket.listen(100)
 
@@ -276,77 +346,108 @@
     # Uncomment to look around at interface and rule configuration while
     # running in the background. (Once the test finishes running, all the
     # interfaces and rules are gone.)
-    #time.sleep(30)
+    # time.sleep(30)
 
   @classmethod
   def tearDownClass(self):
     for netid in self.tuns:
       self._RunSetupCommands(netid, False)
       self.tuns[netid].close()
+    self._SetAutoconfTableSysctl(-1)
+    for version in [4, 6]:
+      self._SetICMPRatelimit(version, self.ratelimits[version])
 
-  def CheckExpectedPacket(self, expected, actual, msg):
-      # Remove the Ethernet header from the incoming packet.
-      actual = scapy.Ether(actual).payload
+  def assertPacketMatches(self, expected, actual):
+    # Remove the Ethernet header from the incoming packet.
+    actual = scapy.Ether(actual).payload
 
-      # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
-      actualip = actual.getlayer("IP")
-      expectedip = expected.getlayer("IP")
-      if actualip and expectedip:
-        actualip.id = expectedip.id
-        actualip.flags &= 5
-        actualip.chksum = None  # Change the header, recalculate the checksum.
+    # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
+    actualip = actual.getlayer("IP")
+    expectedip = expected.getlayer("IP")
+    if actualip and expectedip:
+      actualip.id = expectedip.id
+      actualip.flags &= 5
+      actualip.chksum = None  # Change the header, recalculate the checksum.
 
-      # Blank out TCP fields that we can't predict.
-      actualtcp = actual.getlayer("TCP")
-      expectedtcp = expected.getlayer("TCP")
-      if actualtcp and expectedtcp:
-        actualtcp.dataofs = expectedtcp.dataofs
-        actualtcp.options = expectedtcp.options
-        actualtcp.window = expectedtcp.window
-        if expectedtcp.seq is None:
-          actualtcp.seq = None
-        if expectedtcp.ack is None:
-          actualtcp.ack = None
-        actualtcp.chksum = None
+    # Blank out UDP fields that we can't predict (e.g., the source port for
+    # kernel-originated packets).
+    actualudp = actual.getlayer("UDP")
+    expectedudp = expected.getlayer("UDP")
+    if actualudp and expectedudp:
+      if expectedudp.sport is None:
+        actualudp.sport = None
+        actualudp.chksum = None
 
-      # Serialize the packet so:
-      # - Expected packet fields that are only set when a packet is serialized
-      #   (e.g., the checksum) are filled in.
-      # - The packet is readable. Scapy has detailed dissection capabilities,
-      #   but they only seem to be usable to print the packet, not return its
-      #   dissection as a string.
-      #   TODO: Check if this is true.
-      self.assertMultiLineEqual(str(expected).encode("hex"),
-                                str(actual).encode("hex"))
-    
-  def assertNoPacketsOn(self, netids, msg):
-    for netid in netids:
-      try:
-        self.assertRaisesErrno(errno.EAGAIN, self.tuns[netid].read, 4096)
-      except AssertionError, e:
-        raise UnexpectedPacketError("%s: Unexpected packet on %s" % (
-            msg, self._GetInterfaceName(netid)))
+    # Since the TCP code below messes with options, recalculate the length.
+    if actualip:
+      actualip.len = None
+    actualipv6 = actual.getlayer("IPv6")
+    if actualipv6:
+      actualipv6.plen = None
 
-  def assertNoOtherPackets(self, msg):
-    self.assertNoPacketsOn([netid for netid in self.tuns], msg)
+    # Blank out TCP fields that we can't predict.
+    actualtcp = actual.getlayer("TCP")
+    expectedtcp = expected.getlayer("TCP")
+    if actualtcp and expectedtcp:
+      actualtcp.dataofs = expectedtcp.dataofs
+      actualtcp.options = expectedtcp.options
+      actualtcp.window = expectedtcp.window
+      if expectedtcp.sport is None:
+        actualtcp.sport = None
+      if expectedtcp.seq is None:
+        actualtcp.seq = None
+      if expectedtcp.ack is None:
+        actualtcp.ack = None
+      actualtcp.chksum = None
 
-  def assertNoPacketsExceptOn(self, netid, msg):
-    self.assertNoPacketsOn([n for n in self.tuns if n != netid], msg)
+    # Serialize the packet so:
+    # - Expected packet fields that are only set when a packet is serialized
+    #   (e.g., the checksum) are filled in.
+    # - The packet is vaguely human-readable. Scapy has sophisticated packet
+    #   dissection capabilities, but unfortunately they can only be used to
+    #   print the packet, not to return its dissection as as string.
+    self.assertMultiLineEqual(str(expected).encode("hex"),
+                              str(actual).encode("hex"))
 
-  def ExpectPacketOn(self, netid, msg, expected=None):
-    # Check no packets were sent on any other netid.
-    self.assertNoPacketsExceptOn(netid, msg)
-
-    # Check that a packet was sent on netid.
+  def PacketMatches(self, expected, actual):
     try:
-      actual = self.tuns[netid].read(4096)
-    except IOError, e:
-      raise AssertionError(msg + ": " + str(e))
-    self.assertTrue(actual)
+      self.assertPacketMatches(expected, actual)
+      return True
+    except AssertionError:
+      return False
 
-    # If we know what sort of packet we expect, check that here.
-    if expected:
-      self.CheckExpectedPacket(expected, actual, msg)
+  def ReadAllPacketsOn(self, netid):
+    packets = []
+    while True:
+      try:
+        packets.append(posix.read(self.tuns[netid].fileno(), 4096))
+      except OSError, e:
+        # EAGAIN means there are no more packets waiting.
+        if re.match(e.message, os.strerror(errno.EAGAIN)):
+          break
+        # Anything else is unexpected.
+        else:
+          raise e
+    return packets
+
+  def ExpectPacketOn(self, netid, msg, expected):
+    packets = self.ReadAllPacketsOn(netid)
+    self.assertTrue(packets, msg + ": received no packets")
+
+    # If we receive a packet that matches what we expected, return it.
+    for packet in packets:
+      if self.PacketMatches(expected, packet):
+        return scapy.Ether(packet).payload
+
+    # None of the packets matched. Call assertPacketMatches to output a diff
+    # between the expected packet and the last packet we received. In theory,
+    # we'd output a diff to the packet that's the best match for what we
+    # expected, but this is good enough for now.
+    try:
+      self.assertPacketMatches(expected, packets[-1])
+    except Exception, e:
+      raise UnexpectedPacketError(
+          "%s: diff with last packet:\n%s" % (msg, e.message))
 
   def ReceivePacketOn(self, netid, ip_packet):
     routermac = self._RouterMacAddress(netid)
@@ -355,12 +456,10 @@
     posix.write(self.tuns[netid].fileno(), str(packet))
 
   def ClearTunQueues(self):
-    for f in self.tuns.values():
-      try:
-        f.read(4096)
-      except IOError:
-        continue
-    self.assertNoOtherPackets("Unexpected packets after clearing queues")
+    # Keep reading packets on all netids until we get no packets on any of them.
+    waiting = None
+    while waiting != 0:
+      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
 
   def setUp(self):
     self.ClearTunQueues()
@@ -369,33 +468,53 @@
   def _GetRemoteAddress(version):
     return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
 
-  def MarkSocket(self, s, netid):
+  def SetSocketMark(self, s, netid):
     s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
 
+  def GetSocketMark(self, s):
+    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
+
   def GetProtocolFamily(self, version):
     return {4: AF_INET, 6: AF_INET6}[version]
 
   def testOutgoingPackets(self):
     """Checks that socket marking selects the right outgoing interface."""
 
-    def CheckPingPacket(version, netid, packet):
+    def CheckPingPacket(version, netid, dstaddr, packet):
       s = net_test.PingSocket(self.GetProtocolFamily(version))
-      dstaddr = self._GetRemoteAddress(version)
-      self.MarkSocket(s, netid)
-      s.sendto(packet, (dstaddr, 19321))
-      self.ExpectPacketOn(netid, "IPv%d ping: mark %d" % (version, netid))
+      myaddr = self._MyAddress(version, netid)
+      s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+      s.bind((myaddr, PING_IDENT))
+      self.SetSocketMark(s, netid)
+      net_test.SetSocketTos(s, PING_TOS)
+
+      desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
+
+      self.ClearTunQueues()
+      s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
+      msg = "IPv%d ping: expected %s on %s" % (
+          version, desc, self._GetInterfaceName(netid))
+      self.ExpectPacketOn(netid, msg, expected)
 
     for netid in self.tuns:
-      CheckPingPacket(4, netid, net_test.IPV4_PING)
-      CheckPingPacket(6, netid, net_test.IPV6_PING)
+      CheckPingPacket(4, netid, net_test.IPV4_ADDR, net_test.IPV4_PING)
+      CheckPingPacket(6, netid, net_test.IPV6_ADDR, net_test.IPV6_PING)
 
     def CheckTCPSYNPacket(version, netid, dstaddr):
       s = net_test.TCPSocket(self.GetProtocolFamily(version))
-      self.MarkSocket(s, netid)
+      self.SetSocketMark(s, netid)
+      if version == 6 and dstaddr.startswith("::ffff"):
+        version = 4
+      myaddr = self._MyAddress(version, netid)
+      desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
+                                   sport=None, seq=None)
+
+      self.ClearTunQueues()
       # Non-blocking TCP connects always return EINPROGRESS.
       self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
-      self.ExpectPacketOn(netid, "IPv%d TCP connect: mark %d" % (version,
-                                                                 netid))
+      msg = "IPv%s TCP connect: expected %s on %s" % (
+          version, desc, self._GetInterfaceName(netid))
+      self.ExpectPacketOn(netid, msg, expected)
       s.close()
 
     for netid in self.tuns:
@@ -405,13 +524,22 @@
 
     def CheckUDPPacket(version, netid, dstaddr):
       s = net_test.UDPSocket(self.GetProtocolFamily(version))
-      self.MarkSocket(s, netid)
-      s.sendto("hello", (dstaddr, 53))
-      self.ExpectPacketOn(netid, "IPv%d UDP sendto: mark %d" % (version, netid))
+      self.SetSocketMark(s, netid)
+      if version == 6 and dstaddr.startswith("::ffff"):
+        version = 4
+      myaddr = self._MyAddress(version, netid)
+      desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
+      msg = "IPv%s UDP %%s: expected %s on %s" % (
+          version, desc, self._GetInterfaceName(netid))
+
+      self.ClearTunQueues()
+      s.sendto(UDP_PAYLOAD, (dstaddr, 53))
+      self.ExpectPacketOn(netid, msg % "sendto", expected)
+
+      self.ClearTunQueues()
       s.connect((dstaddr, 53))
-      s.send("hello")
-      self.ExpectPacketOn(netid, "IPv%d UDP connect/send: mark %d" % (version,
-                                                                      netid))
+      s.send(UDP_PAYLOAD)
+      self.ExpectPacketOn(netid, msg % "connect/send", expected)
       s.close()
 
     for netid in self.tuns:
@@ -419,38 +547,53 @@
       CheckUDPPacket(6, netid, net_test.IPV6_ADDR)
       CheckUDPPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
 
-  def CheckReflection(self, version, packet_generator, reply_generator):
-    """Checks that replies go out on the same interface as the original."""
+  def CheckReflection(self, version, packet_generator, reply_generator,
+                      callback=None):
+    """Checks that replies go out on the same interface as the original.
 
+    Iterates through all the combinations of the interfaces in self.tuns and the
+    IP addresses assigned to them. For each combination:
+     - Calls packet_generator to generate a packet to that IP address.
+     - Writes the packet generated by packet_generator on the given tun
+       interface, causing the kernel to receive it.
+     - Checks that the kernel's reply matches the packet generated by
+       reply_generator.
+     - Calls the given callback function.
+
+    Args:
+      version: An integer, 4 or 6.
+      packet_generator: A function taking an IP version (an integer), a source
+        address and a destination address (strings), and returning a scapy
+        packet.
+      reply_generator: A function taking the same arguments as packet_generator,
+        plus a scapy packet, and returning a scapy packet.
+      callback: A function to call to perform extra checks if the packet
+        matches. Takes netid, version, local address, remote address, original
+        packet, kernel reply, and a message.
+    """
     # Check packets addressed to the IP addresses of all our interfaces...
     for dest_ip_netid in self.tuns:
       dest_ip_iface = self._GetInterfaceName(dest_ip_netid)
 
-      if version == 4:
-        myaddr = self._MyIPv4Address(dest_ip_netid)
-      else:
-        myaddr = net_test.GetLinkAddress(self._GetInterfaceName(dest_ip_netid),
-                                                                False)
+      myaddr = self._MyAddress(version, dest_ip_netid)
       remote_addr = self._GetRemoteAddress(version)
 
       # ... coming in on all our interfaces.
       for iif_netid in self.tuns:
         iif = self._GetInterfaceName(iif_netid)
         desc, packet = packet_generator(version, remote_addr, myaddr)
-        if reply_generator:
-          # We know what we want a reply to.
-          reply_desc, reply = reply_generator(version, myaddr, remote_addr,
-                                              packet)
-        else:
-          # Expect any reply.
-          reply_desc, reply = "any packet", None
+        reply_desc, reply = reply_generator(version, myaddr, remote_addr,
+                                            packet)
         msg = "Receiving %s on %s to %s IP: Expecting %s on %s" % (
             desc, iif, dest_ip_iface, reply_desc, iif)
 
-        # Expect a reply on the interface the original packet came in on.
         self.ClearTunQueues()
+        # Cause the kernel to receive packet on iif_netid.
         self.ReceivePacketOn(iif_netid, packet)
-        self.ExpectPacketOn(iif_netid, msg, reply)
+        # Expect the kernel to send out reply on the same interface.
+        reply = self.ExpectPacketOn(iif_netid, msg, reply)
+        if callback:
+          callback(iif_netid, version, myaddr, remote_addr, packet, reply, msg)
 
   def SYNToClosedPort(self, *args):
     return Packets.SYN(999, *args)
@@ -459,10 +602,10 @@
     return Packets.SYN(self.listenport, *args)
 
   def testIPv4ICMPErrorsReflectMark(self):
-    self.CheckReflection(4, Packets.UdpPacket, Packets.ICMPPortUnreachable)
+    self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable)
 
   def testIPv6ICMPErrorsReflectMark(self):
-    self.CheckReflection(6, Packets.UdpPacket, Packets.ICMPPortUnreachable)
+    self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable)
 
   def testIPv4PingRepliesReflectMarkAndTos(self):
     self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply)
@@ -476,13 +619,39 @@
   def testIPv6RSTsReflectMark(self):
     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
 
-  @unittest.skipUnless(False, "skipping: doesn't work yet")
-  def testIPv4SYNACKsReflectMark(self):
-    self.CheckReflection(4, Packets.SYNToOpenPort, Packets.SYNACK)
+  def CheckAcceptedSocketMarkCallback(self, netid, version, myaddr,
+                                      remote_addr, packet, reply, msg):
+    establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
+    self.ReceivePacketOn(netid, establishing_ack)
+    s, unused_peer = self.listensocket.accept()
+    try:
+      mark = self.GetSocketMark(s)
+    finally:
+      s.close()
+    self.assertEquals(netid, mark,
+                      msg + ": Accepted socket: Expected mark %d, got %d" % (
+                          netid, mark))
 
-  @unittest.skipUnless(False, "skipping: doesn't work yet")
+  def testIPv4SYNACKsReflectMark(self):
+    self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK,
+                         self.CheckAcceptedSocketMarkCallback)
+
   def testIPv6SYNACKsReflectMark(self):
-    self.CheckReflection(6, Packets.SYNToOpenPort, Packets.SYNACK)
+    self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK,
+                         self.CheckAcceptedSocketMarkCallback)
+
+  def testSynCookiesSYNACKsReflectMark(self):
+    # Force SYN cookies on all connections.
+    open("/proc/sys/net/ipv4/tcp_syncookies", "w").write("2")
+    try:
+      self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK,
+                           self.CheckAcceptedSocketMarkCallback)
+      self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK,
+                           self.CheckAcceptedSocketMarkCallback)
+    finally:
+      # Stop forcing SYN cookies on all connections.
+      open("/proc/sys/net/ipv4/tcp_syncookies", "w").write("1")
+
 
 
 if __name__ == "__main__":
diff --git a/net/test/net_test.py b/net/test/net_test.py
index caf79a0..bf67785 100755
--- a/net/test/net_test.py
+++ b/net/test/net_test.py
@@ -17,10 +17,11 @@
 from scapy import all as scapy
 
 SOL_IPV6 = 41
-IP_TRANSPARENT = 19
-IPV6_TRANSPARENT = 75
 IP_RECVERR = 11
 IPV6_RECVERR = 25
+IP_TRANSPARENT = 19
+IPV6_TRANSPARENT = 75
+IPV6_TCLASS = 67
 SO_BINDTODEVICE = 25
 SO_MARK = 36
 IPV6_FLOWLABEL_MGR = 32
@@ -60,6 +61,11 @@
   us = (ms % 1000) * 1000
   sock.setsockopt(SOL_SOCKET, SO_RCVTIMEO, struct.pack("LL", s, us))
 
+def SetSocketTos(s, tos):
+  level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
+  option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
+  s.setsockopt(level, option, tos)
+
 def SetNonBlocking(fd):
   flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
   fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
diff --git a/net/test/run_net_test.sh b/net/test/run_net_test.sh
index adaa354..306ebe4 100755
--- a/net/test/run_net_test.sh
+++ b/net/test/run_net_test.sh
@@ -2,7 +2,7 @@
 
 # Kernel configration options.
 OPTIONS=" IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
-OPTIONS="$OPTIONS TUN IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
+OPTIONS="$OPTIONS TUN SYN_COOKIES IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
 OPTIONS="$OPTIONS NETFILTER NETFILTER_ADVANCED NETFILTER_XTABLES"
 OPTIONS="$OPTIONS NETFILTER_XT_MARK NETFILTER_XT_TARGET_MARK"
 OPTIONS="$OPTIONS IP_NF_IPTABLES IP_NF_MANGLE"