Merge "Update NetworkStack version check in NetworkMonitor."
diff --git a/common/moduleutils/src/android/net/util/FdEventsReader.java b/common/moduleutils/src/android/net/util/FdEventsReader.java
index 5a1154f..ebd6c53 100644
--- a/common/moduleutils/src/android/net/util/FdEventsReader.java
+++ b/common/moduleutils/src/android/net/util/FdEventsReader.java
@@ -27,6 +27,7 @@
 
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
+import androidx.annotation.VisibleForTesting;
 
 import java.io.FileDescriptor;
 import java.io.IOException;
@@ -92,6 +93,12 @@
         mBuffer = buffer;
     }
 
+    @VisibleForTesting
+    @NonNull
+    protected MessageQueue getMessageQueue() {
+        return mQueue;
+    }
+
     /** Start this FdEventsReader. */
     public boolean start() {
         if (!onCorrectThread()) {
@@ -185,7 +192,7 @@
 
         if (mFd == null) return false;
 
-        mQueue.addOnFileDescriptorEventListener(
+        getMessageQueue().addOnFileDescriptorEventListener(
                 mFd,
                 FD_EVENTS,
                 (fd, events) -> {
@@ -247,7 +254,7 @@
     private void unregisterAndDestroyFd() {
         if (mFd == null) return;
 
-        mQueue.removeOnFileDescriptorEventListener(mFd);
+        getMessageQueue().removeOnFileDescriptorEventListener(mFd);
         closeFd(mFd);
         mFd = null;
         onStop();
diff --git a/res/values/overlayable.xml b/res/values/overlayable.xml
index 717c5ca..052266d 100644
--- a/res/values/overlayable.xml
+++ b/res/values/overlayable.xml
@@ -20,6 +20,8 @@
             <item type="integer" name="config_captive_portal_dns_probe_timeout"/>
             <item type="string" name="config_captive_portal_http_url"/>
             <item type="string" name="config_captive_portal_https_url"/>
+            <item type="array" name="config_captive_portal_http_urls"/>
+            <item type="array" name="config_captive_portal_https_urls"/>
             <item type="array" name="config_captive_portal_fallback_urls"/>
             <item type="bool" name="config_no_sim_card_uses_neighbor_mcc"/>
             <!-- Configuration value for DhcpResults -->
diff --git a/src/android/net/ip/IpClient.java b/src/android/net/ip/IpClient.java
index bb5565d..9ad2c78 100644
--- a/src/android/net/ip/IpClient.java
+++ b/src/android/net/ip/IpClient.java
@@ -772,14 +772,6 @@
             return;
         }
 
-        mInterfaceParams = mDependencies.getInterfaceParams(mInterfaceName);
-        if (mInterfaceParams == null) {
-            logError("Failed to find InterfaceParams for " + mInterfaceName);
-            doImmediateProvisioningFailure(IpManagerEvent.ERROR_INTERFACE_NOT_FOUND);
-            return;
-        }
-
-        mCallback.setNeighborDiscoveryOffload(true);
         sendMessage(CMD_START, new android.net.shared.ProvisioningConfiguration(req));
     }
 
@@ -1650,6 +1642,17 @@
                 // tethering or during an IpClient restart.
                 stopAllIP();
             }
+
+            // Ensure that interface parameters are fetched on the handler thread so they are
+            // properly ordered with other events, such as restoring the interface MTU on teardown.
+            mInterfaceParams = mDependencies.getInterfaceParams(mInterfaceName);
+            if (mInterfaceParams == null) {
+                logError("Failed to find InterfaceParams for " + mInterfaceName);
+                doImmediateProvisioningFailure(IpManagerEvent.ERROR_INTERFACE_NOT_FOUND);
+                transitionTo(mStoppedState);
+                return;
+            }
+            mCallback.setNeighborDiscoveryOffload(true);
         }
 
         @Override
diff --git a/src/android/net/ip/IpReachabilityMonitor.java b/src/android/net/ip/IpReachabilityMonitor.java
index 17b1f3c..e1d4548 100644
--- a/src/android/net/ip/IpReachabilityMonitor.java
+++ b/src/android/net/ip/IpReachabilityMonitor.java
@@ -27,6 +27,7 @@
 import android.net.LinkProperties;
 import android.net.RouteInfo;
 import android.net.ip.IpNeighborMonitor.NeighborEvent;
+import android.net.ip.IpNeighborMonitor.NeighborEventConsumer;
 import android.net.metrics.IpConnectivityLog;
 import android.net.metrics.IpReachabilityEvent;
 import android.net.netlink.StructNdMsg;
@@ -154,11 +155,12 @@
     }
 
     /**
-     * Encapsulates IpReachabilityMonitor depencencies on systems that hinder unit testing.
+     * Encapsulates IpReachabilityMonitor dependencies on systems that hinder unit testing.
      * TODO: consider also wrapping MultinetworkPolicyTracker in this interface.
      */
     interface Dependencies {
         void acquireWakeLock(long durationMs);
+        IpNeighborMonitor makeIpNeighborMonitor(Handler h, SharedLog log, NeighborEventConsumer cb);
 
         static Dependencies makeDefault(Context context, String iface) {
             final String lockName = TAG + "." + iface;
@@ -169,6 +171,11 @@
                 public void acquireWakeLock(long durationMs) {
                     lock.acquire(durationMs);
                 }
+
+                public IpNeighborMonitor makeIpNeighborMonitor(Handler h, SharedLog log,
+                        NeighborEventConsumer cb) {
+                    return new IpNeighborMonitor(h, log, cb);
+                }
             };
         }
     }
