Merge changes Id4b6122a,Ibac47013 am: 93ca92725e

Original change: https://android-review.googlesource.com/c/platform/packages/modules/IPsec/+/1512481

MUST ONLY BE SUBMITTED BY AUTOMERGER

Change-Id: I929e3f8df43b291213ae12b35f25144558b016b9
diff --git a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
index 84339c9..1163486 100644
--- a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
+++ b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
@@ -884,6 +884,12 @@
 
     /** Switch all IKE SAs to the new IKE socket due to an underlying network change. */
     private void switchToIkeSocket(IkeSocket newSocket) {
+        // Changing IkeSockets - make sure to quit NAT-T keepalive if it's going
+        if (mIkeNattKeepalive != null) {
+            mIkeNattKeepalive.stop();
+            mIkeNattKeepalive = null;
+        }
+
         long currentLocalSpi = mCurrentIkeSaRecord.getLocalSpi();
         migrateSpiToIkeSocket(currentLocalSpi, mIkeSocket, newSocket);
 
@@ -900,6 +906,28 @@
         mIkeSocket = newSocket;
     }
 
+    private void buildAndSwitchToIkeSocketWithPort4500(boolean isIpv4) {
+        try {
+            if (isIpv4) {
+                IkeSocket newSocket =
+                        IkeUdpEncapSocket.getIkeUdpEncapSocket(
+                                mNetwork,
+                                mIpSecManager,
+                                IkeSessionStateMachine.this,
+                                getHandler().getLooper());
+                switchToIkeSocket(newSocket);
+                mIkeNattKeepalive = buildAndStartNattKeepalive();
+            } else {
+                IkeSocket newSocket =
+                        IkeUdp6WithEncapPortSocket.getInstance(
+                                mNetwork, IkeSessionStateMachine.this, getHandler());
+                switchToIkeSocket(newSocket);
+            }
+        } catch (ErrnoException | IOException | ResourceUnavailableException e) {
+            handleIkeFatalError(e);
+        }
+    }
+
     private void migrateSpiToIkeSocket(long localSpi, IkeSocket oldSocket, IkeSocket newSocket) {
         newSocket.registerIke(localSpi, IkeSessionStateMachine.this);
         oldSocket.unregisterIke(localSpi);
@@ -3230,20 +3258,11 @@
                                     IkeSessionStateMachine.this,
                                     getHandler().getLooper());
                     switchToIkeSocket(initIkeSpi, newSocket);
+                    mIkeNattKeepalive = buildAndStartNattKeepalive();
+                    mLocalPort = mIkeSocket.getLocalPort();
                 } catch (ErrnoException | IOException | ResourceUnavailableException e) {
                     handleIkeFatalError(e);
                 }
-
-                mIkeNattKeepalive =
-                        new IkeNattKeepalive(
-                                mContext,
-                                NATT_KEEPALIVE_DELAY_SECONDS,
-                                (Inet4Address) mLocalAddress,
-                                (Inet4Address) mRemoteAddress,
-                                ((IkeUdpEncapSocket) mIkeSocket).getUdpEncapsulationSocket(),
-                                mIkeSocket.getNetwork(),
-                                buildKeepaliveIntent());
-                mIkeNattKeepalive.start();
             }
         }
 
@@ -3254,14 +3273,6 @@
             mIkeSocket = newSocket;
         }
 
-        private PendingIntent buildKeepaliveIntent() {
-            return buildIkeAlarmIntent(
-                    mContext,
-                    ACTION_KEEPALIVE,
-                    getIntentIdentifier(),
-                    obtainMessage(CMD_ALARM_FIRED, mIkeSessionId, CMD_SEND_KEEPALIVE));
-        }
-
         @Override
         public void exitState() {
             super.exitState();
@@ -3339,6 +3350,34 @@
         }
     }
 
