Add unit tests for probing

Add tests for MdnsRecordRepository and MdnsInterfaceAdvertiser
implementations of probing.

Bug: 241738458
Test: atest
Change-Id: If41a387f14e805e81b6d0d8217d081ca053e340f
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
index c056e693a..7c84323 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
@@ -38,7 +38,8 @@
     @NonNull
     private final String mLogTag;
 
-    static class AnnouncementInfo implements MdnsPacketRepeater.Request {
+    /** Announcement request to send with {@link MdnsAnnouncer}. */
+    public static class AnnouncementInfo implements MdnsPacketRepeater.Request {
         @NonNull
         private final MdnsPacket mPacket;
 
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 5d45367..997dcbb 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -17,12 +17,16 @@
 package com.android.server.connectivity.mdns;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.net.LinkAddress;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Handler;
 import android.os.Looper;
 import android.util.Log;
 
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
+
 import java.io.IOException;
 import java.util.List;
 
@@ -31,7 +35,8 @@
  */
 public class MdnsInterfaceAdvertiser {
     private static final boolean DBG = MdnsAdvertiser.DBG;
-    private static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
+    @VisibleForTesting
+    public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
     @NonNull
     private final String mTag;
     @NonNull
@@ -84,7 +89,7 @@
      * Callbacks from {@link MdnsProber}.
      */
     private class ProbingCallback implements
-            MdnsPacketRepeater.PacketRepeaterCallback<MdnsProber.ProbingInfo> {
+            PacketRepeaterCallback<MdnsProber.ProbingInfo> {
         @Override
         public void onFinished(MdnsProber.ProbingInfo info) {
             final MdnsAnnouncer.AnnouncementInfo announcementInfo;
@@ -109,23 +114,64 @@
      * Callbacks from {@link MdnsAnnouncer}.
      */
     private class AnnouncingCallback
-            implements MdnsPacketRepeater.PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
+            implements PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
         // TODO: implement
     }
 
+    /**
+     * Dependencies for {@link MdnsInterfaceAdvertiser}, useful for testing.
+     */
+    @VisibleForTesting
+    public static class Dependencies {
+        /** @see MdnsRecordRepository */
+        @NonNull
+        public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper) {
+            return new MdnsRecordRepository(looper);
+        }
+
+        /** @see MdnsReplySender */
+        @NonNull
+        public MdnsReplySender makeReplySender(@NonNull Looper looper,
+                @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
+            return new MdnsReplySender(looper, socket, packetCreationBuffer);
+        }
+
+        /** @see MdnsAnnouncer */
+        public MdnsAnnouncer makeMdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
+                @NonNull MdnsReplySender replySender,
+                @Nullable PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> cb) {
+            return new MdnsAnnouncer(interfaceTag, looper, replySender, cb);
+        }
+
+        /** @see MdnsProber */
+        public MdnsProber makeMdnsProber(@NonNull String interfaceTag, @NonNull Looper looper,
+                @NonNull MdnsReplySender replySender,
+                @NonNull PacketRepeaterCallback<MdnsProber.ProbingInfo> cb) {
+            return new MdnsProber(interfaceTag, looper, replySender, cb);
+        }
+    }
+
     public MdnsInterfaceAdvertiser(@NonNull String logTag,
             @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
             @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb) {
+        this(logTag, socket, initialAddresses, looper, packetCreationBuffer, cb,
+                new Dependencies());
+    }
+
+    public MdnsInterfaceAdvertiser(@NonNull String logTag,
+            @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
+            @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb,
+            @NonNull Dependencies deps) {
         mTag = MdnsInterfaceAdvertiser.class.getSimpleName() + "/" + logTag;
-        mRecordRepository = new MdnsRecordRepository(looper);
+        mRecordRepository = deps.makeRecordRepository(looper);
         mRecordRepository.updateAddresses(initialAddresses);
         mSocket = socket;
         mCb = cb;
         mCbHandler = new Handler(looper);
-        mReplySender = new MdnsReplySender(looper, socket, packetCreationBuffer);
-        mAnnouncer = new MdnsAnnouncer(logTag, looper, mReplySender,
+        mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
+        mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
                 mAnnouncingCallback);
-        mProber = new MdnsProber(logTag, looper, mReplySender, mProbingCallback);
+        mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
     }
 
     /**
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
index 9a1e62b..2cd9148 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
@@ -44,7 +44,8 @@
         mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag;
     }
 
-    static class ProbingInfo implements Request {
+    /** Probing request to send with {@link MdnsProber}. */
+    public static class ProbingInfo implements Request {
 
         private final int mServiceId;
         @NonNull
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
new file mode 100644
index 0000000..ad22305
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.InetAddresses.parseNumericAddress
+import android.net.LinkAddress
+import android.net.nsd.NsdServiceInfo
+import android.os.Build
+import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS
+import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback
+import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.waitForIdle
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
+
+private const val LOG_TAG = "testlogtag"
+private const val TIMEOUT_MS = 10_000L
+
+private val TEST_ADDRS = listOf(LinkAddress(parseNumericAddress("2001:db8::123"), 64))
+private val TEST_BUFFER = ByteArray(1300)
+
+private const val TEST_SERVICE_ID_1 = 42
+private val TEST_SERVICE_1 = NsdServiceInfo().apply {
+    serviceType = "_testservice._tcp"
+    serviceName = "MyTestService"
+    port = 12345
+}
+
+@RunWith(DevSdkIgnoreRunner::class)
+@IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class MdnsInterfaceAdvertiserTest {
+    private val socket = mock(MdnsInterfaceSocket::class.java)
+    private val thread = HandlerThread(MdnsInterfaceAdvertiserTest::class.simpleName)
+    private val cb = mock(MdnsInterfaceAdvertiser.Callback::class.java)
+    private val deps = mock(MdnsInterfaceAdvertiser.Dependencies::class.java)
+    private val repository = mock(MdnsRecordRepository::class.java)
+    private val replySender = mock(MdnsReplySender::class.java)
+    private val announcer = mock(MdnsAnnouncer::class.java)
+    private val prober = mock(MdnsProber::class.java)
+    private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
+            as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
+    private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
+            as ArgumentCaptor<PacketRepeaterCallback<AnnouncementInfo>>
+
+    private val probeCb get() = probeCbCaptor.value
+    private val announceCb get() = announceCbCaptor.value
+
+    private val advertiser by lazy {
+        MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
+    }
+
+    @Before
+    fun setUp() {
+        doReturn(repository).`when`(deps).makeRecordRepository(any())
+        doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
+        doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
+        doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
+
+        doReturn(-1).`when`(repository).addService(anyInt(), any())
+        thread.start()
+        advertiser.start()
+
+        verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
+        verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
+    }
+
+    @After
+    fun tearDown() {
+        thread.quitSafely()
+    }
+
+    @Test
+    fun testAddRemoveService() {
+        val testProbingInfo = mock(ProbingInfo::class.java)
+        doReturn(TEST_SERVICE_ID_1).`when`(testProbingInfo).serviceId
+        doReturn(testProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
+
+        advertiser.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        verify(repository).addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        verify(prober).startProbing(testProbingInfo)
+
+        // Simulate probing success: continues to announcing
+        val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+        doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
+        probeCb.onFinished(testProbingInfo)
+
+        verify(announcer).startSending(TEST_SERVICE_ID_1, testAnnouncementInfo,
+                0L /* initialDelayMs */)
+
+        thread.waitForIdle(TIMEOUT_MS)
+        verify(cb).onRegisterServiceSucceeded(advertiser, TEST_SERVICE_ID_1)
+
+        // Remove the service: expect exit announcements
+        val testExitInfo = mock(AnnouncementInfo::class.java)
+        doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
+
+        // TODO: after exit announcements are implemented, verify that announceCb.onFinished causes
+        // cb.onDestroyed to be called.
+    }
+}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
new file mode 100644
index 0000000..502a36a
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.InetAddresses.parseNumericAddress
+import android.net.nsd.NsdServiceInfo
+import android.os.Build
+import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import java.net.NetworkInterface
+import java.util.Collections
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
+import kotlin.test.assertNotNull
+import kotlin.test.assertTrue
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TEST_SERVICE_ID_1 = 42
+private const val TEST_SERVICE_ID_2 = 43
+private const val TEST_PORT = 12345
+private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
+private val TEST_ADDRESSES = arrayOf(
+        parseNumericAddress("192.0.2.111"),
+        parseNumericAddress("2001:db8::111"),
+        parseNumericAddress("2001:db8::222"))
+
+private val TEST_SERVICE_1 = NsdServiceInfo().apply {
+    serviceType = "_testservice._tcp"
+    serviceName = "MyTestService"
+    port = TEST_PORT
+}
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class MdnsRecordRepositoryTest {
+    private val thread = HandlerThread(MdnsRecordRepositoryTest::class.simpleName)
+    private val deps = object : Dependencies() {
+        override fun getHostname() = TEST_HOSTNAME
+        override fun getInterfaceInetAddresses(iface: NetworkInterface) =
+                Collections.enumeration(TEST_ADDRESSES.toList())
+    }
+
+    @Before
+    fun setUp() {
+        thread.start()
+    }
+
+    @After
+    fun tearDown() {
+        thread.quitSafely()
+    }
+
+    @Test
+    fun testAddServiceAndProbe() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        assertEquals(0, repository.servicesCount)
+        assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
+        assertEquals(1, repository.servicesCount)
+
+        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+        assertNotNull(probingInfo)
+        assertTrue(repository.isProbing(TEST_SERVICE_ID_1))
+
+        assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
+        val packet = probingInfo.getPacket(0)
+
+        assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
+        assertEquals(0, packet.answers.size)
+        assertEquals(0, packet.additionalRecords.size)
+
+        assertEquals(1, packet.questions.size)
+        val expectedName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+        assertEquals(MdnsAnyRecord(expectedName, false /* unicast */), packet.questions[0])
+
+        assertEquals(1, packet.authorityRecords.size)
+        assertEquals(MdnsServiceRecord(expectedName,
+                0L /* receiptTimeMillis */,
+                false /* cacheFlush */,
+                120_000L /* ttlMillis */,
+                0 /* servicePriority */, 0 /* serviceWeight */,
+                TEST_PORT, TEST_HOSTNAME), packet.authorityRecords[0])
+
+        assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices())
+    }
+
+    @Test
+    fun testAddAndConflicts() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        assertFailsWith(NameConflictException::class) {
+            repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
+        }
+    }
+
+    @Test
+    fun testExitingServiceReAdded() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.exitService(TEST_SERVICE_ID_1)
+
+        assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
+        assertEquals(1, repository.servicesCount)
+
+        repository.removeService(TEST_SERVICE_ID_2)
+        assertEquals(0, repository.servicesCount)
+    }
+}