Add tests for TestableNetworkCallback

Test: this
Change-Id: I47b8f7c7584df0735ff9f46bee7ef10ec58e68c6
diff --git a/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt b/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt
index 7796dcd..25b1e0f 100644
--- a/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt
+++ b/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt
@@ -4,6 +4,7 @@
 import java.util.concurrent.CyclicBarrier
 import kotlin.system.measureTimeMillis
 import kotlin.test.assertEquals
+import kotlin.test.assertFails
 import kotlin.test.assertNull
 import kotlin.test.assertTrue
 
@@ -64,7 +65,15 @@
 
     // Spins as many threads as needed by the test spec and interpret each program concurrently,
     // having all threads waiting on a CyclicBarrier after each line.
-    fun interpretTestSpec(spec: String, initial: T, threadTransform: (T) -> T = { it }) {
+    // |lineShift| says how many lines after the call the spec starts. This is used for error
+    // reporting. Unfortunately AFAICT there is no way to get the line of an argument rather
+    // than the line at which the expression starts.
+    fun interpretTestSpec(
+        spec: String,
+        initial: T,
+        lineShift: Int = 0,
+        threadTransform: (T) -> T = { it }
+    ) {
         // For nice stack traces
         val callSite = getCallingMethod()
         val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
@@ -91,7 +100,8 @@
                         // testing. Instead, catch the exception, cancel other threads, and report
                         // nicely. Catch throwable because fail() is AssertionError, which inherits
                         // from Error.
-                        crash = InterpretException(threadIndex, it, callSite.lineNumber + lineNum,
+                        crash = InterpretException(threadIndex, it,
+                                callSite.lineNumber + lineNum + lineShift,
                                 callSite.className, callSite.methodName, callSite.fileName, e)
                     }
                     barrier.await()
@@ -147,6 +157,9 @@
     // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
     Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
         SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
+    },
+    Regex("""(.*)\s*fails""") to { i, t, r ->
+        assertFails { i.interpret(r.strArg(1), t) }
     }
 )
 
diff --git a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt
index 1cc1168..bbb279e 100644
--- a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt
+++ b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt
@@ -21,11 +21,15 @@
 import android.net.Network
 import android.net.NetworkCapabilities
 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
-import com.android.testutils.RecorderCallback.CallbackRecord.Available
-import com.android.testutils.RecorderCallback.CallbackRecord.BlockedStatus
-import com.android.testutils.RecorderCallback.CallbackRecord.CapabilitiesChanged
-import com.android.testutils.RecorderCallback.CallbackRecord.LinkPropertiesChanged
-import com.android.testutils.RecorderCallback.CallbackRecord.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.Losing
+import com.android.testutils.RecorderCallback.CallbackEntry.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
+import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
+import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
 import kotlin.reflect.KClass
 import kotlin.test.assertEquals
 import kotlin.test.assertTrue
@@ -35,38 +39,43 @@
 
 private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
 
-open class RecorderCallback : NetworkCallback() {
-    sealed class CallbackRecord {
+open class RecorderCallback private constructor(
+    private val backingRecord: ArrayTrackRecord<CallbackEntry>
+) : NetworkCallback() {
+    public constructor() : this(ArrayTrackRecord())
+    protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord())
+
+    sealed class CallbackEntry {
         // To get equals(), hashcode(), componentN() etc for free, the child classes of
         // this class are data classes. But while data classes can inherit from other classes,
         // they may only have visible members in the constructors, so they couldn't declare
-        // a constructor with a non-val arg to pass to CallbackRecord. Instead, force all
+        // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
         // subclasses to implement a `network' property, which can be done in a data class
         // constructor by specifying override.
         abstract val network: Network
 
-        data class Available(override val network: Network) : CallbackRecord()
+        data class Available(override val network: Network) : CallbackEntry()
         data class CapabilitiesChanged(
             override val network: Network,
             val caps: NetworkCapabilities
-        ) : CallbackRecord()
+        ) : CallbackEntry()
         data class LinkPropertiesChanged(
             override val network: Network,
             val lp: LinkProperties
-        ) : CallbackRecord()
-        data class Suspended(override val network: Network) : CallbackRecord()
-        data class Resumed(override val network: Network) : CallbackRecord()
-        data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackRecord()
-        data class Lost(override val network: Network) : CallbackRecord()
+        ) : CallbackEntry()
+        data class Suspended(override val network: Network) : CallbackEntry()
+        data class Resumed(override val network: Network) : CallbackEntry()
+        data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
+        data class Lost(override val network: Network) : CallbackEntry()
         data class Unavailable private constructor(
             override val network: Network
-        ) : CallbackRecord() {
+        ) : CallbackEntry() {
             constructor() : this(NULL_NETWORK)
         }
         data class BlockedStatus(
             override val network: Network,
             val blocked: Boolean
-        ) : CallbackRecord()
+        ) : CallbackEntry()
 
         // Convenience constants for expecting a type
         companion object {
@@ -91,12 +100,15 @@
         }
     }
 
