| // There's a lot of scary concurrent code in this module, but it is copied from |
| // `std::sync::Once` with two changes: |
| // * no poisoning |
| // * init function can fail |
| |
| use std::{ |
| cell::{Cell, UnsafeCell}, |
| hint::unreachable_unchecked, |
| marker::PhantomData, |
| panic::{RefUnwindSafe, UnwindSafe}, |
| sync::atomic::{AtomicBool, AtomicUsize, Ordering}, |
| thread::{self, Thread}, |
| }; |
| |
| use crate::take_unchecked; |
| |
| #[derive(Debug)] |
| pub(crate) struct OnceCell<T> { |
| // This `queue` field is the core of the implementation. It encodes two |
| // pieces of information: |
| // |
| // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`) |
| // * Linked list of threads waiting for the current cell. |
| // |
| // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states |
| // allow waiters. |
| queue: AtomicUsize, |
| _marker: PhantomData<*mut Waiter>, |
| value: UnsafeCell<Option<T>>, |
| } |
| |
| // Why do we need `T: Send`? |
| // Thread A creates a `OnceCell` and shares it with |
| // scoped thread B, which fills the cell, which is |
| // then destroyed by A. That is, destructor observes |
| // a sent value. |
| unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} |
| unsafe impl<T: Send> Send for OnceCell<T> {} |
| |
| impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {} |
| impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {} |
| |
| impl<T> OnceCell<T> { |
| pub(crate) const fn new() -> OnceCell<T> { |
| OnceCell { |
| queue: AtomicUsize::new(INCOMPLETE), |
| _marker: PhantomData, |
| value: UnsafeCell::new(None), |
| } |
| } |
| |
| pub(crate) const fn with_value(value: T) -> OnceCell<T> { |
| OnceCell { |
| queue: AtomicUsize::new(COMPLETE), |
| _marker: PhantomData, |
| value: UnsafeCell::new(Some(value)), |
| } |
| } |
| |
| /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst). |
| #[inline] |
| pub(crate) fn is_initialized(&self) -> bool { |
| // An `Acquire` load is enough because that makes all the initialization |
| // operations visible to us, and, this being a fast path, weaker |
| // ordering helps with performance. This `Acquire` synchronizes with |
| // `SeqCst` operations on the slow path. |
| self.queue.load(Ordering::Acquire) == COMPLETE |
| } |
| |
| /// Safety: synchronizes with store to value via SeqCst read from state, |
| /// writes value only once because we never get to INCOMPLETE state after a |
| /// successful write. |
| #[cold] |
| pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E> |
| where |
| F: FnOnce() -> Result<T, E>, |
| { |
| let mut f = Some(f); |
| let mut res: Result<(), E> = Ok(()); |
| let slot: *mut Option<T> = self.value.get(); |
| initialize_or_wait( |
| &self.queue, |
| Some(&mut || { |
| let f = unsafe { take_unchecked(&mut f) }; |
| match f() { |
| Ok(value) => { |
| unsafe { *slot = Some(value) }; |
| true |
| } |
| Err(err) => { |
| res = Err(err); |
| false |
| } |
| } |
| }), |
| ); |
| res |
| } |
| |
| #[cold] |
| pub(crate) fn wait(&self) { |
| initialize_or_wait(&self.queue, None); |
| } |
| |
| /// Get the reference to the underlying value, without checking if the cell |
| /// is initialized. |
| /// |
| /// # Safety |
| /// |
| /// Caller must ensure that the cell is in initialized state, and that |
| /// the contents are acquired by (synchronized to) this thread. |
| pub(crate) unsafe fn get_unchecked(&self) -> &T { |
| debug_assert!(self.is_initialized()); |
| let slot: &Option<T> = &*self.value.get(); |
| match slot { |
| Some(value) => value, |
| // This unsafe does improve performance, see `examples/bench`. |
| None => { |
| debug_assert!(false); |
| unreachable_unchecked() |
| } |
| } |
| } |
| |
| /// Gets the mutable reference to the underlying value. |
| /// Returns `None` if the cell is empty. |
| pub(crate) fn get_mut(&mut self) -> Option<&mut T> { |
| // Safe b/c we have a unique access. |
| unsafe { &mut *self.value.get() }.as_mut() |
| } |
| |
| /// Consumes this `OnceCell`, returning the wrapped value. |
| /// Returns `None` if the cell was empty. |
| #[inline] |
| pub(crate) fn into_inner(self) -> Option<T> { |
| // Because `into_inner` takes `self` by value, the compiler statically |
| // verifies that it is not currently borrowed. |
| // So, it is safe to move out `Option<T>`. |
| self.value.into_inner() |
| } |
| } |
| |
| // Three states that a OnceCell can be in, encoded into the lower bits of `queue` in |
| // the OnceCell structure. |
| const INCOMPLETE: usize = 0x0; |
| const RUNNING: usize = 0x1; |
| const COMPLETE: usize = 0x2; |
| |
| // Mask to learn about the state. All other bits are the queue of waiters if |
| // this is in the RUNNING state. |
| const STATE_MASK: usize = 0x3; |
| |
| /// Representation of a node in the linked list of waiters in the RUNNING state. |
| /// A waiters is stored on the stack of the waiting threads. |
| #[repr(align(4))] // Ensure the two lower bits are free to use as state bits. |
| struct Waiter { |
| thread: Cell<Option<Thread>>, |
| signaled: AtomicBool, |
| next: *const Waiter, |
| } |
| |
| /// Drains and notifies the queue of waiters on drop. |
| struct Guard<'a> { |
| queue: &'a AtomicUsize, |
| new_queue: usize, |
| } |
| |
| impl Drop for Guard<'_> { |
| fn drop(&mut self) { |
| let queue = self.queue.swap(self.new_queue, Ordering::AcqRel); |
| |
| assert_eq!(queue & STATE_MASK, RUNNING); |
| |
| unsafe { |
| let mut waiter = (queue & !STATE_MASK) as *const Waiter; |
| while !waiter.is_null() { |
| let next = (*waiter).next; |
| let thread = (*waiter).thread.take().unwrap(); |
| (*waiter).signaled.store(true, Ordering::Release); |
| waiter = next; |
| thread.unpark(); |
| } |
| } |
| } |
| } |
| |
| // Corresponds to `std::sync::Once::call_inner`. |
| // |
| // Originally copied from std, but since modified to remove poisoning and to |
| // support wait. |
| // |
| // Note: this is intentionally monomorphic |
| #[inline(never)] |
| fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) { |
| let mut curr_queue = queue.load(Ordering::Acquire); |
| |
| loop { |
| let curr_state = curr_queue & STATE_MASK; |
| match (curr_state, &mut init) { |
| (COMPLETE, _) => return, |
| (INCOMPLETE, Some(init)) => { |
| let exchange = queue.compare_exchange( |
| curr_queue, |
| (curr_queue & !STATE_MASK) | RUNNING, |
| Ordering::Acquire, |
| Ordering::Acquire, |
| ); |
| if let Err(new_queue) = exchange { |
| curr_queue = new_queue; |
| continue; |
| } |
| let mut guard = Guard { queue, new_queue: INCOMPLETE }; |
| if init() { |
| guard.new_queue = COMPLETE; |
| } |
| return; |
| } |
| (INCOMPLETE, None) | (RUNNING, _) => { |
| wait(&queue, curr_queue); |
| curr_queue = queue.load(Ordering::Acquire); |
| } |
| _ => debug_assert!(false), |
| } |
| } |
| } |
| |
| fn wait(queue: &AtomicUsize, mut curr_queue: usize) { |
| let curr_state = curr_queue & STATE_MASK; |
| loop { |
| let node = Waiter { |
| thread: Cell::new(Some(thread::current())), |
| signaled: AtomicBool::new(false), |
| next: (curr_queue & !STATE_MASK) as *const Waiter, |
| }; |
| let me = &node as *const Waiter as usize; |
| |
| let exchange = queue.compare_exchange( |
| curr_queue, |
| me | curr_state, |
| Ordering::Release, |
| Ordering::Relaxed, |
| ); |
| if let Err(new_queue) = exchange { |
| if new_queue & STATE_MASK != curr_state { |
| return; |
| } |
| curr_queue = new_queue; |
| continue; |
| } |
| |
| while !node.signaled.load(Ordering::Acquire) { |
| thread::park(); |
| } |
| break; |
| } |
| } |
| |
| // These test are snatched from std as well. |
| #[cfg(test)] |
| mod tests { |
| use std::panic; |
| use std::{sync::mpsc::channel, thread}; |
| |
| use super::OnceCell; |
| |
| impl<T> OnceCell<T> { |
| fn init(&self, f: impl FnOnce() -> T) { |
| enum Void {} |
| let _ = self.initialize(|| Ok::<T, Void>(f())); |
| } |
| } |
| |
| #[test] |
| fn smoke_once() { |
| static O: OnceCell<()> = OnceCell::new(); |
| let mut a = 0; |
| O.init(|| a += 1); |
| assert_eq!(a, 1); |
| O.init(|| a += 1); |
| assert_eq!(a, 1); |
| } |
| |
| #[test] |
| #[cfg(not(miri))] |
| fn stampede_once() { |
| static O: OnceCell<()> = OnceCell::new(); |
| static mut RUN: bool = false; |
| |
| let (tx, rx) = channel(); |
| for _ in 0..10 { |
| let tx = tx.clone(); |
| thread::spawn(move || { |
| for _ in 0..4 { |
| thread::yield_now() |
| } |
| unsafe { |
| O.init(|| { |
| assert!(!RUN); |
| RUN = true; |
| }); |
| assert!(RUN); |
| } |
| tx.send(()).unwrap(); |
| }); |
| } |
| |
| unsafe { |
| O.init(|| { |
| assert!(!RUN); |
| RUN = true; |
| }); |
| assert!(RUN); |
| } |
| |
| for _ in 0..10 { |
| rx.recv().unwrap(); |
| } |
| } |
| |
| #[test] |
| fn poison_bad() { |
| static O: OnceCell<()> = OnceCell::new(); |
| |
| // poison the once |
| let t = panic::catch_unwind(|| { |
| O.init(|| panic!()); |
| }); |
| assert!(t.is_err()); |
| |
| // we can subvert poisoning, however |
| let mut called = false; |
| O.init(|| { |
| called = true; |
| }); |
| assert!(called); |
| |
| // once any success happens, we stop propagating the poison |
| O.init(|| {}); |
| } |
| |
| #[test] |
| fn wait_for_force_to_finish() { |
| static O: OnceCell<()> = OnceCell::new(); |
| |
| // poison the once |
| let t = panic::catch_unwind(|| { |
| O.init(|| panic!()); |
| }); |
| assert!(t.is_err()); |
| |
| // make sure someone's waiting inside the once via a force |
| let (tx1, rx1) = channel(); |
| let (tx2, rx2) = channel(); |
| let t1 = thread::spawn(move || { |
| O.init(|| { |
| tx1.send(()).unwrap(); |
| rx2.recv().unwrap(); |
| }); |
| }); |
| |
| rx1.recv().unwrap(); |
| |
| // put another waiter on the once |
| let t2 = thread::spawn(|| { |
| let mut called = false; |
| O.init(|| { |
| called = true; |
| }); |
| assert!(!called); |
| }); |
| |
| tx2.send(()).unwrap(); |
| |
| assert!(t1.join().is_ok()); |
| assert!(t2.join().is_ok()); |
| } |
| |
| #[test] |
| #[cfg(target_pointer_width = "64")] |
| fn test_size() { |
| use std::mem::size_of; |
| |
| assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>()); |
| } |
| } |