Merge "Add support for IKE negotiation of DH groups"
diff --git a/src/java/android/net/ipsec/ike/SaProposal.java b/src/java/android/net/ipsec/ike/SaProposal.java
index 53863cd..fa6fbf1 100644
--- a/src/java/android/net/ipsec/ike/SaProposal.java
+++ b/src/java/android/net/ipsec/ike/SaProposal.java
@@ -19,7 +19,6 @@
 import android.annotation.IntDef;
 import android.annotation.NonNull;
 import android.annotation.SystemApi;
-import android.util.ArraySet;
 import android.util.Pair;
 import android.util.SparseArray;
 
@@ -34,9 +33,9 @@
 import java.lang.annotation.RetentionPolicy;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.LinkedHashSet;
 import java.util.LinkedList;
 import java.util.List;
-import java.util.Set;
 
 /**
  * SaProposal represents a proposed configuration to negotiate an IKE or Child SA.
@@ -348,11 +347,13 @@
     protected abstract static class Builder {
         protected static final String ERROR_TAG = "Invalid SA Proposal: ";
 
-        // Use set to avoid adding repeated algorithms.
-        protected final Set<EncryptionTransform> mProposedEncryptAlgos = new ArraySet<>();
-        protected final Set<PrfTransform> mProposedPrfs = new ArraySet<>();
-        protected final Set<IntegrityTransform> mProposedIntegrityAlgos = new ArraySet<>();
-        protected final Set<DhGroupTransform> mProposedDhGroups = new ArraySet<>();
+        // Use LinkedHashSet to ensure uniqueness and that ordering is maintained.
+        protected final LinkedHashSet<EncryptionTransform> mProposedEncryptAlgos =
+                new LinkedHashSet<>();
+        protected final LinkedHashSet<PrfTransform> mProposedPrfs = new LinkedHashSet<>();
+        protected final LinkedHashSet<IntegrityTransform> mProposedIntegrityAlgos =
+                new LinkedHashSet<>();
+        protected final LinkedHashSet<DhGroupTransform> mProposedDhGroups = new LinkedHashSet<>();
 
         protected boolean mHasAead = false;
 
diff --git a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
index 8ad60e9..2e36449 100644
--- a/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
+++ b/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachine.java
@@ -83,6 +83,7 @@
 import com.android.internal.net.ipsec.ike.crypto.IkeMacIntegrity;
 import com.android.internal.net.ipsec.ike.crypto.IkeMacPrf;
 import com.android.internal.net.ipsec.ike.exceptions.AuthenticationFailedException;
+import com.android.internal.net.ipsec.ike.exceptions.InvalidKeException;
 import com.android.internal.net.ipsec.ike.exceptions.InvalidSyntaxException;
 import com.android.internal.net.ipsec.ike.exceptions.NoValidProposalChosenException;
 import com.android.internal.net.ipsec.ike.message.IkeAuthDigitalSignPayload;
@@ -109,7 +110,6 @@
 import com.android.internal.net.ipsec.ike.message.IkeNotifyPayload;
 import com.android.internal.net.ipsec.ike.message.IkePayload;
 import com.android.internal.net.ipsec.ike.message.IkeSaPayload;
-import com.android.internal.net.ipsec.ike.message.IkeSaPayload.DhGroupTransform;
 import com.android.internal.net.ipsec.ike.message.IkeSaPayload.IkeProposal;
 import com.android.internal.net.ipsec.ike.message.IkeTsPayload;
 import com.android.internal.net.ipsec.ike.message.IkeVendorPayload;
@@ -322,6 +322,9 @@
     final HashMap<ChildSessionCallback, ChildSessionStateMachine> mChildCbToSessions =
             new HashMap<>();
 
+    /** Peer-selected DH group to use. Defaults to first proposed DH group in first SA proposal. */
+    @VisibleForTesting int mPeerSelectedDhGroup;
+
     /**
      * Package private socket that sends and receives encoded IKE message. Initialized in Initial
      * State.
@@ -442,6 +445,11 @@
         mIkeSessionParams = ikeParams;
         mEapAuthenticatorFactory = eapAuthenticatorFactory;
 
+        // SaProposals.Builder guarantees there is at least one SA proposal, and each SA proposal
+        // has at least one DH group.
+        mPeerSelectedDhGroup =
+                mIkeSessionParams.getSaProposals().get(0).getDhGroupTransforms()[0].id;
+
         mTempFailHandler = new TempFailureHandler(looper);
 
         // There are at most three IkeSaRecords co-existing during simultaneous rekeying.
@@ -2522,7 +2530,34 @@
 
                 transitionTo(mCreateIkeLocalIkeAuth);
             } catch (IkeProtocolException | GeneralSecurityException | IOException e) {
-                // TODO: Try another DH group to buld KE Payload if receiving InvalidKeException
+                if (e instanceof InvalidKeException) {
+                    InvalidKeException keException = (InvalidKeException) e;
+
+                    int requestedDhGroup = keException.getDhGroup();
+                    boolean doAllProposalsHaveDhGroup = true;
+                    for (IkeSaProposal proposal : mIkeSessionParams.getSaProposalsInternal()) {
+                        doAllProposalsHaveDhGroup &=
+                                proposal.getDhGroups().contains(requestedDhGroup);
+                    }
+
+                    // If DH group is not acceptable for all proposals, fail. The caller explicitly
+                    // did not want that combination, and the IKE library must honor it.
+                    if (doAllProposalsHaveDhGroup) {
+                        mPeerSelectedDhGroup = requestedDhGroup;
+
+                        // Remove state set during request creation
+                        mIkeSocket.unregisterIke(
+                                mRetransmitter.getMessage().ikeHeader.ikeInitiatorSpi);
+                        mIkeInitRequestBytes = null;
+                        mIkeInitNoncePayload = null;
+
+                        transitionTo(mInitial);
+                        openSession();
+
+                        return;
+                    }
+                }
+
                 handleIkeFatalError(e);
             } finally {
                 if (!ikeInitSuccess) {
@@ -2551,6 +2586,7 @@
             List<IkePayload> payloadList =
                     CreateIkeSaHelper.getIkeInitSaRequestPayloads(
                             saProposals,
+                            mPeerSelectedDhGroup,
                             initSpi,
                             respSpi,
                             mLocalAddress,
@@ -2703,7 +2739,7 @@
             IkeKePayload reqKePayload =
                     reqMsg.getPayloadForType(IkePayload.PAYLOAD_TYPE_KE, IkeKePayload.class);
             if (reqKePayload.dhGroup != respKePayload.dhGroup
-                    && respKePayload.dhGroup != mSaProposal.getDhGroupTransforms()[0].id) {
+                    && respKePayload.dhGroup != mPeerSelectedDhGroup) {
                 throw new InvalidSyntaxException("Received KE payload with mismatched DH group.");
             }
 
@@ -4385,6 +4421,7 @@
     private static class CreateIkeSaHelper {
         public static List<IkePayload> getIkeInitSaRequestPayloads(
                 IkeSaProposal[] saProposals,
+                int selectedDhGroup,
                 long initIkeSpi,
                 long respIkeSpi,
                 InetAddress localAddr,
@@ -4393,7 +4430,8 @@
                 int remotePort)
                 throws IOException {
             List<IkePayload> payloadList =
-                    getCreateIkeSaPayloads(IkeSaPayload.createInitialIkeSaPayload(saProposals));
+                    getCreateIkeSaPayloads(
+                            selectedDhGroup, IkeSaPayload.createInitialIkeSaPayload(saProposals));
 
             // Though RFC says Notify-NAT payload is "just after the Ni and Nr payloads (before the
             // optional CERTREQ payload)", it also says recipient MUST NOT reject " messages in
@@ -4418,7 +4456,12 @@
                 throw new IllegalArgumentException("Local address was null for rekey");
             }
 
+            // Guaranteed to have at least one SA Proposal, since the IKE session was set up
+            // properly.
+            int selectedDhGroup = saProposals[0].getDhGroupTransforms()[0].id;
+
             return getCreateIkeSaPayloads(
+                    selectedDhGroup,
                     IkeSaPayload.createRekeyIkeSaRequestPayload(saProposals, localAddr));
         }
 
@@ -4429,7 +4472,10 @@
                 throw new IllegalArgumentException("Local address was null for rekey");
             }
 
+            int selectedDhGroup = saProposal.getDhGroupTransforms()[0].id;
+
             return getCreateIkeSaPayloads(
+                    selectedDhGroup,
                     IkeSaPayload.createRekeyIkeSaResponsePayload(
                             respProposalNumber, saProposal, localAddr));
         }
@@ -4439,8 +4485,8 @@
          *
          * <p>Will return a non-empty list of IkePayloads, the first of which WILL be the SA payload
          */