-    protected val history = ArrayTrackRecord<CallbackRecord>().newReadHead()
+    protected val history = backingRecord.newReadHead()
 
     override fun onAvailable(network: Network) {
         history.add(Available(network))
     }
 
+    // PreCheck is not used in the tests today. For backward compatibility with existing tests that
+    // expect the callbacks not to record this, do not listen to PreCheck here.
+
     override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
         history.add(CapabilitiesChanged(network, caps))
     }
@@ -110,39 +122,46 @@
     }
 
     override fun onNetworkSuspended(network: Network) {
-        history.add(CallbackRecord.Suspended(network))
+        history.add(Suspended(network))
     }
 
     override fun onNetworkResumed(network: Network) {
-        history.add(CallbackRecord.Resumed(network))
+        history.add(Resumed(network))
     }
 
     override fun onLosing(network: Network, maxMsToLive: Int) {
-        history.add(CallbackRecord.Losing(network, maxMsToLive))
+        history.add(Losing(network, maxMsToLive))
     }
 
     override fun onLost(network: Network) {
-        history.add(CallbackRecord.Lost(network))
+        history.add(Lost(network))
     }
 
     override fun onUnavailable() {
-        history.add(CallbackRecord.Unavailable())
+        history.add(Unavailable())
     }
 }
 
-typealias CallbackType = KClass<out RecorderCallback.CallbackRecord>
-const val DEFAULT_TIMEOUT = 200L // ms
+private const val DEFAULT_TIMEOUT = 200L // ms
 
-open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT)
-        : RecorderCallback() {
-    // The last available network. Null if the last available network was lost since.
+open class TestableNetworkCallback private constructor(
+    src: TestableNetworkCallback?,
+    val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
+) : RecorderCallback(src) {
+    @JvmOverloads
+    constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
+
+    fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
+
+    // The last available network, or null if any network was lost since the last call to
+    // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
     val lastAvailableNetwork: Network?
         get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
             is Available -> it.network
             else -> null
         }
 
-    fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackRecord {
+    fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
         return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
     }
 
@@ -153,7 +172,7 @@
         if (null != cb) fail("Expected no callback but got $cb")
     }
 
