Add tests for VTI rekey procedure

This change adds tests for rekeying of VTIs. The general flow is to
create new SAs, update the policy to use the new SAs (outbound should
use new SA exclusively, inbound is permissive), and then delete the old
SAs.

Bug: 66467511
Test: this; Passes on all android-common kernels
Change-Id: I65713926e727fdb99200836da918000a5cddfad8
diff --git a/net/test/xfrm.py b/net/test/xfrm.py
index 56b4774..acdfd4f 100755
--- a/net/test/xfrm.py
+++ b/net/test/xfrm.py
@@ -208,7 +208,7 @@
 NO_LIFETIME_CUR = "\x00" * len(XfrmLifetimeCur)
 
 # IPsec constants.
-IPSEC_PROTO_ANY	= 255
+IPSEC_PROTO_ANY = 255
 
 # ESP header, not technically XFRM but we need a place for a protocol
 # header and this is the only one we have.
@@ -219,6 +219,11 @@
 _DEFAULT_REPLAY_WINDOW = 4
 ALL_ALGORITHMS = 0xffffffff
 
+# Policy-SA match method (for VTI/XFRM-I).
+MATCH_METHOD_ALL = "all"
+MATCH_METHOD_MARK = "mark"
+MATCH_METHOD_IFID = "ifid"
+
 
 def RawAddress(addr):
   """Converts an IP address string to binary format."""
@@ -630,7 +635,7 @@
     self._SendNlRequest(XFRM_MSG_FLUSHSA, usersa_flush.Pack(), flags)
 
   def CreateTunnel(self, direction, selector, src, dst, spi, encryption,
-                   auth_trunc, mark, output_mark, xfrm_if_id):
+                   auth_trunc, mark, output_mark, xfrm_if_id, match_method):
     """Create an XFRM Tunnel Consisting of a Policy and an SA.
 
     Create a unidirectional XFRM tunnel, which entails one Policy and one
@@ -652,9 +657,28 @@
       output_mark: The mark used to select the underlying network for packets
         outbound from xfrm. None means unspecified.
       xfrm_if_id: The ID of the XFRM interface to use or None.
+      match_method: One of MATCH_METHOD_[MARK | ALL | IFID]. This determines how
+        SAs and policies are matched.
     """
     outer_family = net_test.GetAddressFamily(net_test.GetAddressVersion(dst))
 
+    # SA mark is currently unused due to UPDSA not updating marks.
+    # Kept as documentation of ideal/desired behavior.
+    if match_method == MATCH_METHOD_MARK:
+      # sa_mark = mark
+      tmpl_spi = 0
+      if_id = None
+    elif match_method == MATCH_METHOD_ALL:
+      # sa_mark = mark
+      tmpl_spi = spi
+      if_id = xfrm_if_id
+    elif match_method == MATCH_METHOD_IFID:
+      # sa_mark = None
+      tmpl_spi = 0
+      if_id = xfrm_if_id
+    else:
+      raise ValueError("Unknown match_method supplied: %s" % match_method)
+
     # Device code does not use mark; during AllocSpi, the mark is unset, and
     # UPDSA does not update marks at this time. Actual use case will have no
     # mark set. Test this use case.
@@ -668,7 +692,7 @@
 
     for selector in selectors:
       policy = UserPolicy(direction, selector)
-      tmpl = UserTemplate(outer_family, spi, 0, (src, dst))
+      tmpl = UserTemplate(outer_family, tmpl_spi, 0, (src, dst))
       self.AddPolicyInfo(policy, tmpl, mark, xfrm_if_id=xfrm_if_id)
 
   def DeleteTunnel(self, direction, selector, dst, spi, mark, xfrm_if_id):
diff --git a/net/test/xfrm_tunnel_test.py b/net/test/xfrm_tunnel_test.py
index 652a0c2..778eb26 100755
--- a/net/test/xfrm_tunnel_test.py
+++ b/net/test/xfrm_tunnel_test.py
@@ -185,17 +185,15 @@
 
     # Create input/ouput SPs, SAs and sockets to simulate a more realistic
     # environment.
