Build IKE INIT request am: 935e2fd2ea am: 6aaf949023
am: 91c6ff66d9

Change-Id: Ia3e92a992d0b0a90a6ea268da6221b1e03245221
diff --git a/src/java/com/android/ike/ikev2/ChildSessionOptions.java b/src/java/com/android/ike/ikev2/ChildSessionOptions.java
index cbca9f4..b311c66 100644
--- a/src/java/com/android/ike/ikev2/ChildSessionOptions.java
+++ b/src/java/com/android/ike/ikev2/ChildSessionOptions.java
@@ -20,6 +20,6 @@
  * ChildSessionOptions contains user-provided Child SA proposals and negotiated Child SA
  * information.
  */
-public class ChildSessionOptions {
+public final class ChildSessionOptions {
     // TODO: Implement it.
 }
diff --git a/src/java/com/android/ike/ikev2/IkeSessionOptions.java b/src/java/com/android/ike/ikev2/IkeSessionOptions.java
index fe06b85..9f2094f 100644
--- a/src/java/com/android/ike/ikev2/IkeSessionOptions.java
+++ b/src/java/com/android/ike/ikev2/IkeSessionOptions.java
@@ -29,7 +29,7 @@
  *
  * <p>TODO: Make this doc more user-friendly.
  */
-public class IkeSessionOptions {
+public final class IkeSessionOptions {
     private final InetAddress mServerAddress;
     private final UdpEncapsulationSocket mUdpEncapSocket;
     private final SaProposal[] mSaProposals;
@@ -92,7 +92,7 @@
          * @throws IllegalArgumentException if input proposal is not IKE SA proposal.
          */
         public Builder addSaProposal(SaProposal proposal) {
-            if (proposal.mProtocolId != IkePayload.PROTOCOL_ID_IKE) {
+            if (proposal.getProtocolId() != IkePayload.PROTOCOL_ID_IKE) {
                 throw new IllegalArgumentException(
                         "Expected IKE SA Proposal but received Child SA proposal");
             }
diff --git a/src/java/com/android/ike/ikev2/IkeSessionStateMachine.java b/src/java/com/android/ike/ikev2/IkeSessionStateMachine.java
index 31d15e5..b4a7b9c 100644
--- a/src/java/com/android/ike/ikev2/IkeSessionStateMachine.java
+++ b/src/java/com/android/ike/ikev2/IkeSessionStateMachine.java
@@ -17,19 +17,30 @@
 
 import android.os.Looper;
 import android.os.Message;
+import android.system.ErrnoException;
 import android.util.LongSparseArray;
 import android.util.SparseArray;
 
 import com.android.ike.ikev2.SaRecord.IkeSaRecord;
 import com.android.ike.ikev2.exceptions.IkeException;
 import com.android.ike.ikev2.message.IkeHeader;
+import com.android.ike.ikev2.message.IkeKePayload;
 import com.android.ike.ikev2.message.IkeMessage;
+import com.android.ike.ikev2.message.IkeNoncePayload;
 import com.android.ike.ikev2.message.IkeNotifyPayload;
+import com.android.ike.ikev2.message.IkePayload;
+import com.android.ike.ikev2.message.IkeSaPayload;
+import com.android.ike.ikev2.message.IkeSaPayload.DhGroupTransform;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.State;
 import com.android.internal.util.StateMachine;
 
 import java.security.GeneralSecurityException;
+import java.security.SecureRandom;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
 
 /**
  * IkeSessionStateMachine tracks states and manages exchanges of this IKE session.
@@ -73,6 +84,11 @@
     static final int CMD_LOCAL_REQUEST_REKEY_CHILD = CMD_LOCAL_REQUEST_BASE + 7;
     // TODO: Add signals for other procedure types and notificaitons.
 
+    // Remember locally assigned IKE SPIs to avoid SPI collision.
+    private static final Set<Long> ASSIGNED_LOCAL_IKE_SPI_SET = new HashSet<>();
+    private static final int MAX_ASSIGN_IKE_SPI_ATTEMPTS = 100;
+    private static final SecureRandom IKE_SPI_RANDOM = new SecureRandom();
+
     private final IkeSessionOptions mIkeSessionOptions;
     private final ChildSessionOptions mFirstChildSessionOptions;
     /** Map that stores all IkeSaRecords, keyed by remotely generated IKE SPI. */
@@ -84,6 +100,12 @@
      */
     private final SparseArray<ChildSessionStateMachine> mSpiToChildSessionMap;
 
+    /**
+     * Package private socket that sends and receives encoded IKE message. Initialized in Initial
+     * State.
+     */
+    @VisibleForTesting IkeSocket mIkeSocket;
+
     /** Package */
     @VisibleForTesting IkeSaRecord mCurrentIkeSaRecord;
     /** Package */
@@ -145,9 +167,55 @@
         setInitialState(mInitial);
     }
 
+    // Generate IKE SPI. Throw an exception if it failed and handle this exception in current State.
+    private static Long getIkeSpiOrThrow() {
+        for (int i = 0; i < MAX_ASSIGN_IKE_SPI_ATTEMPTS; i++) {
+            long spi = IKE_SPI_RANDOM.nextLong();
+            if (ASSIGNED_LOCAL_IKE_SPI_SET.add(spi)) return spi;
+        }
+        throw new IllegalStateException("Failed to generate IKE SPI.");
+    }
+
     private IkeMessage buildIkeInitReq() {
-        // TODO:Build packet according to mIkeSessionOptions.
-        return null;
+        // TODO: Handle IKE SPI assigning error in CreateIkeLocalIkeInit State.
+
+        List<IkePayload> payloadList = new LinkedList<>();
+
+        // Generate IKE SPI
+        long initSpi = getIkeSpiOrThrow();
+        long respSpi = 0;
+
+        // It is validated in IkeSessionOptions.Builder to ensure IkeSessionOptions has at least one
+        // SaProposal and all SaProposals are valid for IKE SA negotiation.
+        SaProposal[] saProposals = mIkeSessionOptions.getSaProposals();
+
+        // Build SA Payload
+        IkeSaPayload saPayload = new IkeSaPayload(saProposals);
+        payloadList.add(saPayload);
+
+        // Build KE Payload using the first DH group number in the first SaProposal.
+        DhGroupTransform dhGroupTransform = saProposals[0].getDhGroupTransforms()[0];
+        IkeKePayload kePayload = new IkeKePayload(dhGroupTransform.id);
+        payloadList.add(kePayload);
+
+        // Build Nonce Payload
+        IkeNoncePayload noncePayload = new IkeNoncePayload();
+        payloadList.add(noncePayload);
+
+        // TODO: Add Notification Payloads according to user configurations.
+
+        // Build IKE header
+        IkeHeader ikeHeader =
+                new IkeHeader(
+                        initSpi,
+                        respSpi,
+                        IkePayload.PAYLOAD_TYPE_SA,
+                        IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT,
+                        false /*isResponseMsg*/,
+                        true /*fromIkeInitiator*/,
+                        0 /*messageId*/);
+
+        return new IkeMessage(ikeHeader, payloadList);
     }
 
     private IkeMessage buildIkeAuthReq() {
@@ -272,6 +340,15 @@
     /** Initial state of IkeSessionStateMachine. */
     class Initial extends State {
         @Override
+        public void enter() {
+            try {
+                mIkeSocket = IkeSocket.getIkeSocket(mIkeSessionOptions.getUdpEncapsulationSocket());
+            } catch (ErrnoException e) {
+                // TODO: handle exception and close IkeSession.
+            }
+        }
+
+        @Override
         public boolean processMessage(Message message) {
             switch (message.what) {
                 case CMD_LOCAL_REQUEST_CREATE_IKE:
@@ -429,6 +506,7 @@
         public void enter() {
             mRequestMsg = buildRequest();
             mRequestPacket = encodeRequest();
+            mIkeSocket.sendIkePacket(mRequestPacket, mIkeSessionOptions.getServerAddress());
             // TODO: Send out packet and start retransmission timer.
         }
 
@@ -444,12 +522,20 @@
         // CreateIkeLocalInit should override encodeRequest() to encode unencrypted packet
         protected byte[] encodeRequest() {
             // TODO: encrypt and encode mRequestMsg
-            return null;
-        };
+            return new byte[0];
+        }
     }
 
     /** CreateIkeLocalIkeInit represents state when IKE library initiates IKE_INIT exchange. */
     class CreateIkeLocalIkeInit extends LocalNewExchangeBase {
+
+        @Override
+        public void enter() {
+            super.enter();
+            mIkeSocket.registerIke(
+                    mRequestMsg.ikeHeader.ikeInitiatorSpi, IkeSessionStateMachine.this);
+        }
+
         @Override
         protected IkeMessage buildRequest() {
             return buildIkeInitReq();
@@ -457,8 +543,7 @@
 
         @Override
         protected byte[] encodeRequest() {
-            // TODO: Encode an unencrypted IKE packet.
-            return null;
+            return mRequestMsg.encode();
         }
 
         @Override
@@ -534,8 +619,7 @@
                                         mFirstChildSessionOptions);
                         // TODO: Replace null input params to payload lists in IKE_AUTH request and
                         // IKE_AUTH response for negotiating Child SA.
-                        firstChild.handleFirstChildExchange(
-                                null, null, new ChildSessionCallback());
+                        firstChild.handleFirstChildExchange(null, null, new ChildSessionCallback());
 
                         transitionTo(mIdle);
                     } catch (IkeException e) {
diff --git a/src/java/com/android/ike/ikev2/IkeSocket.java b/src/java/com/android/ike/ikev2/IkeSocket.java
index 0f1b9e6..e235ab8 100644
--- a/src/java/com/android/ike/ikev2/IkeSocket.java
+++ b/src/java/com/android/ike/ikev2/IkeSocket.java
@@ -73,18 +73,20 @@
     @VisibleForTesting static final int NON_ESP_MARKER_LEN = 4;
     @VisibleForTesting static final byte[] NON_ESP_MARKER = new byte[NON_ESP_MARKER_LEN];
 
-    // Package private map from UdpEncapsulationSocket to IkeSocket instances.
-    static Map<UdpEncapsulationSocket, IkeSocket> sFdToIkeSocketMap = new HashMap<>();
+    // Map from UdpEncapsulationSocket to IkeSocket instances.
+    private static Map<UdpEncapsulationSocket, IkeSocket> sFdToIkeSocketMap = new HashMap<>();
 
     private static IPacketReceiver sPacketReceiver = new PacketReceiver();
 
-    // Map from locally generated IKE SPI to IkeSessionStateMachine instances.
-    private final LongSparseArray<IkeSessionStateMachine> mSpiToIkeSession =
+    // Package private map from locally generated IKE SPI to IkeSessionStateMachine instances.
+    @VisibleForTesting
+    final LongSparseArray<IkeSessionStateMachine> mSpiToIkeSession =
             new LongSparseArray<>();
     // UdpEncapsulationSocket for sending and receving IKE packet.
     private final UdpEncapsulationSocket mUdpEncapSocket;
 
     /** Package private */
+    @VisibleForTesting
     int mRefCount;
 
     private IkeSocket(UdpEncapsulationSocket udpEncapSocket, Handler handler) {
diff --git a/src/java/com/android/ike/ikev2/SaProposal.java b/src/java/com/android/ike/ikev2/SaProposal.java
index 04c06d0..53206cc 100644
--- a/src/java/com/android/ike/ikev2/SaProposal.java
+++ b/src/java/com/android/ike/ikev2/SaProposal.java
@@ -139,17 +139,17 @@
     }
 
     /** Package private */
-    @IkePayload.ProtocolId final int mProtocolId;
+    @IkePayload.ProtocolId private final int mProtocolId;
     /** Package private */
-    final EncryptionTransform[] mEncryptionAlgorithms;
+    private final EncryptionTransform[] mEncryptionAlgorithms;
     /** Package private */
-    final PrfTransform[] mPseudorandomFunctions;
+    private final PrfTransform[] mPseudorandomFunctions;
     /** Package private */
-    final IntegrityTransform[] mIntegrityAlgorithms;
+    private final IntegrityTransform[] mIntegrityAlgorithms;
     /** Package private */
-    final DhGroupTransform[] mDhGroups;
+    private final DhGroupTransform[] mDhGroups;
     /** Package private */
-    final EsnTransform[] mEsns;
+    private final EsnTransform[] mEsns;
 
     private SaProposal(
             @IkePayload.ProtocolId int protocol,
@@ -227,6 +227,37 @@
         return Arrays.asList(selectFrom).contains(selected[0]);
     }
 
+    /*Package private*/
+    @IkePayload.ProtocolId
+    int getProtocolId() {
+        return mProtocolId;
+    }
+
+    /*Package private*/
+    EncryptionTransform[] getEncryptionTransforms() {
+        return mEncryptionAlgorithms;
+    }
+
+    /*Package private*/
+    PrfTransform[] getPrfTransforms() {
+        return mPseudorandomFunctions;
+    }
+
+    /*Package private*/
+    IntegrityTransform[] getIntegrityTransforms() {
+        return mIntegrityAlgorithms;
+    }
+
+    /*Package private*/
+    DhGroupTransform[] getDhGroupTransforms() {
+        return mDhGroups;
+    }
+
+    /*Package private*/
+    EsnTransform[] getEsnTransforms() {
+        return mEsns;
+    }
+
     /**
      * Return all SA Transforms in this SaProposal to be encoded for building an outbound IKE
      * message.
diff --git a/src/java/com/android/ike/ikev2/message/IkeEncryptedPayloadBody.java b/src/java/com/android/ike/ikev2/message/IkeEncryptedPayloadBody.java
index f47fb88..a41d855 100644
--- a/src/java/com/android/ike/ikev2/message/IkeEncryptedPayloadBody.java
+++ b/src/java/com/android/ike/ikev2/message/IkeEncryptedPayloadBody.java
@@ -139,7 +139,12 @@
         ByteBuffer authenticatedSectionBuffer = ByteBuffer.allocate(dataToAuthenticateLength);
 
         // Encode IKE header
-        ikeHeader.encodeToByteBuffer(authenticatedSectionBuffer);
+        int encryptedPayloadLength =
+                IkePayload.GENERIC_HEADER_LENGTH
+                        + iv.length
+                        + mEncryptedAndPaddedData.length
+                        + checksumLen;
+        ikeHeader.encodeToByteBuffer(authenticatedSectionBuffer, encryptedPayloadLength);
 
         // Encode payload header. The next payload type field indicates the first payload nested in
         // this SkPayload/SkfPayload.
diff --git a/src/java/com/android/ike/ikev2/message/IkeHeader.java b/src/java/com/android/ike/ikev2/message/IkeHeader.java
index ebd3553..c4f215c 100644
--- a/src/java/com/android/ike/ikev2/message/IkeHeader.java
+++ b/src/java/com/android/ike/ikev2/message/IkeHeader.java
@@ -23,6 +23,7 @@
 import com.android.ike.ikev2.exceptions.IkeException;
 import com.android.ike.ikev2.exceptions.InvalidMajorVersionException;
 import com.android.ike.ikev2.exceptions.InvalidSyntaxException;
+import com.android.internal.annotations.VisibleForTesting;
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
@@ -36,7 +37,7 @@
  *     Protocol Version 2 (IKEv2)</a>
  */
 public final class IkeHeader {
-    //TODO: b/122838549 Change IkeHeader to static inner class of IkeMessage.
+    // TODO: b/122838549 Change IkeHeader to static inner class of IkeMessage.
     private static final byte IKE_HEADER_VERSION_INFO = (byte) 0x20;
 
     // Indicate whether this message is a response message
@@ -69,7 +70,14 @@
     public final boolean isResponseMsg;
     public final boolean fromIkeInitiator;
     public final int messageId;
-    public final int messageLength;
+
+    // Cannot assign encoded message length value for an outbound IKE message before it's encoded.
+    private static final int ENCODED_MESSAGE_LEN_UNAVAILABLE = -1;
+
+    // mEncodedMessageLength is only set for an inbound IkeMessage. When building an outbound
+    // IkeMessage, message length is not set because message body length is unknown until it gets
+    // encrypted and encoded.
+    private final int mEncodedMessageLength;
 
     /**
      * Construct an instance of IkeHeader. It is only called in the process of building outbound
@@ -82,7 +90,6 @@
      * @param isResp indicates if this message is a response or a request
      * @param fromInit indictaes if this message is sent from the IKE initiator or the IKE responder
      * @param msgId the message identifier
-     * @param length the length of the total message in octets
      */
     public IkeHeader(
             long iSpi,
@@ -91,8 +98,7 @@
             @ExchangeType int eType,
             boolean isResp,
             boolean fromInit,
-            int msgId,
-            int length) {
+            int msgId) {
         ikeInitiatorSpi = iSpi;
         ikeResponderSpi = rSpi;
         nextPayloadType = nextPType;
@@ -100,7 +106,8 @@
         isResponseMsg = isResp;
         fromIkeInitiator = fromInit;
         messageId = msgId;
-        messageLength = length;
+
+        mEncodedMessageLength = ENCODED_MESSAGE_LEN_UNAVAILABLE;
 
         // Major version of IKE protocol in use; it must be set to 2 when building an IKEv2 message.
         majorVersion = 2;
@@ -135,11 +142,21 @@
         fromIkeInitiator = ((flagsByte & 0x08) != 0);
 
         messageId = buffer.getInt();
-        messageLength = buffer.getInt();
+        mEncodedMessageLength = buffer.getInt();
     }
 
-    /** Validate syntax and major version. */
-    public void checkValidOrThrow(int packetLength) throws IkeException {
+    /*Package private*/
+    @VisibleForTesting
+    int getInboundMessageLength() {
+        if (mEncodedMessageLength == ENCODED_MESSAGE_LEN_UNAVAILABLE) {
+            throw new UnsupportedOperationException(
+                    "It is not supported to get encoded message length from an outbound message.");
+        }
+        return mEncodedMessageLength;
+    }
+
+    /** Validate syntax and major version of inbound IKE header. */
+    public void checkInboundValidOrThrow(int packetLength) throws IkeException {
         if (majorVersion > 2) {
             // Receive higher version of protocol. Stop parsing.
             throw new InvalidMajorVersionException(majorVersion);
@@ -155,13 +172,13 @@
                 || exchangeType > EXCHANGE_TYPE_INFORMATIONAL) {
             throw new InvalidSyntaxException("Invalid IKE Exchange Type.");
         }
-        if (messageLength != packetLength) {
+        if (mEncodedMessageLength != packetLength) {
             throw new InvalidSyntaxException("Invalid IKE Message Length.");
         }
     }
 
     /** Encode IKE header to ByteBuffer */
-    public void encodeToByteBuffer(ByteBuffer byteBuffer) {
+    public void encodeToByteBuffer(ByteBuffer byteBuffer, int encodedMessageBodyLen) {
         byteBuffer
                 .putLong(ikeInitiatorSpi)
                 .putLong(ikeResponderSpi)
@@ -177,6 +194,6 @@
             flag |= IKE_HEADER_FLAG_FROM_IKE_INITIATOR;
         }
 
-        byteBuffer.put(flag).putInt(messageId).putInt(messageLength);
+        byteBuffer.put(flag).putInt(messageId).putInt(IKE_HEADER_LENGTH + encodedMessageBodyLen);
     }
 }
diff --git a/src/java/com/android/ike/ikev2/message/IkeKePayload.java b/src/java/com/android/ike/ikev2/message/IkeKePayload.java
index a742d94..92e2238 100644
--- a/src/java/com/android/ike/ikev2/message/IkeKePayload.java
+++ b/src/java/com/android/ike/ikev2/message/IkeKePayload.java
@@ -16,7 +16,7 @@
 
 package com.android.ike.ikev2.message;
 
-import android.util.Pair;
+import android.annotation.Nullable;
 
 import com.android.ike.ikev2.IkeDhParams;
 import com.android.ike.ikev2.SaProposal;
@@ -27,9 +27,12 @@
 import java.math.BigInteger;
 import java.nio.ByteBuffer;
 import java.security.GeneralSecurityException;
+import java.security.InvalidAlgorithmParameterException;
 import java.security.KeyFactory;
 import java.security.KeyPair;
 import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
+import java.security.ProviderException;
 import java.security.SecureRandom;
 
 import javax.crypto.KeyAgreement;
@@ -64,8 +67,22 @@
     /** Supported dhGroup falls into {@link DhGroup} */
     public final int dhGroup;
 
+    /** Public DH key for the recipient to calculate shared key. */
     public final byte[] keyExchangeData;
 
+    /** Flag indicates if this is an outbound payload. */
+    public final boolean isOutbound;
+
+    /**
+     * localPrivateKey caches the locally generated private key when building an outbound KE
+     * payload. It will not be sent out. It is only used to calculate DH shared
+     * key when IKE library receives a public key from the remote server.
+     *
+     * <p>localPrivateKey of a inbound payload will be set to null. Caller MUST ensure its an
+     * outbound payload before using localPrivateKey.
+     */
+    @Nullable public final DHPrivateKeySpec localPrivateKey;
+
     /**
      * Construct an instance of IkeKePayload in the context of IkePayloadFactory
      *
@@ -79,6 +96,9 @@
     IkeKePayload(boolean critical, byte[] payloadBody) throws IkeException {
         super(PAYLOAD_TYPE_KE, critical);
 
+        isOutbound = false;
+        localPrivateKey = null;
+
         ByteBuffer inputBuffer = ByteBuffer.wrap(payloadBody);
 
         dhGroup = Short.toUnsignedInt(inputBuffer.getShort());
@@ -110,17 +130,67 @@
     /**
      * Construct an instance of IkeKePayload for building an outbound packet.
      *
+     * <p>Generate a DH key pair. Cache the private key and and send out the public key as
+     * keyExchangeData.
+     *
      * <p>Critical bit in this payload must not be set as instructed in RFC 7296.
      *
      * @param dh DH group for this KE payload
-     * @param keData the Key Exchange data
      * @see <a href="https://tools.ietf.org/html/rfc7296#page-76">RFC 7296, Internet Key Exchange
      *     Protocol Version 2 (IKEv2), Critical.
      */
-    private IkeKePayload(@SaProposal.DhGroup int dh, byte[] keData) {
+    public IkeKePayload(@SaProposal.DhGroup int dh) {
         super(PAYLOAD_TYPE_KE, false);
+
         dhGroup = dh;
-        keyExchangeData = keData;
+        isOutbound = true;
+
+        BigInteger prime = BigInteger.ZERO;
+        int keySize = 0;
+        switch (dhGroup) {
+            case SaProposal.DH_GROUP_1024_BIT_MODP:
+                prime =
+                        BigIntegerUtils.unsignedHexStringToBigInteger(
+                                IkeDhParams.PRIME_1024_BIT_MODP);
+                keySize = DH_GROUP_1024_BIT_MODP_DATA_LEN;
+                break;
+            case SaProposal.DH_GROUP_2048_BIT_MODP:
+                prime =
+                        BigIntegerUtils.unsignedHexStringToBigInteger(
+                                IkeDhParams.PRIME_2048_BIT_MODP);
+                keySize = DH_GROUP_2048_BIT_MODP_DATA_LEN;
+                break;
+            default:
+                throw new IllegalArgumentException("DH group not supported: " + dh);
+        }
+
+        try {
+            BigInteger baseGen = BigInteger.valueOf(IkeDhParams.BASE_GENERATOR_MODP);
+            DHParameterSpec dhParams = new DHParameterSpec(prime, baseGen);
+
+            KeyPairGenerator dhKeyPairGen =
+                    KeyPairGenerator.getInstance(
+                            KEY_EXCHANGE_ALGORITHM, IkeMessage.getSecurityProvider());
+            // By default SecureRandom uses AndroidOpenSSL provided SHA1PRNG Algorithm, which takes
+            // /dev/urandom as seed source.
+            dhKeyPairGen.initialize(dhParams, new SecureRandom());
+
+            KeyPair keyPair = dhKeyPairGen.generateKeyPair();
+
+            DHPrivateKey privateKey = (DHPrivateKey) keyPair.getPrivate();
+            DHPrivateKeySpec dhPrivateKeyspec =
+                    new DHPrivateKeySpec(privateKey.getX(), prime, baseGen);
+            DHPublicKey publicKey = (DHPublicKey) keyPair.getPublic();
+
+            // Zero-pad the public key without the sign bit
+            keyExchangeData =
+                    BigIntegerUtils.bigIntegerToUnsignedByteArray(publicKey.getY(), keySize);
+            localPrivateKey = dhPrivateKeyspec;
+        } catch (NoSuchAlgorithmException e) {
+            throw new ProviderException("Failed to obtain " + KEY_EXCHANGE_ALGORITHM, e);
+        } catch (InvalidAlgorithmParameterException e) {
+            throw new IllegalArgumentException("Failed to initialize key generator", e);
+        }
     }
 
     /**
@@ -149,56 +219,6 @@
     }
 
     /**
-     * Construct an instance of IkeKePayload according to its {@link DhGroup}.
-     *
-     * @param dh the Dh-Group. It should be in {@link DhGroup}
-     * @return Pair of generated private key and an instance of IkeKePayload with key exchange data.
-     * @throws GeneralSecurityException for security-related exception.
-     */
-    public static Pair<DHPrivateKeySpec, IkeKePayload> getKePayload(@SaProposal.DhGroup int dh)
-            throws GeneralSecurityException {
-        BigInteger baseGen = BigInteger.valueOf(IkeDhParams.BASE_GENERATOR_MODP);
-        BigInteger prime = BigInteger.ZERO;
-        int keySize = 0;
-        switch (dh) {
-            case SaProposal.DH_GROUP_1024_BIT_MODP:
-                prime =
-                        BigIntegerUtils.unsignedHexStringToBigInteger(
-                                IkeDhParams.PRIME_1024_BIT_MODP);
-                keySize = DH_GROUP_1024_BIT_MODP_DATA_LEN;
-                break;
-            case SaProposal.DH_GROUP_2048_BIT_MODP:
-                prime =
-                        BigIntegerUtils.unsignedHexStringToBigInteger(
-                                IkeDhParams.PRIME_2048_BIT_MODP);
-                keySize = DH_GROUP_2048_BIT_MODP_DATA_LEN;
-                break;
-            default:
-                throw new IllegalArgumentException("DH group not supported: " + dh);
-        }
-
-        DHParameterSpec dhParams = new DHParameterSpec(prime, baseGen);
-
-        KeyPairGenerator dhKeyPairGen =
-                KeyPairGenerator.getInstance(
-                        KEY_EXCHANGE_ALGORITHM, IkeMessage.getSecurityProvider());
-        // By default SecureRandom uses AndroidOpenSSL provided SHA1PRNG Algorithm, which takes
-        // /dev/urandom as seed source.
-        dhKeyPairGen.initialize(dhParams, new SecureRandom());
-
-        KeyPair keyPair = dhKeyPairGen.generateKeyPair();
-
-        DHPrivateKey privateKey = (DHPrivateKey) keyPair.getPrivate();
-        DHPrivateKeySpec dhPrivateKeyspec = new DHPrivateKeySpec(privateKey.getX(), prime, baseGen);
-        DHPublicKey publicKey = (DHPublicKey) keyPair.getPublic();
-
-        // Zero-pad the public key without the sign bit
-        byte[] keData = BigIntegerUtils.bigIntegerToUnsignedByteArray(publicKey.getY(), keySize);
-
-        return new Pair(dhPrivateKeyspec, new IkeKePayload(dh, keData));
-    }
-
-    /**
      * Calculate the shared secret.
      *
      * @param privateKeySpec contains the local private key, DH prime and DH base generator.
diff --git a/src/java/com/android/ike/ikev2/message/IkeMessage.java b/src/java/com/android/ike/ikev2/message/IkeMessage.java
index 4adbde0..5b63ee7 100644
--- a/src/java/com/android/ike/ikev2/message/IkeMessage.java
+++ b/src/java/com/android/ike/ikev2/message/IkeMessage.java
@@ -239,7 +239,7 @@
     byte[] attachEncodedHeader(byte[] encodedIkeBody) {
         ByteBuffer outputBuffer =
                 ByteBuffer.allocate(IkeHeader.IKE_HEADER_LENGTH + encodedIkeBody.length);
-        ikeHeader.encodeToByteBuffer(outputBuffer);
+        ikeHeader.encodeToByteBuffer(outputBuffer, encodedIkeBody.length);
         outputBuffer.put(encodedIkeBody);
         return outputBuffer.array();
     }
@@ -352,7 +352,7 @@
 
             ByteBuffer outputBuffer =
                     ByteBuffer.allocate(IkeHeader.IKE_HEADER_LENGTH + skPayload.getPayloadLength());
-            ikeHeader.encodeToByteBuffer(outputBuffer);
+            ikeHeader.encodeToByteBuffer(outputBuffer, skPayload.getPayloadLength());
             skPayload.encodeToByteBuffer(firstPayload, outputBuffer);
 
             return outputBuffer.array();
@@ -360,7 +360,7 @@
 
         @Override
         public IkeMessage decode(IkeHeader header, byte[] inputPacket) throws IkeException {
-            header.checkValidOrThrow(inputPacket.length);
+            header.checkInboundValidOrThrow(inputPacket.length);
 
             byte[] unencryptedPayloads =
                     Arrays.copyOfRange(
@@ -397,7 +397,7 @@
                 SecretKey dKey)
                 throws IkeException, GeneralSecurityException {
 
-            header.checkValidOrThrow(inputPacket.length);
+            header.checkInboundValidOrThrow(inputPacket.length);
 
             if (header.nextPayloadType != IkePayload.PAYLOAD_TYPE_SK) {
                 // TODO: b/123372339 Handle message containing unprotected payloads.
diff --git a/tests/iketests/src/java/com/android/ike/ikev2/ChildSessionStateMachineTest.java b/tests/iketests/src/java/com/android/ike/ikev2/ChildSessionStateMachineTest.java
index e81b17a..1d760b7 100644
--- a/tests/iketests/src/java/com/android/ike/ikev2/ChildSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/ike/ikev2/ChildSessionStateMachineTest.java
@@ -63,12 +63,13 @@
 
     private ISaRecordHelper mMockSaRecordHelper;
     private IChildSessionCallback mMockChildSessionCallback;
-    private ChildSessionOptions mMockChildSessionOptions;
+    private ChildSessionOptions mChildSessionOptions;
 
     public ChildSessionStateMachineTest() {
         mMockSaRecordHelper = mock(SaRecord.ISaRecordHelper.class);
         mMockChildSessionCallback = mock(IChildSessionCallback.class);
-        mMockChildSessionOptions = mock(ChildSessionOptions.class);
+
+        mChildSessionOptions = new ChildSessionOptions();
     }
 
     @Before
@@ -77,7 +78,7 @@
         mLooper = new TestLooper();
         mChildSessionStateMachine =
                 new ChildSessionStateMachine(
-                        "ChildSessionStateMachine", mLooper.getLooper(), mMockChildSessionOptions);
+                        "ChildSessionStateMachine", mLooper.getLooper(), mChildSessionOptions);
         mChildSessionStateMachine.setDbg(true);
         SaRecord.setSaRecordHelper(mMockSaRecordHelper);
 
diff --git a/tests/iketests/src/java/com/android/ike/ikev2/IkeSessionStateMachineTest.java b/tests/iketests/src/java/com/android/ike/ikev2/IkeSessionStateMachineTest.java
index e3ad450..1dcf5b8 100644
--- a/tests/iketests/src/java/com/android/ike/ikev2/IkeSessionStateMachineTest.java
+++ b/tests/iketests/src/java/com/android/ike/ikev2/IkeSessionStateMachineTest.java
@@ -17,6 +17,9 @@
 package com.android.ike.ikev2;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
@@ -26,8 +29,14 @@
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import android.content.Context;
+import android.net.IpSecManager;
+import android.net.IpSecManager.UdpEncapsulationSocket;
+import android.os.Looper;
 import android.os.test.TestLooper;
 
+import androidx.test.InstrumentationRegistry;
+
 import com.android.ike.ikev2.ChildSessionStateMachineFactory.ChildSessionFactoryHelper;
 import com.android.ike.ikev2.ChildSessionStateMachineFactory.IChildSessionFactoryHelper;
 import com.android.ike.ikev2.IkeSessionStateMachine.ReceivedIkePacket;
@@ -43,26 +52,37 @@
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 
+import java.net.InetAddress;
 import java.util.LinkedList;
+import java.util.List;
 
 public final class IkeSessionStateMachineTest {
 
+    private static final String SERVER_ADDRESS = "192.0.2.100";
+
+    private UdpEncapsulationSocket mUdpEncapSocket;
+
     private TestLooper mLooper;
     private IkeSessionStateMachine mIkeSessionStateMachine;
 
+    private IkeSessionOptions mIkeSessionOptions;
+    private ChildSessionOptions mChildSessionOptions;
+
     private IIkeMessageHelper mMockIkeMessageHelper;
     private ISaRecordHelper mMockSaRecordHelper;
-    private IkeSessionOptions mMockIkeSessionOptions;
 
     private ChildSessionStateMachine mMockChildSessionStateMachine;
-    private ChildSessionOptions mMockChildSessionOptions;
     private IChildSessionFactoryHelper mMockChildSessionFactoryHelper;
 
     private IkeSaRecord mSpyCurrentIkeSaRecord;
     private IkeSaRecord mSpyLocalInitIkeSaRecord;
     private IkeSaRecord mSpyRemoteInitIkeSaRecord;
 
+    private ArgumentCaptor<IkeMessage> mIkeMessageCaptor =
+            ArgumentCaptor.forClass(IkeMessage.class);
+
     private ReceivedIkePacket makeDummyUnencryptedReceivedIkePacket(int packetType)
             throws Exception {
         IkeMessage dummyIkeMessage = makeDummyIkeMessageForTest(0, 0, false, false);
@@ -83,7 +103,7 @@
         byte[] dummyIkePacketBytes = new byte[0];
 
         when(mMockIkeMessageHelper.decode(
-                        mMockIkeSessionOptions,
+                        mIkeSessionOptions,
                         ikeSaRecord,
                         dummyIkeMessage.ikeHeader,
                         dummyIkePacketBytes))
@@ -97,27 +117,21 @@
         int firstPayloadType =
                 isEncrypted ? IkePayload.PAYLOAD_TYPE_SK : IkePayload.PAYLOAD_TYPE_NO_NEXT;
         IkeHeader header =
-                new IkeHeader(initSpi, respSpi, firstPayloadType, 0, true, fromikeInit, 0, 0);
+                new IkeHeader(initSpi, respSpi, firstPayloadType, 0, true, fromikeInit, 0);
         return new IkeMessage(header, new LinkedList<IkePayload>());
     }
 
     private void verifyDecodeEncryptedMessage(IkeSaRecord record, ReceivedIkePacket rcvPacket)
             throws Exception {
         verify(mMockIkeMessageHelper)
-                .decode(
-                        mMockIkeSessionOptions,
-                        record,
-                        rcvPacket.ikeHeader,
-                        rcvPacket.ikePacketBytes);
+                .decode(mIkeSessionOptions, record, rcvPacket.ikeHeader, rcvPacket.ikePacketBytes);
     }
 
     public IkeSessionStateMachineTest() {
         mMockIkeMessageHelper = mock(IkeMessage.IIkeMessageHelper.class);
         mMockSaRecordHelper = mock(SaRecord.ISaRecordHelper.class);
-        mMockIkeSessionOptions = mock(IkeSessionOptions.class);
 
         mMockChildSessionStateMachine = mock(ChildSessionStateMachine.class);
-        mMockChildSessionOptions = mock(ChildSessionOptions.class);
         mMockChildSessionFactoryHelper = mock(IChildSessionFactoryHelper.class);
 
         mSpyCurrentIkeSaRecord = spy(new IkeSaRecord(11, 12, true, null, null));
@@ -131,17 +145,25 @@
     }
 
     @Before
-    public void setUp() {
+    public void setUp() throws Exception {
+        Context context = InstrumentationRegistry.getContext();
+        IpSecManager ipSecManager = (IpSecManager) context.getSystemService(Context.IPSEC_SERVICE);
+        mUdpEncapSocket = ipSecManager.openUdpEncapsulationSocket();
+
+        mIkeSessionOptions = buildIkeSessionOptions();
+        mChildSessionOptions = new ChildSessionOptions();
+
         // Setup thread and looper
         mLooper = new TestLooper();
         mIkeSessionStateMachine =
                 new IkeSessionStateMachine(
                         "IkeSessionStateMachine",
                         mLooper.getLooper(),
-                        mMockIkeSessionOptions,
-                        mMockChildSessionOptions);
+                        mIkeSessionOptions,
+                        mChildSessionOptions);
         mIkeSessionStateMachine.setDbg(true);
         mIkeSessionStateMachine.start();
+
         IkeMessage.setIkeMessageHelper(mMockIkeMessageHelper);
         SaRecord.setSaRecordHelper(mMockSaRecordHelper);
         ChildSessionStateMachineFactory.setChildSessionFactoryHelper(
@@ -149,17 +171,46 @@
     }
 
     @After
-    public void tearDown() {
+    public void tearDown() throws Exception {
         mIkeSessionStateMachine.quit();
         mIkeSessionStateMachine.setDbg(false);
+        mUdpEncapSocket.close();
+
         IkeMessage.setIkeMessageHelper(new IkeMessageHelper());
         SaRecord.setSaRecordHelper(new SaRecordHelper());
         ChildSessionStateMachineFactory.setChildSessionFactoryHelper(
                 new ChildSessionFactoryHelper());
     }
 
+    private IkeSessionOptions buildIkeSessionOptions() throws Exception {
+        SaProposal saProposal =
+                SaProposal.Builder.newIkeSaProposalBuilder()
+                        .addEncryptionAlgorithm(
+                                SaProposal.ENCRYPTION_ALGORITHM_AES_CBC, SaProposal.KEY_LEN_AES_128)
+                        .addIntegrityAlgorithm(SaProposal.INTEGRITY_ALGORITHM_HMAC_SHA1_96)
+                        .addPseudorandomFunction(SaProposal.PSEUDORANDOM_FUNCTION_HMAC_SHA1)
+                        .addDhGroup(SaProposal.DH_GROUP_1024_BIT_MODP)
+                        .build();
+
+        InetAddress serveAddress = InetAddress.getByName(SERVER_ADDRESS);
+        IkeSessionOptions sessionOptions =
+                new IkeSessionOptions.Builder(serveAddress, mUdpEncapSocket)
+                        .addSaProposal(saProposal)
+                        .build();
+        return sessionOptions;
+    }
+
+    private static boolean isIkePayloadExist(
+            List<IkePayload> payloadList, @IkePayload.PayloadType int payloadType) {
+        for (IkePayload payload : payloadList) {
+            if (payload.payloadType == payloadType) return true;
+        }
+        return false;
+    }
+
     @Test
     public void testCreateIkeLocalIkeInit() throws Exception {
+        if (Looper.myLooper() == null) Looper.myLooper().prepare();
         // Mock IKE_INIT response.
         ReceivedIkePacket dummyReceivedIkePacket =
                 makeDummyUnencryptedReceivedIkePacket(IkeMessage.MESSAGE_TYPE_IKE_INIT_RESP);
@@ -171,15 +222,37 @@
                 IkeSessionStateMachine.CMD_RECEIVE_IKE_PACKET, dummyReceivedIkePacket);
 
         mLooper.dispatchAll();
+
+        // Validate outbound IKE INIT request
+        verify(mMockIkeMessageHelper).encode(mIkeMessageCaptor.capture());
+        IkeMessage ikeInitReqMessage = mIkeMessageCaptor.getValue();
+
+        IkeHeader ikeHeader = ikeInitReqMessage.ikeHeader;
+        assertEquals(IkeHeader.EXCHANGE_TYPE_IKE_SA_INIT, ikeHeader.exchangeType);
+        assertFalse(ikeHeader.isResponseMsg);
+        assertTrue(ikeHeader.fromIkeInitiator);
+
+        List<IkePayload> payloadList = ikeInitReqMessage.ikePayloadList;
+        assertTrue(isIkePayloadExist(payloadList, IkePayload.PAYLOAD_TYPE_SA));
+        assertTrue(isIkePayloadExist(payloadList, IkePayload.PAYLOAD_TYPE_KE));
+        assertTrue(isIkePayloadExist(payloadList, IkePayload.PAYLOAD_TYPE_NONCE));
+
+        IkeSocket ikeSocket = mIkeSessionStateMachine.mIkeSocket;
+        assertNotNull(ikeSocket);
+        assertNotEquals(
+                -1 /*not found*/, ikeSocket.mSpiToIkeSession.indexOfValue(mIkeSessionStateMachine));
+
         verify(mMockIkeMessageHelper)
                 .decode(dummyReceivedIkePacket.ikeHeader, dummyReceivedIkePacket.ikePacketBytes);
         verify(mMockIkeMessageHelper).getMessageType(any());
+
         assertTrue(
                 mIkeSessionStateMachine.getCurrentState()
                         instanceof IkeSessionStateMachine.CreateIkeLocalIkeAuth);
     }
 
     private void mockIkeSetup() throws Exception {
+        if (Looper.myLooper() == null) Looper.myLooper().prepare();
         // Mock IKE_INIT response
         ReceivedIkePacket dummyIkeInitRespReceivedPacket =
                 makeDummyUnencryptedReceivedIkePacket(IkeMessage.MESSAGE_TYPE_IKE_INIT_RESP);
diff --git a/tests/iketests/src/java/com/android/ike/ikev2/SaProposalTest.java b/tests/iketests/src/java/com/android/ike/ikev2/SaProposalTest.java
index 428c028..7f40d72 100644
--- a/tests/iketests/src/java/com/android/ike/ikev2/SaProposalTest.java
+++ b/tests/iketests/src/java/com/android/ike/ikev2/SaProposalTest.java
@@ -66,16 +66,17 @@
                         .addDhGroup(SaProposal.DH_GROUP_1024_BIT_MODP)
                         .build();
 
-        assertEquals(IkePayload.PROTOCOL_ID_IKE, proposal.mProtocolId);
+        assertEquals(IkePayload.PROTOCOL_ID_IKE, proposal.getProtocolId());
         assertArrayEquals(
                 new EncryptionTransform[] {mEncryption3DesTransform},
-                proposal.mEncryptionAlgorithms);
+                proposal.getEncryptionTransforms());
         assertArrayEquals(
                 new IntegrityTransform[] {mIntegrityHmacSha1Transform},
-                proposal.mIntegrityAlgorithms);
+                proposal.getIntegrityTransforms());
         assertArrayEquals(
-                new PrfTransform[] {mPrfAes128XCbcTransform}, proposal.mPseudorandomFunctions);
-        assertArrayEquals(new DhGroupTransform[] {mDhGroup1024Transform}, proposal.mDhGroups);
+                new PrfTransform[] {mPrfAes128XCbcTransform}, proposal.getPrfTransforms());
+        assertArrayEquals(
+                new DhGroupTransform[] {mDhGroup1024Transform}, proposal.getDhGroupTransforms());
     }
 
     @Test
@@ -89,14 +90,15 @@
                         .addDhGroup(SaProposal.DH_GROUP_1024_BIT_MODP)
                         .build();
 
-        assertEquals(IkePayload.PROTOCOL_ID_IKE, proposal.mProtocolId);
+        assertEquals(IkePayload.PROTOCOL_ID_IKE, proposal.getProtocolId());
         assertArrayEquals(
                 new EncryptionTransform[] {mEncryptionAesGcm8Transform},
-                proposal.mEncryptionAlgorithms);
+                proposal.getEncryptionTransforms());
         assertArrayEquals(
-                new PrfTransform[] {mPrfAes128XCbcTransform}, proposal.mPseudorandomFunctions);
-        assertArrayEquals(new DhGroupTransform[] {mDhGroup1024Transform}, proposal.mDhGroups);
-        assertTrue(proposal.mIntegrityAlgorithms.length == 0);
+                new PrfTransform[] {mPrfAes128XCbcTransform}, proposal.getPrfTransforms());
+        assertArrayEquals(
+                new DhGroupTransform[] {mDhGroup1024Transform}, proposal.getDhGroupTransforms());
+        assertTrue(proposal.getIntegrityTransforms().length == 0);
     }
 
     @Test
@@ -109,14 +111,15 @@
                         .addIntegrityAlgorithm(SaProposal.INTEGRITY_ALGORITHM_NONE)
                         .build();
 
-        assertEquals(IkePayload.PROTOCOL_ID_ESP, proposal.mProtocolId);
+        assertEquals(IkePayload.PROTOCOL_ID_ESP, proposal.getProtocolId());
         assertArrayEquals(
                 new EncryptionTransform[] {mEncryptionAesGcm8Transform},
-                proposal.mEncryptionAlgorithms);
+                proposal.getEncryptionTransforms());
         assertArrayEquals(
-                new IntegrityTransform[] {mIntegrityNoneTransform}, proposal.mIntegrityAlgorithms);
-        assertTrue(proposal.mPseudorandomFunctions.length == 0);
-        assertTrue(proposal.mDhGroups.length == 0);
+                new IntegrityTransform[] {mIntegrityNoneTransform},
+                proposal.getIntegrityTransforms());
+        assertTrue(proposal.getPrfTransforms().length == 0);
+        assertTrue(proposal.getDhGroupTransforms().length == 0);
     }
 
     @Test
