Handle live edit errors on save/restore

First live edit action triggers full composition reload. In cases when this reload caused a crash to occur, failed composition was left in invalid state (new group was not recorded). Further targeted invalidations of that group weren't applied, as the group didn't exist in the slot table.

This change handles full reload of composition similarly to initial composition, invalidating all groups and resetting content. This ensures that failed composition will be updated whenever related code change occurs.

Fixes: 243426376
Test: LiveEditApiTests
Change-Id: I4a2fe93ccfc5429b43d0e68fd0add3ee3985ff1f
diff --git a/compose/runtime/runtime/integration-tests/src/androidAndroidTest/kotlin/androidx/compose/runtime/LiveEditApiTests.kt b/compose/runtime/runtime/integration-tests/src/androidAndroidTest/kotlin/androidx/compose/runtime/LiveEditApiTests.kt
index 3e04f55..549814d 100644
--- a/compose/runtime/runtime/integration-tests/src/androidAndroidTest/kotlin/androidx/compose/runtime/LiveEditApiTests.kt
+++ b/compose/runtime/runtime/integration-tests/src/androidAndroidTest/kotlin/androidx/compose/runtime/LiveEditApiTests.kt
@@ -392,6 +392,36 @@
             assertThat(errors).hasSize(0)
         }
     }
+
+    @Test
+    @MediumTest
+    fun throwErrorOnReload_recoversAfterInvalidate() {
+        var shouldThrow = false
+        activity.show {
+            TestError { shouldThrow }
+        }
+
+        activity.waitForAFrame()
+
+        run {
+            shouldThrow = true
+            simulateHotReload(Unit)
+
+            val start = errorInvoked
+            var errors = compositionErrors()
+            assertThat(errors).hasSize(1)
+            assertThat(errors[0].first.message).isEqualTo("Test crash!")
+            assertThat(errors[0].second).isEqualTo(true)
+
+            shouldThrow = false
+            invalidateGroup(errorKey)
+
+            assertTrue("TestError should be invoked!", errorInvoked > start)
+
+            errors = compositionErrors()
+            assertThat(errors).hasSize(0)
+        }
+    }
 }
 
 const val someFunctionKey = -1580285603 // Extracted from .class file
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composer.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composer.kt
index 4e06f54..bc62e6c 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composer.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Composer.kt
@@ -1485,7 +1485,7 @@
         return if (!forceRecomposeScopes) {
             forceRecomposeScopes = true
             forciblyRecompose = true
-             true
+            true
         } else {
             false
         }
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt
index c4b90f1..725fb0d 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/Recomposer.kt
@@ -261,6 +261,7 @@
         mutableMapOf<MovableContent<Any?>, MutableList<MovableContentStateReference>>()
     private val compositionValueStatesAvailable =
         mutableMapOf<MovableContentStateReference, MovableContentState>()
+    private var failedCompositions: MutableList<ControlledComposition>? = null
     private var workContinuation: CancellableContinuation<Unit>? = null
     private var concurrentCompositionsOutstanding = 0
     private var isClosed: Boolean = false
@@ -280,8 +281,10 @@
             compositionInvalidations.clear()
             compositionsAwaitingApply.clear()
             compositionValuesAwaitingInsert.clear()
+            failedCompositions = null
             workContinuation?.cancel()
             workContinuation = null
+            errorState = null
             return null
         }
 
@@ -362,8 +365,11 @@
                 .fastMap { HotReloadable(it).apply { clearContent() } }
         }
 
-        fun getAndResetErrorState(): RecomposerErrorState? =
-            this@Recomposer.getAndResetErrorState()
+        fun resetErrorState(): RecomposerErrorState? =
+            this@Recomposer.resetErrorState()
+
+        fun retryFailedCompositions() =
+            this@Recomposer.retryFailedCompositions()
     }
 
     private class HotReloadable(
@@ -380,19 +386,14 @@
             composition.composable = composable
         }
 
