release-request-94bbded7-c270-40fa-9a74-fedecfd046c6-for-git_oc-release-4034177 snap-temp-L73200000066785187

Change-Id: Iae85c227f94a00ac2180cc1f2286a89c55ba48d4
diff --git a/common/src/jni/main/cpp/NativeCrypto.cpp b/common/src/jni/main/cpp/NativeCrypto.cpp
index afe06a8..cbb0f23 100644
--- a/common/src/jni/main/cpp/NativeCrypto.cpp
+++ b/common/src/jni/main/cpp/NativeCrypto.cpp
@@ -8636,6 +8636,11 @@
     if (bio == nullptr) {
         return -1;
     }
+    if (len < 0 || BIO_ctrl_get_write_guarantee(bio) < static_cast<size_t>(len)) {
+        // The network BIO couldn't handle the entire write. Don't write anything, so that we
+        // only process one packet at a time.
+        return 0;
+    }
     const char* sourcePtr = reinterpret_cast<const char*>(address);
 
     AppData* appData = toAppData(ssl);
@@ -8684,6 +8689,11 @@
     if (bio == nullptr) {
         return -1;
     }
+    if (sourceLength < 0 || BIO_ctrl_get_write_guarantee(bio) < static_cast<size_t>(sourceLength)) {
+        // The network BIO couldn't handle the entire write. Don't write anything, so that we
+        // only process one packet at a time.
+        return 0;
+    }
     ScopedByteArrayRO source(env, sourceJava);
     if (source.get() == nullptr) {
         JNI_TRACE("ssl=%p NativeCrypto_ENGINE_SSL_write_BIO_heap => threw exception", ssl);
diff --git a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
index cd46f9a..86245be 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
@@ -267,8 +267,8 @@
 
         synchronized (stateLock) {
             if (isHandshakeStarted()) {
-                throw new IllegalStateException(
-                        "Could not change Channel ID private key after the initial handshake has begun.");
+                throw new IllegalStateException("Could not change Channel ID private key "
+                        + "after the initial handshake has begun.");
             }
 
             if (privateKey == null) {
@@ -491,12 +491,6 @@
         return NativeCrypto.SSL_pending_readable_bytes(sslNativePointer);
     }
 
-    private int pendingInboundCleartextBytes(HandshakeStatus handshakeStatus) {
-        // There won't be any application data until we're done handshaking.
-        // We first check handshakeFinished to eliminate the overhead of extra JNI call if possible.
-        return handshakeStatus == HandshakeStatus.FINISHED ? pendingInboundCleartextBytes() : 0;
-    }
-
     private static SSLEngineResult.HandshakeStatus pendingStatus(int pendingOutboundBytes) {
         // Depending on if there is something left in the BIO we need to WRAP or UNWRAP
         return pendingOutboundBytes > 0 ? NEED_WRAP : NEED_UNWRAP;
@@ -646,28 +640,11 @@
         checkIndex(dsts.length, dstsOffset, dstsLength, "dsts");
 
         // Determine the output capacity.
-        int capacity = 0;
+        final int dstLength = calcDstsLength(dsts, dstsOffset, dstsLength);
         final int endOffset = dstsOffset + dstsLength;
-        for (int i = 0; i < dsts.length; i++) {
-            ByteBuffer dst = dsts[i];
-            checkNotNull(dst, "one of the dst");
-            if (dst.isReadOnly()) {
-                throw new ReadOnlyBufferException();
-            }
-            if (i >= dstsOffset && i < dstsOffset + dstsLength) {
-                capacity += dst.remaining();
-            }
-        }
 
         final int srcsEndOffset = srcsOffset + srcsLength;
-        long len = 0;
-        for (int i = srcsOffset; i < srcsEndOffset; i++) {
-            ByteBuffer src = srcs[i];
-            if (src == null) {
-                throw new IllegalArgumentException("srcs[" + i + "] is null");
-            }
-            len += src.remaining();
-        }
+        final long srcLength = calcSrcsLength(srcs, srcsOffset, srcsEndOffset);
 
         synchronized (stateLock) {
             switch (engineState) {
@@ -698,41 +675,51 @@
                 // NEED_UNWRAP - just fall through to perform the unwrap.
             }
 
-            if (len < SSL3_RT_HEADER_LENGTH) {
+            // Consume any source data. Skip this if there are unread cleartext data.
+            boolean noCleartextDataAvailable = pendingInboundCleartextBytes() <= 0;
+            int lenRemaining = 0;
+            if (srcLength > 0 && noCleartextDataAvailable) {
+                if (srcLength < SSL3_RT_HEADER_LENGTH) {
+                    // Need to be able to read a full TLS header.
+                    return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
+                }
+
+                int packetLength = SSLUtils.getEncryptedPacketLength(srcs, srcsOffset);
+                if (packetLength < 0) {
+                    throw new SSLException("Unable to parse TLS packet header");
+                }
+
+                if (srcLength < packetLength) {
+                    // We either have not enough data to read the packet header or not enough for
+                    // reading the whole packet.
+                    return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
+                }
+
+                // Limit the amount of data to be read to a single packet.
+                lenRemaining = packetLength;
+            } else if (noCleartextDataAvailable) {
+                // No pending data and nothing provided as input.  Need more data.
                 return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
             }
 
-            int packetLength = SSLUtils.getEncryptedPacketLength(srcs, srcsOffset);
-            if (packetLength < 0) {
-                throw new SSLException("Unable to parse TLS packet header");
-            }
-
-            if (len < packetLength) {
-                // We either have not enough data to read the packet header or not enough for
-                // reading the whole packet.
-                return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
-            }
-
-            // Write all of the source data to the networkBio
+            // Write all of the encrypted source data to the networkBio
             int bytesConsumed = 0;
-            if (srcsOffset < srcsEndOffset) {
-                int packetLengthRemaining = packetLength;
+            if (lenRemaining > 0 && srcsOffset < srcsEndOffset) {
                 do {
                     ByteBuffer src = srcs[srcsOffset];
                     int remaining = src.remaining();
                     if (remaining == 0) {
-                        // We must skip empty buffers as BIO_write will return 0 if asked to write
-                        // something
-                        // with length 0.
+                        // We must skip empty buffers as BIO_write will return 0 if asked to
+                        // write something with length 0.
                         srcsOffset++;
                         continue;
                     }
                     // Write the source encrypted data to the networkBio.
-                    int written =
-                            writeEncryptedData(src, Math.min(packetLengthRemaining, remaining));
+                    int written = writeEncryptedData(src, Math.min(lenRemaining, remaining));
                     if (written > 0) {
-                        packetLengthRemaining -= written;
-                        if (packetLengthRemaining == 0) {
+                        bytesConsumed += written;
+                        lenRemaining -= written;
+                        if (lenRemaining == 0) {
                             // A whole packet has been consumed.
                             break;
                         }
@@ -740,30 +727,26 @@
                         if (written == remaining) {
                             srcsOffset++;
                         } else {
-                            // We were not able to write everything into the BIO so break the write
-                            // loop as otherwise
-                            // we will produce an error on the next write attempt, which will
-                            // trigger a SSL.clearError()
-                            // later.
+                            // We were not able to write everything into the BIO so break the
+                            // write loop as otherwise we will produce an error on the next
+                            // write attempt, which will trigger a SSL.clearError() later.
                             break;
                         }
                     } else {
                         // BIO_write returned a negative or zero number, this means we could not
-                        // complete the write
-                        // operation and should retry later.
-                        // We ignore BIO_* errors here as we use in memory BIO anyway and will do
-                        // another SSL_* call
-                        // later on in which we will produce an exception in case of an error
+                        // complete the write operation and should retry later.
+                        // We ignore BIO_* errors here as we use in memory BIO anyway and will
+                        // do another SSL_* call later on in which we will produce an exception
+                        // in case of an error
                         NativeCrypto.SSL_clear_error();
                         break;
                     }
                 } while (srcsOffset < srcsEndOffset);
-                bytesConsumed = packetLength - packetLengthRemaining;
             }
 
             // Now read any available plaintext data.
             int bytesProduced = 0;
-            if (capacity > 0) {
+            if (dstLength > 0) {
                 // Write decrypted data to dsts buffers
                 for (int idx = dstsOffset; idx < endOffset; ++idx) {
                     ByteBuffer dst = dsts[idx];
@@ -772,38 +755,35 @@
                     }
 
                     int bytesRead = readPlaintextData(dst);
-
                     if (bytesRead > 0) {
                         bytesProduced += bytesRead;
-                        if (!dst.hasRemaining()) {
-                            continue;
+                        if (dst.hasRemaining()) {
+                            // We haven't filled this buffer fully, break out of the loop
+                            // and determine the correct response status below.
+                            break;
                         }
-
-                        // We read everything return now.
-                        return newResult(bytesConsumed, bytesProduced, handshakeStatus);
-                    }
-
-                    // Return an appropriate result based on the error code.
-                    int sslError = NativeCrypto.SSL_get_error(sslNativePointer, bytesRead);
-                    switch (sslError) {
-                        case SSL_ERROR_ZERO_RETURN:
-                            // This means the connection was shutdown correctly, close inbound and
-                            // outbound
-                            closeAll();
-                            return newResult(bytesConsumed, bytesProduced, handshakeStatus);
-                        case SSL_ERROR_WANT_READ:
-                        case SSL_ERROR_WANT_WRITE:
-                            return newResult(bytesConsumed, bytesProduced, handshakeStatus);
-                        default:
-                            return sslReadErrorResult(NativeCrypto.SSL_get_last_error_number(),
-                                    bytesConsumed, bytesProduced);
+                    } else {
+                        // Return an appropriate result based on the error code.
+                        int sslError = NativeCrypto.SSL_get_error(sslNativePointer, bytesRead);
+                        switch (sslError) {
+                            case SSL_ERROR_ZERO_RETURN:
+                                // This means the connection was shutdown correctly, close inbound
+                                // and outbound
+                                closeAll();
+                                return newResult(bytesConsumed, bytesProduced, handshakeStatus);
+                            case SSL_ERROR_WANT_READ:
+                            case SSL_ERROR_WANT_WRITE:
+                                return newResult(bytesConsumed, bytesProduced, handshakeStatus);
+                            default:
+                                return sslReadErrorResult(NativeCrypto.SSL_get_last_error_number(),
+                                        bytesConsumed, bytesProduced);
+                        }
                     }
                 }
             } else {
                 // If the capacity of all destination buffers is 0 we need to trigger a SSL_read
-                // anyway to ensure
-                // everything is flushed in the BIO pair and so we can detect it in the
-                // pendingInboundCleartextBytes() call.
+                // anyway to ensure everything is flushed in the BIO pair and so we can detect it
+                // in the pendingInboundCleartextBytes() call.
                 try {
                     if (NativeCrypto.ENGINE_SSL_read_direct(sslNativePointer, EMPTY_ADDR, 0, this)
                             <= 0) {
@@ -818,7 +798,12 @@
                     throw new SSLException(e);
                 }
             }
-            if (pendingInboundCleartextBytes(handshakeStatus) > 0) {
+
+            // There won't be any application data until we're done handshaking.
+            // We first check handshakeFinished to eliminate the overhead of extra JNI call if
+            // possible.
+            int pendingCleartextBytes = handshakeFinished ? pendingInboundCleartextBytes() : 0;
+            if (pendingCleartextBytes > 0) {
                 // We filled all buffers but there is still some data pending in the BIO buffer,
                 // return BUFFER_OVERFLOW.
                 return new SSLEngineResult(BUFFER_OVERFLOW,
@@ -832,6 +817,33 @@
         }
     }
 
+    private static int calcDstsLength(ByteBuffer[] dsts, int dstsOffset, int dstsLength) {
+        int capacity = 0;
+        for (int i = 0; i < dsts.length; i++) {
+            ByteBuffer dst = dsts[i];
+            checkNotNull(dst, "one of the dst");
+            if (dst.isReadOnly()) {
+                throw new ReadOnlyBufferException();
+            }
+            if (i >= dstsOffset && i < dstsOffset + dstsLength) {
+                capacity += dst.remaining();
+            }
+        }
+        return capacity;
+    }
+
+    private static long calcSrcsLength(ByteBuffer[] srcs, int srcsOffset, int srcsEndOffset) {
+        long len = 0;
+        for (int i = srcsOffset; i < srcsEndOffset; i++) {
+            ByteBuffer src = srcs[i];
+            if (src == null) {
+                throw new IllegalArgumentException("srcs[" + i + "] is null");
+            }
+            len += src.remaining();
+        }
+        return len;
+    }
+
     private SSLEngineResult.HandshakeStatus handshake() throws SSLException {
         long sslSessionCtx = 0L;
         try {
diff --git a/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java b/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java
index af62889..39ec165 100644
--- a/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java
+++ b/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java
@@ -23,13 +23,17 @@
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 
+import java.io.ByteArrayOutputStream;
 import java.nio.ByteBuffer;
 import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLEngineResult;
 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLEngineResult.Status;
 import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLHandshakeException;
 import libcore.java.security.TestKeyStore;
@@ -183,6 +187,69 @@
         }
     }
 
+    @Test
+    public void exchangeLargeMessage() throws Exception {
+        setupEngines(TestKeyStore.getClient(), TestKeyStore.getServer());
+        TestUtil.doEngineHandshake(clientEngine, serverEngine);
+
+        // Create the input message.
+        final int largeMessageSize = 16413;
+        final byte[] message = newTextMessage(largeMessageSize);
+        ByteBuffer inputBuffer = bufferType.newBuffer(largeMessageSize);
+        inputBuffer.put(message);
+        inputBuffer.flip();
+
+        // Encrypt the input message.
+        List<ByteBuffer> encryptedBufferList = new ArrayList<ByteBuffer>();
+        while(inputBuffer.hasRemaining()) {
+            ByteBuffer encryptedBuffer = bufferType.newBuffer(clientEngine.getSession().getPacketBufferSize());
+            SSLEngineResult wrapResult = clientEngine.wrap(inputBuffer, encryptedBuffer);
+            assertEquals(SSLEngineResult.Status.OK, wrapResult.getStatus());
+            encryptedBuffer.flip();
+            encryptedBufferList.add(encryptedBuffer);
+        }
+
+        // Unwrap the all of the encrypted messages.
+        ByteArrayOutputStream cleartextStream = new ByteArrayOutputStream();
+        ByteBuffer[] encryptedBuffers = encryptedBufferList.toArray(new ByteBuffer[encryptedBufferList.size()]);
+        int decryptedBufferSize = 8192;
+        final ByteBuffer decryptedBuffer = bufferType.newBuffer(decryptedBufferSize);
+        for (ByteBuffer encryptedBuffer : encryptedBuffers) {
+            SSLEngineResult.Status status = SSLEngineResult.Status.OK;
+            while (encryptedBuffer.hasRemaining() || status.equals(Status.BUFFER_OVERFLOW)) {
+                if (!decryptedBuffer.hasRemaining()) {
+                    decryptedBuffer.clear();
+                }
+                int prevPos = decryptedBuffer.position();
+                SSLEngineResult unwrapResult = Conscrypt.Engines.unwrap(serverEngine,
+                        encryptedBuffers, new ByteBuffer[]{decryptedBuffer});
+                status = unwrapResult.getStatus();
+                int newPos = decryptedBuffer.position();
+                int bytesProduced = unwrapResult.bytesProduced();
+                assertEquals(bytesProduced, newPos - prevPos);
+
+                // Add any generated bytes to the output stream.
+                if (bytesProduced > 0) {
+                    byte[] decryptedBytes = new byte[unwrapResult.bytesProduced()];
+
+                    // Read the chunk that was just written to the output array.
+                    int limit = decryptedBuffer.limit();
+                    decryptedBuffer.limit(newPos);
+                    decryptedBuffer.position(prevPos);
+                    decryptedBuffer.get(decryptedBytes);
+
+                    // Restore the position and limit.
+                    decryptedBuffer.limit(limit);
+
+                    // Write the decrypted bytes to the stream.
+                    cleartextStream.write(decryptedBytes);
+                }
+            }
+        }
+        byte[] actualMessage = cleartextStream.toByteArray();
+        assertArrayEquals(message, actualMessage);
+    }
+
     private void doMutualAuthHandshake(TestKeyStore clientKs, TestKeyStore serverKs, ClientAuth clientAuth) throws Exception {
         setupEngines(clientKs, serverKs);
         clientAuth.apply(serverEngine);