Speed up scope observations cleanup in SnapshotStateObserver

Updates SnapshotStateObserver to use reverse index for faster scope observation cleanups. Previous implementation was using linear access, degrading performance when measuring deeper hierarchies.

Restructures logic inside state observer to support the additional structure + small updates to documentation and naming.

Test: perf: SnapshotStateObserverBenchmark
Test: correctness: SnapshotStateObserverTestsCommon, SnapshotStateObserverTestsJvm

Change-Id: I17fe9ccbeeda37fe1da5774b725c36b3310d0a89
diff --git a/compose/runtime/runtime/compose-runtime-benchmark/src/androidTest/java/androidx/compose/runtime/benchmark/SnapshotStateObserverBenchmark.kt b/compose/runtime/runtime/compose-runtime-benchmark/src/androidTest/java/androidx/compose/runtime/benchmark/SnapshotStateObserverBenchmark.kt
index b8a6cfd..25e96a6 100644
--- a/compose/runtime/runtime/compose-runtime-benchmark/src/androidTest/java/androidx/compose/runtime/benchmark/SnapshotStateObserverBenchmark.kt
+++ b/compose/runtime/runtime/compose-runtime-benchmark/src/androidTest/java/androidx/compose/runtime/benchmark/SnapshotStateObserverBenchmark.kt
@@ -24,16 +24,16 @@
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.snapshots.Snapshot
 import androidx.compose.runtime.snapshots.SnapshotStateObserver
-import androidx.test.filters.LargeTest
-import org.junit.After
-import org.junit.Before
-import org.junit.Test
-import org.junit.runner.RunWith
 import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.LargeTest
 import kotlin.math.pow
 import kotlin.math.roundToInt
 import kotlin.random.Random
+import org.junit.After
 import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
 
 @LargeTest
 @RunWith(AndroidJUnit4::class)
@@ -116,6 +116,36 @@
     }
 
     @Test
