| use crate::task::{waker_ref, ArcWake}; |
| use futures_core::future::{FusedFuture, Future}; |
| use futures_core::task::{Context, Poll, Waker}; |
| use slab::Slab; |
| use std::cell::UnsafeCell; |
| use std::fmt; |
| use std::pin::Pin; |
| use std::sync::atomic::AtomicUsize; |
| use std::sync::atomic::Ordering::{Acquire, SeqCst}; |
| use std::sync::{Arc, Mutex}; |
| |
| /// Future for the [`shared`](super::FutureExt::shared) method. |
| #[must_use = "futures do nothing unless you `.await` or poll them"] |
| pub struct Shared<Fut: Future> { |
| inner: Option<Arc<Inner<Fut>>>, |
| waker_key: usize, |
| } |
| |
| struct Inner<Fut: Future> { |
| future_or_output: UnsafeCell<FutureOrOutput<Fut>>, |
| notifier: Arc<Notifier>, |
| } |
| |
| struct Notifier { |
| state: AtomicUsize, |
| wakers: Mutex<Option<Slab<Option<Waker>>>>, |
| } |
| |
| // The future itself is polled behind the `Arc`, so it won't be moved |
| // when `Shared` is moved. |
| impl<Fut: Future> Unpin for Shared<Fut> {} |
| |
| impl<Fut: Future> fmt::Debug for Shared<Fut> { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("Shared") |
| .field("inner", &self.inner) |
| .field("waker_key", &self.waker_key) |
| .finish() |
| } |
| } |
| |
| impl<Fut: Future> fmt::Debug for Inner<Fut> { |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| f.debug_struct("Inner").finish() |
| } |
| } |
| |
| enum FutureOrOutput<Fut: Future> { |
| Future(Fut), |
| Output(Fut::Output), |
| } |
| |
| unsafe impl<Fut> Send for Inner<Fut> |
| where |
| Fut: Future + Send, |
| Fut::Output: Send + Sync, |
| { |
| } |
| |
| unsafe impl<Fut> Sync for Inner<Fut> |
| where |
| Fut: Future + Send, |
| Fut::Output: Send + Sync, |
| { |
| } |
| |
| const IDLE: usize = 0; |
| const POLLING: usize = 1; |
| const COMPLETE: usize = 2; |
| const POISONED: usize = 3; |
| |
| const NULL_WAKER_KEY: usize = usize::max_value(); |
| |
| impl<Fut: Future> Shared<Fut> { |
| pub(super) fn new(future: Fut) -> Shared<Fut> { |
| let inner = Inner { |
| future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)), |
| notifier: Arc::new(Notifier { |
| state: AtomicUsize::new(IDLE), |
| wakers: Mutex::new(Some(Slab::new())), |
| }), |
| }; |
| |
| Shared { |
| inner: Some(Arc::new(inner)), |
| waker_key: NULL_WAKER_KEY, |
| } |
| } |
| } |
| |
| impl<Fut> Shared<Fut> |
| where |
| Fut: Future, |
| Fut::Output: Clone, |
| { |
| /// Returns [`Some`] containing a reference to this [`Shared`]'s output if |
| /// it has already been computed by a clone or [`None`] if it hasn't been |
| /// computed yet or this [`Shared`] already returned its output from |
| /// [`poll`](Future::poll). |
| pub fn peek(&self) -> Option<&Fut::Output> { |
| if let Some(inner) = self.inner.as_ref() { |
| match inner.notifier.state.load(SeqCst) { |
| COMPLETE => unsafe { return Some(inner.output()) }, |
| POISONED => panic!("inner future panicked during poll"), |
| _ => {} |
| } |
| } |
| None |
| } |
| } |
| |
| impl<Fut> Inner<Fut> |
| where |
| Fut: Future, |
| Fut::Output: Clone, |
| { |
| /// Safety: callers must first ensure that `self.inner.state` |
| /// is `COMPLETE` |
| unsafe fn output(&self) -> &Fut::Output { |
| match &*self.future_or_output.get() { |
| FutureOrOutput::Output(ref item) => &item, |
| FutureOrOutput::Future(_) => unreachable!(), |
| } |
| } |
| /// Registers the current task to receive a wakeup when we are awoken. |
| fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) { |
| let mut wakers_guard = self.notifier.wakers.lock().unwrap(); |
| |
| let wakers = match wakers_guard.as_mut() { |
| Some(wakers) => wakers, |
| None => return, |
| }; |
| |
| let new_waker = cx.waker(); |
| |
| if *waker_key == NULL_WAKER_KEY { |
| *waker_key = wakers.insert(Some(new_waker.clone())); |
| } else { |
| match wakers[*waker_key] { |
| Some(ref old_waker) if new_waker.will_wake(old_waker) => {} |
| // Could use clone_from here, but Waker doesn't specialize it. |
| ref mut slot => *slot = Some(new_waker.clone()), |
| } |
| } |
| debug_assert!(*waker_key != NULL_WAKER_KEY); |
| } |
| |
| /// Safety: callers must first ensure that `inner.state` |
| /// is `COMPLETE` |
| unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output { |
| match Arc::try_unwrap(self) { |
| Ok(inner) => match inner.future_or_output.into_inner() { |
| FutureOrOutput::Output(item) => item, |
| FutureOrOutput::Future(_) => unreachable!(), |
| }, |
| Err(inner) => inner.output().clone(), |
| } |
| } |
| } |
| |
| impl<Fut> FusedFuture for Shared<Fut> |
| where |
| Fut: Future, |
| Fut::Output: Clone, |
| { |
| fn is_terminated(&self) -> bool { |
| self.inner.is_none() |
| } |
| } |
| |
| impl<Fut> Future for Shared<Fut> |
| where |
| Fut: Future, |
| Fut::Output: Clone, |
| { |
| type Output = Fut::Output; |
| |
| fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| let this = &mut *self; |
| |
| let inner = this |
| .inner |
| .take() |
| .expect("Shared future polled again after completion"); |
| |
| // Fast path for when the wrapped future has already completed |
| if inner.notifier.state.load(Acquire) == COMPLETE { |
| // Safety: We're in the COMPLETE state |
| return unsafe { Poll::Ready(inner.take_or_clone_output()) }; |
| } |
| |
| inner.record_waker(&mut this.waker_key, cx); |
| |
| match inner.notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) { |
| IDLE => { |
| // Lock acquired, fall through |
| } |
| POLLING => { |
| // Another task is currently polling, at this point we just want |
| // to ensure that the waker for this task is registered |
| this.inner = Some(inner); |
| return Poll::Pending; |
| } |
| COMPLETE => { |
| // Safety: We're in the COMPLETE state |
| return unsafe { Poll::Ready(inner.take_or_clone_output()) }; |
| } |
| POISONED => panic!("inner future panicked during poll"), |
| _ => unreachable!(), |
| } |
| |
| let waker = waker_ref(&inner.notifier); |
| let mut cx = Context::from_waker(&waker); |
| |
| struct Reset<'a>(&'a AtomicUsize); |
| |
| impl Drop for Reset<'_> { |
| fn drop(&mut self) { |
| use std::thread; |
| |
| if thread::panicking() { |
| self.0.store(POISONED, SeqCst); |
| } |
| } |
| } |
| |
| let _reset = Reset(&inner.notifier.state); |
| |
| let output = { |
| let future = unsafe { |
| match &mut *inner.future_or_output.get() { |
| FutureOrOutput::Future(fut) => Pin::new_unchecked(fut), |
| _ => unreachable!(), |
| } |
| }; |
| |
| match future.poll(&mut cx) { |
| Poll::Pending => { |
| match inner.notifier.state.compare_and_swap(POLLING, IDLE, SeqCst) { |
| POLLING => { |
| // Success |
| drop(_reset); |
| this.inner = Some(inner); |
| return Poll::Pending; |
| } |
| _ => unreachable!(), |
| } |
| } |
| Poll::Ready(output) => output, |
| } |
| }; |
| |
| unsafe { |
| *inner.future_or_output.get() = FutureOrOutput::Output(output); |
| } |
| |
| inner.notifier.state.store(COMPLETE, SeqCst); |
| |
| // Wake all tasks and drop the slab |
| let mut wakers_guard = inner.notifier.wakers.lock().unwrap(); |
| let mut wakers = wakers_guard.take().unwrap(); |
| for opt_waker in wakers.drain() { |
| if let Some(waker) = opt_waker { |
| waker.wake(); |
| } |
| } |
| |
| drop(_reset); // Make borrow checker happy |
| drop(wakers_guard); |
| |
| // Safety: We're in the COMPLETE state |
| unsafe { Poll::Ready(inner.take_or_clone_output()) } |
| } |
| } |
| |
| impl<Fut> Clone for Shared<Fut> |
| where |
| Fut: Future, |
| { |
| fn clone(&self) -> Self { |
| Shared { |
| inner: self.inner.clone(), |
| waker_key: NULL_WAKER_KEY, |
| } |
| } |
| } |
| |
| impl<Fut> Drop for Shared<Fut> |
| where |
| Fut: Future, |
| { |
| fn drop(&mut self) { |
| if self.waker_key != NULL_WAKER_KEY { |
| if let Some(ref inner) = self.inner { |
| if let Ok(mut wakers) = inner.notifier.wakers.lock() { |
| if let Some(wakers) = wakers.as_mut() { |
| wakers.remove(self.waker_key); |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| impl ArcWake for Notifier { |
| fn wake_by_ref(arc_self: &Arc<Self>) { |
| let wakers = &mut *arc_self.wakers.lock().unwrap(); |
| if let Some(wakers) = wakers.as_mut() { |
| for (_key, opt_waker) in wakers { |
| if let Some(waker) = opt_waker.take() { |
| waker.wake(); |
| } |
| } |
| } |
| } |
| } |