Separate the reflect and accept tests.

This will allow testing accepting connections in various routing
modes (uid, SO_BINDTODEVICE, mark, etc.). Currently we only test
mark and SO_BINDTODEVICE.

Change-Id: Ic93ae839e56ae70ecf2c878661ab4f60a8a2440e
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index b14d48b..e6751c5 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -623,7 +623,7 @@
   def ExpectNoPacketsOn(self, netid, msg):
     packets = self.ReadAllPacketsOn(netid)
     if packets:
-      firstpacket = str(packets[0]).encode("hex")
+      firstpacket = repr(packets[0])
     else:
       firstpacket = ""
     self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
@@ -894,96 +894,76 @@
     self.CheckRemarking(6, False)
     self.CheckRemarking(6, True)
 
-  def CheckReflection(self, version, packet_generator, reply_generator,
-                      mark_behaviour, callback=None):
-    """Checks that replies go out on the same interface as the original.
-
-    Iterates through all the combinations of the interfaces in self.tuns and the
-    IP addresses assigned to them. For each combination:
-     - Calls packet_generator to generate a packet to that IP address.
-     - Writes the packet generated by packet_generator on the given tun
-       interface, causing the kernel to receive it.
-     - Checks that the kernel's reply matches the packet generated by
-       reply_generator.
-     - Calls the given callback function.
-
-    Args:
-      version: An integer, 4 or 6.
-      packet_generator: A function taking an IP version (an integer), a source
-        address and a destination address (strings), and returning a scapy
-        packet.
-      reply_generator: A function taking the same arguments as packet_generator,
-        plus a scapy packet, and returning a scapy packet.
-      mark_behaviour: A string describing the mark behaviour to test. Tests are
-        performed with the corresponding sysctl set to both 0 and 1.
-      callback: A function to call to perform extra checks if the packet
-        matches. Takes netid, version, local address, remote address, original
-        packet, kernel reply, and a message.
-    """
-    # What are we testing?
-    sysctl_function = {"accept": self._SetTCPMarkAcceptSysctl,
-                       "reflect": self._SetMarkReflectSysctls}[mark_behaviour]
+  def Combinations(self, version):
+    """Produces a list of combinations to test."""
+    combinations = []
 
     # Check packets addressed to the IP addresses of all our interfaces...
     for dest_ip_netid in self.tuns:
-      dest_ip_iface = self.GetInterfaceName(dest_ip_netid)
-
+      ip_iface = self.GetInterfaceName(dest_ip_netid)
       myaddr = self.MyAddress(version, dest_ip_netid)
       remote_addr = self._GetRemoteAddress(version)
 
-      # ... coming in on all our interfaces...
-      for iif_netid in self.tuns:
-        iif = self.GetInterfaceName(iif_netid)
+      # ... coming in on all our interfaces.
+      for netid in self.tuns:
+        iif = self.GetInterfaceName(netid)
+        combinations.append((netid, iif, ip_iface, myaddr, remote_addr))
 
-        # ... with inbound mark sysctl enabled and disabled.
-        for sysctl_value in [0, 1]:
+    return combinations
 
-          # If we're testing accepting TCP connections, also check that
-          # SO_BINDTODEVICE correctly sets the interface the SYN+ACK is sent on.
-          # Since SO_BINDTODEVICE and the sysctl do the same thing, it doesn't
-          # really make sense to test with sysctl_value=1 and SO_BINDTODEVICE
-          # turned on at the same time.
-          if mark_behaviour == "accept" and not sysctl_value:
-            bind_devices = [None, iif]
-          else:
-            bind_devices = [None]
+  def _FormatMessage(self, iif, ip_iface, extra, desc, reply_desc):
+    msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_iface, extra)
+    if reply_desc:
+      msg += ": Expecting %s on %s" % (reply_desc, iif)
+    else:
+      msg += ": Expecting no packets on %s" % iif
+    return msg
 
-          for bound_dev in bind_devices:
-            # The socket is unbound in tearDown.
-            self.BindToDevice(self.listensocket, bound_dev)
+  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
+    self.ReceivePacketOn(netid, packet)
+    if reply:
+      return self.ExpectPacketOn(netid, msg, reply)
+    else:
+      self.ExpectNoPacketsOn(netid, msg)
+      return None
 