+    /** Starts NAT-T keepalive for current IkeUdpEncapSocket */
+    private IkeNattKeepalive buildAndStartNattKeepalive() throws IOException {
+        if (!(mIkeSocket instanceof IkeUdpEncapSocket)) {
+            throw new IllegalStateException(
+                    "Cannot start NAT-T keepalive when IKE Session is not using UDP Encap socket");
+        }
+
+        PendingIntent keepaliveIntent =
+                buildIkeAlarmIntent(
+                        mContext,
+                        ACTION_KEEPALIVE,
+                        getIntentIdentifier(),
+                        obtainMessage(CMD_ALARM_FIRED, mIkeSessionId, CMD_SEND_KEEPALIVE));
+
+        IkeNattKeepalive keepalive =
+                new IkeNattKeepalive(
+                        mContext,
+                        mConnectivityManager,
+                        NATT_KEEPALIVE_DELAY_SECONDS,
+                        (Inet4Address) mLocalAddress,
+                        (Inet4Address) mRemoteAddress,
+                        ((IkeUdpEncapSocket) mIkeSocket).getUdpEncapsulationSocket(),
+                        mIkeSocket.getNetwork(),
+                        keepaliveIntent);
+        keepalive.start();
+        return keepalive;
+    }
+
     /**
      * CreateIkeLocalIkeAuthBase represents the common state and functionality required to perform
      * IKE AUTH exchanges in both the EAP and non-EAP flows.
@@ -3490,8 +3529,6 @@
                 mSupportMobike = true;
                 mEnabledExtensions.add(EXTENSION_TYPE_MOBIKE);
 
-                // TODO(b/173237734): use port 4500 if NAT-T is enabled
-
                 try {
                     if (mIkeSessionParams.getConfiguredNetwork() != null) {
                         // Caller configured a specific Network - track it
@@ -3512,6 +3549,17 @@
                     // Error occurred while registering the NetworkCallback
                     throw new IkeInternalException("Error while registering NetworkCallback", e);
                 }
+
+                // Use port 4500 if NAT-T is enabled
+                if (mSupportNatTraversal && !(mIkeSocket instanceof IkeUdpEncapSocket)) {
+                    buildAndSwitchToIkeSocketWithPort4500(mIkeSocket instanceof IkeUdp4Socket);
+                    try {
+                        mLocalPort = mIkeSocket.getLocalPort();
+                    } catch (ErrnoException e) {
+                        throw new IkeInternalException(e);
+                    }
+                }
+
                 return;
             } else {
                 // Unknown and unexpected status notifications are ignored as per
@@ -5372,27 +5420,24 @@
             // Only switch the IkeSocket if the underlying Network actually changes. This may not
             // always happen (ex: the underlying Network loses the current local address)
             if (!mNetwork.equals(oldNetwork)) {
-                // Changing IkeSockets - make sure to quit NAT-T keepalive if it's going
-                if (mIkeNattKeepalive != null) {
-                    mIkeNattKeepalive.stop();
-                    mIkeNattKeepalive = null;
-                }
-
-                IkeSocket newSocket;
-                // TODO(b/173237734): use port 4500 if NAT-T is enabled
-                if (isIpv4) {
-                    newSocket =
-                            IkeUdp4Socket.getInstance(
-                                    mNetwork, IkeSessionStateMachine.this, getHandler());
+                // Use port 4500 if NAT-T is supported by both sides
+                if (mSupportNatTraversal) {
+                    buildAndSwitchToIkeSocketWithPort4500(isIpv4);
                 } else {
-                    newSocket =
-                            IkeUdp6Socket.getInstance(
-                                    mNetwork, IkeSessionStateMachine.this, getHandler());
+                    IkeSocket newSocket;
+                    if (isIpv4) {
+                        newSocket =
+                                IkeUdp4Socket.getInstance(
+                                        mNetwork, IkeSessionStateMachine.this, getHandler());
+                    } else {
+                        newSocket =
+                                IkeUdp6Socket.getInstance(
+                                        mNetwork, IkeSessionStateMachine.this, getHandler());
+                    }
+                    switchToIkeSocket(newSocket);
                 }
-                switchToIkeSocket(newSocket);
-                mLocalPort = mIkeSocket.getLocalPort();
             }
-
+            mLocalPort = mIkeSocket.getLocalPort();
             mLocalAddress =
                     mIkeLocalAddressGenerator.generateLocalAddress(
                             mNetwork, isIpv4, mRemoteAddress, mIkeSocket.getIkeServerPort());
diff --git a/src/java/com/android/internal/net/ipsec/ike/keepalive/HardwareKeepaliveImpl.java b/src/java/com/android/internal/net/ipsec/ike/keepalive/HardwareKeepaliveImpl.java
index bb0abe5..28ea3b3 100644
--- a/src/java/com/android/internal/net/ipsec/ike/keepalive/HardwareKeepaliveImpl.java
+++ b/src/java/com/android/internal/net/ipsec/ike/keepalive/HardwareKeepaliveImpl.java
@@ -49,6 +49,7 @@
     /** Construct an instance of HardwareKeepaliveImpl */
     public HardwareKeepaliveImpl(
             Context context,
+            ConnectivityManager connectMgr,
             int keepaliveDelaySeconds,
             Inet4Address src,
             Inet4Address dest,
@@ -61,10 +62,8 @@
         mKeepaliveDelaySeconds = keepaliveDelaySeconds;
         mHardwareKeepaliveCb = hardwareKeepaliveCb;
 
-        ConnectivityManager connMgr =
-                (ConnectivityManager) context.getSystemService(Context.CONNECTIVITY_SERVICE);
         mSocketKeepalive =
-                connMgr.createSocketKeepalive(
+                connectMgr.createSocketKeepalive(
                         network,
                         socket,
                         src,
diff --git a/src/java/com/android/internal/net/ipsec/ike/keepalive/IkeNattKeepalive.java b/src/java/com/android/internal/net/ipsec/ike/keepalive/IkeNattKeepalive.java
index 129a1f0..00c84c7 100644
--- a/src/java/com/android/internal/net/ipsec/ike/keepalive/IkeNattKeepalive.java
+++ b/src/java/com/android/internal/net/ipsec/ike/keepalive/IkeNattKeepalive.java
@@ -20,6 +20,7 @@
 
 import android.app.PendingIntent;
 import android.content.Context;
+import android.net.ConnectivityManager;
 import android.net.IpSecManager.UdpEncapsulationSocket;
 import android.net.Network;
 
@@ -40,6 +41,7 @@
     /** Construct an instance of IkeNattKeepalive */
     public IkeNattKeepalive(
             Context context,
+            ConnectivityManager connectMgr,
             int keepaliveDelaySeconds,
             Inet4Address src,
             Inet4Address dest,
@@ -50,6 +52,7 @@
         mNattKeepalive =
                 new HardwareKeepaliveImpl(
                         context,
+                        connectMgr,
                         keepaliveDelaySeconds,
                         src,
                         dest,
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
index e662294..71dc37e 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
@@ -748,6 +748,9 @@
         when(mMockIkeLocalAddressGenerator.generateLocalAddress(
                         eq(mMockDefaultNetwork), eq(true /* isIpv4 */), any(), anyInt()))
                 .thenReturn(LOCAL_ADDRESS);
+        when(mMockIkeLocalAddressGenerator.generateLocalAddress(
+                        eq(mMockDefaultNetwork), eq(false /* isIpv4 */), any(), anyInt()))
+                .thenReturn(LOCAL_ADDRESS_V6);
 
         mMockEapAuthenticatorFactory = mock(IkeEapAuthenticatorFactory.class);
         mMockEapAuthenticator = mock(EapAuthenticator.class);
@@ -2316,7 +2319,10 @@
                         respIdPayload,
                         authRelatedPayloads,
                         hasChildPayloads,
-                        hasConfigPayloadInResp);
+                        hasConfigPayloadInResp,
+                        false /* isMobikeEnabled */,
+                        true /* isIpv4 */,
+                        0 /* ike3gppCallbackInvocations */);
 
         verify(spyAuthPayload)
                 .verifyInboundSignature(
@@ -2358,6 +2364,27 @@
             boolean isMobikeEnabled,
             int ike3gppCallbackInvocations)
             throws Exception {
+        return verifySharedKeyAuthentication(
+                spyAuthPayload,
+                respIdPayload,
+                authRelatedPayloads,
+                hasChildPayloads,
+                hasConfigPayloadInResp,
+                isMobikeEnabled,
+                true /* isIpv4 */,
+                ike3gppCallbackInvocations);
+    }
+
+    private IkeMessage verifySharedKeyAuthentication(
+            IkeAuthPskPayload spyAuthPayload,
+            IkeIdPayload respIdPayload,
+            List<IkePayload> authRelatedPayloads,
+            boolean hasChildPayloads,
+            boolean hasConfigPayloadInResp,
+            boolean isMobikeEnabled,
+            boolean isIpv4,
+            int ike3gppCallbackInvocations)
+            throws Exception {
         IkeMessage ikeAuthReqMessage =
                 verifyAuthenticationCommonAndGetIkeMessage(
                         respIdPayload,
@@ -2365,6 +2392,7 @@
                         hasChildPayloads,
                         hasConfigPayloadInResp,
                         isMobikeEnabled,
+                        isIpv4,
                         ike3gppCallbackInvocations);
 
         // Validate authentication is done. Cannot use matchers because IkeAuthPskPayload is final.
@@ -2388,23 +2416,9 @@
             IkeIdPayload respIdPayload,
             List<IkePayload> authRelatedPayloads,
             boolean hasChildPayloads,
-            boolean hasConfigPayloadInResp)
-            throws Exception {
-        return verifyAuthenticationCommonAndGetIkeMessage(
-                respIdPayload,
-                authRelatedPayloads,
-                hasChildPayloads,
-                hasConfigPayloadInResp,
-                false /* isMobikeEnabled */,
-                0 /* ike3gppCallbackInvocations */);
-    }
-
-    private IkeMessage verifyAuthenticationCommonAndGetIkeMessage(
-            IkeIdPayload respIdPayload,
-            List<IkePayload> authRelatedPayloads,
-            boolean hasChildPayloads,
             boolean hasConfigPayloadInResp,
             boolean isMobikeEnabled,
+            boolean isIpv4,
             int ike3gppCallbackInvocations)
             throws Exception {
         // Send IKE AUTH response to IKE state machine
@@ -2470,8 +2484,12 @@
                 sessionConfig.isIkeExtensionEnabled(IkeSessionConfiguration.EXTENSION_TYPE_MOBIKE));
 
         IkeSessionConnectionInfo ikeConnInfo = sessionConfig.getIkeSessionConnectionInfo();
-        assertEquals(LOCAL_ADDRESS, ikeConnInfo.getLocalAddress());
-        assertEquals(REMOTE_ADDRESS, ikeConnInfo.getRemoteAddress());
+
+        InetAddress expectedLocalAddress = isIpv4 ? LOCAL_ADDRESS : LOCAL_ADDRESS_V6;
+        InetAddress expectedRemoteAddress = isIpv4 ? REMOTE_ADDRESS : REMOTE_ADDRESS_V6;
+
+        assertEquals(expectedLocalAddress, ikeConnInfo.getLocalAddress());
+        assertEquals(expectedRemoteAddress, ikeConnInfo.getRemoteAddress());
         assertEquals(mMockDefaultNetwork, ikeConnInfo.getNetwork());
 
         // Verify payload list pair for first Child negotiation
@@ -2483,8 +2501,8 @@
                 .handleFirstChildExchange(
                         mReqPayloadListCaptor.capture(),
                         mRespPayloadListCaptor.capture(),
-                        eq(LOCAL_ADDRESS),
-                        eq(REMOTE_ADDRESS),
+                        eq(expectedLocalAddress),
+                        eq(expectedRemoteAddress),
                         any(), // udpEncapSocket
                         eq(mIkeSessionStateMachine.mIkePrf),
                         any()); // sk_d
@@ -5311,6 +5329,34 @@
         assertTrue(mIkeSessionStateMachine.mSupportMobike);
     }
 
+    @Test
+    public void testMobikeEnabledNattSupportedIpv4() throws Exception {
+        verifyMobikeEnabled(true /* doesPeerSupportNatt */, true /* isIpv4 */);
+
+        killSessionAndVerifyNetworkCallback(true /* expectCallbackUnregistered */);
+    }
+
+    @Test
+    public void testMobikeEnabledNattUnsupportedIpv4() throws Exception {
+        verifyMobikeEnabled(false /* doesPeerSupportNatt */, true /* isIpv4 */);
+
+        killSessionAndVerifyNetworkCallback(true /* expectCallbackUnregistered */);
+    }
+
+    @Test
+    public void testMobikeEnabledNattSupportedIpv6() throws Exception {
+        verifyMobikeEnabled(true /* doesPeerSupportNatt */, false /* isIpv4 */);
+
+        killSessionAndVerifyNetworkCallback(true /* expectCallbackUnregistered */);
+    }
+
+    @Test
+    public void testMobikeEnabledNattUnsupportedIpv6() throws Exception {
+        verifyMobikeEnabled(false /* doesPeerSupportNatt */, false /* isIpv4 */);
+
+        killSessionAndVerifyNetworkCallback(true /* expectCallbackUnregistered */);
+    }
+
     /**
      * Restarts the IkeSessionStateMachine with MOBIKE enabled. If doesPeerSupportMobike, MOBIKE
      * will be active for the Session.
@@ -5328,9 +5374,75 @@
     @Nullable
     private IkeNetworkCallbackBase verifyMobikeEnabled(
             boolean doesPeerSupportMobike, Network configuredNetwork) throws Exception {
+        return verifyMobikeEnabled(
+                doesPeerSupportMobike,
+                true /* doesPeerSupportNatt */,
+                true /* isIpv4 */,
+                configuredNetwork);
+    }
+
+    @Nullable
+    private IkeDefaultNetworkCallback verifyMobikeEnabled(
+            boolean doesPeerSupportNatt, boolean isIpv4) throws Exception {
+        // Can cast to IkeDefaultNetworkCallback because no Network is specified
+        return (IkeDefaultNetworkCallback)
+                verifyMobikeEnabled(
+                        true /* doesPeerSupportMobike */,
+                        doesPeerSupportNatt,
+                        isIpv4,
+                        null /* configuredNetwork */);
+    }
+
+    /** Returns the expected IkeSocket type when MOBIKE is supported by both sides */
+    private Class<? extends IkeSocket> getExpectedSocketType(
+            boolean doesPeerSupportNatt, boolean isIpv4) {
+        if (doesPeerSupportNatt) {
+            if (isIpv4) {
+                return IkeUdpEncapSocket.class;
+            } else {
+                return IkeUdp6WithEncapPortSocket.class;
+            }
+        } else {
+            if (isIpv4) {
+                return IkeUdp4Socket.class;
+            } else {
+                return IkeUdp6Socket.class;
+            }
+        }
+    }
+
+    @Nullable
+    private IkeNetworkCallbackBase verifyMobikeEnabled(
+            boolean doesPeerSupportMobike,
+            boolean doesPeerSupportNatt,
+            boolean isIpv4,
+            Network configuredNetwork)
+            throws Exception {
         mIkeSessionStateMachine = restartStateMachineWithMobikeConfigured(configuredNetwork);
         mockIkeInitAndTransitionToIkeAuth(mIkeSessionStateMachine.mCreateIkeLocalIkeAuth);
 
+        // IKE client always supports NAT-T. So the peer decides if both sides support NAT-T.
+        mIkeSessionStateMachine.mSupportNatTraversal = doesPeerSupportNatt;
+        mIkeSessionStateMachine.mLocalAddress = isIpv4 ? LOCAL_ADDRESS : LOCAL_ADDRESS_V6;
+        mIkeSessionStateMachine.mRemoteAddress = isIpv4 ? REMOTE_ADDRESS : REMOTE_ADDRESS_V6;
+
+        if (doesPeerSupportNatt && isIpv4) {
+            // Assume NATs are detected on both sides
+            mIkeSessionStateMachine.mLocalNatDetected = true;
+            mIkeSessionStateMachine.mRemoteNatDetected = true;
+
+            mIkeSessionStateMachine.mIkeSocket = mSpyIkeUdpEncapSocket;
+        } else {
+            mIkeSessionStateMachine.mLocalNatDetected = false;
+            mIkeSessionStateMachine.mRemoteNatDetected = false;
+
+            if (isIpv4) {
+                mIkeSessionStateMachine.mIkeSocket = mSpyIkeUdp4Socket;
+            } else {
+                mIkeSessionStateMachine.mIkeSocket = mSpyIkeUdp6Socket;
+            }
+        }
+
         // Build IKE AUTH response. Include MOBIKE_SUPPORTED if doesPeerSupportMobike is true
         List<IkePayload> authRelatedPayloads = new ArrayList<>();
         IkeAuthPskPayload spyAuthPayload = makeSpyRespPskPayload();
@@ -5352,6 +5464,7 @@
                         true /* hasChildPayloads */,
                         true /* hasConfigPayloadInResp */,
                         doesPeerSupportMobike,
+                        isIpv4,
                         0 /* ike3gppCallbackInvocations */);
         verifyRetransmissionStopped();
 
