Retry IKE INIT if receiving Notify-Cookie

IKE Session will retry IKE INIT request if the received IKE INIT
response contains a Notify-Cookie payload

Bug: 174997213
Test: FrameworksIkeTests(new tests)
Change-Id: Iaafb2f9aa0a4a14458852ad791fd3030a8da6499
diff --git a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
index 151f8ad..6deac8d 100644
--- a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
+++ b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
@@ -32,6 +32,7 @@
 import static com.android.internal.net.ipsec.ike.message.IkeMessage.DECODE_STATUS_PARTIAL;
 import static com.android.internal.net.ipsec.ike.message.IkeMessage.DECODE_STATUS_PROTECTED_ERROR;
 import static com.android.internal.net.ipsec.ike.message.IkeMessage.DECODE_STATUS_UNPROTECTED_ERROR;
+import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_COOKIE;
 import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_COOKIE2;
 import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_EAP_ONLY_AUTHENTICATION;
 import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_IKEV2_FRAGMENTATION_SUPPORTED;
@@ -2884,23 +2885,27 @@
         @Override
         public void enterState() {
             try {
-                IkeMessage request = buildIkeInitReq();
-
-                // Register local SPI to receive the IKE INIT response.
-                mIkeSocket.registerIke(
-                        request.ikeHeader.ikeInitiatorSpi, IkeSessionStateMachine.this);
-
-                mIkeInitRequestBytes = request.encode();
-                mIkeInitNoncePayload =
-                        request.getPayloadForType(
-                                IkePayload.PAYLOAD_TYPE_NONCE, IkeNoncePayload.class);
-                mRetransmitter = new UnencryptedRetransmitter(request);
+                sendRequest(buildIkeInitReq());
             } catch (IOException e) {
                 // Fail to assign IKE SPI
                 handleIkeFatalError(e);
             }
         }
 
+        private void sendRequest(IkeMessage request) {
+            // Register local SPI to receive the IKE INIT response.
+            mIkeSocket.registerIke(request.ikeHeader.ikeInitiatorSpi, IkeSessionStateMachine.this);
+
+            mIkeInitRequestBytes = request.encode();
+            mIkeInitNoncePayload =
+                    request.getPayloadForType(IkePayload.PAYLOAD_TYPE_NONCE, IkeNoncePayload.class);
+
+            if (mRetransmitter != null) {
+                mRetransmitter.stopRetransmitting();
+            }
+            mRetransmitter = new UnencryptedRetransmitter(request);
+        }
+
         @Override
         protected void triggerRetransmit() {
             mRetransmitter.retransmit();
@@ -2977,10 +2982,41 @@
             }
         }
 
