Merge "Test that an SA Can be Updated with a Mark"
diff --git a/net/test/all_tests.sh b/net/test/all_tests.sh
index a5476f9..63576b0 100755
--- a/net/test/all_tests.sh
+++ b/net/test/all_tests.sh
@@ -16,6 +16,27 @@
 
 readonly PREFIX="#####"
 readonly RETRIES=2
+test_prefix=
+
+function checkArgOrExit() {
+  if [[ $# -lt 2 ]]; then
+    echo "Missing argument for option $1" >&2
+    exit 1
+  fi
+}
+
+function usageAndExit() {
+  cat >&2 << EOF
+  all_tests.sh - test runner with support for flake testing
+
+  all_tests.sh [options]
+
+  options:
+  -h, --help                     show this menu
+  -p, --prefix=TEST_PREFIX       specify a prefix for the tests to be run
+EOF
+  exit 0
+}
 
 function maybePlural() {
   # $1 = integer to use for plural check
@@ -28,7 +49,6 @@
   fi
 }
 
-
 function runTest() {
   local cmd="$1"
 
@@ -46,10 +66,39 @@
   echo >&2 "Warning: '$cmd' FLAKY, passed $RETRIES/$((RETRIES + 1))"
 }
 
+# Parse arguments
+while [ -n "$1" ]; do
+  case "$1" in
+    -h|--help)
+      usageAndExit
+      ;;
+    -p|--prefix)
+      checkArgOrExit $@
+      test_prefix=$2
+      shift 2
+      ;;
+    *)
+      echo "Unknown option $1" >&2
+      echo >&2
+      usageAndExit
+      ;;
+  esac
+done
 
-readonly tests=$(find . -name '*_test.py' -type f -executable)
+# Find the relevant tests
+if [[ -z $test_prefix ]]; then
+  readonly tests=$(eval "find . -name '*_test.py' -type f -executable")
+else
+  readonly tests=$(eval "find . -name '$test_prefix*' -type f -executable")
+fi
+
+# Give some readable status messages
 readonly count=$(echo $tests | wc -w)
-echo "$PREFIX Found $count $(maybePlural $count test tests)."
+if [[ -z $test_prefix ]]; then
+  echo "$PREFIX Found $count $(maybePlural $count test tests)."
+else
+  echo "$PREFIX Found $count $(maybePlural $count test tests) with prefix $test_prefix."
+fi
 
 exit_code=0
 
diff --git a/net/test/iproute.py b/net/test/iproute.py
index de7d4bb..1ec7365 100644
--- a/net/test/iproute.py
+++ b/net/test/iproute.py
@@ -302,15 +302,14 @@
                 "IFLA_PROMISCUITY", "IFLA_NUM_RX_QUEUES",
                 "IFLA_NUM_TX_QUEUES", "NDA_PROBES", "RTAX_MTU",
                 "RTAX_HOPLIMIT", "IFLA_CARRIER_CHANGES", "IFLA_GSO_MAX_SEGS",
-                "IFLA_GSO_MAX_SIZE"]:
+                "IFLA_GSO_MAX_SIZE", "RTA_UID"]:
       data = struct.unpack("=I", nla_data)[0]
     elif name == "FRA_SUPPRESS_PREFIXLEN":
       data = struct.unpack("=i", nla_data)[0]
     elif name in ["IFLA_LINKMODE", "IFLA_OPERSTATE", "IFLA_CARRIER"]:
       data = ord(nla_data)
     elif name in ["IFA_ADDRESS", "IFA_LOCAL", "RTA_DST", "RTA_SRC",
-                  "RTA_GATEWAY", "RTA_PREFSRC", "RTA_UID",
-                  "NDA_DST"]:
+                  "RTA_GATEWAY", "RTA_PREFSRC", "NDA_DST"]:
       data = socket.inet_ntop(msg.family, nla_data)
     elif name in ["FRA_IIFNAME", "FRA_OIFNAME", "IFLA_IFNAME", "IFLA_QDISC",
                   "IFA_LABEL", "IFLA_INFO_KIND"]:
diff --git a/net/test/multinetwork_base.py b/net/test/multinetwork_base.py
index eeadc77..362285e 100644
--- a/net/test/multinetwork_base.py
+++ b/net/test/multinetwork_base.py
@@ -302,6 +302,12 @@
                                  cls.OnlinkPrefixLen(4), ifindex)
 
   @classmethod
