Simplify putting sockets onto networks.

Change-Id: Ibc82cdf3c8dd80f8bcab84b5a76f1e4d36069c89
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index 594ff04..18baf34 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -483,6 +483,8 @@
     cls._RestoreSysctls()
 
   def SetSocketMark(self, s, netid):
+    if netid is None:
+      netid = 0
     s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
 
   def GetSocketMark(self, s):
@@ -496,19 +498,15 @@
       iface = ""
     s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
 
-  def SetUnicastInterface(self, version, s, iface):
-    if iface:
-      ifindex = net_test.GetInterfaceIndex(iface)
-    else:
-      ifindex = 0
-    # Otherwise, Python apparently thinks it's a 1-byte option.
+  def SetUnicastInterface(self, s, ifindex):
+    # Otherwise, Python thinks it's a 1-byte option.
     ifindex = struct.pack("!I", ifindex)
 
-    layer, opt = {
-        4: (net_test.SOL_IP, IP_UNICAST_IF),
-        6: (net_test.SOL_IPV6, IPV6_UNICAST_IF),
-    }[version]
-    s.setsockopt(layer, opt, ifindex)
+    # Always set the IPv4 interface, because it will be used even on IPv6
+    # sockets if the destination address is a mapped address.
+    s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
+    if s.family == AF_INET6:
+      s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
 
   def ReceiveEtherPacketOn(self, netid, packet):
     posix.write(self.tuns[netid].fileno(), str(packet))
@@ -698,71 +696,81 @@
   def _GetRemoteAddress(self, version):
     return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
 
-  def BuildSocket(self, version, constructor, mark, uid, oif, ucast_oif):
+  def SelectInterface(self, s, netid, mode):
+    if mode == "uid":
+      raise ValueError("Can't change UID on an existing socket")
+    elif mode == "mark":
+      self.SetSocketMark(s, netid)
+    elif mode == "oif":
+      iface = self.GetInterfaceName(netid) if netid else ""
+      self.BindToDevice(s, iface)
+    elif mode == "ucast_oif":
+      self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
+    else:
+      raise ValueError("Unkown interface selection mode %s" % mode)
+
+  def BuildSocket(self, version, constructor, netid, routing_mode):
+    uid = self.UidForNetid(netid) if routing_mode == "uid" else None
     with RunAsUid(uid):
       family = self.GetProtocolFamily(version)
       s = constructor(family)
-    if mark:
-      self.SetSocketMark(s, mark)
-    if oif:
-      self.BindToDevice(s, oif)
-    if ucast_oif:
-      self.SetUnicastInterface(version, s, ucast_oif)
+
+    if routing_mode not in [None, "uid"]:
+      self.SelectInterface(s, netid, routing_mode)
+
     return s
 
-  def CheckPingPacket(self, version, mark, uid, oif, ucast_oif, dstaddr, packet,
-                      expected_netid):
-    s = self.BuildSocket(version, net_test.PingSocket, mark, uid, oif,
-                         ucast_oif)
+  def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet):
+    s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
 
-    myaddr = self.MyAddress(version, expected_netid)
+    myaddr = self.MyAddress(version, netid)
     s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
     s.bind((myaddr, PING_IDENT))
     net_test.SetSocketTos(s, PING_TOS)
 
     desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
     msg = "IPv%d ping: expected %s on %s" % (
-        version, desc, self.GetInterfaceName(expected_netid))
+        version, desc, self.GetInterfaceName(netid))
     s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
-    self.ExpectPacketOn(expected_netid, msg, expected)
+    self.ExpectPacketOn(netid, msg, expected)
 
-  def CheckTCPSYNPacket(self, version, mark, uid, oif, ucast_oif, dstaddr,
-                        expected_netid):
-    s = self.BuildSocket(version, net_test.TCPSocket, mark, uid, oif, ucast_oif)
+  def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr):
+    s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
 
     if version == 6 and dstaddr.startswith("::ffff"):
       version = 4
-    myaddr = self.MyAddress(version, expected_netid)
+    myaddr = self.MyAddress(version, netid)
     desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
                                  sport=None, seq=None)
 
     # Non-blocking TCP connects always return EINPROGRESS.
     self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
     msg = "IPv%s TCP connect: expected %s on %s" % (
-        version, desc, self.GetInterfaceName(expected_netid))
-    self.ExpectPacketOn(expected_netid, msg, expected)
+        version, desc, self.GetInterfaceName(netid))
+    self.ExpectPacketOn(netid, msg, expected)
     s.close()
 
-  def CheckUDPPacket(self, version, mark, uid, oif, ucast_oif,
-                     dstaddr, expected_netid):
-    s = self.BuildSocket(version, net_test.UDPSocket, mark, uid, oif, ucast_oif)
+  def CheckUDPPacket(self, version, netid, routing_mode, dstaddr):
+    s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
 
     if version == 6 and dstaddr.startswith("::ffff"):
       version = 4
-    myaddr = self.MyAddress(version, expected_netid)
+    myaddr = self.MyAddress(version, netid)
     desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
     msg = "IPv%s UDP %%s: expected %s on %s" % (
-        version, desc, self.GetInterfaceName(expected_netid))
+        version, desc, self.GetInterfaceName(netid))
 
     s.sendto(UDP_PAYLOAD, (dstaddr, 53))
-    self.ExpectPacketOn(expected_netid, msg % "sendto", expected)
+    self.ExpectPacketOn(netid, msg % "sendto", expected)
 
