Merge Conscrypt upstream master.

Contains the following upstream changes:
    Tidy ConscryptEngineSocket state machine. (#1120)
    Add missing close() calls. (#1122)

Bug: 276304877
Test: MtsConscryptTestCases
Change-Id: Iadb7295b1fa0925dd8b966b48bca3040f858196f
Merged-In: Iadb7295b1fa0925dd8b966b48bca3040f858196f
(cherry picked from commit af289bed7ed17dc6138da62db42370b1c3472172)
diff --git a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java
index 48c6b3d..8d96276 100644
--- a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java
+++ b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java
@@ -60,7 +60,7 @@
     private SSLOutputStream out;
     private SSLInputStream in;
 
-    private long handshakeStartedMillis;
+    private long handshakeStartedMillis = 0;
 
     private BufferAllocator bufferAllocator = ConscryptEngine.getDefaultBufferAllocator();
 
@@ -123,7 +123,7 @@
             @Override
             public void onHandshakeFinished() {
                 // Just call the outer class method.
-                socket.onHandshakeFinished();
+                socket.onEngineHandshakeFinished();
             }
         });
 
@@ -194,8 +194,7 @@
                 synchronized (stateLock) {
                     // Initialize the handshake if we haven't already.
                     if (state == STATE_NEW) {
-                        state = STATE_HANDSHAKE_STARTED;
-                        handshakeStartedMillis = Platform.getMillisSinceBoot();
+                        transitionTo(STATE_HANDSHAKE_STARTED);
                         engine.beginHandshake();
                         in = new SSLInputStream();
                         out = new SSLOutputStream();
@@ -208,7 +207,6 @@
                         return;
                     }
                 }
-
                 doHandshake();
             }
         } catch (SSLException e) {
@@ -232,6 +230,7 @@
                     case NEED_UNWRAP:
                         if (in.processDataFromSocket(EmptyArray.BYTE, 0, 0) < 0) {
                             // Can't complete the handshake due to EOF.
+                            close();
                             throw SSLUtils.toSSLHandshakeException(
                                     new EOFException("connection closed"));
                         }
@@ -244,15 +243,13 @@
                     }
                     case NEED_TASK: {
                         // Should never get here, since our engine never provides tasks.
+                        close();
                         throw new IllegalStateException("Engine tasks are unsupported");
                     }
                     case NOT_HANDSHAKING:
                     case FINISHED: {
                         // Handshake is complete.
                         finished = true;
-                        Platform.countTlsHandshake(true, engine.getSession().getProtocol(),
-                                engine.getSession().getCipherSuite(),
-                                Platform.getMillisSinceBoot() - handshakeStartedMillis);
                         break;
                     }
                     default: {
@@ -261,11 +258,15 @@
                     }
                 }
             }
+            if (isState(STATE_HANDSHAKE_COMPLETED)) {
+                // STATE_READY_HANDSHAKE_CUT_THROUGH will wake up any waiting threads which can
+                // race with the listeners, but that's OK.
+                transitionTo(STATE_READY_HANDSHAKE_CUT_THROUGH);
+                notifyHandshakeCompletedListeners();
+                transitionTo(STATE_READY);
+            }
         } catch (SSLException e) {
             drainOutgoingQueue();
-            Platform.countTlsHandshake(false, engine.getSession().getProtocol(),
-                    engine.getSession().getCipherSuite(),
-                    Platform.getMillisSinceBoot() - handshakeStartedMillis);
             close();
             throw e;
         } catch (IOException e) {
@@ -278,6 +279,64 @@
         }
     }
 
