* Copyright (C) 2019 The Android Open Source Project
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import android.util.Log
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
object NULL_NETWORK : Network(-1)
object ANY_NETWORK : Network(-2)
fun anyNetwork() = ANY_NETWORK
open class RecorderCallback private constructor(
private val backingRecord: ArrayTrackRecord<CallbackEntry>
) : NetworkCallback() {
public constructor() : this(ArrayTrackRecord())
protected constructor(src: RecorderCallback?) : this(src?.backingRecord ?: ArrayTrackRecord())
private val TAG = this::class.simpleName
sealed class CallbackEntry {
// To get equals(), hashcode(), componentN() etc for free, the child classes of
// this class are data classes. But while data classes can inherit from other classes,
// they may only have visible members in the constructors, so they couldn't declare
// a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
// subclasses to implement a `network' property, which can be done in a data class
// constructor by specifying override.
abstract val network: Network
data class Available(override val network: Network) : CallbackEntry()
data class CapabilitiesChanged(
override val network: Network,
val caps: NetworkCapabilities
) : CallbackEntry()
data class LinkPropertiesChanged(
override val network: Network,
val lp: LinkProperties
) : CallbackEntry()
data class Suspended(override val network: Network) : CallbackEntry()
data class Resumed(override val network: Network) : CallbackEntry()
data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
data class Lost(override val network: Network) : CallbackEntry()
data class Unavailable private constructor(
override val network: Network
) : CallbackEntry() {
constructor() : this(NULL_NETWORK)
data class BlockedStatus(
override val network: Network,
val blocked: Boolean
) : CallbackEntry()
data class BlockedStatusInt(
override val network: Network,
val blocked: Int
) : CallbackEntry()
// Convenience constants for expecting a type
companion object {
val AVAILABLE = Available::class
val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
val SUSPENDED = Suspended::class
val RESUMED = Resumed::class
val LOSING = Losing::class
val LOST = Lost::class
val UNAVAILABLE = Unavailable::class
val BLOCKED_STATUS = BlockedStatus::class
val BLOCKED_STATUS_INT = BlockedStatusInt::class
val history = backingRecord.newReadHead()
val mark get() = history.mark
override fun onAvailable(network: Network) {
Log.d(TAG, "onAvailable $network")
// PreCheck is not used in the tests today. For backward compatibility with existing tests that
// expect the callbacks not to record this, do not listen to PreCheck here.
override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
Log.d(TAG, "onCapabilitiesChanged $network $caps")
history.add(CapabilitiesChanged(network, caps))
override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
Log.d(TAG, "onLinkPropertiesChanged $network $lp")
history.add(LinkPropertiesChanged(network, lp))
override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
Log.d(TAG, "onBlockedStatusChanged $network $blocked")
history.add(BlockedStatus(network, blocked))
// Cannot do:
// fun onBlockedStatusChanged(network: Network, blocked: Int) {
// because on S, that needs to be "override fun", and on R, that cannot be "override fun".
override fun onNetworkSuspended(network: Network) {
Log.d(TAG, "onNetworkSuspended $network $network")
override fun onNetworkResumed(network: Network) {
Log.d(TAG, "$network onNetworkResumed $network")
override fun onLosing(network: Network, maxMsToLive: Int) {
Log.d(TAG, "onLosing $network $maxMsToLive")
history.add(Losing(network, maxMsToLive))
override fun onLost(network: Network) {
Log.d(TAG, "onLost $network")
override fun onUnavailable() {
Log.d(TAG, "onUnavailable")
private const val DEFAULT_TIMEOUT = 30_000L // ms
private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms
private val NOOP = Runnable {}
* See comments on the public constructor below for a description of the arguments.
open class TestableNetworkCallback private constructor(
src: TestableNetworkCallback?,
val defaultTimeoutMs: Long = DEFAULT_TIMEOUT,
val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
val waiterFunc: Runnable = NOOP // "() -> Unit" would forbid calling with a void func from Java
) : RecorderCallback(src) {
* Construct a testable network callback.
* @param timeoutMs the default timeout for expecting a callback. Default 30 seconds. This
* should be long in most cases, because the success case doesn't incur
* the wait.
* @param noCallbackTimeoutMs the timeout for expecting that no callback is received. Default
* 200ms. Because the success case does incur the timeout, this
* should be short in most cases, but not so short as to frequently
* time out before an incorrect callback is received.
* @param waiterFunc a function to use before asserting no callback. For some specific tests,
* it is useful to run test-specific code before asserting no callback to
* increase the likelihood that a spurious callback is correctly detected.
* As an example, a unit test using mock loopers may want to use this to
* make sure the loopers are drained before asserting no callback, since
* one of them may cause a callback to be called. @see ConnectivityServiceTest
* for such an example.
timeoutMs: Long = DEFAULT_TIMEOUT,
noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
waiterFunc: Runnable = NOOP
) : this(null, timeoutMs, noCallbackTimeoutMs, waiterFunc)
fun createLinkedCopy() = TestableNetworkCallback(
this, defaultTimeoutMs, defaultNoCallbackTimeoutMs, waiterFunc)
// The last available network, or null if any network was lost since the last call to
// onAvailable. TODO : fix this by fixing the tests that rely on this behavior
val lastAvailableNetwork: Network?
get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
is Available ->
else -> null
* Get the next callback or null if timeout.
* With no argument, this method waits out the default timeout. To wait forever, pass
fun poll(timeoutMs: Long = defaultTimeoutMs): CallbackEntry? = history.poll(timeoutMs)
* Get the next callback or throw if timeout.
* With no argument, this method waits out the default timeout. To wait forever, pass
fun pollOrThrow(
timeoutMs: Long = defaultTimeoutMs,
errorMsg: String = "Did not receive callback after $timeoutMs"
): CallbackEntry = poll(timeoutMs) ?: fail(errorMsg)
* expect family of methods.
* These methods fetch the next callback and assert it matches the conditions : type,
* passed predicate. If no callback is received within the timeout, these methods fail.
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: Network = ANY_NETWORK,
timeoutMs: Long = defaultTimeoutMs,
errorMsg: String? = null,
test: (T) -> Boolean = { true }
) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
} as T
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: HasNetwork,
timeoutMs: Long = defaultTimeoutMs,
errorMsg: String? = null,
test: (T) -> Boolean = { true }
) = expect(type,, timeoutMs, errorMsg, test)
// Java needs an explicit overload to let it omit arguments in the middle, so define these
// here. Note that @JvmOverloads give us the versions without the last arguments too, so
// there is no need to explicitly define versions without the test predicate.
// Without |network|
fun <T : CallbackEntry> expect(
type: KClass<T>,
timeoutMs: Long,
errorMsg: String?,
test: (T) -> Boolean = { true }
) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test)
// Without |timeout|, in Network and HasNetwork versions
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: Network,
errorMsg: String?,
test: (T) -> Boolean = { true }
) = expect(type, network, defaultTimeoutMs, errorMsg, test)
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: HasNetwork,
errorMsg: String?,
test: (T) -> Boolean = { true }
) = expect(type,, defaultTimeoutMs, errorMsg, test)
// Without |errorMsg|, in Network and HasNetwork versions
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: Network,
timeoutMs: Long,
test: (T) -> Boolean
) = expect(type, network, timeoutMs, null, test)
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: HasNetwork,
timeoutMs: Long,
test: (T) -> Boolean
) = expect(type,, timeoutMs, null, test)
// Without |network| or |timeout|
fun <T : CallbackEntry> expect(
type: KClass<T>,
errorMsg: String?,
test: (T) -> Boolean = { true }
) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test)
// Without |network| or |errorMsg|
fun <T : CallbackEntry> expect(
type: KClass<T>,
timeoutMs: Long,
test: (T) -> Boolean = { true }
) = expect(type, ANY_NETWORK, timeoutMs, null, test)
// Without |timeout| or |errorMsg|, in Network and HasNetwork versions
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: Network,
test: (T) -> Boolean
) = expect(type, network, defaultTimeoutMs, null, test)
fun <T : CallbackEntry> expect(
type: KClass<T>,
network: HasNetwork,
test: (T) -> Boolean
) = expect(type,, defaultTimeoutMs, null, test)
// Without |network| or |timeout| or |errorMsg|
fun <T : CallbackEntry> expect(
type: KClass<T>,
test: (T) -> Boolean
) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test)
// Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline
inline fun <reified T : CallbackEntry> expect(
network: Network = ANY_NETWORK,
timeoutMs: Long = defaultTimeoutMs,
errorMsg: String? = null,
test: (T) -> Boolean = { true }
) = pollOrThrow(timeoutMs, "Did not receive ${T::class.simpleName} after ${timeoutMs}ms").also {
if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it")
if (ANY_NETWORK !== network && != network) {
fail("Expected network $network for callback : $it")
if (!test(it)) {
fail("${errorMsg ?: "Callback doesn't match predicate"} : $it")
} as T
inline fun <reified T : CallbackEntry> expect(
network: HasNetwork,
timeoutMs: Long = defaultTimeoutMs,
errorMsg: String? = null,
test: (T) -> Boolean = { true }
) = expect(, timeoutMs, errorMsg, test)
fun assertNoCallback(
timeoutMs: Long = defaultNoCallbackTimeoutMs,
valid: (CallbackEntry) -> Boolean = { true }
) {
history.poll(timeoutMs) { valid(it) }?.let { fail("Expected no callback but got $it") }
fun assertNoCallback(valid: (CallbackEntry) -> Boolean) =
assertNoCallback(defaultNoCallbackTimeoutMs, valid)
// Expects a callback of the specified type matching the predicate within the timeout.
// Any callback that doesn't match the predicate will be skipped. Fails only if
// no matching callback is received within the timeout.
inline fun <reified T : CallbackEntry> eventuallyExpect(
timeoutMs: Long = defaultTimeoutMs,
from: Int = mark,
crossinline predicate: (T) -> Boolean = { true }
): T = eventuallyExpectOrNull(timeoutMs, from, predicate).also {
assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
} as T
fun <T : CallbackEntry> eventuallyExpect(
type: KClass<T>,
timeoutMs: Long = defaultTimeoutMs,
predicate: (cb: T) -> Boolean = { true }
) = history.poll(timeoutMs) { && predicate(it as T) }.also {
assertNotNull(it, "Callback ${} not received within ${timeoutMs}ms")
} as T
fun <T : CallbackEntry> eventuallyExpect(
type: KClass<T>,
timeoutMs: Long = defaultTimeoutMs,
from: Int = mark,
predicate: (cb: T) -> Boolean = { true }
) = history.poll(timeoutMs, from) { && predicate(it as T) }.also {
assertNotNull(it, "Callback ${} not received within ${timeoutMs}ms")
} as T
// TODO (b/157405399) straighten and unify the method names
inline fun <reified T : CallbackEntry> eventuallyExpectOrNull(
timeoutMs: Long = defaultTimeoutMs,
from: Int = mark,
crossinline predicate: (T) -> Boolean = { true }
) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
inline fun expectLinkPropertiesThat(
net: Network,
tmt: Long = defaultTimeoutMs,
valid: (LinkProperties) -> Boolean
): LinkPropertiesChanged =
expect(net, tmt, "LinkProperties don't match expectations") { valid(it.lp) }
// Expects onAvailable and the callbacks that follow it. These are:
// - onSuspended, iff the network was suspended when the callbacks fire.
// - onCapabilitiesChanged.
// - onLinkPropertiesChanged.
// - onBlockedStatusChanged.
// @param network the network to expect the callbacks on.
// @param suspended whether to expect a SUSPENDED callback.
// @param validated the expected value of the VALIDATED capability in the
// onCapabilitiesChanged callback.
// @param tmt how long to wait for the callbacks.
fun expectAvailableCallbacks(
net: Network,
suspended: Boolean = false,
validated: Boolean? = true,
blocked: Boolean = false,
tmt: Long = defaultTimeoutMs
) {
expectAvailableCallbacksCommon(net, suspended, validated, tmt)
expectBlockedStatusCallback(blocked, net, tmt)
fun expectAvailableCallbacks(
net: Network,
suspended: Boolean,
validated: Boolean,
blockedStatus: Int,
tmt: Long
) {
expectAvailableCallbacksCommon(net, suspended, validated, tmt)
expectBlockedStatusCallback(blockedStatus, net)
private fun expectAvailableCallbacksCommon(
net: Network,
suspended: Boolean,
validated: Boolean?,
tmt: Long
) {
expect<Available>(net, tmt)
if (suspended) {
expect<Suspended>(net, tmt)
expect<CapabilitiesChanged>(net, tmt) {
validated == null || validated == it.caps.hasCapability(NET_CAPABILITY_VALIDATED)
expect<LinkPropertiesChanged>(net, tmt)
// Backward compatibility for existing Java code. Use named arguments instead and remove all
// these when there is no user left.
fun expectAvailableAndSuspendedCallbacks(
net: Network,
validated: Boolean,
tmt: Long = defaultTimeoutMs
) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) =
expect<BlockedStatus>(net, tmt, "Unexpected blocked status") {
it.blocked == blocked
fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) =
expect<BlockedStatusInt>(net, tmt, "Unexpected blocked status") {
it.blocked == blocked
// Expects the available callbacks (where the onCapabilitiesChanged must contain the
// VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
// one we just sent.
// TODO: this is likely a bug. Fix it and remove this method.
fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
val mark = history.mark
expectAvailableCallbacks(net, tmt = tmt)
val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt))
// Expects the available callbacks where the onCapabilitiesChanged must not have validated,
// then expects another onCapabilitiesChanged that has the validated bit set. This is used
// when a network connects and satisfies a callback, and then immediately validates.
fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
expectAvailableCallbacks(net, validated = false, tmt = tmt)
expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
fun expectAvailableThenValidatedCallbacks(
net: Network,
blockedStatus: Int,
tmt: Long = defaultTimeoutMs
) {
expectAvailableCallbacks(net, validated = false, suspended = false,
blockedStatus = blockedStatus, tmt = tmt)
expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
// Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
// calls with networkAgent can be routed through here without moving MockNetworkAgent.
// TODO: clean this up, remove this method.
interface HasNetwork {
val network: Network
fun expectAvailableCallbacks(
n: HasNetwork,
suspended: Boolean,
validated: Boolean,
blocked: Boolean,
timeoutMs: Long
) = expectAvailableCallbacks(, suspended, validated, blocked, timeoutMs)
fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
expectAvailableAndSuspendedCallbacks(, expectValidated)
fun expectAvailableCallbacksValidated(n: HasNetwork) {
fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
expectAvailableCallbacks(, blocked = true)
fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
expectAvailableCallbacks(, validated = false)
fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
expectAvailableCallbacks(, validated = false, blocked = true)
fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
expectAvailableDoubleValidatedCallbacks(, defaultTimeoutMs)
fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
expectAvailableThenValidatedCallbacks(, defaultTimeoutMs)
fun expectLinkPropertiesThat(
n: HasNetwork,
tmt: Long = defaultTimeoutMs,
valid: (LinkProperties) -> Boolean
) = expectLinkPropertiesThat(, tmt, valid)
fun expectCaps(
n: HasNetwork,
tmt: Long = defaultTimeoutMs,
valid: (NetworkCapabilities) -> Boolean = { true }
) = expect<CapabilitiesChanged>(, tmt) { valid(it.caps) }.caps
fun expectCaps(
n: Network,
tmt: Long = defaultTimeoutMs,
valid: (NetworkCapabilities) -> Boolean
) = expect<CapabilitiesChanged>(n, tmt) { valid(it.caps) }.caps
fun expectCaps(
n: HasNetwork,
valid: (NetworkCapabilities) -> Boolean
) = expect<CapabilitiesChanged>( { valid(it.caps) }.caps
fun expectCaps(
tmt: Long,
valid: (NetworkCapabilities) -> Boolean
) = expect<CapabilitiesChanged>(ANY_NETWORK, tmt) { valid(it.caps) }.caps
fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
expectBlockedStatusCallback(expectBlocked,, defaultTimeoutMs)