-        private static List<IkePayload> getCreateIkeSaPayloads(IkeSaPayload saPayload)
-                throws IOException {
+        private static List<IkePayload> getCreateIkeSaPayloads(
+                int selectedDhGroup, IkeSaPayload saPayload) throws IOException {
             if (saPayload.proposalList.size() == 0) {
                 throw new IllegalArgumentException("Invalid SA proposal list - was empty");
             }
@@ -4451,11 +4497,7 @@
             payloadList.add(new IkeNoncePayload());
 
             // SaPropoals.Builder guarantees that each SA proposal has at least one DH group.
-            DhGroupTransform dhGroupTransform =
-                    ((IkeProposal) saPayload.proposalList.get(0))
-                            .saProposal
-                            .getDhGroupTransforms()[0];
-            payloadList.add(new IkeKePayload(dhGroupTransform.id));
+            payloadList.add(new IkeKePayload(selectedDhGroup));
 
             return payloadList;
         }
diff --git a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
index f44e3a5..0539d45 100644
--- a/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/internal/net/ipsec/ike/IkeSessionStateMachineTest.java
@@ -207,6 +207,7 @@
                     + "92f46bef84f0be7db860351843858f8acf87056e272377f7"
                     + "0c9f2d81e29c7b0ce4f291a3a72476bb0b278fd4b7b0a4c2"
                     + "6bbeb08214c7071376079587";
+    private static final String INVALID_KE_PAYLOAD_HEX_STRING = "0000000a00000011000e";
     private static final String NONCE_INIT_PAYLOAD_HEX_STRING =
             "29000024c39b7f368f4681b89fa9b7be6465abd7c5f68b6ed5d3b4c72cb4240eb5c46412";
     private static final String NONCE_RESP_PAYLOAD_HEX_STRING =
@@ -776,13 +777,20 @@
     }
 
     public static IkeSaProposal buildSaProposal() throws Exception {
+        return buildSaProposalCommon().addDhGroup(SaProposal.DH_GROUP_2048_BIT_MODP).build();
+    }
+
+    private static IkeSaProposal buildNegotiatedSaProposal() throws Exception {
+        return buildSaProposalCommon().build();
+    }
+
+    private static IkeSaProposal.Builder buildSaProposalCommon() throws Exception {
         return new IkeSaProposal.Builder()
                 .addEncryptionAlgorithm(
                         SaProposal.ENCRYPTION_ALGORITHM_AES_CBC, SaProposal.KEY_LEN_AES_128)
                 .addIntegrityAlgorithm(SaProposal.INTEGRITY_ALGORITHM_HMAC_SHA1_96)
                 .addPseudorandomFunction(SaProposal.PSEUDORANDOM_FUNCTION_HMAC_SHA1)
-                .addDhGroup(SaProposal.DH_GROUP_1024_BIT_MODP)
-                .build();
+                .addDhGroup(SaProposal.DH_GROUP_1024_BIT_MODP);
     }
 
     private IkeSessionParams.Builder buildIkeSessionParamsCommon() throws Exception {
@@ -1255,6 +1263,33 @@
         verify(mMockDefaultNetwork).getByName(REMOTE_HOSTNAME);
     }
 