+    private boolean isState(int desiredState) {
+        synchronized (stateLock) {
+            return state == desiredState;
+        }
+    }
+
+    private int transitionTo(int newState) {
+        synchronized (stateLock) {
+            if (state == newState) {
+                return state;
+            }
+
+            int previousState = state;
+            boolean notify = false;
+            switch (newState) {
+                case STATE_HANDSHAKE_STARTED:
+                    handshakeStartedMillis = Platform.getMillisSinceBoot();
+                    break;
+
+                case STATE_READY_HANDSHAKE_CUT_THROUGH:
+                    if (handshakeStartedMillis > 0) {
+                        Platform.countTlsHandshake(true,
+                            engine.getSession().getProtocol(),
+                            engine.getSession().getCipherSuite(),
+                            Platform.getMillisSinceBoot() - handshakeStartedMillis);
+                        handshakeStartedMillis = 0;
+                    }
+                    notify = true;
+                    break;
+
+                case STATE_READY:
+                    notify = true;
+                    break;
+
+                case STATE_CLOSED:
+                    if (handshakeStartedMillis > 0) {
+                        // Handshake must have failed.
+                        Platform.countTlsHandshake(false,
+                            engine.getSession().getProtocol(),
+                            engine.getSession().getCipherSuite(),
+                            Platform.getMillisSinceBoot() - handshakeStartedMillis);
+                        handshakeStartedMillis = 0;
+                    }
+                    notify = true;
+                    break;
+
+                default:
+                    break;
+            }
+
+            state = newState;
+            if (notify) {
+                stateLock.notifyAll();
+            }
+            return previousState;
+        }
+    }
+
     @Override
     public final InputStream getInputStream() throws IOException {
         checkOpen();
@@ -441,24 +500,14 @@
         // TODO: Close SSL sockets using a background thread so they close gracefully.
 
         if (stateLock == null) {
-            // close() has been called before we've initialized the socket, so just
-            // return.
+            // Constructor failed, e.g. superclass constructor called close()
             return;
         }
 
-        int previousState;
-        synchronized (stateLock) {
-            previousState = state;
-            if (state == STATE_CLOSED) {
-                // close() has already been called, so do nothing and return.
-                return;
-            }
-
-            state = STATE_CLOSED;
-
-            stateLock.notifyAll();
+        int previousState = transitionTo(STATE_CLOSED);
+        if (previousState == STATE_CLOSED) {
+            return;
         }
-
         try {
             // Close the engine.
             engine.closeInbound();
@@ -527,25 +576,12 @@
         this.bufferAllocator = bufferAllocator;
     }
 
-    private void onHandshakeFinished() {
-        boolean notify = false;
-        synchronized (stateLock) {
-            if (state != STATE_CLOSED) {
-                if (state == STATE_HANDSHAKE_STARTED) {
-                    state = STATE_READY_HANDSHAKE_CUT_THROUGH;
-                } else if (state == STATE_HANDSHAKE_COMPLETED) {
-                    state = STATE_READY;
-                }
-
-                // Unblock threads that are waiting for our state to transition
-                // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
-                stateLock.notifyAll();
-                notify = true;
-            }
-        }
-
-        if (notify) {
-            notifyHandshakeCompletedListeners();
+    private void onEngineHandshakeFinished() {
+        // Don't do anything here except change state.  This method will be called from
+        // e.g. wrap() which is non re-entrant so we can't call anything that might do
+        // IO until after it exits, e.g. in doHandshake().
+        if (isState(STATE_HANDSHAKE_STARTED)) {
+            transitionTo(STATE_HANDSHAKE_COMPLETED);
         }
     }
 
@@ -556,7 +592,9 @@
         startHandshake();
 
         synchronized (stateLock) {
-            while (state != STATE_READY && state != STATE_READY_HANDSHAKE_CUT_THROUGH
+            while (state != STATE_READY
+                    // Waiting threads are allowed to compete with handshake listeners for access.
+                    && state != STATE_READY_HANDSHAKE_CUT_THROUGH
                     && state != STATE_CLOSED) {
                 try {
                     stateLock.wait();
@@ -901,7 +939,7 @@
 
         private boolean isHandshakeFinished() {
             synchronized (stateLock) {
-                return state >= STATE_READY_HANDSHAKE_CUT_THROUGH;
+                return state > STATE_HANDSHAKE_STARTED;
             }
         }
 
diff --git a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java
index d0d5dd7..36d0cb1 100644
--- a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java
+++ b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java
@@ -16,20 +16,22 @@
 
 package org.conscrypt.javax.net.ssl;
 
-import static org.conscrypt.TestUtils.UTF_8;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.InputStream;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.net.SocketException;
 import java.net.SocketTimeoutException;
 import java.security.KeyManagementException;
 import java.security.NoSuchAlgorithmException;
@@ -44,7 +46,6 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
-import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import javax.crypto.SecretKey;
@@ -71,7 +72,6 @@
 import org.conscrypt.tlswire.handshake.HelloExtension;
 import org.conscrypt.tlswire.util.TlsProtocolVersion;
 import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -81,24 +81,14 @@
 
 @RunWith(JUnit4.class)
 public class SSLSocketTest {
-    private ExecutorService executor;
-    private ThreadGroup threadGroup;
-
-    @Before
-    public void setup() {
-        threadGroup = new ThreadGroup("SSLSocketTest");
-        executor = Executors.newCachedThreadPool(new ThreadFactory() {
-            @Override
-            public Thread newThread(Runnable r) {
-                return new Thread(threadGroup, r);
-            }
-        });
-    }
+    private final ThreadGroup threadGroup = new ThreadGroup("SSLSocketTest");
+    private final ExecutorService executor =
+        Executors.newCachedThreadPool(t -> new Thread(threadGroup, t));
 
     @After
     public void teardown() throws InterruptedException {
         executor.shutdownNow();
-        executor.awaitTermination(5, TimeUnit.SECONDS);
+        assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS));
     }
 
     @Test
@@ -110,8 +100,9 @@
     @Test
     public void test_SSLSocket_getSupportedCipherSuites_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+        }
     }
 
     @Test
@@ -131,7 +122,7 @@
     }
 
     private void test_SSLSocket_getSupportedCipherSuites_connect(
-            TestKeyStore testKeyStore, StringBuilder error) throws Exception {
+            TestKeyStore testKeyStore, StringBuilder error) {
         String clientToServerString = "this is sent from the client to the server...";
         String serverToClientString = "... and this from the server to the client";
         byte[] clientToServer = clientToServerString.getBytes(UTF_8);
@@ -207,21 +198,9 @@
                 // Check that the server and the client cannot read anything else
                 // (reads should time out)
                 server.setSoTimeout(10);
-                try {
-                    @SuppressWarnings("unused")
-                    int value = server.getInputStream().read();
-                    fail();
-                } catch (IOException expected) {
-                    // Ignored.
-                }
+                assertThrows(IOException.class, () -> server.getInputStream().read());
                 client.setSoTimeout(10);
-                try {
-                    @SuppressWarnings("unused")
-                    int value = client.getInputStream().read();
-                    fail();
-                } catch (IOException expected) {
-                    // Ignored.
-                }
+                assertThrows(IOException.class, () -> client.getInputStream().read());
                 client.close();
                 server.close();
             } catch (Exception maybeExpected) {
@@ -273,53 +252,44 @@
     @Test
     public void test_SSLSocket_getEnabledCipherSuites_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledCipherSuites_storesCopy() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        String[] array = new String[] {ssl.getEnabledCipherSuites()[0]};
-        String originalFirstElement = array[0];
-        ssl.setEnabledCipherSuites(array);
-        array[0] = "Modified after having been set";
-        assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            String[] array = new String[]{ssl.getEnabledCipherSuites()[0]};
+            String originalFirstElement = array[0];
+            ssl.setEnabledCipherSuites(array);
+            array[0] = "Modified after having been set";
+            assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledCipherSuites_TLS12() throws Exception {
         SSLContext context = SSLContext.getInstance("TLSv1.2");
         context.init(null, null, null);
-        SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket();
-        try {
-            ssl.setEnabledCipherSuites(null);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
+        try (SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket()) {
+            assertThrows(IllegalArgumentException.class,
+                () -> ssl.setEnabledCipherSuites(null));
+            assertThrows(IllegalArgumentException.class,
+                () -> ssl.setEnabledCipherSuites(new String[1]));
+            assertThrows(IllegalArgumentException.class,
+                () -> ssl.setEnabledCipherSuites(new String[]{"Bogus"}));
+            ssl.setEnabledCipherSuites(new String[0]);
+            ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
+            ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
+            // Check that setEnabledCipherSuites affects getEnabledCipherSuites
+            String[] cipherSuites = new String[]{
+                    TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
+            };
+            ssl.setEnabledCipherSuites(cipherSuites);
+            assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
         }
-        try {
-            ssl.setEnabledCipherSuites(new String[1]);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            ssl.setEnabledCipherSuites(new String[] {"Bogus"});
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        ssl.setEnabledCipherSuites(new String[0]);
-        ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
-        ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
-        // Check that setEnabledCipherSuites affects getEnabledCipherSuites
-        String[] cipherSuites = new String[] {
-                TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
-        };
-        ssl.setEnabledCipherSuites(cipherSuites);
-        assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
     }
 
     @Test
@@ -327,91 +297,81 @@
         SSLContext context = SSLContext.getInstance("TLSv1.3");
         context.init(null, null, null);
         SSLSocketFactory sf = context.getSocketFactory();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        // The TLS 1.3 cipher suites should be enabled by default
-        assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
-        // Disabling them should be ignored
-        ssl.setEnabledCipherSuites(new String[0]);
-        assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            // The TLS 1.3 cipher suites should be enabled by default
+            assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                    .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+            // Disabling them should be ignored
+            ssl.setEnabledCipherSuites(new String[0]);
+            assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                    .containsAll(StandardNames.CIPHER_SUITES_TLS13));
 
-        ssl.setEnabledCipherSuites(new String[] {
-                TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
-        });
-        assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+            ssl.setEnabledCipherSuites(new String[]{
+                    TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
+            });
+            assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                    .containsAll(StandardNames.CIPHER_SUITES_TLS13));
 
-        // Disabling TLS 1.3 should disable 1.3 cipher suites
-        ssl.setEnabledProtocols(new String[] { "TLSv1.2" });
-        assertFalse(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+            // Disabling TLS 1.3 should disable 1.3 cipher suites
+            ssl.setEnabledProtocols(new String[]{"TLSv1.2"});
+            assertFalse(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                    .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+        }
     }
 
     @Test
     public void test_SSLSocket_getSupportedProtocols_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+        }
     }
 
     @Test
     public void test_SSLSocket_getEnabledProtocols_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledProtocols_storesCopy() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        String[] array = new String[] {ssl.getEnabledProtocols()[0]};
-        String originalFirstElement = array[0];
-        ssl.setEnabledProtocols(array);
-        array[0] = "Modified after having been set";
-        assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            String[] array = new String[]{ssl.getEnabledProtocols()[0]};
+            String originalFirstElement = array[0];
+            ssl.setEnabledProtocols(array);
+            array[0] = "Modified after having been set";
+            assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledProtocols() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        try {
-            ssl.setEnabledProtocols(null);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            ssl.setEnabledProtocols(new String[1]);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            ssl.setEnabledProtocols(new String[] {"Bogus"});
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        ssl.setEnabledProtocols(new String[0]);
-        ssl.setEnabledProtocols(ssl.getEnabledProtocols());
-        ssl.setEnabledProtocols(ssl.getSupportedProtocols());
-        // Check that setEnabledProtocols affects getEnabledProtocols
-        for (String protocol : ssl.getSupportedProtocols()) {
-            if ("SSLv2Hello".equals(protocol)) {
-                try {
-                    ssl.setEnabledProtocols(new String[] {protocol});
-                    fail("Should fail when SSLv2Hello is set by itself");
-                } catch (IllegalArgumentException expected) {
-                    // Ignored.
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertThrows(IllegalArgumentException.class,
+                () -> ssl.setEnabledProtocols(null));
+            assertThrows(IllegalArgumentException.class,
+                () -> ssl.setEnabledProtocols(new String[1]));
+            assertThrows(IllegalArgumentException.class,
+                () -> ssl.setEnabledProtocols(new String[]{"Bogus"}));
+            ssl.setEnabledProtocols(new String[0]);
+            ssl.setEnabledProtocols(ssl.getEnabledProtocols());
+            ssl.setEnabledProtocols(ssl.getSupportedProtocols());
+            // Check that setEnabledProtocols affects getEnabledProtocols
+            for (String protocol : ssl.getSupportedProtocols()) {
+                if ("SSLv2Hello".equals(protocol)) {
+                    // Should fail when SSLv2Hello is set by itself
+                    assertThrows(IllegalArgumentException.class,
+                        () -> ssl.setEnabledProtocols(new String[]{protocol}));
+                } else {
+                    String[] protocols = new String[]{protocol};
+                    ssl.setEnabledProtocols(protocols);
+                    assertEquals(Arrays.deepToString(protocols),
+                            Arrays.deepToString(ssl.getEnabledProtocols()));
                 }
-            } else {
-                String[] protocols = new String[] {protocol};
-                ssl.setEnabledProtocols(protocols);
-                assertEquals(Arrays.deepToString(protocols),
-                        Arrays.deepToString(ssl.getEnabledProtocols()));
             }
         }
     }
@@ -430,11 +390,9 @@
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         server.setEnabledProtocols(new String[] {"TLSv1.3", "TLSv1.2", "TLSv1.1"});
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -461,11 +419,9 @@
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -481,18 +437,20 @@
     @Test
     public void test_SSLSocket_getSession() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        SSLSession session = ssl.getSession();
-        assertNotNull(session);
-        assertFalse(session.isValid());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            SSLSession session = ssl.getSession();
+            assertNotNull(session);
+            assertFalse(session.isValid());
+        }
     }
 
     @Test
     public void test_SSLSocket_getHandshakeSession_unconnected() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket socket = (SSLSocket) sf.createSocket();
-        SSLSession session = socket.getHandshakeSession();
-        assertNull(session);
+        try (SSLSocket socket = (SSLSocket) sf.createSocket()) {
+            SSLSession session = socket.getHandshakeSession();
+            assertNull(session);
+        }
     }
 
     @Test
@@ -570,11 +528,9 @@
             clientContext.getSocketFactory().createSocket(c.host, c.port);
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -673,12 +629,10 @@
             clientContext.getSocketFactory().createSocket(c.host, c.port);
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.setNeedClientAuth(true);
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.setNeedClientAuth(true);
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -691,21 +645,11 @@
     }
 
     @Test
-    public void test_SSLSocket_setUseClientMode_afterHandshake() throws Exception {
+    public void test_SSLSocket_setUseClientMode_afterHandshake() {
         // can't set after handshake
         TestSSLSocketPair pair = TestSSLSocketPair.create().connect();
-        try {
-            pair.server.setUseClientMode(false);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            pair.client.setUseClientMode(false);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
+        assertThrows(IllegalArgumentException.class, () -> pair.server.setUseClientMode(true));
+        assertThrows(IllegalArgumentException.class, () -> pair.client.setUseClientMode(false));
     }
 
     @Test
@@ -715,24 +659,14 @@
         SSLSocket client =
                 (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, c.port);
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
-        Future<Void> future = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                try {
-                    server.startHandshake();
-                    fail();
-                } catch (SSLHandshakeException expected) {
-                    // Ignored.
-                }
-                return null;
-            }
+        Future<Void> future = runAsync(() -> {
+            assertThrows(SSLHandshakeException.class, server::startHandshake);
+            return null;
         });
-        try {
-            client.startHandshake();
-            fail();
-        } catch (SSLHandshakeException expected) {
-            assertTrue(expected.getCause() instanceof CertificateException);
-        }
+        SSLHandshakeException expected =
+            assertThrows(SSLHandshakeException.class, client::startHandshake);
+        assertTrue(expected.getCause() instanceof CertificateException);
+
         future.get();
         client.close();
         server.close();
@@ -743,90 +677,93 @@
     public void test_SSLSocket_getSSLParameters() throws Exception {
         TestUtils.assumeSetEndpointIdentificationAlgorithmAvailable();
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        SSLParameters p = ssl.getSSLParameters();
-        assertNotNull(p);
-        String[] cipherSuites = p.getCipherSuites();
-        assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
-        assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
-        String[] protocols = p.getProtocols();
-        assertNotSame(protocols, ssl.getEnabledProtocols());
-        assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
-        assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
-        assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
-        assertNull(p.getEndpointIdentificationAlgorithm());
-        p.setEndpointIdentificationAlgorithm(null);
-        assertNull(p.getEndpointIdentificationAlgorithm());
-        p.setEndpointIdentificationAlgorithm("HTTPS");
-        assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
-        p.setEndpointIdentificationAlgorithm("FOO");
-        assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            SSLParameters p = ssl.getSSLParameters();
+            assertNotNull(p);
+            String[] cipherSuites = p.getCipherSuites();
+            assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
+            assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
+            String[] protocols = p.getProtocols();
+            assertNotSame(protocols, ssl.getEnabledProtocols());
+            assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
+            assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
+            assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
+            assertNull(p.getEndpointIdentificationAlgorithm());
+            p.setEndpointIdentificationAlgorithm(null);
+            assertNull(p.getEndpointIdentificationAlgorithm());
+            p.setEndpointIdentificationAlgorithm("HTTPS");
+            assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
+            p.setEndpointIdentificationAlgorithm("FOO");
+            assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+        }
     }
 
     @Test
     public void test_SSLSocket_setSSLParameters() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
-        String[] defaultProtocols = ssl.getEnabledProtocols();
-        String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
-        String[] supportedProtocols = ssl.getSupportedProtocols();
-        {
-            SSLParameters p = new SSLParameters();
-            ssl.setSSLParameters(p);
-            assertEquals(Arrays.asList(defaultCipherSuites),
-                    Arrays.asList(ssl.getEnabledCipherSuites()));
-            assertEquals(Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
-        }
-        {
-            SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
-            ssl.setSSLParameters(p);
-            assertEquals(Arrays.asList(supportedCipherSuites),
-                    Arrays.asList(ssl.getEnabledCipherSuites()));
-            assertEquals(
-                    Arrays.asList(supportedProtocols), Arrays.asList(ssl.getEnabledProtocols()));
-        }
-        {
-            SSLParameters p = new SSLParameters();
-            p.setNeedClientAuth(true);
-            assertFalse(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
-            ssl.setSSLParameters(p);
-            assertTrue(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
-            p.setWantClientAuth(true);
-            assertTrue(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
-            ssl.setSSLParameters(p);
-            assertFalse(ssl.getNeedClientAuth());
-            assertTrue(ssl.getWantClientAuth());
-            p.setWantClientAuth(false);
-            assertFalse(ssl.getNeedClientAuth());
-            assertTrue(ssl.getWantClientAuth());
-            ssl.setSSLParameters(p);
-            assertFalse(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
+            String[] defaultProtocols = ssl.getEnabledProtocols();
+            String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
+            String[] supportedProtocols = ssl.getSupportedProtocols();
+            {
+                SSLParameters p = new SSLParameters();
+                ssl.setSSLParameters(p);
+                assertEquals(Arrays.asList(defaultCipherSuites),
+                        Arrays.asList(ssl.getEnabledCipherSuites()));
+                assertEquals(Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
+            }
+            {
+                SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
+                ssl.setSSLParameters(p);
+                assertEquals(Arrays.asList(supportedCipherSuites),
+                        Arrays.asList(ssl.getEnabledCipherSuites()));
+                assertEquals(
+                        Arrays.asList(supportedProtocols), Arrays.asList(ssl.getEnabledProtocols()));
+            }
+            {
+                SSLParameters p = new SSLParameters();
+                p.setNeedClientAuth(true);
+                assertFalse(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+                ssl.setSSLParameters(p);
+                assertTrue(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+                p.setWantClientAuth(true);
+                assertTrue(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+                ssl.setSSLParameters(p);
+                assertFalse(ssl.getNeedClientAuth());
+                assertTrue(ssl.getWantClientAuth());
+                p.setWantClientAuth(false);
+                assertFalse(ssl.getNeedClientAuth());
+                assertTrue(ssl.getWantClientAuth());
+                ssl.setSSLParameters(p);
+                assertFalse(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+            }
         }
     }
 
     @Test
     public void test_SSLSocket_setSoTimeout_basic() throws Exception {
-        ServerSocket listening = new ServerSocket(0);
-        Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
-        assertEquals(0, underlying.getSoTimeout());
-        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        Socket wrapping = sf.createSocket(underlying, null, -1, false);
-        assertEquals(0, wrapping.getSoTimeout());
-        // setting wrapper sets underlying and ...
-        int expectedTimeoutMillis = 1000; // 10 was too small because it was affected by rounding
-        wrapping.setSoTimeout(expectedTimeoutMillis);
-        // The kernel can round the requested value based on the HZ setting. We allow up to 10ms.
-        assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
-        assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
-        // ... getting wrapper inspects underlying
-        underlying.setSoTimeout(0);
-        assertEquals(0, wrapping.getSoTimeout());
-        assertEquals(0, underlying.getSoTimeout());
+        try (ServerSocket listening = new ServerSocket(0)) {
+            Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
+            assertEquals(0, underlying.getSoTimeout());
+            SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
+            Socket wrapping = sf.createSocket(underlying, null, -1, false);
+            assertEquals(0, wrapping.getSoTimeout());
+            // setting wrapper sets underlying and ...
+            int expectedTimeoutMillis = 1000; // 10 was too small because it was affected by rounding
+            wrapping.setSoTimeout(expectedTimeoutMillis);
+            // The kernel can round the requested value based on the HZ setting. We allow up to 10ms.
+            assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
+            assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
+            // ... getting wrapper inspects underlying
+            underlying.setSoTimeout(0);
+            assertEquals(0, wrapping.getSoTimeout());
+            assertEquals(0, underlying.getSoTimeout());
+        }
     }
 
     @Test
@@ -838,13 +775,7 @@
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
         Socket clientWrapping = sf.createSocket(underlying, null, -1, false);
         underlying.setSoTimeout(1);
-        try {
-            @SuppressWarnings("unused")
-            int value = clientWrapping.getInputStream().read();
-            fail();
-        } catch (SocketTimeoutException expected) {
-            // Ignored.
-        }
+        assertThrows(SocketTimeoutException.class, () -> clientWrapping.getInputStream().read());
         clientWrapping.close();
         server.close();
         underlying.close();
@@ -870,90 +801,81 @@
 
     @Test
     public void test_SSLSocket_ClientHello_cipherSuites() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                final String[] cipherSuites;
-                // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
-                // a special signaling cipher suite. The TLS API has no way to check or
-                // indicate that a certain TLS extension should be used.
-                HelloExtension renegotiationInfoExtension =
-                    clientHello.findExtensionByType(HelloExtension.TYPE_RENEGOTIATION_INFO);
-                if (renegotiationInfoExtension != null
-                    && renegotiationInfoExtension.data.length == 1
-                    && renegotiationInfoExtension.data[0] == 0) {
-                    cipherSuites = new String[clientHello.cipherSuites.size() + 1];
-                    cipherSuites[clientHello.cipherSuites.size()] =
-                        StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION;
-                } else {
-                    cipherSuites = new String[clientHello.cipherSuites.size()];
-                }
-                for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
-                    CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
-                    cipherSuites[i] = cipherSuite.getAndroidName();
-                }
-                StandardNames.assertDefaultCipherSuites(cipherSuites);
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello = TlsTester
+                    .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            final String[] cipherSuites;
+            // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
+            // a special signaling cipher suite. The TLS API has no way to check or
+            // indicate that a certain TLS extension should be used.
+            HelloExtension renegotiationInfoExtension =
+                clientHello.findExtensionByType(HelloExtension.TYPE_RENEGOTIATION_INFO);
+            if (renegotiationInfoExtension != null
+                && renegotiationInfoExtension.data.length == 1
+                && renegotiationInfoExtension.data[0] == 0) {
+                cipherSuites = new String[clientHello.cipherSuites.size() + 1];
+                cipherSuites[clientHello.cipherSuites.size()] =
+                    StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION;
+            } else {
+                cipherSuites = new String[clientHello.cipherSuites.size()];
             }
-        }, getSSLSocketFactoriesToTest());
+            for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
+                CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
+                cipherSuites[i] = cipherSuite.getAndroidName();
+            }
+            StandardNames.assertDefaultCipherSuites(cipherSuites);
+        },
+            getSSLSocketFactoriesToTest());
     }
 
     @Test
     public void test_SSLSocket_ClientHello_supportedCurves() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                EllipticCurvesHelloExtension ecExtension =
-                    (EllipticCurvesHelloExtension) clientHello.findExtensionByType(
-                        HelloExtension.TYPE_ELLIPTIC_CURVES);
-                final String[] supportedCurves;
-                if (ecExtension == null) {
-                    supportedCurves = new String[0];
-                } else {
-                    assertTrue(ecExtension.wellFormed);
-                    supportedCurves = new String[ecExtension.supported.size()];
-                    for (int i = 0; i < ecExtension.supported.size(); i++) {
-                        EllipticCurve curve = ecExtension.supported.get(i);
-                        supportedCurves[i] = curve.toString();
-                    }
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello = TlsTester
+                    .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            EllipticCurvesHelloExtension ecExtension =
+                (EllipticCurvesHelloExtension) clientHello.findExtensionByType(
+                    HelloExtension.TYPE_ELLIPTIC_CURVES);
+            final String[] supportedCurves;
+            if (ecExtension == null) {
+                supportedCurves = new String[0];
+            } else {
+                assertTrue(ecExtension.wellFormed);
+                supportedCurves = new String[ecExtension.supported.size()];
+                for (int i = 0; i < ecExtension.supported.size(); i++) {
+                    EllipticCurve curve = ecExtension.supported.get(i);
+                    supportedCurves[i] = curve.toString();
                 }
-                StandardNames.assertDefaultEllipticCurves(supportedCurves);
             }
-        }, getSSLSocketFactoriesToTest());
+            StandardNames.assertDefaultEllipticCurves(supportedCurves);
+        },
+            getSSLSocketFactoriesToTest());
     }
 
     @Test
     public void test_SSLSocket_ClientHello_clientProtocolVersion() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