@@ -223,7 +230,7 @@
         }
         setNeighbourParametersForSteadyState();
 
-        mIpNeighborMonitor = new IpNeighborMonitor(h, mLog,
+        mIpNeighborMonitor = mDependencies.makeIpNeighborMonitor(h, mLog,
                 (NeighborEvent event) -> {
                     if (mInterfaceParams.index != event.ifindex) return;
                     if (!mNeighborWatchList.containsKey(event.ip)) return;
diff --git a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
index 0fa6266..9f0ef99 100644
--- a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
+++ b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
@@ -130,6 +130,7 @@
 import com.android.server.NetworkObserverRegistry;
 import com.android.server.NetworkStackService.NetworkStackServiceManager;
 import com.android.server.connectivity.ipmemorystore.IpMemoryStoreService;
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import com.android.testutils.HandlerUtilsKt;
 import com.android.testutils.TapPacketReader;
 
@@ -524,7 +525,6 @@
         mDependencies.setHostnameConfiguration(isHostnameConfigurationEnabled, hostname);
         mIpc.setL2KeyAndGroupHint(TEST_L2KEY, TEST_GROUPHINT);
         mIpc.startProvisioning(builder.build());
-        verify(mCb).setNeighborDiscoveryOffload(true);
         if (!isPreconnectionEnabled) {
             verify(mCb, timeout(TEST_TIMEOUT_MS)).setFallbackMulticastFilter(false);
         }
@@ -642,7 +642,6 @@
                             ArgumentCaptor.forClass(LinkProperties.class);
                     verifyProvisioningSuccess(captor, Collections.singletonList(CLIENT_ADDR));
                 }
-
                 return packetList;
             }
         }
@@ -660,6 +659,12 @@
                 null /* captivePortalApiUrl */, null /* displayName */, null /* scanResultInfo */);
     }
 
