Dispose nested snapshots created from transparent snapshots

Adds a flag to transparent snapshots to "manage" wrapped snapshots which forces dispose of the wrapped snapshot whenever transparent one is disposed.
This flag is only enabled for the snapshots taken inside transparent snapshots, fixing memory leaks in certain conditions.

Fixes a minor bug where transparent snapshot wasn't receiving reads from nested snapshots as well.

Fixes: 239603305
Test: SnapshotTests#testNestedWithinTransparentSnapshotDisposedCorrectly

Change-Id: I62eddd279c8cf44b032d852d646c9ba21ad08a39
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
index cbc5577..489920f 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
@@ -452,7 +452,8 @@
                             previousSnapshot = currentSnapshot as? MutableSnapshot,
                             specifiedReadObserver = readObserver,
                             specifiedWriteObserver = writeObserver,
-                            mergeParentObservers = true
+                            mergeParentObservers = true,
+                            ownsPreviousSnapshot = false
                         )
                     else if (readObserver == null) return block()
                     else currentSnapshot.takeNestedSnapshot(readObserver)
@@ -1434,7 +1435,8 @@
     private val previousSnapshot: MutableSnapshot?,
     internal val specifiedReadObserver: ((Any) -> Unit)?,
     internal val specifiedWriteObserver: ((Any) -> Unit)?,