-            }
-        }, getSSLSocketFactoriesToTest());
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello = TlsTester
+                    .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
+        },
+            getSSLSocketFactoriesToTest());
     }
 
     @Test
     public void test_SSLSocket_ClientHello_compressionMethods() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                assertEquals(Collections.singletonList(CompressionMethod.NULL),
-                    clientHello.compressionMethods);
-            }
-        }, getSSLSocketFactoriesToTest());
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello = TlsTester
+                    .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            assertEquals(Collections.singletonList(CompressionMethod.NULL),
+                clientHello.compressionMethods);
+        },
+            getSSLSocketFactoriesToTest());
     }
 
     private List<Pair<String, SSLSocketFactory>> getSSLSocketFactoriesToTest()
             throws NoSuchAlgorithmException, KeyManagementException {
-        List<Pair<String, SSLSocketFactory>> result =
-                new ArrayList<Pair<String, SSLSocketFactory>>();
+        List<Pair<String, SSLSocketFactory>> result = new ArrayList<>();
         result.add(Pair.of("default", (SSLSocketFactory) SSLSocketFactory.getDefault()));
         for (String sslContextProtocol : StandardNames.SSL_CONTEXT_PROTOCOLS_WITH_DEFAULT_CONFIG) {
             SSLContext sslContext = SSLContext.getInstance(sslContextProtocol);
@@ -977,23 +899,17 @@
         final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
         System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
         clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
-        Future<Void> s = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                server.setEnabledProtocols(new String[]{"TLSv1.2"});
-                server.setEnabledCipherSuites(serverCipherSuites);
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> s = runAsync(() -> {
+            server.setEnabledProtocols(new String[]{"TLSv1.2"});
+            server.setEnabledCipherSuites(serverCipherSuites);
+            server.startHandshake();
+            return null;
         });
-        Future<Void> c = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                client.setEnabledProtocols(new String[]{"TLSv1.2"});
-                client.setEnabledCipherSuites(clientCipherSuites);
-                client.startHandshake();
-                return null;
-            }
+        Future<Void> c = runAsync(() -> {
+            client.setEnabledProtocols(new String[]{"TLSv1.2"});
+            client.setEnabledCipherSuites(clientCipherSuites);
+            client.startHandshake();
+            return null;
         });
         s.get();
         c.get();
@@ -1012,21 +928,15 @@
         // Confirm absence of TLS_FALLBACK_SCSV.
         assertFalse(Arrays.asList(client.getEnabledCipherSuites())
                             .contains(StandardNames.CIPHER_SUITE_FALLBACK));
-        Future<Void> s = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                server.setEnabledProtocols(new String[]{"TLSv1.2", "TLSv1.1"});
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> s = runAsync(() -> {
+            server.setEnabledProtocols(new String[]{"TLSv1.2", "TLSv1.1"});
+            server.startHandshake();
+            return null;
         });
-        Future<Void> c = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                client.setEnabledProtocols(new String[]{"TLSv1.1"});
-                client.startHandshake();
-                return null;
-            }
+        Future<Void> c = runAsync(() -> {
+            client.setEnabledProtocols(new String[]{"TLSv1.1"});
+            client.startHandshake();
+            return null;
         });
         s.get();
         c.get();
