Merge Conscrypt upstream master.
Contains the following upstream changes:
Tidy ConscryptEngineSocket state machine. (#1120)
Add missing close() calls. (#1122)
Bug: 276304877
Test: MtsConscryptTestCases
Change-Id: Iadb7295b1fa0925dd8b966b48bca3040f858196f
Merged-In: Iadb7295b1fa0925dd8b966b48bca3040f858196f
(cherry picked from commit af289bed7ed17dc6138da62db42370b1c3472172)
diff --git a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java
index 48c6b3d..8d96276 100644
--- a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java
+++ b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java
@@ -60,7 +60,7 @@
private SSLOutputStream out;
private SSLInputStream in;
- private long handshakeStartedMillis;
+ private long handshakeStartedMillis = 0;
private BufferAllocator bufferAllocator = ConscryptEngine.getDefaultBufferAllocator();
@@ -123,7 +123,7 @@
@Override
public void onHandshakeFinished() {
// Just call the outer class method.
- socket.onHandshakeFinished();
+ socket.onEngineHandshakeFinished();
}
});
@@ -194,8 +194,7 @@
synchronized (stateLock) {
// Initialize the handshake if we haven't already.
if (state == STATE_NEW) {
- state = STATE_HANDSHAKE_STARTED;
- handshakeStartedMillis = Platform.getMillisSinceBoot();
+ transitionTo(STATE_HANDSHAKE_STARTED);
engine.beginHandshake();
in = new SSLInputStream();
out = new SSLOutputStream();
@@ -208,7 +207,6 @@
return;
}
}
-
doHandshake();
}
} catch (SSLException e) {
@@ -232,6 +230,7 @@
case NEED_UNWRAP:
if (in.processDataFromSocket(EmptyArray.BYTE, 0, 0) < 0) {
// Can't complete the handshake due to EOF.
+ close();
throw SSLUtils.toSSLHandshakeException(
new EOFException("connection closed"));
}
@@ -244,15 +243,13 @@
}
case NEED_TASK: {
// Should never get here, since our engine never provides tasks.
+ close();
throw new IllegalStateException("Engine tasks are unsupported");
}
case NOT_HANDSHAKING:
case FINISHED: {
// Handshake is complete.
finished = true;
- Platform.countTlsHandshake(true, engine.getSession().getProtocol(),
- engine.getSession().getCipherSuite(),
- Platform.getMillisSinceBoot() - handshakeStartedMillis);
break;
}
default: {
@@ -261,11 +258,15 @@
}
}
}
+ if (isState(STATE_HANDSHAKE_COMPLETED)) {
+ // STATE_READY_HANDSHAKE_CUT_THROUGH will wake up any waiting threads which can
+ // race with the listeners, but that's OK.
+ transitionTo(STATE_READY_HANDSHAKE_CUT_THROUGH);
+ notifyHandshakeCompletedListeners();
+ transitionTo(STATE_READY);
+ }
} catch (SSLException e) {
drainOutgoingQueue();
- Platform.countTlsHandshake(false, engine.getSession().getProtocol(),
- engine.getSession().getCipherSuite(),
- Platform.getMillisSinceBoot() - handshakeStartedMillis);
close();
throw e;
} catch (IOException e) {
@@ -278,6 +279,64 @@
}
}
+ private boolean isState(int desiredState) {
+ synchronized (stateLock) {
+ return state == desiredState;
+ }
+ }
+
+ private int transitionTo(int newState) {
+ synchronized (stateLock) {
+ if (state == newState) {
+ return state;
+ }
+
+ int previousState = state;
+ boolean notify = false;
+ switch (newState) {
+ case STATE_HANDSHAKE_STARTED:
+ handshakeStartedMillis = Platform.getMillisSinceBoot();
+ break;
+
+ case STATE_READY_HANDSHAKE_CUT_THROUGH:
+ if (handshakeStartedMillis > 0) {
+ Platform.countTlsHandshake(true,
+ engine.getSession().getProtocol(),
+ engine.getSession().getCipherSuite(),
+ Platform.getMillisSinceBoot() - handshakeStartedMillis);
+ handshakeStartedMillis = 0;
+ }
+ notify = true;
+ break;
+
+ case STATE_READY:
+ notify = true;
+ break;
+
+ case STATE_CLOSED:
+ if (handshakeStartedMillis > 0) {
+ // Handshake must have failed.
+ Platform.countTlsHandshake(false,
+ engine.getSession().getProtocol(),
+ engine.getSession().getCipherSuite(),
+ Platform.getMillisSinceBoot() - handshakeStartedMillis);
+ handshakeStartedMillis = 0;
+ }
+ notify = true;
+ break;
+
+ default:
+ break;
+ }
+
+ state = newState;
+ if (notify) {
+ stateLock.notifyAll();
+ }
+ return previousState;
+ }
+ }
+
@Override
public final InputStream getInputStream() throws IOException {
checkOpen();
@@ -441,24 +500,14 @@
// TODO: Close SSL sockets using a background thread so they close gracefully.
if (stateLock == null) {
- // close() has been called before we've initialized the socket, so just
- // return.
+ // Constructor failed, e.g. superclass constructor called close()
return;
}
- int previousState;
- synchronized (stateLock) {
- previousState = state;
- if (state == STATE_CLOSED) {
- // close() has already been called, so do nothing and return.
- return;
- }
-
- state = STATE_CLOSED;
-
- stateLock.notifyAll();
+ int previousState = transitionTo(STATE_CLOSED);
+ if (previousState == STATE_CLOSED) {
+ return;
}
-
try {
// Close the engine.
engine.closeInbound();
@@ -527,25 +576,12 @@
this.bufferAllocator = bufferAllocator;
}
- private void onHandshakeFinished() {
- boolean notify = false;
- synchronized (stateLock) {
- if (state != STATE_CLOSED) {
- if (state == STATE_HANDSHAKE_STARTED) {
- state = STATE_READY_HANDSHAKE_CUT_THROUGH;
- } else if (state == STATE_HANDSHAKE_COMPLETED) {
- state = STATE_READY;
- }
-
- // Unblock threads that are waiting for our state to transition
- // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
- stateLock.notifyAll();
- notify = true;
- }
- }
-
- if (notify) {
- notifyHandshakeCompletedListeners();
+ private void onEngineHandshakeFinished() {
+ // Don't do anything here except change state. This method will be called from
+ // e.g. wrap() which is non re-entrant so we can't call anything that might do
+ // IO until after it exits, e.g. in doHandshake().
+ if (isState(STATE_HANDSHAKE_STARTED)) {
+ transitionTo(STATE_HANDSHAKE_COMPLETED);
}
}
@@ -556,7 +592,9 @@
startHandshake();
synchronized (stateLock) {
- while (state != STATE_READY && state != STATE_READY_HANDSHAKE_CUT_THROUGH
+ while (state != STATE_READY
+ // Waiting threads are allowed to compete with handshake listeners for access.
+ && state != STATE_READY_HANDSHAKE_CUT_THROUGH
&& state != STATE_CLOSED) {
try {
stateLock.wait();
@@ -901,7 +939,7 @@
private boolean isHandshakeFinished() {
synchronized (stateLock) {
- return state >= STATE_READY_HANDSHAKE_CUT_THROUGH;
+ return state > STATE_HANDSHAKE_STARTED;
}
}
diff --git a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java
index d0d5dd7..36d0cb1 100644
--- a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java
+++ b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLSocketTest.java
@@ -16,20 +16,22 @@
package org.conscrypt.javax.net.ssl;
-import static org.conscrypt.TestUtils.UTF_8;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
+import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
@@ -44,7 +46,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
-import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.crypto.SecretKey;
@@ -71,7 +72,6 @@
import org.conscrypt.tlswire.handshake.HelloExtension;
import org.conscrypt.tlswire.util.TlsProtocolVersion;
import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -81,24 +81,14 @@
@RunWith(JUnit4.class)
public class SSLSocketTest {
- private ExecutorService executor;
- private ThreadGroup threadGroup;
-
- @Before
- public void setup() {
- threadGroup = new ThreadGroup("SSLSocketTest");
- executor = Executors.newCachedThreadPool(new ThreadFactory() {
- @Override
- public Thread newThread(Runnable r) {
- return new Thread(threadGroup, r);
- }
- });
- }
+ private final ThreadGroup threadGroup = new ThreadGroup("SSLSocketTest");
+ private final ExecutorService executor =
+ Executors.newCachedThreadPool(t -> new Thread(threadGroup, t));
@After
public void teardown() throws InterruptedException {
executor.shutdownNow();
- executor.awaitTermination(5, TimeUnit.SECONDS);
+ assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS));
}
@Test
@@ -110,8 +100,9 @@
@Test
public void test_SSLSocket_getSupportedCipherSuites_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+ }
}
@Test
@@ -131,7 +122,7 @@
}
private void test_SSLSocket_getSupportedCipherSuites_connect(
- TestKeyStore testKeyStore, StringBuilder error) throws Exception {
+ TestKeyStore testKeyStore, StringBuilder error) {
String clientToServerString = "this is sent from the client to the server...";
String serverToClientString = "... and this from the server to the client";
byte[] clientToServer = clientToServerString.getBytes(UTF_8);
@@ -207,21 +198,9 @@
// Check that the server and the client cannot read anything else
// (reads should time out)
server.setSoTimeout(10);
- try {
- @SuppressWarnings("unused")
- int value = server.getInputStream().read();
- fail();
- } catch (IOException expected) {
- // Ignored.
- }
+ assertThrows(IOException.class, () -> server.getInputStream().read());
client.setSoTimeout(10);
- try {
- @SuppressWarnings("unused")
- int value = client.getInputStream().read();
- fail();
- } catch (IOException expected) {
- // Ignored.
- }
+ assertThrows(IOException.class, () -> client.getInputStream().read());
client.close();
server.close();
} catch (Exception maybeExpected) {
@@ -273,53 +252,44 @@
@Test
public void test_SSLSocket_getEnabledCipherSuites_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+ }
}
@Test
public void test_SSLSocket_setEnabledCipherSuites_storesCopy() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- String[] array = new String[] {ssl.getEnabledCipherSuites()[0]};
- String originalFirstElement = array[0];
- ssl.setEnabledCipherSuites(array);
- array[0] = "Modified after having been set";
- assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ String[] array = new String[]{ssl.getEnabledCipherSuites()[0]};
+ String originalFirstElement = array[0];
+ ssl.setEnabledCipherSuites(array);
+ array[0] = "Modified after having been set";
+ assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+ }
}
@Test
public void test_SSLSocket_setEnabledCipherSuites_TLS12() throws Exception {
SSLContext context = SSLContext.getInstance("TLSv1.2");
context.init(null, null, null);
- SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket();
- try {
- ssl.setEnabledCipherSuites(null);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
+ try (SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket()) {
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledCipherSuites(null));
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledCipherSuites(new String[1]));
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledCipherSuites(new String[]{"Bogus"}));
+ ssl.setEnabledCipherSuites(new String[0]);
+ ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
+ ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
+ // Check that setEnabledCipherSuites affects getEnabledCipherSuites
+ String[] cipherSuites = new String[]{
+ TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
+ };
+ ssl.setEnabledCipherSuites(cipherSuites);
+ assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
}
- try {
- ssl.setEnabledCipherSuites(new String[1]);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- ssl.setEnabledCipherSuites(new String[] {"Bogus"});
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- ssl.setEnabledCipherSuites(new String[0]);
- ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
- ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
- // Check that setEnabledCipherSuites affects getEnabledCipherSuites
- String[] cipherSuites = new String[] {
- TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
- };
- ssl.setEnabledCipherSuites(cipherSuites);
- assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
}
@Test
@@ -327,91 +297,81 @@
SSLContext context = SSLContext.getInstance("TLSv1.3");
context.init(null, null, null);
SSLSocketFactory sf = context.getSocketFactory();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- // The TLS 1.3 cipher suites should be enabled by default
- assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
- // Disabling them should be ignored
- ssl.setEnabledCipherSuites(new String[0]);
- assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ // The TLS 1.3 cipher suites should be enabled by default
+ assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ // Disabling them should be ignored
+ ssl.setEnabledCipherSuites(new String[0]);
+ assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
- ssl.setEnabledCipherSuites(new String[] {
- TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
- });
- assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ ssl.setEnabledCipherSuites(new String[]{
+ TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
+ });
+ assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
- // Disabling TLS 1.3 should disable 1.3 cipher suites
- ssl.setEnabledProtocols(new String[] { "TLSv1.2" });
- assertFalse(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ // Disabling TLS 1.3 should disable 1.3 cipher suites
+ ssl.setEnabledProtocols(new String[]{"TLSv1.2"});
+ assertFalse(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ }
}
@Test
public void test_SSLSocket_getSupportedProtocols_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+ }
}
@Test
public void test_SSLSocket_getEnabledProtocols_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+ }
}
@Test
public void test_SSLSocket_setEnabledProtocols_storesCopy() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- String[] array = new String[] {ssl.getEnabledProtocols()[0]};
- String originalFirstElement = array[0];
- ssl.setEnabledProtocols(array);
- array[0] = "Modified after having been set";
- assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ String[] array = new String[]{ssl.getEnabledProtocols()[0]};
+ String originalFirstElement = array[0];
+ ssl.setEnabledProtocols(array);
+ array[0] = "Modified after having been set";
+ assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+ }
}
@Test
public void test_SSLSocket_setEnabledProtocols() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- try {
- ssl.setEnabledProtocols(null);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- ssl.setEnabledProtocols(new String[1]);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- ssl.setEnabledProtocols(new String[] {"Bogus"});
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- ssl.setEnabledProtocols(new String[0]);
- ssl.setEnabledProtocols(ssl.getEnabledProtocols());
- ssl.setEnabledProtocols(ssl.getSupportedProtocols());
- // Check that setEnabledProtocols affects getEnabledProtocols
- for (String protocol : ssl.getSupportedProtocols()) {
- if ("SSLv2Hello".equals(protocol)) {
- try {
- ssl.setEnabledProtocols(new String[] {protocol});
- fail("Should fail when SSLv2Hello is set by itself");
- } catch (IllegalArgumentException expected) {
- // Ignored.
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledProtocols(null));
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledProtocols(new String[1]));
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledProtocols(new String[]{"Bogus"}));
+ ssl.setEnabledProtocols(new String[0]);
+ ssl.setEnabledProtocols(ssl.getEnabledProtocols());
+ ssl.setEnabledProtocols(ssl.getSupportedProtocols());
+ // Check that setEnabledProtocols affects getEnabledProtocols
+ for (String protocol : ssl.getSupportedProtocols()) {
+ if ("SSLv2Hello".equals(protocol)) {
+ // Should fail when SSLv2Hello is set by itself
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledProtocols(new String[]{protocol}));
+ } else {
+ String[] protocols = new String[]{protocol};
+ ssl.setEnabledProtocols(protocols);
+ assertEquals(Arrays.deepToString(protocols),
+ Arrays.deepToString(ssl.getEnabledProtocols()));
}
- } else {
- String[] protocols = new String[] {protocol};
- ssl.setEnabledProtocols(protocols);
- assertEquals(Arrays.deepToString(protocols),
- Arrays.deepToString(ssl.getEnabledProtocols()));
}
}
}
@@ -430,11 +390,9 @@
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
server.setEnabledProtocols(new String[] {"TLSv1.3", "TLSv1.2", "TLSv1.1"});
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -461,11 +419,9 @@
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -481,18 +437,20 @@
@Test
public void test_SSLSocket_getSession() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- SSLSession session = ssl.getSession();
- assertNotNull(session);
- assertFalse(session.isValid());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ SSLSession session = ssl.getSession();
+ assertNotNull(session);
+ assertFalse(session.isValid());
+ }
}
@Test
public void test_SSLSocket_getHandshakeSession_unconnected() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket socket = (SSLSocket) sf.createSocket();
- SSLSession session = socket.getHandshakeSession();
- assertNull(session);
+ try (SSLSocket socket = (SSLSocket) sf.createSocket()) {
+ SSLSession session = socket.getHandshakeSession();
+ assertNull(session);
+ }
}
@Test
@@ -570,11 +528,9 @@
clientContext.getSocketFactory().createSocket(c.host, c.port);
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -673,12 +629,10 @@
clientContext.getSocketFactory().createSocket(c.host, c.port);
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.setNeedClientAuth(true);
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.setNeedClientAuth(true);
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -691,21 +645,11 @@
}
@Test
- public void test_SSLSocket_setUseClientMode_afterHandshake() throws Exception {
+ public void test_SSLSocket_setUseClientMode_afterHandshake() {
// can't set after handshake
TestSSLSocketPair pair = TestSSLSocketPair.create().connect();
- try {
- pair.server.setUseClientMode(false);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- pair.client.setUseClientMode(false);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
+ assertThrows(IllegalArgumentException.class, () -> pair.server.setUseClientMode(true));
+ assertThrows(IllegalArgumentException.class, () -> pair.client.setUseClientMode(false));
}
@Test
@@ -715,24 +659,14 @@
SSLSocket client =
(SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, c.port);
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
- Future<Void> future = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- try {
- server.startHandshake();
- fail();
- } catch (SSLHandshakeException expected) {
- // Ignored.
- }
- return null;
- }
+ Future<Void> future = runAsync(() -> {
+ assertThrows(SSLHandshakeException.class, server::startHandshake);
+ return null;
});
- try {
- client.startHandshake();
- fail();
- } catch (SSLHandshakeException expected) {
- assertTrue(expected.getCause() instanceof CertificateException);
- }
+ SSLHandshakeException expected =
+ assertThrows(SSLHandshakeException.class, client::startHandshake);
+ assertTrue(expected.getCause() instanceof CertificateException);
+
future.get();
client.close();
server.close();
@@ -743,90 +677,93 @@
public void test_SSLSocket_getSSLParameters() throws Exception {
TestUtils.assumeSetEndpointIdentificationAlgorithmAvailable();
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- SSLParameters p = ssl.getSSLParameters();
- assertNotNull(p);
- String[] cipherSuites = p.getCipherSuites();
- assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
- assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
- String[] protocols = p.getProtocols();
- assertNotSame(protocols, ssl.getEnabledProtocols());
- assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
- assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
- assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
- assertNull(p.getEndpointIdentificationAlgorithm());
- p.setEndpointIdentificationAlgorithm(null);
- assertNull(p.getEndpointIdentificationAlgorithm());
- p.setEndpointIdentificationAlgorithm("HTTPS");
- assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
- p.setEndpointIdentificationAlgorithm("FOO");
- assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ SSLParameters p = ssl.getSSLParameters();
+ assertNotNull(p);
+ String[] cipherSuites = p.getCipherSuites();
+ assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
+ assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
+ String[] protocols = p.getProtocols();
+ assertNotSame(protocols, ssl.getEnabledProtocols());
+ assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
+ assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
+ assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
+ assertNull(p.getEndpointIdentificationAlgorithm());
+ p.setEndpointIdentificationAlgorithm(null);
+ assertNull(p.getEndpointIdentificationAlgorithm());
+ p.setEndpointIdentificationAlgorithm("HTTPS");
+ assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
+ p.setEndpointIdentificationAlgorithm("FOO");
+ assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+ }
}
@Test
public void test_SSLSocket_setSSLParameters() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
- String[] defaultProtocols = ssl.getEnabledProtocols();
- String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
- String[] supportedProtocols = ssl.getSupportedProtocols();
- {
- SSLParameters p = new SSLParameters();
- ssl.setSSLParameters(p);
- assertEquals(Arrays.asList(defaultCipherSuites),
- Arrays.asList(ssl.getEnabledCipherSuites()));
- assertEquals(Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
- }
- {
- SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
- ssl.setSSLParameters(p);
- assertEquals(Arrays.asList(supportedCipherSuites),
- Arrays.asList(ssl.getEnabledCipherSuites()));
- assertEquals(
- Arrays.asList(supportedProtocols), Arrays.asList(ssl.getEnabledProtocols()));
- }
- {
- SSLParameters p = new SSLParameters();
- p.setNeedClientAuth(true);
- assertFalse(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
- ssl.setSSLParameters(p);
- assertTrue(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
- p.setWantClientAuth(true);
- assertTrue(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
- ssl.setSSLParameters(p);
- assertFalse(ssl.getNeedClientAuth());
- assertTrue(ssl.getWantClientAuth());
- p.setWantClientAuth(false);
- assertFalse(ssl.getNeedClientAuth());
- assertTrue(ssl.getWantClientAuth());
- ssl.setSSLParameters(p);
- assertFalse(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
+ String[] defaultProtocols = ssl.getEnabledProtocols();
+ String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
+ String[] supportedProtocols = ssl.getSupportedProtocols();
+ {
+ SSLParameters p = new SSLParameters();
+ ssl.setSSLParameters(p);
+ assertEquals(Arrays.asList(defaultCipherSuites),
+ Arrays.asList(ssl.getEnabledCipherSuites()));
+ assertEquals(Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
+ }
+ {
+ SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
+ ssl.setSSLParameters(p);
+ assertEquals(Arrays.asList(supportedCipherSuites),
+ Arrays.asList(ssl.getEnabledCipherSuites()));
+ assertEquals(
+ Arrays.asList(supportedProtocols), Arrays.asList(ssl.getEnabledProtocols()));
+ }
+ {
+ SSLParameters p = new SSLParameters();
+ p.setNeedClientAuth(true);
+ assertFalse(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ ssl.setSSLParameters(p);
+ assertTrue(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ p.setWantClientAuth(true);
+ assertTrue(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ ssl.setSSLParameters(p);
+ assertFalse(ssl.getNeedClientAuth());
+ assertTrue(ssl.getWantClientAuth());
+ p.setWantClientAuth(false);
+ assertFalse(ssl.getNeedClientAuth());
+ assertTrue(ssl.getWantClientAuth());
+ ssl.setSSLParameters(p);
+ assertFalse(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ }
}
}
@Test
public void test_SSLSocket_setSoTimeout_basic() throws Exception {
- ServerSocket listening = new ServerSocket(0);
- Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
- assertEquals(0, underlying.getSoTimeout());
- SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- Socket wrapping = sf.createSocket(underlying, null, -1, false);
- assertEquals(0, wrapping.getSoTimeout());
- // setting wrapper sets underlying and ...
- int expectedTimeoutMillis = 1000; // 10 was too small because it was affected by rounding
- wrapping.setSoTimeout(expectedTimeoutMillis);
- // The kernel can round the requested value based on the HZ setting. We allow up to 10ms.
- assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
- assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
- // ... getting wrapper inspects underlying
- underlying.setSoTimeout(0);
- assertEquals(0, wrapping.getSoTimeout());
- assertEquals(0, underlying.getSoTimeout());
+ try (ServerSocket listening = new ServerSocket(0)) {
+ Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
+ assertEquals(0, underlying.getSoTimeout());
+ SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
+ Socket wrapping = sf.createSocket(underlying, null, -1, false);
+ assertEquals(0, wrapping.getSoTimeout());
+ // setting wrapper sets underlying and ...
+ int expectedTimeoutMillis = 1000; // 10 was too small because it was affected by rounding
+ wrapping.setSoTimeout(expectedTimeoutMillis);
+ // The kernel can round the requested value based on the HZ setting. We allow up to 10ms.
+ assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
+ assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
+ // ... getting wrapper inspects underlying
+ underlying.setSoTimeout(0);
+ assertEquals(0, wrapping.getSoTimeout());
+ assertEquals(0, underlying.getSoTimeout());
+ }
}
@Test
@@ -838,13 +775,7 @@
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
Socket clientWrapping = sf.createSocket(underlying, null, -1, false);
underlying.setSoTimeout(1);
- try {
- @SuppressWarnings("unused")
- int value = clientWrapping.getInputStream().read();
- fail();
- } catch (SocketTimeoutException expected) {
- // Ignored.
- }
+ assertThrows(SocketTimeoutException.class, () -> clientWrapping.getInputStream().read());
clientWrapping.close();
server.close();
underlying.close();
@@ -870,90 +801,81 @@
@Test
public void test_SSLSocket_ClientHello_cipherSuites() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- final String[] cipherSuites;
- // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
- // a special signaling cipher suite. The TLS API has no way to check or
- // indicate that a certain TLS extension should be used.
- HelloExtension renegotiationInfoExtension =
- clientHello.findExtensionByType(HelloExtension.TYPE_RENEGOTIATION_INFO);
- if (renegotiationInfoExtension != null
- && renegotiationInfoExtension.data.length == 1
- && renegotiationInfoExtension.data[0] == 0) {
- cipherSuites = new String[clientHello.cipherSuites.size() + 1];
- cipherSuites[clientHello.cipherSuites.size()] =
- StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION;
- } else {
- cipherSuites = new String[clientHello.cipherSuites.size()];
- }
- for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
- CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
- cipherSuites[i] = cipherSuite.getAndroidName();
- }
- StandardNames.assertDefaultCipherSuites(cipherSuites);
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello = TlsTester
+ .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ final String[] cipherSuites;
+ // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
+ // a special signaling cipher suite. The TLS API has no way to check or
+ // indicate that a certain TLS extension should be used.
+ HelloExtension renegotiationInfoExtension =
+ clientHello.findExtensionByType(HelloExtension.TYPE_RENEGOTIATION_INFO);
+ if (renegotiationInfoExtension != null
+ && renegotiationInfoExtension.data.length == 1
+ && renegotiationInfoExtension.data[0] == 0) {
+ cipherSuites = new String[clientHello.cipherSuites.size() + 1];
+ cipherSuites[clientHello.cipherSuites.size()] =
+ StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION;
+ } else {
+ cipherSuites = new String[clientHello.cipherSuites.size()];
}
- }, getSSLSocketFactoriesToTest());
+ for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
+ CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
+ cipherSuites[i] = cipherSuite.getAndroidName();
+ }
+ StandardNames.assertDefaultCipherSuites(cipherSuites);
+ },
+ getSSLSocketFactoriesToTest());
}
@Test
public void test_SSLSocket_ClientHello_supportedCurves() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- EllipticCurvesHelloExtension ecExtension =
- (EllipticCurvesHelloExtension) clientHello.findExtensionByType(
- HelloExtension.TYPE_ELLIPTIC_CURVES);
- final String[] supportedCurves;
- if (ecExtension == null) {
- supportedCurves = new String[0];
- } else {
- assertTrue(ecExtension.wellFormed);
- supportedCurves = new String[ecExtension.supported.size()];
- for (int i = 0; i < ecExtension.supported.size(); i++) {
- EllipticCurve curve = ecExtension.supported.get(i);
- supportedCurves[i] = curve.toString();
- }
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello = TlsTester
+ .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ EllipticCurvesHelloExtension ecExtension =
+ (EllipticCurvesHelloExtension) clientHello.findExtensionByType(
+ HelloExtension.TYPE_ELLIPTIC_CURVES);
+ final String[] supportedCurves;
+ if (ecExtension == null) {
+ supportedCurves = new String[0];
+ } else {
+ assertTrue(ecExtension.wellFormed);
+ supportedCurves = new String[ecExtension.supported.size()];
+ for (int i = 0; i < ecExtension.supported.size(); i++) {
+ EllipticCurve curve = ecExtension.supported.get(i);
+ supportedCurves[i] = curve.toString();
}
- StandardNames.assertDefaultEllipticCurves(supportedCurves);
}
- }, getSSLSocketFactoriesToTest());
+ StandardNames.assertDefaultEllipticCurves(supportedCurves);
+ },
+ getSSLSocketFactoriesToTest());
}
@Test
public void test_SSLSocket_ClientHello_clientProtocolVersion() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
- }
- }, getSSLSocketFactoriesToTest());
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello = TlsTester
+ .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
+ },
+ getSSLSocketFactoriesToTest());
}
@Test
public void test_SSLSocket_ClientHello_compressionMethods() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- assertEquals(Collections.singletonList(CompressionMethod.NULL),
- clientHello.compressionMethods);
- }
- }, getSSLSocketFactoriesToTest());
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello = TlsTester
+ .captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ assertEquals(Collections.singletonList(CompressionMethod.NULL),
+ clientHello.compressionMethods);
+ },
+ getSSLSocketFactoriesToTest());
}
private List<Pair<String, SSLSocketFactory>> getSSLSocketFactoriesToTest()
throws NoSuchAlgorithmException, KeyManagementException {
- List<Pair<String, SSLSocketFactory>> result =
- new ArrayList<Pair<String, SSLSocketFactory>>();
+ List<Pair<String, SSLSocketFactory>> result = new ArrayList<>();
result.add(Pair.of("default", (SSLSocketFactory) SSLSocketFactory.getDefault()));
for (String sslContextProtocol : StandardNames.SSL_CONTEXT_PROTOCOLS_WITH_DEFAULT_CONFIG) {
SSLContext sslContext = SSLContext.getInstance(sslContextProtocol);
@@ -977,23 +899,17 @@
final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
- Future<Void> s = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- server.setEnabledProtocols(new String[]{"TLSv1.2"});
- server.setEnabledCipherSuites(serverCipherSuites);
- server.startHandshake();
- return null;
- }
+ Future<Void> s = runAsync(() -> {
+ server.setEnabledProtocols(new String[]{"TLSv1.2"});
+ server.setEnabledCipherSuites(serverCipherSuites);
+ server.startHandshake();
+ return null;
});
- Future<Void> c = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- client.setEnabledProtocols(new String[]{"TLSv1.2"});
- client.setEnabledCipherSuites(clientCipherSuites);
- client.startHandshake();
- return null;
- }
+ Future<Void> c = runAsync(() -> {
+ client.setEnabledProtocols(new String[]{"TLSv1.2"});
+ client.setEnabledCipherSuites(clientCipherSuites);
+ client.startHandshake();
+ return null;
});
s.get();
c.get();
@@ -1012,21 +928,15 @@
// Confirm absence of TLS_FALLBACK_SCSV.
assertFalse(Arrays.asList(client.getEnabledCipherSuites())
.contains(StandardNames.CIPHER_SUITE_FALLBACK));
- Future<Void> s = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- server.setEnabledProtocols(new String[]{"TLSv1.2", "TLSv1.1"});
- server.startHandshake();
- return null;
- }
+ Future<Void> s = runAsync(() -> {
+ server.setEnabledProtocols(new String[]{"TLSv1.2", "TLSv1.1"});
+ server.startHandshake();
+ return null;
});
- Future<Void> c = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- client.setEnabledProtocols(new String[]{"TLSv1.1"});
- client.startHandshake();
- return null;
- }
+ Future<Void> c = runAsync(() -> {
+ client.setEnabledProtocols(new String[]{"TLSv1.1"});
+ client.startHandshake();
+ return null;
});
s.get();
c.get();
@@ -1053,37 +963,25 @@
final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
- Future<Void> s = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
- server.setEnabledCipherSuites(serverCipherSuites);
- try {
- server.startHandshake();
- fail("Should result in inappropriate fallback");
- } catch (SSLHandshakeException expected) {
- Throwable cause = expected.getCause();
- assertEquals(SSLProtocolException.class, cause.getClass());
- assertInappropriateFallbackIsCause(cause);
- }
- return null;
- }
+ Future<Void> s = runAsync(() -> {
+ server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
+ server.setEnabledCipherSuites(serverCipherSuites);
+ SSLHandshakeException expected =
+ assertThrows(SSLHandshakeException.class, server::startHandshake);
+ Throwable cause = expected.getCause();
+ assertEquals(SSLProtocolException.class, cause.getClass());
+ assertInappropriateFallbackIsCause(cause);
+ return null;
});
- Future<Void> c = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- client.setEnabledProtocols(new String[]{"TLSv1.1"});
- client.setEnabledCipherSuites(clientCipherSuites);
- try {
- client.startHandshake();
- fail("Should receive TLS alert inappropriate fallback");
- } catch (SSLHandshakeException expected) {
- Throwable cause = expected.getCause();
- assertEquals(SSLProtocolException.class, cause.getClass());
- assertInappropriateFallbackIsCause(cause);
- }
- return null;
- }
+ Future<Void> c = runAsync(() -> {
+ client.setEnabledProtocols(new String[]{"TLSv1.1"});
+ client.setEnabledCipherSuites(clientCipherSuites);
+ SSLHandshakeException expected =
+ assertThrows(SSLHandshakeException.class, client::startHandshake);
+ Throwable cause = expected.getCause();
+ assertEquals(SSLProtocolException.class, cause.getClass());
+ assertInappropriateFallbackIsCause(cause);
+ return null;
});
s.get();
c.get();
@@ -1118,6 +1016,74 @@
}
}
+ @Test
+ public void handshakeListenersRunExactlyOnce() {
+ AtomicInteger count = new AtomicInteger(0);
+ TestSSLSocketPair pair = TestSSLSocketPair.create();
+ pair.client.addHandshakeCompletedListener(event -> count.addAndGet(1));
+ pair.client.addHandshakeCompletedListener(event -> count.addAndGet(2));
+ pair.client.addHandshakeCompletedListener(event -> count.addAndGet(4));
+ pair.connect();
+ assertEquals(1 + 2 + 4, count.get());
+ }
+
+ @Test
+ public void closeFromHandshakeListener() throws Exception {
+ TestUtils.assumeEngineSocket();
+
+ TestSSLSocketPair pair = TestSSLSocketPair.create();
+ pair.client.addHandshakeCompletedListener(event -> socketClose(pair.client));
+ Future<Void> serverFuture = runAsync((Callable<Void>) () -> {
+ pair.server.startHandshake();
+ return null;
+ });
+ pair.client.startHandshake();
+ assertThrows(SocketException.class, pair.client::getInputStream);
+ serverFuture.get();
+ InputStream istream = pair.server.getInputStream();
+ assertEquals(-1, istream.read());
+ }
+
+ @Test
+ public void writeFromHandshakeListener() throws Exception {
+ TestUtils.assumeEngineSocket();
+
+ byte[] ping = "ping".getBytes(UTF_8);
+ byte[] pong = "pong".getBytes(UTF_8);
+ TestSSLSocketPair pair = TestSSLSocketPair.create();
+ pair.client.addHandshakeCompletedListener(event -> socketWrite(pair.client, ping));
+ pair.server.addHandshakeCompletedListener(event -> socketWrite(pair.server, pong));
+ Future<Void> serverFuture = runAsync(() -> {
+ pair.server.startHandshake();
+ return null;
+ });
+ byte[] buffer = new byte[4];
+ InputStream clientStream = pair.client.getInputStream();
+ assertEquals(4, clientStream.read(buffer));
+ assertArrayEquals(pong, buffer);
+
+ serverFuture.get();
+ InputStream serverStream = pair.server.getInputStream();
+ assertEquals(4, serverStream.read(buffer));
+ assertArrayEquals(ping, buffer);
+ }
+
+ private void socketClose(Socket socket) {
+ try {
+ socket.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void socketWrite(Socket socket, byte[] data) {
+ try {
+ socket.getOutputStream().write(data);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
private <T> Future<T> runAsync(Callable<T> callable) {
return executor.submit(callable);
}
@@ -1134,5 +1100,4 @@
byteCount -= bytesRead;
}
}
-
}
diff --git a/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java b/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java
index 69a341f..c641841 100644
--- a/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java
+++ b/repackaged/common/src/main/java/com/android/org/conscrypt/ConscryptEngineSocket.java
@@ -61,7 +61,7 @@
private SSLOutputStream out;
private SSLInputStream in;
- private long handshakeStartedMillis;
+ private long handshakeStartedMillis = 0;
private BufferAllocator bufferAllocator = ConscryptEngine.getDefaultBufferAllocator();
@@ -124,7 +124,7 @@
@Override
public void onHandshakeFinished() {
// Just call the outer class method.
- socket.onHandshakeFinished();
+ socket.onEngineHandshakeFinished();
}
});
@@ -202,8 +202,7 @@
synchronized (stateLock) {
// Initialize the handshake if we haven't already.
if (state == STATE_NEW) {
- state = STATE_HANDSHAKE_STARTED;
- handshakeStartedMillis = Platform.getMillisSinceBoot();
+ transitionTo(STATE_HANDSHAKE_STARTED);
engine.beginHandshake();
in = new SSLInputStream();
out = new SSLOutputStream();
@@ -216,7 +215,6 @@
return;
}
}
-
doHandshake();
}
} catch (SSLException e) {
@@ -240,6 +238,7 @@
case NEED_UNWRAP:
if (in.processDataFromSocket(EmptyArray.BYTE, 0, 0) < 0) {
// Can't complete the handshake due to EOF.
+ close();
throw SSLUtils.toSSLHandshakeException(
new EOFException("connection closed"));
}
@@ -252,15 +251,13 @@
}
case NEED_TASK: {
// Should never get here, since our engine never provides tasks.
+ close();
throw new IllegalStateException("Engine tasks are unsupported");
}
case NOT_HANDSHAKING:
case FINISHED: {
// Handshake is complete.
finished = true;
- Platform.countTlsHandshake(true, engine.getSession().getProtocol(),
- engine.getSession().getCipherSuite(),
- Platform.getMillisSinceBoot() - handshakeStartedMillis);
break;
}
default: {
@@ -269,11 +266,15 @@
}
}
}
+ if (isState(STATE_HANDSHAKE_COMPLETED)) {
+ // STATE_READY_HANDSHAKE_CUT_THROUGH will wake up any waiting threads which can
+ // race with the listeners, but that's OK.
+ transitionTo(STATE_READY_HANDSHAKE_CUT_THROUGH);
+ notifyHandshakeCompletedListeners();
+ transitionTo(STATE_READY);
+ }
} catch (SSLException e) {
drainOutgoingQueue();
- Platform.countTlsHandshake(false, engine.getSession().getProtocol(),
- engine.getSession().getCipherSuite(),
- Platform.getMillisSinceBoot() - handshakeStartedMillis);
close();
throw e;
} catch (IOException e) {
@@ -286,6 +287,62 @@
}
}
+ private boolean isState(int desiredState) {
+ synchronized (stateLock) {
+ return state == desiredState;
+ }
+ }
+
+ private int transitionTo(int newState) {
+ synchronized (stateLock) {
+ if (state == newState) {
+ return state;
+ }
+
+ int previousState = state;
+ boolean notify = false;
+ switch (newState) {
+ case STATE_HANDSHAKE_STARTED:
+ handshakeStartedMillis = Platform.getMillisSinceBoot();
+ break;
+
+ case STATE_READY_HANDSHAKE_CUT_THROUGH:
+ if (handshakeStartedMillis > 0) {
+ Platform.countTlsHandshake(true, engine.getSession().getProtocol(),
+ engine.getSession().getCipherSuite(),
+ Platform.getMillisSinceBoot() - handshakeStartedMillis);
+ handshakeStartedMillis = 0;
+ }
+ notify = true;
+ break;
+
+ case STATE_READY:
+ notify = true;
+ break;
+
+ case STATE_CLOSED:
+ if (handshakeStartedMillis > 0) {
+ // Handshake must have failed.
+ Platform.countTlsHandshake(false, engine.getSession().getProtocol(),
+ engine.getSession().getCipherSuite(),
+ Platform.getMillisSinceBoot() - handshakeStartedMillis);
+ handshakeStartedMillis = 0;
+ }
+ notify = true;
+ break;
+
+ default:
+ break;
+ }
+
+ state = newState;
+ if (notify) {
+ stateLock.notifyAll();
+ }
+ return previousState;
+ }
+ }
+
@Override
public final InputStream getInputStream() throws IOException {
checkOpen();
@@ -457,24 +514,14 @@
// TODO: Close SSL sockets using a background thread so they close gracefully.
if (stateLock == null) {
- // close() has been called before we've initialized the socket, so just
- // return.
+ // Constructor failed, e.g. superclass constructor called close()
return;
}
- int previousState;
- synchronized (stateLock) {
- previousState = state;
- if (state == STATE_CLOSED) {
- // close() has already been called, so do nothing and return.
- return;
- }
-
- state = STATE_CLOSED;
-
- stateLock.notifyAll();
+ int previousState = transitionTo(STATE_CLOSED);
+ if (previousState == STATE_CLOSED) {
+ return;
}
-
try {
// Close the engine.
engine.closeInbound();
@@ -543,25 +590,12 @@
this.bufferAllocator = bufferAllocator;
}
- private void onHandshakeFinished() {
- boolean notify = false;
- synchronized (stateLock) {
- if (state != STATE_CLOSED) {
- if (state == STATE_HANDSHAKE_STARTED) {
- state = STATE_READY_HANDSHAKE_CUT_THROUGH;
- } else if (state == STATE_HANDSHAKE_COMPLETED) {
- state = STATE_READY;
- }
-
- // Unblock threads that are waiting for our state to transition
- // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
- stateLock.notifyAll();
- notify = true;
- }
- }
-
- if (notify) {
- notifyHandshakeCompletedListeners();
+ private void onEngineHandshakeFinished() {
+ // Don't do anything here except change state. This method will be called from
+ // e.g. wrap() which is non re-entrant so we can't call anything that might do
+ // IO until after it exits, e.g. in doHandshake().
+ if (isState(STATE_HANDSHAKE_STARTED)) {
+ transitionTo(STATE_HANDSHAKE_COMPLETED);
}
}
@@ -572,8 +606,9 @@
startHandshake();
synchronized (stateLock) {
- while (state != STATE_READY && state != STATE_READY_HANDSHAKE_CUT_THROUGH
- && state != STATE_CLOSED) {
+ while (state != STATE_READY
+ // Waiting threads are allowed to compete with handshake listeners for access.
+ && state != STATE_READY_HANDSHAKE_CUT_THROUGH && state != STATE_CLOSED) {
try {
stateLock.wait();
} catch (InterruptedException e) {
@@ -917,7 +952,7 @@
private boolean isHandshakeFinished() {
synchronized (stateLock) {
- return state >= STATE_READY_HANDSHAKE_CUT_THROUGH;
+ return state > STATE_HANDSHAKE_STARTED;
}
}
diff --git a/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java b/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java
index 80c8486..4a3f257 100644
--- a/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java
+++ b/repackaged/common/src/test/java/com/android/org/conscrypt/javax/net/ssl/SSLSocketTest.java
@@ -17,14 +17,15 @@
package com.android.org.conscrypt.javax.net.ssl;
-import static com.android.org.conscrypt.TestUtils.UTF_8;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
import com.android.org.conscrypt.TestUtils;
import com.android.org.conscrypt.java.security.StandardNames;
@@ -42,6 +43,7 @@
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
+import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
@@ -56,7 +58,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
-import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.crypto.SecretKey;
@@ -72,7 +73,6 @@
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.X509ExtendedTrustManager;
import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -85,24 +85,14 @@
*/
@RunWith(JUnit4.class)
public class SSLSocketTest {
- private ExecutorService executor;
- private ThreadGroup threadGroup;
-
- @Before
- public void setup() {
- threadGroup = new ThreadGroup("SSLSocketTest");
- executor = Executors.newCachedThreadPool(new ThreadFactory() {
- @Override
- public Thread newThread(Runnable r) {
- return new Thread(threadGroup, r);
- }
- });
- }
+ private final ThreadGroup threadGroup = new ThreadGroup("SSLSocketTest");
+ private final ExecutorService executor =
+ Executors.newCachedThreadPool(t -> new Thread(threadGroup, t));
@After
public void teardown() throws InterruptedException {
executor.shutdownNow();
- executor.awaitTermination(5, TimeUnit.SECONDS);
+ assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS));
}
@Test
@@ -114,8 +104,9 @@
@Test
public void test_SSLSocket_getSupportedCipherSuites_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
+ }
}
@Test
@@ -135,7 +126,7 @@
}
private void test_SSLSocket_getSupportedCipherSuites_connect(
- TestKeyStore testKeyStore, StringBuilder error) throws Exception {
+ TestKeyStore testKeyStore, StringBuilder error) {
String clientToServerString = "this is sent from the client to the server...";
String serverToClientString = "... and this from the server to the client";
byte[] clientToServer = clientToServerString.getBytes(UTF_8);
@@ -211,21 +202,9 @@
// Check that the server and the client cannot read anything else
// (reads should time out)
server.setSoTimeout(10);
- try {
- @SuppressWarnings("unused")
- int value = server.getInputStream().read();
- fail();
- } catch (IOException expected) {
- // Ignored.
- }
+ assertThrows(IOException.class, () -> server.getInputStream().read());
client.setSoTimeout(10);
- try {
- @SuppressWarnings("unused")
- int value = client.getInputStream().read();
- fail();
- } catch (IOException expected) {
- // Ignored.
- }
+ assertThrows(IOException.class, () -> client.getInputStream().read());
client.close();
server.close();
} catch (Exception maybeExpected) {
@@ -277,53 +256,42 @@
@Test
public void test_SSLSocket_getEnabledCipherSuites_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
+ }
}
@Test
public void test_SSLSocket_setEnabledCipherSuites_storesCopy() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- String[] array = new String[] {ssl.getEnabledCipherSuites()[0]};
- String originalFirstElement = array[0];
- ssl.setEnabledCipherSuites(array);
- array[0] = "Modified after having been set";
- assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ String[] array = new String[] {ssl.getEnabledCipherSuites()[0]};
+ String originalFirstElement = array[0];
+ ssl.setEnabledCipherSuites(array);
+ array[0] = "Modified after having been set";
+ assertEquals(originalFirstElement, ssl.getEnabledCipherSuites()[0]);
+ }
}
@Test
public void test_SSLSocket_setEnabledCipherSuites_TLS12() throws Exception {
SSLContext context = SSLContext.getInstance("TLSv1.2");
context.init(null, null, null);
- SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket();
- try {
- ssl.setEnabledCipherSuites(null);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
+ try (SSLSocket ssl = (SSLSocket) context.getSocketFactory().createSocket()) {
+ assertThrows(IllegalArgumentException.class, () -> ssl.setEnabledCipherSuites(null));
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledCipherSuites(new String[1]));
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledCipherSuites(new String[] {"Bogus"}));
+ ssl.setEnabledCipherSuites(new String[0]);
+ ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
+ ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
+ // Check that setEnabledCipherSuites affects getEnabledCipherSuites
+ String[] cipherSuites = new String[] {
+ TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())};
+ ssl.setEnabledCipherSuites(cipherSuites);
+ assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
}
- try {
- ssl.setEnabledCipherSuites(new String[1]);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- ssl.setEnabledCipherSuites(new String[] {"Bogus"});
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- ssl.setEnabledCipherSuites(new String[0]);
- ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
- ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
- // Check that setEnabledCipherSuites affects getEnabledCipherSuites
- String[] cipherSuites = new String[] {
- TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
- };
- ssl.setEnabledCipherSuites(cipherSuites);
- assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
}
@Test
@@ -331,91 +299,79 @@
SSLContext context = SSLContext.getInstance("TLSv1.3");
context.init(null, null, null);
SSLSocketFactory sf = context.getSocketFactory();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- // The TLS 1.3 cipher suites should be enabled by default
- assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
- // Disabling them should be ignored
- ssl.setEnabledCipherSuites(new String[0]);
- assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ // The TLS 1.3 cipher suites should be enabled by default
+ assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ // Disabling them should be ignored
+ ssl.setEnabledCipherSuites(new String[0]);
+ assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
- ssl.setEnabledCipherSuites(new String[] {
- TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())
- });
- assertTrue(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ ssl.setEnabledCipherSuites(new String[] {
+ TestUtils.pickArbitraryNonTls13Suite(ssl.getSupportedCipherSuites())});
+ assertTrue(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
- // Disabling TLS 1.3 should disable 1.3 cipher suites
- ssl.setEnabledProtocols(new String[] { "TLSv1.2" });
- assertFalse(new HashSet<String>(Arrays.asList(ssl.getEnabledCipherSuites()))
- .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ // Disabling TLS 1.3 should disable 1.3 cipher suites
+ ssl.setEnabledProtocols(new String[] {"TLSv1.2"});
+ assertFalse(new HashSet<>(Arrays.asList(ssl.getEnabledCipherSuites()))
+ .containsAll(StandardNames.CIPHER_SUITES_TLS13));
+ }
}
@Test
public void test_SSLSocket_getSupportedProtocols_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
+ }
}
@Test
public void test_SSLSocket_getEnabledProtocols_returnsCopies() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
+ }
}
@Test
public void test_SSLSocket_setEnabledProtocols_storesCopy() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- String[] array = new String[] {ssl.getEnabledProtocols()[0]};
- String originalFirstElement = array[0];
- ssl.setEnabledProtocols(array);
- array[0] = "Modified after having been set";
- assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ String[] array = new String[] {ssl.getEnabledProtocols()[0]};
+ String originalFirstElement = array[0];
+ ssl.setEnabledProtocols(array);
+ array[0] = "Modified after having been set";
+ assertEquals(originalFirstElement, ssl.getEnabledProtocols()[0]);
+ }
}
@Test
public void test_SSLSocket_setEnabledProtocols() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- try {
- ssl.setEnabledProtocols(null);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- ssl.setEnabledProtocols(new String[1]);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- ssl.setEnabledProtocols(new String[] {"Bogus"});
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- ssl.setEnabledProtocols(new String[0]);
- ssl.setEnabledProtocols(ssl.getEnabledProtocols());
- ssl.setEnabledProtocols(ssl.getSupportedProtocols());
- // Check that setEnabledProtocols affects getEnabledProtocols
- for (String protocol : ssl.getSupportedProtocols()) {
- if ("SSLv2Hello".equals(protocol)) {
- try {
- ssl.setEnabledProtocols(new String[] {protocol});
- fail("Should fail when SSLv2Hello is set by itself");
- } catch (IllegalArgumentException expected) {
- // Ignored.
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ assertThrows(IllegalArgumentException.class, () -> ssl.setEnabledProtocols(null));
+ assertThrows(
+ IllegalArgumentException.class, () -> ssl.setEnabledProtocols(new String[1]));
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledProtocols(new String[] {"Bogus"}));
+ ssl.setEnabledProtocols(new String[0]);
+ ssl.setEnabledProtocols(ssl.getEnabledProtocols());
+ ssl.setEnabledProtocols(ssl.getSupportedProtocols());
+ // Check that setEnabledProtocols affects getEnabledProtocols
+ for (String protocol : ssl.getSupportedProtocols()) {
+ if ("SSLv2Hello".equals(protocol)) {
+ // Should fail when SSLv2Hello is set by itself
+ assertThrows(IllegalArgumentException.class,
+ () -> ssl.setEnabledProtocols(new String[] {protocol}));
+ } else {
+ String[] protocols = new String[] {protocol};
+ ssl.setEnabledProtocols(protocols);
+ assertEquals(Arrays.deepToString(protocols),
+ Arrays.deepToString(ssl.getEnabledProtocols()));
}
- } else {
- String[] protocols = new String[] {protocol};
- ssl.setEnabledProtocols(protocols);
- assertEquals(Arrays.deepToString(protocols),
- Arrays.deepToString(ssl.getEnabledProtocols()));
}
}
}
@@ -434,11 +390,9 @@
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
server.setEnabledProtocols(new String[] {"TLSv1.3", "TLSv1.2", "TLSv1.1"});
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -465,11 +419,9 @@
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -485,18 +437,20 @@
@Test
public void test_SSLSocket_getSession() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- SSLSession session = ssl.getSession();
- assertNotNull(session);
- assertFalse(session.isValid());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ SSLSession session = ssl.getSession();
+ assertNotNull(session);
+ assertFalse(session.isValid());
+ }
}
@Test
public void test_SSLSocket_getHandshakeSession_unconnected() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket socket = (SSLSocket) sf.createSocket();
- SSLSession session = socket.getHandshakeSession();
- assertNull(session);
+ try (SSLSocket socket = (SSLSocket) sf.createSocket()) {
+ SSLSession session = socket.getHandshakeSession();
+ assertNull(session);
+ }
}
@Test
@@ -574,11 +528,9 @@
clientContext.getSocketFactory().createSocket(c.host, c.port);
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -677,12 +629,10 @@
clientContext.getSocketFactory().createSocket(c.host, c.port);
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
ExecutorService executor = Executors.newSingleThreadExecutor();
- Future<Void> future = executor.submit(new Callable<Void>() {
- @Override public Void call() throws Exception {
- server.setNeedClientAuth(true);
- server.startHandshake();
- return null;
- }
+ Future<Void> future = executor.submit(() -> {
+ server.setNeedClientAuth(true);
+ server.startHandshake();
+ return null;
});
executor.shutdown();
client.startHandshake();
@@ -695,21 +645,11 @@
}
@Test
- public void test_SSLSocket_setUseClientMode_afterHandshake() throws Exception {
+ public void test_SSLSocket_setUseClientMode_afterHandshake() {
// can't set after handshake
TestSSLSocketPair pair = TestSSLSocketPair.create().connect();
- try {
- pair.server.setUseClientMode(false);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
- try {
- pair.client.setUseClientMode(false);
- fail();
- } catch (IllegalArgumentException expected) {
- // Ignored.
- }
+ assertThrows(IllegalArgumentException.class, () -> pair.server.setUseClientMode(true));
+ assertThrows(IllegalArgumentException.class, () -> pair.client.setUseClientMode(false));
}
@Test
@@ -719,24 +659,14 @@
SSLSocket client =
(SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, c.port);
final SSLSocket server = (SSLSocket) c.serverSocket.accept();
- Future<Void> future = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- try {
- server.startHandshake();
- fail();
- } catch (SSLHandshakeException expected) {
- // Ignored.
- }
- return null;
- }
+ Future<Void> future = runAsync(() -> {
+ assertThrows(SSLHandshakeException.class, server::startHandshake);
+ return null;
});
- try {
- client.startHandshake();
- fail();
- } catch (SSLHandshakeException expected) {
- assertTrue(expected.getCause() instanceof CertificateException);
- }
+ SSLHandshakeException expected =
+ assertThrows(SSLHandshakeException.class, client::startHandshake);
+ assertTrue(expected.getCause() instanceof CertificateException);
+
future.get();
client.close();
server.close();
@@ -747,90 +677,96 @@
public void test_SSLSocket_getSSLParameters() throws Exception {
TestUtils.assumeSetEndpointIdentificationAlgorithmAvailable();
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- SSLParameters p = ssl.getSSLParameters();
- assertNotNull(p);
- String[] cipherSuites = p.getCipherSuites();
- assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
- assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
- String[] protocols = p.getProtocols();
- assertNotSame(protocols, ssl.getEnabledProtocols());
- assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
- assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
- assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
- assertNull(p.getEndpointIdentificationAlgorithm());
- p.setEndpointIdentificationAlgorithm(null);
- assertNull(p.getEndpointIdentificationAlgorithm());
- p.setEndpointIdentificationAlgorithm("HTTPS");
- assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
- p.setEndpointIdentificationAlgorithm("FOO");
- assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ SSLParameters p = ssl.getSSLParameters();
+ assertNotNull(p);
+ String[] cipherSuites = p.getCipherSuites();
+ assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
+ assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
+ String[] protocols = p.getProtocols();
+ assertNotSame(protocols, ssl.getEnabledProtocols());
+ assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
+ assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
+ assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
+ assertNull(p.getEndpointIdentificationAlgorithm());
+ p.setEndpointIdentificationAlgorithm(null);
+ assertNull(p.getEndpointIdentificationAlgorithm());
+ p.setEndpointIdentificationAlgorithm("HTTPS");
+ assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
+ p.setEndpointIdentificationAlgorithm("FOO");
+ assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
+ }
}
@Test
public void test_SSLSocket_setSSLParameters() throws Exception {
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- SSLSocket ssl = (SSLSocket) sf.createSocket();
- String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
- String[] defaultProtocols = ssl.getEnabledProtocols();
- String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
- String[] supportedProtocols = ssl.getSupportedProtocols();
- {
- SSLParameters p = new SSLParameters();
- ssl.setSSLParameters(p);
- assertEquals(Arrays.asList(defaultCipherSuites),
- Arrays.asList(ssl.getEnabledCipherSuites()));
- assertEquals(Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
- }
- {
- SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
- ssl.setSSLParameters(p);
- assertEquals(Arrays.asList(supportedCipherSuites),
- Arrays.asList(ssl.getEnabledCipherSuites()));
- assertEquals(
- Arrays.asList(supportedProtocols), Arrays.asList(ssl.getEnabledProtocols()));
- }
- {
- SSLParameters p = new SSLParameters();
- p.setNeedClientAuth(true);
- assertFalse(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
- ssl.setSSLParameters(p);
- assertTrue(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
- p.setWantClientAuth(true);
- assertTrue(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
- ssl.setSSLParameters(p);
- assertFalse(ssl.getNeedClientAuth());
- assertTrue(ssl.getWantClientAuth());
- p.setWantClientAuth(false);
- assertFalse(ssl.getNeedClientAuth());
- assertTrue(ssl.getWantClientAuth());
- ssl.setSSLParameters(p);
- assertFalse(ssl.getNeedClientAuth());
- assertFalse(ssl.getWantClientAuth());
+ try (SSLSocket ssl = (SSLSocket) sf.createSocket()) {
+ String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
+ String[] defaultProtocols = ssl.getEnabledProtocols();
+ String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
+ String[] supportedProtocols = ssl.getSupportedProtocols();
+ {
+ SSLParameters p = new SSLParameters();
+ ssl.setSSLParameters(p);
+ assertEquals(Arrays.asList(defaultCipherSuites),
+ Arrays.asList(ssl.getEnabledCipherSuites()));
+ assertEquals(
+ Arrays.asList(defaultProtocols), Arrays.asList(ssl.getEnabledProtocols()));
+ }
+ {
+ SSLParameters p = new SSLParameters(supportedCipherSuites, supportedProtocols);
+ ssl.setSSLParameters(p);
+ assertEquals(Arrays.asList(supportedCipherSuites),
+ Arrays.asList(ssl.getEnabledCipherSuites()));
+ assertEquals(Arrays.asList(supportedProtocols),
+ Arrays.asList(ssl.getEnabledProtocols()));
+ }
+ {
+ SSLParameters p = new SSLParameters();
+ p.setNeedClientAuth(true);
+ assertFalse(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ ssl.setSSLParameters(p);
+ assertTrue(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ p.setWantClientAuth(true);
+ assertTrue(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ ssl.setSSLParameters(p);
+ assertFalse(ssl.getNeedClientAuth());
+ assertTrue(ssl.getWantClientAuth());
+ p.setWantClientAuth(false);
+ assertFalse(ssl.getNeedClientAuth());
+ assertTrue(ssl.getWantClientAuth());
+ ssl.setSSLParameters(p);
+ assertFalse(ssl.getNeedClientAuth());
+ assertFalse(ssl.getWantClientAuth());
+ }
}
}
@Test
public void test_SSLSocket_setSoTimeout_basic() throws Exception {
- ServerSocket listening = new ServerSocket(0);
- Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
- assertEquals(0, underlying.getSoTimeout());
- SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
- Socket wrapping = sf.createSocket(underlying, null, -1, false);
- assertEquals(0, wrapping.getSoTimeout());
- // setting wrapper sets underlying and ...
- int expectedTimeoutMillis = 1000; // 10 was too small because it was affected by rounding
- wrapping.setSoTimeout(expectedTimeoutMillis);
- // The kernel can round the requested value based on the HZ setting. We allow up to 10ms.
- assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
- assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
- // ... getting wrapper inspects underlying
- underlying.setSoTimeout(0);
- assertEquals(0, wrapping.getSoTimeout());
- assertEquals(0, underlying.getSoTimeout());
+ try (ServerSocket listening = new ServerSocket(0)) {
+ Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
+ assertEquals(0, underlying.getSoTimeout());
+ SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
+ Socket wrapping = sf.createSocket(underlying, null, -1, false);
+ assertEquals(0, wrapping.getSoTimeout());
+ // setting wrapper sets underlying and ...
+ int expectedTimeoutMillis =
+ 1000; // 10 was too small because it was affected by rounding
+ wrapping.setSoTimeout(expectedTimeoutMillis);
+ // The kernel can round the requested value based on the HZ setting. We allow up to
+ // 10ms.
+ assertTrue(Math.abs(expectedTimeoutMillis - wrapping.getSoTimeout()) <= 10);
+ assertTrue(Math.abs(expectedTimeoutMillis - underlying.getSoTimeout()) <= 10);
+ // ... getting wrapper inspects underlying
+ underlying.setSoTimeout(0);
+ assertEquals(0, wrapping.getSoTimeout());
+ assertEquals(0, underlying.getSoTimeout());
+ }
}
@Test
@@ -842,13 +778,7 @@
SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
Socket clientWrapping = sf.createSocket(underlying, null, -1, false);
underlying.setSoTimeout(1);
- try {
- @SuppressWarnings("unused")
- int value = clientWrapping.getInputStream().read();
- fail();
- } catch (SocketTimeoutException expected) {
- // Ignored.
- }
+ assertThrows(SocketTimeoutException.class, () -> clientWrapping.getInputStream().read());
clientWrapping.close();
server.close();
underlying.close();
@@ -874,90 +804,76 @@
@Test
public void test_SSLSocket_ClientHello_cipherSuites() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- final String[] cipherSuites;
- // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
- // a special signaling cipher suite. The TLS API has no way to check or
- // indicate that a certain TLS extension should be used.
- HelloExtension renegotiationInfoExtension =
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello =
+ TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ final String[] cipherSuites;
+ // RFC 5746 allows you to send an empty "renegotiation_info" extension *or*
+ // a special signaling cipher suite. The TLS API has no way to check or
+ // indicate that a certain TLS extension should be used.
+ HelloExtension renegotiationInfoExtension =
clientHello.findExtensionByType(HelloExtension.TYPE_RENEGOTIATION_INFO);
- if (renegotiationInfoExtension != null
- && renegotiationInfoExtension.data.length == 1
+ if (renegotiationInfoExtension != null && renegotiationInfoExtension.data.length == 1
&& renegotiationInfoExtension.data[0] == 0) {
- cipherSuites = new String[clientHello.cipherSuites.size() + 1];
- cipherSuites[clientHello.cipherSuites.size()] =
+ cipherSuites = new String[clientHello.cipherSuites.size() + 1];
+ cipherSuites[clientHello.cipherSuites.size()] =
StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION;
- } else {
- cipherSuites = new String[clientHello.cipherSuites.size()];
- }
- for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
- CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
- cipherSuites[i] = cipherSuite.getAndroidName();
- }
- StandardNames.assertDefaultCipherSuites(cipherSuites);
+ } else {
+ cipherSuites = new String[clientHello.cipherSuites.size()];
}
+ for (int i = 0; i < clientHello.cipherSuites.size(); i++) {
+ CipherSuite cipherSuite = clientHello.cipherSuites.get(i);
+ cipherSuites[i] = cipherSuite.getAndroidName();
+ }
+ StandardNames.assertDefaultCipherSuites(cipherSuites);
}, getSSLSocketFactoriesToTest());
}
@Test
public void test_SSLSocket_ClientHello_supportedCurves() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- EllipticCurvesHelloExtension ecExtension =
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello =
+ TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ EllipticCurvesHelloExtension ecExtension =
(EllipticCurvesHelloExtension) clientHello.findExtensionByType(
- HelloExtension.TYPE_ELLIPTIC_CURVES);
- final String[] supportedCurves;
- if (ecExtension == null) {
- supportedCurves = new String[0];
- } else {
- assertTrue(ecExtension.wellFormed);
- supportedCurves = new String[ecExtension.supported.size()];
- for (int i = 0; i < ecExtension.supported.size(); i++) {
- EllipticCurve curve = ecExtension.supported.get(i);
- supportedCurves[i] = curve.toString();
- }
+ HelloExtension.TYPE_ELLIPTIC_CURVES);
+ final String[] supportedCurves;
+ if (ecExtension == null) {
+ supportedCurves = new String[0];
+ } else {
+ assertTrue(ecExtension.wellFormed);
+ supportedCurves = new String[ecExtension.supported.size()];
+ for (int i = 0; i < ecExtension.supported.size(); i++) {
+ EllipticCurve curve = ecExtension.supported.get(i);
+ supportedCurves[i] = curve.toString();
}
- StandardNames.assertDefaultEllipticCurves(supportedCurves);
}
+ StandardNames.assertDefaultEllipticCurves(supportedCurves);
}, getSSLSocketFactoriesToTest());
}
@Test
public void test_SSLSocket_ClientHello_clientProtocolVersion() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
- }
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello =
+ TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ assertEquals(TlsProtocolVersion.TLSv1_2, clientHello.clientVersion);
}, getSSLSocketFactoriesToTest());
}
@Test
public void test_SSLSocket_ClientHello_compressionMethods() throws Exception {
- ForEachRunner.runNamed(new ForEachRunner.Callback<SSLSocketFactory>() {
- @Override
- public void run(SSLSocketFactory sslSocketFactory) throws Exception {
- ClientHello clientHello = TlsTester
- .captureTlsHandshakeClientHello(executor, sslSocketFactory);
- assertEquals(Collections.singletonList(CompressionMethod.NULL),
+ ForEachRunner.runNamed(sslSocketFactory -> {
+ ClientHello clientHello =
+ TlsTester.captureTlsHandshakeClientHello(executor, sslSocketFactory);
+ assertEquals(Collections.singletonList(CompressionMethod.NULL),
clientHello.compressionMethods);
- }
}, getSSLSocketFactoriesToTest());
}
private List<Pair<String, SSLSocketFactory>> getSSLSocketFactoriesToTest()
throws NoSuchAlgorithmException, KeyManagementException {
- List<Pair<String, SSLSocketFactory>> result =
- new ArrayList<Pair<String, SSLSocketFactory>>();
+ List<Pair<String, SSLSocketFactory>> result = new ArrayList<>();
result.add(Pair.of("default", (SSLSocketFactory) SSLSocketFactory.getDefault()));
for (String sslContextProtocol : StandardNames.SSL_CONTEXT_PROTOCOLS_WITH_DEFAULT_CONFIG) {
SSLContext sslContext = SSLContext.getInstance(sslContextProtocol);
@@ -981,23 +897,17 @@
final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
- Future<Void> s = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- server.setEnabledProtocols(new String[]{"TLSv1.2"});
- server.setEnabledCipherSuites(serverCipherSuites);
- server.startHandshake();
- return null;
- }
+ Future<Void> s = runAsync(() -> {
+ server.setEnabledProtocols(new String[] {"TLSv1.2"});
+ server.setEnabledCipherSuites(serverCipherSuites);
+ server.startHandshake();
+ return null;
});
- Future<Void> c = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- client.setEnabledProtocols(new String[]{"TLSv1.2"});
- client.setEnabledCipherSuites(clientCipherSuites);
- client.startHandshake();
- return null;
- }
+ Future<Void> c = runAsync(() -> {
+ client.setEnabledProtocols(new String[] {"TLSv1.2"});
+ client.setEnabledCipherSuites(clientCipherSuites);
+ client.startHandshake();
+ return null;
});
s.get();
c.get();
@@ -1016,21 +926,15 @@
// Confirm absence of TLS_FALLBACK_SCSV.
assertFalse(Arrays.asList(client.getEnabledCipherSuites())
.contains(StandardNames.CIPHER_SUITE_FALLBACK));
- Future<Void> s = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- server.setEnabledProtocols(new String[]{"TLSv1.2", "TLSv1.1"});
- server.startHandshake();
- return null;
- }
+ Future<Void> s = runAsync(() -> {
+ server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
+ server.startHandshake();
+ return null;
});
- Future<Void> c = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- client.setEnabledProtocols(new String[]{"TLSv1.1"});
- client.startHandshake();
- return null;
- }
+ Future<Void> c = runAsync(() -> {
+ client.setEnabledProtocols(new String[] {"TLSv1.1"});
+ client.startHandshake();
+ return null;
});
s.get();
c.get();
@@ -1057,37 +961,25 @@
final String[] clientCipherSuites = new String[serverCipherSuites.length + 1];
System.arraycopy(serverCipherSuites, 0, clientCipherSuites, 0, serverCipherSuites.length);
clientCipherSuites[serverCipherSuites.length] = StandardNames.CIPHER_SUITE_FALLBACK;
- Future<Void> s = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
- server.setEnabledCipherSuites(serverCipherSuites);
- try {
- server.startHandshake();
- fail("Should result in inappropriate fallback");
- } catch (SSLHandshakeException expected) {
- Throwable cause = expected.getCause();
- assertEquals(SSLProtocolException.class, cause.getClass());
- assertInappropriateFallbackIsCause(cause);
- }
- return null;
- }
+ Future<Void> s = runAsync(() -> {
+ server.setEnabledProtocols(new String[] {"TLSv1.2", "TLSv1.1"});
+ server.setEnabledCipherSuites(serverCipherSuites);
+ SSLHandshakeException expected =
+ assertThrows(SSLHandshakeException.class, server::startHandshake);
+ Throwable cause = expected.getCause();
+ assertEquals(SSLProtocolException.class, cause.getClass());
+ assertInappropriateFallbackIsCause(cause);
+ return null;
});
- Future<Void> c = runAsync(new Callable<Void>() {
- @Override
- public Void call() throws Exception {
- client.setEnabledProtocols(new String[] {"TLSv1.1"});
- client.setEnabledCipherSuites(clientCipherSuites);
- try {
- client.startHandshake();
- fail("Should receive TLS alert inappropriate fallback");
- } catch (SSLHandshakeException expected) {
- Throwable cause = expected.getCause();
- assertEquals(SSLProtocolException.class, cause.getClass());
- assertInappropriateFallbackIsCause(cause);
- }
- return null;
- }
+ Future<Void> c = runAsync(() -> {
+ client.setEnabledProtocols(new String[] {"TLSv1.1"});
+ client.setEnabledCipherSuites(clientCipherSuites);
+ SSLHandshakeException expected =
+ assertThrows(SSLHandshakeException.class, client::startHandshake);
+ Throwable cause = expected.getCause();
+ assertEquals(SSLProtocolException.class, cause.getClass());
+ assertInappropriateFallbackIsCause(cause);
+ return null;
});
s.get();
c.get();
@@ -1122,6 +1014,74 @@
}
}
+ @Test
+ public void handshakeListenersRunExactlyOnce() {
+ AtomicInteger count = new AtomicInteger(0);
+ TestSSLSocketPair pair = TestSSLSocketPair.create();
+ pair.client.addHandshakeCompletedListener(event -> count.addAndGet(1));
+ pair.client.addHandshakeCompletedListener(event -> count.addAndGet(2));
+ pair.client.addHandshakeCompletedListener(event -> count.addAndGet(4));
+ pair.connect();
+ assertEquals(1 + 2 + 4, count.get());
+ }
+
+ @Test
+ public void closeFromHandshakeListener() throws Exception {
+ TestUtils.assumeEngineSocket();
+
+ TestSSLSocketPair pair = TestSSLSocketPair.create();
+ pair.client.addHandshakeCompletedListener(event -> socketClose(pair.client));
+ Future<Void> serverFuture = runAsync((Callable<Void>) () -> {
+ pair.server.startHandshake();
+ return null;
+ });
+ pair.client.startHandshake();
+ assertThrows(SocketException.class, pair.client::getInputStream);
+ serverFuture.get();
+ InputStream istream = pair.server.getInputStream();
+ assertEquals(-1, istream.read());
+ }
+
+ @Test
+ public void writeFromHandshakeListener() throws Exception {
+ TestUtils.assumeEngineSocket();
+
+ byte[] ping = "ping".getBytes(UTF_8);
+ byte[] pong = "pong".getBytes(UTF_8);
+ TestSSLSocketPair pair = TestSSLSocketPair.create();
+ pair.client.addHandshakeCompletedListener(event -> socketWrite(pair.client, ping));
+ pair.server.addHandshakeCompletedListener(event -> socketWrite(pair.server, pong));
+ Future<Void> serverFuture = runAsync(() -> {
+ pair.server.startHandshake();
+ return null;
+ });
+ byte[] buffer = new byte[4];
+ InputStream clientStream = pair.client.getInputStream();
+ assertEquals(4, clientStream.read(buffer));
+ assertArrayEquals(pong, buffer);
+
+ serverFuture.get();
+ InputStream serverStream = pair.server.getInputStream();
+ assertEquals(4, serverStream.read(buffer));
+ assertArrayEquals(ping, buffer);
+ }
+
+ private void socketClose(Socket socket) {
+ try {
+ socket.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void socketWrite(Socket socket, byte[] data) {
+ try {
+ socket.getOutputStream().write(data);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
private <T> Future<T> runAsync(Callable<T> callable) {
return executor.submit(callable);
}
@@ -1138,5 +1098,4 @@
byteCount -= bytesRead;
}
}
-
}