Update SnapshotStateObserver state cleanup

Updates SnapshotStateObserver to:
- Cleanup derived state dependencies when state leaves the scope instead of cleaning on every read.
- Check if derived state is observed in other scopes before removing the state.
- Use the same lambdas for observing derived state calculations instead of reallocating them for every observation.
- Reuse the same set for recording observed values per scope.

Test: SnapshotStateObserverTestsCommon#readingDerivedStateConditionallyInvalidatesBothScopes
Change-Id: I57bbd111d824d01511e7738e8335f67347096585
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
index 3a7cf06..86dfb6b 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
@@ -17,6 +17,7 @@
 package androidx.compose.runtime.snapshots
 
 import androidx.compose.runtime.DerivedState
+import androidx.compose.runtime.State
 import androidx.compose.runtime.TestOnly
 import androidx.compose.runtime.collection.IdentityArrayMap
 import androidx.compose.runtime.collection.IdentityArraySet
@@ -125,8 +126,8 @@
             scopeMap.currentScope = scope
 
             observeDerivedStateRecalculations(
-                start = { scopeMap.deriveStateScopeCount++ },
-                done = { scopeMap.deriveStateScopeCount-- },
+                start = scopeMap.derivedStateEnterObserver,
+                done = scopeMap.derivedStateExitObserver
             ) {
                 Snapshot.observe(readObserver, null, block)
             }
@@ -237,7 +238,22 @@
          */
         var currentScope: Any? = null
 
-        var deriveStateScopeCount = 0
+        /**
+         * Start observer for derived state recalculation
+         */
+        val derivedStateEnterObserver: (State<*>) -> Unit = { deriveStateScopeCount++ }
+
+        /**
+         * Exit observer for derived state recalculation
+         */
+        val derivedStateExitObserver: (State<*>) -> Unit = { deriveStateScopeCount-- }
+
+        /**
+         * Counter for skipping reads inside derived states. If count is > 0, read happens inside
+         * a derived state.
+         * Reads for derived states are captured separately through [DerivedState.dependencies].
+         */
+        private var deriveStateScopeCount = 0
 
         /**
          * Values that have been read during the scope's [SnapshotStateObserver.observeReads].
@@ -255,9 +271,15 @@
          */
         private val invalidated = hashSetOf<Any>()
 
+        /**
+         * Invalidation index from state objects to derived states reading them.
+         */
         private val dependencyToDerivedStates = IdentityScopeMap<DerivedState<*>>()
 
-        private val derivedStateToValue = HashMap<DerivedState<*>, Any?>()
+        /**
+         * Last derived state value recorded during read.
+         */
+        private val recordedDerivedStateValues = HashMap<DerivedState<*>, Any?>()
 
         /**
          * Record that [value] was read in [currentScope].
@@ -276,14 +298,13 @@
             recordedValues.add(value)
 
             if (value is DerivedState<*>) {
-                dependencyToDerivedStates.removeScope(value)
                 val dependencies = value.dependencies
                 for (dependency in dependencies) {
                     // skip over dependency array
                     if (dependency == null) break
                     dependencyToDerivedStates.add(dependency, value)
                 }
-                derivedStateToValue[value] = value.currentValue
+                recordedDerivedStateValues[value] = value.currentValue
             }
         }
 
@@ -295,7 +316,9 @@
             recordedValues.fastForEach {
                 removeObservation(scope, it)
             }
-            scopeToValues.remove(scope)
+            // clearing the scope usually means that we are about to start observation again
+            // so it doesn't make sense to reallocate the set
+            recordedValues.clear()
         }
 
         /**
@@ -315,8 +338,9 @@
 
         private fun removeObservation(scope: Any, value: Any) {
             valueToScopes.remove(value, scope)
-            if (value is DerivedState<*>) {
+            if (value is DerivedState<*> && value !in valueToScopes) {
                 dependencyToDerivedStates.removeScope(value)
+                recordedDerivedStateValues.remove(value)
             }
         }
 
@@ -327,6 +351,7 @@
             valueToScopes.clear()
             scopeToValues.clear()
             dependencyToDerivedStates.clear()
+            recordedDerivedStateValues.clear()
         }
 
         /**
@@ -340,7 +365,7 @@
                     // Find derived state that is invalidated by this change
                     dependencyToDerivedStates.forEachScopeOf(value) { derivedState ->
                         derivedState as DerivedState<Any?>
-                        val previousValue = derivedStateToValue[derivedState]
+                        val previousValue = recordedDerivedStateValues[derivedState]
                         val policy = derivedState.policy ?: structuralEqualityPolicy()
 
                         // Invalidate only if currentValue is different than observed on read
diff --git a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
index 6b9bd48..efc63b90 100644
--- a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
+++ b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
@@ -542,6 +542,32 @@
         assertEquals(1, changes)
     }
 
+    @Test
+    fun readingDerivedStateConditionallyInvalidatesBothScopes() {
+        var changes = 0
+
+        runSimpleTest { stateObserver, state ->
+            val derivedState = derivedStateOf { state.value }
+
+            val onChange: (String) -> Unit = { changes++ }
+            stateObserver.observeReads("scope", onChange) {
+                // read derived state
+                derivedState.value
+            }
+
+            // read the same state in other scope
+            stateObserver.observeReads("other scope", onChange) {
+                derivedState.value
+            }
+
+            // stop observing state in other scope
+            stateObserver.observeReads("other scope", onChange) {
+                /* no-op */
+            }
+        }
+        assertEquals(1, changes)
+    }
+
     private fun runSimpleTest(
         block: (modelObserver: SnapshotStateObserver, data: MutableState<Int>) -> Unit
     ) {