blob: 4706d627efd7a0e9a5260a7b5d10d0c49a81ca58 [file] [log] [blame]
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines.test
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
import kotlin.math.*
/**
* [CoroutineDispatcher] that performs both immediate and lazy execution of coroutines in tests
* and implements [DelayController] to control its virtual clock.
*
* By default, [TestCoroutineDispatcher] is immediate. That means any tasks scheduled to be run without delay are
* immediately executed. If they were scheduled with a delay, the virtual clock-time must be advanced via one of the
* methods on [DelayController].
*
* When switched to lazy execution using [pauseDispatcher] any coroutines started via [launch] or [async] will
* not execute until a call to [DelayController.runCurrent] or the virtual clock-time has been advanced via one of the
* methods on [DelayController].
*
* @see DelayController
*/
@ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0
public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayController {
private var dispatchImmediately = true
set(value) {
field = value
if (value) {
// there may already be tasks from setup code we need to run
advanceUntilIdle()
}
}
// The ordered queue for the runnable tasks.
private val queue = ThreadSafeHeap<TimedRunnable>()
// The per-scheduler global order counter.
private val _counter = atomic(0L)
// Storing time in nanoseconds internally.
private val _time = atomic(0L)
/** @suppress */
override fun dispatch(context: CoroutineContext, block: Runnable) {
if (dispatchImmediately) {
block.run()
} else {
post(block)
}
}
/** @suppress */
@InternalCoroutinesApi
override fun dispatchYield(context: CoroutineContext, block: Runnable) {
post(block)
}
/** @suppress */
override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) {
postDelayed(CancellableContinuationRunnable(continuation) { resumeUndispatched(Unit) }, timeMillis)
}
/** @suppress */
override fun invokeOnTimeout(timeMillis: Long, block: Runnable): DisposableHandle {
val node = postDelayed(block, timeMillis)
return object : DisposableHandle {
override fun dispose() {
queue.remove(node)
}
}
}
/** @suppress */
override fun toString(): String {
return "TestCoroutineDispatcher[currentTime=${currentTime}ms, queued=${queue.size}]"
}
private fun post(block: Runnable) =
queue.addLast(TimedRunnable(block, _counter.getAndIncrement()))
private fun postDelayed(block: Runnable, delayTime: Long) =
TimedRunnable(block, _counter.getAndIncrement(), safePlus(currentTime, delayTime))
.also {
queue.addLast(it)
}
private fun safePlus(currentTime: Long, delayTime: Long): Long {
check(delayTime >= 0)
val result = currentTime + delayTime
if (result < currentTime) return Long.MAX_VALUE // clam on overflow
return result
}
private fun doActionsUntil(targetTime: Long) {
while (true) {
val current = queue.removeFirstIf { it.time <= targetTime } ?: break
// If the scheduled time is 0 (immediate) use current virtual time
if (current.time != 0L) _time.value = current.time
current.run()
}
}
/** @suppress */
override val currentTime: Long get() = _time.value
/** @suppress */
override fun advanceTimeBy(delayTimeMillis: Long): Long {
val oldTime = currentTime
advanceUntilTime(oldTime + delayTimeMillis)
return currentTime - oldTime
}
/**
* Moves the CoroutineContext's clock-time to a particular moment in time.
*
* @param targetTime The point in time to which to move the CoroutineContext's clock (milliseconds).
*/
private fun advanceUntilTime(targetTime: Long) {
doActionsUntil(targetTime)
_time.update { currentValue -> max(currentValue, targetTime) }
}
/** @suppress */
override fun advanceUntilIdle(): Long {
val oldTime = currentTime
while(!queue.isEmpty) {
runCurrent()
val next = queue.peek() ?: break
advanceUntilTime(next.time)
}
return currentTime - oldTime
}
/** @suppress */
override fun runCurrent(): Unit = doActionsUntil(currentTime)
/** @suppress */
override suspend fun pauseDispatcher(block: suspend () -> Unit) {
val previous = dispatchImmediately
dispatchImmediately = false
try {
block()
} finally {
dispatchImmediately = previous
}
}
/** @suppress */
override fun pauseDispatcher() {
dispatchImmediately = false
}
/** @suppress */
override fun resumeDispatcher() {
dispatchImmediately = true
}
/** @suppress */
override fun cleanupTestCoroutines() {
// process any pending cancellations or completions, but don't advance time
doActionsUntil(currentTime)
// run through all pending tasks, ignore any submitted coroutines that are not active
val pendingTasks = mutableListOf<TimedRunnable>()
while (true) {
pendingTasks += queue.removeFirstOrNull() ?: break
}
val activeDelays = pendingTasks
.mapNotNull { it.runnable as? CancellableContinuationRunnable<*> }
.filter { it.continuation.isActive }
val activeTimeouts = pendingTasks.filter { it.runnable !is CancellableContinuationRunnable<*> }
if (activeDelays.isNotEmpty() || activeTimeouts.isNotEmpty()) {
throw UncompletedCoroutinesError(
"Unfinished coroutines during teardown. Ensure all coroutines are" +
" completed or cancelled by your test."
)
}
}
}
/**
* This class exists to allow cleanup code to avoid throwing for cancelled continuations scheduled
* in the future.
*/
private class CancellableContinuationRunnable<T>(
@JvmField val continuation: CancellableContinuation<T>,
private val block: CancellableContinuation<T>.() -> Unit
) : Runnable {
override fun run() = continuation.block()
}
/**
* A Runnable for our event loop that represents a task to perform at a time.
*/
private class TimedRunnable(
@JvmField val runnable: Runnable,
private val count: Long = 0,
@JvmField val time: Long = 0
) : Comparable<TimedRunnable>, Runnable by runnable, ThreadSafeHeapNode {
override var heap: ThreadSafeHeap<*>? = null
override var index: Int = 0
override fun compareTo(other: TimedRunnable) = if (time == other.time) {
count.compareTo(other.count)
} else {
time.compareTo(other.time)
}
override fun toString() = "TimedRunnable(time=$time, run=$runnable)"
}