Updates to SSLSocketTest to support engine-based socket. (#199)

diff --git a/openjdk-integ-tests/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java b/openjdk-integ-tests/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java
index e0c3656..a86e666 100644
--- a/openjdk-integ-tests/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java
+++ b/openjdk-integ-tests/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java
@@ -25,7 +25,8 @@
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeFalse;
+import static org.junit.Assume.assumeNoException;
 import static org.junit.Assume.assumeTrue;
 
 import java.io.ByteArrayInputStream;
@@ -94,6 +95,7 @@
 import javax.net.ssl.SNIHostName;
 import javax.net.ssl.SNIServerName;
 import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLHandshakeException;
 import javax.net.ssl.SSLParameters;
@@ -121,7 +123,6 @@
 import libcore.tlswire.util.TlsProtocolVersion;
 import org.junit.After;
 import org.junit.Before;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -183,6 +184,16 @@
                     @Override
                     protected SecretKey getKey(
                             String identityHint, String identity, Socket socket) {
+                        return newKey();
+                    }
+
+                    @Override
+                    protected SecretKey getKey(
+                            String identityHint, String identity, SSLEngine engine) {
+                        return newKey();
+                    }
+
+                    private SecretKey newKey() {
                         return new SecretKeySpec("Just an arbitrary key".getBytes(UTF_8), "RAW");
                     }
                 });
@@ -410,8 +421,7 @@
             assertNotNull(localCertificates);
             TestKeyStore.assertChainLength(localCertificates);
             assertNotNull(localCertificates[0]);
-            TestSSLContext.assertServerCertificateChain(
-                    c.serverTrustManager, localCertificates);
+            TestSSLContext.assertServerCertificateChain(c.serverTrustManager, localCertificates);
             TestSSLContext.assertCertificateInKeyStore(localCertificates[0], c.serverKeyStore);
             return null;
         });
@@ -506,10 +516,10 @@
     @Test
     public void test_SSLSocket_startHandshake_noKeyStore() throws Exception {
         TestSSLContext c = TestSSLContext.newBuilder()
-                .useDefaults(false)
-                .clientContext(SSLContext.getDefault())
-                .serverContext(SSLContext.getDefault())
-                .build();
+                                   .useDefaults(false)
+                                   .clientContext(SSLContext.getDefault())
+                                   .serverContext(SSLContext.getDefault())
+                                   .build();
         SSLSocket client =
                 (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, c.port);
         final SSLSocket server = (SSLSocket) c.serverSocket.accept();
@@ -578,18 +588,15 @@
                 assertEquals(32, id.length);
                 assertNotNull(c.clientContext.getClientSessionContext().getSession(id));
                 assertNotNull(cipherSuite);
-                assertTrue(
-                        Arrays.asList(client.getEnabledCipherSuites()).contains(cipherSuite));
+                assertTrue(Arrays.asList(client.getEnabledCipherSuites()).contains(cipherSuite));
                 assertTrue(Arrays.asList(c.serverSocket.getEnabledCipherSuites())
                                    .contains(cipherSuite));
                 assertNull(localCertificates);
                 assertNotNull(peerCertificates);
                 TestKeyStore.assertChainLength(peerCertificates);
                 assertNotNull(peerCertificates[0]);
-                TestSSLContext.assertServerCertificateChain(
-                        c.clientTrustManager, peerCertificates);
-                TestSSLContext.assertCertificateInKeyStore(
-                        peerCertificates[0], c.serverKeyStore);
+                TestSSLContext.assertServerCertificateChain(c.clientTrustManager, peerCertificates);
+                TestSSLContext.assertCertificateInKeyStore(peerCertificates[0], c.serverKeyStore);
                 assertNotNull(peerCertificateChain);
                 TestKeyStore.assertChainLength(peerCertificateChain);
                 assertNotNull(peerCertificateChain[0]);
@@ -614,8 +621,8 @@
         });
         client.startHandshake();
         future.get();
