Test Updating OUTPUT_MARK on Active SAs

Test that the xfrm output mark is update-able
on an Active SA, which means that we can dynamically
reroute traffic to new underlying networks after an
SA has been created.

Bug: 71645364
Test: run_net_test.sh xfrm_test.py
Change-Id: I561fdc27439d33807667c4a58a03bca3c468599b
diff --git a/net/test/xfrm_test.py b/net/test/xfrm_test.py
index 3a3d9b0..afcacde 100755
--- a/net/test/xfrm_test.py
+++ b/net/test/xfrm_test.py
@@ -38,6 +38,9 @@
 TEST_ADDR1 = "2001:4860:4860::8888"
 TEST_ADDR2 = "2001:4860:4860::8844"
 
+XFRM_STATS_PROCFILE = "/proc/net/xfrm_stat"
+XFRM_STATS_OUT_NO_STATES = "XfrmOutNoStates"
+
 # IP addresses to use for tunnel endpoints. For generality, these should be
 # different from the addresses we send packets to.
 TUNNEL_ENDPOINTS = {4: "8.8.4.4", 6: TEST_ADDR2}
@@ -778,7 +781,7 @@
             xfrm_base._ALGO_HMAC_SHA1, None, None, None, 0)
 
   def testUpdateSaAddMark(self):
-    """Test that when an SA has no mark, it can be updated to add a mark."""
+    """Test that an embryonic SA can be updated to add a mark."""
     for version in [4, 6]:
       spi = 0xABCD
       # Test that an SA created with ALLOCSPI can be updated with the mark.
@@ -798,7 +801,94 @@
       self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version),
                              spi, IPPROTO_ESP, mark)
 
-      # TODO: we might also need to update the mark for a VALID SA.
+  def getXfrmStat(self, statName):
+    stateVal = 0
+    with open(XFRM_STATS_PROCFILE, 'r') as f:
+      for line in f:
+          if statName in line:
+            stateVal = int(line.split()[1])
+            break
+      f.close()
+    return stateVal
+
+  def testUpdateActiveSaMarks(self):
+    """Test that the OUTPUT_MARK can be updated on an ACTIVE SA."""
+    for version in [4, 6]:
+      family = net_test.GetAddressFamily(version)
+      netid = self.RandomNetid()
+      remote = self.GetRemoteAddress(version)
+      local = self.MyAddress(version, netid)
+      s = socket(family, SOCK_DGRAM, 0)
+      self.SelectInterface(s, netid, "mark")
+      # Create a mark that we will apply to the policy and later the SA
+      mark = xfrm.ExactMatchMark(netid)
+
+      # Create a global policy that selects using the mark.
+      sel = xfrm.EmptySelector(family)
+      policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+      tmpl = xfrm.UserTemplate(family, 0, 0, (local, remote))
+      self.xfrm.AddPolicyInfo(policy, tmpl, mark)
+
+      # Pull /proc/net/xfrm_stats for baseline
+      outNoStateCount = self.getXfrmStat(XFRM_STATS_OUT_NO_STATES);
+
+      # should increment XfrmOutNoStates
+      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
+
+      # Check to make sure XfrmOutNoStates is incremented by exactly 1
+      self.assertEquals(outNoStateCount + 1,
+                        self.getXfrmStat(XFRM_STATS_OUT_NO_STATES))
+
+      length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL,
+                                            version, False,
+                                            net_test.UDP_PAYLOAD,
+                                            xfrm_base._ALGO_HMAC_SHA1,
+                                            xfrm_base._ALGO_CBC_AES_256)
+
+      # Add a default SA with no mark that routes to nowhere.
+      self.xfrm.AddSaInfo(local,
+                          remote,
+                          TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
+                          xfrm_base._ALGO_CBC_AES_256,
+                          xfrm_base._ALGO_HMAC_SHA1,
+                          None, None, None, 0, is_update=False)
+      self.assertRaisesErrno(
+          ENETUNREACH,
+          s.sendto, net_test.UDP_PAYLOAD, (remote, 53))
+
+      # Update the SA to route to a valid netid.
+      self.xfrm.AddSaInfo(local,
+                          remote,
+                          TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
+                          xfrm_base._ALGO_CBC_AES_256,
+                          xfrm_base._ALGO_HMAC_SHA1,
+                          None, None, None, netid, is_update=True)
+
+      # Now the payload routes to the updated netid.
+      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
+      self._ExpectEspPacketOn(netid, TEST_SPI, 1, length, None, None)
+
+      # Get a new netid and reroute the packets to the new netid.
+      reroute_netid = self.RandomNetid(netid)
+      # Update the SA to change the output mark.
+      self.xfrm.AddSaInfo(local,
+                         remote,
+                         TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
+                         xfrm_base._ALGO_CBC_AES_256,
+                         xfrm_base._ALGO_HMAC_SHA1,
+                         None, None, None, reroute_netid, is_update=True)
+
+      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
+      self._ExpectEspPacketOn(reroute_netid, TEST_SPI, 2, length, None, None)
+
+      dump = self.xfrm.DumpSaInfo()
+
+      self.assertEquals(1, len(dump)) # check that update updated
+      sainfo, attributes = dump[0]
+      self.assertEquals(reroute_netid, attributes["XFRMA_OUTPUT_MARK"])
+
+      self.xfrm.DeleteSaInfo(remote, TEST_SPI, IPPROTO_ESP, None)
+      self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark)
 
 if __name__ == "__main__":
   unittest.main()