Support DerivedState in SnapshotStateObserver

Updates SnapshotStateObserver to track derived state updates and invalidate the scopes only if result of the calculation has changed.

Test: new tests in SnapshotStateObserverTestsCommon.kt
Change-Id: I72f0e4771a1b74183f678e5ad477626ede701e35
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
index 2e5d0da..9ee164d 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/Snapshot.kt
@@ -215,8 +215,8 @@
 
     /**
      * Notify the snapshot that all objects created in this snapshot to this point should be
-     * considered initialized. If any state object is are modified passed this point it will
-     * appear as modified in the snapshot and any applicable snapshot write observer will be
+     * considered initialized. If any state object is modified after this point it will
+     * appear as modified in the snapshot. Any applicable snapshot write observer will be
      * called for the object and the object will be part of the a set of mutated objects sent to
      * any applicable snapshot apply observer.
      *
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
index 382f190..55240ef 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
@@ -16,11 +16,14 @@
 
 package androidx.compose.runtime.snapshots
 
+import androidx.compose.runtime.DerivedState
 import androidx.compose.runtime.TestOnly
 import androidx.compose.runtime.collection.IdentityArrayMap
 import androidx.compose.runtime.collection.IdentityArraySet
 import androidx.compose.runtime.collection.IdentityScopeMap
 import androidx.compose.runtime.collection.mutableVectorOf
+import androidx.compose.runtime.observeDerivedStateRecalculations
+import androidx.compose.runtime.structuralEqualityPolicy
 
 /**
  * Helper class to efficiently observe snapshot state reads. See [observeReads] for more details.
@@ -121,7 +124,12 @@
             currentMap = scopeMap
             scopeMap.currentScope = scope
 
-            Snapshot.observe(readObserver, null, block)
+            observeDerivedStateRecalculations(
+                start = { scopeMap.deriveStateScopeCount++ },
+                done = { scopeMap.deriveStateScopeCount-- },
+            ) {
+                Snapshot.observe(readObserver, null, block)
+            }
         } finally {
             scopeMap.currentScope = oldScope
             currentMap = oldMap
@@ -229,6 +237,8 @@
          */
         var currentScope: Any? = null
 
+        var deriveStateScopeCount = 0
+
         /**
          * Values that have been read during the scope's [SnapshotStateObserver.observeReads].
          */
@@ -245,16 +255,34 @@
          */
         private val invalidated = hashSetOf<Any>()
 