+    fun deeplyNestedModelObservations() {
+        assumeTrue(Build.VERSION.SDK_INT != 29)
+        runOnUiThread {
+            val list = mutableListOf<Any>()
+            repeat(100) {
+                list += nodes[random.nextInt(ScopeCount)]
+            }
+
+            fun observeRecursive(index: Int) {
+                if (index == 100) return
+                val node = list[index]
+                stateObserver.observeReads(node, doNothing) {
+                    observeForNode(node)
+                    observeRecursive(index + 1)
+                }
+            }
+
+            benchmarkRule.measureRepeated {
+                runWithTimingDisabled {
+                    random = Random(0)
+                    nodes.forEach { node ->
+                        stateObserver.clear(node)
+                    }
+                }
+                observeRecursive(0)
+            }
+        }
+    }
+
+    @Test
     fun modelClear() {
         assumeTrue(Build.VERSION.SDK_INT != 29)
         runOnUiThread {
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
index 516644a..890dd67 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/snapshots/SnapshotStateObserver.kt
@@ -17,30 +17,30 @@
 package androidx.compose.runtime.snapshots
 
 import androidx.compose.runtime.TestOnly
+import androidx.compose.runtime.collection.IdentityArrayMap
+import androidx.compose.runtime.collection.IdentityArraySet
 import androidx.compose.runtime.collection.IdentityScopeMap
 import androidx.compose.runtime.collection.mutableVectorOf
-import androidx.compose.runtime.synchronized
 
+/**
+ * Helper class to efficiently observe snapshot state reads. See [observeReads] for more details.
+ *
+ * NOTE: This class is not thread-safe, so implementations should not reuse observer between
+ * different threads to avoid race conditions.
+ */
 @Suppress("NotCloseable") // we can't implement AutoCloseable from commonMain
 class SnapshotStateObserver(private val onChangedExecutor: (callback: () -> Unit) -> Unit) {
     private val applyObserver: (Set<Any>, Snapshot) -> Unit = { applied, _ ->
         var hasValues = false
 
-        synchronized(applyMaps) {
-            applyMaps.forEach { applyMap ->
-                val invalidated = applyMap.invalidated
-                val map = applyMap.map
-                for (value in applied) {
-                    map.forEachScopeOf(value) { scope ->
-                        invalidated += scope
-                        hasValues = true
-                    }
-                }
-            }
+        forEachScopeMap { scopeMap ->
+            hasValues = scopeMap.recordInvalidation(applied) || hasValues
         }
         if (hasValues) {
             onChangedExecutor {
-                callOnChanged()
+                forEachScopeMap { scopeMap ->
+                    scopeMap.notifyInvalidatedScopes()
+                }
             }
         }
     }
@@ -50,17 +50,29 @@
      */
     private val readObserver: (Any) -> Unit = { state ->
         if (!isPaused) {
-            synchronized(applyMaps) {
-                currentMap!!.addValue(state)
+            synchronized(observedScopeMaps) {
+                currentMap!!.recordRead(state)
             }
         }
     }
 
     /**
-     * List of all [ApplyMap]s. When [observeReads] is called, there will be a [ApplyMap]
-     * associated with its `onChanged` callback in this list. The list only grows.
+     * List of all [ObservedScopeMap]s. When [observeReads] is called, there will be a
+     * [ObservedScopeMap] associated with its [ObservedScopeMap.onChanged] callback in this list.
+     * The list only grows.
      */
-    private val applyMaps = mutableVectorOf<ApplyMap<*>>()
+    private val observedScopeMaps = mutableVectorOf<ObservedScopeMap>()
+
+    /**
+     * Helper for synchronized iteration over [observedScopeMaps]. All observed reads should
+     * happen on the same thread, but snapshots can be applied on a different thread, requiring
+     * synchronization.
+     */
+    private inline fun forEachScopeMap(block: (ObservedScopeMap) -> Unit) {
+        synchronized(observedScopeMaps) {
+            observedScopeMaps.forEach(block)
+        }
+    }
 
     /**
      * Method to call when unsubscribing from the apply observer.
@@ -68,50 +80,53 @@
     private var applyUnsubscribe: ObserverHandle? = null
 
     /**
-     * `true` when [withNoObservations] is called and read observations should no
-     * longer be considered invalidations for the `onCommit` callback.
+     * `true` when [withNoObservations] is called and read observations should not
+     * be considered invalidations for the current scope.
      */
     private var isPaused = false
 
     /**
-     * The [ApplyMap] that should be added to when a model is read during [observeReads].
+     * The [ObservedScopeMap] that should be added to when a model is read during [observeReads].
      */
-    private var currentMap: ApplyMap<*>? = null
+    private var currentMap: ObservedScopeMap? = null
 
     /**
      * Executes [block], observing state object reads during its execution.
      *
      * The [scope] and [onValueChangedForScope] are associated with any values that are read so
-     * that when those values change, [onValueChangedForScope] can be called with the [scope]
+     * that when those values change, [onValueChangedForScope] will be called with the [scope]
      * parameter.
      *
-     * Observation for [scope] will be paused when a new [observeReads] call is made or when
-     * [withNoObservations] is called.
+     * Observation can be paused with [Snapshot.withoutReadObservation].
      *
-     * Any previous observation with the given [scope] and [onValueChangedForScope] will be
-     * cleared when the [onValueChangedForScope] is called for [scope]. The
-     * [onValueChangedForScope] should trigger a new [observeReads] call to resubscribe to
-     * changes. They may also be cleared using [clearIf] or [clear].
+     * @param scope value associated with the observed scope.
+     * @param onValueChangedForScope is called with the [scope] when value read within [block]
+     * has been changed. For repeated observations, it is more performant to pass the same instance
+     * of the callback, as [observedScopeMaps] grows with each new callback instance.
+     * @param block to observe reads within.
      */
     fun <T : Any> observeReads(scope: T, onValueChangedForScope: (T) -> Unit, block: () -> Unit) {
-        val oldMap = currentMap
-        val oldPaused = isPaused
-        val applyMap = synchronized(applyMaps) {
+        val scopeMap = synchronized(observedScopeMaps) {
             ensureMap(onValueChangedForScope).also {
-                it.map.removeScope(scope)
+                it.clearScopeObservations(scope)
             }
         }
-        val oldScope = applyMap.currentScope
 
-        applyMap.currentScope = scope
-        currentMap = applyMap
-        isPaused = false
+        val oldPaused = isPaused
+        val oldMap = currentMap
+        val oldScope = scopeMap.currentScope
 
-        Snapshot.observe(readObserver, null, block)
+        try {
+            isPaused = false
+            currentMap = scopeMap
+            scopeMap.currentScope = scope
 
-        currentMap = oldMap
-        applyMap.currentScope = oldScope
-        isPaused = oldPaused
+            Snapshot.observe(readObserver, null, block)
+        } finally {
+            scopeMap.currentScope = oldScope
+            currentMap = oldMap
+            isPaused = oldPaused
+        }
     }
 
     /**
@@ -136,28 +151,22 @@
     }
 
     /**
-     * Clears all model read observations for a given [scope]. This clears values for all
-     * `onCommit` methods passed in [observeReads].
+     * Clears all state read observations for a given [scope]. This clears values for all
+     * `onValueChangedForScope` callbacks passed in [observeReads].
      */
     fun clear(scope: Any) {
-        synchronized(applyMaps) {
-            applyMaps.forEach { commitMap ->
-                commitMap.map.removeValueIf {
-                    it === scope
-                }
-            }
+        forEachScopeMap {
+            it.clearScopeObservations(scope)
         }
     }
 
     /**
-     * Remove observations using [predicate] to identify scope scopes to be removed. This is
+     * Remove observations using [predicate] to identify scopes to be removed. This is
      * used when a scope is no longer in the hierarchy and should not receive any callbacks.
      */
     fun clearIf(predicate: (scope: Any) -> Boolean) {
-        synchronized(applyMaps) {
-            applyMaps.forEach { applyMap ->
-                applyMap.map.removeValueIf(predicate)
-            }
+        forEachScopeMap { scopeMap ->
+            scopeMap.removeScopeIf(predicate)
         }
     }
 
@@ -188,79 +197,114 @@
      * Remove all observations.
      */
     fun clear() {
-        synchronized(applyMaps) {
-            applyMaps.forEach { applyMap ->
-                applyMap.map.clear()
-            }
+        forEachScopeMap { scopeMap ->
+            scopeMap.clear()
         }
     }
 
     /**
-     * Calls the `onChanged` callback for the given scopes.
-     */
-    private fun callOnChanged() {
-        applyMaps.forEach { applyMap ->
-            val scopes = applyMap.invalidated
-            if (scopes.isNotEmpty()) {
-                applyMap.callOnChanged(scopes)
-                scopes.clear()
-            }
-        }
-    }
-
-    /**
-     * Returns the [ApplyMap] within [applyMaps] associated with [onChanged] or a newly-
+     * Returns the [ObservedScopeMap] within [observedScopeMaps] associated with [onChanged] or a newly-
      * inserted one if it doesn't exist.
      *
      * Must be called inside a synchronized block.
      */
-    private fun <T : Any> ensureMap(onChanged: (T) -> Unit): ApplyMap<T> {
-        val index = applyMaps.indexOfFirst { it.onChanged === onChanged }
-        if (index == -1) {
-            val commitMap = ApplyMap(onChanged)
-            applyMaps += commitMap
-            return commitMap
+    @Suppress("UNCHECKED_CAST")
+    private fun <T : Any> ensureMap(onChanged: (T) -> Unit): ObservedScopeMap {
+        val scopeMap = observedScopeMaps.firstOrNull { it.onChanged === onChanged }
+        if (scopeMap == null) {
+            val map = ObservedScopeMap(onChanged as ((Any) -> Unit))
+            observedScopeMaps += map
+            return map
         }
-        @Suppress("UNCHECKED_CAST")
-        return applyMaps[index] as ApplyMap<T>
+        return scopeMap
     }
 
     /**
-     * Used to tie an [onChanged] to its scope by type. This works around some difficulties in
-     * unchecked casts with kotlin.
+     * Connects observed values to scopes for each [onChanged] callback.
      */
     @Suppress("UNCHECKED_CAST")
-    private class ApplyMap<T : Any>(val onChanged: (T) -> Unit) {
+    private class ObservedScopeMap(val onChanged: (Any) -> Unit) {
         /**
-         * Map (key = model, value = scope). These are the models that have been
-         * read during the scope's [SnapshotStateObserver.observeReads].
+         * Currently observed scope.
          */
-        val map = IdentityScopeMap<T>()
+        var currentScope: Any? = null
 
         /**
-         * Scopes that were invalidated. This and cleared during the [applyObserver] call.
+         * Values that have been read during the scope's [SnapshotStateObserver.observeReads].
          */
-        val invalidated = hashSetOf<Any>()
+        private val valueToScopes = IdentityScopeMap<Any>()
 
         /**
-         * Current scope that adds to [map] will use.
+         * Reverse index (scope -> values) for faster scope invalidation.
          */
-        var currentScope: T? = null
+        private val scopeToValues: IdentityArrayMap<Any, IdentityArraySet<Any>> =
+            IdentityArrayMap()
 
         /**
-         * Adds [value]/[currentScope] to the [map].
+         * Scopes that were invalidated during previous apply step.
          */
-        fun addValue(value: Any) {
-            map.add(value, currentScope!!)
+        private val invalidated = hashSetOf<Any>()
+
+        /**
+         * Record that [value] was read in [currentScope].
+         */
+        fun recordRead(value: Any) {
+            val scope = currentScope!!
+            valueToScopes.add(value, scope)
+            val recordedValues = scopeToValues[scope]
+                ?: IdentityArraySet<Any>().also { scopeToValues[scope] = it }
+
+            recordedValues.add(value)
         }
 
         /**
-         * Calls the `onCommit` callback for scopes affected by the given committed values.
+         * Clear observations for [scope].
          */
-        fun callOnChanged(scopes: Collection<Any>) {
-            scopes.forEach { scope ->
-                onChanged(scope as T)
+        fun clearScopeObservations(scope: Any) {
+            val recordedValues = scopeToValues[scope] ?: return
+            recordedValues.forEach {
+                valueToScopes.remove(it, scope)
             }
+            scopeToValues.remove(scope)
+        }
+
+        /**
+         * Remove observations in scopes matching [predicate].
+         */
+        inline fun removeScopeIf(predicate: (scope: Any) -> Boolean) {
+            valueToScopes.removeValueIf(predicate)
+            scopeToValues.removeIf { scope, _ -> predicate(scope) }
+        }
+
+        /**
+         * Clear all observations.
+         */
+        fun clear() {
+            valueToScopes.clear()
+            scopeToValues.clear()
+        }
+
+        /**
+         * Record scope invalidation for given set of values.
+         * @return whether any scopes observe changed values
+         */
+        fun recordInvalidation(changes: Set<Any>): Boolean {
+            var hasValues = false
+            for (value in changes) {
+                valueToScopes.forEachScopeOf(value) { scope ->
+                    invalidated += scope
+                    hasValues = true
+                }
+            }
+            return hasValues
+        }
+
+        /**
+         * Call [onChanged] for previously invalidated scopes.
+         */
+        fun notifyInvalidatedScopes() {
+            invalidated.forEach(onChanged)
+            invalidated.clear()
         }
     }
 }