blob: 5d8c3d5c6d3d1081350e1eeeaae881e1f3fcba64 [file] [log] [blame]
/*
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines
import org.junit.*
import org.junit.Test
import java.lang.IllegalStateException
import kotlin.test.*
@Suppress("RedundantAsync")
class ThreadLocalTest : TestBase() {
private val stringThreadLocal = ThreadLocal<String?>()
private val intThreadLocal = ThreadLocal<Int?>()
private val executor = newFixedThreadPoolContext(1, "threadLocalTest")
@After
fun tearDown() {
executor.close()
}
@Test
fun testThreadLocal() = runTest {
assertNull(stringThreadLocal.get())
assertFalse(stringThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) {
assertEquals("value", stringThreadLocal.get())
assertTrue(stringThreadLocal.isPresent())
withContext(executor) {
assertTrue(stringThreadLocal.isPresent())
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
assertEquals("value", stringThreadLocal.get())
}
assertTrue(stringThreadLocal.isPresent())
assertEquals("value", stringThreadLocal.get())
}
assertNull(stringThreadLocal.get())
deferred.await()
assertNull(stringThreadLocal.get())
assertFalse(stringThreadLocal.isPresent())
}
@Test
fun testThreadLocalInitialValue() = runTest {
intThreadLocal.set(42)
assertFalse(intThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) {
assertEquals(239, intThreadLocal.get())
withContext(executor) {
intThreadLocal.ensurePresent()
assertEquals(239, intThreadLocal.get())
}
assertEquals(239, intThreadLocal.get())
}
deferred.await()
assertEquals(42, intThreadLocal.get())
}
@Test
fun testMultipleThreadLocals() = runTest {
stringThreadLocal.set("test")
intThreadLocal.set(314)
val deferred = async(Dispatchers.Default
+ intThreadLocal.asContextElement(value = 239) + stringThreadLocal.asContextElement(value = "pew")) {
assertEquals(239, intThreadLocal.get())
assertEquals("pew", stringThreadLocal.get())
withContext(executor) {
assertEquals(239, intThreadLocal.get())
assertEquals("pew", stringThreadLocal.get())
intThreadLocal.ensurePresent()
stringThreadLocal.ensurePresent()
}
assertEquals(239, intThreadLocal.get())
assertEquals("pew", stringThreadLocal.get())
}
deferred.await()
assertEquals(314, intThreadLocal.get())
assertEquals("test", stringThreadLocal.get())
}
@Test
fun testConflictingThreadLocals() = runTest {
intThreadLocal.set(42)
val deferred = GlobalScope.async(intThreadLocal.asContextElement(1)) {
assertEquals(1, intThreadLocal.get())
withContext(executor + intThreadLocal.asContextElement(42)) {
assertEquals(42, intThreadLocal.get())
}
assertEquals(1, intThreadLocal.get())
val deferred = async(intThreadLocal.asContextElement(53)) {
assertEquals(53, intThreadLocal.get())
}
deferred.await()
assertEquals(1, intThreadLocal.get())
val deferred2 = GlobalScope.async(executor) {
assertNull(intThreadLocal.get())
}
deferred2.await()
assertEquals(1, intThreadLocal.get())
}
deferred.await()
assertEquals(42, intThreadLocal.get())
}
@Test
fun testThreadLocalModification() = runTest {
stringThreadLocal.set("main")
val deferred = async(Dispatchers.Default
+ stringThreadLocal.asContextElement("initial")) {
assertEquals("initial", stringThreadLocal.get())
stringThreadLocal.set("overridden") // <- this value is not reflected in the context, so it's not restored
withContext(executor + stringThreadLocal.asContextElement("ctx")) {
assertEquals("ctx", stringThreadLocal.get())
}
val deferred = async(stringThreadLocal.asContextElement("async")) {
assertEquals("async", stringThreadLocal.get())
}
deferred.await()
assertEquals("initial", stringThreadLocal.get()) // <- not restored
}
deferred.await()
assertFalse(stringThreadLocal.isPresent())
assertEquals("main", stringThreadLocal.get())
}
private data class Counter(var cnt: Int)
private val myCounterLocal = ThreadLocal<Counter>()
@Test
fun testThreadLocalModificationMutableBox() = runTest {
myCounterLocal.set(Counter(42))
val deferred = async(Dispatchers.Default
+ myCounterLocal.asContextElement(Counter(0))) {
assertEquals(0, myCounterLocal.get().cnt)
// Mutate
myCounterLocal.get().cnt = 71
withContext(executor + myCounterLocal.asContextElement(Counter(-1))) {
assertEquals(-1, myCounterLocal.get().cnt)
++myCounterLocal.get().cnt
}
val deferred = async(myCounterLocal.asContextElement(Counter(31))) {
assertEquals(31, myCounterLocal.get().cnt)
++myCounterLocal.get().cnt
}
deferred.await()
assertEquals(71, myCounterLocal.get().cnt)
}
deferred.await()
assertEquals(42, myCounterLocal.get().cnt)
}
@Test
fun testWithContext() = runTest {
expect(1)
newSingleThreadContext("withContext").use {
val data = 42
GlobalScope.async(Dispatchers.Default + intThreadLocal.asContextElement(42)) {
assertEquals(data, intThreadLocal.get())
expect(2)
GlobalScope.async(it + intThreadLocal.asContextElement(31)) {
assertEquals(31, intThreadLocal.get())
expect(3)
}.await()
withContext(it + intThreadLocal.asContextElement(2)) {
assertEquals(2, intThreadLocal.get())
expect(4)
}
GlobalScope.async(it) {
assertNull(intThreadLocal.get())
expect(5)
}.await()
expect(6)
}.await()
}
finish(7)
}
@Test
fun testScope() = runTest {
intThreadLocal.set(42)
val mainThread = Thread.currentThread()
GlobalScope.async {
assertNull(intThreadLocal.get())
assertNotSame(mainThread, Thread.currentThread())
}.await()
GlobalScope.async(intThreadLocal.asContextElement()) {
assertEquals(42, intThreadLocal.get())
assertNotSame(mainThread, Thread.currentThread())
}.await()
}
@Test
fun testMissingThreadLocal() = runTest {
assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() }
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
}
}