Adding all factory methods for engine socket. (#192)
Also properly throwing SSLHandshakeException in some cases.
Fixes #191
diff --git a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
index c42a0ff..d12d4d8 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
@@ -369,7 +369,7 @@
String logMessage = String.format("ssl_unexpected_ccs: host=%s", getSniHostname());
Platform.logEvent(logMessage);
}
- throw new SSLException(e);
+ throw SSLUtils.toSSLHandshakeException(e);
} finally {
if (releaseResources) {
engineState = EngineState.CLOSED;
@@ -1317,6 +1317,11 @@
public void onSSLStateChange(int type, int val) {
synchronized (stateLock) {
switch (type) {
+ case SSL_CB_HANDSHAKE_START:
+ // For clients, this will allow the NEED_UNWRAP status to be
+ // returned.
+ engineState = EngineState.HANDSHAKE_STARTED;
+ break;
case SSL_CB_HANDSHAKE_DONE:
if (engineState != EngineState.HANDSHAKE_STARTED
&& engineState != EngineState.READY_HANDSHAKE_CUT_THROUGH) {
@@ -1325,11 +1330,7 @@
}
engineState = EngineState.HANDSHAKE_COMPLETED;
break;
- case SSL_CB_HANDSHAKE_START:
- // For clients, this will allow the NEED_UNWRAP status to be
- // returned.
- engineState = EngineState.HANDSHAKE_STARTED;
- break;
+
}
}
}
diff --git a/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java b/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
index eb06c9a..7e711d8 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
@@ -23,6 +23,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
+import java.net.InetAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
@@ -49,6 +50,31 @@
private final InputStreamWrapper inputStreamWrapper;
private boolean handshakeComplete;
+ OpenSSLEngineSocketImpl(SSLParametersImpl sslParameters) throws IOException {
+ this(new Socket(), null, -1, false, sslParameters);
+ }
+
+ OpenSSLEngineSocketImpl(String host, int port, SSLParametersImpl sslParameters)
+ throws IOException {
+ this(new Socket(host, port), host, port, false, sslParameters);
+ }
+
+ OpenSSLEngineSocketImpl(String host, int port, InetAddress clientAddress, int clientPort,
+ SSLParametersImpl sslParameters) throws IOException {
+ this(new Socket(host, port, clientAddress, clientPort), host, port, false, sslParameters);
+ }
+
+ OpenSSLEngineSocketImpl(InetAddress address, int port, SSLParametersImpl sslParameters)
+ throws IOException {
+ this(new Socket(address, port), null, port, false, sslParameters);
+ }
+
+ OpenSSLEngineSocketImpl(InetAddress address, int port, InetAddress clientAddress,
+ int clientPort, SSLParametersImpl sslParameters) throws IOException {
+ this(new Socket(address, port, clientAddress, clientPort), null, port, false,
+ sslParameters);
+ }
+
OpenSSLEngineSocketImpl(Socket socket, String hostname, int port, boolean autoClose,
SSLParametersImpl sslParameters) throws IOException {
super(socket, hostname, port, autoClose, sslParameters);
@@ -72,37 +98,44 @@
@Override
public void startHandshake() throws IOException {
- // Trigger the handshake
- boolean beginHandshakeCalled = false;
- while (!handshakeComplete) {
- switch (engine.getHandshakeStatus()) {
- case NOT_HANDSHAKING: {
- if (!beginHandshakeCalled) {
- beginHandshakeCalled = true;
- engine.beginHandshake();
+ try {
+ // Trigger the handshake
+ boolean beginHandshakeCalled = false;
+ while (!handshakeComplete) {
+ switch (engine.getHandshakeStatus()) {
+ case NOT_HANDSHAKING: {
+ if (!beginHandshakeCalled) {
+ beginHandshakeCalled = true;
+ engine.beginHandshake();
+ break;
+ }
break;
}
- break;
- }
- case FINISHED: {
- return;
- }
- case NEED_WRAP: {
- outputStreamWrapper.write(EMPTY_BUFFER);
- break;
- }
- case NEED_UNWRAP: {
- if (inputStreamWrapper.read(EmptyArray.BYTE) == -1) {
- // Can't complete the handshake due to EOF.
- throw SSLUtils.toSSLHandshakeException(new EOFException());
+ case FINISHED: {
+ return;
}
- break;
+ case NEED_WRAP: {
+ outputStreamWrapper.write(EMPTY_BUFFER);
+ break;
+ }
+ case NEED_UNWRAP: {
+ if (inputStreamWrapper.read(EmptyArray.BYTE) == -1) {
+ // Can't complete the handshake due to EOF.
+ throw SSLUtils.toSSLHandshakeException(new EOFException());
+ }
+ break;
+ }
+ case NEED_TASK: {
+ throw new IllegalStateException("OpenSSLEngineImpl returned NEED_TASK");
+ }
+ default: {
+ break;
+ }
}
- case NEED_TASK: {
- throw new IllegalStateException("OpenSSLEngineImpl returned NEED_TASK");
- }
- default: { break; }
}
+ } catch (Exception e) {
+ close();
+ throw SSLUtils.toSSLHandshakeException(e);
}
}
@@ -394,10 +427,8 @@
}
} while (len > 0);
} catch (IOException e) {
- e.printStackTrace();
throw e;
} catch (RuntimeException e) {
- e.printStackTrace();
throw e;
}
}
@@ -541,10 +572,8 @@
// Continue the loop and return the data from the engine buffer.
}
} catch (IOException e) {
- e.printStackTrace();
throw e;
} catch (RuntimeException e) {
- e.printStackTrace();
throw e;
}
}
diff --git a/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java b/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java
index 8e38e53..1794d9b 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLSocketFactoryImpl.java
@@ -80,27 +80,47 @@
if (instantiationException != null) {
throw instantiationException;
}
- return new OpenSSLSocketImpl((SSLParametersImpl) sslParameters.clone());
+ if (useEngineSocket) {
+ return new OpenSSLEngineSocketImpl((SSLParametersImpl) sslParameters.clone());
+ } else {
+ return new OpenSSLSocketImpl((SSLParametersImpl) sslParameters.clone());
+ }
}
@Override
public Socket createSocket(String hostname, int port) throws IOException, UnknownHostException {
- return new OpenSSLSocketImpl(hostname, port, (SSLParametersImpl) sslParameters.clone());
+ if (useEngineSocket) {
+ return new OpenSSLEngineSocketImpl(hostname, port, (SSLParametersImpl) sslParameters.clone());
+ } else {
+ return new OpenSSLSocketImpl(hostname, port, (SSLParametersImpl) sslParameters.clone());
+ }
}
@Override
public Socket createSocket(String hostname, int port, InetAddress localHost, int localPort)
throws IOException, UnknownHostException {
- return new OpenSSLSocketImpl(hostname,
- port,
- localHost,
- localPort,
- (SSLParametersImpl) sslParameters.clone());
+ if (useEngineSocket) {
+ return new OpenSSLEngineSocketImpl(hostname,
+ port,
+ localHost,
+ localPort,
+ (SSLParametersImpl) sslParameters.clone());
+ } else {
+ return new OpenSSLSocketImpl(hostname,
+ port,
+ localHost,
+ localPort,
+ (SSLParametersImpl) sslParameters.clone());
+ }
}
@Override
public Socket createSocket(InetAddress address, int port) throws IOException {
- return new OpenSSLSocketImpl(address, port, (SSLParametersImpl) sslParameters.clone());
+ if (useEngineSocket) {
+ return new OpenSSLEngineSocketImpl(address, port, (SSLParametersImpl) sslParameters.clone());
+ } else {
+ return new OpenSSLSocketImpl(address, port, (SSLParametersImpl) sslParameters.clone());
+ }
}
@Override
@@ -109,25 +129,25 @@
InetAddress localAddress,
int localPort)
throws IOException {
- return new OpenSSLSocketImpl(address,
- port,
- localAddress,
- localPort,
- (SSLParametersImpl) sslParameters.clone());
+ if (useEngineSocket) {
+ return new OpenSSLEngineSocketImpl(address,
+ port,
+ localAddress,
+ localPort,
+ (SSLParametersImpl) sslParameters.clone());
+ } else {
+ return new OpenSSLSocketImpl(address,
+ port,
+ localAddress,
+ localPort,
+ (SSLParametersImpl) sslParameters.clone());
+ }
}
@Override
public Socket createSocket(Socket s, String hostname, int port, boolean autoClose)
throws IOException {
- boolean socketHasFd = false;
- try {
- // If socket has a file descriptor we can use OpenSSLSocketImplWrapper directly
- // otherwise we need to use the engine.
- socketHasFd = Platform.getFileDescriptor(s) != null;
- } catch (RuntimeException re) {
- // Ignore
- }
- if (socketHasFd && !useEngineSocket) {
+ if (hasFileDescriptor(s) && !useEngineSocket) {
return new OpenSSLSocketImplWrapper(
s, hostname, port, autoClose, (SSLParametersImpl) sslParameters.clone());
} else {
@@ -135,4 +155,15 @@
s, hostname, port, autoClose, (SSLParametersImpl) sslParameters.clone());
}
}
+
+ private boolean hasFileDescriptor(Socket s) {
+ try {
+ // If socket has a file descriptor we can use OpenSSLSocketImplWrapper directly
+ // otherwise we need to use the engine.
+ Platform.getFileDescriptor(s);
+ return true;
+ } catch (RuntimeException re) {
+ return false;
+ }
+ }
}
diff --git a/common/src/main/java/org/conscrypt/OpenSSLSocketImplWrapper.java b/common/src/main/java/org/conscrypt/OpenSSLSocketImplWrapper.java
index 126dbb2..848ef7f 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLSocketImplWrapper.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLSocketImplWrapper.java
@@ -42,17 +42,17 @@
@Override
public void connect(SocketAddress sockaddr, int timeout)
throws IOException {
- throw new IOException("Underlying socket is already connected.");
+ socket.connect(sockaddr, timeout);
}
@Override
public void connect(SocketAddress sockaddr) throws IOException {
- throw new IOException("Underlying socket is already connected.");
+ socket.connect(sockaddr);
}
@Override
public void bind(SocketAddress sockaddr) throws IOException {
- throw new IOException("Underlying socket is already connected.");
+ socket.bind(sockaddr);
}
@Override
diff --git a/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java b/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java
index 929629a..8c59e8f 100644
--- a/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java
+++ b/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java
@@ -26,6 +26,7 @@
import java.io.IOException;
import java.lang.reflect.Field;
+import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.KeyManagementException;
@@ -65,8 +66,20 @@
* Various factories for SSL server sockets.
*/
public enum SocketType {
- DEFAULT(false),
- ENGINE(true);
+ DEFAULT(false) {
+ @Override
+ void assertSocketType(Socket socket) {
+ assertTrue("Unexpected socket type: " + socket.getClass().getName(),
+ socket instanceof OpenSSLSocketImpl);
+ }
+ },
+ ENGINE(true) {
+ @Override
+ void assertSocketType(Socket socket) {
+ assertTrue("Unexpected socket type: " + socket.getClass().getName(),
+ socket instanceof OpenSSLEngineSocketImpl);
+ }
+ };
private final boolean useEngineSocket;
@@ -80,6 +93,7 @@
Conscrypt.SocketFactories.setUseEngineSocket(factory, useEngineSocket);
OpenSSLSocketImpl socket = (OpenSSLSocketImpl) factory.createSocket(
listener.getInetAddress(), listener.getLocalPort());
+ assertSocketType(socket);
socket.setUseClientMode(true);
return socket;
}
@@ -91,9 +105,12 @@
OpenSSLSocketImpl socket = (OpenSSLSocketImpl) factory.createSocket(listener.accept(),
null, -1, // hostname, port
true); // autoclose
+ assertSocketType(socket);
socket.setUseClientMode(false);
return socket;
}
+
+ abstract void assertSocketType(Socket socket);
}
@Parameters(name = "{0}")
@@ -256,7 +273,7 @@
}
void doHandshake() throws Exception {
- ServerSocket listener = new ServerSocket(0);
+ ServerSocket listener = newServerSocket();
Future<OpenSSLSocketImpl> clientFuture = handshake(listener, clientHooks);
Future<OpenSSLSocketImpl> serverFuture = handshake(listener, serverHooks);
@@ -353,29 +370,18 @@
// http://b/27250522
@Test
public void test_setSoTimeout_doesNotCreateSocketImpl() throws Exception {
- ServerSocket listening = new ServerSocket(0);
+ ServerSocket listening = newServerSocket();
Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
- OpenSSLSocketImpl simpl;
- switch (socketType) {
- case DEFAULT:
- simpl = new OpenSSLSocketImpl(underlying, null, listening.getLocalPort(), false,
- SSLParametersImpl.getDefault());
- break;
- case ENGINE:
- simpl = new OpenSSLEngineSocketImpl(underlying, null, listening.getLocalPort(),
- false, SSLParametersImpl.getDefault());
- break;
- default:
- throw new IllegalArgumentException("Unexpected socketType " + socketType);
- }
-
- simpl.setSoTimeout(1000);
- simpl.close();
+ Socket socket = TestUtils.getConscryptSocketFactory(socketType == SocketType.ENGINE)
+ .createSocket(underlying, null, listening.getLocalPort(), false);
+ socketType.assertSocketType(socket);
+ socket.setSoTimeout(1000);
+ socket.close();
Field f = Socket.class.getDeclaredField("created");
f.setAccessible(true);
- assertFalse(f.getBoolean(simpl));
+ assertFalse(f.getBoolean(socket));
}
@Test
@@ -402,4 +408,8 @@
assertFalse(connection.clientHooks.isHandshakeCompleted);
assertFalse(connection.serverHooks.isHandshakeCompleted);
}
+
+ private static ServerSocket newServerSocket() throws IOException {
+ return new ServerSocket(0, 50, InetAddress.getLoopbackAddress());
+ }
}