-    s.connect((dstaddr, 53))
-    s.send(UDP_PAYLOAD)
-    self.ExpectPacketOn(expected_netid, msg % "connect/send", expected)
-    s.close()
+    # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
+    if routing_mode != "ucast_oif":
+      s.connect((dstaddr, 53))
+      s.send(UDP_PAYLOAD)
+      self.ExpectPacketOn(netid, msg % "connect/send", expected)
+      s.close()
 
-  def CheckOutgoingPackets(self, mode):
+  def CheckOutgoingPackets(self, routing_mode):
     v4addr = self.IPV4_ADDR
     v6addr = self.IPV6_ADDR
     v4mapped = "::ffff:" + v4addr
@@ -770,35 +778,18 @@
     for _ in xrange(self.ITERATIONS):
       for netid in self.tuns:
 
-        mark = uid = oif = ucast_oif = None
-        if mode == "mark":
-          mark = netid
-        elif mode == "uid":
-          uid = self.UidForNetid(netid)
-        elif mode == "oif":
-          oif = self.GetInterfaceName(netid)
-        elif mode == "ucast_oif":
-          ucast_oif = self.GetInterfaceName(netid)
-        else:
-          raise ValueError("Unkown routing mode %s" % mode)
+        self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING)
+        self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING)
 
-        self.CheckPingPacket(4, mark, uid, oif, ucast_oif, v4addr,
-                             self.IPV4_PING, netid)
-        self.CheckPingPacket(6, mark, uid, oif, ucast_oif, v6addr,
-                             self.IPV6_PING, netid)
+        # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
+        if routing_mode != "ucast_oif":
+          self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr)
+          self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr)
+          self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped)
 
-        # TCP doesn't seem to honour IP_UNICAST_IF.
-        if mode != "ucast_oif":
-          self.CheckTCPSYNPacket(4, mark, uid, oif, ucast_oif, v4addr, netid)
-          self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v6addr, netid)
-          self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid)
-
-        if mode != "ucast_oif":
-          # This doesn't work.
-          self.CheckUDPPacket(4, mark, uid, oif, ucast_oif, v4addr, netid)
-          # These work, but the source addresses are incorrect.
-          self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v6addr, netid)
-          self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid)
+        self.CheckUDPPacket(4, netid, routing_mode, v4addr)
+        self.CheckUDPPacket(6, netid, routing_mode, v6addr)
+        self.CheckUDPPacket(6, netid, routing_mode, v4mapped)
 
   def testMarkRouting(self):
     """Checks that socket marking selects the right outgoing interface."""
@@ -817,34 +808,53 @@
     """Checks that ucast oif routing selects the right outgoing interface."""
     self.CheckOutgoingPackets("ucast_oif")
 
-  def CheckRemarking(self, version):
-    s = net_test.UDPSocket(self.GetProtocolFamily(version))
+  def CheckRemarking(self, version, use_connect):
+    # Remarking or resetting UNICAST_IF on connected sockets does not work.
+    if use_connect:
+      modes = ["oif"]
+    else:
+      modes = ["mark", "oif", "ucast_oif"]
 
-    # Figure out what packets to expect.
-    unspec = {4: "0.0.0.0", 6: "::"}[version]
-    sport = Packets.RandomPort()
-    s.bind((unspec, sport))
-    dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
-    desc, expected = Packets.UDP(version, unspec, dstaddr, sport)
+    for mode in modes:
+      s = net_test.UDPSocket(self.GetProtocolFamily(version))
 
-    # For each netid, set that netid's mark on the socket without closing it,
-    # and check that the packets sent on that socket go out on the right
-    # network.
-    for netid in self.tuns:
-      self.SetSocketMark(s, netid)
-      expected.src = self.MyAddress(version, netid)
-      s.sendto("hello", (dstaddr, 53))
-      msg = "Remarked UDPv%d socket: expecting %s on %s" % (
-          version, desc, self.GetInterfaceName(netid))
-      self.ExpectPacketOn(netid, msg, expected)
+      # Figure out what packets to expect.
+      unspec = {4: "0.0.0.0", 6: "::"}[version]
+      sport = Packets.RandomPort()
+      s.bind((unspec, sport))
+      dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
+      desc, expected = Packets.UDP(version, unspec, dstaddr, sport)
+
+      # If we're testing connected sockets, connect the socket on the first
+      # netid now.
+      if use_connect:
+        netid = self.tuns.keys()[0]
+        self.SelectInterface(s, netid, mode)
+        s.connect((dstaddr, 53))
+        expected.src = self.MyAddress(version, netid)
+
+      # For each netid, select that network without closing the socket, and
+      # check that the packets sent on that socket go out on the right network.
+      for netid in self.tuns:
+        self.SelectInterface(s, netid, mode)
+        if not use_connect:
+          expected.src = self.MyAddress(version, netid)
+        s.sendto("hello", (dstaddr, 53))
+        connected_str = "Connected" if use_connect else "Unconnected"
+        msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
+            connected_str, version, mode, desc, self.GetInterfaceName(netid))
+        self.ExpectPacketOn(netid, msg, expected)
+        self.SelectInterface(s, None, mode)
 
   def testIPv4Remarking(self):
     """Checks that updating the mark on an IPv4 socket changes routing."""
-    self.CheckRemarking(4)
+    self.CheckRemarking(4, False)
+    self.CheckRemarking(4, True)
 
   def testIPv6Remarking(self):
     """Checks that updating the mark on an IPv6 socket changes routing."""
-    self.CheckRemarking(6)
+    self.CheckRemarking(6, False)
+    self.CheckRemarking(6, True)
 
   def CheckReflection(self, version, packet_generator, reply_generator,
                       mark_behaviour, callback=None):