Fix race condition when advancing the global snapshot

A race condition existed when advancing the global snapshot where
simultanious calls to `advanceGlobalSnapshot` would race against
each other to clame the current global snapshot. The "winner" of
the race would create a global snapshot that was never closed as
the "loser" would close the wrong snapshot.

Fixes: 236043352
Test: ./gradlew :compose:r:r:tDUT
Change-Id: Icab068bbfaa7c72d23dcc9dc7ad9ff806b03f97b
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
index c89245c..26fd0ed 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
@@ -1734,8 +1734,9 @@
 }
 
 private fun <T> advanceGlobalSnapshot(block: (invalid: SnapshotIdSet) -> T): T {
-    val previousGlobalSnapshot = currentGlobalSnapshot.get()
+    var previousGlobalSnapshot = snapshotInitializer as GlobalSnapshot
     val result = sync {
+        previousGlobalSnapshot = currentGlobalSnapshot.get()
         takeNewGlobalSnapshot(previousGlobalSnapshot, block)
     }
 
diff --git a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/JvmCompositionTests.kt b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/JvmCompositionTests.kt
index 5869ebb..6e5c2d70 100644
--- a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/JvmCompositionTests.kt
+++ b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/JvmCompositionTests.kt
@@ -29,8 +29,11 @@
 import kotlinx.coroutines.test.UnconfinedTestDispatcher
 import kotlinx.coroutines.test.runTest
 import kotlin.concurrent.thread
+import kotlin.test.AfterTest
+import kotlin.test.BeforeTest
 import kotlinx.coroutines.delay
 import kotlin.test.Test
+import kotlin.test.assertEquals
 
 @Stable
 @OptIn(InternalComposeApi::class)
@@ -163,4 +166,14 @@
         value = 2
         expectChanges()
     }
+
+    private var count = 0
+    @BeforeTest fun saveSnapshotCount() {
+        count = Snapshot.openSnapshotCount()
+    }
+
+    @AfterTest fun checkSnapshotCount() {
+        val afterCount = Snapshot.openSnapshotCount()
+        assertEquals(count, afterCount, "A snapshot was left open after the test")
+    }
 }