Support flushing XFRM state.
Test: all_tests.sh passes on common and device kernels
Bug: 34114242
Change-Id: I95e7a92bba2e82c4e3affb4bf7788e139268ab44
diff --git a/net/test/xfrm.py b/net/test/xfrm.py
index e9b4fce..56f581f 100755
--- a/net/test/xfrm.py
+++ b/net/test/xfrm.py
@@ -134,7 +134,8 @@
XfrmAlgo = cstruct.Struct("XfrmAlgo", "=64AI", "name key_len")
-XfrmAlgoAuth = cstruct.Struct("XfrmAlgo", "=64AII", "name key_len trunc_len")
+XfrmAlgoAuth = cstruct.Struct("XfrmAlgoAuth", "=64AII",
+ "name key_len trunc_len")
XfrmAlgoAead = cstruct.Struct("XfrmAlgoAead", "=64AII", "name key_len icv_len")
@@ -164,6 +165,8 @@
"sel lft curlft priority index dir action flags share",
[XfrmSelector, XfrmLifetimeCfg, XfrmLifetimeCur])
+XfrmUsersaFlush = cstruct.Struct("XfrmUsersaFlush", "=B", "proto")
+
# Socket options. See include/uapi/linux/in.h.
IP_IPSEC_POLICY = 16
IP_XFRM_POLICY = 17
@@ -179,6 +182,9 @@
NO_LIFETIME_CFG = XfrmLifetimeCfg((_INF, _INF, _INF, _INF, 0, 0, 0, 0))
NO_LIFETIME_CUR = "\x00" * len(XfrmLifetimeCur)
+# IPsec constants.
+IPSEC_PROTO_ANY = 255
+
def RawAddress(addr):
"""Converts an IP address string to binary format."""
@@ -282,6 +288,11 @@
sainfo = [sa for sa, attrs in self.DumpSaInfo() if sa.id.spi == spi]
return sainfo[0] if sainfo else None
+ def FlushSaInfo(self):
+ usersa_flush = XfrmUsersaFlush((IPSEC_PROTO_ANY,))
+ flags = netlink.NLM_F_REQUEST | netlink.NLM_F_ACK
+ self._SendNlRequest(XFRM_MSG_FLUSHSA, usersa_flush.Pack(), flags)
+
if __name__ == "__main__":
x = Xfrm()
diff --git a/net/test/xfrm_test.py b/net/test/xfrm_test.py
index 5700ad8..734ac62 100755
--- a/net/test/xfrm_test.py
+++ b/net/test/xfrm_test.py
@@ -50,6 +50,7 @@
ALGO_CBC_AES_256 = xfrm.XfrmAlgo(("cbc(aes)", 256))
ALGO_HMAC_SHA1 = xfrm.XfrmAlgoAuth(("hmac(sha1)", 128, 96))
+
class XfrmTest(multinetwork_base.MultiNetworkBaseTest):
@classmethod
@@ -60,7 +61,11 @@
def setUp(self):
# TODO: delete this when we're more diligent about deleting our SAs.
super(XfrmTest, self).setUp()
- subprocess.call("ip xfrm state flush".split())
+ self.xfrm.FlushSaInfo()
+
+ def tearDown(self):
+ super(XfrmTest, self).tearDown()
+ self.xfrm.FlushSaInfo()
def expectIPv6EspPacketOn(self, netid, spi, seq, length):
packets = self.ReadAllPacketsOn(netid)
@@ -100,6 +105,19 @@
finally:
self.xfrm.DeleteSaInfo(TEST_ADDR1, htonl(TEST_SPI), IPPROTO_ESP)
+ def testFlush(self):
+ self.assertEquals(0, len(self.xfrm.DumpSaInfo()))
+ self.xfrm.AddMinimalSaInfo("::", "2000::", htonl(TEST_SPI),
+ IPPROTO_ESP, xfrm.XFRM_MODE_TRANSPORT, 1234,
+ ALGO_CBC_AES_256, ENCRYPTION_KEY,
+ ALGO_HMAC_SHA1, AUTH_TRUNC_KEY, None)
+ self.xfrm.AddMinimalSaInfo("0.0.0.0", "192.0.2.1", htonl(TEST_SPI),
+ IPPROTO_ESP, xfrm.XFRM_MODE_TRANSPORT, 4321,
+ ALGO_CBC_AES_256, ENCRYPTION_KEY,
+ ALGO_HMAC_SHA1, AUTH_TRUNC_KEY, None)
+ self.assertEquals(2, len(self.xfrm.DumpSaInfo()))
+ self.xfrm.FlushSaInfo()
+ self.assertEquals(0, len(self.xfrm.DumpSaInfo()))
@unittest.skipUnless(net_test.LINUX_VERSION < (4, 4, 0), "regression")
def testSocketPolicy(self):