@@ -129,14 +132,16 @@
                         .addDhGroup(SaProposal.DH_GROUP_1024_BIT_MODP)
                         .build();
 
-        assertEquals(IkePayload.PROTOCOL_ID_ESP, proposal.mProtocolId);
+        assertEquals(IkePayload.PROTOCOL_ID_ESP, proposal.getProtocolId());
         assertArrayEquals(
                 new EncryptionTransform[] {mEncryption3DesTransform},
-                proposal.mEncryptionAlgorithms);
+                proposal.getEncryptionTransforms());
         assertArrayEquals(
-                new IntegrityTransform[] {mIntegrityNoneTransform}, proposal.mIntegrityAlgorithms);
-        assertArrayEquals(new DhGroupTransform[] {mDhGroup1024Transform}, proposal.mDhGroups);
-        assertTrue(proposal.mPseudorandomFunctions.length == 0);
+                new IntegrityTransform[] {mIntegrityNoneTransform},
+                proposal.getIntegrityTransforms());
+        assertArrayEquals(
+                new DhGroupTransform[] {mDhGroup1024Transform}, proposal.getDhGroupTransforms());
+        assertTrue(proposal.getPrfTransforms().length == 0);
     }
 
     @Test
diff --git a/tests/iketests/src/java/com/android/ike/ikev2/message/IkeHeaderTest.java b/tests/iketests/src/java/com/android/ike/ikev2/message/IkeHeaderTest.java
index 1bc52c3..08b1612 100644
--- a/tests/iketests/src/java/com/android/ike/ikev2/message/IkeHeaderTest.java
+++ b/tests/iketests/src/java/com/android/ike/ikev2/message/IkeHeaderTest.java
@@ -62,6 +62,7 @@
 
     private static final int IKE_MSG_ID = 0;
     private static final int IKE_MSG_LENGTH = 336;
