Simplify putting sockets onto networks.
Change-Id: Ibc82cdf3c8dd80f8bcab84b5a76f1e4d36069c89
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index 594ff04..18baf34 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -483,6 +483,8 @@
cls._RestoreSysctls()
def SetSocketMark(self, s, netid):
+ if netid is None:
+ netid = 0
s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
def GetSocketMark(self, s):
@@ -496,19 +498,15 @@
iface = ""
s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
- def SetUnicastInterface(self, version, s, iface):
- if iface:
- ifindex = net_test.GetInterfaceIndex(iface)
- else:
- ifindex = 0
- # Otherwise, Python apparently thinks it's a 1-byte option.
+ def SetUnicastInterface(self, s, ifindex):
+ # Otherwise, Python thinks it's a 1-byte option.
ifindex = struct.pack("!I", ifindex)
- layer, opt = {
- 4: (net_test.SOL_IP, IP_UNICAST_IF),
- 6: (net_test.SOL_IPV6, IPV6_UNICAST_IF),
- }[version]
- s.setsockopt(layer, opt, ifindex)
+ # Always set the IPv4 interface, because it will be used even on IPv6
+ # sockets if the destination address is a mapped address.
+ s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
+ if s.family == AF_INET6:
+ s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
def ReceiveEtherPacketOn(self, netid, packet):
posix.write(self.tuns[netid].fileno(), str(packet))
@@ -698,71 +696,81 @@
def _GetRemoteAddress(self, version):
return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
- def BuildSocket(self, version, constructor, mark, uid, oif, ucast_oif):
+ def SelectInterface(self, s, netid, mode):
+ if mode == "uid":
+ raise ValueError("Can't change UID on an existing socket")
+ elif mode == "mark":
+ self.SetSocketMark(s, netid)
+ elif mode == "oif":
+ iface = self.GetInterfaceName(netid) if netid else ""
+ self.BindToDevice(s, iface)
+ elif mode == "ucast_oif":
+ self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
+ else:
+ raise ValueError("Unkown interface selection mode %s" % mode)
+
+ def BuildSocket(self, version, constructor, netid, routing_mode):
+ uid = self.UidForNetid(netid) if routing_mode == "uid" else None
with RunAsUid(uid):
family = self.GetProtocolFamily(version)
s = constructor(family)
- if mark:
- self.SetSocketMark(s, mark)
- if oif:
- self.BindToDevice(s, oif)
- if ucast_oif:
- self.SetUnicastInterface(version, s, ucast_oif)
+
+ if routing_mode not in [None, "uid"]:
+ self.SelectInterface(s, netid, routing_mode)
+
return s
- def CheckPingPacket(self, version, mark, uid, oif, ucast_oif, dstaddr, packet,
- expected_netid):
- s = self.BuildSocket(version, net_test.PingSocket, mark, uid, oif,
- ucast_oif)
+ def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet):
+ s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
- myaddr = self.MyAddress(version, expected_netid)
+ myaddr = self.MyAddress(version, netid)
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
s.bind((myaddr, PING_IDENT))
net_test.SetSocketTos(s, PING_TOS)
desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
msg = "IPv%d ping: expected %s on %s" % (
- version, desc, self.GetInterfaceName(expected_netid))
+ version, desc, self.GetInterfaceName(netid))
s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
- self.ExpectPacketOn(expected_netid, msg, expected)
+ self.ExpectPacketOn(netid, msg, expected)
- def CheckTCPSYNPacket(self, version, mark, uid, oif, ucast_oif, dstaddr,
- expected_netid):
- s = self.BuildSocket(version, net_test.TCPSocket, mark, uid, oif, ucast_oif)
+ def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr):
+ s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
if version == 6 and dstaddr.startswith("::ffff"):
version = 4
- myaddr = self.MyAddress(version, expected_netid)
+ myaddr = self.MyAddress(version, netid)
desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
sport=None, seq=None)
# Non-blocking TCP connects always return EINPROGRESS.
self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
msg = "IPv%s TCP connect: expected %s on %s" % (
- version, desc, self.GetInterfaceName(expected_netid))
- self.ExpectPacketOn(expected_netid, msg, expected)
+ version, desc, self.GetInterfaceName(netid))
+ self.ExpectPacketOn(netid, msg, expected)
s.close()
- def CheckUDPPacket(self, version, mark, uid, oif, ucast_oif,
- dstaddr, expected_netid):
- s = self.BuildSocket(version, net_test.UDPSocket, mark, uid, oif, ucast_oif)
+ def CheckUDPPacket(self, version, netid, routing_mode, dstaddr):
+ s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
if version == 6 and dstaddr.startswith("::ffff"):
version = 4
- myaddr = self.MyAddress(version, expected_netid)
+ myaddr = self.MyAddress(version, netid)
desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
msg = "IPv%s UDP %%s: expected %s on %s" % (
- version, desc, self.GetInterfaceName(expected_netid))
+ version, desc, self.GetInterfaceName(netid))
s.sendto(UDP_PAYLOAD, (dstaddr, 53))
- self.ExpectPacketOn(expected_netid, msg % "sendto", expected)
+ self.ExpectPacketOn(netid, msg % "sendto", expected)
- s.connect((dstaddr, 53))
- s.send(UDP_PAYLOAD)
- self.ExpectPacketOn(expected_netid, msg % "connect/send", expected)
- s.close()
+ # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
+ if routing_mode != "ucast_oif":
+ s.connect((dstaddr, 53))
+ s.send(UDP_PAYLOAD)
+ self.ExpectPacketOn(netid, msg % "connect/send", expected)
+ s.close()
- def CheckOutgoingPackets(self, mode):
+ def CheckOutgoingPackets(self, routing_mode):
v4addr = self.IPV4_ADDR
v6addr = self.IPV6_ADDR
v4mapped = "::ffff:" + v4addr
@@ -770,35 +778,18 @@
for _ in xrange(self.ITERATIONS):
for netid in self.tuns:
- mark = uid = oif = ucast_oif = None
- if mode == "mark":
- mark = netid
- elif mode == "uid":
- uid = self.UidForNetid(netid)
- elif mode == "oif":
- oif = self.GetInterfaceName(netid)
- elif mode == "ucast_oif":
- ucast_oif = self.GetInterfaceName(netid)
- else:
- raise ValueError("Unkown routing mode %s" % mode)
+ self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING)
+ self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING)
- self.CheckPingPacket(4, mark, uid, oif, ucast_oif, v4addr,
- self.IPV4_PING, netid)
- self.CheckPingPacket(6, mark, uid, oif, ucast_oif, v6addr,
- self.IPV6_PING, netid)
+ # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
+ if routing_mode != "ucast_oif":
+ self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr)
+ self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr)
+ self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped)
- # TCP doesn't seem to honour IP_UNICAST_IF.
- if mode != "ucast_oif":
- self.CheckTCPSYNPacket(4, mark, uid, oif, ucast_oif, v4addr, netid)
- self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v6addr, netid)
- self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid)
-
- if mode != "ucast_oif":
- # This doesn't work.
- self.CheckUDPPacket(4, mark, uid, oif, ucast_oif, v4addr, netid)
- # These work, but the source addresses are incorrect.
- self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v6addr, netid)
- self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid)
+ self.CheckUDPPacket(4, netid, routing_mode, v4addr)
+ self.CheckUDPPacket(6, netid, routing_mode, v6addr)
+ self.CheckUDPPacket(6, netid, routing_mode, v4mapped)
def testMarkRouting(self):
"""Checks that socket marking selects the right outgoing interface."""
@@ -817,34 +808,53 @@
"""Checks that ucast oif routing selects the right outgoing interface."""
self.CheckOutgoingPackets("ucast_oif")
- def CheckRemarking(self, version):
- s = net_test.UDPSocket(self.GetProtocolFamily(version))
+ def CheckRemarking(self, version, use_connect):
+ # Remarking or resetting UNICAST_IF on connected sockets does not work.
+ if use_connect:
+ modes = ["oif"]
+ else:
+ modes = ["mark", "oif", "ucast_oif"]
- # Figure out what packets to expect.
- unspec = {4: "0.0.0.0", 6: "::"}[version]
- sport = Packets.RandomPort()
- s.bind((unspec, sport))
- dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
- desc, expected = Packets.UDP(version, unspec, dstaddr, sport)
+ for mode in modes:
+ s = net_test.UDPSocket(self.GetProtocolFamily(version))
- # For each netid, set that netid's mark on the socket without closing it,
- # and check that the packets sent on that socket go out on the right
- # network.
- for netid in self.tuns:
- self.SetSocketMark(s, netid)
- expected.src = self.MyAddress(version, netid)
- s.sendto("hello", (dstaddr, 53))
- msg = "Remarked UDPv%d socket: expecting %s on %s" % (
- version, desc, self.GetInterfaceName(netid))
- self.ExpectPacketOn(netid, msg, expected)
+ # Figure out what packets to expect.
+ unspec = {4: "0.0.0.0", 6: "::"}[version]
+ sport = Packets.RandomPort()
+ s.bind((unspec, sport))
+ dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
+ desc, expected = Packets.UDP(version, unspec, dstaddr, sport)
+
+ # If we're testing connected sockets, connect the socket on the first
+ # netid now.
+ if use_connect:
+ netid = self.tuns.keys()[0]
+ self.SelectInterface(s, netid, mode)
+ s.connect((dstaddr, 53))
+ expected.src = self.MyAddress(version, netid)
+
+ # For each netid, select that network without closing the socket, and
+ # check that the packets sent on that socket go out on the right network.
+ for netid in self.tuns:
+ self.SelectInterface(s, netid, mode)
+ if not use_connect:
+ expected.src = self.MyAddress(version, netid)
+ s.sendto("hello", (dstaddr, 53))
+ connected_str = "Connected" if use_connect else "Unconnected"
+ msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
+ connected_str, version, mode, desc, self.GetInterfaceName(netid))
+ self.ExpectPacketOn(netid, msg, expected)
+ self.SelectInterface(s, None, mode)
def testIPv4Remarking(self):
"""Checks that updating the mark on an IPv4 socket changes routing."""
- self.CheckRemarking(4)
+ self.CheckRemarking(4, False)
+ self.CheckRemarking(4, True)
def testIPv6Remarking(self):
"""Checks that updating the mark on an IPv6 socket changes routing."""
- self.CheckRemarking(6)
+ self.CheckRemarking(6, False)
+ self.CheckRemarking(6, True)
def CheckReflection(self, version, packet_generator, reply_generator,
mark_behaviour, callback=None):