+    private List<DhcpPacket> performDhcpHandshake() throws Exception {
+        return performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                false /* isDhcpLeaseCacheEnabled */, false /* shouldReplyRapidCommitAck */,
+                TEST_DEFAULT_MTU, false /* isDhcpIpConflictDetectEnabled */);
+    }
+
     private DhcpPacket getNextDhcpPacket() throws ParseException {
         byte[] packet;
         while ((packet = mPacketReader.popPacket(PACKET_TIMEOUT_MS)) != null) {
@@ -725,6 +730,12 @@
             assertEquals(NetworkInterface.getByName(mIfaceName).getMTU(), mtu);
         }
 
+        // Sometimes, IpClient receives an update with an empty LinkProperties during startup,
+        // when the link-local address is deleted after interface bringup. Reset expectations
+        // here to ensure that verifyAfterIpClientShutdown does not fail because it sees two
+        // empty LinkProperties changes instead of one.
+        reset(mCb);
+
         if (shouldRemoveTapInterface) removeTapInterface(mTapFd);
         try {
             mIpc.shutdown();
@@ -1047,6 +1058,14 @@
         assertTrue(packet instanceof DhcpDiscoverPacket);
     }
 
+    @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
+    public void testDhcpServerInLinkProperties() throws Exception {
+        performDhcpHandshake();
+        ArgumentCaptor<LinkProperties> captor = ArgumentCaptor.forClass(LinkProperties.class);
+        verify(mCb, timeout(TEST_TIMEOUT_MS)).onProvisioningSuccess(captor.capture());
+        assertEquals(SERVER_ADDR, captor.getValue().getDhcpServerAddress());
+    }
+
     @Test
     public void testRestoreInitialInterfaceMtu() throws Exception {
         doRestoreInitialMtuTest(true /* shouldChangeMtu */, false /* shouldRemoveTapInterface */);
@@ -1081,10 +1100,37 @@
                 .build();
 
         mIpc.startProvisioning(config);
-        verify(mCb).onProvisioningFailure(any());
+        verify(mCb, timeout(TEST_TIMEOUT_MS)).onProvisioningFailure(any());
         verify(mCb, never()).setNeighborDiscoveryOffload(true);
     }
 
+    @Test
+    public void testRestoreInitialInterfaceMtu_stopIpClientAndRestart() throws Exception {
+        long currentTime = System.currentTimeMillis();
+
+        performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                true /* isDhcpLeaseCacheEnabled */, false /* shouldReplyRapidCommitAck */,
+                TEST_MIN_MTU, false /* isDhcpIpConflictDetectEnabled */);
+        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime, TEST_MIN_MTU);
+
+        // Pretend that ConnectivityService set the MTU.
+        mNetd.interfaceSetMtu(mIfaceName, TEST_MIN_MTU);
+        assertEquals(NetworkInterface.getByName(mIfaceName).getMTU(), TEST_MIN_MTU);
+
+        reset(mCb);
+        reset(mIpMemoryStore);
+
+        // Stop IpClient and then restart provisioning immediately.
+        mIpc.stop();
+        currentTime = System.currentTimeMillis();
+        // Intend to set mtu option to 0, then verify that won't influence interface mtu restore.
+        performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                true /* isDhcpLeaseCacheEnabled */, false /* shouldReplyRapidCommitAck */,
+                0 /* mtu */, false /* isDhcpIpConflictDetectEnabled */);
+        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime, 0 /* mtu */);
+        assertEquals(NetworkInterface.getByName(mIfaceName).getMTU(), TEST_DEFAULT_MTU);
+    }
+
     private boolean isRouterSolicitation(final byte[] packetBytes) {
         ByteBuffer packet = ByteBuffer.wrap(packetBytes);
         return packet.getShort(ETHER_TYPE_OFFSET) == (short) ETH_P_IPV6
diff --git a/tests/lib/src/com/android/testutils/DevSdkIgnoreRule.kt b/tests/lib/src/com/android/testutils/DevSdkIgnoreRule.kt
index fd7deb6..d30138d 100644
--- a/tests/lib/src/com/android/testutils/DevSdkIgnoreRule.kt
+++ b/tests/lib/src/com/android/testutils/DevSdkIgnoreRule.kt
@@ -27,8 +27,14 @@
  *
  * If the device is not using a release SDK, the development SDK is considered to be higher than
  * [Build.VERSION.SDK_INT].
+ *
+ * @param ignoreClassUpTo Skip all tests in the class if the device dev SDK is <= this value.
+ * @param ignoreClassAfter Skip all tests in the class if the device dev SDK is > this value.
  */
-class DevSdkIgnoreRule : TestRule {
+class DevSdkIgnoreRule @JvmOverloads constructor(
+    private val ignoreClassUpTo: Int? = null,
+    private val ignoreClassAfter: Int? = null
+) : TestRule {
     override fun apply(base: Statement, description: Description): Statement {
         return IgnoreBySdkStatement(base, description)
     }
@@ -49,7 +55,7 @@
      */
     annotation class IgnoreUpTo(val value: Int)
 
-    private class IgnoreBySdkStatement(
+    private inner class IgnoreBySdkStatement(
         private val base: Statement,
         private val description: Description
     ) : Statement() {
@@ -63,6 +69,8 @@
             val sdkInt = Build.VERSION.SDK_INT
             val devApiLevel = sdkInt + if (release) 0 else 1
             val message = "Skipping test for ${if (!release) "non-" else ""}release SDK $sdkInt"
+            assumeTrue(message, ignoreClassAfter == null || devApiLevel <= ignoreClassAfter)
+            assumeTrue(message, ignoreClassUpTo == null || devApiLevel > ignoreClassUpTo)
             assumeTrue(message, ignoreAfter == null || devApiLevel <= ignoreAfter.value)
             assumeTrue(message, ignoreUpTo == null || devApiLevel > ignoreUpTo.value)
             base.evaluate()
diff --git a/tests/lib/src/com/android/testutils/NetworkStatsUtils.kt b/tests/lib/src/com/android/testutils/NetworkStatsUtils.kt
index 51e9e4d..8324b25 100644
--- a/tests/lib/src/com/android/testutils/NetworkStatsUtils.kt
+++ b/tests/lib/src/com/android/testutils/NetworkStatsUtils.kt
@@ -29,16 +29,24 @@
     if (compareTime && leftStats.getElapsedRealtime() != rightStats.getElapsedRealtime()) {
         return false
     }
-    if (leftStats.size() != rightStats.size()) return false
+
+    // While operations such as add/subtract will preserve empty entries. This will make
+    // the result be hard to verify during test. Remove them before comparing since they
+    // are not really affect correctness.
+    // TODO (b/152827872): Remove empty entries after addition/subtraction.
+    val leftTrimmedEmpty = leftStats.removeEmptyEntries()
+    val rightTrimmedEmpty = rightStats.removeEmptyEntries()
+
+    if (leftTrimmedEmpty.size() != rightTrimmedEmpty.size()) return false
     val left = NetworkStats.Entry()
     val right = NetworkStats.Entry()
     // Order insensitive compare.
-    for (i in 0 until leftStats.size()) {
-        leftStats.getValues(i, left)
-        val j: Int = rightStats.findIndexHinted(left.iface, left.uid, left.set, left.tag,
+    for (i in 0 until leftTrimmedEmpty.size()) {
+        leftTrimmedEmpty.getValues(i, left)
+        val j: Int = rightTrimmedEmpty.findIndexHinted(left.iface, left.uid, left.set, left.tag,
                 left.metered, left.roaming, left.defaultNetwork, i)
         if (j == -1) return false
-        rightStats.getValues(j, right)
+        rightTrimmedEmpty.getValues(j, right)
         if (left != right) return false
     }
     return true
@@ -58,5 +66,13 @@
     compareTime: Boolean = false
 ) {
     assertTrue(orderInsensitiveEquals(expected, actual, compareTime),
-            "expected: " + expected + "but was: " + actual)
+            "expected: " + expected + " but was: " + actual)
+}
+
+/**
+ * Assert that after being parceled then unparceled, {@link NetworkStats} is equal to the original
+ * object.
+ */
+fun assertParcelingIsLossless(stats: NetworkStats) {
+    assertParcelingIsLossless(stats, { a, b -> orderInsensitiveEquals(a, b) })
 }
diff --git a/tests/lib/src/com/android/testutils/ParcelUtils.kt b/tests/lib/src/com/android/testutils/ParcelUtils.kt
index 9cbd053..5784f7c 100644
--- a/tests/lib/src/com/android/testutils/ParcelUtils.kt
+++ b/tests/lib/src/com/android/testutils/ParcelUtils.kt
@@ -18,13 +18,13 @@
 
 import android.os.Parcel
 import android.os.Parcelable
-import kotlin.test.assertEquals
+import kotlin.test.assertTrue
 import kotlin.test.fail
 
 /**
  * Return a new instance of `T` after being parceled then unparceled.
  */
-fun <T: Parcelable> parcelingRoundTrip(source: T): T {
+fun <T : Parcelable> parcelingRoundTrip(source: T): T {
     val creator: Parcelable.Creator<T>
     try {
         creator = source.javaClass.getField("CREATOR").get(null) as Parcelable.Creator<T>
@@ -46,13 +46,23 @@
 
 /**
  * Assert that after being parceled then unparceled, `source` is equal to the original
- * object.
+ * object. If a customized equals function is provided, uses the provided one.
  */
-fun <T: Parcelable> assertParcelingIsLossless(source: T) {
-    assertEquals(source, parcelingRoundTrip(source))
+@JvmOverloads
+fun <T : Parcelable> assertParcelingIsLossless(
+    source: T,
+    equals: (T, T) -> Boolean = { a, b -> a == b }
+) {
+    val actual = parcelingRoundTrip(source)
+    assertTrue(equals(source, actual), "Expected $source, but was $actual")
 }
 
-fun <T: Parcelable> assertParcelSane(obj: T, fieldCount: Int) {
+@JvmOverloads
+fun <T : Parcelable> assertParcelSane(
+    obj: T,
+    fieldCount: Int,
+    equals: (T, T) -> Boolean = { a, b -> a == b }
+) {
     assertFieldCountEquals(fieldCount, obj::class.java)
-    assertParcelingIsLossless(obj)
+    assertParcelingIsLossless(obj, equals)
 }
diff --git a/tests/unit/src/android/net/ip/IpClientTest.java b/tests/unit/src/android/net/ip/IpClientTest.java
index 4be5442..5fa5a6c 100644
--- a/tests/unit/src/android/net/ip/IpClientTest.java
+++ b/tests/unit/src/android/net/ip/IpClientTest.java
@@ -221,7 +221,7 @@
         final IpClient ipc = new IpClient(mContext, TEST_IFNAME, mCb, mObserverRegistry,
                 mNetworkStackServiceManager, mDependencies);
         ipc.startProvisioning(new ProvisioningConfiguration());
-        verify(mCb, times(1)).onProvisioningFailure(any());
+        verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).onProvisioningFailure(any());
         verify(mIpMemoryStore, never()).storeNetworkAttributes(any(), any(), any());
         ipc.shutdown();
     }
@@ -249,7 +249,7 @@
                 .build();
 
         ipc.startProvisioning(config);
-        verify(mCb, times(1)).setNeighborDiscoveryOffload(true);
+        verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).setNeighborDiscoveryOffload(true);
         verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).setFallbackMulticastFilter(false);
 
         final LinkProperties lp = makeIPv6ProvisionedLinkProperties();
@@ -364,7 +364,7 @@
                 .build();
 
         ipc.startProvisioning(config);
-        verify(mCb, times(1)).setNeighborDiscoveryOffload(true);
+        verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).setNeighborDiscoveryOffload(true);
         verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).setFallbackMulticastFilter(false);
         verify(mCb, never()).onProvisioningFailure(any());
         ipc.setL2KeyAndGroupHint(l2Key, groupHint);
diff --git a/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java b/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java
deleted file mode 100644
index ba3b306..0000000
--- a/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package android.net.ip;
-
-import static org.mockito.Mockito.anyString;
-import static org.mockito.Mockito.when;
-
-import android.content.Context;
-import android.net.INetd;
-import android.net.util.InterfaceParams;
-import android.net.util.SharedLog;
-import android.os.Handler;
-import android.os.Looper;
-
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-/**
- * Tests for IpReachabilityMonitor.
- */
-@RunWith(AndroidJUnit4.class)
-@SmallTest
-public class IpReachabilityMonitorTest {
-    @Mock IpReachabilityMonitor.Callback mCallback;
-    @Mock IpReachabilityMonitor.Dependencies mDependencies;
-    @Mock SharedLog mLog;
-    @Mock Context mContext;
-    @Mock INetd mNetd;
-    Handler mHandler;
-
-    @Before
-    public void setUp() {
-        MockitoAnnotations.initMocks(this);
-        when(mLog.forSubComponent(anyString())).thenReturn(mLog);
-        mHandler = new Handler(Looper.getMainLooper());
-    }
-
-    IpReachabilityMonitor makeMonitor() {
-        final InterfaceParams ifParams = new InterfaceParams("fake0", 1, null);
-        return new IpReachabilityMonitor(
-                mContext, ifParams, mHandler, mLog, mCallback, false, mDependencies, mNetd);
-    }
-
-    @Test
-    public void testNothing() {
-        // make sure the unit test runs in the same thread with main looper.
-        // Otherwise, throwing IllegalStateException would cause test fails.
-        mHandler.post(() -> makeMonitor());
-    }
-}
diff --git a/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.kt b/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.kt
new file mode 100644
index 0000000..ac50651
--- /dev/null
+++ b/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.kt
@@ -0,0 +1,247 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.ip
+
+import android.content.Context
+import android.net.INetd
+import android.net.InetAddresses.parseNumericAddress
+import android.net.IpPrefix
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.RouteInfo
+import android.net.netlink.StructNdMsg.NUD_FAILED
+import android.net.netlink.StructNdMsg.NUD_STALE
+import android.net.netlink.makeNewNeighMessage
+import android.net.util.InterfaceParams
+import android.net.util.SharedLog
+import android.os.Handler
+import android.os.HandlerThread
+import android.os.MessageQueue
+import android.os.MessageQueue.OnFileDescriptorEventListener
+import android.system.ErrnoException
+import android.system.OsConstants.EAGAIN
+import androidx.test.filters.SmallTest
+import androidx.test.runner.AndroidJUnit4
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers.any
+import org.mockito.ArgumentMatchers.anyInt
+import org.mockito.ArgumentMatchers.anyString
+import org.mockito.ArgumentMatchers.eq
+import org.mockito.Mockito.doAnswer
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.never
+import org.mockito.Mockito.timeout
+import org.mockito.Mockito.verify
+import java.io.FileDescriptor
+import java.net.Inet4Address
+import java.net.Inet6Address
+import java.net.InetAddress
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.TimeUnit
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+private const val TEST_TIMEOUT_MS = 10_000L
+
+private val TEST_IPV4_GATEWAY = parseNumericAddress("192.168.222.3") as Inet4Address
+private val TEST_IPV6_GATEWAY = parseNumericAddress("2001:db8::1") as Inet6Address
+
+private val TEST_IPV4_LINKADDR = LinkAddress("192.168.222.123/24")
+private val TEST_IPV6_LINKADDR = LinkAddress("2001:db8::123/64")
+
+// DNSes inside IP prefix
+private val TEST_IPV4_DNS = parseNumericAddress("192.168.222.1") as Inet4Address
+private val TEST_IPV6_DNS = parseNumericAddress("2001:db8::321") as Inet6Address
+
+private val TEST_IFACE = InterfaceParams("fake0", 21, null)
+private val TEST_LINK_PROPERTIES = LinkProperties().apply {
+    interfaceName = TEST_IFACE.name
+    addLinkAddress(TEST_IPV4_LINKADDR)
+    addLinkAddress(TEST_IPV6_LINKADDR)
+
+    // Add on link routes
+    addRoute(RouteInfo(TEST_IPV4_LINKADDR, null /* gateway */, TEST_IFACE.name))
+    addRoute(RouteInfo(TEST_IPV6_LINKADDR, null /* gateway */, TEST_IFACE.name))
+
+    // Add default routes
+    addRoute(RouteInfo(IpPrefix(parseNumericAddress("0.0.0.0"), 0), TEST_IPV4_GATEWAY))
+    addRoute(RouteInfo(IpPrefix(parseNumericAddress("::"), 0), TEST_IPV6_GATEWAY))
+
+    addDnsServer(TEST_IPV4_DNS)
+    addDnsServer(TEST_IPV6_DNS)
+}
+
+/**
+ * Tests for IpReachabilityMonitor.
+ */
+@RunWith(AndroidJUnit4::class)
+@SmallTest
+class IpReachabilityMonitorTest {
+    private val callback = mock(IpReachabilityMonitor.Callback::class.java)
+    private val dependencies = mock(IpReachabilityMonitor.Dependencies::class.java)
+    private val log = mock(SharedLog::class.java)
+    private val context = mock(Context::class.java)
+    private val netd = mock(INetd::class.java)
+    private val fd = mock(FileDescriptor::class.java)
+
+    private val handlerThread = HandlerThread(IpReachabilityMonitorTest::class.simpleName)
+    private val handler by lazy { Handler(handlerThread.looper) }
+
+    private lateinit var reachabilityMonitor: IpReachabilityMonitor
+    private lateinit var neighborMonitor: TestIpNeighborMonitor
+
+    /**
+     * A version of [IpNeighborMonitor] that overrides packet reading from a socket, and instead
+     * allows the test to enqueue test packets via [enqueuePacket].
+     */
+    private class TestIpNeighborMonitor(
+        handler: Handler,
+        log: SharedLog,
+        cb: NeighborEventConsumer,
+        private val fd: FileDescriptor
+    ) : IpNeighborMonitor(handler, log, cb) {
+
+        private val pendingPackets = ConcurrentLinkedQueue<ByteArray>()
+        val msgQueue = mock(MessageQueue::class.java)
+
+        private var eventListener: OnFileDescriptorEventListener? = null
+
+        override fun createFd() = fd
+        override fun getMessageQueue() = msgQueue
+
+        fun enqueuePacket(packet: ByteArray) {
+            val listener = eventListener ?: fail("IpNeighborMonitor was not yet started")
+            pendingPackets.add(packet)
+            handler.post {
+                listener.onFileDescriptorEvents(fd, OnFileDescriptorEventListener.EVENT_INPUT)
+            }
+        }
+
+        override fun readPacket(fd: FileDescriptor, packetBuffer: ByteArray): Int {
+            val packet = pendingPackets.poll() ?: throw ErrnoException("No pending packet", EAGAIN)
+            if (packet.size > packetBuffer.size) {
+                fail("Buffer (${packetBuffer.size}) is too small for packet (${packet.size})")
+            }
+            System.arraycopy(packet, 0, packetBuffer, 0, packet.size)
+            return packet.size
+        }
+
+        override fun onStart() {
+            super.onStart()
+
+            // Find the file descriptor listener that was registered on the instrumented queue
+            val captor = ArgumentCaptor.forClass(OnFileDescriptorEventListener::class.java)
+            verify(msgQueue).addOnFileDescriptorEventListener(
+                    eq(fd), anyInt(), captor.capture())
+            eventListener = captor.value
+        }
+    }
+
+    @Before
+    fun setUp() {
+        doReturn(log).`when`(log).forSubComponent(anyString())
+        doReturn(true).`when`(fd).valid()
+        handlerThread.start()
+
+        doAnswer { inv ->
+            val handler = inv.getArgument<Handler>(0)
+            val log = inv.getArgument<SharedLog>(1)
+            val cb = inv.getArgument<IpNeighborMonitor.NeighborEventConsumer>(2)
+            neighborMonitor = TestIpNeighborMonitor(handler, log, cb, fd)
+            neighborMonitor
+        }.`when`(dependencies).makeIpNeighborMonitor(any(), any(), any())
+
+        val monitorFuture = CompletableFuture<IpReachabilityMonitor>()
+        // IpReachabilityMonitor needs to be started from the handler thread
+        handler.post {
+            monitorFuture.complete(IpReachabilityMonitor(
+                    context,
+                    TEST_IFACE,
+                    handler,
+                    log,
+                    callback,
+                    false /* useMultinetworkPolicyTracker */,
+                    dependencies,
+                    netd))
+        }
+        reachabilityMonitor = monitorFuture.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+        assertTrue(::neighborMonitor.isInitialized,
+                "IpReachabilityMonitor did not call makeIpNeighborMonitor")
+    }
+
+    @After
+    fun tearDown() {
+        doReturn(false).`when`(fd).valid()
+        handlerThread.quitSafely()
+    }
+
+    // TODO: fix this bug
+    @Test
+    fun testLoseProvisioning_CrashIfFirstProbeIsFailed() {
+        reachabilityMonitor.updateLinkProperties(TEST_LINK_PROPERTIES)
+
+        doAnswer {
+            // Set the fd as invalid when the event listener is removed, to avoid a crash when the
+            // reader tries to close the mock fd.
+            // This does not exactly reflect behavior on close, but this test is only demonstrating
+            // a bug that causes the close, and it will be removed when the bug fixed.
+            doReturn(false).`when`(fd).valid()
+        }.`when`(neighborMonitor.msgQueue).removeOnFileDescriptorEventListener(any())
+
+        neighborMonitor.enqueuePacket(makeNewNeighMessage(TEST_IPV4_DNS, NUD_FAILED))
+        verify(neighborMonitor.msgQueue, timeout(TEST_TIMEOUT_MS))
+                .removeOnFileDescriptorEventListener(any())
+        verify(callback, never()).notifyLost(eq(TEST_IPV4_DNS), anyString())
+    }
+
+    private fun runLoseProvisioningTest(lostNeighbor: InetAddress) {
+        reachabilityMonitor.updateLinkProperties(TEST_LINK_PROPERTIES)
+
+        neighborMonitor.enqueuePacket(makeNewNeighMessage(TEST_IPV4_GATEWAY, NUD_STALE))
+        neighborMonitor.enqueuePacket(makeNewNeighMessage(TEST_IPV6_GATEWAY, NUD_STALE))
+        neighborMonitor.enqueuePacket(makeNewNeighMessage(TEST_IPV4_DNS, NUD_STALE))
+        neighborMonitor.enqueuePacket(makeNewNeighMessage(TEST_IPV6_DNS, NUD_STALE))
+
+        neighborMonitor.enqueuePacket(makeNewNeighMessage(lostNeighbor, NUD_FAILED))
+        verify(callback, timeout(TEST_TIMEOUT_MS)).notifyLost(eq(lostNeighbor), anyString())
+    }
+
+    @Test
+    fun testLoseProvisioning_Ipv4DnsLost() {
+        runLoseProvisioningTest(TEST_IPV4_DNS)
+    }
+
+    @Test
+    fun testLoseProvisioning_Ipv6DnsLost() {
+        runLoseProvisioningTest(TEST_IPV6_DNS)
+    }
+
+    @Test
+    fun testLoseProvisioning_Ipv4GatewayLost() {
+        runLoseProvisioningTest(TEST_IPV4_GATEWAY)
+    }
+
+    @Test
+    fun testLoseProvisioning_Ipv6GatewayLost() {
+        runLoseProvisioningTest(TEST_IPV6_GATEWAY)
+    }
+}
\ No newline at end of file
diff --git a/tests/unit/src/android/net/netlink/NetlinkTestUtils.kt b/tests/unit/src/android/net/netlink/NetlinkTestUtils.kt
new file mode 100644
index 0000000..6655e96
--- /dev/null
+++ b/tests/unit/src/android/net/netlink/NetlinkTestUtils.kt
@@ -0,0 +1,100 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.netlink
+
+import android.net.netlink.NetlinkConstants.RTM_DELNEIGH
+import android.net.netlink.NetlinkConstants.RTM_NEWNEIGH
+import libcore.util.HexEncoding
+import libcore.util.HexEncoding.encodeToString
+import java.net.Inet6Address
+import java.net.InetAddress
+
+/**
+ * Make a RTM_NEWNEIGH netlink message.
+ */
+fun makeNewNeighMessage(
+    neighAddr: InetAddress,
+    nudState: Short
+) = makeNeighborMessage(
+        neighAddr = neighAddr,
+        type = RTM_NEWNEIGH,
+        nudState = nudState
+)
+
+/**
+ * Make a RTM_DELNEIGH netlink message.
+ */
+fun makeDelNeighMessage(
+    neighAddr: InetAddress,
+    nudState: Short
+) = makeNeighborMessage(
+        neighAddr = neighAddr,
+        type = RTM_DELNEIGH,
+        nudState = nudState
+)
+
+private fun makeNeighborMessage(
+    neighAddr: InetAddress,
+    type: Short,
+    nudState: Short
+) = HexEncoding.decode(
+    /* ktlint-disable indent */
+    // -- struct nlmsghdr --
+                         // length = 88 or 76:
+    (if (neighAddr is Inet6Address) "58000000" else "4c000000") +
+    type.toLEHex() +     // type
+    "0000" +             // flags
+    "00000000" +         // seqno
+    "00000000" +         // pid (0 == kernel)
+    // struct ndmsg
+                         // family (AF_INET6 or AF_INET)
+    (if (neighAddr is Inet6Address) "0a" else "02") +
+    "00" +               // pad1
+    "0000" +             // pad2
+    "15000000" +         // interface index (21 == wlan0, on test device)
+    nudState.toLEHex() + // NUD state
+    "00" +               // flags
+    "01" +               // type
+    // -- struct nlattr: NDA_DST --
+                         // length = 20 or 8:
+    (if (neighAddr is Inet6Address) "1400" else "0800") +
+    "0100" +             // type (1 == NDA_DST, for neighbor messages)
+                         // IP address:
+    encodeToString(neighAddr.address, false /* upperCase */) +
+    // -- struct nlattr: NDA_LLADDR --
+    "0a00" +             // length = 10
+    "0200" +             // type (2 == NDA_LLADDR, for neighbor messages)
+    "00005e000164" +     // MAC Address (== 00:00:5e:00:01:64)
+    "0000" +             // padding, for 4 byte alignment
+    // -- struct nlattr: NDA_PROBES --
+    "0800" +             // length = 8
+    "0400" +             // type (4 == NDA_PROBES, for neighbor messages)
+    "01000000" +         // number of probes
+    // -- struct nlattr: NDA_CACHEINFO --
+    "1400" +             // length = 20
+    "0300" +             // type (3 == NDA_CACHEINFO, for neighbor messages)
+    "05190000" +         // ndm_used, as "clock ticks ago"
+    "05190000" +         // ndm_confirmed, as "clock ticks ago"
+    "190d0000" +         // ndm_updated, as "clock ticks ago"
+    "00000000",          // ndm_refcnt
+    false /* allowSingleChar */)
+    /* ktlint-enable indent */
+
+/**
+ * Convert a [Short] to a little-endian hex string.
+ */
+private fun Short.toLEHex() = String.format("%04x", java.lang.Short.reverseBytes(this))
diff --git a/tests/unit/src/android/net/netlink/RtNetlinkNeighborMessageTest.java b/tests/unit/src/android/net/netlink/RtNetlinkNeighborMessageTest.java
index 72e6bca..34257b8 100644
--- a/tests/unit/src/android/net/netlink/RtNetlinkNeighborMessageTest.java
+++ b/tests/unit/src/android/net/netlink/RtNetlinkNeighborMessageTest.java
@@ -16,10 +16,15 @@
 
 package android.net.netlink;
 
+import static android.net.netlink.NetlinkTestUtilsKt.makeDelNeighMessage;
+import static android.net.netlink.NetlinkTestUtilsKt.makeNewNeighMessage;
+import static android.net.netlink.StructNdMsg.NUD_STALE;
+
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 
+import android.net.InetAddresses;
 import android.net.netlink.NetlinkConstants;
 import android.net.netlink.NetlinkMessage;
 import android.net.netlink.RtNetlinkNeighborMessage;
@@ -46,83 +51,11 @@
 public class RtNetlinkNeighborMessageTest {
     private final String TAG = "RtNetlinkNeighborMessageTest";
 
-    // Hexadecimal representation of packet capture.
-    public static final String RTM_DELNEIGH_HEX =
-            // struct nlmsghdr
-            "4c000000" +     // length = 76
-            "1d00" +         // type = 29 (RTM_DELNEIGH)
-            "0000" +         // flags
-            "00000000" +     // seqno
-            "00000000" +     // pid (0 == kernel)
-            // struct ndmsg
-            "02" +           // family
-            "00" +           // pad1
-            "0000" +         // pad2
-            "15000000" +     // interface index (21  == wlan0, on test device)
-            "0400" +         // NUD state (0x04 == NUD_STALE)
-            "00" +           // flags
-            "01" +           // type
-            // struct nlattr: NDA_DST
-            "0800" +         // length = 8
-            "0100" +         // type (1 == NDA_DST, for neighbor messages)
-            "c0a89ffe" +     // IPv4 address (== 192.168.159.254)
-            // struct nlattr: NDA_LLADDR
-            "0a00" +         // length = 10
-            "0200" +         // type (2 == NDA_LLADDR, for neighbor messages)
-            "00005e000164" + // MAC Address (== 00:00:5e:00:01:64)
-            "0000" +         // padding, for 4 byte alignment
-            // struct nlattr: NDA_PROBES
-            "0800" +         // length = 8
-            "0400" +         // type (4 == NDA_PROBES, for neighbor messages)
-            "01000000" +     // number of probes
-            // struct nlattr: NDA_CACHEINFO
-            "1400" +         // length = 20
-            "0300" +         // type (3 == NDA_CACHEINFO, for neighbor messages)
-            "05190000" +     // ndm_used, as "clock ticks ago"
-            "05190000" +     // ndm_confirmed, as "clock ticks ago"
-            "190d0000" +     // ndm_updated, as "clock ticks ago"
-            "00000000";      // ndm_refcnt
-    public static final byte[] RTM_DELNEIGH =
-            HexEncoding.decode(RTM_DELNEIGH_HEX.toCharArray(), false);
+    public static final byte[] RTM_DELNEIGH = makeDelNeighMessage(
+            InetAddresses.parseNumericAddress("192.168.159.254"), NUD_STALE);
 
-    // Hexadecimal representation of packet capture.
-    public static final String RTM_NEWNEIGH_HEX =
-            // struct nlmsghdr
-            "58000000" +     // length = 88
-            "1c00" +         // type = 28 (RTM_NEWNEIGH)
-            "0000" +         // flags
-            "00000000" +     // seqno
-            "00000000" +     // pid (0 == kernel)
-            // struct ndmsg
-            "0a" +           // family
-            "00" +           // pad1
-            "0000" +         // pad2
-            "15000000" +     // interface index (21  == wlan0, on test device)
-            "0400" +         // NUD state (0x04 == NUD_STALE)
-            "80" +           // flags
-            "01" +           // type
-            // struct nlattr: NDA_DST
-            "1400" +         // length = 20
-            "0100" +         // type (1 == NDA_DST, for neighbor messages)
-            "fe8000000000000086c9b2fffe6aed4b" + // IPv6 address (== fe80::86c9:b2ff:fe6a:ed4b)
-            // struct nlattr: NDA_LLADDR
-            "0a00" +         // length = 10
-            "0200" +         // type (2 == NDA_LLADDR, for neighbor messages)
-            "84c9b26aed4b" + // MAC Address (== 84:c9:b2:6a:ed:4b)
-            "0000" +         // padding, for 4 byte alignment
-            // struct nlattr: NDA_PROBES
-            "0800" +         // length = 8
-            "0400" +         // type (4 == NDA_PROBES, for neighbor messages)
-            "01000000" +     // number of probes
-            // struct nlattr: NDA_CACHEINFO
-            "1400" +         // length = 20
-            "0300" +         // type (3 == NDA_CACHEINFO, for neighbor messages)
-            "eb0e0000" +     // ndm_used, as "clock ticks ago"
-            "861f0000" +     // ndm_confirmed, as "clock ticks ago"
-            "00000000" +     // ndm_updated, as "clock ticks ago"
-            "05000000";      // ndm_refcnt
-    public static final byte[] RTM_NEWNEIGH =
-            HexEncoding.decode(RTM_NEWNEIGH_HEX.toCharArray(), false);
+    public static final byte[] RTM_NEWNEIGH = makeNewNeighMessage(
+            InetAddresses.parseNumericAddress("fe80::86c9:b2ff:fe6a:ed4b"), NUD_STALE);
 
     // An example of the full response from an RTM_GETNEIGH query.
     private static final String RTM_GETNEIGH_RESPONSE_HEX =
@@ -165,7 +98,7 @@
         assertNotNull(ndmsgHdr);
         assertEquals((byte) OsConstants.AF_INET, ndmsgHdr.ndm_family);
         assertEquals(21, ndmsgHdr.ndm_ifindex);
-        assertEquals(StructNdMsg.NUD_STALE, ndmsgHdr.ndm_state);
+        assertEquals(NUD_STALE, ndmsgHdr.ndm_state);
         final InetAddress destination = neighMsg.getDestination();
         assertNotNull(destination);
         assertEquals(InetAddress.parseNumericAddress("192.168.159.254"), destination);
@@ -192,7 +125,7 @@
         assertNotNull(ndmsgHdr);
         assertEquals((byte) OsConstants.AF_INET6, ndmsgHdr.ndm_family);
         assertEquals(21, ndmsgHdr.ndm_ifindex);
-        assertEquals(StructNdMsg.NUD_STALE, ndmsgHdr.ndm_state);
+        assertEquals(NUD_STALE, ndmsgHdr.ndm_state);
         final InetAddress destination = neighMsg.getDestination();
         assertNotNull(destination);
         assertEquals(InetAddress.parseNumericAddress("fe80::86c9:b2ff:fe6a:ed4b"), destination);