Merge "Snap for 4888699 from e7483acd9b0928a2506e786ec1df3fce2db87aad to oreo-mr1-cts-release" into oreo-mr1-cts-release
diff --git a/luni/src/test/java/libcore/java/nio/channels/DatagramChannelMulticastTest.java b/luni/src/test/java/libcore/java/nio/channels/DatagramChannelMulticastTest.java
index 38babc0..74999c0 100644
--- a/luni/src/test/java/libcore/java/nio/channels/DatagramChannelMulticastTest.java
+++ b/luni/src/test/java/libcore/java/nio/channels/DatagramChannelMulticastTest.java
@@ -39,6 +39,7 @@
 import java.nio.channels.MembershipKey;
 import java.util.ArrayList;
 import java.util.Enumeration;
+import java.util.concurrent.TimeUnit;
 
 import libcore.io.IoBridge;
 import libcore.io.IoUtils;
@@ -249,8 +250,7 @@
 
         // now verify that we received the data as expected
         ByteBuffer recvBuffer = ByteBuffer.allocate(100);
-        SocketAddress sourceAddress = receiverChannel.receive(recvBuffer);
-        assertNotNull(sourceAddress);
+        receiveExpectedDatagram(receiverChannel, recvBuffer);
         assertEquals(msg, new String(recvBuffer.array(), 0, recvBuffer.position()));
 
         // now verify that we didn't receive the second message
@@ -258,8 +258,7 @@
         createChannelAndSendMulticastMessage(
                 group2, localAddress.getPort(), msg2, networkInterface);
         recvBuffer.position(0);
-        SocketAddress sourceAddress2 = receiverChannel.receive(recvBuffer);
-        assertNull(sourceAddress2);
+        checkNoDatagramReceived(receiverChannel);
 
         receiverChannel.close();
     }
@@ -375,13 +374,12 @@
                     group, localAddress.getPort(), msg, sendingInterface);
 
             ByteBuffer recvBuffer = ByteBuffer.allocate(100);
-            SocketAddress sourceAddress = dc.receive(recvBuffer);
             if (thisInterface.equals(sendingInterface)) {
+                receiveExpectedDatagram(dc, recvBuffer);
                 assertEquals(msg, new String(recvBuffer.array(), 0, recvBuffer.position()));
             } else {
-                assertNull(sourceAddress);
+                checkNoDatagramReceived(dc);
             }
-
             dc.close();
         }
     }
@@ -465,8 +463,7 @@
 
         // receive the datagram
         ByteBuffer recvBuffer = ByteBuffer.allocate(100);
-        SocketAddress sourceAddress = dc.receive(recvBuffer);
-        assertNotNull(sourceAddress);
+        receiveExpectedDatagram(dc, recvBuffer);
 
         String recvMessage = new String(recvBuffer.array(), 0, recvBuffer.position());
         assertEquals(message, recvMessage);
@@ -479,8 +476,7 @@
         ByteBuffer sendBuffer2 = ByteBuffer.wrap(sendData);
         dc.send(sendBuffer2, new InetSocketAddress(group, localAddress.getPort()));
 
-        SocketAddress sourceAddress2 = dc.receive(recvBuffer);
-        assertNull(sourceAddress2);
+        checkNoDatagramReceived(dc);
 
         dc.close();
     }
@@ -802,7 +798,7 @@
         String msg1 = "Hello1";
         channel.sendMulticastMessage(msg1, groupSocketAddress);
         IoBridge.poll(receivingChannel.socket().getFileDescriptor$(), POLLIN, 1000);
-        InetSocketAddress sourceAddress1 = (InetSocketAddress) receivingChannel.receive(receiveBuffer);
+        InetSocketAddress sourceAddress1 = (InetSocketAddress) receiveExpectedDatagram(receivingChannel, receiveBuffer);
         assertEquals(sourceAddress1, sendingAddress);
         assertEquals(msg1, new String(receiveBuffer.array(), 0, receiveBuffer.position()));
 
@@ -817,8 +813,7 @@
             fail();
         } catch (SocketTimeoutException expected) { }
         receiveBuffer.position(0);
-        InetSocketAddress sourceAddress2 = (InetSocketAddress) receivingChannel.receive(receiveBuffer);
-        assertNull(sourceAddress2);
+        checkNoDatagramReceived(receivingChannel);
 
         // Now unblock the sender
         membershipKey.unblock(sendingAddress.getAddress());
@@ -828,7 +823,8 @@
         channel.sendMulticastMessage(msg3, groupSocketAddress);
         IoBridge.poll(receivingChannel.socket().getFileDescriptor$(), POLLIN, 1000);
         receiveBuffer.position(0);
