Test the oif rules using SO_BINDTODEVICE.

For now, this only tests incoming connections.

Change-Id: Ie1fcf53786d6c65c7f4ec80eb6573e824e730899
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index 89537e4..b107506 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -25,6 +25,8 @@
 PING_SEQ = 3
 PING_TOS = 0x83
 
+SO_BINDTODEVICE = 25
+
 UDP_PAYLOAD = "hello"
 
 
@@ -205,7 +207,7 @@
     packet = (ip(src=srcaddr, dst=dstaddr) /
               icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
     cls._SetPacketTos(packet, PING_TOS)
-    return ("ICMPv%d echo" % version, packet)
+    return ("ICMPv%d echo reply" % version, packet)
 
 
 class RunAsUid(object):
@@ -342,7 +344,8 @@
       iface = cls.GetInterfaceName(netid)
       if HAVE_EXPERIMENTAL_UID_ROUTING:
         cls.iproute.UidRule(version, is_add, uid, table, priority=100)
-      cls.iproute.FwmarkRule(version, is_add, netid, table, priority=200)
+      cls.iproute.OifRule(version, is_add, iface, table, priority=200)
+      cls.iproute.FwmarkRule(version, is_add, netid, table, priority=300)
 
       if cls.DEBUG:
         os.spawnvp(os.P_WAIT, "/sbin/ip", ["ip", "-6", "rule", "list"])
@@ -452,6 +455,11 @@
   def GetSocketMark(self, s):
     return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
 
+  def BindToDevice(self, s, iface):
+    if not iface:
+      iface = ""
+    s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
+
   def ReceivePacketOn(self, netid, ip_packet):
     routermac = self.RouterMacAddress(netid)
     mymac = self.MyMacAddress(netid)
@@ -621,6 +629,10 @@
   def setUp(self):
     self.ClearTunQueues()
 
+  def tearDown(self):
+    # In case there was an exception in one of the tests and we didn't clean up.
+    self.BindToDevice(self.listensocket, None)
+
   def _GetRemoteAddress(self, version):
     return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
 
@@ -767,28 +779,57 @@
       # ... coming in on all our interfaces...
       for iif_netid in self.tuns:
         iif = self.GetInterfaceName(iif_netid)
-        desc, packet = packet_generator(version, remote_addr, myaddr)
-        reply_desc, reply = reply_generator(version, myaddr, remote_addr,
-                                            packet)
 
         # ... with inbound mark sysctl enabled and disabled.
         for sysctl_value in [0, 1]:
-          msg = "Receiving %s on %s to %s IP, %s=%d" % (
-              desc, iif, dest_ip_iface, mark_behaviour, sysctl_value)
-          sysctl_function(sysctl_value)
-          self.ClearTunQueues()
-          # Cause the kernel to receive packet on iif_netid.
-          self.ReceivePacketOn(iif_netid, packet)
-          # Expect the kernel to send out reply on the same interface.
-          if sysctl_value:
-            msg += ": Expecting %s on %s" % (reply_desc, iif)
-            reply = self.ExpectPacketOn(iif_netid, msg, reply)
-            if callback:
-              callback(iif_netid, version, myaddr, remote_addr, packet, reply,
-                       msg)
+
+          # If we're testing accepting TCP connections, also check that
+          # SO_BINDTODEVICE correctly sets the interface the SYN+ACK is sent on.
+          # Since SO_BINDTODEVICE and the sysctl do the same thing, it doesn't
+          # really make sense to test with sysctl_value=1 and SO_BINDTODEVICE
+          # turned on at the same time.
+          if mark_behaviour == "accept" and not sysctl_value:
+            bind_devices = [None, iif]
           else:
-            msg += ": Expecting no packets on %s" % reply_desc
-            self.ExpectNoPacketsOn(iif_netid, msg, reply)
+            bind_devices = [None]
+
+          for bound_dev in bind_devices:
+            # The socket is unbound in tearDown.
+            self.BindToDevice(self.listensocket, bound_dev)
+
+            # Generate the packet here instead of in the outer loop, so
+            # subsequent TCP connections use different source ports and
+            # retransmissions from old connections don't confuse subsequent
+            # tests.
+            desc, packet = packet_generator(version, remote_addr, myaddr)
+            reply_desc, reply = reply_generator(version, myaddr, remote_addr,
+                                                packet)
+
+            msg = "Receiving %s on %s to %s IP, %s=%d, bound_dev=%s" % (
+                desc, iif, dest_ip_iface, mark_behaviour, sysctl_value,
+                bound_dev)
+            sysctl_function(sysctl_value)
+            self.ClearTunQueues()
+
+            # Cause the kernel to receive packet on iif_netid.
+            self.ReceivePacketOn(iif_netid, packet)
+
+            # Expect the kernel to send out reply on the same interface.
+            #
+            # HACK: IPv6 ping replies always do a routing lookup with the
+            # interface the ping came in on. So even if mark reflection is not
+            # working, IPv6 ping replies will be properly reflected. Don't
+            # fail when that happens.
+            if bound_dev or sysctl_value or reply_desc == "ICMPv6 echo reply":
+              msg += ": Expecting %s on %s" % (reply_desc, iif)
+              reply = self.ExpectPacketOn(iif_netid, msg, reply)
+              # If a callback was set, call it.
+              if callback:
+                callback(sysctl_value, iif_netid, version, myaddr, remote_addr,
+                         packet, reply, msg)
+            else:
+              msg += ": Expecting no packets on %s" % iif
+              self.ExpectNoPacketsOn(iif_netid, msg, reply)
 
   def SYNToClosedPort(self, *args):
     return Packets.SYN(999, *args)
@@ -820,8 +861,8 @@
   def testIPv6RSTsReflectMark(self):
     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST, "reflect")
 
-  def CheckTCPConnection(self, netid, version, myaddr, remote_addr,
-                         packet, reply, msg):
+  def CheckTCPConnection(self, sysctl_value, netid, version,
+                         myaddr, remote_addr, packet, reply, msg):
     establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
     self.ReceivePacketOn(netid, establishing_ack)
     s, unused_peer = self.listensocket.accept()
@@ -829,9 +870,10 @@
       mark = self.GetSocketMark(s)
     finally:
       s.close()
-    self.assertEquals(netid, mark,
-                      msg + ": Accepted socket: Expected mark %d, got %d" % (
-                          netid, mark))
+    if sysctl_value:
+      self.assertEquals(netid, mark,
+                        msg + ": Accepted socket: Expected mark %d, got %d" % (
+                            netid, mark))
 
     # Check the FIN was sent on the right interface, and ack it. We don't expect
     # this to fail because by the time the connection is established things are