Clean up derived state dependencies in composition

Fixes two cases when derived state dependencies were retained:

1) Derived state instance is no longer observed in composition.
2) Derived state instance is still observed in composition, but its dependencies has been changed since last observation.

As composition uses `dependency -> derived state` mapping, this fix relies on iterating over all map values with `.removeValueIf`. In the future, we can optimize the lookup by using reverse index to lookup dependencies for derived state.

Fixes: 230168389
Test: CompositionAndDerivedStateTests#changingTheDerivedStateInstanceShouldClearDependencies, CompositionAndDerivedStateTests#changingDerivedStateDependenciesShouldClearThem
Change-Id: Id175952fedae53d669d028e460c65d6d2d5d4174
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt
index 37ef908..e318d61 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composition.kt
@@ -19,12 +19,12 @@
 
 import androidx.compose.runtime.collection.IdentityArrayMap
 import androidx.compose.runtime.collection.IdentityArraySet
-import kotlin.coroutines.CoroutineContext
-import kotlin.coroutines.EmptyCoroutineContext
 import androidx.compose.runtime.collection.IdentityScopeMap
 import androidx.compose.runtime.snapshots.fastAll
 import androidx.compose.runtime.snapshots.fastAny
 import androidx.compose.runtime.snapshots.fastForEach
+import kotlin.coroutines.CoroutineContext
+import kotlin.coroutines.EmptyCoroutineContext
 
 /**
  * A composition object is usually constructed for you, and returned from an API that
@@ -400,6 +400,11 @@
     private val derivedStates = IdentityScopeMap<DerivedState<*>>()
 
     /**
+     * Used for testing. Returns dependencies of derived states that are currently observed.
+     */
+    internal val derivedStateDependencies get() = derivedStates.values.filterNotNull()
+
+    /**
      * A list of changes calculated by [Composer] to be applied to the [Applier] and the
      * [SlotTable] to reflect the result of composition. This is a list of lambdas that need to
      * be invoked in order to produce the desired effects.
@@ -674,14 +679,20 @@
             observations.removeValueIf { scope ->
                 scope in conditionallyInvalidatedScopes || invalidated?.let { scope in it } == true
             }
+            cleanUpDerivedStateObservations()
             conditionallyInvalidatedScopes.clear()
         } else {
             invalidated?.let {
                 observations.removeValueIf { scope -> scope in it }
+                cleanUpDerivedStateObservations()
             }
         }
     }
 
+    private fun cleanUpDerivedStateObservations() {
+        derivedStates.removeValueIf { derivedValue -> derivedValue !in observations }
+    }
+
     override fun recordReadOf(value: Any) {
         // Not acquiring lock since this happens during composition with it already held
         if (!areChildrenComposing) {
@@ -691,6 +702,7 @@
 
                 // Record derived state dependency mapping
                 if (value is DerivedState<*>) {
+                    derivedStates.removeScope(value)
                     value.dependencies.forEach { dependency ->
                         derivedStates.add(dependency, value)
                     }
@@ -778,7 +790,7 @@
                 trace("Compose:unobserve") {
                     pendingInvalidScopes = false
                     observations.removeValueIf { scope -> !scope.valid }
-                    derivedStates.removeValueIf { derivedValue -> derivedValue !in observations }
+                    cleanUpDerivedStateObservations()
                 }
             }
         } finally {
@@ -906,6 +918,10 @@
         observations.remove(instance, scope)
     }
 
+    internal fun removeDerivedStateObservation(state: DerivedState<*>) {
+        derivedStates.removeScope(state)
+    }
+
     /**
      * This takes ownership of the current invalidations and sets up a new array map to hold the
      * new invalidations.
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/RecomposeScopeImpl.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/RecomposeScopeImpl.kt
index e2b0c73..36271aa 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/RecomposeScopeImpl.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/RecomposeScopeImpl.kt
@@ -287,6 +287,7 @@
                             if (remove) {
                                 composition.removeObservation(instance, this)
                                 (instance as? DerivedState<*>)?.let {
+                                    composition.removeDerivedStateObservation(it)
                                     trackedDependencies?.let { dependencies ->
                                         dependencies.remove(it)
                                         if (dependencies.size == 0) {
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt
index 1a4aaf0..2bb6ec1 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/collection/IdentityScopeMap.kt
@@ -217,11 +217,27 @@
      * removed, that value is removed also.
      */
     inline fun removeValueIf(predicate: (scope: T) -> Boolean) {
+        removingScopes { scopeSet ->
+            scopeSet.removeValueIf(predicate)
+        }
+    }
+
+    /**
+     * Removes given scope from all sets. If all scopes for a given value are removed, that value
+     * is removed as well.
+     */
+    fun removeScope(scope: T) {
+        removingScopes { scopeSet ->
+            scopeSet.remove(scope)
+        }
+    }
+
+    private inline fun removingScopes(removalOperation: (IdentityArraySet<T>) -> Unit) {
         var destinationIndex = 0
         for (i in 0 until size) {
             val valueIndex = valueOrder[i]
             val set = scopeSets[valueIndex]!!
-            set.removeValueIf(predicate)
+            removalOperation(set)
             if (set.size > 0) {
                 if (destinationIndex != i) {
                     // We'll bubble-up the now-free key-order by swapping the index with the one
diff --git a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/CompositionAndDerivedStateTests.kt b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/CompositionAndDerivedStateTests.kt
index 98e83fb..c73242e 100644
--- a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/CompositionAndDerivedStateTests.kt
+++ b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/CompositionAndDerivedStateTests.kt
@@ -385,6 +385,79 @@
         val observed = (composition as? CompositionImpl)?.observedObjects ?: emptyList()
         assertEquals(2, observed.count())
     }
+
+    @Test
+    fun changingTheDerivedStateInstanceShouldClearDependencies() = compositionTest {
+        var reload by mutableStateOf(0)
+
+        compose {
+            val itemValue = remember(reload) {
+                derivedStateOf {
+                    reload
+                }
+            }
+
+            val items = remember(reload) {
+                derivedStateOf {
+                    List(10) { itemValue.value }
+                }
+            }
+
+            Text("List of size ${items.value.size}")
+        }
+
+        validate {
+            Text("List of size 10")
+        }
+
+        repeat(10) {
+            reload++
+            advance()
+        }
+
+        revalidate()
+
+        // Validate there are only 2 observed dependencies, one per each derived state
+        val observed = (composition as? CompositionImpl)?.derivedStateDependencies ?: emptyList()
+        assertEquals(2, observed.count())
+    }
+
+    @Test
+    fun changingDerivedStateDependenciesShouldClearThem() = compositionTest {
+        var reload by mutableStateOf(0)
+
+        compose {
+            val itemValue = remember(reload) {
+                derivedStateOf { 1 }
+            }
+
+            val intermediateState = rememberUpdatedState(itemValue)
+
+            val snapshot = remember {
+                derivedStateOf {
+                    List(10) { intermediateState.value.value }
+                }
+            }
+
+            Text("List of size ${snapshot.value.size}")
+        }
+
+        validate {
+            Text("List of size 10")
+        }
+
+        repeat(10) {
+            reload++
+            advance()
+        }
+
+        revalidate()
+
+        // Validate there are only 2 observed dependencies, one for intermediateState, one for itemValue
+        val observed = (composition as? CompositionImpl)?.derivedStateDependencies ?: emptyList()
+        println(observed)
+        assertEquals(2, observed.count())
+    }
 }
 
 @Composable