@@ -1053,37 +963,25 @@
         final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
         System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
         clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
-        Future<Void> s = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
-                server.setEnabledCipherSuites(serverCipherSuites);
-                try {
-                    server.startHandshake();
-                    fail("Should result in inappropriate fallback");
-                } catch (SSLHandshakeException expected) {
-                    Throwable cause = expected.getCause();
-                    assertEquals(SSLProtocolException.class, cause.getClass());
-                    assertInappropriateFallbackIsCause(cause);
-                }
-                return null;
-            }
+        Future<Void> s = runAsync(() -> {
+            server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
+            server.setEnabledCipherSuites(serverCipherSuites);
+            SSLHandshakeException expected =
+                assertThrows(SSLHandshakeException.class, server::startHandshake);
+            Throwable cause = expected.getCause();
+            assertEquals(SSLProtocolException.class, cause.getClass());
+            assertInappropriateFallbackIsCause(cause);
+            return null;
         });
-        Future<Void> c = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                client.setEnabledProtocols(new String[]{"TLSv1.1"});
-                client.setEnabledCipherSuites(clientCipherSuites);
-                try {
-                    client.startHandshake();
-                    fail("Should receive TLS alert inappropriate fallback");
-                } catch (SSLHandshakeException expected) {
-                    Throwable cause = expected.getCause();
-                    assertEquals(SSLProtocolException.class, cause.getClass());
-                    assertInappropriateFallbackIsCause(cause);
-                }
-                return null;
-            }
+        Future<Void> c = runAsync(() -> {
+            client.setEnabledProtocols(new String[]{"TLSv1.1"});
+            client.setEnabledCipherSuites(clientCipherSuites);
+            SSLHandshakeException expected =
+                assertThrows(SSLHandshakeException.class, client::startHandshake);
+            Throwable cause = expected.getCause();
+            assertEquals(SSLProtocolException.class, cause.getClass());
+            assertInappropriateFallbackIsCause(cause);
+            return null;
         });
         s.get();
         c.get();
@@ -1118,6 +1016,74 @@
         }
     }
 
+    @Test
+    public void handshakeListenersRunExactlyOnce() {
+        AtomicInteger count = new AtomicInteger(0);
+        TestSSLSocketPair pair = TestSSLSocketPair.create();
+        pair.client.addHandshakeCompletedListener(event -> count.addAndGet(1));
+        pair.client.addHandshakeCompletedListener(event -> count.addAndGet(2));
+        pair.client.addHandshakeCompletedListener(event -> count.addAndGet(4));
+        pair.connect();
+        assertEquals(1 + 2 + 4, count.get());
+    }
+
+    @Test
+    public void closeFromHandshakeListener() throws Exception {
+        TestUtils.assumeEngineSocket();
+
+        TestSSLSocketPair pair = TestSSLSocketPair.create();
+        pair.client.addHandshakeCompletedListener(event -> socketClose(pair.client));
+        Future<Void> serverFuture = runAsync((Callable<Void>) () -> {
+            pair.server.startHandshake();
+            return null;
+        });
+        pair.client.startHandshake();
+        assertThrows(SocketException.class, pair.client::getInputStream);
+        serverFuture.get();
+        InputStream istream = pair.server.getInputStream();
+        assertEquals(-1, istream.read());
+    }
+
+    @Test
+    public void writeFromHandshakeListener() throws Exception {
+        TestUtils.assumeEngineSocket();
+
+        byte[] ping = "ping".getBytes(UTF_8);
+        byte[] pong = "pong".getBytes(UTF_8);
+        TestSSLSocketPair pair = TestSSLSocketPair.create();
+        pair.client.addHandshakeCompletedListener(event -> socketWrite(pair.client, ping));
+        pair.server.addHandshakeCompletedListener(event -> socketWrite(pair.server, pong));
+        Future<Void> serverFuture = runAsync(() -> {
+            pair.server.startHandshake();
+            return null;
+        });
+        byte[] buffer = new byte[4];
+        InputStream clientStream = pair.client.getInputStream();
+        assertEquals(4, clientStream.read(buffer));
+        assertArrayEquals(pong, buffer);
+
+        serverFuture.get();
+        InputStream serverStream = pair.server.getInputStream();
+        assertEquals(4, serverStream.read(buffer));
+        assertArrayEquals(ping, buffer);
+    }
+
+    private void socketClose(Socket socket) {
+        try {
+            socket.close();
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private void socketWrite(Socket socket, byte[] data) {
+        try {
+            socket.getOutputStream().write(data);
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
     private <T> Future<T> runAsync(Callable<T> callable) {
         return executor.submit(callable);
     }
@@ -1134,5 +1100,4 @@
             byteCount -= bytesRead;
         }
     }
-
 }
diff --git a/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java b/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java
index 69a341f..c641841 100644
--- a/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java
+++ b/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java
@@ -61,7 +61,7 @@
     private SSLOutputStream out;
     private SSLInputStream in;
 
-    private long handshakeStartedMillis;
+    private long handshakeStartedMillis = 0;
 
     private BufferAllocator bufferAllocator = ConscryptEngine.getDefaultBufferAllocator();
 
@@ -124,7 +124,7 @@
             @Override
             public void onHandshakeFinished() {
                 // Just call the outer class method.
-                socket.onHandshakeFinished();
+                socket.onEngineHandshakeFinished();
             }
         });
 
