Update SnapshotStateObserver state cleanup
Updates SnapshotStateObserver to:
- Cleanup derived state dependencies when state leaves the scope instead of cleaning on every read.
- Check if derived state is observed in other scopes before removing the state.
- Use the same lambdas for observing derived state calculations instead of reallocating them for every observation.
- Reuse the same set for recording observed values per scope.
Test: SnapshotStateObserverTestsCommon#readingDerivedStateConditionallyInvalidatesBothScopes
Change-Id: I57bbd111d824d01511e7738e8335f67347096585
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
index 3a7cf06..86dfb6b 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
@@ -17,6 +17,7 @@
package androidx.compose.runtime.snapshots
import androidx.compose.runtime.DerivedState
+import androidx.compose.runtime.State
import androidx.compose.runtime.TestOnly
import androidx.compose.runtime.collection.IdentityArrayMap
import androidx.compose.runtime.collection.IdentityArraySet
@@ -125,8 +126,8 @@
scopeMap.currentScope = scope
observeDerivedStateRecalculations(
- start = { scopeMap.deriveStateScopeCount++ },
- done = { scopeMap.deriveStateScopeCount-- },
+ start = scopeMap.derivedStateEnterObserver,
+ done = scopeMap.derivedStateExitObserver
) {
Snapshot.observe(readObserver, null, block)
}
@@ -237,7 +238,22 @@
*/
var currentScope: Any? = null
- var deriveStateScopeCount = 0
+ /**
+ * Start observer for derived state recalculation
+ */
+ val derivedStateEnterObserver: (State<*>) -> Unit = { deriveStateScopeCount++ }
+
+ /**
+ * Exit observer for derived state recalculation
+ */
+ val derivedStateExitObserver: (State<*>) -> Unit = { deriveStateScopeCount-- }
+
+ /**
+ * Counter for skipping reads inside derived states. If count is > 0, read happens inside
+ * a derived state.
+ * Reads for derived states are captured separately through [DerivedState.dependencies].
+ */
+ private var deriveStateScopeCount = 0
/**
* Values that have been read during the scope's [SnapshotStateObserver.observeReads].
@@ -255,9 +271,15 @@
*/
private val invalidated = hashSetOf<Any>()
+ /**
+ * Invalidation index from state objects to derived states reading them.
+ */
private val dependencyToDerivedStates = IdentityScopeMap<DerivedState<*>>()
- private val derivedStateToValue = HashMap<DerivedState<*>, Any?>()
+ /**
+ * Last derived state value recorded during read.
+ */
+ private val recordedDerivedStateValues = HashMap<DerivedState<*>, Any?>()
/**
* Record that [value] was read in [currentScope].
@@ -276,14 +298,13 @@
recordedValues.add(value)
if (value is DerivedState<*>) {
- dependencyToDerivedStates.removeScope(value)
val dependencies = value.dependencies
for (dependency in dependencies) {
// skip over dependency array
if (dependency == null) break
dependencyToDerivedStates.add(dependency, value)
}
- derivedStateToValue[value] = value.currentValue
+ recordedDerivedStateValues[value] = value.currentValue
}
}
@@ -295,7 +316,9 @@
recordedValues.fastForEach {
removeObservation(scope, it)
}
- scopeToValues.remove(scope)
+ // clearing the scope usually means that we are about to start observation again
+ // so it doesn't make sense to reallocate the set
+ recordedValues.clear()
}
/**
@@ -315,8 +338,9 @@
private fun removeObservation(scope: Any, value: Any) {
valueToScopes.remove(value, scope)
- if (value is DerivedState<*>) {
+ if (value is DerivedState<*> && value !in valueToScopes) {
dependencyToDerivedStates.removeScope(value)
+ recordedDerivedStateValues.remove(value)
}
}
@@ -327,6 +351,7 @@
valueToScopes.clear()
scopeToValues.clear()
dependencyToDerivedStates.clear()
+ recordedDerivedStateValues.clear()
}
/**
@@ -340,7 +365,7 @@
// Find derived state that is invalidated by this change
dependencyToDerivedStates.forEachScopeOf(value) { derivedState ->
derivedState as DerivedState<Any?>
- val previousValue = derivedStateToValue[derivedState]
+ val previousValue = recordedDerivedStateValues[derivedState]
val policy = derivedState.policy ?: structuralEqualityPolicy()
// Invalidate only if currentValue is different than observed on read
diff --git a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
index 6b9bd48..efc63b90 100644
--- a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
+++ b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
@@ -542,6 +542,32 @@
assertEquals(1, changes)
}
+ @Test
+ fun readingDerivedStateConditionallyInvalidatesBothScopes() {
+ var changes = 0
+
+ runSimpleTest { stateObserver, state ->
+ val derivedState = derivedStateOf { state.value }
+
+ val onChange: (String) -> Unit = { changes++ }
+ stateObserver.observeReads("scope", onChange) {
+ // read derived state
+ derivedState.value
+ }
+
+ // read the same state in other scope
+ stateObserver.observeReads("other scope", onChange) {
+ derivedState.value
+ }
+
+ // stop observing state in other scope
+ stateObserver.observeReads("other scope", onChange) {
+ /* no-op */
+ }
+ }
+ assertEquals(1, changes)
+ }
+
private fun runSimpleTest(
block: (modelObserver: SnapshotStateObserver, data: MutableState<Int>) -> Unit
) {