+  def SetMarkReflectSysctls(cls, value):
+    """Makes kernel-generated replies use the mark of the original packet."""
+    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
+    cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
+
+  @classmethod
   def _SetInboundMarking(cls, netid, iface, is_add):
     for version in [4, 6]:
       # Run iptables to set up incoming packet marking.
@@ -313,6 +319,11 @@
         raise ConfigurationError("Setup command failed: %s" % args)
 
   @classmethod
+  def SetInboundMarks(cls, is_add):
+    for netid in cls.tuns:
+      cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add)
+
+  @classmethod
   def SetDefaultNetwork(cls, netid):
     table = cls._TableForNetid(netid) if netid else None
     for version in [4, 6]:
@@ -725,21 +736,9 @@
   @classmethod
   def setUpClass(cls):
     super(InboundMarkingTest, cls).setUpClass()
-    for netid in cls.tuns:
-      cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), True)
+    cls.SetInboundMarks(True)
 
   @classmethod
   def tearDownClass(cls):
-    for netid in cls.tuns:
-      cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), False)
+    cls.SetInboundMarks(False)
     super(InboundMarkingTest, cls).tearDownClass()
-
-  @classmethod
-  def SetMarkReflectSysctls(cls, value):
-    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
-    try:
-      cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
-    except IOError:
-      # This does not exist if we use the version of the patch that uses a
-      # common sysctl for IPv4 and IPv6.
-      pass
diff --git a/net/test/packets.py b/net/test/packets.py
index b2cf3e4..3a40cbe 100644
--- a/net/test/packets.py
+++ b/net/test/packets.py
@@ -156,8 +156,6 @@
             scapy.IP(src=srcaddr, dst=dstaddr, proto=1) /
             scapy.ICMPerror(type=3, code=4, unused=1280) / str(packet)[:64])
   else:
-    udp = packet.getlayer("UDP")
-    udp.payload = str(udp.payload)[:1280-40-8]
     return ("ICMPv6 Packet Too Big",
             scapy.IPv6(src=srcaddr, dst=dstaddr) /
             scapy.ICMPv6PacketTooBig() / str(packet)[:1232])
diff --git a/net/test/ping6_test.py b/net/test/ping6_test.py
index 7a260e7..7cd686d 100755
--- a/net/test/ping6_test.py
+++ b/net/test/ping6_test.py
@@ -314,7 +314,13 @@
                 "%08X:%08X" % (txmem, rxmem),
                 str(os.getuid()), "2", "0"]
     actual = self.ReadProcNetSocket(name)[-1]
+    # Check all the parameters except rxmem and txmem.
+    expected[3] = actual[3]
     self.assertListEqual(expected, actual)
+    # Check that rxmem and txmem don't differ too much from each other.
+    actual_txmem, actual_rxmem = expected[3].split(":")
+    self.assertAlmostEqual(txmem, int(actual_txmem, 16), delta=txmem / 4)
+    self.assertAlmostEqual(rxmem, int(actual_rxmem, 16), delta=rxmem / 4)
 
   def testIPv4SendWithNoConnection(self):
     s = net_test.IPv4PingSocket()
diff --git a/net/test/xfrm_base.py b/net/test/xfrm_base.py
index d41fae5..b7c7cb7 100644
--- a/net/test/xfrm_base.py
+++ b/net/test/xfrm_base.py
@@ -374,3 +374,49 @@
     esp_hdr, _ = cstruct.Read(str(packet.payload), xfrm.EspHdr)
     self.assertEquals(xfrm.EspHdr((spi, seq)), esp_hdr)
     return packet