@@ -202,8 +202,7 @@
                 synchronized (stateLock) {
                     // Initialize the handshake if we haven't already.
                     if (state == STATE_NEW) {
-                        state = STATE_HANDSHAKE_STARTED;
-                        handshakeStartedMillis = Platform.getMillisSinceBoot();
+                        transitionTo(STATE_HANDSHAKE_STARTED);
                         engine.beginHandshake();
                         in = new SSLInputStream();
                         out = new SSLOutputStream();
@@ -216,7 +215,6 @@
                         return;
                     }
                 }
-
                 doHandshake();
             }
         } catch (SSLException e) {
@@ -240,6 +238,7 @@
                     case NEED_UNWRAP:
                         if (in.processDataFromSocket(EmptyArray.BYTE, 0, 0) < 0) {
                             // Can't complete the handshake due to EOF.
+                            close();
                             throw SSLUtils.toSSLHandshakeException(
                                     new EOFException("connection closed"));
                         }
@@ -252,15 +251,13 @@
                     }
                     case NEED_TASK: {
                         // Should never get here, since our engine never provides tasks.
+                        close();
                         throw new IllegalStateException("Engine tasks are unsupported");
                     }
                     case NOT_HANDSHAKING:
                     case FINISHED: {
                         // Handshake is complete.
                         finished = true;
-                        Platform.countTlsHandshake(true, engine.getSession().getProtocol(),
-                                engine.getSession().getCipherSuite(),
-                                Platform.getMillisSinceBoot() - handshakeStartedMillis);
                         break;
                     }
                     default: {
@@ -269,11 +266,15 @@
                     }
                 }
             }
+            if (isState(STATE_HANDSHAKE_COMPLETED)) {
+                // STATE_READY_HANDSHAKE_CUT_THROUGH will wake up any waiting threads which can
+                // race with the listeners, but that's OK.
+                transitionTo(STATE_READY_HANDSHAKE_CUT_THROUGH);
+                notifyHandshakeCompletedListeners();
+                transitionTo(STATE_READY);
+            }
         } catch (SSLException e) {
             drainOutgoingQueue();
-            Platform.countTlsHandshake(false, engine.getSession().getProtocol(),
-                    engine.getSession().getCipherSuite(),
-                    Platform.getMillisSinceBoot() - handshakeStartedMillis);
             close();
             throw e;
         } catch (IOException e) {
@@ -286,6 +287,62 @@
         }
     }
 