+    private static final int IKE_MSG_BODY_LENGTH = IKE_MSG_LENGTH - IkeHeader.IKE_HEADER_LENGTH;
 
     // Byte offsets of version field in IKE message header.
     private static final int VERSION_OFFSET = 17;
@@ -89,7 +90,7 @@
         assertFalse(header.isResponseMsg);
         assertTrue(header.fromIkeInitiator);
         assertEquals(IKE_MSG_ID, header.messageId);
-        assertEquals(IKE_MSG_LENGTH, header.messageLength);
+        assertEquals(IKE_MSG_LENGTH, header.getInboundMessageLength());
     }
 
     @Test
@@ -142,7 +143,7 @@
         IkeHeader header = new IkeHeader(inputPacket);
 
         ByteBuffer byteBuffer = ByteBuffer.allocate(IkeHeader.IKE_HEADER_LENGTH);
-        header.encodeToByteBuffer(byteBuffer);
+        header.encodeToByteBuffer(byteBuffer, IKE_MSG_BODY_LENGTH);
 
         byte[] expectedPacket = TestUtils.hexStringToByteArray(IKE_HEADER_HEX_STRING);
         assertArrayEquals(expectedPacket, byteBuffer.array());
diff --git a/tests/iketests/src/java/com/android/ike/ikev2/message/IkeKePayloadTest.java b/tests/iketests/src/java/com/android/ike/ikev2/message/IkeKePayloadTest.java
index 4f45f26..1bb0b70 100644
--- a/tests/iketests/src/java/com/android/ike/ikev2/message/IkeKePayloadTest.java
+++ b/tests/iketests/src/java/com/android/ike/ikev2/message/IkeKePayloadTest.java
@@ -18,11 +18,10 @@
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
-import android.util.Pair;
-
 import com.android.ike.ikev2.IkeDhParams;
 import com.android.ike.ikev2.SaProposal;
 import com.android.ike.ikev2.exceptions.InvalidSyntaxException;