-        fun recompose(rootOnly: Boolean = true) {
-            if (rootOnly) {
-                if (composition.isRoot) {
-                    composition.setContent(composable)
-                }
-            } else {
+        fun recompose() {
+            if (composition.isRoot) {
                 composition.setContent(composable)
             }
         }
     }
 
     private class RecomposerErrorState(
-        val failedInitialComposition: HotReloadable?,
         override val recoverable: Boolean,
         override val cause: Exception
     ) : RecomposerErrorInfo
@@ -647,13 +648,22 @@
                 compositionValueStatesAvailable.clear()
 
                 errorState = RecomposerErrorState(
-                    failedInitialComposition = (failedInitialComposition as? CompositionImpl)?.let {
-                        HotReloadable(it)
-                    },
                     recoverable = recoverable,
                     cause = e
                 )
 
+                if (failedInitialComposition != null) {
+                    val failedCompositions = failedCompositions
+                        ?: mutableListOf<ControlledComposition>().also {
+                            failedCompositions = it
+                        }
+
+                    if (failedInitialComposition !in failedCompositions) {
+                        failedCompositions += failedInitialComposition
+                    }
+                    knownCompositions -= failedInitialComposition
+                }
+
                 deriveStateLocked()
             }
         } else {
@@ -661,7 +671,7 @@
         }
     }
 