-    private val mergeParentObservers: Boolean
+    private val mergeParentObservers: Boolean,
+    private val ownsPreviousSnapshot: Boolean
 ) : MutableSnapshot(
     INVALID_SNAPSHOT,
     SnapshotIdSet.EMPTY,
@@ -1454,6 +1456,9 @@
     override fun dispose() {
         // Explicitly don't call super.dispose()
         disposed = true
+        if (ownsPreviousSnapshot) {
+            previousSnapshot?.dispose()
+        }
     }
 
     override var id: Int
@@ -1486,7 +1491,8 @@
         return if (!mergeParentObservers) {
             createTransparentSnapshotWithNoParentReadObserver(
                 previousSnapshot = currentSnapshot.takeNestedSnapshot(null),
-                readObserver = readObserver
+                readObserver = mergedReadObserver,
+                ownsPreviousSnapshot = true
             )
         } else {
             currentSnapshot.takeNestedSnapshot(mergedReadObserver)
@@ -1508,7 +1514,8 @@
                 previousSnapshot = nestedSnapshot,
                 specifiedReadObserver = mergedReadObserver,
                 specifiedWriteObserver = mergedWriteObserver,
-                mergeParentObservers = false
+                mergeParentObservers = false,
+                ownsPreviousSnapshot = true
             )
         } else {
             currentSnapshot.takeNestedMutableSnapshot(
@@ -1532,7 +1539,8 @@
 internal class TransparentObserverSnapshot(
     private val previousSnapshot: Snapshot?,
     specifiedReadObserver: ((Any) -> Unit)?,
-    private val mergeParentObservers: Boolean
+    private val mergeParentObservers: Boolean,
+    private val ownsPreviousSnapshot: Boolean
 ) : Snapshot(
     INVALID_SNAPSHOT,
     SnapshotIdSet.EMPTY,
@@ -1552,6 +1560,9 @@
     override fun dispose() {
         // Explicitly don't call super.dispose()
         disposed = true
+        if (ownsPreviousSnapshot) {
+            previousSnapshot?.dispose()
+        }
     }
 
     override var id: Int
@@ -1580,8 +1591,9 @@
         val mergedReadObserver = mergedReadObserver(readObserver, this.readObserver)
         return if (!mergeParentObservers) {
             createTransparentSnapshotWithNoParentReadObserver(
-                previousSnapshot = currentSnapshot.takeNestedSnapshot(null),
-                readObserver = readObserver
+                currentSnapshot.takeNestedSnapshot(null),
+                mergedReadObserver,
+                ownsPreviousSnapshot = true
             )
         } else {
             currentSnapshot.takeNestedSnapshot(mergedReadObserver)
@@ -1599,18 +1611,21 @@
 private fun createTransparentSnapshotWithNoParentReadObserver(
     previousSnapshot: Snapshot?,
     readObserver: ((Any) -> Unit)? = null,
+    ownsPreviousSnapshot: Boolean = false
 ): Snapshot = if (previousSnapshot is MutableSnapshot || previousSnapshot == null) {
     TransparentObserverMutableSnapshot(
         previousSnapshot = previousSnapshot as? MutableSnapshot,
         specifiedReadObserver = readObserver,
         specifiedWriteObserver = null,
-        mergeParentObservers = false
+        mergeParentObservers = false,
+        ownsPreviousSnapshot = ownsPreviousSnapshot
     )
 } else {
     TransparentObserverSnapshot(
         previousSnapshot = previousSnapshot,
         specifiedReadObserver = readObserver,
-        mergeParentObservers = false
+        mergeParentObservers = false,
+        ownsPreviousSnapshot = ownsPreviousSnapshot
     )
 }
 
diff --git a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotTests.kt b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotTests.kt
index 8d279be..cdff066 100644
--- a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotTests.kt
+++ b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotTests.kt
@@ -936,6 +936,126 @@
         }
     }
 
+    @Test
+    fun testNestedWithinTransparentSnapshotDisposedCorrectly() {
+        val outerSnapshot = TransparentObserverSnapshot(
+            previousSnapshot = currentSnapshot(),
+            specifiedReadObserver = null,
+            mergeParentObservers = false,
+            ownsPreviousSnapshot = false
+        )
+
+        try {
+            outerSnapshot.enter {
+                val innerSnapshot = outerSnapshot.takeNestedSnapshot()
+
+                try {
+                    innerSnapshot.enter { }
+                } finally {
+                    innerSnapshot.dispose()
+                }
+            }
+        } finally {
+            outerSnapshot.dispose()
+        }
+    }
+
+    @Test
+    fun testNestedWithinTransparentMutableSnapshotDisposedCorrectly() {
+        val outerSnapshot = TransparentObserverMutableSnapshot(
+            previousSnapshot = currentSnapshot() as? MutableSnapshot,
+            specifiedReadObserver = null,
+            specifiedWriteObserver = null,
+            mergeParentObservers = false,
+            ownsPreviousSnapshot = false
+        )
+
+        try {
+            outerSnapshot.enter {
+                val innerSnapshot = outerSnapshot.takeNestedSnapshot()
+
+                try {
+                    innerSnapshot.enter { }
+                } finally {
+                    innerSnapshot.dispose()
+                }
+            }
+        } finally {
+            outerSnapshot.dispose()
+        }
+    }
+
+    @Test
+    fun testTransparentSnapshotMergedWithNestedReadObserver() {
+        var outerChanges = 0
+        var innerChanges = 0
+        val state by mutableStateOf(0)
+
+        val outerSnapshot = TransparentObserverSnapshot(
+            previousSnapshot = currentSnapshot(),
+            specifiedReadObserver = { outerChanges++ },
+            mergeParentObservers = false,
+            ownsPreviousSnapshot = false
+        )
+
+        try {
+            outerSnapshot.enter {
+                val innerSnapshot = outerSnapshot.takeNestedSnapshot(
+                    readObserver = { innerChanges++ }
+                )
+
+                try {
+                    innerSnapshot.enter {
+                        state // read
+                    }
+                } finally {
+                    innerSnapshot.dispose()
+                }
+            }
+        } finally {
+            outerSnapshot.dispose()
+        }
+
+        assertEquals(1, outerChanges)
+        assertEquals(1, innerChanges)
+    }
+
+    @Test
+    fun testTransparentMutableSnapshotMergedWithNestedReadObserver() {
+        var outerChanges = 0
+        var innerChanges = 0
+        val state by mutableStateOf(0)
+
+        val outerSnapshot = TransparentObserverMutableSnapshot(
+            previousSnapshot = currentSnapshot() as? MutableSnapshot,
+            specifiedReadObserver = { outerChanges++ },
+            specifiedWriteObserver = null,
+            mergeParentObservers = false,
+            ownsPreviousSnapshot = false
+        )
+
+        try {
+            outerSnapshot.enter {
+                val innerSnapshot = outerSnapshot.takeNestedSnapshot(
+                    readObserver = { innerChanges++ }
+                )
+
+                try {
+                    innerSnapshot.enter {
+                        state // read
+                    }
+                } finally {
+                    innerSnapshot.dispose()
+                }
+            }
+        } finally {
+            outerSnapshot.dispose()
+        }
+
+        assertEquals(1, outerChanges)
+        assertEquals(1, innerChanges)
+    }
+
     private var count = 0
 
     @BeforeTest