blob: ad22305de759f0acf72c5c6523e7a354a7a965d3 [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress
import android.net.LinkAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback
import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.waitForIdle
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.verify
private const val LOG_TAG = "testlogtag"
private const val TIMEOUT_MS = 10_000L
private val TEST_ADDRS = listOf(LinkAddress(parseNumericAddress("2001:db8::123"), 64))
private val TEST_BUFFER = ByteArray(1300)
private const val TEST_SERVICE_ID_1 = 42
private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp"
serviceName = "MyTestService"
port = 12345
}
@RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsInterfaceAdvertiserTest {
private val socket = mock(MdnsInterfaceSocket::class.java)
private val thread = HandlerThread(MdnsInterfaceAdvertiserTest::class.simpleName)
private val cb = mock(MdnsInterfaceAdvertiser.Callback::class.java)
private val deps = mock(MdnsInterfaceAdvertiser.Dependencies::class.java)
private val repository = mock(MdnsRecordRepository::class.java)
private val replySender = mock(MdnsReplySender::class.java)
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<AnnouncementInfo>>
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
private val advertiser by lazy {
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
}
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any())
doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
doReturn(-1).`when`(repository).addService(anyInt(), any())
thread.start()
advertiser.start()
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
}
@After
fun tearDown() {
thread.quitSafely()
}
@Test
fun testAddRemoveService() {
val testProbingInfo = mock(ProbingInfo::class.java)
doReturn(TEST_SERVICE_ID_1).`when`(testProbingInfo).serviceId
doReturn(testProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
advertiser.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
verify(repository).addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
verify(prober).startProbing(testProbingInfo)
// Simulate probing success: continues to announcing
val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
probeCb.onFinished(testProbingInfo)
verify(announcer).startSending(TEST_SERVICE_ID_1, testAnnouncementInfo,
0L /* initialDelayMs */)
thread.waitForIdle(TIMEOUT_MS)
verify(cb).onRegisterServiceSucceeded(advertiser, TEST_SERVICE_ID_1)
// Remove the service: expect exit announcements
val testExitInfo = mock(AnnouncementInfo::class.java)
doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
advertiser.removeService(TEST_SERVICE_ID_1)
verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
// TODO: after exit announcements are implemented, verify that announceCb.onFinished causes
// cb.onDestroyed to be called.
}
}