Fix unwrap bug for large messages. (#189)

If you write a record and don't have enough destination buffer space to read all the plaintext, the plaintext gets left in the plaintext buffer and the next record you write ends up in the ciphertext buffer (and you read the leftover plaintext from the last record), and you continue to have a record sitting in the ciphertext buffer until you get two records that don't fit in the buffer together, at which point you get the short write and subsequent exception.

Also added a test to verify the bug.
diff --git a/common/src/jni/main/cpp/NativeCrypto.cpp b/common/src/jni/main/cpp/NativeCrypto.cpp
index dd754f7..a18ae5c 100644
--- a/common/src/jni/main/cpp/NativeCrypto.cpp
+++ b/common/src/jni/main/cpp/NativeCrypto.cpp
@@ -8599,6 +8599,11 @@
     if (bio == nullptr) {
         return -1;
     }
+    if (BIO_ctrl_get_write_guarantee(bio) < len) {
+        // The network BIO couldn't handle the entire write. Don't write anything, so that we
+        // only process one packet at a time.
+        return 0;
+    }
     const char* sourcePtr = reinterpret_cast<const char*>(address);
 
     AppData* appData = toAppData(ssl);
@@ -8647,6 +8652,11 @@
     if (bio == nullptr) {
         return -1;
     }
+    if (BIO_ctrl_get_write_guarantee(bio) < sourceLength) {
+        // The network BIO couldn't handle the entire write. Don't write anything, so that we
+        // only process one packet at a time.
+        return 0;
+    }
     ScopedByteArrayRO source(env, sourceJava);
     if (source.get() == nullptr) {
         JNI_TRACE("ssl=%p NativeCrypto_ENGINE_SSL_write_BIO_heap => threw exception", ssl);
diff --git a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
index aca9504..c42a0ff 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
@@ -267,8 +267,8 @@
 
         synchronized (stateLock) {
             if (isHandshakeStarted()) {
-                throw new IllegalStateException(
-                        "Could not change Channel ID private key after the initial handshake has begun.");
+                throw new IllegalStateException("Could not change Channel ID private key "
+                        + "after the initial handshake has begun.");
             }
 
             if (privateKey == null) {
@@ -491,12 +491,6 @@
         return NativeCrypto.SSL_pending_readable_bytes(sslNativePointer);
     }
 
-    private int pendingInboundCleartextBytes(HandshakeStatus handshakeStatus) {
-        // There won't be any application data until we're done handshaking.
-        // We first check handshakeFinished to eliminate the overhead of extra JNI call if possible.
-        return handshakeStatus == HandshakeStatus.FINISHED ? pendingInboundCleartextBytes() : 0;
-    }
-
     private static SSLEngineResult.HandshakeStatus pendingStatus(int pendingOutboundBytes) {
         // Depending on if there is something left in the BIO we need to WRAP or UNWRAP
         return pendingOutboundBytes > 0 ? NEED_WRAP : NEED_UNWRAP;
@@ -646,28 +640,11 @@
         checkIndex(dsts.length, dstsOffset, dstsLength, "dsts");
 
         // Determine the output capacity.
-        int capacity = 0;
+        final int dstLength = calcDstsLength(dsts, dstsOffset, dstsLength);
         final int endOffset = dstsOffset + dstsLength;
-        for (int i = 0; i < dsts.length; i++) {
-            ByteBuffer dst = dsts[i];
-            checkNotNull(dst, "one of the dst");
-            if (dst.isReadOnly()) {
-                throw new ReadOnlyBufferException();
-            }
-            if (i >= dstsOffset && i < dstsOffset + dstsLength) {
-                capacity += dst.remaining();
-            }
-        }
 
         final int srcsEndOffset = srcsOffset + srcsLength;
-        long len = 0;
-        for (int i = srcsOffset; i < srcsEndOffset; i++) {
-            ByteBuffer src = srcs[i];
-            if (src == null) {
-                throw new IllegalArgumentException("srcs[" + i + "] is null");
-            }
-            len += src.remaining();
-        }
+        final long srcLength = calcSrcsLength(srcs, srcsOffset, srcsEndOffset);
 
         synchronized (stateLock) {
             switch (engineState) {
@@ -698,41 +675,51 @@
                 // NEED_UNWRAP - just fall through to perform the unwrap.
             }
 
-            if (len < SSL3_RT_HEADER_LENGTH) {
+            // Consume any source data. Skip this if there are unread cleartext data.
+            boolean noCleartextDataAvailable = pendingInboundCleartextBytes() <= 0;
+            int lenRemaining = 0;
+            if (srcLength > 0 && noCleartextDataAvailable) {
+                if (srcLength < SSL3_RT_HEADER_LENGTH) {
+                    // Need to be able to read a full TLS header.
+                    return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
+                }
+
+                int packetLength = SSLUtils.getEncryptedPacketLength(srcs, srcsOffset);
+                if (packetLength < 0) {
+                    throw new SSLException("Unable to parse TLS packet header");
+                }
+
+                if (srcLength < packetLength) {
+                    // We either have not enough data to read the packet header or not enough for
+                    // reading the whole packet.
+                    return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
+                }
+
+                // Limit the amount of data to be read to a single packet.
+                lenRemaining = packetLength;
+            } else if (noCleartextDataAvailable) {
+                // No pending data and nothing provided as input.  Need more data.
                 return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
             }
 
-            int packetLength = SSLUtils.getEncryptedPacketLength(srcs, srcsOffset);
-            if (packetLength < 0) {
-                throw new SSLException("Unable to parse TLS packet header");
-            }
-
-            if (len < packetLength) {
-                // We either have not enough data to read the packet header or not enough for
-                // reading the whole packet.
-                return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
-            }
-
-            // Write all of the source data to the networkBio
+            // Write all of the encrypted source data to the networkBio
             int bytesConsumed = 0;
-            if (srcsOffset < srcsEndOffset) {
-                int packetLengthRemaining = packetLength;
+            if (lenRemaining > 0 && srcsOffset < srcsEndOffset) {
                 do {
                     ByteBuffer src = srcs[srcsOffset];
                     int remaining = src.remaining();
                     if (remaining == 0) {
-                        // We must skip empty buffers as BIO_write will return 0 if asked to write
-                        // something
-                        // with length 0.
+                        // We must skip empty buffers as BIO_write will return 0 if asked to
+                        // write something with length 0.
                         srcsOffset++;
                         continue;
                     }
                     // Write the source encrypted data to the networkBio.
-                    int written =
-                            writeEncryptedData(src, Math.min(packetLengthRemaining, remaining));
+                    int written = writeEncryptedData(src, Math.min(lenRemaining, remaining));
                     if (written > 0) {
-                        packetLengthRemaining -= written;
-                        if (packetLengthRemaining == 0) {
+                        bytesConsumed += written;
+                        lenRemaining -= written;
+                        if (lenRemaining == 0) {
                             // A whole packet has been consumed.
                             break;
                         }
@@ -740,30 +727,26 @@
                         if (written == remaining) {
                             srcsOffset++;
                         } else {
-                            // We were not able to write everything into the BIO so break the write
-                            // loop as otherwise
-                            // we will produce an error on the next write attempt, which will
-                            // trigger a SSL.clearError()
-                            // later.
+                            // We were not able to write everything into the BIO so break the
+                            // write loop as otherwise we will produce an error on the next
+                            // write attempt, which will trigger a SSL.clearError() later.
                             break;
                         }
                     } else {
                         // BIO_write returned a negative or zero number, this means we could not
-                        // complete the write
-                        // operation and should retry later.
-                        // We ignore BIO_* errors here as we use in memory BIO anyway and will do
-                        // another SSL_* call
-                        // later on in which we will produce an exception in case of an error
+                        // complete the write operation and should retry later.
+                        // We ignore BIO_* errors here as we use in memory BIO anyway and will
+                        // do another SSL_* call later on in which we will produce an exception
+                        // in case of an error
                         NativeCrypto.SSL_clear_error();
                         break;
                     }
                 } while (srcsOffset < srcsEndOffset);
-                bytesConsumed = packetLength - packetLengthRemaining;
             }
 
             // Now read any available plaintext data.
             int bytesProduced = 0;
-            if (capacity > 0) {
+            if (dstLength > 0) {
                 // Write decrypted data to dsts buffers
                 for (int idx = dstsOffset; idx < endOffset; ++idx) {
                     ByteBuffer dst = dsts[idx];
@@ -772,38 +755,35 @@
                     }
 
                     int bytesRead = readPlaintextData(dst);
-
                     if (bytesRead > 0) {
                         bytesProduced += bytesRead;
-                        if (!dst.hasRemaining()) {
-                            continue;
+                        if (dst.hasRemaining()) {
+                            // We haven't filled this buffer fully, break out of the loop
+                            // and determine the correct response status below.
+                            break;
                         }
-
-                        // We read everything return now.
-                        return newResult(bytesConsumed, bytesProduced, handshakeStatus);
-                    }
-
-                    // Return an appropriate result based on the error code.
-                    int sslError = NativeCrypto.SSL_get_error(sslNativePointer, bytesRead);
-                    switch (sslError) {
-                        case SSL_ERROR_ZERO_RETURN:
-                            // This means the connection was shutdown correctly, close inbound and
-                            // outbound
-                            closeAll();
-                            return newResult(bytesConsumed, bytesProduced, handshakeStatus);
-                        case SSL_ERROR_WANT_READ:
-                        case SSL_ERROR_WANT_WRITE:
-                            return newResult(bytesConsumed, bytesProduced, handshakeStatus);
-                        default:
-                            return sslReadErrorResult(NativeCrypto.SSL_get_last_error_number(),
-                                    bytesConsumed, bytesProduced);
+                    } else {
+                        // Return an appropriate result based on the error code.
+                        int sslError = NativeCrypto.SSL_get_error(sslNativePointer, bytesRead);
+                        switch (sslError) {
+                            case SSL_ERROR_ZERO_RETURN:
+                                // This means the connection was shutdown correctly, close inbound
+                                // and outbound
+                                closeAll();
+                                return newResult(bytesConsumed, bytesProduced, handshakeStatus);
+                            case SSL_ERROR_WANT_READ:
+                            case SSL_ERROR_WANT_WRITE:
+                                return newResult(bytesConsumed, bytesProduced, handshakeStatus);
+                            default:
+                                return sslReadErrorResult(NativeCrypto.SSL_get_last_error_number(),
+                                        bytesConsumed, bytesProduced);
+                        }
                     }
                 }
             } else {
                 // If the capacity of all destination buffers is 0 we need to trigger a SSL_read
-                // anyway to ensure
-                // everything is flushed in the BIO pair and so we can detect it in the
-                // pendingInboundCleartextBytes() call.
+                // anyway to ensure everything is flushed in the BIO pair and so we can detect it
+                // in the pendingInboundCleartextBytes() call.
                 try {
                     if (NativeCrypto.ENGINE_SSL_read_direct(sslNativePointer, EMPTY_ADDR, 0, this)
                             <= 0) {
@@ -818,7 +798,12 @@
                     throw new SSLException(e);
                 }
             }
-            if (pendingInboundCleartextBytes(handshakeStatus) > 0) {
+
+            // There won't be any application data until we're done handshaking.
+            // We first check handshakeFinished to eliminate the overhead of extra JNI call if
+            // possible.
+            int pendingCleartextBytes = handshakeFinished ? pendingInboundCleartextBytes() : 0;
+            if (pendingCleartextBytes > 0) {
                 // We filled all buffers but there is still some data pending in the BIO buffer,
                 // return BUFFER_OVERFLOW.
                 return new SSLEngineResult(BUFFER_OVERFLOW,
@@ -832,6 +817,33 @@
         }
     }
 
+    private static int calcDstsLength(ByteBuffer[] dsts, int dstsOffset, int dstsLength) {
+        int capacity = 0;
+        for (int i = 0; i < dsts.length; i++) {
+            ByteBuffer dst = dsts[i];
+            checkNotNull(dst, "one of the dst");
+            if (dst.isReadOnly()) {
+                throw new ReadOnlyBufferException();
+            }
+            if (i >= dstsOffset && i < dstsOffset + dstsLength) {
+                capacity += dst.remaining();
+            }
+        }
+        return capacity;
+    }
+
+    private static long calcSrcsLength(ByteBuffer[] srcs, int srcsOffset, int srcsEndOffset) {
+        long len = 0;
+        for (int i = srcsOffset; i < srcsEndOffset; i++) {
+            ByteBuffer src = srcs[i];
+            if (src == null) {
+                throw new IllegalArgumentException("srcs[" + i + "] is null");
+            }
+            len += src.remaining();
+        }
+        return len;
+    }
+
     private SSLEngineResult.HandshakeStatus handshake() throws SSLException {
         long sslSessionCtx = 0L;
         try {
diff --git a/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java b/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java
index 188ee47..27df90e 100644
--- a/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java
+++ b/openjdk/src/test/java/org/conscrypt/OpenSSLEngineImplTest.java
@@ -23,13 +23,17 @@
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 
+import java.io.ByteArrayOutputStream;
 import java.nio.ByteBuffer;
 import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLEngineResult;
 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLEngineResult.Status;
 import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLHandshakeException;
 import libcore.java.security.TestKeyStore;
@@ -182,6 +186,69 @@
         }
     }
 
+    @Test
+    public void exchangeLargeMessage() throws Exception {
+        setupEngines(TestKeyStore.getClient(), TestKeyStore.getServer());
+        TestUtils.doEngineHandshake(clientEngine, serverEngine);
+
+        // Create the input message.
+        final int largeMessageSize = 16413;
+        final byte[] message = newTextMessage(largeMessageSize);
+        ByteBuffer inputBuffer = bufferType.newBuffer(largeMessageSize);
+        inputBuffer.put(message);
+        inputBuffer.flip();
+
+        // Encrypt the input message.
+        List<ByteBuffer> encryptedBufferList = new ArrayList<ByteBuffer>();
+        while(inputBuffer.hasRemaining()) {
+            ByteBuffer encryptedBuffer = bufferType.newBuffer(clientEngine.getSession().getPacketBufferSize());
+            SSLEngineResult wrapResult = clientEngine.wrap(inputBuffer, encryptedBuffer);
+            assertEquals(SSLEngineResult.Status.OK, wrapResult.getStatus());
+            encryptedBuffer.flip();
+            encryptedBufferList.add(encryptedBuffer);
+        }
+
+        // Unwrap the all of the encrypted messages.
+        ByteArrayOutputStream cleartextStream = new ByteArrayOutputStream();
+        ByteBuffer[] encryptedBuffers = encryptedBufferList.toArray(new ByteBuffer[encryptedBufferList.size()]);
+        int decryptedBufferSize = 8192;
+        final ByteBuffer decryptedBuffer = bufferType.newBuffer(decryptedBufferSize);
+        for (ByteBuffer encryptedBuffer : encryptedBuffers) {
+            SSLEngineResult.Status status = SSLEngineResult.Status.OK;
+            while (encryptedBuffer.hasRemaining() || status.equals(Status.BUFFER_OVERFLOW)) {
+                if (!decryptedBuffer.hasRemaining()) {
+                    decryptedBuffer.clear();
+                }
+                int prevPos = decryptedBuffer.position();
+                SSLEngineResult unwrapResult = Conscrypt.Engines.unwrap(serverEngine,
+                        encryptedBuffers, new ByteBuffer[]{decryptedBuffer});
+                status = unwrapResult.getStatus();
+                int newPos = decryptedBuffer.position();
+                int bytesProduced = unwrapResult.bytesProduced();
+                assertEquals(bytesProduced, newPos - prevPos);
+
+                // Add any generated bytes to the output stream.
+                if (bytesProduced > 0) {
+                    byte[] decryptedBytes = new byte[unwrapResult.bytesProduced()];
+
+                    // Read the chunk that was just written to the output array.
+                    int limit = decryptedBuffer.limit();
+                    decryptedBuffer.limit(newPos);
+                    decryptedBuffer.position(prevPos);
+                    decryptedBuffer.get(decryptedBytes);
+
+                    // Restore the position and limit.
+                    decryptedBuffer.limit(limit);
+
+                    // Write the decrypted bytes to the stream.
+                    cleartextStream.write(decryptedBytes);
+                }
+            }
+        }
+        byte[] actualMessage = cleartextStream.toByteArray();
+        assertArrayEquals(message, actualMessage);
+    }
+
     private void doMutualAuthHandshake(TestKeyStore clientKs, TestKeyStore serverKs, ClientAuth clientAuth) throws Exception {
         setupEngines(clientKs, serverKs);
         clientAuth.apply(serverEngine);