Cache UdpEncapsulationSocket for MOBIKE Child rekeys.

This CL updates ChildSessionStateMachine to update its cached
UdpEncapsulationSocket and local/remote addresses for a MOBIKE-specific
Child rekey. This is necessary so that the new IpSecTransforms are
created with the current state of the IKE Session.

Bug: 175161459
Test: atest FrameworksIkeTests
Change-Id: I8db9206d0d637244671132de8015c9b9d576e6aa
diff --git a/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java b/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java
index 3fdc4de..278a0b0 100644
--- a/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java
+++ b/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachine.java
@@ -439,8 +439,19 @@
      * SAs associated with this state machine. However, the caller is notified of Child SA creation
      * via {@link ChildSessionCallback#onIpSecTransformsMigrated(android.net.IpSecTransform,
      * android.net.IpSecTransform)};
+     *
+     * @param localAddress The local (outer) address from which traffic will originate.
+     * @param remoteAddress The remote (outer) address to which traffic will be sent.
+     * @param udpEncapSocket The socket to use for UDP encapsulation, or NULL if no encap needed.
      */
-    public void rekeyChildSessionForMobike() {
+    public void rekeyChildSessionForMobike(
+            InetAddress localAddress,
+            InetAddress remoteAddress,
+            UdpEncapsulationSocket udpEncapSocket) {
+        this.mLocalAddress = localAddress;
+        this.mRemoteAddress = remoteAddress;
+        this.mUdpEncapSocket = udpEncapSocket;
+
         sendMessage(CMD_LOCAL_REQUEST_REKEY_CHILD_MOBIKE);
     }
 
diff --git a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
index 1163486..60f1c6c 100644
--- a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
+++ b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
@@ -2487,7 +2487,7 @@
                             childData.respPayloads,
                             mLocalAddress,
                             mRemoteAddress,
-                            getEncapSocketIfNatDetected(),
+                            getEncapSocketOrNull(),
                             mIkePrf,
                             mCurrentIkeSaRecord.getSkD());
                     return HANDLED;
@@ -2538,16 +2538,12 @@
         }
 
         // Returns the UDP-Encapsulation socket to the newly created ChildSessionStateMachine if
