Fixed memory leak on a race between adding/removing from lock-free list (#1845)

* The problem was introduced by #1565. When doing concurrent add+removeFirst the following can happen:
  - "add" completes, but has not correct prev pointer in next node yet
  - "removeFirst" removes freshly added element
  - "add" performs "finishAdd" that adjust prev pointer of the next node and thus removed element is pointed from the list again
* A separate LockFreeLinkedListAddRemoveStressTest is added that reproduces this problem.
* The old LockFreeLinkedListAtomicLFStressTest is refactored a bit.
diff --git a/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt b/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt
index 26fd169..f718df0 100644
--- a/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt
+++ b/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt
@@ -390,7 +390,7 @@
         final override fun updatedNext(affected: Node, next: Node): Any = next.removed()
 
         final override fun finishOnSuccess(affected: Node, next: Node) {
-            // Complete removal operation here. It bails out if next node is also removed and it becomes
+            // Complete removal operation here. It bails out if next node is also removed. It becomes
             // responsibility of the next's removes to call correctPrev which would help fix all the links.
             next.correctPrev(null)
         }
@@ -531,7 +531,12 @@
     private fun finishAdd(next: Node) {
         next._prev.loop { nextPrev ->
             if (this.next !== next) return // this or next was removed or another node added, remover/adder fixes up links
-            if (next._prev.compareAndSet(nextPrev, this)) return
+            if (next._prev.compareAndSet(nextPrev, this)) {
+                // This newly added node could have been removed, and the above CAS would have added it physically again.
+                // Let us double-check for this situation and correct if needed
+                if (isRemoved) next.correctPrev(null)
+                return
+            }
         }
     }
 
@@ -546,7 +551,7 @@
      * * When this node is removed. In this case there is no need to waste time on corrections, because
      *   remover of this node will ultimately call [correctPrev] on the next node and that will fix all
      *   the links from this node, too.
-     * * When [op] descriptor is not `null` and and operation descriptor that is [OpDescriptor.isEarlierThan]
+     * * When [op] descriptor is not `null` and operation descriptor that is [OpDescriptor.isEarlierThan]
      *   that current [op] is found while traversing the list. This `null` result will be translated
      *   by callers to [RETRY_ATOMIC].
      */
@@ -554,7 +559,7 @@
         val oldPrev = _prev.value
         var prev: Node = oldPrev
         var last: Node? = null // will be set so that last.next === prev
-        while (true) { // move the the left until first non-removed node
+        while (true) { // move the left until first non-removed node
             val prevNext: Any = prev._next.value
             when {
                 // fast path to find quickly find prev node when everything is properly linked
@@ -565,7 +570,7 @@
                         // Note: retry from scratch on failure to update prev
                         return correctPrev(op)
                     }
-                    return prev // return a correct prev
+                    return prev // return the correct prev
                 }
                 // slow path when we need to help remove operations
                 this.isRemoved -> return null // nothing to do, this node was removed, bail out asap to save time
diff --git a/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAddRemoveStressTest.kt b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAddRemoveStressTest.kt
new file mode 100644
index 0000000..3229e66
--- /dev/null
+++ b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAddRemoveStressTest.kt
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
+package kotlinx.coroutines.internal
+
+import kotlinx.atomicfu.*
+import kotlinx.coroutines.*
+import java.util.concurrent.*
+import kotlin.concurrent.*
+import kotlin.test.*
+
+class LockFreeLinkedListAddRemoveStressTest : TestBase() {
+    private class Node : LockFreeLinkedListNode()
+    
+    private val nRepeat = 100_000 * stressTestMultiplier
+    private val list = LockFreeLinkedListHead()
+    private val barrier = CyclicBarrier(3)
+    private val done = atomic(false)
+    private val removed = atomic(0)
+
+    @Test
+    fun testStressAddRemove() {
+        val threads = ArrayList<Thread>()
+        threads += testThread("adder") {
+            val node = Node()
+            list.addLast(node)
+            if (node.remove()) removed.incrementAndGet()
+        }
+        threads += testThread("remover") {
+            val node = list.removeFirstOrNull()
+            if (node != null) removed.incrementAndGet()
+        }
+        try {
+            for (i in 1..nRepeat) {
+                barrier.await()
+                barrier.await()
+                assertEquals(i, removed.value)
+                list.validate()
+            }
+        } finally {
+            done.value = true
+            barrier.await()
+            threads.forEach { it.join() }
+        }
+    }
+
+    private fun testThread(name: String, op: () -> Unit) = thread(name = name) {
+        while (true) {
+            barrier.await()
+            if (done.value) break
+            op()
+            barrier.await()
+        }
+    }
+}
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt
index b967c46..225b848 100644
--- a/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
  */
 
 package kotlinx.coroutines.internal
@@ -19,9 +19,9 @@
 class LockFreeLinkedListAtomicLFStressTest {
     private val env = LockFreedomTestEnvironment("LockFreeLinkedListAtomicLFStressTest")
 
-    data class IntNode(val i: Int) : LockFreeLinkedListNode()
+    private data class Node(val i: Long) : LockFreeLinkedListNode()
 
-    private val TEST_DURATION_SEC = 5 * stressTestMultiplier
+    private val nSeconds = 5 * stressTestMultiplier
 
     private val nLists = 4
     private val nAdderThreads = 4
@@ -32,7 +32,8 @@
     private val undone = AtomicLong()
     private val missed = AtomicLong()
     private val removed = AtomicLong()
-    val error = AtomicReference<Throwable>()
+    private val error = AtomicReference<Throwable>()
+    private val index = AtomicLong()
 
     @Test
     fun testStress() {
@@ -42,7 +43,7 @@
                 when (rnd.nextInt(4)) {
                     0 -> {
                         val list = lists[rnd.nextInt(nLists)]
-                        val node = IntNode(threadId)
+                        val node = Node(index.incrementAndGet())
                         addLastOp(list, node)
                         randomSpinWaitIntermission()
                         tryRemoveOp(node)
@@ -50,7 +51,7 @@
                     1 -> {
                         // just to test conditional add
                         val list = lists[rnd.nextInt(nLists)]
-                        val node = IntNode(threadId)
+                        val node = Node(index.incrementAndGet())
                         addLastIfTrueOp(list, node)
                         randomSpinWaitIntermission()
                         tryRemoveOp(node)
@@ -58,7 +59,7 @@
                     2 -> {
                         // just to test failed conditional add and burn some time
                         val list = lists[rnd.nextInt(nLists)]
-                        val node = IntNode(threadId)
+                        val node = Node(index.incrementAndGet())
                         addLastIfFalseOp(list, node)
                     }
                     3 -> {
@@ -68,8 +69,8 @@
                         check(idx1 < idx2) // that is our global order
                         val list1 = lists[idx1]
                         val list2 = lists[idx2]
-                        val node1 = IntNode(threadId)
-                        val node2 = IntNode(-threadId - 1)
+                        val node1 = Node(index.incrementAndGet())
+                        val node2 = Node(index.incrementAndGet())
                         addTwoOp(list1, node1, list2, node2)
                         randomSpinWaitIntermission()
                         tryRemoveOp(node1)
@@ -91,13 +92,13 @@
                 removeTwoOp(list1, list2)
             }
         }
-        env.performTest(TEST_DURATION_SEC) {
-            val _undone = undone.get()
-            val _missed = missed.get()
-            val _removed = removed.get()
-            println("  Adders undone $_undone node additions")
-            println("  Adders missed $_missed nodes")
-            println("Remover removed $_removed nodes")
+        env.performTest(nSeconds) {
+            val undone = undone.get()
+            val missed = missed.get()
+            val removed = removed.get()
+            println("  Adders undone $undone node additions")
+            println("  Adders missed $missed nodes")
+            println("Remover removed $removed nodes")
         }
         error.get()?.let { throw it }
         assertEquals(missed.get(), removed.get())
@@ -106,19 +107,19 @@
         lists.forEach { it.validate() }
     }
 
-    private fun addLastOp(list: LockFreeLinkedListHead, node: IntNode) {
+    private fun addLastOp(list: LockFreeLinkedListHead, node: Node) {
         list.addLast(node)
     }
 
-    private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: IntNode) {
-        assertTrue(list.addLastIf(node, { true }))
+    private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: Node) {
+        assertTrue(list.addLastIf(node) { true })
     }
 
-    private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: IntNode) {
-        assertFalse(list.addLastIf(node, { false }))
+    private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: Node) {
+        assertFalse(list.addLastIf(node) { false })
     }
 
-    private fun addTwoOp(list1: LockFreeLinkedListHead, node1: IntNode, list2: LockFreeLinkedListHead, node2: IntNode) {
+    private fun addTwoOp(list1: LockFreeLinkedListHead, node1: Node, list2: LockFreeLinkedListHead, node2: Node) {
         val add1 = list1.describeAddLast(node1)
         val add2 = list2.describeAddLast(node2)
         val op = object : AtomicOp<Any?>() {
@@ -138,7 +139,7 @@
         assertTrue(op.perform(null) == null)
     }
 
-    private fun tryRemoveOp(node: IntNode) {
+    private fun tryRemoveOp(node: Node) {
         if (node.remove())
             undone.incrementAndGet()
         else
@@ -165,5 +166,4 @@
         val success = op.perform(null) == null
         if (success) removed.addAndGet(2)
     }
-
 }
\ No newline at end of file