+    private boolean isState(int desiredState) {
+        synchronized (stateLock) {
+            return state == desiredState;
+        }
+    }
+
+    private int transitionTo(int newState) {
+        synchronized (stateLock) {
+            if (state == newState) {
+                return state;
+            }
+
+            int previousState = state;
+            boolean notify = false;
+            switch (newState) {
+                case STATE_HANDSHAKE_STARTED:
+                    handshakeStartedMillis = Platform.getMillisSinceBoot();
+                    break;
+
+                case STATE_READY_HANDSHAKE_CUT_THROUGH:
+                    if (handshakeStartedMillis > 0) {
+                        Platform.countTlsHandshake(true, engine.getSession().getProtocol(),
+                                engine.getSession().getCipherSuite(),
+                                Platform.getMillisSinceBoot() - handshakeStartedMillis);
+                        handshakeStartedMillis = 0;
+                    }
+                    notify = true;
+                    break;
+
+                case STATE_READY:
+                    notify = true;
+                    break;
+
+                case STATE_CLOSED:
+                    if (handshakeStartedMillis > 0) {
+                        // Handshake must have failed.
+                        Platform.countTlsHandshake(false, engine.getSession().getProtocol(),
+                                engine.getSession().getCipherSuite(),
+                                Platform.getMillisSinceBoot() - handshakeStartedMillis);
+                        handshakeStartedMillis = 0;
+                    }
+                    notify = true;
+                    break;
+
+                default:
+                    break;
+            }
+
+            state = newState;
+            if (notify) {
+                stateLock.notifyAll();
+            }
+            return previousState;
+        }
+    }
+
     @Override
     public final InputStream getInputStream() throws IOException {
         checkOpen();
@@ -457,24 +514,14 @@
         // TODO: Close SSL sockets using a background thread so they close gracefully.
 
         if (stateLock == null) {
-            // close() has been called before we've initialized the socket, so just
-            // return.
+            // Constructor failed, e.g. superclass constructor called close()
             return;
         }
 
-        int previousState;
-        synchronized (stateLock) {
-            previousState = state;
-            if (state == STATE_CLOSED) {
-                // close() has already been called, so do nothing and return.
-                return;
-            }
-
-            state = STATE_CLOSED;
-
-            stateLock.notifyAll();
+        int previousState = transitionTo(STATE_CLOSED);
+        if (previousState == STATE_CLOSED) {
+            return;
         }
-
         try {
             // Close the engine.
             engine.closeInbound();
@@ -543,25 +590,12 @@
         this.bufferAllocator = bufferAllocator;
     }
 
-    private void onHandshakeFinished() {
-        boolean notify = false;
-        synchronized (stateLock) {
-            if (state != STATE_CLOSED) {
-                if (state == STATE_HANDSHAKE_STARTED) {
-                    state = STATE_READY_HANDSHAKE_CUT_THROUGH;
-                } else if (state == STATE_HANDSHAKE_COMPLETED) {
-                    state = STATE_READY;
-                }
-
-                // Unblock threads that are waiting for our state to transition
-                // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
-                stateLock.notifyAll();
-                notify = true;
-            }
-        }
-
-        if (notify) {
-            notifyHandshakeCompletedListeners();
+    private void onEngineHandshakeFinished() {
+        // Don't do anything here except change state.  This method will be called from
+        // e.g. wrap() which is non re-entrant so we can't call anything that might do
+        // IO until after it exits, e.g. in doHandshake().
+        if (isState(STATE_HANDSHAKE_STARTED)) {
+            transitionTo(STATE_HANDSHAKE_COMPLETED);
         }
     }
 
@@ -572,8 +606,9 @@
         startHandshake();
 
         synchronized (stateLock) {
-            while (state != STATE_READY && state != STATE_READY_HANDSHAKE_CUT_THROUGH
-                    && state != STATE_CLOSED) {
+            while (state != STATE_READY
+                    // Waiting threads are allowed to compete with handshake listeners for access.
+                    && state != STATE_READY_HANDSHAKE_CUT_THROUGH && state != STATE_CLOSED) {
                 try {
                     stateLock.wait();
                 } catch (InterruptedException e) {
@@ -917,7 +952,7 @@
 
         private boolean isHandshakeFinished() {
             synchronized (stateLock) {
-                return state >= STATE_READY_HANDSHAKE_CUT_THROUGH;
+                return state > STATE_HANDSHAKE_STARTED;
             }
         }
 
diff --git a/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java b/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java
index 80c8486..4a3f257 100644
--- a/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java
+++ b/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java
@@ -17,14 +17,15 @@
 
 package com.android.org.conscrypt.javax.net.ssl;
 
-import static com.android.org.conscrypt.TestUtils.UTF_8;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 import com.android.org.conscrypt.TestUtils;
 import com.android.org.conscrypt.java.security.StandardNames;
@@ -42,6 +43,7 @@
 import java.io.InputStream;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.net.SocketException;
 import java.net.SocketTimeoutException;
 import java.security.KeyManagementException;
 import java.security.NoSuchAlgorithmException;
@@ -56,7 +58,6 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
-import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import javax.crypto.SecretKey;
@@ -72,7 +73,6 @@
 import javax.net.ssl.SSLSocketFactory;
 import javax.net.ssl.X509ExtendedTrustManager;
 import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -85,24 +85,14 @@
  */
 @RunWith(JUnit4.class)
 public class SSLSocketTest {
-    private ExecutorService executor;
-    private ThreadGroup threadGroup;
-
-    @Before
-    public void setup() {
-        threadGroup = new ThreadGroup("SSLSocketTest");
-        executor = Executors.newCachedThreadPool(new ThreadFactory() {
-            @Override
-            public Thread newThread(Runnable r) {
-                return new Thread(threadGroup, r);
-            }
-        });
-    }
+    private final ThreadGroup threadGroup = new ThreadGroup("SSLSocketTest");
+    private final ExecutorService executor =
+            Executors.newCachedThreadPool(t -> new Thread(threadGroup, t));
 
     @After
     public void teardown() throws InterruptedException {
         executor.shutdownNow();
-        executor.awaitTermination(5, TimeUnit.SECONDS);
+        assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS));
     }
 
     @Test
@@ -114,8 +104,9 @@
     @Test
     public void test_SSLSocket_getSupportedCipherSuites_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+        }
     }
 
     @Test
@@ -135,7 +126,7 @@
     }
 
     private void test_SSLSocket_getSupportedCipherSuites_connect(
-            TestKeyStore testKeyStore, StringBuilder error) throws Exception {
+            TestKeyStore testKeyStore, StringBuilder error) {
         String clientToServerString = "this is sent from the client to the server...";
         String serverToClientString = "... and this from the server to the client";
         byte[] clientToServer = clientToServerString.getBytes(UTF_8);
@@ -211,21 +202,9 @@
                 // Check that the server and the client cannot read anything else
                 // (reads should time out)
                 server.setSoTimeout(10);
-                try {
-                    @SuppressWarnings("unused")
-                    int value = server.getInputStream().read();
-                    fail();
-                } catch (IOException expected) {
-                    // Ignored.
-                }
+                assertThrows(IOException.class, () -> server.getInputStream().read());
                 client.setSoTimeout(10);
-                try {
-                    @SuppressWarnings("unused")
-                    int value = client.getInputStream().read();
-                    fail();
-                } catch (IOException expected) {
-                    // Ignored.
-                }
+                assertThrows(IOException.class, () -> client.getInputStream().read());
                 client.close();
                 server.close();
             } catch (Exception maybeExpected) {
@@ -277,53 +256,42 @@
     @Test
     public void test_SSLSocket_getEnabledCipherSuites_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledCipherSuites_storesCopy() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        String[] array = new String[] {ssl.getEnabledCipherSuites()[0]};
-        String originalFirstElement = array[0];
-        ssl.setEnabledCipherSuites(array);
-        array[0] = "Modified after having been set";
-        assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            String[] array = new String[] {ssl.getEnabledCipherSuites()[0]};
+            String originalFirstElement = array[0];
+            ssl.setEnabledCipherSuites(array);
+            array[0] = "Modified after having been set";
+            assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledCipherSuites_TLS12() throws Exception {
         SSLContext context = SSLContext.getInstance("TLSv1.2");
         context.init(null, null, null);
-        SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket();
-        try {
-            ssl.setEnabledCipherSuites(null);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
+        try (SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket()) {
+            assertThrows(IllegalArgumentException.class, () -> ssl.setEnabledCipherSuites(null));
+            assertThrows(IllegalArgumentException.class,
+                    () -> ssl.setEnabledCipherSuites(new String[1]));
+            assertThrows(IllegalArgumentException.class,
+                    () -> ssl.setEnabledCipherSuites(new String[] {"Bogus"}));
+            ssl.setEnabledCipherSuites(new String[0]);
+            ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
+            ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
+            // Check that setEnabledCipherSuites affects getEnabledCipherSuites
+            String[] cipherSuites = new String[] {
+                    TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())};
+            ssl.setEnabledCipherSuites(cipherSuites);
+            assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
         }
-        try {
-            ssl.setEnabledCipherSuites(new String[1]);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            ssl.setEnabledCipherSuites(new String[] {"Bogus"});
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        ssl.setEnabledCipherSuites(new String[0]);
-        ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
-        ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
-        // Check that setEnabledCipherSuites affects getEnabledCipherSuites
-        String[] cipherSuites = new String[] {
-                TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
-        };
-        ssl.setEnabledCipherSuites(cipherSuites);
-        assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
     }
 
     @Test
@@ -331,91 +299,79 @@
         SSLContext context = SSLContext.getInstance("TLSv1.3");
         context.init(null, null, null);
         SSLSocketFactory sf = context.getSocketFactory();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        // The TLS 1.3 cipher suites should be enabled by default
-        assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
-        // Disabling them should be ignored
-        ssl.setEnabledCipherSuites(new String[0]);
-        assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            // The TLS 1.3 cipher suites should be enabled by default
+            assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                               .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+            // Disabling them should be ignored
+            ssl.setEnabledCipherSuites(new String[0]);
+            assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                               .containsAll(StandardNames.CIPHER_SUITES_TLS13));
 
-        ssl.setEnabledCipherSuites(new String[] {
-                TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
-        });
-        assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+            ssl.setEnabledCipherSuites(new String[] {
+                    TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())});
+            assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                               .containsAll(StandardNames.CIPHER_SUITES_TLS13));
 
-        // Disabling TLS 1.3 should disable 1.3 cipher suites
-        ssl.setEnabledProtocols(new String[] { "TLSv1.2" });
-        assertFalse(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
-                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+            // Disabling TLS 1.3 should disable 1.3 cipher suites
+            ssl.setEnabledProtocols(new String[] {"TLSv1.2"});
+            assertFalse(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+                                .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+        }
     }
 
     @Test
     public void test_SSLSocket_getSupportedProtocols_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+        }
     }
 
     @Test
     public void test_SSLSocket_getEnabledProtocols_returnsCopies() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledProtocols_storesCopy() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        String[] array = new String[] {ssl.getEnabledProtocols()[0]};
-        String originalFirstElement = array[0];
-        ssl.setEnabledProtocols(array);
-        array[0] = "Modified after having been set";
-        assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            String[] array = new String[] {ssl.getEnabledProtocols()[0]};
+            String originalFirstElement = array[0];
+            ssl.setEnabledProtocols(array);
+            array[0] = "Modified after having been set";
+            assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+        }
     }
 
     @Test
     public void test_SSLSocket_setEnabledProtocols() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        try {
-            ssl.setEnabledProtocols(null);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            ssl.setEnabledProtocols(new String[1]);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            ssl.setEnabledProtocols(new String[] {"Bogus"});
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        ssl.setEnabledProtocols(new String[0]);
-        ssl.setEnabledProtocols(ssl.getEnabledProtocols());
-        ssl.setEnabledProtocols(ssl.getSupportedProtocols());
-        // Check that setEnabledProtocols affects getEnabledProtocols
-        for (String protocol : ssl.getSupportedProtocols()) {
-            if ("SSLv2Hello".equals(protocol)) {
-                try {
-                    ssl.setEnabledProtocols(new String[] {protocol});
-                    fail("Should fail when SSLv2Hello is set by itself");
-                } catch (IllegalArgumentException expected) {
-                    // Ignored.
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            assertThrows(IllegalArgumentException.class, () -> ssl.setEnabledProtocols(null));
+            assertThrows(
+                    IllegalArgumentException.class, () -> ssl.setEnabledProtocols(new String[1]));
+            assertThrows(IllegalArgumentException.class,
+                    () -> ssl.setEnabledProtocols(new String[] {"Bogus"}));
+            ssl.setEnabledProtocols(new String[0]);
+            ssl.setEnabledProtocols(ssl.getEnabledProtocols());
+            ssl.setEnabledProtocols(ssl.getSupportedProtocols());
+            // Check that setEnabledProtocols affects getEnabledProtocols
+            for (String protocol : ssl.getSupportedProtocols()) {
+                if ("SSLv2Hello".equals(protocol)) {
+                    // Should fail when SSLv2Hello is set by itself
+                    assertThrows(IllegalArgumentException.class,
+                            () -> ssl.setEnabledProtocols(new String[] {protocol}));
+                } else {
+                    String[] protocols = new String[] {protocol};
+                    ssl.setEnabledProtocols(protocols);
+                    assertEquals(Arrays.deepToString(protocols),
+                            Arrays.deepToString(ssl.getEnabledProtocols()));
                 }
-            } else {
-                String[] protocols = new String[] {protocol};
-                ssl.setEnabledProtocols(protocols);
-                assertEquals(Arrays.deepToString(protocols),
-                        Arrays.deepToString(ssl.getEnabledProtocols()));
             }
         }
     }
@@ -434,11 +390,9 @@
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         server.setEnabledProtocols(new String[] {"TLSv1.3", "TLSv1.2", "TLSv1.1"});
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -465,11 +419,9 @@
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -485,18 +437,20 @@
     @Test
     public void test_SSLSocket_getSession() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        SSLSession session = ssl.getSession();
-        assertNotNull(session);
-        assertFalse(session.isValid());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            SSLSession session = ssl.getSession();
+            assertNotNull(session);
+            assertFalse(session.isValid());
+        }
     }
 
     @Test
     public void test_SSLSocket_getHandshakeSession_unconnected() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket socket = (SSLSocket) sf.createSocket();
-        SSLSession session = socket.getHandshakeSession();
-        assertNull(session);
+        try (SSLSocket socket = (SSLSocket) sf.createSocket()) {
+            SSLSession session = socket.getHandshakeSession();
+            assertNull(session);
+        }
     }
 
     @Test
@@ -574,11 +528,9 @@
             clientContext.getSocketFactory().createSocket(c.host, c.port);
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -677,12 +629,10 @@
             clientContext.getSocketFactory().createSocket(c.host, c.port);
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
         ExecutorService executor = Executors.newSingleThreadExecutor();