-            # Generate the packet here instead of in the outer loop, so
-            # subsequent TCP connections use different source ports and
-            # retransmissions from old connections don't confuse subsequent
-            # tests.
-            desc, packet = packet_generator(version, remote_addr, myaddr)
-            reply_desc, reply = reply_generator(version, myaddr, remote_addr,
-                                                packet)
+  def CheckReflection(self, version, gen_packet, gen_reply):
+    """Checks that replies go out on the same interface as the original.
 
-            msg = "Receiving %s on %s to %s IP, %s=%d, bound_dev=%s" % (
-                desc, iif, dest_ip_iface, mark_behaviour, sysctl_value,
-                bound_dev)
-            sysctl_function(sysctl_value)
+    For each combination:
+     - Calls gen_packet to generate a packet to that IP address.
+     - Writes the packet generated by gen_packet on the given tun
+       interface, causing the kernel to receive it.
+     - Checks that the kernel's reply matches the packet generated by
+       gen_reply.
 
-            # Cause the kernel to receive packet on iif_netid.
-            self.ReceivePacketOn(iif_netid, packet)
+    Args:
+      version: An integer, 4 or 6.
+      gen_packet: A function taking an IP version (an integer), a source
+        address and a destination address (strings), and returning a scapy
+        packet.
+      gen_reply: A function taking the same arguments as gen_packet,
+        plus a scapy packet, and returning a scapy packet.
+    """
+    for netid, iif, ip_iface, myaddr, remote_addr in self.Combinations(version):
+      # Generate a test packet.
+      desc, packet = gen_packet(version, remote_addr, myaddr)
 
-            # Expect the kernel to send out reply on the same interface.
-            #
-            # HACK: IPv6 ping replies always do a routing lookup with the
-            # interface the ping came in on. So even if mark reflection is not
-            # working, IPv6 ping replies will be properly reflected. Don't
-            # fail when that happens.
-            if bound_dev or sysctl_value or reply_desc == "ICMPv6 echo reply":
-              msg += ": Expecting %s on %s" % (reply_desc, iif)
-              reply = self.ExpectPacketOn(iif_netid, msg, reply)
-              # If a callback was set, call it.
-              if callback:
-                callback(sysctl_value, iif_netid, version, myaddr, remote_addr,
-                         packet, reply, msg)
-            else:
-              msg += ": Expecting no packets on %s" % iif
-              self.ExpectNoPacketsOn(iif_netid, msg)
+      # Test with mark reflection enabled and disabled.
+      for reflect in [0, 1]:
+        self._SetMarkReflectSysctls(reflect)
+        # HACK: IPv6 ping replies always do a routing lookup with the
+        # interface the ping came in on. So even if mark reflection is not
+        # working, IPv6 ping replies will be properly reflected. Don't
+        # fail when that happens.
+        if reflect or desc == "ICMPv6 echo":
+          reply_desc, reply = gen_reply(version, myaddr, remote_addr, packet)
+        else:
+          reply_desc, reply = None, None
+
+        msg = self._FormatMessage(iif, ip_iface, "reflect=%d" % reflect,
+                                  desc, reply_desc)
+        self._ReceiveAndExpectResponse(netid, packet, reply, msg)
 
   def SYNToClosedPort(self, *args):
     return Packets.SYN(999, *args)
@@ -993,33 +973,33 @@
 
   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
   def testIPv4ICMPErrorsReflectMark(self):
-    self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable, "reflect")
+    self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable)
 
   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
   def testIPv6ICMPErrorsReflectMark(self):
-    self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable, "reflect")
+    self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable)
 
   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
   def testIPv4PingRepliesReflectMarkAndTos(self):
-    self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply, "reflect")
+    self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply)
 
   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
   def testIPv6PingRepliesReflectMarkAndTos(self):
-    self.CheckReflection(6, Packets.ICMPEcho, Packets.ICMPReply, "reflect")
+    self.CheckReflection(6, Packets.ICMPEcho, Packets.ICMPReply)
 
   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
   def testIPv4RSTsReflectMark(self):