-    private fun getAndResetErrorState(): RecomposerErrorState? {
+    private fun resetErrorState(): RecomposerErrorState? {
         val errorState = synchronized(stateLock) {
             val error = errorState
             if (error != null) {
@@ -673,6 +683,22 @@
         return errorState
     }
 
+    private fun retryFailedCompositions() {
+        synchronized(stateLock) {
+            val failedCompositions = failedCompositions ?: return
+
+            while (failedCompositions.isNotEmpty()) {
+                val composition = failedCompositions.removeLast()
+                if (composition !is CompositionImpl) continue
+
+                composition.invalidateAll()
+                composition.setContent(composition.composable)
+
+                if (errorState != null) break
+            }
+        }
+    }
+
     /**
      * Await the invalidation of any associated [Composer]s, recompose them, and apply their
      * changes to their associated [Composition]s if recomposition is successful.
@@ -945,9 +971,6 @@
             performInitialMovableContentInserts(composition)
         } catch (e: Exception) {
             processCompositionError(e, composition, recoverable = true)
-            synchronized(stateLock) {
-                knownCompositions -= composition
-            }
             return
         }
 
@@ -1250,8 +1273,8 @@
             // to ensure that we pause recompositions before this call.
             _hotReloadEnabled.set(true)
 
-            val errorStates = _runningRecomposers.value.map {
-                it.getAndResetErrorState()
+            _runningRecomposers.value.forEach {
+                it.resetErrorState()
             }
 
             @Suppress("UNCHECKED_CAST")
@@ -1259,11 +1282,8 @@
             holders.fastForEach { it.resetContent() }
             holders.fastForEach { it.recompose() }
 
-            errorStates.fastForEach {
-                 it?.failedInitialComposition?.let { c ->
-                     c.resetContent()
-                     c.recompose(rootOnly = false)
-                 }
+            _runningRecomposers.value.forEach {
+                it.retryFailedCompositions()
             }
         }
 
@@ -1274,14 +1294,11 @@
                     return@forEach
                 }
 
-                val errorState = it.getAndResetErrorState()
+                it.resetErrorState()
 
                 it.invalidateGroupsWithKey(key)
 
-                errorState?.failedInitialComposition?.let { c ->
-                    c.resetContent()
-                    c.recompose(rootOnly = false)
-                }
+                it.retryFailedCompositions()
             }
         }
 
@@ -1292,7 +1309,7 @@
 
         internal fun clearErrors() {
             _runningRecomposers.value.mapNotNull {
-                it.getAndResetErrorState()
+                it.resetErrorState()
             }
         }
     }
diff --git a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SlotTable.kt b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SlotTable.kt
index c5a0450..612736f 100644
--- a/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SlotTable.kt
+++ b/compose/runtime/runtime/src/commonMain/kotlin/androidx/compose/runtime/SlotTable.kt
@@ -298,8 +298,8 @@
      * which we know will no longer have the same structure so we want to remove them before
      * recomposing.
      *
-     * Returns true if all the groups were successfully invalidated. If this returns fals then
-     * the a full composition must be foreced.
+     * Returns a list of groups if they were successfully invalidated. If this returns null then
+     * a full composition must be forced.
      */
     internal fun invalidateGroupsWithKey(target: Int): List<RecomposeScopeImpl>? {
         val anchors = mutableListOf<Anchor>()
@@ -939,7 +939,7 @@
      *  Skip a group. Must be called at the start of a group.
      */
     fun skipGroup(): Int {
-        require(emptyCount == 0) { "Cannot skip while in an empty region" }
+        runtimeCheck(emptyCount == 0) { "Cannot skip while in an empty region" }
         val count = if (groups.isNode(currentGroup)) 1 else groups.nodeCount(currentGroup)
         currentGroup += groups.groupSize(currentGroup)
         return count
@@ -949,7 +949,7 @@
      * Skip to the end of the current group.
      */
     fun skipToGroupEnd() {
-        require(emptyCount == 0) { "Cannot skip the enclosing group while in an empty region" }
+        runtimeCheck(emptyCount == 0) { "Cannot skip the enclosing group while in an empty region" }
         currentGroup = currentEnd
     }
 
@@ -957,7 +957,7 @@
      * Reposition the read to the group at [index].
      */
     fun reposition(index: Int) {
-        require(emptyCount == 0) { "Cannot reposition while in an empty region" }
+        runtimeCheck(emptyCount == 0) { "Cannot reposition while in an empty region" }
         currentGroup = index
         val parent = if (index < groupsSize) groups.parentAnchor(index) else -1
         this.parent = parent
@@ -976,7 +976,7 @@
         val newCurrentEnd = index + groups.groupSize(index)
         val current = currentGroup
         @Suppress("ConvertTwoComparisonsToRangeCheck")
-        require(current >= index && current <= newCurrentEnd) {
+        runtimeCheck(current >= index && current <= newCurrentEnd) {
             "Index $index is not a parent of $current"
         }
         this.parent = index
@@ -990,7 +990,9 @@
      */
     fun endGroup() {
         if (emptyCount == 0) {
-            require(currentGroup == currentEnd) { "endGroup() not called at the end of a group" }
+            runtimeCheck(currentGroup == currentEnd) {
+                "endGroup() not called at the end of a group"
+            }
             val parent = groups.parentAnchor(parent)
             this.parent = parent
             currentEnd = if (parent < 0)
@@ -1476,7 +1478,7 @@
      * currently started [parent].
      */
     fun advanceBy(amount: Int) {
-        require(amount >= 0) { "Cannot seek backwards" }
+        runtimeCheck(amount >= 0) { "Cannot seek backwards" }
         check(insertCount <= 0) { "Cannot call seek() while inserting" }
         if (amount == 0) return
         val index = currentGroup + amount
@@ -1532,7 +1534,7 @@
      * Enter the group at current without changing it. Requires not currently inserting.
      */
     fun startGroup() {
-        require(insertCount == 0) { "Key must be supplied when inserting" }
+        runtimeCheck(insertCount == 0) { "Key must be supplied when inserting" }
         startGroup(key = 0, objectKey = Composer.Empty, isNode = false, aux = Composer.Empty)
     }
 
@@ -1657,7 +1659,7 @@
             nodeCount = nodeCountStack.pop() + if (isNode) 1 else newNodes
             parent = groups.parent(groupIndex)
         } else {
-            require(currentGroup == currentGroupEnd) {
+            runtimeCheck(currentGroup == currentGroupEnd) {
                 "Expected to be at the end of a group"
             }
             // Update group length
@@ -1733,12 +1735,12 @@
      * group is reached.
      */
     fun ensureStarted(index: Int) {
-        require(insertCount <= 0) { "Cannot call ensureStarted() while inserting" }
+        runtimeCheck(insertCount <= 0) { "Cannot call ensureStarted() while inserting" }
         val parent = parent
         if (parent != index) {
             // The new parent a child of the current group.
             @Suppress("ConvertTwoComparisonsToRangeCheck")
-            require(index >= parent && index < currentGroupEnd) {
+            runtimeCheck(index >= parent && index < currentGroupEnd) {
                 "Started group at $index must be a subgroup of the group at $parent"
             }
 
@@ -1770,7 +1772,7 @@
      * Remove the current group. Returns if any anchors were in the group removed.
      */
     fun removeGroup(): Boolean {
-        require(insertCount == 0) { "Cannot remove group while inserting" }
+        runtimeCheck(insertCount == 0) { "Cannot remove group while inserting" }
         val oldGroup = currentGroup
         val oldSlot = currentSlot
         val count = skipGroup()
@@ -1813,8 +1815,8 @@
      * number of groups after the [currentGroup] left in the [parent] group.
      */
     fun moveGroup(offset: Int) {
-        require(insertCount == 0) { "Cannot move a group while inserting" }
-        require(offset >= 0) { "Parameter offset is out of bounds" }
+        runtimeCheck(insertCount == 0) { "Cannot move a group while inserting" }
+        runtimeCheck(offset >= 0) { "Parameter offset is out of bounds" }
         if (offset == 0) return
         val current = currentGroup
         val parent = parent
@@ -1827,7 +1829,7 @@
             groupToMove += groups.groupSize(
                 address = groupIndexToAddress(groupToMove)
             )
-            require(groupToMove <= parentEnd) { "Parameter offset is out of bounds" }
+            runtimeCheck(groupToMove <= parentEnd) { "Parameter offset is out of bounds" }
             count--
         }
 
@@ -2104,12 +2106,12 @@
      * This requires [writer] be inserting and this writer to not be inserting.
      */
     fun moveTo(anchor: Anchor, offset: Int, writer: SlotWriter): List<Anchor> {
-        require(writer.insertCount > 0)
-        require(insertCount == 0)
-        require(anchor.valid)
+        runtimeCheck(writer.insertCount > 0)
+        runtimeCheck(insertCount == 0)
+        runtimeCheck(anchor.valid)
         val location = anchorIndex(anchor) + offset
         val currentGroup = currentGroup
-        require(location in currentGroup until currentGroupEnd)
+        runtimeCheck(location in currentGroup until currentGroupEnd)
         val parent = parent(location)
         val size = groupSize(location)
         val nodes = if (isNode(location)) 1 else nodeCount(location)
@@ -2154,7 +2156,7 @@
      * @return a list of the anchors that were moved
      */
     fun moveFrom(table: SlotTable, index: Int): List<Anchor> {
-        require(insertCount > 0)
+        runtimeCheck(insertCount > 0)
 
         if (index == 0 && currentGroup == 0 && this.table.groupsSize == 0) {
             // Special case for moving the entire slot table into an empty table. This case occurs
@@ -3218,7 +3220,7 @@
 
 private fun IntArray.updateNodeCount(address: Int, value: Int) {
     @Suppress("ConvertTwoComparisonsToRangeCheck")
-    require(value >= 0 && value < NodeCount_Mask)
+    runtimeCheck(value >= 0 && value < NodeCount_Mask)
     this[address * Group_Fields_Size + GroupInfo_Offset] =
         (this[address * Group_Fields_Size + GroupInfo_Offset] and NodeCount_Mask.inv()) or value
 }
@@ -3241,7 +3243,7 @@
 // Slot count access
 private fun IntArray.groupSize(address: Int) = this[address * Group_Fields_Size + Size_Offset]
 private fun IntArray.updateGroupSize(address: Int, value: Int) {
-    require(value >= 0)
+    runtimeCheck(value >= 0)
     this[address * Group_Fields_Size + Size_Offset] = value
 }
 
diff --git a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/LiveEditTests.kt b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/LiveEditTests.kt
index 3e2426c..dcb954f 100644
--- a/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/LiveEditTests.kt
+++ b/compose/runtime/runtime/src/jvmTest/kotlin/androidx/compose/runtime/LiveEditTests.kt
@@ -34,7 +34,7 @@
     @After
     fun tearDown() {
         clearCompositionErrors()
-        Recomposer.setHotReloadEnabled(true)
+        Recomposer.setHotReloadEnabled(false)
     }
 
     @Test