+
+  def CreateTunnel(self, direction, selector, src, dst, spi, encryption,
+                   auth_trunc, mark, output_mark):
+    """Create an XFRM Tunnel Consisting of a Policy and an SA.
+
+    Create a unidirectional XFRM tunnel, which entails one Policy and one
+    security association.
+
+    Args:
+      direction: XFRM_POLICY_IN or XFRM_POLICY_OUT
+      selector: An XfrmSelector that specifies the packets to be transformed.
+        This is only applied to the policy; the selector in the SA is always
+        empty. If the passed-in selector is None, then the tunnel is made
+        dual-stack. This requires two policies, one for IPv4 and one for IPv6.
+      src: The source address of the tunneled packets
+      dst: The destination address of the tunneled packets
+      spi: The SPI for the IPsec SA that encapsulates the tunneled packet
+      encryption: A tuple (XfrmAlgo, key), the encryption parameters.
+      auth_trunc: A tuple (XfrmAlgoAuth, key), the authentication parameters.
+      mark: An XfrmMark, the mark used for selecting packets to be tunneled, and
+        for matching the security policy and security association. None means
+        unspecified.
+      output_mark: The mark used to select the underlying network for packets
+        outbound from xfrm. None means unspecified.
+    """
+    outer_family = net_test.GetAddressFamily(net_test.GetAddressVersion(dst))
+
+    self.xfrm.AddSaInfo(
+        src, dst,
+        spi, xfrm.XFRM_MODE_TUNNEL, 0,
+        encryption,
+        auth_trunc,
+        None,
+        None,
+        mark,
+        output_mark)
+
+    if selector is None:
+      selectors = [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]
+    else:
+      selectors = [selector]
+
+    for selector in selectors:
+      policy = UserPolicy(direction, selector)
+      tmpl = UserTemplate(outer_family, spi, 0, (src, dst))
+      self.xfrm.AddPolicyInfo(policy, tmpl, mark)
diff --git a/net/test/xfrm_tunnel_test.py b/net/test/xfrm_tunnel_test.py
index fc7a709..a6a0eec 100755
--- a/net/test/xfrm_tunnel_test.py
+++ b/net/test/xfrm_tunnel_test.py
@@ -40,106 +40,21 @@
 _TEST_IKEY = _TEST_IN_SPI + _VTI_NETID
 
 
-@unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
+def _GetLocalInnerAddress(version):
+  return {4: "10.16.5.15", 6: "2001:db8:1::1"}[version]
+
+
+def _GetRemoteInnerAddress(version):
+  return {4: "10.16.5.20", 6: "2001:db8:2::1"}[version]
+
+
+def _GetRemoteOuterAddress(version):
+  return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
+
+
+
 class XfrmTunnelTest(xfrm_base.XfrmBaseTest):
 
-  def setUp(self):
-    super(XfrmTunnelTest, self).setUp()
-    # If the hard-coded netids are redefined this will catch the error.
-    self.assertNotIn(_VTI_NETID, self.NETIDS,
-                     "VTI netid %d already in use" % _VTI_NETID)
-    self.iproute = iproute.IPRoute()
-    self._QuietDeleteLink(_VTI_IFNAME)
-
-  def tearDown(self):
-    super(XfrmTunnelTest, self).tearDown()
-    self._QuietDeleteLink(_VTI_IFNAME)
-
-  @staticmethod
-  def _GetLocalInnerAddress(version):
-    return {4: "10.16.5.15", 6: "2001:db8:1::1"}[version]
-
-  @staticmethod
-  def _GetRemoteInnerAddress(version):
-    return {4: "10.16.5.20", 6: "2001:db8:2::1"}[version]
-
-  def _GetRemoteOuterAddress(self, version):
-    return self.GetRemoteAddress(version)
-
-  def _QuietDeleteLink(self, ifname):
-    try:
-      self.iproute.DeleteLink(ifname)
-    except IOError:
-      # The link was not present.
-      pass
-
-  def _SwapInterfaceAddress(self, ifname, old_addr, new_addr):
-    """Exchange two addresses on a given interface.
-
-    Args:
-      ifname: Name of the interface
-      old_addr: An address to be removed from the interface
-      new_addr: An address to be added to an interface
-    """
-    version = 6 if ":" in new_addr else 4
-    ifindex = net_test.GetInterfaceIndex(ifname)
-    self.iproute.AddAddress(new_addr,
-                            net_test.AddressLengthBits(version), ifindex)
-    self.iproute.DelAddress(old_addr,
-                            net_test.AddressLengthBits(version), ifindex)
-
-  # TODO: Take encryption and auth parameters.
-  def _CreateXfrmTunnel(self,
-                        direction,
-                        selector,
-                        tsrc_addr,
-                        tdst_addr,
-                        spi,
-                        mark=None,
-                        output_mark=None):
-    """Create an XFRM Tunnel Consisting of a Policy and an SA.
-
-    Create a unidirectional XFRM tunnel, which entails one Policy and one
-    security association.
-
-    Args:
-      direction: XFRM_POLICY_IN or XFRM_POLICY_OUT
-      selector: An XfrmSelector that specifies the packets to be transformed.
-        This is only applied to the policy; the selector in the SA is always
-        empty. If the passed-in selector is None, then the tunnel is made
-        dual-stack. This requires two policies, one for IPv4 and one for IPv6.
-      tsrc_addr: The source address of the tunneled packets
-      tdst_addr: The destination address of the tunneled packets
-      spi: The SPI for the IPsec SA that encapsulates the tunneled packet
-      mark: The mark used for selecting packets to be tunneled, and for
-        matching the security policy and security association.
-      output_mark: The mark used to select the underlying network for packets
-        outbound from xfrm.
-    """
-    outer_family = net_test.GetAddressFamily(
-        net_test.GetAddressVersion(tdst_addr))
-
-    self.xfrm.AddSaInfo(
-        tsrc_addr, tdst_addr,
-        spi, xfrm.XFRM_MODE_TUNNEL, 0,
-        xfrm_base._ALGO_CBC_AES_256,
-        xfrm_base._ALGO_HMAC_SHA1,
-        None,
-        None,
-        mark,
-        output_mark)
-
-    if selector is None:
-      selectors = [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]
-    else:
-      selectors = [selector]
-
-    for selector in selectors:
-      policy = xfrm_base.UserPolicy(direction, selector)
-      tmpl = xfrm_base.UserTemplate(outer_family, spi, 0,
-                                    (tsrc_addr, tdst_addr))
-      self.xfrm.AddPolicyInfo(policy, tmpl, mark)
-
   def _CheckTunnelOutput(self, inner_version, outer_version):
     """Test a bi-directional XFRM Tunnel with explicit selectors"""
     # Select the underlying netid, which represents the external
