Refactor OpenSSLSocketImplTest to cover both socket types (#182)

diff --git a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
index cd46f9a..aca9504 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLEngineImpl.java
@@ -399,7 +399,7 @@
             if (engineState == EngineState.CLOSED || engineState == EngineState.CLOSED_OUTBOUND) {
                 return;
             }
-            if (engineState != EngineState.MODE_SET && engineState != EngineState.NEW) {
+            if (isHandshakeStarted()) {
                 shutdownAndFreeSslNative();
             }
             if (engineState == EngineState.CLOSED_INBOUND) {
@@ -581,7 +581,7 @@
     @Override
     public void setUseClientMode(boolean mode) {
         synchronized (stateLock) {
-            if (engineState != EngineState.MODE_SET && engineState != EngineState.NEW) {
+            if (isHandshakeStarted()) {
                 throw new IllegalArgumentException(
                         "Can not change mode after handshake: engineState == " + engineState);
             }
@@ -1344,8 +1344,7 @@
                     NativeCrypto.SSL_get1_session(sslNativePointer), null, peerCertChain, ocspData,
                     tlsSctData, getSniHostname(), getPeerPort(), null);
 
-            boolean client = sslParameters.getUseClientMode();
-            if (client) {
+            if (getUseClientMode()) {
                 Platform.checkServerTrusted(x509tm, peerCertChain, authMethod, this);
             } else {
                 String authType = peerCertChain[0].getPublicKey().getAlgorithm();
diff --git a/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java b/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
index e87e9a0..eb06c9a 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLEngineSocketImpl.java
@@ -94,7 +94,7 @@
                 case NEED_UNWRAP: {
                     if (inputStreamWrapper.read(EmptyArray.BYTE) == -1) {
                         // Can't complete the handshake due to EOF.
-                        throw new EOFException();
+                        throw SSLUtils.toSSLHandshakeException(new EOFException());
                     }
                     break;
                 }
diff --git a/common/src/main/java/org/conscrypt/OpenSSLSocketImpl.java b/common/src/main/java/org/conscrypt/OpenSSLSocketImpl.java
index 09cedfa..b1dd921 100644
--- a/common/src/main/java/org/conscrypt/OpenSSLSocketImpl.java
+++ b/common/src/main/java/org/conscrypt/OpenSSLSocketImpl.java
@@ -438,7 +438,7 @@
                 try {
                     shutdownAndFreeSslNative();
                 } catch (IOException ignored) {
-
+                    // Ignored.
                 }
             }
         }
@@ -702,8 +702,6 @@
          * Reads one byte. If there is no data in the underlying buffer,
          * this operation can block until the data will be
          * available.
-         * @return read value.
-         * @throws IOException
          */
         @Override
         public int read() throws IOException {
@@ -1122,8 +1120,8 @@
     public void close() throws IOException {
         // TODO: Close SSL sockets using a background thread so they close gracefully.
 
-        SSLInputStream sslInputStream = null;
-        SSLOutputStream sslOutputStream = null;
+        SSLInputStream sslInputStream;
+        SSLOutputStream sslOutputStream;
 
         synchronized (stateLock) {
             if (state == STATE_CLOSED) {
diff --git a/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java b/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java
index 8689dbb..929629a 100644
--- a/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java
+++ b/openjdk/src/test/java/org/conscrypt/OpenSSLSocketImplTest.java
@@ -25,15 +25,15 @@
 import static org.junit.Assert.assertTrue;
 
 import java.io.IOException;
-import java.lang.reflect.Constructor;
 import java.lang.reflect.Field;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.security.KeyManagementException;
 import java.security.KeyStore;
 import java.security.PrivateKey;
 import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
-import java.util.concurrent.Callable;
+import java.util.Arrays;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -47,48 +47,103 @@
 import javax.net.ssl.SSLSocketFactory;
 import javax.net.ssl.TrustManager;
 import javax.net.ssl.TrustManagerFactory;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Ignore;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
 
+@RunWith(Parameterized.class)
 public class OpenSSLSocketImplTest {
     private static final long TIMEOUT_SECONDS = 5;
     private static final char[] EMPTY_PASSWORD = new char[0];
 
+    /**
+     * Various factories for SSL server sockets.
+     */
+    public enum SocketType {
+        DEFAULT(false),
+        ENGINE(true);
+
+        private final boolean useEngineSocket;
+
+        SocketType(boolean useEngineSocket) {
+            this.useEngineSocket = useEngineSocket;
+        }
+
+        OpenSSLSocketImpl createClientSocket(OpenSSLContextImpl context, ServerSocket listener)
+                throws IOException {
+            SSLSocketFactory factory = context.engineGetSocketFactory();
+            Conscrypt.SocketFactories.setUseEngineSocket(factory, useEngineSocket);
+            OpenSSLSocketImpl socket = (OpenSSLSocketImpl) factory.createSocket(
+                    listener.getInetAddress(), listener.getLocalPort());
+            socket.setUseClientMode(true);
+            return socket;
+        }
+
+        OpenSSLSocketImpl createServerSocket(OpenSSLContextImpl context, ServerSocket listener)
+                throws IOException {
+            SSLSocketFactory factory = context.engineGetSocketFactory();
+            Conscrypt.SocketFactories.setUseEngineSocket(factory, useEngineSocket);
+            OpenSSLSocketImpl socket = (OpenSSLSocketImpl) factory.createSocket(listener.accept(),
+                    null, -1, // hostname, port
+                    true); // autoclose
+            socket.setUseClientMode(false);
+            return socket;
+        }
+    }
+
+    @Parameters(name = "{0}")
+    public static Iterable<SocketType> data() {
+        return Arrays.asList(SocketType.DEFAULT, SocketType.ENGINE);
+    }
+
+    @Parameter public SocketType socketType;
+
     private X509Certificate ca;
     private X509Certificate cert;
     private X509Certificate certEmbedded;
     private PrivateKey certKey;
 
     private Field contextSSLParameters;
-    private Field sslParametersTrustManager;
+    private ExecutorService executor;
 
     @Before
     public void setUp() throws Exception {
         contextSSLParameters = OpenSSLContextImpl.class.getDeclaredField("sslParameters");
         contextSSLParameters.setAccessible(true);
 
-        sslParametersTrustManager = SSLParametersImpl.class.getDeclaredField("x509TrustManager");
-        sslParametersTrustManager.setAccessible(true);
-
         ca = OpenSSLX509Certificate.fromX509PemInputStream(openTestFile("ca-cert.pem"));
         cert = OpenSSLX509Certificate.fromX509PemInputStream(openTestFile("cert.pem"));
         certEmbedded =
                 OpenSSLX509Certificate.fromX509PemInputStream(openTestFile("cert-ct-embedded.pem"));
         certKey = OpenSSLKey.fromPrivateKeyPemInputStream(openTestFile("cert-key.pem"))
                           .getPrivateKey();
+        executor = Executors.newCachedThreadPool();
+    }
+
+    @After
+    public void teardown() throws Exception {
+        executor.shutdown();
+        executor.awaitTermination(5, TimeUnit.SECONDS);
     }
 
     abstract class Hooks implements HandshakeCompletedListener {
         KeyManager[] keyManagers;
         TrustManager[] trustManagers;
 
-        abstract OpenSSLSocketImpl createSocket(SSLSocketFactory factory, ServerSocket listener)
-                throws IOException;
+        abstract OpenSSLSocketImpl createSocket(ServerSocket listener) throws IOException;
 
-        public OpenSSLContextImpl createContext() throws Exception {
+        OpenSSLContextImpl createContext() throws IOException {
             OpenSSLContextImpl context = OpenSSLContextImpl.getPreferred();
-            context.engineInit(keyManagers, trustManagers, null);
+            try {
+                context.engineInit(keyManagers, trustManagers, null);
+            } catch (KeyManagementException e) {
+                throw new IOException(e);
+            }
             return context;
         }
 
@@ -98,39 +153,31 @@
             isHandshakeCompleted = true;
         }
 
-        protected SSLParametersImpl getContextSSLParameters(OpenSSLContextImpl context)
+        SSLParametersImpl getContextSSLParameters(OpenSSLContextImpl context)
                 throws IllegalAccessException {
             return (SSLParametersImpl) contextSSLParameters.get(context);
         }
-
-        protected TrustManager getSSLParametersTrustManager(SSLParametersImpl params)
-                throws IllegalAccessException {
-            return (TrustManager) sslParametersTrustManager.get(params);
-        }
     }
 
     class ClientHooks extends Hooks {
-        boolean ctVerificationEnabled;
         String hostname = "example.com";
 
         @Override
-        public OpenSSLContextImpl createContext() throws Exception {
+        public OpenSSLContextImpl createContext() throws IOException {
             OpenSSLContextImpl context = super.createContext();
-            SSLParametersImpl sslParameters = getContextSSLParameters(context);
-            if (ctVerificationEnabled) {
-                sslParameters.setCTVerificationEnabled(ctVerificationEnabled);
+            try {
+                SSLParametersImpl sslParameters = getContextSSLParameters(context);
+                sslParameters.setCTVerificationEnabled(true);
+            } catch (IllegalAccessException e) {
+                throw new IOException(e);
             }
             return context;
         }
 
         @Override
-        public OpenSSLSocketImpl createSocket(SSLSocketFactory factory, ServerSocket listener)
-                throws IOException {
-            OpenSSLSocketImpl socket = (OpenSSLSocketImpl) factory.createSocket(
-                    listener.getInetAddress(), listener.getLocalPort());
-            socket.setUseClientMode(true);
+        OpenSSLSocketImpl createSocket(ServerSocket listener) throws IOException {
+            OpenSSLSocketImpl socket = socketType.createClientSocket(createContext(), listener);
             socket.setHostname(hostname);
-
             return socket;
         }
     }
@@ -140,22 +187,21 @@
         byte[] ocspResponse;
 
         @Override
-        public OpenSSLContextImpl createContext() throws Exception {
+        public OpenSSLContextImpl createContext() throws IOException {
             OpenSSLContextImpl context = super.createContext();
-            SSLParametersImpl sslParameters = getContextSSLParameters(context);
-            sslParameters.setSCTExtension(sctTLSExtension);
-            sslParameters.setOCSPResponse(ocspResponse);
-            return context;
+            try {
+                SSLParametersImpl sslParameters = getContextSSLParameters(context);
+                sslParameters.setSCTExtension(sctTLSExtension);
+                sslParameters.setOCSPResponse(ocspResponse);
+                return context;
+            } catch (IllegalAccessException e) {
+                throw new IOException(e);
+            }
         }
 
         @Override
-        public OpenSSLSocketImpl createSocket(SSLSocketFactory factory, ServerSocket listener)
-                throws IOException {
-            OpenSSLSocketImpl socket = (OpenSSLSocketImpl) factory.createSocket(listener.accept(),
-                    null, -1, // hostname, port
-                    true); // autoclose
-            socket.setUseClientMode(false);
-            return socket;
+        OpenSSLSocketImpl createSocket(ServerSocket listener) throws IOException {
+            return socketType.createServerSocket(createContext(), listener);
         }
     }
 
@@ -169,7 +215,7 @@
         Exception clientException;
         Exception serverException;
 
-        public TestConnection(X509Certificate[] chain, PrivateKey key) throws Exception {
+        TestConnection(X509Certificate[] chain, PrivateKey key) throws Exception {
             clientHooks = new ClientHooks();
             serverHooks = new ServerHooks();
             setCertificates(chain, key);
@@ -209,12 +255,11 @@
             }
         }
 
-        public void doHandshake() throws Exception {
+        void doHandshake() throws Exception {
             ServerSocket listener = new ServerSocket(0);
             Future<OpenSSLSocketImpl> clientFuture = handshake(listener, clientHooks);
             Future<OpenSSLSocketImpl> serverFuture = handshake(listener, serverHooks);
 
-            Exception cause = null;
             try {
                 client = getOrThrowCause(clientFuture, TIMEOUT_SECONDS, TimeUnit.SECONDS);
             } catch (Exception e) {
@@ -228,24 +273,14 @@
         }
 
         Future<OpenSSLSocketImpl> handshake(final ServerSocket listener, final Hooks hooks) {
-            ExecutorService executor = Executors.newSingleThreadExecutor();
-            Future<OpenSSLSocketImpl> future = executor.submit(new Callable<OpenSSLSocketImpl>() {
-                @Override
-                public OpenSSLSocketImpl call() throws Exception {
-                    OpenSSLContextImpl context = hooks.createContext();
-                    SSLSocketFactory factory = context.engineGetSocketFactory();
-                    OpenSSLSocketImpl socket = hooks.createSocket(factory, listener);
-                    socket.addHandshakeCompletedListener(hooks);
+            return executor.submit(() -> {
+                OpenSSLSocketImpl socket = hooks.createSocket(listener);
+                socket.addHandshakeCompletedListener(hooks);
 
-                    socket.startHandshake();
+                socket.startHandshake();
 
-                    return socket;
-                }
+                return socket;
             });
-
-            executor.shutdown();
-
-            return future;
         }
     }
 
@@ -263,8 +298,6 @@
         TestConnection connection =
                 new TestConnection(new X509Certificate[] {certEmbedded, ca}, certKey);
 
-        connection.clientHooks.ctVerificationEnabled = true;
-
         connection.doHandshake();
 
         assertTrue(connection.clientHooks.isHandshakeCompleted);
@@ -275,7 +308,6 @@
     public void test_handshakeWithSCTFromOCSPResponse() throws Exception {
         TestConnection connection = new TestConnection(new X509Certificate[] {cert, ca}, certKey);
 
-        connection.clientHooks.ctVerificationEnabled = true;
         connection.serverHooks.ocspResponse = readTestFile("ocsp-response.der");
 
         connection.doHandshake();
@@ -288,7 +320,6 @@
     public void test_handshakeWithSCTFromTLSExtension() throws Exception {
         TestConnection connection = new TestConnection(new X509Certificate[] {cert, ca}, certKey);
 
-        connection.clientHooks.ctVerificationEnabled = true;
         connection.serverHooks.sctTLSExtension = readTestFile("ct-signed-timestamp-list");
 
         connection.doHandshake();
@@ -302,8 +333,6 @@
     public void test_handshake_failsWithMissingSCT() throws Exception {
         TestConnection connection = new TestConnection(new X509Certificate[] {cert, ca}, certKey);
 
-        connection.clientHooks.ctVerificationEnabled = true;
-
         connection.doHandshake();
         assertThat(connection.clientException, instanceOf(SSLHandshakeException.class));
         assertThat(connection.clientException.getCause(), instanceOf(CertificateException.class));
@@ -314,7 +343,6 @@
     public void test_handshake_failsWithInvalidSCT() throws Exception {
         TestConnection connection = new TestConnection(new X509Certificate[] {cert, ca}, certKey);
 
-        connection.clientHooks.ctVerificationEnabled = true;
         connection.serverHooks.sctTLSExtension = readTestFile("ct-signed-timestamp-list-invalid");
 
         connection.doHandshake();
@@ -328,11 +356,20 @@
         ServerSocket listening = new ServerSocket(0);
         Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
 
-        Constructor<OpenSSLSocketImpl> cons = OpenSSLSocketImpl.class.getDeclaredConstructor(
-                Socket.class, String.class, Integer.TYPE, Boolean.TYPE, SSLParametersImpl.class);
-        cons.setAccessible(true);
-        OpenSSLSocketImpl simpl =
-                cons.newInstance(underlying, null, listening.getLocalPort(), false, null);
+        OpenSSLSocketImpl simpl;
+        switch (socketType) {
+            case DEFAULT:
+                simpl = new OpenSSLSocketImpl(underlying, null, listening.getLocalPort(), false,
+                        SSLParametersImpl.getDefault());
+                break;
+            case ENGINE:
+                simpl = new OpenSSLEngineSocketImpl(underlying, null, listening.getLocalPort(),
+                        false, SSLParametersImpl.getDefault());
+                break;
+            default:
+                throw new IllegalArgumentException("Unexpected socketType " + socketType);
+        }
+
         simpl.setSoTimeout(1000);
         simpl.close();
 
@@ -347,9 +384,8 @@
 
         connection.clientHooks = new ClientHooks() {
             @Override
-            public OpenSSLSocketImpl createSocket(SSLSocketFactory factory, ServerSocket listener)
-                    throws IOException {
-                OpenSSLSocketImpl socket = super.createSocket(factory, listener);
+            public OpenSSLSocketImpl createSocket(ServerSocket listener) throws IOException {
+                OpenSSLSocketImpl socket = super.createSocket(listener);
                 socket.setEnabledProtocols(new String[] {"SSLv3"});
                 assertEquals(
                         "SSLv3 should be filtered out", 0, socket.getEnabledProtocols().length);