blob: 00c5a6090f9ee681849624b6b9be1560e11c5071 [file] [log] [blame]
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines.channels
import kotlinx.coroutines.*
import kotlinx.coroutines.selects.*
import org.junit.*
import org.junit.Test
import org.junit.runner.*
import org.junit.runners.*
import java.util.concurrent.atomic.*
import kotlin.test.*
@RunWith(Parameterized::class)
class ChannelSendReceiveStressTest(
private val kind: TestChannelKind,
private val nSenders: Int,
private val nReceivers: Int
) : TestBase() {
companion object {
@Parameterized.Parameters(name = "{0}, nSenders={1}, nReceivers={2}")
@JvmStatic
fun params(): Collection<Array<Any>> =
listOf(1, 2, 10).flatMap { nSenders ->
listOf(1, 10).flatMap { nReceivers ->
TestChannelKind.values().map { arrayOf(it, nSenders, nReceivers) }
}
}
}
private val timeLimit = 30_000L * stressTestMultiplier // 30 sec
private val nEvents = 200_000 * stressTestMultiplier
private val maxBuffer = 10_000 // artificial limit for LinkedListChannel
val channel = kind.create()
private val sendersCompleted = AtomicInteger()
private val receiversCompleted = AtomicInteger()
private val dupes = AtomicInteger()
private val sentTotal = AtomicInteger()
val received = AtomicIntegerArray(nEvents)
private val receivedTotal = AtomicInteger()
private val receivedBy = IntArray(nReceivers)
private val pool =
newFixedThreadPoolContext(nSenders + nReceivers, "ChannelSendReceiveStressTest")
@After
fun tearDown() {
pool.close()
}
@Test
fun testSendReceiveStress() = runBlocking {
println("--- ChannelSendReceiveStressTest $kind with nSenders=$nSenders, nReceivers=$nReceivers")
val receivers = List(nReceivers) { receiverIndex ->
// different event receivers use different code
launch(pool + CoroutineName("receiver$receiverIndex")) {
when (receiverIndex % 5) {
0 -> doReceive(receiverIndex)
1 -> doReceiveOrNull(receiverIndex)
2 -> doIterator(receiverIndex)
3 -> doReceiveSelect(receiverIndex)
4 -> doReceiveSelectOrNull(receiverIndex)
}
receiversCompleted.incrementAndGet()
}
}
val senders = List(nSenders) { senderIndex ->
launch(pool + CoroutineName("sender$senderIndex")) {
when (senderIndex % 2) {
0 -> doSend(senderIndex)
1 -> doSendSelect(senderIndex)
}
sendersCompleted.incrementAndGet()
}
}
// print progress
val progressJob = launch {
var seconds = 0
while (true) {
delay(1000)
println("${++seconds}: Sent ${sentTotal.get()}, received ${receivedTotal.get()}")
}
}
try {
withTimeout(timeLimit) {
senders.forEach { it.join() }
channel.close()
receivers.forEach { it.join() }
}
} catch (e: CancellationException) {
println("!!! Test timed out $e")
}
progressJob.cancel()
println("Tested $kind with nSenders=$nSenders, nReceivers=$nReceivers")
println("Completed successfully ${sendersCompleted.get()} sender coroutines")
println("Completed successfully ${receiversCompleted.get()} receiver coroutines")
println(" Sent ${sentTotal.get()} events")
println(" Received ${receivedTotal.get()} events")
println(" Received dupes ${dupes.get()}")
repeat(nReceivers) { receiveIndex ->
println(" Received by #$receiveIndex ${receivedBy[receiveIndex]}")
}
assertEquals(nSenders, sendersCompleted.get())
assertEquals(nReceivers, receiversCompleted.get())
assertEquals(0, dupes.get())
assertEquals(nEvents, sentTotal.get())
if (!kind.isConflated) assertEquals(nEvents, receivedTotal.get())
repeat(nReceivers) { receiveIndex ->
assertTrue(receivedBy[receiveIndex] > 0, "Each receiver should have received something")
}
}
private suspend fun doSent() {
sentTotal.incrementAndGet()
if (!kind.isConflated) {
while (sentTotal.get() > receivedTotal.get() + maxBuffer)
yield() // throttle fast senders to prevent OOM with LinkedListChannel
}
}
private suspend fun doSend(senderIndex: Int) {
for (i in senderIndex until nEvents step nSenders) {
channel.send(i)
doSent()
}
}
private suspend fun doSendSelect(senderIndex: Int) {
for (i in senderIndex until nEvents step nSenders) {
select<Unit> { channel.onSend(i) { Unit } }
doSent()
}
}
private fun doReceived(receiverIndex: Int, event: Int) {
if (!received.compareAndSet(event, 0, 1)) {
println("Duplicate event $event at $receiverIndex")
dupes.incrementAndGet()
}
receivedTotal.incrementAndGet()
receivedBy[receiverIndex]++
}
private suspend fun doReceive(receiverIndex: Int) {
while (true) {
try { doReceived(receiverIndex, channel.receive()) }
catch (ex: ClosedReceiveChannelException) { break }
}
}
private suspend fun doReceiveOrNull(receiverIndex: Int) {
while (true) {
doReceived(receiverIndex, channel.receiveOrNull() ?: break)
}
}
private suspend fun doIterator(receiverIndex: Int) {
for (event in channel) {
doReceived(receiverIndex, event)
}
}
private suspend fun doReceiveSelect(receiverIndex: Int) {
while (true) {
try {
val event = select<Int> { channel.onReceive { it } }
doReceived(receiverIndex, event)
} catch (ex: ClosedReceiveChannelException) { break }
}
}
private suspend fun doReceiveSelectOrNull(receiverIndex: Int) {
while (true) {
val event = select<Int?> { channel.onReceiveOrNull { it } } ?: break
doReceived(receiverIndex, event)
}
}
}