blob: c01078cabd5efc52e4a95fbf1a27528fa48369fe [file] [log] [blame]
// There's a lot of scary concurrent code in this module, but it is copied from
// `std::sync::Once` with two changes:
// * no poisoning
// * init function can fail
use std::{
cell::{Cell, UnsafeCell},
hint::unreachable_unchecked,
marker::PhantomData,
panic::{RefUnwindSafe, UnwindSafe},
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
thread::{self, Thread},
};
use crate::take_unchecked;
#[derive(Debug)]
pub(crate) struct OnceCell<T> {
// This `queue` field is the core of the implementation. It encodes two
// pieces of information:
//
// * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
// * Linked list of threads waiting for the current cell.
//
// State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
// allow waiters.
queue: AtomicUsize,
_marker: PhantomData<*mut Waiter>,
value: UnsafeCell<Option<T>>,
}
// Why do we need `T: Send`?
// Thread A creates a `OnceCell` and shares it with
// scoped thread B, which fills the cell, which is
// then destroyed by A. That is, destructor observes
// a sent value.
unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
unsafe impl<T: Send> Send for OnceCell<T> {}
impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
impl<T> OnceCell<T> {
pub(crate) const fn new() -> OnceCell<T> {
OnceCell {
queue: AtomicUsize::new(INCOMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(None),
}
}
pub(crate) const fn with_value(value: T) -> OnceCell<T> {
OnceCell {
queue: AtomicUsize::new(COMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(Some(value)),
}
}
/// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
#[inline]
pub(crate) fn is_initialized(&self) -> bool {
// An `Acquire` load is enough because that makes all the initialization
// operations visible to us, and, this being a fast path, weaker
// ordering helps with performance. This `Acquire` synchronizes with
// `SeqCst` operations on the slow path.
self.queue.load(Ordering::Acquire) == COMPLETE
}
/// Safety: synchronizes with store to value via SeqCst read from state,
/// writes value only once because we never get to INCOMPLETE state after a
/// successful write.
#[cold]
pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
where
F: FnOnce() -> Result<T, E>,
{
let mut f = Some(f);
let mut res: Result<(), E> = Ok(());
let slot: *mut Option<T> = self.value.get();
initialize_or_wait(
&self.queue,
Some(&mut || {
let f = unsafe { take_unchecked(&mut f) };
match f() {
Ok(value) => {
unsafe { *slot = Some(value) };
true
}
Err(err) => {
res = Err(err);
false
}
}
}),
);
res
}
#[cold]
pub(crate) fn wait(&self) {
initialize_or_wait(&self.queue, None);
}
/// Get the reference to the underlying value, without checking if the cell
/// is initialized.
///
/// # Safety
///
/// Caller must ensure that the cell is in initialized state, and that
/// the contents are acquired by (synchronized to) this thread.
pub(crate) unsafe fn get_unchecked(&self) -> &T {
debug_assert!(self.is_initialized());
let slot: &Option<T> = &*self.value.get();
match slot {
Some(value) => value,
// This unsafe does improve performance, see `examples/bench`.
None => {
debug_assert!(false);
unreachable_unchecked()
}
}
}
/// Gets the mutable reference to the underlying value.
/// Returns `None` if the cell is empty.
pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
// Safe b/c we have a unique access.
unsafe { &mut *self.value.get() }.as_mut()
}
/// Consumes this `OnceCell`, returning the wrapped value.
/// Returns `None` if the cell was empty.
#[inline]
pub(crate) fn into_inner(self) -> Option<T> {
// Because `into_inner` takes `self` by value, the compiler statically
// verifies that it is not currently borrowed.
// So, it is safe to move out `Option<T>`.
self.value.into_inner()
}
}
// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;
// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;
/// Representation of a node in the linked list of waiters in the RUNNING state.
/// A waiters is stored on the stack of the waiting threads.
#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
}
/// Drains and notifies the queue of waiters on drop.
struct Guard<'a> {
queue: &'a AtomicUsize,
new_queue: usize,
}
impl Drop for Guard<'_> {
fn drop(&mut self) {
let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);
assert_eq!(queue & STATE_MASK, RUNNING);
unsafe {
let mut waiter = (queue & !STATE_MASK) as *const Waiter;
while !waiter.is_null() {
let next = (*waiter).next;
let thread = (*waiter).thread.take().unwrap();
(*waiter).signaled.store(true, Ordering::Release);
waiter = next;
thread.unpark();
}
}
}
}
// Corresponds to `std::sync::Once::call_inner`.
//
// Originally copied from std, but since modified to remove poisoning and to
// support wait.
//
// Note: this is intentionally monomorphic
#[inline(never)]
fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) {
let mut curr_queue = queue.load(Ordering::Acquire);
loop {
let curr_state = curr_queue & STATE_MASK;
match (curr_state, &mut init) {
(COMPLETE, _) => return,
(INCOMPLETE, Some(init)) => {
let exchange = queue.compare_exchange(
curr_queue,
(curr_queue & !STATE_MASK) | RUNNING,
Ordering::Acquire,
Ordering::Acquire,
);
if let Err(new_queue) = exchange {
curr_queue = new_queue;
continue;
}
let mut guard = Guard { queue, new_queue: INCOMPLETE };
if init() {
guard.new_queue = COMPLETE;
}
return;
}
(INCOMPLETE, None) | (RUNNING, _) => {
wait(&queue, curr_queue);
curr_queue = queue.load(Ordering::Acquire);
}
_ => debug_assert!(false),
}
}
}
fn wait(queue: &AtomicUsize, mut curr_queue: usize) {
let curr_state = curr_queue & STATE_MASK;
loop {
let node = Waiter {
thread: Cell::new(Some(thread::current())),
signaled: AtomicBool::new(false),
next: (curr_queue & !STATE_MASK) as *const Waiter,
};
let me = &node as *const Waiter as usize;
let exchange = queue.compare_exchange(
curr_queue,
me | curr_state,
Ordering::Release,
Ordering::Relaxed,
);
if let Err(new_queue) = exchange {
if new_queue & STATE_MASK != curr_state {
return;
}
curr_queue = new_queue;
continue;
}
while !node.signaled.load(Ordering::Acquire) {
thread::park();
}
break;
}
}
// These test are snatched from std as well.
#[cfg(test)]
mod tests {
use std::panic;
use std::{sync::mpsc::channel, thread};
use super::OnceCell;
impl<T> OnceCell<T> {
fn init(&self, f: impl FnOnce() -> T) {
enum Void {}
let _ = self.initialize(|| Ok::<T, Void>(f()));
}
}
#[test]
fn smoke_once() {
static O: OnceCell<()> = OnceCell::new();
let mut a = 0;
O.init(|| a += 1);
assert_eq!(a, 1);
O.init(|| a += 1);
assert_eq!(a, 1);
}
#[test]
#[cfg(not(miri))]
fn stampede_once() {
static O: OnceCell<()> = OnceCell::new();
static mut RUN: bool = false;
let (tx, rx) = channel();
for _ in 0..10 {
let tx = tx.clone();
thread::spawn(move || {
for _ in 0..4 {
thread::yield_now()
}
unsafe {
O.init(|| {
assert!(!RUN);
RUN = true;
});
assert!(RUN);
}
tx.send(()).unwrap();
});
}
unsafe {
O.init(|| {
assert!(!RUN);
RUN = true;
});
assert!(RUN);
}
for _ in 0..10 {
rx.recv().unwrap();
}
}
#[test]
fn poison_bad() {
static O: OnceCell<()> = OnceCell::new();
// poison the once
let t = panic::catch_unwind(|| {
O.init(|| panic!());
});
assert!(t.is_err());
// we can subvert poisoning, however
let mut called = false;
O.init(|| {
called = true;
});
assert!(called);
// once any success happens, we stop propagating the poison
O.init(|| {});
}
#[test]
fn wait_for_force_to_finish() {
static O: OnceCell<()> = OnceCell::new();
// poison the once
let t = panic::catch_unwind(|| {
O.init(|| panic!());
});
assert!(t.is_err());
// make sure someone's waiting inside the once via a force
let (tx1, rx1) = channel();
let (tx2, rx2) = channel();
let t1 = thread::spawn(move || {
O.init(|| {
tx1.send(()).unwrap();
rx2.recv().unwrap();
});
});
rx1.recv().unwrap();
// put another waiter on the once
let t2 = thread::spawn(|| {
let mut called = false;
O.init(|| {
called = true;
});
assert!(!called);
});
tx2.send(()).unwrap();
assert!(t1.join().is_ok());
assert!(t2.join().is_ok());
}
#[test]
#[cfg(target_pointer_width = "64")]
fn test_size() {
use std::mem::size_of;
assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
}
}