Socket ctor should try all addresses

In M and below, calling:

    Socket("www.google.com", 80);

would call InetAddress.getAllByName("www.google.com") and then try
all of the resulting IP addresses in a loop. The new behaviour we
adopted in Enso is only trying the first record.

This restores the old behaviour and make sure we can create socket to
sites with misconfigured AAAA record.

Contributed by: Paul Marks <pmarks@google.com>
Bug: 30007735
Test: libcore.java.net.SocketTest#testSocketTestAllAddresses
Change-Id: Ieaafd20676081f6bf21548d17a95db092eece299
(cherry picked from commit f119f6d48e63d5cab09f5342eb46f84ed91b1195)
diff --git a/luni/src/test/java/libcore/java/net/SocketTest.java b/luni/src/test/java/libcore/java/net/SocketTest.java
index 52638d4..e2d3964 100644
--- a/luni/src/test/java/libcore/java/net/SocketTest.java
+++ b/luni/src/test/java/libcore/java/net/SocketTest.java
@@ -21,6 +21,7 @@
 import java.io.OutputStream;
 import java.net.ConnectException;
 import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.Proxy;
@@ -34,6 +35,7 @@
 import java.net.UnknownHostException;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
+import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.CountDownLatch;
@@ -45,6 +47,10 @@
 
 
 public class SocketTest extends junit.framework.TestCase {
+
+    // This hostname is required to resolve to 127.0.0.1 and ::1 for all tests to pass.
+    private static final String ALL_LOOPBACK_HOSTNAME = "loopback46.unittest.grpc.io";
+
     // See http://b/2980559.
     public void test_close() throws Exception {
         Socket s = new Socket();
@@ -542,4 +548,51 @@
             new SocketThatFailOnClose(InetAddress.getLocalHost(), 1, true);
         } catch(IOException expected) {}
     }
+
+    // b/30007735
+    public void testSocketTestAllAddresses() throws Exception {
+        // Socket Ctor should try all sockets.
+        //
+        // This test creates a server socket bound to 127.0.0.1 or ::1 only, and connects using a
+        // hostname that resolves to both addresses. We should be able to connect to the server
+        // socket in either setup.
+        final String loopbackHost = ALL_LOOPBACK_HOSTNAME;
+
+        assertTrue("Loopback DNS record is unreachable or is invalid.", checkLoopbackHost(
+                loopbackHost));
+
+        final int port = 9999;
+        for (InetAddress addr : new InetAddress[]{ Inet4Address.LOOPBACK, Inet6Address.LOOPBACK }) {
+            try (ServerSocket ss = new ServerSocket(port, 0, addr)) {
+                new Thread(() -> {
+                    try {
+                        ss.accept();
+                    } catch (IOException e) {
+                        e.printStackTrace();
+                    }
+                }).start();
+
+                assertTrue(canConnect(loopbackHost, port));
+            }
+        }
+    }
+
+    /** Confirm the supplied hostname maps to only loopback addresses. */
+    private static boolean checkLoopbackHost(String host) {
+        try {
+            List<InetAddress> addrs = Arrays.asList(InetAddress.getAllByName(host));
+            return addrs.stream().allMatch(InetAddress::isLoopbackAddress) &&
+                    addrs.contains(Inet4Address.LOOPBACK) && addrs.contains(Inet6Address.LOOPBACK);
+        } catch (UnknownHostException e) {
+            return false;
+        }
+    }
+
+    private static boolean canConnect(String host, int port) {
+        try(Socket sock = new Socket(host, port)) {
+            return sock.isConnected();
+        } catch (IOException e) {
+            return false;
+        }
+    }
 }