-        Future<Void> future = executor.submit(new Callable<Void>() {
-            @Override public Void call() throws Exception {
-                server.setNeedClientAuth(true);
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> future = executor.submit(() -> {
+            server.setNeedClientAuth(true);
+            server.startHandshake();
+            return null;
         });
         executor.shutdown();
         client.startHandshake();
@@ -695,21 +645,11 @@
     }
 
     @Test
-    public void test_SSLSocket_setUseClientMode_afterHandshake() throws Exception {
+    public void test_SSLSocket_setUseClientMode_afterHandshake() {
         // can't set after handshake
         TestSSLSocketPair pair = TestSSLSocketPair.create().connect();
-        try {
-            pair.server.setUseClientMode(false);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
-        try {
-            pair.client.setUseClientMode(false);
-            fail();
-        } catch (IllegalArgumentException expected) {
-            // Ignored.
-        }
+        assertThrows(IllegalArgumentException.class, () -> pair.server.setUseClientMode(true));
+        assertThrows(IllegalArgumentException.class, () -> pair.client.setUseClientMode(false));
     }
 
     @Test
@@ -719,24 +659,14 @@
         SSLSocket client =
                 (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, c.port);
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
-        Future<Void> future = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                try {
-                    server.startHandshake();
-                    fail();
-                } catch (SSLHandshakeException expected) {
-                    // Ignored.
-                }
-                return null;
-            }
+        Future<Void> future = runAsync(() -> {
+            assertThrows(SSLHandshakeException.class, server::startHandshake);
+            return null;
         });
-        try {
-            client.startHandshake();
-            fail();
-        } catch (SSLHandshakeException expected) {
-            assertTrue(expected.getCause() instanceof CertificateException);
-        }
+        SSLHandshakeException expected =
+                assertThrows(SSLHandshakeException.class, client::startHandshake);
+        assertTrue(expected.getCause() instanceof CertificateException);
+
         future.get();
         client.close();
         server.close();
@@ -747,90 +677,96 @@
     public void test_SSLSocket_getSSLParameters() throws Exception {
         TestUtils.assumeSetEndpointIdentificationAlgorithmAvailable();
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        SSLParameters p = ssl.getSSLParameters();
-        assertNotNull(p);
-        String[] cipherSuites = p.getCipherSuites();
-        assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
-        assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
-        String[] protocols = p.getProtocols();
-        assertNotSame(protocols, ssl.getEnabledProtocols());
-        assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
-        assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
-        assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
-        assertNull(p.getEndpointIdentificationAlgorithm());
-        p.setEndpointIdentificationAlgorithm(null);
-        assertNull(p.getEndpointIdentificationAlgorithm());
-        p.setEndpointIdentificationAlgorithm("HTTPS");
-        assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
-        p.setEndpointIdentificationAlgorithm("FOO");
-        assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            SSLParameters p = ssl.getSSLParameters();
+            assertNotNull(p);
+            String[] cipherSuites = p.getCipherSuites();
+            assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
+            assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
+            String[] protocols = p.getProtocols();
+            assertNotSame(protocols, ssl.getEnabledProtocols());
+            assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
+            assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
+            assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
+            assertNull(p.getEndpointIdentificationAlgorithm());
+            p.setEndpointIdentificationAlgorithm(null);
+            assertNull(p.getEndpointIdentificationAlgorithm());
+            p.setEndpointIdentificationAlgorithm("HTTPS");
+            assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
+            p.setEndpointIdentificationAlgorithm("FOO");
+            assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+        }
     }
 
     @Test
     public void test_SSLSocket_setSSLParameters() throws Exception {
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        SSLSocket ssl = (SSLSocket) sf.createSocket();
-        String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
-        String[] defaultProtocols = ssl.getEnabledProtocols();
-        String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
-        String[] supportedProtocols = ssl.getSupportedProtocols();
-        {
-            SSLParameters p = new SSLParameters();
-            ssl.setSSLParameters(p);
-            assertEquals(Arrays.asList(defaultCipherSuites),
-                    Arrays.asList(ssl.getEnabledCipherSuites()));
-            assertEquals(Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
-        }
-        {
-            SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
-            ssl.setSSLParameters(p);
-            assertEquals(Arrays.asList(supportedCipherSuites),
-                    Arrays.asList(ssl.getEnabledCipherSuites()));
-            assertEquals(
-                    Arrays.asList(supportedProtocols), Arrays.asList(ssl.getEnabledProtocols()));
-        }
-        {
-            SSLParameters p = new SSLParameters();
-            p.setNeedClientAuth(true);
-            assertFalse(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
-            ssl.setSSLParameters(p);
-            assertTrue(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
-            p.setWantClientAuth(true);
-            assertTrue(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
-            ssl.setSSLParameters(p);
-            assertFalse(ssl.getNeedClientAuth());
-            assertTrue(ssl.getWantClientAuth());
-            p.setWantClientAuth(false);
-            assertFalse(ssl.getNeedClientAuth());
-            assertTrue(ssl.getWantClientAuth());
-            ssl.setSSLParameters(p);
-            assertFalse(ssl.getNeedClientAuth());
-            assertFalse(ssl.getWantClientAuth());
+        try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+            String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
+            String[] defaultProtocols = ssl.getEnabledProtocols();
+            String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
+            String[] supportedProtocols = ssl.getSupportedProtocols();
+            {
+                SSLParameters p = new SSLParameters();
+                ssl.setSSLParameters(p);
+                assertEquals(Arrays.asList(defaultCipherSuites),
+                        Arrays.asList(ssl.getEnabledCipherSuites()));
+                assertEquals(
+                        Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
+            }
+            {
+                SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
+                ssl.setSSLParameters(p);
+                assertEquals(Arrays.asList(supportedCipherSuites),
+                        Arrays.asList(ssl.getEnabledCipherSuites()));
+                assertEquals(Arrays.asList(supportedProtocols),
+                        Arrays.asList(ssl.getEnabledProtocols()));
+            }
+            {
+                SSLParameters p = new SSLParameters();
+                p.setNeedClientAuth(true);
+                assertFalse(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+                ssl.setSSLParameters(p);
+                assertTrue(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+                p.setWantClientAuth(true);
+                assertTrue(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+                ssl.setSSLParameters(p);
+                assertFalse(ssl.getNeedClientAuth());
+                assertTrue(ssl.getWantClientAuth());
+                p.setWantClientAuth(false);
+                assertFalse(ssl.getNeedClientAuth());
+                assertTrue(ssl.getWantClientAuth());
+                ssl.setSSLParameters(p);
+                assertFalse(ssl.getNeedClientAuth());
+                assertFalse(ssl.getWantClientAuth());
+            }
         }
     }
 
     @Test
     public void test_SSLSocket_setSoTimeout_basic() throws Exception {
-        ServerSocket listening = new ServerSocket(0);
-        Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
-        assertEquals(0, underlying.getSoTimeout());
-        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
-        Socket wrapping = sf.createSocket(underlying, null, -1, false);
-        assertEquals(0, wrapping.getSoTimeout());
-        // setting wrapper sets underlying and ...
-        int expectedTimeoutMillis = 1000; // 10 was too small because it was affected by rounding
-        wrapping.setSoTimeout(expectedTimeoutMillis);
-        // The kernel can round the requested value based on the HZ setting. We allow up to 10ms.
-        assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
-        assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
-        // ... getting wrapper inspects underlying
-        underlying.setSoTimeout(0);
-        assertEquals(0, wrapping.getSoTimeout());
-        assertEquals(0, underlying.getSoTimeout());
+        try (ServerSocket listening = new ServerSocket(0)) {
+            Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
+            assertEquals(0, underlying.getSoTimeout());
+            SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
+            Socket wrapping = sf.createSocket(underlying, null, -1, false);
+            assertEquals(0, wrapping.getSoTimeout());
+            // setting wrapper sets underlying and ...
+            int expectedTimeoutMillis =
+                    1000; // 10 was too small because it was affected by rounding
+            wrapping.setSoTimeout(expectedTimeoutMillis);
+            // The kernel can round the requested value based on the HZ setting. We allow up to
+            // 10ms.
+            assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
+            assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
+            // ... getting wrapper inspects underlying
+            underlying.setSoTimeout(0);
+            assertEquals(0, wrapping.getSoTimeout());
+            assertEquals(0, underlying.getSoTimeout());
+        }
     }
 
     @Test
@@ -842,13 +778,7 @@
         SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
         Socket clientWrapping = sf.createSocket(underlying, null, -1, false);
         underlying.setSoTimeout(1);
-        try {
-            @SuppressWarnings("unused")
-            int value = clientWrapping.getInputStream().read();
-            fail();
-        } catch (SocketTimeoutException expected) {
-            // Ignored.
-        }
+        assertThrows(SocketTimeoutException.class, () -> clientWrapping.getInputStream().read());
         clientWrapping.close();
         server.close();
         underlying.close();
@@ -874,90 +804,76 @@
 
     @Test
     public void test_SSLSocket_ClientHello_cipherSuites() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                final String[] cipherSuites;
-                // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
-                // a special signaling cipher suite. The TLS API has no way to check or
-                // indicate that a certain TLS extension should be used.
-                HelloExtension renegotiationInfoExtension =
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello =
+                    TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            final String[] cipherSuites;
+            // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
+            // a special signaling cipher suite. The TLS API has no way to check or
+            // indicate that a certain TLS extension should be used.
+            HelloExtension renegotiationInfoExtension =
                     clientHello.findExtensionByType(HelloExtension.TYPE_RENEGOTIATION_INFO);
-                if (renegotiationInfoExtension != null
-                    && renegotiationInfoExtension.data.length == 1
+            if (renegotiationInfoExtension != null && renegotiationInfoExtension.data.length == 1
                     && renegotiationInfoExtension.data[0] == 0) {
-                    cipherSuites = new String[clientHello.cipherSuites.size() + 1];
-                    cipherSuites[clientHello.cipherSuites.size()] =
+                cipherSuites = new String[clientHello.cipherSuites.size() + 1];
+                cipherSuites[clientHello.cipherSuites.size()] =
                         StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION;
-                } else {
-                    cipherSuites = new String[clientHello.cipherSuites.size()];
-                }
-                for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
-                    CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
-                    cipherSuites[i] = cipherSuite.getAndroidName();
-                }
-                StandardNames.assertDefaultCipherSuites(cipherSuites);
+            } else {
+                cipherSuites = new String[clientHello.cipherSuites.size()];
             }
+            for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
+                CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
+                cipherSuites[i] = cipherSuite.getAndroidName();
+            }
+            StandardNames.assertDefaultCipherSuites(cipherSuites);
         }, getSSLSocketFactoriesToTest());
     }
 
     @Test
     public void test_SSLSocket_ClientHello_supportedCurves() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                EllipticCurvesHelloExtension ecExtension =
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello =
+                    TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            EllipticCurvesHelloExtension ecExtension =
                     (EllipticCurvesHelloExtension) clientHello.findExtensionByType(
-                        HelloExtension.TYPE_ELLIPTIC_CURVES);
-                final String[] supportedCurves;
-                if (ecExtension == null) {
-                    supportedCurves = new String[0];
-                } else {
-                    assertTrue(ecExtension.wellFormed);
-                    supportedCurves = new String[ecExtension.supported.size()];
-                    for (int i = 0; i < ecExtension.supported.size(); i++) {
-                        EllipticCurve curve = ecExtension.supported.get(i);
-                        supportedCurves[i] = curve.toString();
-                    }
+                            HelloExtension.TYPE_ELLIPTIC_CURVES);
+            final String[] supportedCurves;
+            if (ecExtension == null) {
+                supportedCurves = new String[0];
+            } else {
+                assertTrue(ecExtension.wellFormed);
+                supportedCurves = new String[ecExtension.supported.size()];
+                for (int i = 0; i < ecExtension.supported.size(); i++) {
+                    EllipticCurve curve = ecExtension.supported.get(i);
+                    supportedCurves[i] = curve.toString();
                 }
-                StandardNames.assertDefaultEllipticCurves(supportedCurves);
             }
+            StandardNames.assertDefaultEllipticCurves(supportedCurves);
         }, getSSLSocketFactoriesToTest());
     }
 
     @Test
     public void test_SSLSocket_ClientHello_clientProtocolVersion() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
-            }
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello =
+                    TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
         }, getSSLSocketFactoriesToTest());
     }
 
     @Test
     public void test_SSLSocket_ClientHello_compressionMethods() throws Exception {
-        ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
-            @Override
-            public void run(SSLSocketFactory sslSocketFactory) throws Exception {
-                ClientHello clientHello = TlsTester
-                        .captureTlsHandshakeClientHello(executor, sslSocketFactory);
-                assertEquals(Collections.singletonList(CompressionMethod.NULL),
+        ForEachRunner.runNamed(sslSocketFactory -> {
+            ClientHello clientHello =
+                    TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+            assertEquals(Collections.singletonList(CompressionMethod.NULL),
                     clientHello.compressionMethods);
-            }
         }, getSSLSocketFactoriesToTest());
     }
 
     private List<Pair<String, SSLSocketFactory>> getSSLSocketFactoriesToTest()
             throws NoSuchAlgorithmException, KeyManagementException {
-        List<Pair<String, SSLSocketFactory>> result =
-                new ArrayList<Pair<String, SSLSocketFactory>>();
+        List<Pair<String, SSLSocketFactory>> result = new ArrayList<>();
         result.add(Pair.of("default", (SSLSocketFactory) SSLSocketFactory.getDefault()));
         for (String sslContextProtocol : StandardNames.SSL_CONTEXT_PROTOCOLS_WITH_DEFAULT_CONFIG) {
             SSLContext sslContext = SSLContext.getInstance(sslContextProtocol);
@@ -981,23 +897,17 @@
         final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
         System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
         clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
-        Future<Void> s = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                server.setEnabledProtocols(new String[]{"TLSv1.2"});
-                server.setEnabledCipherSuites(serverCipherSuites);
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> s = runAsync(() -> {
+            server.setEnabledProtocols(new String[] {"TLSv1.2"});
+            server.setEnabledCipherSuites(serverCipherSuites);
+            server.startHandshake();
+            return null;
         });
-        Future<Void> c = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                client.setEnabledProtocols(new String[]{"TLSv1.2"});
-                client.setEnabledCipherSuites(clientCipherSuites);
-                client.startHandshake();
-                return null;
-            }
+        Future<Void> c = runAsync(() -> {
+            client.setEnabledProtocols(new String[] {"TLSv1.2"});
+            client.setEnabledCipherSuites(clientCipherSuites);
+            client.startHandshake();
+            return null;
         });
         s.get();
         c.get();
@@ -1016,21 +926,15 @@
         // Confirm absence of TLS_FALLBACK_SCSV.
         assertFalse(Arrays.asList(client.getEnabledCipherSuites())
                             .contains(StandardNames.CIPHER_SUITE_FALLBACK));
-        Future<Void> s = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                server.setEnabledProtocols(new String[]{"TLSv1.2", "TLSv1.1"});
-                server.startHandshake();
-                return null;
-            }
+        Future<Void> s = runAsync(() -> {
+            server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
+            server.startHandshake();
+            return null;
         });
-        Future<Void> c = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                client.setEnabledProtocols(new String[]{"TLSv1.1"});
-                client.startHandshake();
-                return null;
-            }
+        Future<Void> c = runAsync(() -> {
+            client.setEnabledProtocols(new String[] {"TLSv1.1"});
+            client.startHandshake();
+            return null;
         });
         s.get();
         c.get();