-    self.CheckReflection(4, self.SYNToClosedPort, Packets.RST, "reflect")
+    self.CheckReflection(4, self.SYNToClosedPort, Packets.RST)
 
   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
   def testIPv6RSTsReflectMark(self):
-    self.CheckReflection(6, self.SYNToClosedPort, Packets.RST, "reflect")
+    self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
 
   def CheckTCPConnection(self, sysctl_value, netid, version,
                          myaddr, remote_addr, packet, reply, msg):
     establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
     self.ReceivePacketOn(netid, establishing_ack)
-    s, unused_peer = self.listensocket.accept()
+    s, _ = self.listensocket.accept()
     try:
       mark = self.GetSocketMark(s)
     finally:
@@ -1043,25 +1023,73 @@
     desc, finackack = Packets.ACK(version, myaddr, remote_addr, finack)
     self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
 
+  def CheckTCP(self, version, gen_packet, gen_reply):
+    """Checks that incoming TCP connections work.
+
+    Args:
+      version: An integer, 4 or 6.
+      gen_packet: A function taking an IP version (an integer), a source
+        address and a destination address (strings), and returning a scapy
+        packet.
+      gen_reply: A function taking the same arguments as gen_packet,
+        plus a scapy packet, and returning a scapy packet.
+        packet, kernel reply, and a message.
+    """
+    for netid, iif, ip_iface, myaddr, remote_addr in self.Combinations(version):
+      desc, packet = gen_packet(version, remote_addr, myaddr)
+
+      for sysctl_value in [0, 1]:
+        self._SetTCPMarkAcceptSysctl(sysctl_value)
+
+        # If we're testing accepting TCP connections, also check that
+        # SO_BINDTODEVICE correctly sets the interface the SYN+ACK is sent on.
+        # Since SO_BINDTODEVICE and the sysctl do the same thing, it doesn't
+        # really make sense to test with sysctl_value=1 and SO_BINDTODEVICE
+        # turned on at the same time.
+        if not sysctl_value:
+          bind_devices = [None, iif]
+        else:
+          bind_devices = [None]
+
+        for bound_dev in bind_devices:
+          # The socket is unbound in tearDown.
+          self.BindToDevice(self.listensocket, bound_dev)
+
+          # Generate the packet here instead of in the outer loop, so
+          # subsequent TCP connections use different source ports and
+          # retransmissions from old connections don't confuse subsequent
+          # tests.
+          desc, packet = gen_packet(version, remote_addr, myaddr)
+
+          if bound_dev or sysctl_value:
+            reply_desc, reply = gen_reply(version, myaddr, remote_addr, packet)
+          else:
+            reply_desc, reply = None, None
+
+          extra = "accept=%d, bound_dev=%s" % (sysctl_value, bound_dev)
+          msg = self._FormatMessage(iif, ip_iface, extra, desc, reply_desc)
+          reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
+
+          if reply:
+            self.CheckTCPConnection(sysctl_value, netid, version,
+                                    myaddr, remote_addr,
+                                    packet, reply, msg)
+
   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
   def testIPv4TCPConnections(self):
-    self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                         self.CheckTCPConnection)
+    self.CheckTCP(4, self.SYNToOpenPort, Packets.SYNACK)
 
   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
   def testIPv6TCPConnections(self):
-    self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                         self.CheckTCPConnection)
+    self.CheckTCP(6, self.SYNToOpenPort, Packets.SYNACK)
 
   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
   def testTCPConnectionsWithSynCookies(self):
     # Force SYN cookies on all connections.
     self.SetSysctl(SYNCOOKIES_SYSCTL, 2)
     try:
-      self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                           self.CheckTCPConnection)
-      self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                           self.CheckTCPConnection)
+      self.CheckTCP(4, self.SYNToOpenPort, Packets.SYNACK)
+      self.CheckTCP(6, self.SYNToOpenPort, Packets.SYNACK)
     finally:
       # Stop forcing SYN cookies on all connections.
       self.SetSysctl(SYNCOOKIES_SYSCTL, 1)