Refactor setup and common code into a superclass.

Change-Id: Iee489954175de6eec12b711d6c3ebb9a64cfd6c3
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index 48c9e98..a5cd120 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -213,10 +213,7 @@
       os.seteuid(self.saved_uid)
 
 
-class MarkTest(net_test.NetworkTest):
-
-  # How many times to run packet reflection tests.
-  ITERATIONS = 5
+class MultiNetworkTest(net_test.NetworkTest):
 
   # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
   NETIDS = [100, 150, 200, 250]
@@ -224,18 +221,12 @@
   # Stores sysctl values to write back when the test completes.
   saved_sysctls = {}
 
-  # For convenience.
-  IPV4_ADDR = net_test.IPV4_ADDR
-  IPV6_ADDR = net_test.IPV6_ADDR
-  IPV4_PING = net_test.IPV4_PING
-  IPV6_PING = net_test.IPV6_PING
-
   # Wether to output setup commands.
   DEBUG = False
 
   @staticmethod
-  def _GetInterfaceName(netid):
-    return "nettest%d" % netid
+  def UidForNetid(netid):
+    return 2000 + netid
 
   @classmethod
   def _TableForNetid(cls, netid):
@@ -245,15 +236,15 @@
       return netid
 
   @staticmethod
-  def _UidForNetid(netid):
-    return 2000 + netid
+  def GetInterfaceName(netid):
+    return "nettest%d" % netid
 
   @staticmethod
-  def _RouterMacAddress(netid):
+  def RouterMacAddress(netid):
     return "02:00:00:00:%02x:00" % netid
 
   @staticmethod
-  def _MyMacAddress(netid):
+  def MyMacAddress(netid):
     return "02:00:00:00:%02x:01" % netid
 
   @staticmethod
@@ -271,22 +262,22 @@
 
   @classmethod
   def _MyIPv6Address(cls, netid):
-    return net_test.GetLinkAddress(cls._GetInterfaceName(netid), False)
+    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
 
   @classmethod
-  def _MyAddress(cls, version, netid):
+  def MyAddress(cls, version, netid):
     return {4: cls._MyIPv4Address(netid),
             6: cls._MyIPv6Address(netid)}[version]
 
   @classmethod
   def _CreateTunInterface(cls, netid):
-    iface = cls._GetInterfaceName(netid)
+    iface = cls.GetInterfaceName(netid)
     f = open("/dev/net/tun", "r+b")
     ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
     ifr += "\x00" * (40 - len(ifr))
     fcntl.ioctl(f, TUNSETIFF, ifr)
     # Give ourselves a predictable MAC address.
-    net_test.SetInterfaceHWAddr(iface, cls._MyMacAddress(netid))
+    net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
     # Disable DAD so we don't have to wait for it.
     open("/proc/sys/net/ipv6/conf/%s/dad_transmits" % iface, "w").write("0")
     net_test.SetInterfaceUp(iface)
@@ -297,7 +288,7 @@
   def _SendRA(cls, netid):
     validity = 300                 # seconds
     validity_ms = validity * 1000  # milliseconds
-    macaddr = cls._RouterMacAddress(netid)
+    macaddr = cls.RouterMacAddress(netid)
     lladdr = cls._RouterAddress(netid, 6)
 
     # We don't want any routes in the main table. If the kernel doesn't support
@@ -335,7 +326,7 @@
     for version, iptables in zip([4, 6], ["iptables", "ip6tables"]):
 
       table = cls._TableForNetid(netid)
-      uid = cls._UidForNetid(netid)
+      uid = cls.UidForNetid(netid)
       if HAVE_EXPERIMENTAL_UID_ROUTING:
         cls.iproute.UidRule(version, is_add, uid, table, priority=100)
       cls.iproute.FwmarkRule(version, is_add, netid, table, priority=200)
