Test multiple VTI tunnels at the same time.

XfrmVtiTest now creates one IPv4 and one IPv6 tunnel per netid.
Similar to MultinetworkBaseTest, the tunnels are created in
setUpClass and torn down in tearDownClass, so all the tests are
always run with all the VTIs present.

This provides the basic infrastructure to perform more in-depth
testing of the VTI code in a realistic environment.

Bug: 70371070
Test: xfrm tests pass on android-4.9
Change-Id: I295c3af6626b0ced98975bbb67c2592ab9a8f455
diff --git a/net/test/xfrm.py b/net/test/xfrm.py
index 93d15ac..1bd10da 100755
--- a/net/test/xfrm.py
+++ b/net/test/xfrm.py
@@ -630,6 +630,15 @@
       tmpl = UserTemplate(outer_family, spi, 0, (src, dst))
       self.AddPolicyInfo(policy, tmpl, mark)
 
+  def DeleteTunnel(self, direction, selector, dst, spi, mark):
+    self.DeleteSaInfo(dst, spi, IPPROTO_ESP, ExactMatchMark(mark))
+    if selector is None:
+      selectors = [EmptySelector(AF_INET), EmptySelector(AF_INET6)]
+    else:
+      selectors = [selector]
+    for selector in selectors:
+      self.DeletePolicyInfo(selector, direction, ExactMatchMark(mark))
+
 
 if __name__ == "__main__":
   x = Xfrm()
diff --git a/net/test/xfrm_tunnel_test.py b/net/test/xfrm_tunnel_test.py
index c66f74c..94e846e 100755
--- a/net/test/xfrm_tunnel_test.py
+++ b/net/test/xfrm_tunnel_test.py
@@ -32,6 +32,10 @@
 import xfrm_base
 
 # Parameters to Set up VTI as a special network
+_BASE_VTI_NETID = {4: 40, 6: 60}
+_BASE_VTI_OKEY = 2000000100
+_BASE_VTI_IKEY = 2000000200
+
 _VTI_NETID = 50
 _VTI_IFNAME = "test_vti"
 
@@ -100,6 +104,40 @@
     self._CheckTunnelOutput(6, 6)
 
 
+@unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
+class XfrmAddDeleteVtiTest(xfrm_base.XfrmBaseTest):
+
+  def testAddVti(self):
+    """Test the creation of a Virtual Tunnel Interface."""
+    for version in [4, 6]:
+      netid = self.RandomNetid()
+      local_addr = self.MyAddress(version, netid)
+      self.iproute.CreateVirtualTunnelInterface(
+          dev_name=_VTI_IFNAME,
+          local_addr=local_addr,
+          remote_addr=_GetRemoteOuterAddress(version),
+          o_key=_TEST_OKEY,
+          i_key=_TEST_IKEY)
+      if_index = self.iproute.GetIfIndex(_VTI_IFNAME)
+
+      # Validate that the netlink interface matches the ioctl interface.
+      self.assertEquals(net_test.GetInterfaceIndex(_VTI_IFNAME), if_index)
+      self.iproute.DeleteLink(_VTI_IFNAME)
+      with self.assertRaises(IOError):
+        self.iproute.GetIfIndex(_VTI_IFNAME)
+
+  def _QuietDeleteLink(self, ifname):
+    try:
+      self.iproute.DeleteLink(ifname)
+    except IOError:
+      # The link was not present.
+      pass
+
+  def tearDown(self):
+    super(XfrmAddDeleteVtiTest, self).tearDown()
+    self._QuietDeleteLink(_VTI_IFNAME)
+
+
 class VtiInterface(object):
 
   def __init__(self, iface, netid, underlying_netid, local, remote):
@@ -120,6 +158,7 @@
     self.addrs = {}
 
   def Teardown(self):
+    self.TeardownXfrm()
     self.TeardownInterface()
 
   def SetupInterface(self):
@@ -144,9 +183,15 @@
                            xfrm_base._ALGO_HMAC_SHA1,
                            xfrm.ExactMatchMark(self.ikey), None)
 
+  def TeardownXfrm(self):
+    self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_OUT, None, self.remote,
+                           self.out_spi, self.okey)
+    self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_IN, None, self.local,
+                           self.in_spi, self.ikey)
+
 
 @unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
