Refactor setup and common code into a superclass.
Change-Id: Iee489954175de6eec12b711d6c3ebb9a64cfd6c3
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index 48c9e98..a5cd120 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -213,10 +213,7 @@
os.seteuid(self.saved_uid)
-class MarkTest(net_test.NetworkTest):
-
- # How many times to run packet reflection tests.
- ITERATIONS = 5
+class MultiNetworkTest(net_test.NetworkTest):
# Must be between 1 and 256, since we put them in MAC addresses and IIDs.
NETIDS = [100, 150, 200, 250]
@@ -224,18 +221,12 @@
# Stores sysctl values to write back when the test completes.
saved_sysctls = {}
- # For convenience.
- IPV4_ADDR = net_test.IPV4_ADDR
- IPV6_ADDR = net_test.IPV6_ADDR
- IPV4_PING = net_test.IPV4_PING
- IPV6_PING = net_test.IPV6_PING
-
# Wether to output setup commands.
DEBUG = False
@staticmethod
- def _GetInterfaceName(netid):
- return "nettest%d" % netid
+ def UidForNetid(netid):
+ return 2000 + netid
@classmethod
def _TableForNetid(cls, netid):
@@ -245,15 +236,15 @@
return netid
@staticmethod
- def _UidForNetid(netid):
- return 2000 + netid
+ def GetInterfaceName(netid):
+ return "nettest%d" % netid
@staticmethod
- def _RouterMacAddress(netid):
+ def RouterMacAddress(netid):
return "02:00:00:00:%02x:00" % netid
@staticmethod
- def _MyMacAddress(netid):
+ def MyMacAddress(netid):
return "02:00:00:00:%02x:01" % netid
@staticmethod
@@ -271,22 +262,22 @@
@classmethod
def _MyIPv6Address(cls, netid):
- return net_test.GetLinkAddress(cls._GetInterfaceName(netid), False)
+ return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
@classmethod
- def _MyAddress(cls, version, netid):
+ def MyAddress(cls, version, netid):
return {4: cls._MyIPv4Address(netid),
6: cls._MyIPv6Address(netid)}[version]
@classmethod
def _CreateTunInterface(cls, netid):
- iface = cls._GetInterfaceName(netid)
+ iface = cls.GetInterfaceName(netid)
f = open("/dev/net/tun", "r+b")
ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
ifr += "\x00" * (40 - len(ifr))
fcntl.ioctl(f, TUNSETIFF, ifr)
# Give ourselves a predictable MAC address.
- net_test.SetInterfaceHWAddr(iface, cls._MyMacAddress(netid))
+ net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
# Disable DAD so we don't have to wait for it.
open("/proc/sys/net/ipv6/conf/%s/dad_transmits" % iface, "w").write("0")
net_test.SetInterfaceUp(iface)
@@ -297,7 +288,7 @@
def _SendRA(cls, netid):
validity = 300 # seconds
validity_ms = validity * 1000 # milliseconds
- macaddr = cls._RouterMacAddress(netid)
+ macaddr = cls.RouterMacAddress(netid)
lladdr = cls._RouterAddress(netid, 6)
# We don't want any routes in the main table. If the kernel doesn't support
@@ -335,7 +326,7 @@
for version, iptables in zip([4, 6], ["iptables", "ip6tables"]):
table = cls._TableForNetid(netid)
- uid = cls._UidForNetid(netid)
+ uid = cls.UidForNetid(netid)
if HAVE_EXPERIMENTAL_UID_ROUTING:
cls.iproute.UidRule(version, is_add, uid, table, priority=100)
cls.iproute.FwmarkRule(version, is_add, netid, table, priority=200)
@@ -358,10 +349,10 @@
cmds = str("\n".join(cmds) % {
"add_del": "add" if is_add else "del",
"append_delete": "-A" if is_add else "-D",
- "iface": cls._GetInterfaceName(netid),
+ "iface": cls.GetInterfaceName(netid),
"iptables": iptables,
"ipv4addr": cls._MyIPv4Address(netid),
- "macaddr": cls._RouterMacAddress(netid),
+ "macaddr": cls.RouterMacAddress(netid),
"mark": netid,
"router": cls._RouterAddress(netid, version),
"table": table,
@@ -375,16 +366,16 @@
raise ConfigurationError("Setup command failed: %s" % " ".join(cmd))
@classmethod
- def _GetSysctl(cls, sysctl):
+ def GetSysctl(cls, sysctl):
return open(sysctl, "r").read()
@classmethod
- def _SetSysctl(cls, sysctl, value):
+ def SetSysctl(cls, sysctl, value):
# Only save each sysctl value the first time we set it. This is so we can
# set it to arbitrary values multiple times and still write it back
# correctly at the end.
if sysctl not in cls.saved_sysctls:
- cls.saved_sysctls[sysctl] = cls._GetSysctl(sysctl)
+ cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
open(sysctl, "w").write(str(value) + "\n")
@classmethod
@@ -399,21 +390,7 @@
@classmethod
def _SetICMPRatelimit(cls, version, limit):
- cls._SetSysctl(cls._ICMPRatelimitFilename(version), limit)
-
- @classmethod
- def _SetMarkReflectSysctls(cls, value):
- cls._SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
- try:
- cls._SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
- except IOError:
- # This does not exist if we use the version of the patch that uses a
- # common sysctl for IPv4 and IPv6.
- pass
-
- @classmethod
- def _SetTCPMarkAcceptSysctl(cls, value):
- cls._SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
+ cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
@classmethod
def setUpClass(cls):
@@ -424,7 +401,7 @@
cls.tuns = {}
cls.ifindices = {}
if HAVE_AUTOCONF_TABLE:
- cls._SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
+ cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
cls.AUTOCONF_TABLE_OFFSET = -1000
else:
cls.AUTOCONF_TABLE_OFFSET = None
@@ -436,12 +413,58 @@
for netid in cls.NETIDS:
cls.tuns[netid] = cls._CreateTunInterface(netid)
- iface = cls._GetInterfaceName(netid)
+ iface = cls.GetInterfaceName(netid)
cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
cls._SendRA(netid)
cls._RunSetupCommands(netid, True)
+ # Uncomment to look around at interface and rule configuration while
+ # running in the background. (Once the test finishes running, all the
+ # interfaces and rules are gone.)
+ # time.sleep(30)
+
+ @classmethod
+ def tearDownClass(cls):
+ for netid in cls.tuns:
+ cls._RunSetupCommands(netid, False)
+ cls.tuns[netid].close()
+ cls._RestoreSysctls()
+
+ def SetSocketMark(self, s, netid):
+ s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
+
+ def GetSocketMark(self, s):
+ return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
+
+ def ReceivePacketOn(self, netid, ip_packet):
+ routermac = self.RouterMacAddress(netid)
+ mymac = self.MyMacAddress(netid)
+ packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
+ posix.write(self.tuns[netid].fileno(), str(packet))
+
+ def ClearTunQueues(self):
+ # Keep reading packets on all netids until we get no packets on any of them.
+ waiting = None
+ while waiting != 0:
+ waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
+
+
+class MarkTest(MultiNetworkTest):
+
+ # How many times to run packet reflection tests.
+ ITERATIONS = 5
+
+ # For convenience.
+ IPV4_ADDR = net_test.IPV4_ADDR
+ IPV6_ADDR = net_test.IPV6_ADDR
+ IPV4_PING = net_test.IPV4_PING
+ IPV6_PING = net_test.IPV6_PING
+
+ @classmethod
+ def setUpClass(cls):
+ super(MarkTest, cls).setUpClass()
+
# Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
# will accept both IPv4 and IPv6 connections. We do this here instead of in
# each test so we can use the same socket every time. That way, if a kernel
@@ -454,17 +477,19 @@
cls.listensocket.bind(("::", cls.listenport))
cls.listensocket.listen(100)
- # Uncomment to look around at interface and rule configuration while
- # running in the background. (Once the test finishes running, all the
- # interfaces and rules are gone.)
- # time.sleep(30)
+ @classmethod
+ def _SetMarkReflectSysctls(cls, value):
+ cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
+ try:
+ cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
+ except IOError:
+ # This does not exist if we use the version of the patch that uses a
+ # common sysctl for IPv4 and IPv6.
+ pass
@classmethod
- def tearDownClass(cls):
- for netid in cls.tuns:
- cls._RunSetupCommands(netid, False)
- cls.tuns[netid].close()
- cls._RestoreSysctls()
+ def _SetTCPMarkAcceptSysctl(cls, value):
+ cls.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
def assertPacketMatches(self, expected, actual):
# The expected packet is just a rough sketch of the packet we expect to
@@ -578,37 +603,18 @@
raise UnexpectedPacketError(
"%s: diff with last packet:\n%s" % (msg, e.message))
- def ReceivePacketOn(self, netid, ip_packet):
- routermac = self._RouterMacAddress(netid)
- mymac = self._MyMacAddress(netid)
- packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
- posix.write(self.tuns[netid].fileno(), str(packet))
-
- def ClearTunQueues(self):
- # Keep reading packets on all netids until we get no packets on any of them.
- waiting = None
- while waiting != 0:
- waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
-
def setUp(self):
self.ClearTunQueues()
- @classmethod
- def _GetRemoteAddress(cls, version):
- return {4: cls.IPV4_ADDR, 6: cls.IPV6_ADDR}[version]
+ def _GetRemoteAddress(self, version):
+ return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
- def SetSocketMark(self, s, netid):
- s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
-
- def GetSocketMark(self, s):
- return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
-
- def GetProtocolFamily(self, version):
+ def _GetProtocolFamily(self, version):
return {4: AF_INET, 6: AF_INET6}[version]
def BuildSocket(self, version, constructor, mark, uid):
with RunAsUid(uid):
- family = self.GetProtocolFamily(version)
+ family = self._GetProtocolFamily(version)
s = constructor(family)
if mark:
self.SetSocketMark(s, mark)
@@ -618,7 +624,7 @@
expected_netid):
s = self.BuildSocket(version, net_test.PingSocket, mark, uid)
- myaddr = self._MyAddress(version, expected_netid)
+ myaddr = self.MyAddress(version, expected_netid)
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
s.bind((myaddr, PING_IDENT))
net_test.SetSocketTos(s, PING_TOS)
@@ -628,7 +634,7 @@
self.ClearTunQueues()
s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
msg = "IPv%d ping: expected %s on %s" % (
- version, desc, self._GetInterfaceName(expected_netid))
+ version, desc, self.GetInterfaceName(expected_netid))
self.ExpectPacketOn(expected_netid, msg, expected)
def CheckTCPSYNPacket(self, version, mark, uid, dstaddr, expected_netid):
@@ -636,7 +642,7 @@
if version == 6 and dstaddr.startswith("::ffff"):
version = 4
- myaddr = self._MyAddress(version, expected_netid)
+ myaddr = self.MyAddress(version, expected_netid)
desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
sport=None, seq=None)
@@ -644,7 +650,7 @@
# Non-blocking TCP connects always return EINPROGRESS.
self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
msg = "IPv%s TCP connect: expected %s on %s" % (
- version, desc, self._GetInterfaceName(expected_netid))
+ version, desc, self.GetInterfaceName(expected_netid))
self.ExpectPacketOn(expected_netid, msg, expected)
s.close()
@@ -653,10 +659,10 @@
if version == 6 and dstaddr.startswith("::ffff"):
version = 4
- myaddr = self._MyAddress(version, expected_netid)
+ myaddr = self.MyAddress(version, expected_netid)
desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
msg = "IPv%s UDP %%s: expected %s on %s" % (
- version, desc, self._GetInterfaceName(expected_netid))
+ version, desc, self.GetInterfaceName(expected_netid))
self.ClearTunQueues()
s.sendto(UDP_PAYLOAD, (dstaddr, 53))
@@ -690,18 +696,18 @@
"""Checks that UID routing selects the right outgoing interface."""
for _ in xrange(self.ITERATIONS):
for netid in self.tuns:
- uid = self._UidForNetid(netid)
+ uid = self.UidForNetid(netid)
self.CheckPingPacket(4, 0, uid, self.IPV4_ADDR, self.IPV4_PING, netid)
self.CheckPingPacket(6, 0, uid, self.IPV6_ADDR, self.IPV6_PING, netid)
for netid in self.tuns:
- uid = self._UidForNetid(netid)
+ uid = self.UidForNetid(netid)
self.CheckTCPSYNPacket(4, 0, uid, self.IPV4_ADDR, netid)
self.CheckTCPSYNPacket(6, 0, uid, self.IPV6_ADDR, netid)
self.CheckTCPSYNPacket(6, 0, uid, "::ffff:" + self.IPV4_ADDR, netid)
for netid in self.tuns:
- uid = self._UidForNetid(netid)
+ uid = self.UidForNetid(netid)
self.CheckUDPPacket(4, 0, uid, self.IPV4_ADDR, netid)
self.CheckUDPPacket(6, 0, uid, self.IPV6_ADDR, netid)
self.CheckUDPPacket(6, 0, uid, "::ffff:" + self.IPV4_ADDR, netid)
@@ -738,14 +744,14 @@
# Check packets addressed to the IP addresses of all our interfaces...
for dest_ip_netid in self.tuns:
- dest_ip_iface = self._GetInterfaceName(dest_ip_netid)
+ dest_ip_iface = self.GetInterfaceName(dest_ip_netid)
- myaddr = self._MyAddress(version, dest_ip_netid)
+ myaddr = self.MyAddress(version, dest_ip_netid)
remote_addr = self._GetRemoteAddress(version)
# ... coming in on all our interfaces...
for iif_netid in self.tuns:
- iif = self._GetInterfaceName(iif_netid)
+ iif = self.GetInterfaceName(iif_netid)
desc, packet = packet_generator(version, remote_addr, myaddr)
reply_desc, reply = reply_generator(version, myaddr, remote_addr,
packet)
@@ -799,8 +805,8 @@
def testIPv6RSTsReflectMark(self):
self.CheckReflection(6, self.SYNToClosedPort, Packets.RST, "reflect")
- def CheckAcceptedSocketMarkCallback(self, netid, version, myaddr,
- remote_addr, packet, reply, msg):
+ def CheckTCPConnection(self, netid, version, myaddr, remote_addr,
+ packet, reply, msg):
establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
self.ReceivePacketOn(netid, establishing_ack)
s, unused_peer = self.listensocket.accept()
@@ -829,25 +835,25 @@
@unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
def testIPv4TCPConnections(self):
self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
- self.CheckAcceptedSocketMarkCallback)
+ self.CheckTCPConnection)
@unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
def testIPv6TCPConnections(self):
self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
- self.CheckAcceptedSocketMarkCallback)
+ self.CheckTCPConnection)
@unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
def testTCPConnectionsWithSynCookies(self):
# Force SYN cookies on all connections.
- self._SetSysctl(SYNCOOKIES_SYSCTL, 2)
+ self.SetSysctl(SYNCOOKIES_SYSCTL, 2)
try:
self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
- self.CheckAcceptedSocketMarkCallback)
+ self.CheckTCPConnection)
self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
- self.CheckAcceptedSocketMarkCallback)
+ self.CheckTCPConnection)
finally:
# Stop forcing SYN cookies on all connections.
- self._SetSysctl(SYNCOOKIES_SYSCTL, 1)
+ self.SetSysctl(SYNCOOKIES_SYSCTL, 1)
if __name__ == "__main__":