@@ -5389,6 +5502,9 @@
                             ? IkeDefaultNetworkCallback.class
                             : IkeSpecificNetworkCallback.class;
             assertTrue(expectedCallbackType.isInstance(networkCallback));
+            assertTrue(
+                    getExpectedSocketType(doesPeerSupportNatt, isIpv4)
+                            .isInstance(mIkeSessionStateMachine.mIkeSocket));
         }
         return networkCallback;
     }
@@ -5414,7 +5530,7 @@
         // makeAndStartIkeSession() expects no use of ConnectivityManager#getActiveNetwork when
         // there is a configured Network. Use reset() to forget usage in setUp()
         if (configuredNetwork != null) {
-            reset(mMockConnectManager);
+            resetMockConnectManager();
         }
 
         setupChildStateMachineFactory(mMockChildSessionStateMachine);
@@ -5466,7 +5582,7 @@
     public void testMobikeActiveMobilityEvent() throws Exception {
         IkeDefaultNetworkCallback callback = verifyMobikeEnabled(true /* doesPeerSupportMobike */);
 
-        Network newNetwork = mockNewNetworkAndAddress();
+        Network newNetwork = mockNewNetworkAndAddress(true /* isIpv4 */);
 
         callback.onAvailable(newNetwork);
         mLooper.dispatchAll();
@@ -5474,20 +5590,19 @@
         verifyNetworkAndLocalAddress(newNetwork, UPDATED_LOCAL_ADDRESS, REMOTE_ADDRESS, callback);
         verify(mMockIkeLocalAddressGenerator)
                 .generateLocalAddress(
-                        eq(newNetwork),
-                        eq(true /* isIpv4 */),
-                        eq(REMOTE_ADDRESS),
-                        eq(IkeSocket.SERVER_PORT_NON_UDP_ENCAPSULATED));
+                        eq(newNetwork), eq(true /* isIpv4 */), eq(REMOTE_ADDRESS), anyInt());
     }
 