+        /** Returns the Notify-Cookie payload, or null if it does not exist */
+        private IkeNotifyPayload getNotifyCookie(IkeMessage ikeMessage) {
+            List<IkeNotifyPayload> notifyPayloads =
+                    ikeMessage.getPayloadListForType(PAYLOAD_TYPE_NOTIFY, IkeNotifyPayload.class);
+            for (IkeNotifyPayload notify : notifyPayloads) {
+                if (notify.notifyType == NOTIFY_TYPE_COOKIE) {
+                    return notify;
+                }
+            }
+            return null;
+        }
+
         @Override
         protected void handleResponseIkeMessage(IkeMessage ikeMessage) {
             boolean ikeInitSuccess = false;
             try {
+                int exchangeType = ikeMessage.ikeHeader.exchangeType;
+                if (exchangeType != IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT) {
+                    throw new InvalidSyntaxException(
+                            "Expected EXCHANGE_TYPE_IKE_SA_INIT but received: " + exchangeType);
+                }
+
+                // Retry IKE INIT if there is Notify-Cookie
+                IkeNotifyPayload inCookiePayload = getNotifyCookie(ikeMessage);
+                if (inCookiePayload != null) {
+                    IkeNotifyPayload outCookiePayload =
+                            IkeNotifyPayload.handleCookieAndGenerateCopy(inCookiePayload);
+                    IkeMessage initReq =
+                            buildReqWithCookie(mRetransmitter.getMessage(), outCookiePayload);
+
+                    sendRequest(initReq);
+                    return;
+                }
+
+                // Negotiate IKE SA
                 validateIkeInitResp(mRetransmitter.getMessage(), ikeMessage);
 
                 mCurrentIkeSaRecord =
@@ -3094,18 +3130,45 @@
             return new IkeMessage(ikeHeader, payloadList);
         }
 
+        /**
+         * Builds an IKE INIT request that has the same payloads and SPI with the original request,
+         * and with the new Notify-Cookie Payload as the first payload.
+         */
+        private IkeMessage buildReqWithCookie(
+                IkeMessage originalReq, IkeNotifyPayload cookieNotify) {
+            List<IkePayload> payloads = new ArrayList<>();
+
+            // Notify-Cookie MUST be the first payload.
+            payloads.add(cookieNotify);
+
+            for (IkePayload payload : originalReq.ikePayloadList) {
+                // Keep all previous payloads except COOKIEs
+                if (payload instanceof IkeNotifyPayload
+                        && ((IkeNotifyPayload) payload).notifyType == NOTIFY_TYPE_COOKIE) {
+                    continue;
+                }
+                payloads.add(payload);
+            }
+
+            IkeHeader originalHeader = originalReq.ikeHeader;
+            IkeHeader header =
+                    new IkeHeader(
+                            originalHeader.ikeInitiatorSpi,
+                            originalHeader.ikeResponderSpi,
+                            PAYLOAD_TYPE_NOTIFY,
+                            IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT,
+                            false /* isResponseMsg */,
+                            true /* fromIkeInitiator */,
+                            0 /* messageId */);
+            return new IkeMessage(header, payloads);
+        }
+
         private void validateIkeInitResp(IkeMessage reqMsg, IkeMessage respMsg)
                 throws IkeProtocolException, IOException {
             IkeHeader respIkeHeader = respMsg.ikeHeader;
             mRemoteIkeSpiResource =
                     mIkeSpiGenerator.allocateSpi(mRemoteAddress, respIkeHeader.ikeResponderSpi);
 
-            int exchangeType = respIkeHeader.exchangeType;
-            if (exchangeType != IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT) {
-                throw new InvalidSyntaxException(
-                        "Expected EXCHANGE_TYPE_IKE_SA_INIT but received: " + exchangeType);
-            }
-
             IkeSaPayload respSaPayload = null;
             IkeKePayload respKePayload = null;
 
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
index ccd731d..90986a6 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
@@ -51,6 +51,7 @@
 import static com.android.internal.net.ipsec.ike.message.IkeConfigPayload.CONFIG_ATTR_IP6_PCSCF;
 import static com.android.internal.net.ipsec.ike.message.IkeHeader.EXCHANGE_TYPE_CREATE_CHILD_SA;
 import static com.android.internal.net.ipsec.ike.message.IkeHeader.EXCHANGE_TYPE_INFORMATIONAL;
+import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_COOKIE;
 import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_COOKIE2;
 import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_EAP_ONLY_AUTHENTICATION;
 import static com.android.internal.net.ipsec.ike.message.IkeNotifyPayload.NOTIFY_TYPE_IKEV2_FRAGMENTATION_SUPPORTED;
@@ -335,13 +336,16 @@
 
     private static final int PAYLOAD_TYPE_UNSUPPORTED = 127;
 
+    private static final int COOKIE_DATA_LEN = 64;
     private static final int COOKIE2_DATA_LEN = 64;
 
+    private static final byte[] COOKIE_DATA = new byte[COOKIE_DATA_LEN];
     private static final byte[] COOKIE2_DATA = new byte[COOKIE2_DATA_LEN];
 
     private static final int NATT_KEEPALIVE_DELAY = 20;
 
     static {
+        new Random().nextBytes(COOKIE_DATA);
         new Random().nextBytes(COOKIE2_DATA);
     }
 
@@ -417,31 +421,39 @@
     private ArgumentCaptor<List<IkePayload>> mPayloadListCaptor =
             ArgumentCaptor.forClass(List.class);
 
-    private ReceivedIkePacket makeDummyReceivedIkeInitRespPacket(
-            long initiatorSpi,
-            long responderSpi,
-            @IkeHeader.ExchangeType int eType,
-            boolean isResp,
-            boolean fromIkeInit,
-            List<Integer> payloadTypeList,
-            List<String> payloadHexStringList)
+    private ReceivedIkePacket makeDummyReceivedIkeInitRespPacket(List<IkePayload> payloadList)
             throws Exception {
+        long dummyInitSpi = 1L;
+        long dummyRespSpi = 2L;
 
-        List<IkePayload> payloadList =
-                hexStrListToIkePayloadList(payloadTypeList, payloadHexStringList, isResp);
         // Build a remotely generated NAT_DETECTION_SOURCE_IP payload to mock a remote node's
         // network that is not behind NAT.
         IkePayload sourceNatPayload =
                 new IkeNotifyPayload(
                         NOTIFY_TYPE_NAT_DETECTION_SOURCE_IP,
                         IkeNotifyPayload.generateNatDetectionData(
-                                initiatorSpi,
-                                responderSpi,
+                                dummyInitSpi,
+                                dummyRespSpi,
                                 REMOTE_ADDRESS,
                                 IkeSocket.SERVER_PORT_UDP_ENCAPSULATED));
         payloadList.add(sourceNatPayload);
+
         return makeDummyUnencryptedReceivedIkePacket(
-                initiatorSpi, responderSpi, eType, isResp, fromIkeInit, payloadList);
+                dummyInitSpi,
+                dummyRespSpi,
+                IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT,
+                true /*isResp*/,
+                false /*fromIkeInit*/,
+                payloadList);
+    }
+
+    private ReceivedIkePacket makeDummyReceivedIkeInitRespPacket(
+            List<Integer> payloadTypeList, List<String> payloadHexStringList) throws Exception {
+
+        List<IkePayload> payloadList =
+                hexStrListToIkePayloadList(
+                        payloadTypeList, payloadHexStringList, true /* isResp */);
+        return makeDummyReceivedIkeInitRespPacket(payloadList);
     }
 
     private ReceivedIkePacket makeDummyUnencryptedReceivedIkePacket(
@@ -997,16 +1009,7 @@
         payloadHexStringList.add(NONCE_RESP_PAYLOAD_HEX_STRING);
         payloadHexStringList.addAll(optionalPayloadHexStrings);
 
-        // In each test assign different IKE responder SPI in IKE INIT response to avoid remote SPI
-        // collision during response validation.
-        // STOPSHIP: b/131617794 allow #mockIkeSetup to be independent in each test after we can
-        // support IkeSession cleanup.
         return makeDummyReceivedIkeInitRespPacket(
-                1L /*initiator SPI*/,
-                2L /*responder SPI*/,
-                IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT,
-                true /*isResp*/,
-                false /*fromIkeInit*/,
                 payloadTypeList,
                 payloadHexStringList);
     }
@@ -1457,11 +1460,6 @@
         // Send back a INVALID_KE_PAYLOAD, and verify that the selected DH group changes
         ReceivedIkePacket resp =
                 makeDummyReceivedIkeInitRespPacket(
-                        1L /*initiator SPI*/,
-                        2L /*responder SPI*/,
-                        IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT,
-                        true /*isResp*/,
-                        false /*fromIkeInit*/,
                         Arrays.asList(IkePayload.PAYLOAD_TYPE_NOTIFY),
                         Arrays.asList(INVALID_KE_PAYLOAD_HEX_STRING));
         mIkeSessionStateMachine.sendMessage(IkeSessionStateMachine.CMD_RECEIVE_IKE_PACKET, resp);
@@ -1472,6 +1470,45 @@
     }
 
     @Test
