Better tests for Path MTU discovery.
Add PMTU tests for unconnected sockets, and test PMTUD when
routing using all methods, not just using socket marking.
Change-Id: I8f0f6fc00afa95b8e57792c51e955e2150ef29dc
diff --git a/net/test/mark_test.py b/net/test/mark_test.py
index ec70b5e..2e2e5b4 100755
--- a/net/test/mark_test.py
+++ b/net/test/mark_test.py
@@ -74,6 +74,16 @@
   return result
 
 
+def LinuxVersion():
+  # Example: "3.4.67-00753-gb7a556f".
+  # Get the part before the dash.
+  version = os.uname()[2].split("-")[0]
+  # Convert it into a tuple such as (3, 4, 67). That allows comparing versions
+  # using < and >, since tuples are compared lexicographically.
+  version = tuple(int(i) for i in version.split("."))
+  return version
+
+
 AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
 IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
 IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
@@ -1398,59 +1408,123 @@
     self.assertEquals(num_routes, GetNumRoutes())
 
 
-class PMTUTest(MultiNetworkTest):
+class PMTUTest(InboundMarkingTest):
 
   PAYLOAD_SIZE = 1400
+
+  # Socket options to change PMTU behaviour.
   IP_MTU_DISCOVER = 10
-  IP_MTU = 14
   IP_PMTUDISC_DO = 1
-  IPV6_PATHMTU = 61
   IPV6_DONTFRAG = 62
 
+  # Socket options to get the MTU.
+  IP_MTU = 14
+  IPV6_PATHMTU = 61
+
   def GetSocketMTU(self, version, s):
     if version == 6:
       ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
-      mtu = struct.unpack("=28sI", ip6_mtuinfo)
-      return mtu[1]
+      unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
+      return mtu
     else:
       return s.getsockopt(net_test.SOL_IP, self.IP_MTU)
 
-  def CheckPMTU(self, version):
+  def DisableFragmentationAndReportErrors(self, version, s):
+    if version == 4:
+      s.setsockopt(net_test.SOL_IP, self.IP_MTU_DISCOVER, self.IP_PMTUDISC_DO)
+      s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
+    else:
+      s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
+      s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
+
+  def CheckPMTU(self, version, use_connect, modes):
     for netid in self.tuns:
-      s = net_test.UDPSocket(self.GetProtocolFamily(version))
+      for mode in modes:
+        s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
+        self.DisableFragmentationAndReportErrors(version, s)
 
-      srcaddr = self.MyAddress(version, netid)
-      dst_prefix, intermediate = {
-          4: ("172.19.", "172.16.9.12"),
-          6: ("2001:db8::", "2001:db8::1")
-      }[version]
-      dstaddr = self.GetRandomDestination(dst_prefix)
+        srcaddr = self.MyAddress(version, netid)
+        dst_prefix, intermediate = {
+            4: ("172.19.", "172.16.9.12"),
+            6: ("2001:db8::", "2001:db8::1")
+        }[version]
+        dstaddr = self.GetRandomDestination(dst_prefix)
 
-      # So the packet has somewhere to go.
-      self.SetSocketMark(s, netid)
-      s.connect((dstaddr, 1234))
-      self.assertEquals(1500, self.GetSocketMTU(version, s))
+        if use_connect:
+          s.connect((dstaddr, 1234))
 
-      s.send(self.PAYLOAD_SIZE * "a")
-      packets = self.ReadAllPacketsOn(netid)
-      self.assertEquals(1, len(packets))
-      _, toobig = Packets.ICMPPacketTooBig(version, intermediate, srcaddr,
-                                           packets[0])
-      self.ReceivePacketOn(netid, toobig)
-      self.assertEquals(1280, self.GetSocketMTU(version, s))
-      s.close()
+        payload = self.PAYLOAD_SIZE * "a"
 
-      # Open another socket to ensure the path MTU is cached.
-      s2 = net_test.UDPSocket(self.GetProtocolFamily(version))
-      self.BindToDevice(s2, self.GetInterfaceName(netid))
-      s2.connect((dstaddr, 1234))
-      self.assertEquals(1280, self.GetSocketMTU(version, s2))
+        def SendBigPacket():
+          if use_connect:
+            s.send(payload)
+          else:
+            self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
 
-  def testIPv4PMTU(self):
-    self.CheckPMTU(4)
+        # Send a packet and receive a packet too big.
+        SendBigPacket()
+        packets = self.ReadAllPacketsOn(netid)
+        self.assertEquals(1, len(packets))
+        _, toobig = Packets.ICMPPacketTooBig(version, intermediate, srcaddr,
+                                             packets[0])
+        self.ReceivePacketOn(netid, toobig)
 
-  def testIPv6PMTU(self):
-    self.CheckPMTU(6)
+        # Check that another send on the same socket returns EMSGSIZE.
+        self.assertRaisesErrno(errno.EMSGSIZE, SendBigPacket)
+
+        # If this is a connected socket, make sure the socket MTU was set.
+        # Note that in IPv4 this only started working in Linux 3.6!
+        if use_connect and (version == 6 or LinuxVersion() >= (3, 6)):
+          self.assertEquals(1280, self.GetSocketMTU(version, s))
+
+        s.close()
+
+        # Check that other sockets pick up the PMTU we have been told about by
+        # connecting another socket to the same destination and getting its MTU.
+        # This new socket can use any method to select its outgoing interface;
+        # here we use a mark for simplicity.
+        s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
+        s2.connect((dstaddr, 1234))
+        self.assertEquals(1280, self.GetSocketMTU(version, s2))
+
+  def testIPv4BasicPMTU(self):
+    self.CheckPMTU(4, True, ["mark", "oif"])
+    self.CheckPMTU(4, False, ["mark", "oif"])
+
+  def testIPv6BasicPMTU(self):
+    self.CheckPMTU(6, True, ["mark", "oif"])
+    self.CheckPMTU(6, False, ["mark", "oif"])
+
+  @unittest.skipUnless(HAVE_EXPERIMENTAL_UID_ROUTING, "no UID routing")
+  def testIPv4UIDPMTU(self):
+    self.CheckPMTU(4, True, ["uid"])
+    self.CheckPMTU(4, False, ["uid"])
+
+  @unittest.skipUnless(HAVE_EXPERIMENTAL_UID_ROUTING, "no UID routing")
+  def testIPv6ConnectedSocketUIDPMTU(self):
+    self.CheckPMTU(6, True, ["uid"])
+    self.CheckPMTU(6, False, ["uid"])
+
+  # Making Path MTU Discovery work on unmarked  sockets requires that mark
+  # reflection be enabled. Otherwise the kernel has no way to know what routing
+  # table the original packet used, and thus it won't be able to clone the
+  # correct route.
+
+  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
+  def testIPv4UnmarkedSocketPMTU(self):
+    self.SetMarkReflectSysctls(1)
+    try:
+      self.CheckPMTU(4, False, [None])
+    finally:
+      self.SetMarkReflectSysctls(0)
+
+  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
+  def testIPv6UnmarkedSocketPMTU(self):
+    self.SetMarkReflectSysctls(1)
+    try:
+      self.CheckPMTU(6, False, [None])
+    finally:
+      self.SetMarkReflectSysctls(0)
 
 
 @unittest.skipUnless(HAVE_EXPERIMENTAL_UID_ROUTING, "no UID routing")