-        assertNotNull(c.serverContext.getServerSessionContext().getSession(
-                    client.getSession().getId()));
+        assertNotNull(
+                c.serverContext.getServerSessionContext().getSession(client.getSession().getId()));
         synchronized (handshakeCompletedListenerCalled) {
             while (!handshakeCompletedListenerCalled[0]) {
                 handshakeCompletedListenerCalled.wait();
@@ -648,9 +655,7 @@
             server.startHandshake();
             return null;
         });
-        client.addHandshakeCompletedListener(event -> {
-            throw expectedException;
-        });
+        client.addHandshakeCompletedListener(event -> { throw expectedException; });
         client.startHandshake();
         future.get();
         client.close();
@@ -1557,9 +1562,6 @@
         SSLSocket client =
                 (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, c.port);
 
-        Method writeTimeoutMethod = getWriteTimeoutSetter(client);
-        assumeNotNull("Client socket does not support setting write timeout", writeTimeoutMethod);
-
         // Try to make the client SO_SNDBUF size as small as possible
         // (it can default to 512k or even megabytes).  Note that
         // socket(7) says that the kernel will double the request to
@@ -1578,8 +1580,9 @@
         });
         server.startHandshake();
 
-        writeTimeoutMethod.invoke(client, 1);
         try {
+            setWriteTimeout(client, 1);
+
             // Add extra space to the write to exceed the send buffer
             // size and cause the write to block.
             final int extra = 1;
@@ -1592,37 +1595,45 @@
         }
     }
 
-    private static Method getWriteTimeoutSetter(Object socket) {
-        try {
-            return socket.getClass().getDeclaredMethod("setSoWriteTimeout", int.class);
-        } catch (Exception e) {
-            return null;
-        }
-    }
-
-    private static String osName() {
-        return System.getProperty("os.name").toLowerCase(Locale.US).replaceAll("[^a-z0-9]+", "");
-    }
-
-    private static boolean isLinux() {
-        return osName().startsWith("linux");
-    }
-
-    @Ignore("TODO(nmittler): Fix this.")
+    // TODO(nmittler): FD socket read may return -1 instead of SocketException.
     @Test