@@ -358,10 +349,10 @@
       cmds = str("\n".join(cmds) % {
           "add_del": "add" if is_add else "del",
           "append_delete": "-A" if is_add else "-D",
-          "iface": cls._GetInterfaceName(netid),
+          "iface": cls.GetInterfaceName(netid),
           "iptables": iptables,
           "ipv4addr": cls._MyIPv4Address(netid),
-          "macaddr": cls._RouterMacAddress(netid),
+          "macaddr": cls.RouterMacAddress(netid),
           "mark": netid,
           "router": cls._RouterAddress(netid, version),
           "table": table,
@@ -375,16 +366,16 @@
           raise ConfigurationError("Setup command failed: %s" % " ".join(cmd))
 
   @classmethod
-  def _GetSysctl(cls, sysctl):
+  def GetSysctl(cls, sysctl):
     return open(sysctl, "r").read()
 
   @classmethod
-  def _SetSysctl(cls, sysctl, value):
+  def SetSysctl(cls, sysctl, value):
     # Only save each sysctl value the first time we set it. This is so we can
     # set it to arbitrary values multiple times and still write it back
     # correctly at the end.
     if sysctl not in cls.saved_sysctls:
-      cls.saved_sysctls[sysctl] = cls._GetSysctl(sysctl)
+      cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
     open(sysctl, "w").write(str(value) + "\n")
 
   @classmethod
@@ -399,21 +390,7 @@
 
   @classmethod
   def _SetICMPRatelimit(cls, version, limit):
-    cls._SetSysctl(cls._ICMPRatelimitFilename(version), limit)
-
-  @classmethod
-  def _SetMarkReflectSysctls(cls, value):
-    cls._SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
-    try:
-      cls._SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
-    except IOError:
-      # This does not exist if we use the version of the patch that uses a
-      # common sysctl for IPv4 and IPv6.
-      pass
-
-  @classmethod
-  def _SetTCPMarkAcceptSysctl(cls, value):
-    cls._SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
+    cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
 
   @classmethod
   def setUpClass(cls):
@@ -424,7 +401,7 @@
     cls.tuns = {}
     cls.ifindices = {}
     if HAVE_AUTOCONF_TABLE:
-      cls._SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
+      cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
       cls.AUTOCONF_TABLE_OFFSET = -1000
     else:
       cls.AUTOCONF_TABLE_OFFSET = None
@@ -436,12 +413,58 @@
     for netid in cls.NETIDS:
       cls.tuns[netid] = cls._CreateTunInterface(netid)
 
-      iface = cls._GetInterfaceName(netid)
+      iface = cls.GetInterfaceName(netid)
       cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
 
       cls._SendRA(netid)
       cls._RunSetupCommands(netid, True)
 
+    # Uncomment to look around at interface and rule configuration while
+    # running in the background. (Once the test finishes running, all the
+    # interfaces and rules are gone.)
+    # time.sleep(30)
+
+  @classmethod
+  def tearDownClass(cls):
+    for netid in cls.tuns:
+      cls._RunSetupCommands(netid, False)
+      cls.tuns[netid].close()
+    cls._RestoreSysctls()
+
+  def SetSocketMark(self, s, netid):
+    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
+
+  def GetSocketMark(self, s):
+    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
+
+  def ReceivePacketOn(self, netid, ip_packet):
+    routermac = self.RouterMacAddress(netid)
+    mymac = self.MyMacAddress(netid)
+    packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
+    posix.write(self.tuns[netid].fileno(), str(packet))
+
+  def ClearTunQueues(self):
+    # Keep reading packets on all netids until we get no packets on any of them.
+    waiting = None
+    while waiting != 0:
+      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
+
+
+class MarkTest(MultiNetworkTest):
+
+  # How many times to run packet reflection tests.
+  ITERATIONS = 5
+
+  # For convenience.
+  IPV4_ADDR = net_test.IPV4_ADDR
+  IPV6_ADDR = net_test.IPV6_ADDR
+  IPV4_PING = net_test.IPV4_PING
+  IPV6_PING = net_test.IPV6_PING
+
+  @classmethod
+  def setUpClass(cls):
+    super(MarkTest, cls).setUpClass()
+
     # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
     # will accept both IPv4 and IPv6 connections. We do this here instead of in
     # each test so we can use the same socket every time. That way, if a kernel
@@ -454,17 +477,19 @@
     cls.listensocket.bind(("::", cls.listenport))
     cls.listensocket.listen(100)
 
-    # Uncomment to look around at interface and rule configuration while
-    # running in the background. (Once the test finishes running, all the
-    # interfaces and rules are gone.)
-    # time.sleep(30)
+  @classmethod
+  def _SetMarkReflectSysctls(cls, value):
+    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
+    try:
+      cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
+    except IOError:
+      # This does not exist if we use the version of the patch that uses a
+      # common sysctl for IPv4 and IPv6.
+      pass
 
   @classmethod
-  def tearDownClass(cls):
-    for netid in cls.tuns:
-      cls._RunSetupCommands(netid, False)
-      cls.tuns[netid].close()
-    cls._RestoreSysctls()
+  def _SetTCPMarkAcceptSysctl(cls, value):
+    cls.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
 
   def assertPacketMatches(self, expected, actual):
     # The expected packet is just a rough sketch of the packet we expect to
@@ -578,37 +603,18 @@
       raise UnexpectedPacketError(
           "%s: diff with last packet:\n%s" % (msg, e.message))
 
-  def ReceivePacketOn(self, netid, ip_packet):
-    routermac = self._RouterMacAddress(netid)
-    mymac = self._MyMacAddress(netid)
-    packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
-    posix.write(self.tuns[netid].fileno(), str(packet))
-
-  def ClearTunQueues(self):
-    # Keep reading packets on all netids until we get no packets on any of them.
-    waiting = None
-    while waiting != 0:
-      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
-
   def setUp(self):
     self.ClearTunQueues()
 
-  @classmethod
-  def _GetRemoteAddress(cls, version):
-    return {4: cls.IPV4_ADDR, 6: cls.IPV6_ADDR}[version]
+  def _GetRemoteAddress(self, version):
+    return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
 
-  def SetSocketMark(self, s, netid):
-    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
-
-  def GetSocketMark(self, s):
-    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
-
-  def GetProtocolFamily(self, version):
+  def _GetProtocolFamily(self, version):
     return {4: AF_INET, 6: AF_INET6}[version]
 
   def BuildSocket(self, version, constructor, mark, uid):
     with RunAsUid(uid):
-      family = self.GetProtocolFamily(version)
+      family = self._GetProtocolFamily(version)
       s = constructor(family)
     if mark:
       self.SetSocketMark(s, mark)
@@ -618,7 +624,7 @@
                       expected_netid):
     s = self.BuildSocket(version, net_test.PingSocket, mark, uid)
 
-    myaddr = self._MyAddress(version, expected_netid)
+    myaddr = self.MyAddress(version, expected_netid)
     s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
     s.bind((myaddr, PING_IDENT))
     net_test.SetSocketTos(s, PING_TOS)
@@ -628,7 +634,7 @@
     self.ClearTunQueues()
     s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
     msg = "IPv%d ping: expected %s on %s" % (
-        version, desc, self._GetInterfaceName(expected_netid))
+        version, desc, self.GetInterfaceName(expected_netid))
     self.ExpectPacketOn(expected_netid, msg, expected)
 
   def CheckTCPSYNPacket(self, version, mark, uid, dstaddr, expected_netid):