-        // a NAT is detected. It allows the ChildSessionStateMachine to build IPsec transforms that
-        // can send and receive IPsec traffic through a NAT.
-        private UdpEncapsulationSocket getEncapSocketIfNatDetected() {
-            boolean isNatDetected = mLocalNatDetected || mRemoteNatDetected;
-
-            if (!isNatDetected) return null;
-
+        // a NAT is detected or if NAT-T AND MOBIKE are enabled by both parties. It allows the
+        // ChildSessionStateMachine to build IPsec transforms that can send and receive IPsec
+        // traffic through a NAT.
+        private UdpEncapsulationSocket getEncapSocketOrNull() {
             if (!(mIkeSocket instanceof IkeUdpEncapSocket)) {
-                throw new IllegalStateException(
-                        "NAT is detected but IKE packet is not UDP-Encapsulated.");
+                return null;
             }
             return ((IkeUdpEncapSocket) mIkeSocket).getUdpEncapsulationSocket();
         }
@@ -2574,7 +2570,7 @@
                     mChildInLocalProcedure.createChildSession(
                             mLocalAddress,
                             mRemoteAddress,
-                            getEncapSocketIfNatDetected(),
+                            getEncapSocketOrNull(),
                             mIkePrf,
                             mCurrentIkeSaRecord.getSkD());
                     break;
@@ -2582,7 +2578,8 @@
                     mChildInLocalProcedure.rekeyChildSession();
                     break;
                 case CMD_LOCAL_REQUEST_REKEY_CHILD_MOBIKE:
-                    mChildInLocalProcedure.rekeyChildSessionForMobike();
+                    mChildInLocalProcedure.rekeyChildSessionForMobike(
+                            mLocalAddress, mRemoteAddress, getEncapSocketOrNull());
                     break;
                 case CMD_LOCAL_REQUEST_DELETE_CHILD:
                     mChildInLocalProcedure.deleteChildSession();
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachineTest.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachineTest.java
index 2e2ba5f..beff4d9 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/ChildSessionStateMachineTest.java
@@ -143,6 +143,8 @@
 
     private static final Inet4Address LOCAL_ADDRESS =
             (Inet4Address) InetAddresses.parseNumericAddress("192.0.2.200");
+    private static final Inet4Address UPDATED_LOCAL_ADDRESS =
+            (Inet4Address) InetAddresses.parseNumericAddress("192.0.2.201");
     private static final Inet4Address REMOTE_ADDRESS =
             (Inet4Address) InetAddresses.parseNumericAddress("192.0.2.100");
     private static final Inet4Address INTERNAL_ADDRESS =
@@ -400,16 +402,27 @@
             int initSpi,
             int respSpi,
             boolean isLocalInit) {
+        verifyChildSaRecordConfig(
+                childSaRecordConfig, initSpi, respSpi, isLocalInit, LOCAL_ADDRESS, REMOTE_ADDRESS);
+    }
+
+    private void verifyChildSaRecordConfig(
+            ChildSaRecordConfig childSaRecordConfig,
+            int initSpi,
+            int respSpi,
+            boolean isLocalInit,
+            InetAddress localAddress,
+            InetAddress remoteAddress) {
         assertEquals(mContext, childSaRecordConfig.context);
         assertEquals(initSpi, childSaRecordConfig.initSpi.getSpi());
         assertEquals(respSpi, childSaRecordConfig.respSpi.getSpi());
 
         if (isLocalInit) {
-            assertEquals(LOCAL_ADDRESS, childSaRecordConfig.initAddress);
-            assertEquals(REMOTE_ADDRESS, childSaRecordConfig.respAddress);
+            assertEquals(localAddress, childSaRecordConfig.initAddress);
+            assertEquals(remoteAddress, childSaRecordConfig.respAddress);
         } else {
-            assertEquals(REMOTE_ADDRESS, childSaRecordConfig.initAddress);
-            assertEquals(LOCAL_ADDRESS, childSaRecordConfig.respAddress);
+            assertEquals(remoteAddress, childSaRecordConfig.initAddress);
+            assertEquals(localAddress, childSaRecordConfig.respAddress);
         }
 
         assertEquals(mMockUdpEncapSocket, childSaRecordConfig.udpEncapSocket);
@@ -1024,9 +1037,14 @@
     }
 
     private void setupStateMachineAndSpiForLocalRekey() throws Exception {
+        setupStateMachineAndSpiForLocalRekey(LOCAL_ADDRESS, REMOTE_ADDRESS);
+    }
+
+    private void setupStateMachineAndSpiForLocalRekey(
+            InetAddress updatedLocalAddress, InetAddress updatedRemoteAddress) throws Exception {
         setupIdleStateMachine();
-        setUpSpiResource(LOCAL_ADDRESS, LOCAL_INIT_NEW_CHILD_SA_SPI_IN);
-        setUpSpiResource(REMOTE_ADDRESS, LOCAL_INIT_NEW_CHILD_SA_SPI_OUT);
+        setUpSpiResource(updatedLocalAddress, LOCAL_INIT_NEW_CHILD_SA_SPI_IN);
+        setUpSpiResource(updatedRemoteAddress, LOCAL_INIT_NEW_CHILD_SA_SPI_OUT);
     }
 
     @Test