+    @Test
+    public void testCreateIkeLocalIkeInitNegotiatesDhGroup() throws Exception {
+        setupFirstIkeSa();
+        mIkeSessionStateMachine.sendMessage(IkeSessionStateMachine.CMD_LOCAL_REQUEST_CREATE_IKE);
+        mLooper.dispatchAll();
+
+        // Verify we started with the top proposed DH group
+        assertEquals(
+                SaProposal.DH_GROUP_1024_BIT_MODP, mIkeSessionStateMachine.mPeerSelectedDhGroup);
+
+        // Send back a INVALID_KE_PAYLOAD, and verify that the selected DH group changes
+        ReceivedIkePacket resp =
+                makeDummyReceivedIkeInitRespPacket(
+                        1L /*initiator SPI*/,
+                        2L /*responder SPI*/,
+                        IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT,
+                        true /*isResp*/,
+                        false /*fromIkeInit*/,
+                        Arrays.asList(IkePayload.PAYLOAD_TYPE_NOTIFY),
+                        Arrays.asList(INVALID_KE_PAYLOAD_HEX_STRING));
+        mIkeSessionStateMachine.sendMessage(IkeSessionStateMachine.CMD_RECEIVE_IKE_PACKET, resp);
+        mLooper.dispatchAll();
+
+        assertEquals(
+                SaProposal.DH_GROUP_2048_BIT_MODP, mIkeSessionStateMachine.mPeerSelectedDhGroup);
+    }
+
     @Ignore
     public void disableTestCreateIkeLocalIkeInit() throws Exception {
         setupFirstIkeSa();
@@ -1343,7 +1378,7 @@
         mIkeSessionStateMachine.mIkeCipher = mock(IkeCipher.class);
         mIkeSessionStateMachine.mIkeIntegrity = mock(IkeMacIntegrity.class);
         mIkeSessionStateMachine.mIkePrf = mock(IkeMacPrf.class);
-        mIkeSessionStateMachine.mSaProposal = buildSaProposal();
+        mIkeSessionStateMachine.mSaProposal = buildNegotiatedSaProposal();
         mIkeSessionStateMachine.mCurrentIkeSaRecord = mSpyCurrentIkeSaRecord;
         mIkeSessionStateMachine.mLocalAddress = LOCAL_ADDRESS;
         mIkeSessionStateMachine.mIsLocalBehindNat = true;
@@ -4098,6 +4133,11 @@
         IkeSessionParams mockSessionParams = mock(IkeSessionParams.class);
         when(mockSessionParams.getSaProposalsInternal()).thenThrow(mock(RuntimeException.class));
 
+        DhGroupTransform dhGroupTransform = new DhGroupTransform(SaProposal.DH_GROUP_2048_BIT_MODP);
+        IkeSaProposal mockSaProposal = mock(IkeSaProposal.class);
+        when(mockSaProposal.getDhGroupTransforms())
+                .thenReturn(new DhGroupTransform[] {dhGroupTransform});
+        when(mockSessionParams.getSaProposals()).thenReturn(Arrays.asList(mockSaProposal));
         IkeSessionStateMachine ikeSession =
                 new IkeSessionStateMachine(
                         mLooper.getLooper(),