@@ -636,7 +642,7 @@
 
     if version == 6 and dstaddr.startswith("::ffff"):
       version = 4
-    myaddr = self._MyAddress(version, expected_netid)
+    myaddr = self.MyAddress(version, expected_netid)
     desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
                                  sport=None, seq=None)
 
@@ -644,7 +650,7 @@
     # Non-blocking TCP connects always return EINPROGRESS.
     self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
     msg = "IPv%s TCP connect: expected %s on %s" % (
-        version, desc, self._GetInterfaceName(expected_netid))
+        version, desc, self.GetInterfaceName(expected_netid))
     self.ExpectPacketOn(expected_netid, msg, expected)
     s.close()
 
@@ -653,10 +659,10 @@
 
     if version == 6 and dstaddr.startswith("::ffff"):
       version = 4
-    myaddr = self._MyAddress(version, expected_netid)
+    myaddr = self.MyAddress(version, expected_netid)
     desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
     msg = "IPv%s UDP %%s: expected %s on %s" % (
-        version, desc, self._GetInterfaceName(expected_netid))
+        version, desc, self.GetInterfaceName(expected_netid))
 
     self.ClearTunQueues()
     s.sendto(UDP_PAYLOAD, (dstaddr, 53))
@@ -690,18 +696,18 @@
     """Checks that UID routing selects the right outgoing interface."""
     for _ in xrange(self.ITERATIONS):
       for netid in self.tuns:
-        uid = self._UidForNetid(netid)
+        uid = self.UidForNetid(netid)
         self.CheckPingPacket(4, 0, uid, self.IPV4_ADDR, self.IPV4_PING, netid)
         self.CheckPingPacket(6, 0, uid, self.IPV6_ADDR, self.IPV6_PING, netid)
 
       for netid in self.tuns:
-        uid = self._UidForNetid(netid)
+        uid = self.UidForNetid(netid)
         self.CheckTCPSYNPacket(4, 0, uid, self.IPV4_ADDR, netid)
         self.CheckTCPSYNPacket(6, 0, uid, self.IPV6_ADDR, netid)
         self.CheckTCPSYNPacket(6, 0, uid, "::ffff:" + self.IPV4_ADDR, netid)
 
       for netid in self.tuns:
-        uid = self._UidForNetid(netid)
+        uid = self.UidForNetid(netid)
         self.CheckUDPPacket(4, 0, uid, self.IPV4_ADDR, netid)
         self.CheckUDPPacket(6, 0, uid, self.IPV6_ADDR, netid)
         self.CheckUDPPacket(6, 0, uid, "::ffff:" + self.IPV4_ADDR, netid)
@@ -738,14 +744,14 @@
 
     # Check packets addressed to the IP addresses of all our interfaces...
     for dest_ip_netid in self.tuns:
-      dest_ip_iface = self._GetInterfaceName(dest_ip_netid)
+      dest_ip_iface = self.GetInterfaceName(dest_ip_netid)
 
-      myaddr = self._MyAddress(version, dest_ip_netid)
+      myaddr = self.MyAddress(version, dest_ip_netid)
       remote_addr = self._GetRemoteAddress(version)
 
       # ... coming in on all our interfaces...
       for iif_netid in self.tuns:
-        iif = self._GetInterfaceName(iif_netid)
+        iif = self.GetInterfaceName(iif_netid)
         desc, packet = packet_generator(version, remote_addr, myaddr)
         reply_desc, reply = reply_generator(version, myaddr, remote_addr,
                                             packet)
@@ -799,8 +805,8 @@
   def testIPv6RSTsReflectMark(self):
     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST, "reflect")
 
-  def CheckAcceptedSocketMarkCallback(self, netid, version, myaddr,
-                                      remote_addr, packet, reply, msg):
+  def CheckTCPConnection(self, netid, version, myaddr, remote_addr,
+                         packet, reply, msg):
     establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
     self.ReceivePacketOn(netid, establishing_ack)
     s, unused_peer = self.listensocket.accept()
@@ -829,25 +835,25 @@
   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
   def testIPv4TCPConnections(self):
     self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                         self.CheckAcceptedSocketMarkCallback)
+                         self.CheckTCPConnection)
 
   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
   def testIPv6TCPConnections(self):
     self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                         self.CheckAcceptedSocketMarkCallback)
+                         self.CheckTCPConnection)
 
   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
   def testTCPConnectionsWithSynCookies(self):
     # Force SYN cookies on all connections.
-    self._SetSysctl(SYNCOOKIES_SYSCTL, 2)
+    self.SetSysctl(SYNCOOKIES_SYSCTL, 2)
     try:
       self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                           self.CheckAcceptedSocketMarkCallback)
+                           self.CheckTCPConnection)
       self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
-                           self.CheckAcceptedSocketMarkCallback)
+                           self.CheckTCPConnection)
     finally:
       # Stop forcing SYN cookies on all connections.
-      self._SetSysctl(SYNCOOKIES_SYSCTL, 1)
+      self.SetSysctl(SYNCOOKIES_SYSCTL, 1)
 
 
 if __name__ == "__main__":