-class XfrmVtiTest(xfrm_base.XfrmLazyTest):
+class XfrmVtiTest(xfrm_base.XfrmBaseTest):
 
   @classmethod
   def setUpClass(cls):
@@ -154,34 +199,39 @@
     # VTI interfaces use marks extensively, so configure realistic packet
     # marking rules to make the test representative, make PMTUD work, etc.
     cls.SetInboundMarks(True)
-    cls._SetInboundMarking(_VTI_NETID, _VTI_IFNAME, True)
     cls.SetMarkReflectSysctls(1)
 
+    cls.vtis = {}
+    for i, underlying_netid in enumerate(cls.tuns):
+      for version in 4, 6:
+        netid = _BASE_VTI_NETID[version] + i
+        iface = "ipsec%s" % netid
+        local = cls.MyAddress(version, underlying_netid)
+        if version == 4:
+          remote = net_test.IPV4_ADDR2 if (i % 2) else net_test.IPV4_ADDR
+        else:
+          remote = net_test.IPV6_ADDR2 if (i % 2) else net_test.IPV6_ADDR
+        vti = VtiInterface(iface, netid, underlying_netid, local, remote)
+        cls._SetInboundMarking(netid, iface, True)
+        cls._SetupVtiNetwork(vti, True)
+        cls.vtis[netid] = vti
+
   @classmethod
   def tearDownClass(cls):
     # The sysctls are restored by MultinetworkBaseTest.tearDownClass.
-    cls._SetInboundMarking(_VTI_NETID, _VTI_IFNAME, False)
     cls.SetInboundMarks(False)
+    for vti in cls.vtis.values():
+      cls._SetInboundMarking(vti.netid, vti.iface, False)
+      cls._SetupVtiNetwork(vti, False)
+      vti.Teardown()
     xfrm_base.XfrmBaseTest.tearDownClass()
 
   def setUp(self):
-    super(XfrmVtiTest, self).setUp()
-    # If the hard-coded netids are redefined this will catch the error.
-    self.assertNotIn(_VTI_NETID, self.NETIDS,
-                     "VTI netid %d already in use" % _VTI_NETID)
+    multinetwork_base.MultiNetworkBaseTest.setUp(self)
     self.iproute = iproute.IPRoute()
-    self._QuietDeleteLink(_VTI_IFNAME)
 
   def tearDown(self):
-    super(XfrmVtiTest, self).tearDown()
-    self._QuietDeleteLink(_VTI_IFNAME)
-
-  def _QuietDeleteLink(self, ifname):
-    try:
-      self.iproute.DeleteLink(ifname)
-    except IOError:
-      # The link was not present.
-      pass
+    multinetwork_base.MultiNetworkBaseTest.tearDown(self)
 
   def _SwapInterfaceAddress(self, ifname, old_addr, new_addr):
     """Exchange two addresses on a given interface.
@@ -198,26 +248,8 @@
     self.iproute.DelAddress(old_addr,
                             net_test.AddressLengthBits(version), ifindex)
 
-  def testAddVti(self):
-    """Test the creation of a Virtual Tunnel Interface."""
-    for version in [4, 6]:
-      netid = self.RandomNetid()
-      local_addr = self.MyAddress(version, netid)
-      self.iproute.CreateVirtualTunnelInterface(
-          dev_name=_VTI_IFNAME,
-          local_addr=local_addr,
-          remote_addr=_GetRemoteOuterAddress(version),
-          o_key=_TEST_OKEY,
-          i_key=_TEST_IKEY)
-      if_index = self.iproute.GetIfIndex(_VTI_IFNAME)
-
-      # Validate that the netlink interface matches the ioctl interface.
-      self.assertEquals(net_test.GetInterfaceIndex(_VTI_IFNAME), if_index)
-      self.iproute.DeleteLink(_VTI_IFNAME)
-      with self.assertRaises(IOError):
-        self.iproute.GetIfIndex(_VTI_IFNAME)
-
-  def _SetupVtiNetwork(self, vti, is_add):
+  @classmethod
+  def _SetupVtiNetwork(cls, vti, is_add):
     """Setup rules and routes for a VTI Network.
 
     Takes an interface and depending on the boolean