-    public void test_SSLSocket_interrupt() throws Exception {
+    public void test_SSLSocket_interrupt_readUnderlyingAndCloseUnderlying() throws Exception {
         test_SSLSocket_interrupt_case(true, true);
+    }
+
+    // TODO(nmittler): FD socket read may return -1 instead of SocketException.
+    @Test
+    public void test_SSLSocket_interrupt_readUnderlyingAndCloseWrapper() throws Exception {
         test_SSLSocket_interrupt_case(true, false);
+    }
+
+    // TODO(nmittler): FD socket gets stuck in read on Windows and OSX.
+    @Test
+    public void test_SSLSocket_interrupt_readWrapperAndCloseUnderlying() throws Exception {
         test_SSLSocket_interrupt_case(false, true);
+    }
+
+    // TODO(nmittler): FD socket read may return -1 instead of SocketException.
+    @Test
+    public void test_SSLSocket_interrupt_readWrapperAndCloseWrapper() throws Exception {
         test_SSLSocket_interrupt_case(false, false);
     }
+
     private void test_SSLSocket_interrupt_case(boolean readUnderlying, boolean closeUnderlying)
             throws Exception {
         final int readingTimeoutMillis = 5000;
         TestSSLContext c = TestSSLContext.create();
         final Socket underlying = new Socket(c.host, c.port);
-        final SSLSocket clientWrapping = (SSLSocket) c.clientContext.getSocketFactory().createSocket(
-                underlying, c.host.getHostName(), c.port, false);
+        final SSLSocket clientWrapping =
+                (SSLSocket) c.clientContext.getSocketFactory().createSocket(
+                        underlying, c.host.getHostName(), c.port, true);
+
+        if (isConscryptFdSocket(clientWrapping) && !readUnderlying && closeUnderlying) {
+            // TODO(nmittler): FD socket gets stuck in the read on Windows and OSX.
+            assumeFalse("Skipping interrupt test on Windows", isWindows());
+            assumeFalse("Skipping interrupt test on OSX", isOsx());
+        }
+
         SSLSocket server = (SSLSocket) c.serverSocket.accept();
 
         // Start the handshake.
@@ -1639,6 +1650,7 @@
         // Schedule the socket to be closes in 1 second.
         Future<Void> future = runAsync(() -> {
             Thread.sleep(1000);
+            toClose.shutdownInput();
             toClose.close();
             return null;
         });
@@ -1647,11 +1659,16 @@
         try {
             toRead.setSoTimeout(readingTimeoutMillis);
             final InputStream inputStream = toRead.getInputStream();
-            @SuppressWarnings("unused")
             int value = inputStream.read();
-            fail();
-        } catch (SocketException expected) {
-            // Ignored.
+            if (isConscryptFdSocket(clientWrapping)) {
+                // TODO(nmittler): FD socket read may return -1 instead of SocketException.
+                assertEquals(-1, value);
+            } else {
+                // For every other condition, we should expect SocketException.
+                fail();
+            }
+        } catch (SocketException e) {
+            // Otherwise, ignore the exception since it's expected.
         }
 
         future.get();
@@ -1659,6 +1676,7 @@
         underlying.close();
         server.close();
     }
+
     /**
      * b/7014266 Test to confirm that an SSLSocket.close() on one
      * thread will interrupt another thread blocked reading on the same
@@ -1673,8 +1691,22 @@
                 underlying, c.host.getHostName(), c.port, false);
         Future<Void> clientFuture = runAsync(() -> {
             wrapping.startHandshake();
-            wrapping.setSoTimeout(readingTimeoutMillis);
-            assertEquals(-1, wrapping.getInputStream().read());
+            try {
+                wrapping.setSoTimeout(readingTimeoutMillis);
+                int ret = wrapping.getInputStream().read();
+                // Android returns -1 rather than throwing.
+                if (isConscryptFdSocket(wrapping)) {
+                    // This seems to only happen with Conscrypt's FD-based socket.
+                    assertEquals(-1, ret);
+                } else {
+                    // For every other condition, we should expect SocketException.
+                    fail();
+                }
+            } catch (SocketException e) {
+                // Opposite condition of the one above. Verify the behavior we expect.
+                assertFalse(isConscryptFdSocket(wrapping));
+                // Otherwise, ignore the exception since it's expected.
+            }
             return null;
         });
         SSLSocket server = (SSLSocket) c.serverSocket.accept();
@@ -1733,25 +1765,11 @@
             protected SSLSocket configureSocket(SSLSocket socket) {
                 // Enable SNI extension on the socket (this is typically enabled by default)
                 // to increase the size of ClientHello.
-                try {
-                    Method setHostname = socket.getClass().getMethod("setHostname", String.class);
-                    setHostname.invoke(socket, "sslsockettest.androidcts.google.com");
-                } catch (NoSuchMethodException ignored) {
-                    // Ignored.
-                } catch (Exception e) {
-                    throw new RuntimeException("Failed to enable SNI", e);
-                }
+                setSniHostname(socket);
+
                 // Enable Session Tickets extension on the socket (this is typically enabled
                 // by default) to increase the size of ClientHello.
-                try {
-                    Method setUseSessionTickets =
-                            socket.getClass().getMethod("setUseSessionTickets", boolean.class);
-                    setUseSessionTickets.invoke(socket, true);
-                } catch (NoSuchMethodException ignored) {
-                    // Ignored.
-                } catch (Exception e) {
-                    throw new RuntimeException("Failed to enable Session Tickets", e);
-                }
+                enableSessionTickets(socket);
                 return socket;
             }
         };
@@ -1777,8 +1795,7 @@
             // indicate that a certain TLS extension should be used.
             HelloExtension renegotiationInfoExtension =
                     clientHello.findExtensionByType(HelloExtension.TYPE_RENEGOTIATION_INFO);
-            if (renegotiationInfoExtension != null
-                    && renegotiationInfoExtension.data.length == 1
+            if (renegotiationInfoExtension != null && renegotiationInfoExtension.data.length == 1
                     && renegotiationInfoExtension.data[0] == 0) {
                 cipherSuites = new String[clientHello.cipherSuites.size() + 1];
                 cipherSuites[clientHello.cipherSuites.size()] =
@@ -1843,8 +1860,7 @@
     }
     private List<Pair<String, SSLSocketFactory>> getSSLSocketFactoriesToTest()
             throws NoSuchAlgorithmException, KeyManagementException {
-        List<Pair<String, SSLSocketFactory>> result =
-                new ArrayList<>();
+        List<Pair<String, SSLSocketFactory>> result = new ArrayList<>();
         result.add(Pair.of("default", (SSLSocketFactory) SSLSocketFactory.getDefault()));
         for (String sslContextProtocol : StandardNames.SSL_CONTEXT_PROTOCOLS) {
             SSLContext sslContext = SSLContext.getInstance(sslContextProtocol);
@@ -1898,21 +1914,20 @@
             listeningSocket = ServerSocketFactory.getDefault().createServerSocket(0);
             final ServerSocket finalListeningSocket = listeningSocket;
             // 2. (in background) Wait for an incoming connection and read its first chunk.
-            final Future<byte[]> readFirstReceivedChunkFuture =
-                    runAsync(() -> {
-                        Socket socket = finalListeningSocket.accept();
-                        sockets[1] = socket;
-                        try {
-                            byte[] buffer = new byte[64 * 1024];
-                            int bytesRead = socket.getInputStream().read(buffer);
-                            if (bytesRead == -1) {
-                                throw new EOFException("Failed to read anything");
-                            }
-                            return Arrays.copyOf(buffer, bytesRead);
-                        } finally {
-                            closeQuietly(socket);
-                        }
-                    });
+            final Future<byte[]> readFirstReceivedChunkFuture = runAsync(() -> {
+                Socket socket = finalListeningSocket.accept();
+                sockets[1] = socket;
+                try {
+                    byte[] buffer = new byte[64 * 1024];
+                    int bytesRead = socket.getInputStream().read(buffer);
+                    if (bytesRead == -1) {
+                        throw new EOFException("Failed to read anything");
+                    }
+                    return Arrays.copyOf(buffer, bytesRead);
+                } finally {
+                    closeQuietly(socket);
+                }
+            });
             // 3. Create a client socket, connect it to the server socket, and start the TLS/SSL
             //    handshake.
             runAsync((Callable<Void>) () -> {
@@ -1924,8 +1939,7 @@
                     // server socket receives a ClientHello.
                     try {
                         SSLSocket sslSocket = (SSLSocket) sslSocketFactory.createSocket(client,
-                                "localhost.localdomain", finalListeningSocket.getLocalPort(),
-                                true);
+                                "localhost.localdomain", finalListeningSocket.getLocalPort(), true);
                         sslSocket.startHandshake();
                         fail();
                         return null;
@@ -1949,16 +1963,10 @@
     @Test
     public void test_SSLSocket_getPortWithSNI() throws Exception {
         TestSSLContext context = TestSSLContext.create();
-        try (SSLSocket client = (SSLSocket) context.clientContext.getSocketFactory()
-                .createSocket()) {
+        try (SSLSocket client =
+                        (SSLSocket) context.clientContext.getSocketFactory().createSocket()) {
             client.connect(new InetSocketAddress(context.host, context.port));
-            try {
-                // This is crucial to reproducing issue 18428603.
-                Method setHostname = client.getClass().getMethod("setHostname", String.class);
-                setHostname.invoke(client, "sslsockettest.androidcts.google.com");
-            } catch (NoSuchMethodException ignored) {
-                // Ignored.
-            }
+            setSniHostname(client);
             assertTrue(client.getPort() > 0);
         } finally {
             context.close();
@@ -1969,8 +1977,7 @@
         TestSSLContext c = TestSSLContext.create();
         final SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket();
         SSLParameters clientParams = client.getSSLParameters();
-        clientParams.setServerNames(
-                Collections.singletonList(new SNIHostName("www.example.com")));
+        clientParams.setServerNames(Collections.singletonList(new SNIHostName("www.example.com")));
         client.setSSLParameters(clientParams);
         SSLParameters serverParams = c.serverSocket.getSSLParameters();
         serverParams.setSNIMatchers(
@@ -2050,10 +2057,13 @@
         server.close();
         context.close();
     }
+
     private static void assertInappropriateFallbackIsCause(Throwable cause) {
-        assertTrue(cause.getMessage(), cause.getMessage().contains("inappropriate fallback")
+        assertTrue(cause.getMessage(),
+                cause.getMessage().contains("inappropriate fallback")
                         || cause.getMessage().contains("INAPPROPRIATE_FALLBACK"));
     }
+
     @Test
     public void test_SSLSocket_sendsTlsFallbackScsv_InappropriateFallback_Failure()
             throws Exception {
@@ -2123,11 +2133,9 @@
             int bytesRead = server.getInputStream().read(scratch);
             // Write a bogus TLS alert:
             // TLSv1.2 Record Layer: Alert (Level: Warning, Description: Protocol Version)
-            server.getOutputStream().write(
-                    new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x46});
+            server.getOutputStream().write(new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x46});
             // TLSv1.2 Record Layer: Alert (Level: Warning, Description: Close Notify)
-            server.getOutputStream().write(
-                    new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x00});
+            server.getOutputStream().write(new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x00});
             return null;
         });
         c.get(5, TimeUnit.SECONDS);
@@ -2203,12 +2211,10 @@
             // Write a bogus TLS alert:
             // TLSv1.2 Record Layer: Alert (Level: Warning, Description:
             // Protocol Version)
-            client.getOutputStream().write(
-                    new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x46});
+            client.getOutputStream().write(new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x46});
             // TLSv1.2 Record Layer: Alert (Level: Warning, Description:
             // Close Notify)
-            client.getOutputStream().write(
-                    new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x00});
+            client.getOutputStream().write(new byte[] {0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x00});
             return null;
         });
         c.get(5, TimeUnit.SECONDS);
@@ -2233,6 +2239,65 @@
         }
     }
 
+    private static void setWriteTimeout(Object socket, int timeout) {
+        Exception ex = null;
+        try {
+            Method method = socket.getClass().getMethod("setSoWriteTimeout", int.class);
+            method.setAccessible(true);
+            method.invoke(socket, timeout);
+        } catch (Exception e) {
+            ex = e;
+        }
+        // Engine-based socket currently has the method but throws UnsupportedOperationException.
+        assumeNoException("Client socket does not support setting write timeout", ex);
+    }
+
+    private static void setSniHostname(SSLSocket socket) {
+        try {
+            Method method = socket.getClass().getMethod("setHostname", String.class);
+            method.setAccessible(true);
+            method.invoke(socket, "sslsockettest.androidcts.google.com");
+        } catch (NoSuchMethodException ignored) {
+            // Ignored.
+        } catch (Exception e) {
+            throw new RuntimeException("Failed to enable SNI", e);
+        }
+    }
+
+    private static void enableSessionTickets(SSLSocket socket) {
+        try {
+            Method method =
+                    socket.getClass().getMethod("setUseSessionTickets", boolean.class);
+            method.setAccessible(true);
+            method.invoke(socket, true);
+        } catch (NoSuchMethodException ignored) {
+            // Ignored.
+        } catch (Exception e) {
+            throw new RuntimeException("Failed to enable Session Tickets", e);
+        }
+    }
+
+    private static boolean isConscryptFdSocket(Socket socket) {
+        return "OpenSSLSocketImplWrapper".equals(socket.getClass().getSimpleName());
+    }
+
+    private static String osName() {
+        return System.getProperty("os.name").toLowerCase(Locale.US).replaceAll("[^a-z0-9]+", "");
+    }
+
+    private static boolean isLinux() {
+        return osName().startsWith("linux");
+    }
+
+    private static boolean isWindows() {
+        return osName().startsWith("windows");
+    }
+
+    private static boolean isOsx() {
+        String name = osName();
+        return name.startsWith("macosx") || name.startsWith("osx");
+    }
+
     private <T> Future<T> runAsync(Callable<T> callable) {
         return executor.submit(callable);
     }