-    private Network mockNewNetworkAndAddress() throws Exception {
+    private Network mockNewNetworkAndAddress(boolean isIpv4) throws Exception {
         Network newNetwork = mock(Network.class);
+
+        InetAddress expectedRemoteAddress = isIpv4 ? REMOTE_ADDRESS : REMOTE_ADDRESS_V6;
+        InetAddress injectedLocalAddress =
+                isIpv4 ? UPDATED_LOCAL_ADDRESS : UPDATED_LOCAL_ADDRESS_V6;
         when(mMockIkeLocalAddressGenerator.generateLocalAddress(
-                        eq(newNetwork),
-                        eq(true /* isIpv4 */),
-                        eq(REMOTE_ADDRESS),
-                        eq(IkeSocket.SERVER_PORT_NON_UDP_ENCAPSULATED)))
-                .thenReturn(UPDATED_LOCAL_ADDRESS);
+                        eq(newNetwork), eq(isIpv4), eq(expectedRemoteAddress), anyInt()))
+                .thenReturn(injectedLocalAddress);
+
         return newNetwork;
     }
 
@@ -5540,18 +5655,29 @@
     private void verifySetNetwork(
             IkeNetworkCallbackBase callback, IkeSaRecord rekeySaRecord, State expectedState)
             throws Exception {
-        Network newNetwork = mockNewNetworkAndAddress();
+        verifySetNetwork(callback, rekeySaRecord, expectedState, true /* isIpv4 */);
+    }
+
+    private void verifySetNetwork(
+            IkeNetworkCallbackBase callback,
+            IkeSaRecord rekeySaRecord,
+            State expectedState,
+            boolean isIpv4)
+            throws Exception {
+        Network newNetwork = mockNewNetworkAndAddress(isIpv4);
 
         mIkeSessionStateMachine.setNetwork(newNetwork);
         mLooper.dispatchAll();
 
-        verifyNetworkAndLocalAddress(newNetwork, UPDATED_LOCAL_ADDRESS, REMOTE_ADDRESS, callback);
+        InetAddress expectedUpdatedLocalAddress =
+                isIpv4 ? UPDATED_LOCAL_ADDRESS : UPDATED_LOCAL_ADDRESS_V6;
+        InetAddress expectedRemoteAddress = isIpv4 ? REMOTE_ADDRESS : REMOTE_ADDRESS_V6;
+
+        verifyNetworkAndLocalAddress(
+                newNetwork, expectedUpdatedLocalAddress, expectedRemoteAddress, callback);
         verify(mMockIkeLocalAddressGenerator)
                 .generateLocalAddress(
-                        eq(newNetwork),
-                        eq(true /* isIpv4 */),
-                        eq(REMOTE_ADDRESS),
-                        eq(IkeSocket.SERVER_PORT_NON_UDP_ENCAPSULATED));
+                        eq(newNetwork), eq(isIpv4), eq(expectedRemoteAddress), anyInt());
 
         assertEquals(
                 mIkeSessionStateMachine,
@@ -5559,7 +5685,7 @@
                         mIkeSessionStateMachine.mCurrentIkeSaRecord.getLocalSpi()));
 
         if (rekeySaRecord != null) {
-            verifyIkeSaAddresses(rekeySaRecord, UPDATED_LOCAL_ADDRESS, REMOTE_ADDRESS);
+            verifyIkeSaAddresses(rekeySaRecord, expectedUpdatedLocalAddress, expectedRemoteAddress);
             assertEquals(
                     mIkeSessionStateMachine,
                     mIkeSessionStateMachine.mIkeSocket.mSpiToIkeSession.get(
@@ -5570,8 +5696,17 @@
     }
 
     private IkeNetworkCallbackBase setupIdleStateMachineWithMobike() throws Exception {
+        return setupIdleStateMachineWithMobike(true /* doesPeerSupportNatt */, true /* isIpv4 */);
+    }
+
+    private IkeNetworkCallbackBase setupIdleStateMachineWithMobike(
+            boolean doesPeerSupportNatt, boolean isIpv4) throws Exception {
         IkeNetworkCallbackBase callback =
-                verifyMobikeEnabled(true /* doesPeerSupportMobike */, mMockDefaultNetwork);
+                verifyMobikeEnabled(
+                        true /* doesPeerSupportMobike */,
+                        doesPeerSupportNatt,
+                        isIpv4,
+                        mMockDefaultNetwork);
 
         // reset IkeMessageHelper to make verifying outbound req easier
         resetMockIkeMessageHelper();
@@ -5587,12 +5722,39 @@
         return callback;
     }
 
-    @Test
-    public void testSetNetworkIdleState() throws Exception {
-        IkeNetworkCallbackBase callback = setupIdleStateMachineWithMobike();
+    private void verifySetNetworkIdleState(boolean doesPeerSupportNatt, boolean isIpv4)
+            throws Exception {
+        IkeNetworkCallbackBase callback =
+                setupIdleStateMachineWithMobike(doesPeerSupportNatt, isIpv4);
 
         verifySetNetwork(
-                callback, null /* rekeySaRecord */, mIkeSessionStateMachine.mMobikeLocalInfo);
+                callback,
+                null /* rekeySaRecord */,
+                mIkeSessionStateMachine.mMobikeLocalInfo,
+                isIpv4);
+        assertTrue(
+                getExpectedSocketType(doesPeerSupportNatt, isIpv4)
+                        .isInstance(mIkeSessionStateMachine.mIkeSocket));
+    }
+
+    @Test
+    public void testSetNetworkIdleStateNattSupportedIpv4() throws Exception {
+        verifySetNetworkIdleState(true /* doesPeerSupportNatt */, true /* isIpv4 */);
+    }
+
+    @Test
+    public void testSetNetworkIdleStateNattSupportedIpv6() throws Exception {
+        verifySetNetworkIdleState(true /* doesPeerSupportNatt */, false /* isIpv4 */);
+    }
+
+    @Test
+    public void testSetNetworkIdleStateNattUnsupportedIpv4() throws Exception {
+        verifySetNetworkIdleState(false /* doesPeerSupportNatt */, true /* isIpv4 */);
+    }
+
+    @Test
+    public void testSetNetworkIdleStateNattUnsupportedIpv6() throws Exception {
+        verifySetNetworkIdleState(false /* doesPeerSupportNatt */, false /* isIpv4 */);
     }
 
     @Test
@@ -5829,8 +5991,6 @@
                 mIkeSessionStateMachine.getCurrentState()
                         instanceof IkeSessionStateMachine.ChildProcedureOngoing);
         verify(mMockChildSessionStateMachine).rekeyChildSessionForMobike();
-
-        // TODO(b/173237734): check IkeSocket - if includeNatDetection then expect UdpEncap
     }
 
     @Test