+        private val dependencyToDerivedStates = IdentityScopeMap<DerivedState<*>>()
+
+        private val derivedStateToValue = HashMap<DerivedState<*>, Any?>()
+
         /**
          * Record that [value] was read in [currentScope].
          */
         fun recordRead(value: Any) {
+            if (deriveStateScopeCount > 0) {
+                // Reads coming from derivedStateOf block
+                return
+            }
+
             val scope = currentScope!!
             valueToScopes.add(value, scope)
+
             val recordedValues = scopeToValues[scope]
                 ?: IdentityArraySet<Any>().also { scopeToValues[scope] = it }
-
             recordedValues.add(value)
+
+            if (value is DerivedState<*>) {
+                dependencyToDerivedStates.removeScope(value)
+                val dependencies = value.dependencies
+                for (dependency in dependencies) {
+                    dependencyToDerivedStates.add(dependency, value)
+                }
+                derivedStateToValue[value] = value.currentValue
+            }
         }
 
         /**
@@ -263,7 +291,7 @@
         fun clearScopeObservations(scope: Any) {
             val recordedValues = scopeToValues[scope] ?: return
             recordedValues.fastForEach {
-                valueToScopes.remove(it, scope)
+                removeObservation(scope, it)
             }
             scopeToValues.remove(scope)
         }
@@ -276,19 +304,27 @@
                 val willRemove = predicate(scope)
                 if (willRemove) {
                     valueSet.fastForEach {
-                        valueToScopes.remove(it, scope)
+                        removeObservation(scope, it)
                     }
                 }
                 willRemove
             }
         }
 
+        private fun removeObservation(scope: Any, value: Any) {
+            valueToScopes.remove(value, scope)
+            if (value is DerivedState<*>) {
+                dependencyToDerivedStates.removeScope(value)
+            }
+        }
+
         /**
          * Clear all observations.
          */
         fun clear() {
             valueToScopes.clear()
             scopeToValues.clear()
+            dependencyToDerivedStates.clear()
         }
 
         /**
@@ -298,6 +334,23 @@
         fun recordInvalidation(changes: Set<Any>): Boolean {
             var hasValues = false
             for (value in changes) {
+                if (value in dependencyToDerivedStates) {
+                    // Find derived state that is invalidated by this change
+                    dependencyToDerivedStates.forEachScopeOf(value) { derivedState ->
+                        derivedState as DerivedState<Any?>
+                        val previousValue = derivedStateToValue[derivedState]
+                        val policy = derivedState.policy ?: structuralEqualityPolicy()
+
+                        // Invalidate only if currentValue is different than observed on read
+                        if (!policy.equivalent(derivedState.currentValue, previousValue)) {
+                            valueToScopes.forEachScopeOf(derivedState) { scope ->
+                                invalidated += scope
+                                hasValues = true
+                            }
+                        }
+                    }
+                }
+
                 valueToScopes.forEachScopeOf(value) { scope ->
                     invalidated += scope
                     hasValues = true
diff --git a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
index 8dc0656..6b9bd48 100644
--- a/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
+++ b/compose/runtime/runtime/src/commonTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsCommon.kt
@@ -17,7 +17,10 @@
 package androidx.compose.runtime.snapshots
 
 import androidx.compose.runtime.MutableState
+import androidx.compose.runtime.derivedStateOf
 import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.referentialEqualityPolicy
+import androidx.compose.runtime.structuralEqualityPolicy
 import kotlin.test.Test
 import kotlin.test.assertEquals
 
@@ -437,6 +440,108 @@
         assertEquals(0, changes)
     }
 
+    @Test
+    fun derivedStateOfInvalidatesObserver() {
+        var changes = 0
+
+        runSimpleTest { stateObserver, state ->
+            val derivedState = derivedStateOf { state.value }
+
+            stateObserver.observeReads("scope", { changes++ }) {
+                // read
+                derivedState.value
+            }
+        }
+        assertEquals(1, changes)
+    }
+
+    @Test
+    fun derivedStateOfReferentialChangeDoesNotInvalidateObserver() {
+        var changes = 0
+
+        runSimpleTest { stateObserver, _ ->
+            val state = mutableStateOf(mutableListOf(42), referentialEqualityPolicy())
+            val derivedState = derivedStateOf { state.value }
+
+            stateObserver.observeReads("scope", { changes++ }) {
+                // read
+                derivedState.value
+            }
+
+            state.value = mutableListOf(42)
+        }
+        assertEquals(0, changes)
+    }
+
+    @Test
+    fun nestedDerivedStateOfInvalidatesObserver() {
+        var changes = 0
+
+        runSimpleTest { stateObserver, state ->
+            val derivedState = derivedStateOf { state.value }
+            val derivedState2 = derivedStateOf { derivedState.value }
+
+            stateObserver.observeReads("scope", { changes++ }) {
+                // read
+                derivedState2.value
+            }
+        }
+        assertEquals(1, changes)
+    }
+
+    @Test
+    fun derivedStateOfWithReferentialMutationPolicy() {
+        var changes = 0
+
+        runSimpleTest { stateObserver, _ ->
+            val state = mutableStateOf(mutableListOf(1), referentialEqualityPolicy())
+            val derivedState = derivedStateOf(referentialEqualityPolicy()) { state.value }
+
+            stateObserver.observeReads("scope", { changes++ }) {
+                // read
+                derivedState.value
+            }
+
+            state.value = mutableListOf(1)
+        }
+        assertEquals(1, changes)
+    }
+
+    @Test
+    fun derivedStateOfWithStructuralMutationPolicy() {
+        var changes = 0
+
+        runSimpleTest { stateObserver, _ ->
+            val state = mutableStateOf(mutableListOf(1), referentialEqualityPolicy())
+            val derivedState = derivedStateOf(structuralEqualityPolicy()) { state.value }
+
+            stateObserver.observeReads("scope", { changes++ }) {
+                // read
+                derivedState.value
+            }
+
+            state.value = mutableListOf(1)
+        }
+        assertEquals(0, changes)
+    }
+
+    @Test
+    fun readingDerivedStateAndDependencyInvalidates() {
+        var changes = 0
+
+        runSimpleTest { stateObserver, state ->
+            val derivedState = derivedStateOf { state.value >= 0 }
+
+            stateObserver.observeReads("scope", { changes++ }) {
+                // read derived state
+                derivedState.value
+                // read dependency
+                state.value
+            }
+        }
+        assertEquals(1, changes)
+    }
+
     private fun runSimpleTest(
         block: (modelObserver: SnapshotStateObserver, data: MutableState<Int>) -> Unit
     ) {
diff --git a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt
index 380f3ab..02d3ff8 100644
--- a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt
+++ b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserverTestsJvm.kt
@@ -178,6 +178,7 @@
         val stateObserver = SnapshotStateObserver { it() }
         try {
             stateObserver.start()
+            Snapshot.notifyObjectsInitialized()
 
             val observer = object : (Any) -> Unit {
                 override fun invoke(affected: Any) {
@@ -194,11 +195,10 @@
                     }
                 }
             }
-
-            state.value++
+            // read with 0
             observer.readWithObservation()
-
-            Snapshot.notifyObjectsInitialized()
+            // increase to 1
+            state.value++
             Snapshot.sendApplyNotifications()
 
             assertEquals(1, changes)