blob: 8a3a0b9189a5e64657e3cef8d17776cb6130e212 [file] [log] [blame]
use core::cell::UnsafeCell;
use core::fmt;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::Waker;
use crate::raw::TaskVTable;
use crate::state::*;
use crate::utils::abort_on_panic;
/// The header of a task.
///
/// This header is stored in memory at the beginning of the heap-allocated task.
pub(crate) struct Header {
/// Current state of the task.
///
/// Contains flags representing the current state and the reference count.
pub(crate) state: AtomicUsize,
/// The task that is blocked on the `Task` handle.
///
/// This waker needs to be woken up once the task completes or is closed.
pub(crate) awaiter: UnsafeCell<Option<Waker>>,
/// The virtual table.
///
/// In addition to the actual waker virtual table, it also contains pointers to several other
/// methods necessary for bookkeeping the heap-allocated task.
pub(crate) vtable: &'static TaskVTable,
}
impl Header {
/// Notifies the awaiter blocked on this task.
///
/// If the awaiter is the same as the current waker, it will not be notified.
#[inline]
pub(crate) fn notify(&self, current: Option<&Waker>) {
if let Some(w) = self.take(current) {
abort_on_panic(|| w.wake());
}
}
/// Takes the awaiter blocked on this task.
///
/// If there is no awaiter or if it is the same as the current waker, returns `None`.
#[inline]
pub(crate) fn take(&self, current: Option<&Waker>) -> Option<Waker> {
// Set the bit indicating that the task is notifying its awaiter.
let state = self.state.fetch_or(NOTIFYING, Ordering::AcqRel);
// If the task was not notifying or registering an awaiter...
if state & (NOTIFYING | REGISTERING) == 0 {
// Take the waker out.
let waker = unsafe { (*self.awaiter.get()).take() };
// Unset the bit indicating that the task is notifying its awaiter.
self.state
.fetch_and(!NOTIFYING & !AWAITER, Ordering::Release);
// Finally, notify the waker if it's different from the current waker.
if let Some(w) = waker {
match current {
None => return Some(w),
Some(c) if !w.will_wake(c) => return Some(w),
Some(_) => abort_on_panic(|| drop(w)),
}
}
}
None
}
/// Registers a new awaiter blocked on this task.
///
/// This method is called when `Task` is polled and it has not yet completed.
#[inline]
pub(crate) fn register(&self, waker: &Waker) {
// Load the state and synchronize with it.
let mut state = self.state.fetch_or(0, Ordering::Acquire);
loop {
// There can't be two concurrent registrations because `Task` can only be polled
// by a unique pinned reference.
debug_assert!(state & REGISTERING == 0);
// If we're in the notifying state at this moment, just wake and return without
// registering.
if state & NOTIFYING != 0 {
abort_on_panic(|| waker.wake_by_ref());
return;
}
// Mark the state to let other threads know we're registering a new awaiter.
match self.state.compare_exchange_weak(
state,
state | REGISTERING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
state |= REGISTERING;
break;
}
Err(s) => state = s,
}
}
// Put the waker into the awaiter field.
unsafe {
abort_on_panic(|| (*self.awaiter.get()) = Some(waker.clone()));
}
// This variable will contain the newly registered waker if a notification comes in before
// we complete registration.
let mut waker = None;
loop {
// If there was a notification, take the waker out of the awaiter field.
if state & NOTIFYING != 0 {
if let Some(w) = unsafe { (*self.awaiter.get()).take() } {
abort_on_panic(|| waker = Some(w));
}
}
// The new state is not being notified nor registered, but there might or might not be
// an awaiter depending on whether there was a concurrent notification.
let new = if waker.is_none() {
(state & !NOTIFYING & !REGISTERING) | AWAITER
} else {
state & !NOTIFYING & !REGISTERING & !AWAITER
};
match self
.state
.compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => break,
Err(s) => state = s,
}
}
// If there was a notification during registration, wake the awaiter now.
if let Some(w) = waker {
abort_on_panic(|| w.wake());
}
}
}
impl fmt::Debug for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.state.load(Ordering::SeqCst);
f.debug_struct("Header")
.field("scheduled", &(state & SCHEDULED != 0))
.field("running", &(state & RUNNING != 0))
.field("completed", &(state & COMPLETED != 0))
.field("closed", &(state & CLOSED != 0))
.field("awaiter", &(state & AWAITER != 0))
.field("task", &(state & TASK != 0))
.field("ref_count", &(state / REFERENCE))
.finish()
}
}