@@ -1038,19 +1056,30 @@
         mLooper.dispatchAll();
 
         verifyRekeyChildLocalCreateHandlesResponse(
-                ChildSessionStateMachine.RekeyChildLocalCreate.class, false /* isMobikeRekey */);
+                ChildSessionStateMachine.RekeyChildLocalCreate.class,
+                false /* isMobikeRekey */,
+                LOCAL_ADDRESS,
+                REMOTE_ADDRESS);
     }
 
     private void verifyRekeyChildLocalCreateHandlesResponse(
-            Class<?> expectedState, boolean isMobikeRekey) throws Exception {
+            Class<?> expectedState,
+            boolean isMobikeRekey,
+            InetAddress localAddress,
+            InetAddress remoteAddress)
+            throws Exception {
         assertTrue(expectedState.isInstance(mChildSessionStateMachine.getCurrentState()));
 
         List<IkePayload> rekeyRespPayloads = receiveRekeyChildResponse();
-        verifyLocalRekeyCreateIsDone(rekeyRespPayloads, isMobikeRekey);
+        verifyLocalRekeyCreateIsDone(rekeyRespPayloads, isMobikeRekey, localAddress, remoteAddress);
     }
 
     private void verifyLocalRekeyCreateIsDone(
-            List<IkePayload> rekeyRespPayloads, boolean isMobikeRekey) throws Exception {
+            List<IkePayload> rekeyRespPayloads,
+            boolean isMobikeRekey,
+            InetAddress localAddress,
+            InetAddress remoteAddress)
+            throws Exception {
         // Verify state transition
         assertTrue(
                 mChildSessionStateMachine.getCurrentState()
@@ -1075,7 +1104,9 @@
                 childSaRecordConfig,
                 LOCAL_INIT_NEW_CHILD_SA_SPI_IN,
                 LOCAL_INIT_NEW_CHILD_SA_SPI_OUT,
-                true /*isLocalInit*/);
+                true /*isLocalInit*/,
+                localAddress,
+                remoteAddress);
 
         // Verify users have been notified
         verify(mSpyUserCbExecutor).execute(any(Runnable.class));
@@ -1136,7 +1167,8 @@
 
         // Receive Rekey Create response and verify creation is done
         List<IkePayload> rekeyRespPayloads = receiveRekeyChildResponse();
-        verifyLocalRekeyCreateIsDone(rekeyRespPayloads, false /* isMobikeRekey */);
+        verifyLocalRekeyCreateIsDone(
+                rekeyRespPayloads, false /* isMobikeRekey */, LOCAL_ADDRESS, REMOTE_ADDRESS);
     }
 
     @Test
@@ -1954,14 +1986,21 @@
 
     @Test
     public void testMobikeRekeyChildLocalCreateHandlesResp() throws Exception {
-        setupStateMachineAndSpiForLocalRekey();
+        setupStateMachineAndSpiForLocalRekey(UPDATED_LOCAL_ADDRESS, REMOTE_ADDRESS);
 
         // Send MOBIKE Rekey-Create request
-        mChildSessionStateMachine.rekeyChildSessionForMobike();
+        mChildSessionStateMachine.rekeyChildSessionForMobike(
+                UPDATED_LOCAL_ADDRESS, REMOTE_ADDRESS, mMockUdpEncapSocket);
         mLooper.dispatchAll();
 
         verifyRekeyChildLocalCreateHandlesResponse(
                 ChildSessionStateMachine.MobikeRekeyChildLocalCreate.class,
-                true /* isMobikeRekey */);
+                true /* isMobikeRekey */,
+                UPDATED_LOCAL_ADDRESS,
+                REMOTE_ADDRESS);
+
+        assertEquals(UPDATED_LOCAL_ADDRESS, mChildSessionStateMachine.mLocalAddress);
+        assertEquals(REMOTE_ADDRESS, mChildSessionStateMachine.mRemoteAddress);
+        assertEquals(mMockUdpEncapSocket, mChildSessionStateMachine.mUdpEncapSocket);
     }
 }
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
index 71dc37e..4a7d8ea 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
@@ -5990,7 +5990,8 @@
         assertTrue(
                 mIkeSessionStateMachine.getCurrentState()
                         instanceof IkeSessionStateMachine.ChildProcedureOngoing);
-        verify(mMockChildSessionStateMachine).rekeyChildSessionForMobike();
+        verify(mMockChildSessionStateMachine)
+                .rekeyChildSessionForMobike(eq(expectedLocalAddr), eq(expectedRemoteAddr), any());
     }
 
     @Test