@@ -1057,37 +961,25 @@
         final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
         System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
         clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
-        Future<Void> s = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
-                server.setEnabledCipherSuites(serverCipherSuites);
-                try {
-                    server.startHandshake();
-                    fail("Should result in inappropriate fallback");
-                } catch (SSLHandshakeException expected) {
-                    Throwable cause = expected.getCause();
-                    assertEquals(SSLProtocolException.class, cause.getClass());
-                    assertInappropriateFallbackIsCause(cause);
-                }
-                return null;
-            }
+        Future<Void> s = runAsync(() -> {
+            server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
+            server.setEnabledCipherSuites(serverCipherSuites);
+            SSLHandshakeException expected =
+                    assertThrows(SSLHandshakeException.class, server::startHandshake);
+            Throwable cause = expected.getCause();
+            assertEquals(SSLProtocolException.class, cause.getClass());
+            assertInappropriateFallbackIsCause(cause);
+            return null;
         });
-        Future<Void> c = runAsync(new Callable<Void>() {
-            @Override
-            public Void call() throws Exception {
-                client.setEnabledProtocols(new String[] {"TLSv1.1"});
-                client.setEnabledCipherSuites(clientCipherSuites);
-                try {
-                    client.startHandshake();
-                    fail("Should receive TLS alert inappropriate fallback");
-                } catch (SSLHandshakeException expected) {
-                    Throwable cause = expected.getCause();
-                    assertEquals(SSLProtocolException.class, cause.getClass());
-                    assertInappropriateFallbackIsCause(cause);
-                }
-                return null;
-            }
+        Future<Void> c = runAsync(() -> {
+            client.setEnabledProtocols(new String[] {"TLSv1.1"});
+            client.setEnabledCipherSuites(clientCipherSuites);
+            SSLHandshakeException expected =
+                    assertThrows(SSLHandshakeException.class, client::startHandshake);
+            Throwable cause = expected.getCause();
+            assertEquals(SSLProtocolException.class, cause.getClass());
+            assertInappropriateFallbackIsCause(cause);
+            return null;
         });
         s.get();
         c.get();
@@ -1122,6 +1014,74 @@
         }
     }
 
+    @Test
+    public void handshakeListenersRunExactlyOnce() {
+        AtomicInteger count = new AtomicInteger(0);
+        TestSSLSocketPair pair = TestSSLSocketPair.create();
+        pair.client.addHandshakeCompletedListener(event -> count.addAndGet(1));
+        pair.client.addHandshakeCompletedListener(event -> count.addAndGet(2));
+        pair.client.addHandshakeCompletedListener(event -> count.addAndGet(4));
+        pair.connect();
+        assertEquals(1 + 2 + 4, count.get());
+    }
+
+    @Test
+    public void closeFromHandshakeListener() throws Exception {
+        TestUtils.assumeEngineSocket();
+
+        TestSSLSocketPair pair = TestSSLSocketPair.create();
+        pair.client.addHandshakeCompletedListener(event -> socketClose(pair.client));
+        Future<Void> serverFuture = runAsync((Callable<Void>) () -> {
+            pair.server.startHandshake();
+            return null;
+        });
+        pair.client.startHandshake();
+        assertThrows(SocketException.class, pair.client::getInputStream);
+        serverFuture.get();
+        InputStream istream = pair.server.getInputStream();
+        assertEquals(-1, istream.read());
+    }
+
+    @Test
+    public void writeFromHandshakeListener() throws Exception {
+        TestUtils.assumeEngineSocket();
+
+        byte[] ping = "ping".getBytes(UTF_8);
+        byte[] pong = "pong".getBytes(UTF_8);
+        TestSSLSocketPair pair = TestSSLSocketPair.create();
+        pair.client.addHandshakeCompletedListener(event -> socketWrite(pair.client, ping));
+        pair.server.addHandshakeCompletedListener(event -> socketWrite(pair.server, pong));
+        Future<Void> serverFuture = runAsync(() -> {
+            pair.server.startHandshake();
+            return null;
+        });
+        byte[] buffer = new byte[4];
+        InputStream clientStream = pair.client.getInputStream();
+        assertEquals(4, clientStream.read(buffer));
+        assertArrayEquals(pong, buffer);
+
+        serverFuture.get();
+        InputStream serverStream = pair.server.getInputStream();
+        assertEquals(4, serverStream.read(buffer));
+        assertArrayEquals(ping, buffer);
+    }
+
+    private void socketClose(Socket socket) {
+        try {
+            socket.close();
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private void socketWrite(Socket socket, byte[] data) {
+        try {
+            socket.getOutputStream().write(data);
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
     private <T> Future<T> runAsync(Callable<T> callable) {
         return executor.submit(callable);
     }
@@ -1138,5 +1098,4 @@
             byteCount -= bytesRead;
         }
     }
-
 }