Test that ICMP errors work on VTI interfaces.

Bug: 70371070
Test: xfrm_tunnel_test.py passes on android-4.9
Change-Id: I3a313a0c158efd084cb5601c0ae3c0999fc55467
diff --git a/net/test/xfrm_tunnel_test.py b/net/test/xfrm_tunnel_test.py
index c687163..e1fabb2 100755
--- a/net/test/xfrm_tunnel_test.py
+++ b/net/test/xfrm_tunnel_test.py
@@ -26,6 +26,7 @@
 import iproute
 import multinetwork_base
 import net_test
+import packets
 import xfrm
 import xfrm_base
 
@@ -101,6 +102,22 @@
 @unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
 class XfrmVtiTest(xfrm_base.XfrmLazyTest):
 
+  @classmethod
+  def setUpClass(cls):
+    xfrm_base.XfrmBaseTest.setUpClass()
+    # VTI interfaces use marks extensively, so configure realistic packet
+    # marking rules to make the test representative, make PMTUD work, etc.
+    cls.SetInboundMarks(True)
+    cls._SetInboundMarking(_VTI_NETID, _VTI_IFNAME, True)
+    cls.SetMarkReflectSysctls(1)
+
+  @classmethod
+  def tearDownClass(cls):
+    # The sysctls are restored by MultinetworkBaseTest.tearDownClass.
+    cls._SetInboundMarking(_VTI_NETID, _VTI_IFNAME, False)
+    cls.SetInboundMarks(False)
+    xfrm_base.XfrmBaseTest.tearDownClass()
+
   def setUp(self):
     super(XfrmVtiTest, self).setUp()
     # If the hard-coded netids are redefined this will catch the error.
@@ -285,7 +302,35 @@
       # Unwind the switcheroo
       self._SwapInterfaceAddress(iface, new_addr=local, old_addr=remote)
 
-    return rx + 1, tx + 1
+    # Now attempt to provoke an ICMP error.
+    # TODO: deduplicate with multinetwork_test.py.
+    version = outer_version
+    dst_prefix, intermediate = {
+        4: ("172.19.", "172.16.9.12"),
+        6: ("2001:db8::", "2001:db8::1")
+    }[version]
+
+    write_sock.sendto(net_test.UDP_PAYLOAD,
+                      (_GetRemoteInnerAddress(inner_version), port))
+    pkt = self._ExpectEspPacketOn(netid, _TEST_OUT_SPI, tx + 2, None,
+                                  local_outer, remote_outer)
+    myaddr = self.MyAddress(version, netid)
+    _, toobig = packets.ICMPPacketTooBig(version, intermediate, myaddr, pkt)
+    self.ReceivePacketOn(netid, toobig)
+
+    # Check that the packet too big reduced the MTU.
+    routes = self.iproute.GetRoutes(remote_outer, 0, netid, None)
+    self.assertEquals(1, len(routes))
+    rtmsg, attributes = routes[0]
+    self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
+    self.assertEquals(packets.PTB_MTU, attributes["RTA_METRICS"]["RTAX_MTU"])
+
+    # Clear PMTU information so that future tests don't have to worry about it.
+    self.InvalidateDstCache(version, netid)
+
+    self.assertEquals((rx + 1, tx + 2),
+                      self.iproute.GetRxTxPackets(iface))
+    return rx + 1, tx + 2
 
   def _TestVti(self, outer_version):
     """Test packet input and output over a Virtual Tunnel Interface."""