-    self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_IN,
-                           xfrm.SrcDstSelector(remote_inner, local_inner),
-                           remote_outer, local_outer, _TEST_IN_SPI,
-                           xfrm_base._ALGO_CRYPT_NULL,
-                           xfrm_base._ALGO_AUTH_NULL, None, None, None)
+    self.xfrm.CreateTunnel(
+        xfrm.XFRM_POLICY_IN, xfrm.SrcDstSelector(remote_inner, local_inner),
+        remote_outer, local_outer, _TEST_IN_SPI, xfrm_base._ALGO_CRYPT_NULL,
+        xfrm_base._ALGO_AUTH_NULL, None, None, None, xfrm.MATCH_METHOD_ALL)
 
-    self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_OUT,
-                           xfrm.SrcDstSelector(local_inner, remote_inner),
-                           local_outer, remote_outer, _TEST_OUT_SPI,
-                           xfrm_base._ALGO_CBC_AES_256,
-                           xfrm_base._ALGO_HMAC_SHA1, None, u_netid, None)
+    self.xfrm.CreateTunnel(
+        xfrm.XFRM_POLICY_OUT, xfrm.SrcDstSelector(local_inner, remote_inner),
+        local_outer, remote_outer, _TEST_OUT_SPI, xfrm_base._ALGO_CBC_AES_256,
+        xfrm_base._ALGO_HMAC_SHA1, None, u_netid, None, xfrm.MATCH_METHOD_ALL)
 
     write_sock = socket(net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
     self.SelectInterface(write_sock, netid, "mark")
@@ -321,12 +319,31 @@
 
     self._SetupXfrmByType(auth, crypt)
 
+  def Rekey(self, outer_family, new_out_sa, new_in_sa):
+    """Rekeys the Tunnel Interface
+
+    Creates new SAs and updates the outbound security policy to use new SAs.
+
+    Args:
+      outer_family: AF_INET or AF_INET6
+      new_out_sa: An SaInfo struct representing the new outbound SA's info
+      new_in_sa: An SaInfo struct representing the new inbound SA's info
+    """
+    self._Rekey(outer_family, new_out_sa, new_in_sa)
+
+    # Update Interface object
+    self.out_sa = new_out_sa
+    self.in_sa = new_in_sa
+
   def TeardownXfrm(self):
     raise NotImplementedError("Subclasses should implement this")
 
   def _SetupXfrmByType(self, auth_algo, crypt_algo):
     raise NotImplementedError("Subclasses should implement this")
 
+  def _Rekey(self, outer_family, new_out_sa, new_in_sa):
+    raise NotImplementedError("Subclasses should implement this")
+
 
 class VtiInterface(IpSecBaseInterface):
 
@@ -351,11 +368,12 @@
     self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote,
                            self.out_sa.spi, crypt_algo, auth_algo,
                            xfrm.ExactMatchMark(self.okey),
-                           self.underlying_netid, None)
+                           self.underlying_netid, None, xfrm.MATCH_METHOD_ALL)
 
     self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local,
                            self.in_sa.spi, crypt_algo, auth_algo,
-                           xfrm.ExactMatchMark(self.ikey), None, None)
+                           xfrm.ExactMatchMark(self.ikey), None, None,
+                           xfrm.MATCH_METHOD_MARK)
 
   def TeardownXfrm(self):
     self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_OUT, None, self.remote,
@@ -363,6 +381,35 @@
     self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_IN, None, self.local,
                            self.in_sa.spi, self.ikey, None)
 
+  def _Rekey(self, outer_family, new_out_sa, new_in_sa):
+    # TODO: Consider ways to share code with xfrm.CreateTunnel(). It's mostly
+    #       the same, but rekeys are asymmetric, and only update the outbound
+    #       policy.
+    self.xfrm.AddSaInfo(self.local, self.remote, new_out_sa.spi,
+                        xfrm.XFRM_MODE_TUNNEL, 0, xfrm_base._ALGO_CRYPT_NULL,
+                        xfrm_base._ALGO_AUTH_NULL, None, None,
+                        xfrm.ExactMatchMark(self.okey), self.underlying_netid)
+
+    self.xfrm.AddSaInfo(self.remote, self.local, new_in_sa.spi,
+                        xfrm.XFRM_MODE_TUNNEL, 0, xfrm_base._ALGO_CRYPT_NULL,
+                        xfrm_base._ALGO_AUTH_NULL, None, None,
+                        xfrm.ExactMatchMark(self.ikey), None)
+
+    # Create new policies for IPv4 and IPv6.
+    for sel in [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]:
+      # Add SPI-specific output policy to enforce using new outbound SPI
+      policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+      tmpl = xfrm.UserTemplate(outer_family, new_out_sa.spi, 0,
+                                    (self.local, self.remote))
+      self.xfrm.UpdatePolicyInfo(policy, tmpl, xfrm.ExactMatchMark(self.okey),
+                                 0)
+
+  def DeleteOldSaInfo(self, outer_family, old_in_spi, old_out_spi):
+    self.xfrm.DeleteSaInfo(self.local, old_in_spi, IPPROTO_ESP,
+                           xfrm.ExactMatchMark(self.ikey))
+    self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP,
+                           xfrm.ExactMatchMark(self.okey))
+
 
 @unittest.skipUnless(HAVE_XFRM_INTERFACES, "XFRM interfaces unsupported")
 class XfrmAddDeleteXfrmInterfaceTest(xfrm_base.XfrmBaseTest):