diff --git a/ojluni/src/main/java/java/net/Socket.java b/ojluni/src/main/java/java/net/Socket.java
index 2aa057f..cbe081e 100755
--- a/ojluni/src/main/java/java/net/Socket.java
+++ b/ojluni/src/main/java/java/net/Socket.java
@@ -207,9 +207,7 @@
     public Socket(String host, int port)
         throws UnknownHostException, IOException
     {
-        this(host != null ? new InetSocketAddress(host, port) :
-             new InetSocketAddress(InetAddress.getByName(null), port),
-             (SocketAddress) null, true);
+        this(InetAddress.getAllByName(host), port, (SocketAddress) null, true);
     }
 
     /**
@@ -240,8 +238,7 @@
      * @see        SecurityManager#checkConnect
      */
     public Socket(InetAddress address, int port) throws IOException {
-        this(address != null ? new InetSocketAddress(address, port) : null,
-             (SocketAddress) null, true);
+        this(nonNullAddress(address), port, (SocketAddress) null, true);
     }
 
     /**
@@ -279,8 +276,7 @@
      */
     public Socket(String host, int port, InetAddress localAddr,
                   int localPort) throws IOException {
-        this(host != null ? new InetSocketAddress(host, port) :
-               new InetSocketAddress(InetAddress.getByName(null), port),
+        this(InetAddress.getAllByName(host), port,
              new InetSocketAddress(localAddr, localPort), true);
     }
 
@@ -318,7 +314,7 @@
      */
     public Socket(InetAddress address, int port, InetAddress localAddr,
                   int localPort) throws IOException {
-        this(address != null ? new InetSocketAddress(address, port) : null,
+        this(nonNullAddress(address), port,
              new InetSocketAddress(localAddr, localPort), true);
     }
 
@@ -364,9 +360,7 @@
      */
     @Deprecated
     public Socket(String host, int port, boolean stream) throws IOException {
-        this(host != null ? new InetSocketAddress(host, port) :
-               new InetSocketAddress(InetAddress.getByName(null), port),
-             (SocketAddress) null, stream);
+        this(InetAddress.getAllByName(host), port, (SocketAddress) null, stream);
     }
 
     /**
@@ -407,32 +401,56 @@
      */
     @Deprecated
     public Socket(InetAddress host, int port, boolean stream) throws IOException {
-        this(host != null ? new InetSocketAddress(host, port) : null,
-             new InetSocketAddress(0), stream);
+        this(nonNullAddress(host), port, new InetSocketAddress(0), stream);
     }
 
-    private Socket(SocketAddress address, SocketAddress localAddr,
-                   boolean stream) throws IOException {
-        setImpl();
-
+    private static InetAddress[] nonNullAddress(InetAddress address) {
         // backward compatibility
         if (address == null)
             throw new NullPointerException();
 
-        try {
-            createImpl(stream);
-            if (localAddr != null)
-                bind(localAddr);
-            if (address != null)
-                connect(address);
-        } catch (IOException e) {
-            // Do not call #close, classes that extend this class may do not expect a call
-            // to #close coming from the superclass constructor.
-            if (impl != null) {
-                impl.close();
+        return new InetAddress[] { address };
+    }
+
+    // Android-changed: Socket ctor should try all addresses
+    // b/30007735
+    private Socket(InetAddress[] addresses, int port, SocketAddress localAddr,
+            boolean stream) throws IOException {
+        if (addresses == null || addresses.length == 0) {
+            throw new SocketException("Impossible: empty address list");
+        }
+
+        for (int i = 0; i < addresses.length; i++) {
+            setImpl();
+            try {
+                createImpl(stream);
+                if (localAddr != null) {
+                    bind(localAddr);
+                }
+                connect(new InetSocketAddress(addresses[i], port));
+                break;
+            } catch (IOException | IllegalArgumentException | SecurityException e) {
+                try {
+                    // Android-changed:
+                    // Do not call #close, classes that extend this class may do not expect a call
+                    // to #close coming from the superclass constructor.
+                    impl.close();
+                    closed = true;
+                } catch (IOException ce) {
+                    e.addSuppressed(ce);
+                }
+
+                // Only stop on the last address.
+                if (i == addresses.length - 1) {
+                    throw e;
+                }
             }
-            closed = true;
-            throw e;
+
+            // Discard the connection state and try again.
+            impl = null;
+            created = false;
+            bound = false;
+            closed = false;
         }
     }