@@ -237,7 +269,7 @@
       # is echoed back to the VTI, it causes the test to fail by not receiving
       # the UDP_PAYLOAD; or, two packets may arrive on the underlying
       # network which fails the assertion that only one ESP packet is received.
-      self.SetSysctl(
+      cls.SetSysctl(
           "/proc/sys/net/ipv6/conf/%s/router_solicitations" % vti.iface, 0)
       net_test.SetInterfaceUp(vti.iface)
 
@@ -246,26 +278,26 @@
       table = vti.netid
 
       # Set up routing rules.
-      start, end = self.UidRangeForNetid(vti.netid)
-      self.iproute.UidRangeRule(version, is_add, start, end, table,
-                                self.PRIORITY_UID)
-      self.iproute.OifRule(version, is_add, vti.iface, table, self.PRIORITY_OIF)
-      self.iproute.FwmarkRule(version, is_add, vti.netid, self.NETID_FWMASK,
-                              table, self.PRIORITY_FWMARK)
+      start, end = cls.UidRangeForNetid(vti.netid)
+      cls.iproute.UidRangeRule(version, is_add, start, end, table,
+                                cls.PRIORITY_UID)
+      cls.iproute.OifRule(version, is_add, vti.iface, table, cls.PRIORITY_OIF)
+      cls.iproute.FwmarkRule(version, is_add, vti.netid, cls.NETID_FWMASK,
+                              table, cls.PRIORITY_FWMARK)
 
       # Configure IP addresses.
       if version == 4:
-        addr = self._MyIPv4Address(vti.netid)
+        addr = cls._MyIPv4Address(vti.netid)
       else:
-        addr = self.OnlinkPrefix(6, vti.netid) + "1"
+        addr = cls.OnlinkPrefix(6, vti.netid) + "1"
       prefixlen = net_test.AddressLengthBits(version)
       vti.addrs[version] = addr
       if is_add:
-        self.iproute.AddAddress(addr, prefixlen, ifindex)
-        self.iproute.AddRoute(version, table, "default", 0, None, ifindex)
+        cls.iproute.AddAddress(addr, prefixlen, ifindex)
+        cls.iproute.AddRoute(version, table, "default", 0, None, ifindex)
       else:
-        self.iproute.DelRoute(version, table, "default", 0, None, ifindex)
-        self.iproute.DelAddress(addr, prefixlen, ifindex)
+        cls.iproute.DelRoute(version, table, "default", 0, None, ifindex)
+        cls.iproute.DelAddress(addr, prefixlen, ifindex)
 
   def assertReceivedPacket(self, vti):
     vti.rx += 1
@@ -280,13 +312,6 @@
   # direction individually. This approach would improve debuggability, avoid the
   # complexity of the twister, and allow the test to more-closely validate
   # deployable configurations.
-  def _CreateVti(self, netid, vti_netid, ifname, outer_version):
-    local_outer = self.MyAddress(outer_version, netid)
-    remote_outer = _GetRemoteOuterAddress(outer_version)
-    vti = VtiInterface(ifname, vti_netid, netid, local_outer, remote_outer)
-    self._SetupVtiNetwork(vti, True)
-    return vti
-
   def _CheckVtiInputOutput(self, vti, inner_version):
     local_outer = vti.local
     remote_outer = vti.remote
@@ -361,23 +386,12 @@
     # Clear PMTU information so that future tests don't have to worry about it.
     self.InvalidateDstCache(version, vti.underlying_netid)
 
-  def _TestVti(self, outer_version):
+  def testVtiInputOutput(self):
     """Test packet input and output over a Virtual Tunnel Interface."""
-    netid = self.RandomNetid()
-
-    vti = self._CreateVti(netid, _VTI_NETID, _VTI_IFNAME, outer_version)
-
-    try:
+    for i in xrange(3 * len(self.vtis.values())):
+      vti = random.choice(self.vtis.values())
       self._CheckVtiInputOutput(vti, 4)
       self._CheckVtiInputOutput(vti, 6)
-    finally:
-      self._SetupVtiNetwork(vti, False)
-
-  def testIpv4Vti(self):
-    self._TestVti(4)
-
-  def testIpv6Vti(self):
-    self._TestVti(6)
 
 
 if __name__ == "__main__":