@@ -150,17 +65,15 @@
     netid = self.RandomNetid(exclude=underlying_netid)
 
     local_inner = self.MyAddress(inner_version, netid)
-    remote_inner = self._GetRemoteInnerAddress(inner_version)
+    remote_inner = _GetRemoteInnerAddress(inner_version)
     local_outer = self.MyAddress(outer_version, underlying_netid)
-    remote_outer = self._GetRemoteOuterAddress(outer_version)
+    remote_outer = _GetRemoteOuterAddress(outer_version)
 
-    self._CreateXfrmTunnel(
-        direction=xfrm.XFRM_POLICY_OUT,
-        selector=xfrm.SrcDstSelector(local_inner, remote_inner),
-        tsrc_addr=local_outer,
-        tdst_addr=remote_outer,
-        spi=_TEST_OUT_SPI,
-        output_mark=underlying_netid)
+    self.CreateTunnel(xfrm.XFRM_POLICY_OUT,
+                      xfrm.SrcDstSelector(local_inner, remote_inner),
+                      local_outer, remote_outer, _TEST_OUT_SPI,
+                      xfrm_base._ALGO_CBC_AES_256, xfrm_base._ALGO_HMAC_SHA1,
+                      None, underlying_netid)
 
     write_sock = socket(net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
     # Select an interface, which provides the source address of the inner
@@ -184,6 +97,44 @@
   def testIpv6InIpv6TunnelOutput(self):
     self._CheckTunnelOutput(6, 6)
 
+
+@unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
+class XfrmVtiTest(xfrm_base.XfrmBaseTest):
+
+  def setUp(self):
+    super(XfrmVtiTest, self).setUp()
+    # If the hard-coded netids are redefined this will catch the error.
+    self.assertNotIn(_VTI_NETID, self.NETIDS,
+                     "VTI netid %d already in use" % _VTI_NETID)
+    self.iproute = iproute.IPRoute()
+    self._QuietDeleteLink(_VTI_IFNAME)
+
+  def tearDown(self):
+    super(XfrmVtiTest, self).tearDown()
+    self._QuietDeleteLink(_VTI_IFNAME)
+
+  def _QuietDeleteLink(self, ifname):
+    try:
+      self.iproute.DeleteLink(ifname)
+    except IOError:
+      # The link was not present.
+      pass
+
+  def _SwapInterfaceAddress(self, ifname, old_addr, new_addr):
+    """Exchange two addresses on a given interface.
+
+    Args:
+      ifname: Name of the interface
+      old_addr: An address to be removed from the interface
+      new_addr: An address to be added to an interface
+    """
+    version = 6 if ":" in new_addr else 4
+    ifindex = net_test.GetInterfaceIndex(ifname)
+    self.iproute.AddAddress(new_addr,
+                            net_test.AddressLengthBits(version), ifindex)
+    self.iproute.DelAddress(old_addr,
+                            net_test.AddressLengthBits(version), ifindex)
+
   def testAddVti(self):
     """Test the creation of a Virtual Tunnel Interface."""
     for version in [4, 6]:
@@ -192,7 +143,7 @@
       self.iproute.CreateVirtualTunnelInterface(
           dev_name=_VTI_IFNAME,
           local_addr=local_addr,
-          remote_addr=self._GetRemoteOuterAddress(version),
+          remote_addr=_GetRemoteOuterAddress(version),
           o_key=_TEST_OKEY,
           i_key=_TEST_IKEY)
       if_index = self.iproute.GetIfIndex(_VTI_IFNAME)
@@ -203,7 +154,7 @@
       with self.assertRaises(IOError):
         self.iproute.GetIfIndex(_VTI_IFNAME)
 
-  def _SetupVtiNetwork(self, ifname, is_add):
+  def _SetupVtiNetwork(self, netid, ifname, is_add):
     """Setup rules and routes for a VTI Network.
 
     Takes an interface and depending on the boolean
@@ -219,7 +170,7 @@
     if is_add:
       # Bring up the interface so that we can start adding addresses
       # and routes.
-      net_test.SetInterfaceUp(_VTI_IFNAME)
+      net_test.SetInterfaceUp(ifname)
 
       # Disable router solicitations to avoid occasional spurious packets
       # arriving on the underlying network; there are two possible behaviors
@@ -228,133 +179,131 @@
       # the UDP_PAYLOAD; or, two packets may arrive on the underlying
       # network which fails the assertion that only one ESP packet is received.
       self.SetSysctl(
-          "/proc/sys/net/ipv6/conf/%s/router_solicitations" % _VTI_IFNAME, 0)
+          "/proc/sys/net/ipv6/conf/%s/router_solicitations" % ifname, 0)
     for version in [4, 6]:
       ifindex = net_test.GetInterfaceIndex(ifname)
-      table = _VTI_NETID
+      table = netid
 
       # Set up routing rules.
-      start, end = self.UidRangeForNetid(_VTI_NETID)
+      start, end = self.UidRangeForNetid(netid)
       self.iproute.UidRangeRule(version, is_add, start, end, table,
                                 self.PRIORITY_UID)
       self.iproute.OifRule(version, is_add, ifname, table, self.PRIORITY_OIF)
-      self.iproute.FwmarkRule(version, is_add, _VTI_NETID, table,
+      self.iproute.FwmarkRule(version, is_add, netid, table,
                               self.PRIORITY_FWMARK)
       if is_add:
         self.iproute.AddAddress(
-            self._GetLocalInnerAddress(version),
+            _GetLocalInnerAddress(version),
             net_test.AddressLengthBits(version), ifindex)
         self.iproute.AddRoute(version, table, "default", 0, None, ifindex)
       else:
         self.iproute.DelRoute(version, table, "default", 0, None, ifindex)
         self.iproute.DelAddress(
-            self._GetLocalInnerAddress(version),
+            _GetLocalInnerAddress(version),
             net_test.AddressLengthBits(version), ifindex)
     if not is_add:
-      net_test.SetInterfaceDown(_VTI_IFNAME)
+      net_test.SetInterfaceDown(ifname)
 
   # TODO: Should we completely re-write this using null encryption and null
   # authentication? We could then assemble and disassemble packets for each
   # direction individually. This approach would improve debuggability, avoid the
   # complexity of the twister, and allow the test to more-closely validate
   # deployable configurations.
-  def _CheckVtiOutput(self, inner_version, outer_version):
-    """Test packet input and output over a Virtual Tunnel Interface."""
-    netid = self.RandomNetid()
+  def _CreateVti(self, netid, vti_netid, ifname, outer_version):
     local_outer = self.MyAddress(outer_version, netid)
-    remote_outer = self._GetRemoteOuterAddress(outer_version)
+    remote_outer = _GetRemoteOuterAddress(outer_version)
     self.iproute.CreateVirtualTunnelInterface(
-        dev_name=_VTI_IFNAME,
+        dev_name=ifname,
         local_addr=local_outer,
         remote_addr=remote_outer,
         i_key=_TEST_IKEY,
         o_key=_TEST_OKEY)
-    self._SetupVtiNetwork(_VTI_IFNAME, True)
+
+    self._SetupVtiNetwork(vti_netid, ifname, True)
+
+    # For the VTI, the selectors are wildcard since packets will only
+    # be selected if they have the appropriate mark, hence the inner
+    # addresses are wildcard.
+    self.CreateTunnel(xfrm.XFRM_POLICY_OUT, None, local_outer, remote_outer,
+                      _TEST_OUT_SPI, xfrm_base._ALGO_CBC_AES_256,
+                      xfrm_base._ALGO_HMAC_SHA1,
+                      xfrm.ExactMatchMark(_TEST_OKEY), netid)
+
+    self.CreateTunnel(xfrm.XFRM_POLICY_IN, None, remote_outer, local_outer,
+                      _TEST_IN_SPI, xfrm_base._ALGO_CBC_AES_256,
+                      xfrm_base._ALGO_HMAC_SHA1,
+                      xfrm.ExactMatchMark(_TEST_IKEY), None)
+
+  def _CheckVtiInputOutput(self, netid, vti_netid, iface, outer_version,
+                           inner_version, rx, tx):
+    local_outer = self.MyAddress(outer_version, netid)
+    remote_outer = _GetRemoteOuterAddress(outer_version)
+
+    # Create a socket to receive packets.
+    read_sock = socket(
+        net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
+    read_sock.bind((net_test.GetWildcardAddress(inner_version), 0))
+    # The second parameter of the tuple is the port number regardless of AF.
+    port = read_sock.getsockname()[1]
+    # Guard against the eventuality of the receive failing.
+    csocket.SetSocketTimeout(read_sock, 100)
+
+    # Send a packet out via the vti-backed network, bound for the port number
+    # of the input socket.
+    write_sock = socket(
+        net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
+    self.SelectInterface(write_sock, vti_netid, "mark")
+    write_sock.sendto(net_test.UDP_PAYLOAD,
+                      (_GetRemoteInnerAddress(inner_version), port))
+
+    # Read a tunneled IP packet on the underlying (outbound) network
+    # verifying that it is an ESP packet.
+    pkt = self._ExpectEspPacketOn(netid, _TEST_OUT_SPI, tx + 1, None,
+                                  local_outer, remote_outer)
+
+    self.assertEquals((rx, tx + 1), self.iproute.GetRxTxPackets(iface))
+
+    # Perform an address switcheroo so that the inner address of the remote
+    # end of the tunnel is now the address on the local VTI interface; this
+    # way, the twisted inner packet finds a destination via the VTI once
+    # decrypted.
+    remote = _GetRemoteInnerAddress(inner_version)
+    local = _GetLocalInnerAddress(inner_version)
+    self._SwapInterfaceAddress(iface, new_addr=remote, old_addr=local)
+    try:
+      # Swap the packet's IP headers and write it back to the
+      # underlying network.
+      pkt = TunTwister.TwistPacket(pkt)
+      self.ReceivePacketOn(netid, pkt)
+      # Receive the decrypted packet on the dest port number.
+      read_packet = read_sock.recv(4096)
+      self.assertEquals(read_packet, net_test.UDP_PAYLOAD)
+      self.assertEquals((rx + 1, tx + 1), self.iproute.GetRxTxPackets(iface))
+    finally:
+      # Unwind the switcheroo
+      self._SwapInterfaceAddress(iface, new_addr=local, old_addr=remote)
+
+    return rx + 1, tx + 1
+
+  def _TestVti(self, outer_version):
+    """Test packet input and output over a Virtual Tunnel Interface."""
+    netid = self.RandomNetid()
+
+    self._CreateVti(netid, _VTI_NETID, _VTI_IFNAME, outer_version)
 
     try:
-      # For the VTI, the selectors are wildcard since packets will only
-      # be selected if they have the appropriate mark, hence the inner
-      # addresses are wildcard.
-      self._CreateXfrmTunnel(
-          direction=xfrm.XFRM_POLICY_OUT,
-          selector=None,
-          tsrc_addr=local_outer,
-          tdst_addr=remote_outer,
-          mark=xfrm.ExactMatchMark(_TEST_OKEY),
-          spi=_TEST_OUT_SPI,
-          output_mark=netid)
-
-      self._CreateXfrmTunnel(
-          direction=xfrm.XFRM_POLICY_IN,
-          selector=None,
-          tsrc_addr=remote_outer,
-          tdst_addr=local_outer,
-          mark=xfrm.ExactMatchMark(_TEST_IKEY),
-          spi=_TEST_IN_SPI,
-          output_mark=netid)
-
-      # Create a socket to receive packets.
-      read_sock = socket(
-          net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
-      read_sock.bind((net_test.GetWildcardAddress(inner_version), 0))
-      # The second parameter of the tuple is the port number regardless of AF.
-      port = read_sock.getsockname()[1]
-      # Guard against the eventuality of the receive failing.
-      csocket.SetSocketTimeout(read_sock, 100)
-
-      # Start counting packets.
-      rx, tx = self.iproute.GetRxTxPackets(_VTI_IFNAME)
-
-      # Send a packet out via the vti-backed network, bound for the port number
-      # of the input socket.
-      write_sock = socket(
-          net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
-      self.SelectInterface(write_sock, _VTI_NETID, "mark")
-      write_sock.sendto(net_test.UDP_PAYLOAD,
-                        (self._GetRemoteInnerAddress(inner_version), port))
-
-      # Read a tunneled IP packet on the underlying (outbound) network
-      # verifying that it is an ESP packet.
-      pkt = self._ExpectEspPacketOn(netid, _TEST_OUT_SPI, 1, None, local_outer,
-                                    remote_outer)
-
-      self.assertEquals((rx, tx + 1), self.iproute.GetRxTxPackets(_VTI_IFNAME))
-
-      # Perform an address switcheroo so that the inner address of the remote
-      # end of the tunnel is now the address on the local VTI interface; this
-      # way, the twisted inner packet finds a destination via the VTI once
-      # decrypted.
-      remote = self._GetRemoteInnerAddress(inner_version)
-      local = self._GetLocalInnerAddress(inner_version)
-      self._SwapInterfaceAddress(_VTI_IFNAME, new_addr=remote, old_addr=local)
-      try:
-        # Swap the packet's IP headers and write it back to the
-        # underlying network.
-        pkt = TunTwister.TwistPacket(pkt)
-        self.ReceivePacketOn(netid, pkt)
-        # Receive the decrypted packet on the dest port number.
-        read_packet = read_sock.recv(4096)
-        self.assertEquals(read_packet, net_test.UDP_PAYLOAD)
-        self.assertEquals((rx + 1, tx + 1),
-                          self.iproute.GetRxTxPackets(_VTI_IFNAME))
-      finally:
-        # Unwind the switcheroo
-        self._SwapInterfaceAddress(_VTI_IFNAME, new_addr=local, old_addr=remote)
-
+      rx, tx = self._CheckVtiInputOutput(netid, _VTI_NETID, _VTI_IFNAME,
+                                         outer_version, 4, 0, 0)
+      self._CheckVtiInputOutput(netid, _VTI_NETID, _VTI_IFNAME, outer_version,
+                                4, rx, tx)
     finally:
-      self._SetupVtiNetwork(_VTI_IFNAME, False)
+      self._SetupVtiNetwork(_VTI_NETID, _VTI_IFNAME, False)
 
-  def testIpv4InIpv4VtiOutput(self):
-    self._CheckVtiOutput(4, 4)
+  def testIpv4Vti(self):
+    self._TestVti(4)
 
-  def testIpv4InIpv6VtiOutput(self):
-    self._CheckVtiOutput(4, 6)
-
-  def testIpv6InIpv4VtiOutput(self):
-    self._CheckVtiOutput(6, 4)
-
-  def testIpv6InIpv6VtiOutput(self):
-    self._CheckVtiOutput(6, 6)
+  def testIpv6Vti(self):
+    self._TestVti(6)
 
 
 if __name__ == "__main__":