Add unit tests for probing
Add tests for MdnsRecordRepository and MdnsInterfaceAdvertiser
implementations of probing.
Bug: 241738458
Test: atest
Change-Id: If41a387f14e805e81b6d0d8217d081ca053e340f
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
index c056e693a..7c84323 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
@@ -38,7 +38,8 @@
@NonNull
private final String mLogTag;
- static class AnnouncementInfo implements MdnsPacketRepeater.Request {
+ /** Announcement request to send with {@link MdnsAnnouncer}. */
+ public static class AnnouncementInfo implements MdnsPacketRepeater.Request {
@NonNull
private final MdnsPacket mPacket;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 5d45367..997dcbb 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -17,12 +17,16 @@
package com.android.server.connectivity.mdns;
import android.annotation.NonNull;
+import android.annotation.Nullable;
import android.net.LinkAddress;
import android.net.nsd.NsdServiceInfo;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
+
import java.io.IOException;
import java.util.List;
@@ -31,7 +35,8 @@
*/
public class MdnsInterfaceAdvertiser {
private static final boolean DBG = MdnsAdvertiser.DBG;
- private static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
+ @VisibleForTesting
+ public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
@NonNull
private final String mTag;
@NonNull
@@ -84,7 +89,7 @@
* Callbacks from {@link MdnsProber}.
*/
private class ProbingCallback implements
- MdnsPacketRepeater.PacketRepeaterCallback<MdnsProber.ProbingInfo> {
+ PacketRepeaterCallback<MdnsProber.ProbingInfo> {
@Override
public void onFinished(MdnsProber.ProbingInfo info) {
final MdnsAnnouncer.AnnouncementInfo announcementInfo;
@@ -109,23 +114,64 @@
* Callbacks from {@link MdnsAnnouncer}.
*/
private class AnnouncingCallback
- implements MdnsPacketRepeater.PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
+ implements PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
// TODO: implement
}
+ /**
+ * Dependencies for {@link MdnsInterfaceAdvertiser}, useful for testing.
+ */
+ @VisibleForTesting
+ public static class Dependencies {
+ /** @see MdnsRecordRepository */
+ @NonNull
+ public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper) {
+ return new MdnsRecordRepository(looper);
+ }
+
+ /** @see MdnsReplySender */
+ @NonNull
+ public MdnsReplySender makeReplySender(@NonNull Looper looper,
+ @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
+ return new MdnsReplySender(looper, socket, packetCreationBuffer);
+ }
+
+ /** @see MdnsAnnouncer */
+ public MdnsAnnouncer makeMdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
+ @NonNull MdnsReplySender replySender,
+ @Nullable PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> cb) {
+ return new MdnsAnnouncer(interfaceTag, looper, replySender, cb);
+ }
+
+ /** @see MdnsProber */
+ public MdnsProber makeMdnsProber(@NonNull String interfaceTag, @NonNull Looper looper,
+ @NonNull MdnsReplySender replySender,
+ @NonNull PacketRepeaterCallback<MdnsProber.ProbingInfo> cb) {
+ return new MdnsProber(interfaceTag, looper, replySender, cb);
+ }
+ }
+
public MdnsInterfaceAdvertiser(@NonNull String logTag,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
@NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb) {
+ this(logTag, socket, initialAddresses, looper, packetCreationBuffer, cb,
+ new Dependencies());
+ }
+
+ public MdnsInterfaceAdvertiser(@NonNull String logTag,
+ @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
+ @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb,
+ @NonNull Dependencies deps) {
mTag = MdnsInterfaceAdvertiser.class.getSimpleName() + "/" + logTag;
- mRecordRepository = new MdnsRecordRepository(looper);
+ mRecordRepository = deps.makeRecordRepository(looper);
mRecordRepository.updateAddresses(initialAddresses);
mSocket = socket;
mCb = cb;
mCbHandler = new Handler(looper);
- mReplySender = new MdnsReplySender(looper, socket, packetCreationBuffer);
- mAnnouncer = new MdnsAnnouncer(logTag, looper, mReplySender,
+ mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
+ mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
mAnnouncingCallback);
- mProber = new MdnsProber(logTag, looper, mReplySender, mProbingCallback);
+ mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
}
/**
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
index 9a1e62b..2cd9148 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
@@ -44,7 +44,8 @@
mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag;
}
- static class ProbingInfo implements Request {
+ /** Probing request to send with {@link MdnsProber}. */
+ public static class ProbingInfo implements Request {
private final int mServiceId;
@NonNull
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
new file mode 100644
index 0000000..ad22305
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.InetAddresses.parseNumericAddress
+import android.net.LinkAddress
+import android.net.nsd.NsdServiceInfo
+import android.os.Build
+import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS
+import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback
+import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.waitForIdle
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
+
+private const val LOG_TAG = "testlogtag"
+private const val TIMEOUT_MS = 10_000L
+
+private val TEST_ADDRS = listOf(LinkAddress(parseNumericAddress("2001:db8::123"), 64))
+private val TEST_BUFFER = ByteArray(1300)
+
+private const val TEST_SERVICE_ID_1 = 42
+private val TEST_SERVICE_1 = NsdServiceInfo().apply {
+ serviceType = "_testservice._tcp"
+ serviceName = "MyTestService"
+ port = 12345
+}
+
+@RunWith(DevSdkIgnoreRunner::class)
+@IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class MdnsInterfaceAdvertiserTest {
+ private val socket = mock(MdnsInterfaceSocket::class.java)
+ private val thread = HandlerThread(MdnsInterfaceAdvertiserTest::class.simpleName)
+ private val cb = mock(MdnsInterfaceAdvertiser.Callback::class.java)
+ private val deps = mock(MdnsInterfaceAdvertiser.Dependencies::class.java)
+ private val repository = mock(MdnsRecordRepository::class.java)
+ private val replySender = mock(MdnsReplySender::class.java)
+ private val announcer = mock(MdnsAnnouncer::class.java)
+ private val prober = mock(MdnsProber::class.java)
+ private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
+ as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
+ private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
+ as ArgumentCaptor<PacketRepeaterCallback<AnnouncementInfo>>
+
+ private val probeCb get() = probeCbCaptor.value
+ private val announceCb get() = announceCbCaptor.value
+
+ private val advertiser by lazy {
+ MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
+ }
+
+ @Before
+ fun setUp() {
+ doReturn(repository).`when`(deps).makeRecordRepository(any())
+ doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
+ doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
+ doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
+
+ doReturn(-1).`when`(repository).addService(anyInt(), any())
+ thread.start()
+ advertiser.start()
+
+ verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
+ verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
+ }
+
+ @After
+ fun tearDown() {
+ thread.quitSafely()
+ }
+
+ @Test
+ fun testAddRemoveService() {
+ val testProbingInfo = mock(ProbingInfo::class.java)
+ doReturn(TEST_SERVICE_ID_1).`when`(testProbingInfo).serviceId
+ doReturn(testProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
+
+ advertiser.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ verify(repository).addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ verify(prober).startProbing(testProbingInfo)
+
+ // Simulate probing success: continues to announcing
+ val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+ doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
+ probeCb.onFinished(testProbingInfo)
+
+ verify(announcer).startSending(TEST_SERVICE_ID_1, testAnnouncementInfo,
+ 0L /* initialDelayMs */)
+
+ thread.waitForIdle(TIMEOUT_MS)
+ verify(cb).onRegisterServiceSucceeded(advertiser, TEST_SERVICE_ID_1)
+
+ // Remove the service: expect exit announcements
+ val testExitInfo = mock(AnnouncementInfo::class.java)
+ doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
+ advertiser.removeService(TEST_SERVICE_ID_1)
+
+ verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
+
+ // TODO: after exit announcements are implemented, verify that announceCb.onFinished causes
+ // cb.onDestroyed to be called.
+ }
+}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
new file mode 100644
index 0000000..502a36a
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.InetAddresses.parseNumericAddress
+import android.net.nsd.NsdServiceInfo
+import android.os.Build
+import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import java.net.NetworkInterface
+import java.util.Collections
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
+import kotlin.test.assertNotNull
+import kotlin.test.assertTrue
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TEST_SERVICE_ID_1 = 42
+private const val TEST_SERVICE_ID_2 = 43
+private const val TEST_PORT = 12345
+private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
+private val TEST_ADDRESSES = arrayOf(
+ parseNumericAddress("192.0.2.111"),
+ parseNumericAddress("2001:db8::111"),
+ parseNumericAddress("2001:db8::222"))
+
+private val TEST_SERVICE_1 = NsdServiceInfo().apply {
+ serviceType = "_testservice._tcp"
+ serviceName = "MyTestService"
+ port = TEST_PORT
+}
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class MdnsRecordRepositoryTest {
+ private val thread = HandlerThread(MdnsRecordRepositoryTest::class.simpleName)
+ private val deps = object : Dependencies() {
+ override fun getHostname() = TEST_HOSTNAME
+ override fun getInterfaceInetAddresses(iface: NetworkInterface) =
+ Collections.enumeration(TEST_ADDRESSES.toList())
+ }
+
+ @Before
+ fun setUp() {
+ thread.start()
+ }
+
+ @After
+ fun tearDown() {
+ thread.quitSafely()
+ }
+
+ @Test
+ fun testAddServiceAndProbe() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ assertEquals(0, repository.servicesCount)
+ assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
+ assertEquals(1, repository.servicesCount)
+
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ assertNotNull(probingInfo)
+ assertTrue(repository.isProbing(TEST_SERVICE_ID_1))
+
+ assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
+ val packet = probingInfo.getPacket(0)
+
+ assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
+ assertEquals(0, packet.answers.size)
+ assertEquals(0, packet.additionalRecords.size)
+
+ assertEquals(1, packet.questions.size)
+ val expectedName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+ assertEquals(MdnsAnyRecord(expectedName, false /* unicast */), packet.questions[0])
+
+ assertEquals(1, packet.authorityRecords.size)
+ assertEquals(MdnsServiceRecord(expectedName,
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ 120_000L /* ttlMillis */,
+ 0 /* servicePriority */, 0 /* serviceWeight */,
+ TEST_PORT, TEST_HOSTNAME), packet.authorityRecords[0])
+
+ assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices())
+ }
+
+ @Test
+ fun testAddAndConflicts() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ assertFailsWith(NameConflictException::class) {
+ repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
+ }
+ }
+
+ @Test
+ fun testExitingServiceReAdded() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ repository.exitService(TEST_SERVICE_ID_1)
+
+ assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
+ assertEquals(1, repository.servicesCount)
+
+ repository.removeService(TEST_SERVICE_ID_2)
+ assertEquals(0, repository.servicesCount)
+ }
+}