+    public void testCreateIkeLocalIkeInitReceivesCookie() throws Exception {
+        setupFirstIkeSa();
+
+        mIkeSessionStateMachine.sendMessage(IkeSessionStateMachine.CMD_LOCAL_REQUEST_CREATE_IKE);
+        mLooper.dispatchAll();
+
+        // Encode 2 times: one for mIkeInitRequestBytes and one for sending packets
+        verify(mMockIkeMessageHelper, times(2)).encode(mIkeMessageCaptor.capture());
+        IkeMessage originalReqMsg = mIkeMessageCaptor.getValue();
+        List<IkePayload> originalPayloadList = originalReqMsg.ikePayloadList;
+
+        // Reset to forget sending original IKE INIT request
+        resetMockIkeMessageHelper();
+
+        // Send back a Notify-Cookie
+        IkeNotifyPayload inCookieNotify = new IkeNotifyPayload(NOTIFY_TYPE_COOKIE, COOKIE_DATA);
+        List<IkePayload> payloads = new ArrayList<>();
+        payloads.add(inCookieNotify);
+        ReceivedIkePacket resp = makeDummyReceivedIkeInitRespPacket(payloads);
+        mIkeSessionStateMachine.sendMessage(IkeSessionStateMachine.CMD_RECEIVE_IKE_PACKET, resp);
+        mLooper.dispatchAll();
+
+        // Verify retry IKE INIT request
+        verify(mMockIkeMessageHelper, times(2)).encode(mIkeMessageCaptor.capture());
+        IkeMessage ikeInitReqMessage = mIkeMessageCaptor.getValue();
+        List<IkePayload> payloadList = ikeInitReqMessage.ikePayloadList;
+
+        IkeNotifyPayload outCookieNotify = (IkeNotifyPayload) payloadList.get(0);
+        assertEquals(NOTIFY_TYPE_COOKIE, outCookieNotify.notifyType);
+        assertArrayEquals(COOKIE_DATA, outCookieNotify.notifyData);
+
+        assertEquals(originalPayloadList, payloadList.subList(1, payloadList.size()));
+
+        assertTrue(
+                mIkeSessionStateMachine.getCurrentState()
+                        instanceof IkeSessionStateMachine.CreateIkeLocalIkeInit);
+    }
+
+    @Test
     public void testCreateIkeLocalIkeInitSwitchesToEncapPorts() throws Exception {
         setupFirstIkeSa();
         mIkeSessionStateMachine.sendMessage(IkeSessionStateMachine.CMD_LOCAL_REQUEST_CREATE_IKE);