@@ -100,6 +99,7 @@
 
         IkeKePayload payload = new IkeKePayload(CRITICAL_BIT, inputPacket);
 
+        assertFalse(payload.isOutbound);
         assertEquals(EXPECTED_DH_GROUP, payload.dhGroup);
 
         byte[] keyExchangeData = TestUtils.hexStringToByteArray(KEY_EXCHANGE_DATA_RAW_PACKET);
@@ -138,11 +138,11 @@
 
     @Test
     public void testGetIkeKePayload() throws Exception {
-        Pair<DHPrivateKeySpec, IkeKePayload> pair =
-                IkeKePayload.getKePayload(SaProposal.DH_GROUP_1024_BIT_MODP);
+        IkeKePayload payload = new IkeKePayload(SaProposal.DH_GROUP_1024_BIT_MODP);
 
         // Test DHPrivateKeySpec
-        DHPrivateKeySpec privateKeySpec = pair.first;
+        assertTrue(payload.isOutbound);
+        DHPrivateKeySpec privateKeySpec = payload.localPrivateKey;
 
         BigInteger primeValue = privateKeySpec.getP();
         BigInteger expectedPrimeValue = new BigInteger(IkeDhParams.PRIME_1024_BIT_MODP, 16);
@@ -153,8 +153,6 @@
         assertEquals(0, expectedGenValue.compareTo(genValue));
 
         // Test IkeKePayload
-        IkeKePayload payload = pair.second;
-
         assertEquals(EXPECTED_DH_GROUP, payload.dhGroup);
         assertEquals(EXPECTED_KE_DATA_LEN, payload.keyExchangeData.length);
     }