Support DerivedState in SnapshotStateObserver
Updates SnapshotStateObserver to track derived state updates and invalidate the scopes only if result of the calculation has changed.
Test: new tests in SnapshotStateObserverTestsCommon.kt
Change-Id: I72f0e4771a1b74183f678e5ad477626ede701e35
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
index 2e5d0da..9ee164d 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
@@ -215,8 +215,8 @@
/**
* Notify the snapshot that all objects created in this snapshot to this point should be
- * considered initialized. If any state object is are modified passed this point it will
- * appear as modified in the snapshot and any applicable snapshot write observer will be
+ * considered initialized. If any state object is modified after this point it will
+ * appear as modified in the snapshot. Any applicable snapshot write observer will be
* called for the object and the object will be part of the a set of mutated objects sent to
* any applicable snapshot apply observer.
*
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
index 382f190..55240ef 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
@@ -16,11 +16,14 @@
package androidx.compose.runtime.snapshots
+import androidx.compose.runtime.DerivedState
import androidx.compose.runtime.TestOnly
import androidx.compose.runtime.collection.IdentityArrayMap
import androidx.compose.runtime.collection.IdentityArraySet
import androidx.compose.runtime.collection.IdentityScopeMap
import androidx.compose.runtime.collection.mutableVectorOf
+import androidx.compose.runtime.observeDerivedStateRecalculations
+import androidx.compose.runtime.structuralEqualityPolicy
/**
* Helper class to efficiently observe snapshot state reads. See [observeReads] for more details.
@@ -121,7 +124,12 @@
currentMap = scopeMap
scopeMap.currentScope = scope
- Snapshot.observe(readObserver, null, block)
+ observeDerivedStateRecalculations(
+ start = { scopeMap.deriveStateScopeCount++ },
+ done = { scopeMap.deriveStateScopeCount-- },
+ ) {
+ Snapshot.observe(readObserver, null, block)
+ }
} finally {
scopeMap.currentScope = oldScope
currentMap = oldMap
@@ -229,6 +237,8 @@
*/
var currentScope: Any? = null
+ var deriveStateScopeCount = 0
+
/**
* Values that have been read during the scope's [SnapshotStateObserver.observeReads].
*/
@@ -245,16 +255,34 @@
*/
private val invalidated = hashSetOf<Any>()
+ private val dependencyToDerivedStates = IdentityScopeMap<DerivedState<*>>()
+
+ private val derivedStateToValue = HashMap<DerivedState<*>, Any?>()
+
/**
* Record that [value] was read in [currentScope].
*/
fun recordRead(value: Any) {
+ if (deriveStateScopeCount > 0) {
+ // Reads coming from derivedStateOf block
+ return
+ }
+
val scope = currentScope!!
valueToScopes.add(value, scope)
+
val recordedValues = scopeToValues[scope]
?: IdentityArraySet<Any>().also { scopeToValues[scope] = it }
-
recordedValues.add(value)
+
+ if (value is DerivedState<*>) {
+ dependencyToDerivedStates.removeScope(value)
+ val dependencies = value.dependencies
+ for (dependency in dependencies) {
+ dependencyToDerivedStates.add(dependency, value)
+ }
+ derivedStateToValue[value] = value.currentValue
+ }
}
/**
@@ -263,7 +291,7 @@
fun clearScopeObservations(scope: Any) {
val recordedValues = scopeToValues[scope] ?: return
recordedValues.fastForEach {
- valueToScopes.remove(it, scope)
+ removeObservation(scope, it)
}
scopeToValues.remove(scope)
}
@@ -276,19 +304,27 @@
val willRemove = predicate(scope)
if (willRemove) {
valueSet.fastForEach {
- valueToScopes.remove(it, scope)
+ removeObservation(scope, it)
}
}
willRemove
}
}
+ private fun removeObservation(scope: Any, value: Any) {
+ valueToScopes.remove(value, scope)
+ if (value is DerivedState<*>) {
+ dependencyToDerivedStates.removeScope(value)
+ }
+ }
+
/**
* Clear all observations.
*/
fun clear() {
valueToScopes.clear()
scopeToValues.clear()
+ dependencyToDerivedStates.clear()
}
/**
@@ -298,6 +334,23 @@
fun recordInvalidation(changes: Set<Any>): Boolean {
var hasValues = false
for (value in changes) {
+ if (value in dependencyToDerivedStates) {
+ // Find derived state that is invalidated by this change
+ dependencyToDerivedStates.forEachScopeOf(value) { derivedState ->
+ derivedState as DerivedState<Any?>
+ val previousValue = derivedStateToValue[derivedState]
+ val policy = derivedState.policy ?: structuralEqualityPolicy()
+
+ // Invalidate only if currentValue is different than observed on read
+ if (!policy.equivalent(derivedState.currentValue, previousValue)) {
+ valueToScopes.forEachScopeOf(derivedState) { scope ->
+ invalidated += scope
+ hasValues = true
+ }
+ }
+ }
+ }
+
valueToScopes.forEachScopeOf(value) { scope ->
invalidated += scope
hasValues = true
diff --git a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
index 8dc0656..6b9bd48 100644
--- a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
+++ b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
@@ -17,7 +17,10 @@
package androidx.compose.runtime.snapshots
import androidx.compose.runtime.MutableState
+import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.referentialEqualityPolicy
+import androidx.compose.runtime.structuralEqualityPolicy
import kotlin.test.Test
import kotlin.test.assertEquals
@@ -437,6 +440,108 @@
assertEquals(0, changes)
}
+ @Test
+ fun derivedStateOfInvalidatesObserver() {
+ var changes = 0
+
+ runSimpleTest { stateObserver, state ->
+ val derivedState = derivedStateOf { state.value }
+
+ stateObserver.observeReads("scope", { changes++ }) {
+ // read
+ derivedState.value
+ }
+ }
+ assertEquals(1, changes)
+ }
+
+ @Test
+ fun derivedStateOfReferentialChangeDoesNotInvalidateObserver() {
+ var changes = 0
+
+ runSimpleTest { stateObserver, _ ->
+ val state = mutableStateOf(mutableListOf(42), referentialEqualityPolicy())
+ val derivedState = derivedStateOf { state.value }
+
+ stateObserver.observeReads("scope", { changes++ }) {
+ // read
+ derivedState.value
+ }
+
+ state.value = mutableListOf(42)
+ }
+ assertEquals(0, changes)
+ }
+
+ @Test
+ fun nestedDerivedStateOfInvalidatesObserver() {
+ var changes = 0
+
+ runSimpleTest { stateObserver, state ->
+ val derivedState = derivedStateOf { state.value }
+ val derivedState2 = derivedStateOf { derivedState.value }
+
+ stateObserver.observeReads("scope", { changes++ }) {
+ // read
+ derivedState2.value
+ }
+ }
+ assertEquals(1, changes)
+ }
+
+ @Test
+ fun derivedStateOfWithReferentialMutationPolicy() {
+ var changes = 0
+
+ runSimpleTest { stateObserver, _ ->
+ val state = mutableStateOf(mutableListOf(1), referentialEqualityPolicy())
+ val derivedState = derivedStateOf(referentialEqualityPolicy()) { state.value }
+
+ stateObserver.observeReads("scope", { changes++ }) {
+ // read
+ derivedState.value
+ }
+
+ state.value = mutableListOf(1)
+ }
+ assertEquals(1, changes)
+ }
+
+ @Test
+ fun derivedStateOfWithStructuralMutationPolicy() {
+ var changes = 0
+
+ runSimpleTest { stateObserver, _ ->
+ val state = mutableStateOf(mutableListOf(1), referentialEqualityPolicy())
+ val derivedState = derivedStateOf(structuralEqualityPolicy()) { state.value }
+
+ stateObserver.observeReads("scope", { changes++ }) {
+ // read
+ derivedState.value
+ }
+
+ state.value = mutableListOf(1)
+ }
+ assertEquals(0, changes)
+ }
+
+ @Test
+ fun readingDerivedStateAndDependencyInvalidates() {
+ var changes = 0
+
+ runSimpleTest { stateObserver, state ->
+ val derivedState = derivedStateOf { state.value >= 0 }
+
+ stateObserver.observeReads("scope", { changes++ }) {
+ // read derived state
+ derivedState.value
+ // read dependency
+ state.value
+ }
+ }
+ assertEquals(1, changes)
+ }
+
private fun runSimpleTest(
block: (modelObserver: SnapshotStateObserver, data: MutableState<Int>) -> Unit
) {
diff --git a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt
index 380f3ab..02d3ff8 100644
--- a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt
+++ b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt
@@ -178,6 +178,7 @@
val stateObserver = SnapshotStateObserver { it() }
try {
stateObserver.start()
+ Snapshot.notifyObjectsInitialized()
val observer = object : (Any) -> Unit {
override fun invoke(affected: Any) {
@@ -194,11 +195,10 @@
}
}
}
-
- state.value++
+ // read with 0
observer.readWithObservation()
-
- Snapshot.notifyObjectsInitialized()
+ // increase to 1
+ state.value++
Snapshot.sendApplyNotifications()
assertEquals(1, changes)