@@ -401,10 +448,11 @@
   def _SetupXfrmByType(self, auth_algo, crypt_algo):
     self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote,
                            self.out_sa.spi, crypt_algo, auth_algo, None,
-                           self.underlying_netid, self.xfrm_if_id)
+                           self.underlying_netid, self.xfrm_if_id,
+                           xfrm.MATCH_METHOD_ALL)
     self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local,
                            self.in_sa.spi, crypt_algo, auth_algo, None, None,
-                           self.xfrm_if_id)
+                           self.xfrm_if_id, xfrm.MATCH_METHOD_IFID)
 
   def TeardownXfrm(self):
     self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_OUT, None, self.remote,
@@ -412,6 +460,33 @@
     self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_IN, None, self.local,
                            self.in_sa.spi, None, self.xfrm_if_id)
 
+  def _Rekey(self, outer_family, new_out_sa, new_in_sa):
+    # TODO: Consider ways to share code with xfrm.CreateTunnel(). It's mostly
+    #       the same, but rekeys are asymmetric, and only update the outbound
+    #       policy.
+    self.xfrm.AddSaInfo(
+        self.local, self.remote, new_out_sa.spi, xfrm.XFRM_MODE_TUNNEL, 0,
+        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL, None, None,
+        None, self.underlying_netid, xfrm_if_id=self.xfrm_if_id)
+
+    self.xfrm.AddSaInfo(
+        self.remote, self.local, new_in_sa.spi, xfrm.XFRM_MODE_TUNNEL, 0,
+        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL, None, None,
+        None, None, xfrm_if_id=self.xfrm_if_id)
+
+    # Create new policies for IPv4 and IPv6.
+    for sel in [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]:
+      # Add SPI-specific output policy to enforce using new outbound SPI
+      policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+      tmpl = xfrm.UserTemplate(outer_family, new_out_sa.spi, 0,
+                                    (self.local, self.remote))
+      self.xfrm.UpdatePolicyInfo(policy, tmpl, None, self.xfrm_if_id)
+
+  def DeleteOldSaInfo(self, outer_family, old_in_spi, old_out_spi):
+    self.xfrm.DeleteSaInfo(self.local, old_in_spi, IPPROTO_ESP, None,
+                           self.xfrm_if_id)
+    self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP, None,
+                           self.xfrm_if_id)
 
 
 class XfrmTunnelBase(xfrm_base.XfrmBaseTest):
@@ -555,7 +630,7 @@
     sa_info.seq_num += 1
 
   def _CheckTunnelInput(self, tunnel, inner_version, local_inner, remote_inner,
-                        sa_info=None):
+                        sa_info=None, expect_fail=False):
     """Test null-crypt input path over an IPsec interface."""
     if sa_info is None:
       sa_info = tunnel.in_sa
@@ -566,11 +641,14 @@
         local_inner, tunnel.local, local_port, sa_info.spi, sa_info.seq_num)
     self.ReceivePacketOn(tunnel.underlying_netid, input_pkt)
 
-    # Verify that the packet data and src are correct
-    self.assertReceivedPacket(tunnel, sa_info)
-    data, src = read_sock.recvfrom(4096)
-    self.assertEquals(net_test.UDP_PAYLOAD, data)
-    self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2])
+    if expect_fail:
+      self.assertRaisesErrno(EAGAIN, read_sock.recv, 4096)
+    else:
+      # Verify that the packet data and src are correct
+      self.assertReceivedPacket(tunnel, sa_info)
+      data, src = read_sock.recvfrom(4096)
+      self.assertEquals(net_test.UDP_PAYLOAD, data)
+      self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2])
 
   def _CheckTunnelOutput(self, tunnel, inner_version, local_inner,
                          remote_inner, sa_info=None):