-        InetSocketAddress sourceAddress3 = (InetSocketAddress) receivingChannel.receive(receiveBuffer);
+        InetSocketAddress sourceAddress3 =
+                (InetSocketAddress) receiveExpectedDatagram(receivingChannel, receiveBuffer);
         assertEquals(sourceAddress3, sendingAddress);
         assertEquals(msg3, new String(receiveBuffer.array(), 0, receiveBuffer.position()));
 
@@ -1106,7 +1102,8 @@
         // Send a message. It should be received.
         String msg1 = "Hello1";
         channel.sendMulticastMessage(msg1, groupSocketAddress);
-        InetSocketAddress sourceAddress1 = (InetSocketAddress) receivingChannel.receive(receiveBuffer);
+        InetSocketAddress sourceAddress1 =
+                (InetSocketAddress) receiveExpectedDatagram(receivingChannel, receiveBuffer);
         assertEquals(sourceAddress1, sendingAddress);
         assertEquals(msg1, new String(receiveBuffer.array(), 0, receiveBuffer.position()));
 
@@ -1117,8 +1114,7 @@
         // Send a message. It should not be received.
         String msg2 = "Hello2";
         channel.sendMulticastMessage(msg2, groupSocketAddress);
-        InetSocketAddress sourceAddress2 = (InetSocketAddress) receivingChannel.receive(receiveBuffer);
-        assertNull(sourceAddress2);
+        checkNoDatagramReceived(receivingChannel);
 
         receivingChannel.close();
         sendingChannel.close();
@@ -1221,15 +1217,70 @@
 
     private static void configureChannelForReceiving(DatagramChannel receivingChannel)
             throws Exception {
-
-        // NOTE: At the time of writing setSoTimeout() has no effect in the RI, making these tests hang
-        // if the channel is in blocking mode. configureBlocking(false) is used instead and rely on the
-        // network to the local host being instantaneous.
-        // receivingChannel.socket().setSoTimeout(200);
-        // receivingChannel.configureBlocking(true);
+        /* NOTE: At the time of writing setSoTimeout() has no effect in the RI, making
+         * these tests hang if the channel is in blocking mode.
+         *
+         * Therefore this test instead uses configureBlocking(false) together with
+         * {@link #receiveWithTimeout} to do our own blocking.
+         */
         receivingChannel.configureBlocking(false);
     }
 
+    /**
+     * Asserts that a datagram is received from the supplied {@code receivingChannel}
+     * when {@link #receiveWithTimeout(DatagramChannel, ByteBuffer, long) polling}
+     * with a short timeout.
+     */
+    private static SocketAddress receiveExpectedDatagram(DatagramChannel receivingChannel,
+            ByteBuffer byteBuffer) throws InterruptedException, IOException {
+        long timeoutMillis = 50L;
+        SocketAddress result = receiveWithTimeout(receivingChannel, byteBuffer, timeoutMillis);
+        assertNotNull("Expected Datagram, but none received after " + timeoutMillis + " msec",
+                result);
+        return result;
+    }
+
+    /**
+     * Asserts that no datagram is received from the supplied {@code receivingChannel}
+     * when {@link #receiveWithTimeout(DatagramChannel, ByteBuffer, long) polling}
+     * with a moderate timeout.
+     */
+    private static void checkNoDatagramReceived(DatagramChannel receivingChannel)
+            throws InterruptedException, IOException {
+        ByteBuffer byteBuffer = ByteBuffer.allocate(100);
+        long startMillis = System.currentTimeMillis();
+        SocketAddress result = receiveWithTimeout(receivingChannel, byteBuffer, 1000L);
+        long elapsed = System.currentTimeMillis() - startMillis;
+        assertNull("Datagram unexpectedly received after " + elapsed + " msec", result);
+    }
+
+    /**
+     * Receives a datagram from {@code receivingChannel} and writes the datagram content
+     * into {@code byteBuffer}.
+     * This method polls periodically until it finds that either a datagram was received,
+     * or the indicated timeout has expired.
+     *
+     * @return the received datagram's source address, or null if no datagram was received
+     *         before the timeout was found to have expired.
+     */
+    private static SocketAddress receiveWithTimeout(DatagramChannel receivingChannel,
+            ByteBuffer byteBuffer, long timeoutMillis) throws InterruptedException, IOException {
+        long endTimeMillis = System.currentTimeMillis() + timeoutMillis;
+        SocketAddress result;
+        while (true) {
+            result = receivingChannel.receive(byteBuffer);
+            if (result != null) {
+                break;
+            }
+            long remainingMillis = endTimeMillis - System.currentTimeMillis();
+            if (remainingMillis <= 0) {
+                break;
+            }
+            Thread.sleep(Math.min(20L, remainingMillis + 1));
+        }
+        return result;
+    }
+
     private static boolean willWorkForMulticast(NetworkInterface iface) throws IOException {
         return iface.isUp()
                 // On Oreo+ NetworkInterface.isUp() doesn't check the IFF_RUNNING flag so we do