-    inline fun <reified T : CallbackRecord> expectCallback(
+    inline fun <reified T : CallbackEntry> expectCallback(
         network: Network,
         timeoutMs: Long = defaultTimeoutMs
     ): T = pollForNextCallback(timeoutMs).let {
@@ -166,7 +185,7 @@
 
     fun expectCallbackThat(
         timeoutMs: Long = defaultTimeoutMs,
-        valid: (CallbackRecord) -> Boolean
+        valid: (CallbackEntry) -> Boolean
     ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
 
     fun expectCapabilitiesThat(
@@ -209,7 +228,7 @@
     ) {
         expectCallback<Available>(net, tmt)
         if (suspended) {
-            expectCallback<CallbackRecord.Suspended>(net, tmt)
+            expectCallback<CallbackEntry.Suspended>(net, tmt)
         }
         expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) }
         expectCallback<LinkPropertiesChanged>(net, tmt)
@@ -257,7 +276,7 @@
     }
 
     @JvmOverloads
-    open fun <T : CallbackRecord> expectCallback(
+    open fun <T : CallbackEntry> expectCallback(
         type: KClass<T>,
         n: HasNetwork?,
         timeoutMs: Long = defaultTimeoutMs
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp
index 3410d0e..03bcf95 100644
--- a/tests/unit/Android.bp
+++ b/tests/unit/Android.bp
@@ -22,6 +22,7 @@
     resource_dirs: ["res"],
     static_libs: [
         "androidx.test.rules",
+        "kotlin-reflect",
         "mockito-target-extended-minus-junit4",
         "net-tests-utils",
         "NetworkStackApiCurrentLib",
diff --git a/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt b/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt
new file mode 100644
index 0000000..4e4d25a
--- /dev/null
+++ b/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt
@@ -0,0 +1,293 @@
+package android.net.testutils
+
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import com.android.testutils.ConcurrentIntepreter
+import com.android.testutils.InterpretMatcher
+import com.android.testutils.RecorderCallback.CallbackEntry
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.intArg
+import com.android.testutils.strArg
+import com.android.testutils.timeArg
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import kotlin.reflect.KClass
+import kotlin.test.assertEquals
+import kotlin.test.assertFails
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+const val SHORT_TIMEOUT_MS = 20L
+const val DEFAULT_LINGER_DELAY_MS = 30000
+const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED
+const val WIFI = NetworkCapabilities.TRANSPORT_WIFI
+const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR
+const val TEST_INTERFACE_NAME = "testInterfaceName"
+
+@RunWith(JUnit4::class)
+class TestableNetworkCallbackTest {
+    private lateinit var mCallback: TestableNetworkCallback
+
+    private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork {
+        override val network: Network = Network(netId)
+    }
+
+    @Before
+    fun setUp() {
+        mCallback = TestableNetworkCallback()
+    }
+
+    @Test
+    fun testLastAvailableNetwork() {
+        // Make sure there is no last available network at first, then the last available network
+        // is returned after onAvailable is called.
+        val net2097 = Network(2097)
+        assertNull(mCallback.lastAvailableNetwork)
+        mCallback.onAvailable(net2097)
+        assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+        // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available
+        // network.
+        mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities())
+        mCallback.onLinkPropertiesChanged(net2097, LinkProperties())
+        assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+        // Make sure onLost clears the last available network.
+        mCallback.onLost(net2097)
+        assertNull(mCallback.lastAvailableNetwork)
+
+        // Do the same but with a different network after onLost : make sure the last available
+        // network is the new one, not the original one.
+        val net2098 = Network(2098)
+        mCallback.onAvailable(net2098)
+        mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities())
+        mCallback.onLinkPropertiesChanged(net2098, LinkProperties())
+        assertEquals(mCallback.lastAvailableNetwork, net2098)
+
+        // Make sure onAvailable changes the last available network even if onLost was not called.
+        val net2099 = Network(2099)
+        mCallback.onAvailable(net2099)
+        assertEquals(mCallback.lastAvailableNetwork, net2099)
+
+        // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily
+        // the last available one. Check that behavior.
+        mCallback.onLost(net2098)
+        assertNull(mCallback.lastAvailableNetwork)
+
+        // Make sure that losing the really last available one still results in null.
+        mCallback.onLost(net2099)
+        assertNull(mCallback.lastAvailableNetwork)
+
+        // Make sure multiple onAvailable in a row then onLost still results in null.
+        mCallback.onAvailable(net2097)
+        mCallback.onAvailable(net2098)
+        mCallback.onAvailable(net2099)
+        mCallback.onLost(net2097)
+        assertNull(mCallback.lastAvailableNetwork)
+    }
+
+    @Test
+    fun testAssertNoCallback() {
+        mCallback.assertNoCallback(SHORT_TIMEOUT_MS)
+        mCallback.onAvailable(Network(100))
+        assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) }
+    }
+
+    @Test
+    fun testCapabilitiesWithAndWithout() {
+        val net = Network(101)
+        val matcher = makeHasNetwork(101)
+        val meteredNc = NetworkCapabilities()
+        val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
+        // Check that expecting caps (with or without) fails when no callback has been received.
+        assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+        assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+        // Add NOT_METERED and check that With succeeds and Without fails.
+        mCallback.onCapabilitiesChanged(net, unmeteredNc)
+        mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
+        mCallback.onCapabilitiesChanged(net, unmeteredNc)
+        assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+        // Don't add NOT_METERED and check that With fails and Without succeeds.
+        mCallback.onCapabilitiesChanged(net, meteredNc)
+        assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+        mCallback.onCapabilitiesChanged(net, meteredNc)
+        mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
+    }
+
+    @Test
+    fun testExpectCallbackThat() {
+        val net = Network(193)
+        val netCaps = NetworkCapabilities().addTransportType(CELLULAR)
+        // Check that expecting callbackThat anything fails when no callback has been received.
+        assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { true } }
+
+        // Basic test for true and false
+        mCallback.onAvailable(net)
+        mCallback.expectCallbackThat { true }
+        mCallback.onAvailable(net)
+        assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { false } }
+
+        // Try a positive and a negative case
+        mCallback.onBlockedStatusChanged(net, true)
+        mCallback.expectCallbackThat { cb -> cb is BlockedStatus && cb.blocked }
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { cb ->
+            cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI)
+        } }
+    }
+
+    @Test
+    fun testCapabilitiesThat() {
+        val net = Network(101)
+        val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
+        // Check that expecting capabilitiesThat anything fails when no callback has been received.
+        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+        // Basic test for true and false
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        mCallback.expectCapabilitiesThat(net) { true }
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+        // Try a positive and a negative case
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        mCallback.expectCapabilitiesThat(net) { caps ->
+            caps.hasCapability(NOT_METERED) &&
+                    caps.hasTransport(WIFI) &&
+                    !caps.hasTransport(CELLULAR)
+        }
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
+            caps.hasTransport(CELLULAR)
+        } }
+
+        // Try a matching callback on the wrong network
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+    }
+
+    @Test
+    fun testLinkPropertiesThat() {
+        val net = Network(112)
+        val linkAddress = LinkAddress("fe80::ace:d00d/64")
+        val mtu = 1984
+        val linkProps = LinkProperties().apply {
+            this.mtu = mtu
+            interfaceName = TEST_INTERFACE_NAME
+            addLinkAddress(linkAddress)
+        }
+
+        // Check that expecting linkPropsThat anything fails when no callback has been received.
+        assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+        // Basic test for true and false
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        mCallback.expectLinkPropertiesThat(net) { true }
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+        // Try a positive and negative case
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        mCallback.expectLinkPropertiesThat(net) { lp ->
+            lp.interfaceName == TEST_INTERFACE_NAME &&
+                    lp.linkAddresses.contains(linkAddress) &&
+                    lp.mtu == mtu
+        }
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp ->
+            lp.interfaceName != TEST_INTERFACE_NAME
+        } }
+
+        // Try a matching callback on the wrong network
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp ->
+            lp.interfaceName == TEST_INTERFACE_NAME
+        } }
+    }
+
+    @Test
+    fun testExpectCallback() {
+        val net = Network(103)
+        // Test expectCallback fails when nothing was sent.
+        assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+
+        // Test onAvailable is seen and can be expected
+        mCallback.onAvailable(net)
+        mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS)
+
+        // Test onAvailable won't return calls with a different network
+        mCallback.onAvailable(Network(106))
+        assertFails { mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS) }
+
+        // Test onAvailable won't return calls with a different callback
+        mCallback.onAvailable(net)
+        assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+    }
+
+    @Test
+    fun testPollForNextCallback() {
+        assertFails { mCallback.pollForNextCallback(SHORT_TIMEOUT_MS) }
+        TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+                threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+            sleep; onAvailable(133)    | poll(2) = Available(133) time 1..4
+                                       | poll(1) fails
+            onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3
+            onBlockedStatus(199)       | poll(1) = BlockedStatus(199) time 0..3
+        """)
+    }
+}
+
+private object TNCInterpreter : ConcurrentIntepreter<TestableNetworkCallback>(interpretTable)
+
+val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|")
+private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> {
+    return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name }
+}
+
+private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>(
+    // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for
+    // all callback types. This is implemented above by enumerating the subclasses of
+    // CallbackEntry and reading their simpleName.
+    Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t ->
+        val record = i.interpret(t.strArg(1), cb)
+        assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record))
+        // Strictly speaking testing for is CallbackEntry is useless as it's been tested above
+        // but the compiler can't figure things out from the isInstance call. It does understand
+        // from the assertTrue(is CallbackEntry) that this is true, which allows to access
+        // the 'network' member below.
+        assertTrue(record is CallbackEntry)
+        assertEquals(record.network.netId, t.intArg(3))
+    },
+    // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for
+    // all callback types. NetworkCapabilities and LinkProperties just get an empty object
+    // as their argument. Losing gets the default linger timer. Blocked gets false.
+    Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t ->
+        val net = Network(t.intArg(2))
+        when (t.strArg(1)) {
+            "Available" -> cb.onAvailable(net)
+            // PreCheck not used in tests. Add it here if it becomes useful.
+            "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities())
+            "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties())
+            "Suspended" -> cb.onNetworkSuspended(net)
+            "Resumed" -> cb.onNetworkResumed(net)
+            "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS)
+            "Lost" -> cb.onLost(net)
+            "Unavailable" -> cb.onUnavailable()
+            "BlockedStatus" -> cb.onBlockedStatusChanged(net, false)
+            else -> fail("Unknown callback type")
+        }
+    },
+    Regex("""poll\((\d+)\)""") to { i, cb, t ->
+        cb.pollForNextCallback(t.timeArg(1))
+    }
+)
diff --git a/tests/unit/src/android/net/testutils/TrackRecordTest.kt b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
index 4fe8d37..995d537 100644
--- a/tests/unit/src/android/net/testutils/TrackRecordTest.kt
+++ b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
@@ -352,7 +352,8 @@
 
 private object TRTInterpreter : ConcurrentIntepreter<TrackRecord<Int>>(interpretTable) {
     fun interpretTestSpec(spec: String, useReadHeads: Boolean) = if (useReadHeads) {
-        interpretTestSpec(spec, ArrayTrackRecord(), { (it as ArrayTrackRecord).newReadHead() })
+        interpretTestSpec(spec, initial = ArrayTrackRecord(),
+                threadTransform = { (it as ArrayTrackRecord).newReadHead() })
     } else {
         interpretTestSpec(spec, ArrayTrackRecord())
     }