@@ -711,13 +789,25 @@
     tunnel = self.randomTunnel(outer_version)
 
     try:
+      # Some tests require that the out_seq_num and in_seq_num are the same
+      # (Specifically encrypted tests), rebuild SAs to ensure seq_num is 1
+      #
+      # Until we get better scapy support, the only way we can build an
+      # encrypted packet is to send it out, and read the packet from the wire.
+      # We then generally use this as the "inbound" encrypted packet, injecting
+      # it into the interface for which it is expected on.
+      #
+      # As such, this is required to ensure that encrypted packets (which we
+      # currently have no way to easily modify) are not considered replay
+      # attacks by the inbound SA.  (eg: received 3 packets, seq_num_in = 3,
+      # sent only 1, # seq_num_out = 1, inbound SA would consider this a replay
+      # attack)
       tunnel.TeardownXfrm()
       tunnel.SetupXfrm(use_null_crypt)
 
       local_inner = tunnel.addrs[inner_version]
       remote_inner = _GetRemoteInnerAddress(inner_version)
 
-      # Run twice to ensure sequence numbers are tested
       for i in range(2):
         func(tunnel, inner_version, local_inner, remote_inner)
     finally:
@@ -725,6 +815,67 @@
         tunnel.TeardownXfrm()
         tunnel.SetupXfrm(False)
 
+  def _CheckTunnelRekey(self, tunnel, inner_version, local_inner, remote_inner):
+    old_out_sa = tunnel.out_sa
+    old_in_sa = tunnel.in_sa
+
+    # Check to make sure that both directions work before rekey
+    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+                           old_in_sa)
+    self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
+                            old_out_sa)
+
+    # Rekey
+    outer_family = net_test.GetAddressFamily(tunnel.version)
+
+    # Create new SA
+    # Distinguish the new SAs with new SPIs.
+    new_out_sa = SaInfo(old_out_sa.spi + 1)
+    new_in_sa = SaInfo(old_in_sa.spi + 1)
+
+    # Perform Rekey
+    tunnel.Rekey(outer_family, new_out_sa, new_in_sa)
+
+    # Expect that the old SPI still works for inbound packets
+    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+                           old_in_sa)
+
+    # Test both paths with new SPIs, expect outbound to use new SPI
+    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+                           new_in_sa)
+    self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
+                            new_out_sa)
+
+    # Delete old SAs
+    tunnel.DeleteOldSaInfo(outer_family, old_in_sa.spi, old_out_sa.spi)
+
+    # Test both paths with new SPIs; should still work
+    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+                           new_in_sa)
+    self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
+                            new_out_sa)
+
+    # Expect failure upon trying to receive a packet with the deleted SPI
+    self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+                           old_in_sa, True)
+
+  def _TestTunnelRekey(self, inner_version, outer_version):
+    """Test packet input and output over a Virtual Tunnel Interface."""
+    tunnel = self.randomTunnel(outer_version)
+
+    try:
+      # Always use null_crypt, so we can check input and output separately
+      tunnel.TeardownXfrm()
+      tunnel.SetupXfrm(True)
+
+      local_inner = tunnel.addrs[inner_version]
+      remote_inner = _GetRemoteInnerAddress(inner_version)
+
+      self._CheckTunnelRekey(tunnel, inner_version, local_inner, remote_inner)
+    finally:
+      tunnel.TeardownXfrm()
+      tunnel.SetupXfrm(False)
+
 
 @unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
 class XfrmVtiTest(XfrmTunnelBase):
@@ -749,6 +900,9 @@
     self._TestTunnel(inner_version, outer_version,
                      self._CheckTunnelEncryptionWithIcmp, False)
 
+  def ParamTestVtiRekey(self, inner_version, outer_version):
+    self._TestTunnelRekey(inner_version, outer_version)
+
 
 @unittest.skipUnless(HAVE_XFRM_INTERFACES, "XFRM interfaces unsupported")
 class XfrmInterfaceTest(XfrmTunnelBase):
@@ -773,6 +927,9 @@
     self._TestTunnel(inner_version, outer_version,
                      self._CheckTunnelEncryptionWithIcmp, False)
 
+  def ParamTestXfrmIntfRekey(self, inner_version, outer_version):
+    self._TestTunnelRekey(inner_version, outer_version)
+
 
 if __name__ == "__main__":
   InjectTests()