Merge "Test that an SA Can be Updated with a Mark"
diff --git a/net/test/xfrm.py b/net/test/xfrm.py
index 456f3bd..04a434d 100755
--- a/net/test/xfrm.py
+++ b/net/test/xfrm.py
@@ -380,7 +380,7 @@
     return self._SendNlRequest(msg_type, msg, flags)
 
   def AddSaInfo(self, src, dst, spi, mode, reqid, encryption, auth_trunc, aead,
-                encap, mark, output_mark):
+                encap, mark, output_mark, is_update=False):
     """Adds an IPsec security association.
 
     Args:
@@ -397,6 +397,8 @@
       mark: A mark match specifier, such as returned by ExactMatchMark(), or
         None for an SA that matches all possible marks.
       output_mark: An integer, the output mark. 0 means unset.
+      is_update: If true, update an existing SA otherwise create a new SA. For
+        compatibility reasons, this value defaults to False.
     """
     proto = IPPROTO_ESP
     xfrm_id = XfrmId((PaddedAddress(dst), spi, proto))
@@ -450,26 +452,47 @@
                          cur, stats, seq, reqid, family, mode, replay, flags))
     msg = sa.Pack() + nlattrs
     flags = netlink.NLM_F_REQUEST | netlink.NLM_F_ACK
-    self._SendNlRequest(XFRM_MSG_NEWSA, msg, flags)
+    nl_msg_type = XFRM_MSG_UPDSA if is_update else XFRM_MSG_NEWSA
+    self._SendNlRequest(nl_msg_type, msg, flags)
 
-  def DeleteSaInfo(self, daddr, spi, proto):
+  def DeleteSaInfo(self, dst, spi, proto, mark=None):
+    """Delete an SA from the SAD
+
+    Args:
+      dst: A string, the destination IP address. Forms part of the XFRM ID, and
+        must match the destination address of the packets sent by this SA.
+      spi: An integer, the SPI.
+      proto: The protocol DB of the SA, such as IPPROTO_ESP.
+      mark: A mark match specifier, such as returned by ExactMatchMark(), or
+        None for an SA without a Mark attribute.
+    """
     # TODO: deletes take a mark as well.
-    family = AF_INET6 if ":" in daddr else AF_INET
-    usersa_id = XfrmUsersaId((PaddedAddress(daddr), spi, family, proto))
-    flags = netlink.NLM_F_REQUEST | netlink.NLM_F_ACK
-    self._SendNlRequest(XFRM_MSG_DELSA, usersa_id.Pack(), flags)
+    family = AF_INET6 if ":" in dst else AF_INET
+    usersa_id = XfrmUsersaId((PaddedAddress(dst), spi, family, proto))
+    nlattrs = []
+    if mark is not None:
+      nlattrs.append((XFRMA_MARK, mark))
+    self.SendXfrmNlRequest(XFRM_MSG_DELSA, usersa_id, nlattrs)
 
   def AllocSpi(self, dst, proto, min_spi, max_spi):
     """Allocate (reserve) an SPI.
 
     This sends an XFRM_MSG_ALLOCSPI message and returns the resulting
     XfrmUsersaInfo struct.
+
+    Args:
+      dst: A string, the destination IP address. Forms part of the XFRM ID, and
+        must match the destination address of the packets sent by this SA.
+      proto: the protocol DB of the SA, such as IPPROTO_ESP.
+      min_spi: The minimum value of the acceptable SPI range (inclusive).
+      max_spi: The maximum value of the acceptable SPI range (inclusive).
     """
     spi = XfrmUserSpiInfo("\x00" * len(XfrmUserSpiInfo))
     spi.min = min_spi
     spi.max = max_spi
     spi.info.id.daddr = PaddedAddress(dst)
     spi.info.id.proto = proto
+    spi.info.family = AF_INET6 if ":" in dst else AF_INET
 
     msg = spi.Pack()
     flags = netlink.NLM_F_REQUEST
diff --git a/net/test/xfrm_test.py b/net/test/xfrm_test.py
index c4fc46e..1ec9692 100755
--- a/net/test/xfrm_test.py
+++ b/net/test/xfrm_test.py
@@ -760,6 +760,28 @@
             xfrm.XFRM_MODE_TRANSPORT, 0, invalid_crypt,
             xfrm_base._ALGO_HMAC_SHA1, None, None, None, 0)
 
+  def testUpdateSaAddMark(self):
+    """Test that when an SA has no mark, it can be updated to add a mark."""
+    for version in [4, 6]:
+      spi = 0xABCD
+      # Test that an SA created with ALLOCSPI can be updated with the mark.
+      new_sa = self.xfrm.AllocSpi(net_test.GetWildcardAddress(version),
+                                  IPPROTO_ESP, spi, spi)
+      mark = xfrm.ExactMatchMark(0xf00d)
+      self.xfrm.AddSaInfo(net_test.GetWildcardAddress(version),
+                          net_test.GetWildcardAddress(version),
+                          spi, xfrm.XFRM_MODE_TUNNEL, 0,
+                          xfrm_base._ALGO_CBC_AES_256,
+                          xfrm_base._ALGO_HMAC_SHA1,
+                          None, None, mark, 0, is_update=True)
+      dump = self.xfrm.DumpSaInfo()
+      self.assertEquals(1, len(dump)) # check that update updated
+      sainfo, attributes = dump[0]
+      self.assertEquals(mark, attributes["XFRMA_MARK"])
+      self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version),
+                             spi, IPPROTO_ESP, mark)
+
+      # TODO: we might also need to update the mark for a VALID SA.
 
 if __name__ == "__main__":
   unittest.main()