@@ -5840,8 +6000,11 @@
         verifySetNetwork(
                 callback, null /* rekeySaRecord */, mIkeSessionStateMachine.mMobikeLocalInfo);
 
+        // Keepalive for the old UDP encap socket stopped
         verify(mMockIkeNattKeepalive).stop();
-        assertNull(mIkeSessionStateMachine.mIkeNattKeepalive);
+
+        // Keepalive for the new UDP encap socket started
+        assertNotNull(mIkeSessionStateMachine.mIkeNattKeepalive);
     }
 
     @Test
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionTestBase.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionTestBase.java
index 0daeb9d..5d14a04 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionTestBase.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionTestBase.java
@@ -27,6 +27,7 @@
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 
 import android.content.Context;
@@ -48,6 +49,7 @@
 import org.junit.Before;
 
 import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.util.concurrent.Executor;
 
 public abstract class IkeSessionTestBase {
@@ -57,6 +59,12 @@
             (Inet4Address) InetAddresses.parseNumericAddress("192.0.2.201");
     protected static final Inet4Address REMOTE_ADDRESS =
             (Inet4Address) InetAddresses.parseNumericAddress("127.0.0.1");
+    protected static final Inet6Address LOCAL_ADDRESS_V6 =
+            (Inet6Address) InetAddresses.parseNumericAddress("2001:db8::200");
+    protected static final Inet6Address UPDATED_LOCAL_ADDRESS_V6 =
+            (Inet6Address) InetAddresses.parseNumericAddress("2001:db8::201");
+    protected static final Inet6Address REMOTE_ADDRESS_V6 =
+            (Inet6Address) InetAddresses.parseNumericAddress("::1");
     protected static final String REMOTE_HOSTNAME = "ike.test.android.com";
 
     protected PowerManager.WakeLock mMockBusyWakelock;
@@ -100,15 +108,29 @@
                 .when(mPowerManager)
                 .newWakeLock(anyInt(), argThat(tag -> tag.contains(LOCAL_REQUEST_WAKE_LOCK_TAG)));
 
-        mMockConnectManager = mock(ConnectivityManager.class);
         mMockDefaultNetwork = mock(Network.class);
-        doReturn(mMockDefaultNetwork).when(mMockConnectManager).getActiveNetwork();
         doReturn(REMOTE_ADDRESS).when(mMockDefaultNetwork).getByName(REMOTE_HOSTNAME);
         doReturn(REMOTE_ADDRESS)
                 .when(mMockDefaultNetwork)
                 .getByName(REMOTE_ADDRESS.getHostAddress());
 
         mMockSocketKeepalive = mock(SocketKeepalive.class);
+
+        mMockNetworkCapabilities = mock(NetworkCapabilities.class);
+        doReturn(false)
+                .when(mMockNetworkCapabilities)
+                .hasTransport(RandomnessFactory.TRANSPORT_TEST);
+
+        mMockConnectManager = mock(ConnectivityManager.class);
+        doReturn(mMockConnectManager)
+                .when(mSpyContext)
+                .getSystemService(Context.CONNECTIVITY_SERVICE);
+        resetMockConnectManager();
+    }
+
+    protected void resetMockConnectManager() {
+        reset(mMockConnectManager);
+        doReturn(mMockDefaultNetwork).when(mMockConnectManager).getActiveNetwork();
         doReturn(mMockSocketKeepalive)
                 .when(mMockConnectManager)
                 .createSocketKeepalive(
@@ -118,16 +140,8 @@
                         any(Inet4Address.class),
                         any(Executor.class),
                         any(SocketKeepalive.Callback.class));
-        doReturn(mMockConnectManager)
-                .when(mSpyContext)
-                .getSystemService(Context.CONNECTIVITY_SERVICE);
-
-        mMockNetworkCapabilities = mock(NetworkCapabilities.class);
         doReturn(mMockNetworkCapabilities)
                 .when(mMockConnectManager)
                 .getNetworkCapabilities(any(Network.class));
-        doReturn(false)
-                .when(mMockNetworkCapabilities)
-                .hasTransport(RandomnessFactory.TRANSPORT_TEST);
     }
 }