blob: 59a68e69133b3b2e22b6dc02862ff7f915925a99 [file] [log] [blame]
/*
* Copyright 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@file:JvmName("SnapshotStateKt")
@file:JvmMultifileClass
package androidx.compose.runtime
import androidx.compose.runtime.collection.IdentityArrayMap
import androidx.compose.runtime.collection.MutableVector
import androidx.compose.runtime.collection.mutableVectorOf
import androidx.compose.runtime.snapshots.Snapshot
import androidx.compose.runtime.snapshots.StateObject
import androidx.compose.runtime.snapshots.StateRecord
import androidx.compose.runtime.snapshots.current
import androidx.compose.runtime.snapshots.newWritableRecord
import androidx.compose.runtime.snapshots.sync
import androidx.compose.runtime.snapshots.withCurrent
// Explicit imports for these needed in common source sets.
import kotlin.jvm.JvmName
import kotlin.jvm.JvmMultifileClass
import kotlin.math.min
/**
* A [State] that is derived from one or more other states.
*
* @see derivedStateOf
*/
internal interface DerivedState<T> : State<T> {
/**
* The value of the derived state retrieved without triggering a notification to read observers.
*/
val currentValue: T
/**
* A list of the dependencies used to produce [value] or [currentValue].
*
* The [dependencies] list can be used to determine when a [StateObject] appears in the apply
* observer set, if the state could affect value of this derived state.
*/
val dependencies: Array<Any?>
/**
* Mutation policy that controls how changes are handled after state dependencies update.
* If the policy is `null`, the derived state update is triggered regardless of the value
* produced and it is up to observer to invalidate it correctly.
*/
val policy: SnapshotMutationPolicy<T>?
}
private val calculationBlockNestedLevel = SnapshotThreadLocal<Int>()
private class DerivedSnapshotState<T>(
private val calculation: () -> T,
override val policy: SnapshotMutationPolicy<T>?
) : StateObject, DerivedState<T> {
private var first: ResultRecord<T> = ResultRecord()
class ResultRecord<T> : StateRecord() {
companion object {
val Unset = Any()
}
var dependencies: IdentityArrayMap<StateObject, Int>? = null
var result: Any? = Unset
var resultHash: Int = 0
override fun assign(value: StateRecord) {
@Suppress("UNCHECKED_CAST")
val other = value as ResultRecord<T>
dependencies = other.dependencies
result = other.result
resultHash = other.resultHash
}
override fun create(): StateRecord = ResultRecord<T>()
fun isValid(derivedState: DerivedState<*>, snapshot: Snapshot): Boolean =
result !== Unset && resultHash == readableHash(derivedState, snapshot)
fun readableHash(derivedState: DerivedState<*>, snapshot: Snapshot): Int {
var hash = 7
val dependencies = sync { dependencies }
if (dependencies != null) {
notifyObservers(derivedState) {
dependencies.forEach { stateObject, readLevel ->
if (readLevel != 1) {
return@forEach
}
// Find the first record without triggering an observer read.
val record = if (stateObject is DerivedSnapshotState<*>) {
// eagerly access the parent derived states without recording the
// read
// that way we can be sure derived states in deps were recalculated,
// and are updated to the last values
stateObject.current(snapshot)
} else {
current(stateObject.firstStateRecord, snapshot)
}
hash = 31 * hash + identityHashCode(record)
hash = 31 * hash + record.snapshotId
}
}
}
return hash
}
}
/**
* Get current record in snapshot. Forces recalculation if record is invalid to refresh
* state value.
*
* @return latest state record for the derived state.
*/
fun current(snapshot: Snapshot): StateRecord =
@Suppress("UNCHECKED_CAST")
currentRecord(current(first, snapshot), snapshot, false, calculation)
private fun currentRecord(
readable: ResultRecord<T>,
snapshot: Snapshot,
forceDependencyReads: Boolean,
calculation: () -> T
): ResultRecord<T> {
if (readable.isValid(this, snapshot)) {
// If the dependency is not recalculated, emulate nested state reads
// for correct invalidation later
if (forceDependencyReads) {
notifyObservers(this) {
val dependencies = readable.dependencies
val invalidationNestedLevel = calculationBlockNestedLevel.get() ?: 0
dependencies?.forEach { dependency, nestedLevel ->
calculationBlockNestedLevel.set(nestedLevel + invalidationNestedLevel)
snapshot.readObserver?.invoke(dependency)
}
calculationBlockNestedLevel.set(invalidationNestedLevel)
}
}
return readable
}
val nestedCalculationLevel = calculationBlockNestedLevel.get() ?: 0
val newDependencies = IdentityArrayMap<StateObject, Int>()
val result = notifyObservers(this) {
calculationBlockNestedLevel.set(nestedCalculationLevel + 1)
val result = Snapshot.observe(
{
if (it === this)
error("A derived state calculation cannot read itself")
if (it is StateObject) {
val readNestedLevel = calculationBlockNestedLevel.get()!!
newDependencies[it] = min(
readNestedLevel - nestedCalculationLevel,
newDependencies[it] ?: Int.MAX_VALUE
)
}
},
null, calculation
)
calculationBlockNestedLevel.set(nestedCalculationLevel)
result
}
val record = sync {
val currentSnapshot = Snapshot.current
if (
readable.result !== ResultRecord.Unset &&
@Suppress("UNCHECKED_CAST")
policy?.equivalent(result, readable.result as T) == true
) {
readable.dependencies = newDependencies
readable.resultHash = readable.readableHash(this, currentSnapshot)
readable
} else {
val writable = first.newWritableRecord(this, currentSnapshot)
writable.dependencies = newDependencies
writable.resultHash = writable.readableHash(this, currentSnapshot)
writable.result = result
writable
}
}
if (nestedCalculationLevel == 0) {
Snapshot.notifyObjectsInitialized()
}
return record
}
override val firstStateRecord: StateRecord get() = first
override fun prependStateRecord(value: StateRecord) {
@Suppress("UNCHECKED_CAST")
first = value as ResultRecord<T>
}
override val value: T
get() {
// Unlike most state objects, the record list of a derived state can change during a read
// because reading updates the cache. To account for this, instead of calling readable,
// which sends the read notification, the read observer is notified directly and current
// value is used instead which doesn't notify. This allow the read observer to read the
// value and only update the cache once.
Snapshot.current.readObserver?.invoke(this)
return first.withCurrent {
@Suppress("UNCHECKED_CAST")
currentRecord(it, Snapshot.current, true, calculation).result as T
}
}
override val currentValue: T
get() = first.withCurrent {
@Suppress("UNCHECKED_CAST")
currentRecord(it, Snapshot.current, false, calculation).result as T
}
override val dependencies: Array<Any?>
get() = first.withCurrent {
val record = currentRecord(it, Snapshot.current, false, calculation)
@Suppress("UNCHECKED_CAST")
record.dependencies?.keys ?: emptyArray()
}
override fun toString(): String = first.withCurrent {
"DerivedState(value=${displayValue()})@${hashCode()}"
}
/**
* A function used by the debugger to display the value of the current value of the mutable
* state object without triggering read observers.
*/
@Suppress("unused")
val debuggerDisplayValue: T?
@JvmName("getDebuggerDisplayValue")
get() = first.withCurrent {
@Suppress("UNCHECKED_CAST")
if (it.isValid(this, Snapshot.current))
it.result as T
else null
}
private fun displayValue(): String {
first.withCurrent {
if (it.isValid(this, Snapshot.current)) {
return it.result.toString()
}
return "<Not calculated>"
}
}
}
/**
* Creates a [State] object whose [State.value] is the result of [calculation]. The result of
* calculation will be cached in such a way that calling [State.value] repeatedly will not cause
* [calculation] to be executed multiple times, but reading [State.value] will cause all [State]
* objects that got read during the [calculation] to be read in the current [Snapshot], meaning
* that this will correctly subscribe to the derived state objects if the value is being read in
* an observed context such as a [Composable] function.
* Derived states without mutation policy trigger updates on each dependency change. To avoid
* invalidation on update, provide suitable [SnapshotMutationPolicy] through [derivedStateOf]
* overload.
*
* @sample androidx.compose.runtime.samples.DerivedStateSample
*
* @param calculation the calculation to create the value this state object represents.
*/
fun <T> derivedStateOf(
calculation: () -> T,
): State<T> = DerivedSnapshotState(calculation, null)
/**
* Creates a [State] object whose [State.value] is the result of [calculation]. The result of
* calculation will be cached in such a way that calling [State.value] repeatedly will not cause
* [calculation] to be executed multiple times, but reading [State.value] will cause all [State]
* objects that got read during the [calculation] to be read in the current [Snapshot], meaning
* that this will correctly subscribe to the derived state objects if the value is being read in
* an observed context such as a [Composable] function.
*
* @sample androidx.compose.runtime.samples.DerivedStateSample
*
* @param policy mutation policy to control when changes to the [calculation] result trigger update.
* @param calculation the calculation to create the value this state object represents.
*/
fun <T> derivedStateOf(
policy: SnapshotMutationPolicy<T>,
calculation: () -> T,
): State<T> = DerivedSnapshotState(calculation, policy)
private typealias DerivedStateObservers = Pair<(DerivedState<*>) -> Unit, (DerivedState<*>) -> Unit>
private val derivedStateObservers = SnapshotThreadLocal<MutableVector<DerivedStateObservers>>()
private inline fun <R> notifyObservers(derivedState: DerivedState<*>, block: () -> R): R {
val observers = derivedStateObservers.get() ?: MutableVector(0)
observers.forEach { (start, _) -> start(derivedState) }
return try {
block()
} finally {
observers.forEach { (_, done) -> done(derivedState) }
}
}
/**
* Observe the recalculations performed by any derived state that is recalculated during the
* execution of [block]. [start] is called before a calculation starts and [done] is called
* after the started calculation is complete.
*
* @param start a lambda called before every calculation of a derived state is in [block].
* @param done a lambda that is called after the state passed to [start] is recalculated.
* @param block the block of code to observe.
*/
internal fun <R> observeDerivedStateRecalculations(
start: (derivedState: State<*>) -> Unit,
done: (derivedState: State<*>) -> Unit,
block: () -> R
) {
val observers = derivedStateObservers.get() ?: mutableVectorOf<DerivedStateObservers>().also {
derivedStateObservers.set(it)
}
val observer = start to done
try {
observers.add(observer)
block()
} finally {
observers.removeAt(observers.lastIndex)
}
}