|  | //! Futures task based helpers | 
|  |  | 
|  | #![allow(clippy::mutex_atomic)] | 
|  |  | 
|  | use std::future::Future; | 
|  | use std::mem; | 
|  | use std::ops; | 
|  | use std::pin::Pin; | 
|  | use std::sync::{Arc, Condvar, Mutex}; | 
|  | use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; | 
|  |  | 
|  | use tokio_stream::Stream; | 
|  |  | 
|  | /// TODO: dox | 
|  | pub fn spawn<T>(task: T) -> Spawn<T> { | 
|  | Spawn { | 
|  | task: MockTask::new(), | 
|  | future: Box::pin(task), | 
|  | } | 
|  | } | 
|  |  | 
|  | /// Future spawned on a mock task | 
|  | #[derive(Debug)] | 
|  | pub struct Spawn<T> { | 
|  | task: MockTask, | 
|  | future: Pin<Box<T>>, | 
|  | } | 
|  |  | 
|  | /// Mock task | 
|  | /// | 
|  | /// A mock task is able to intercept and track wake notifications. | 
|  | #[derive(Debug, Clone)] | 
|  | struct MockTask { | 
|  | waker: Arc<ThreadWaker>, | 
|  | } | 
|  |  | 
|  | #[derive(Debug)] | 
|  | struct ThreadWaker { | 
|  | state: Mutex<usize>, | 
|  | condvar: Condvar, | 
|  | } | 
|  |  | 
|  | const IDLE: usize = 0; | 
|  | const WAKE: usize = 1; | 
|  | const SLEEP: usize = 2; | 
|  |  | 
|  | impl<T> Spawn<T> { | 
|  | /// Consumes `self` returning the inner value | 
|  | pub fn into_inner(self) -> T | 
|  | where | 
|  | T: Unpin, | 
|  | { | 
|  | *Pin::into_inner(self.future) | 
|  | } | 
|  |  | 
|  | /// Returns `true` if the inner future has received a wake notification | 
|  | /// since the last call to `enter`. | 
|  | pub fn is_woken(&self) -> bool { | 
|  | self.task.is_woken() | 
|  | } | 
|  |  | 
|  | /// Returns the number of references to the task waker | 
|  | /// | 
|  | /// The task itself holds a reference. The return value will never be zero. | 
|  | pub fn waker_ref_count(&self) -> usize { | 
|  | self.task.waker_ref_count() | 
|  | } | 
|  |  | 
|  | /// Enter the task context | 
|  | pub fn enter<F, R>(&mut self, f: F) -> R | 
|  | where | 
|  | F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R, | 
|  | { | 
|  | let fut = self.future.as_mut(); | 
|  | self.task.enter(|cx| f(cx, fut)) | 
|  | } | 
|  | } | 
|  |  | 
|  | impl<T: Unpin> ops::Deref for Spawn<T> { | 
|  | type Target = T; | 
|  |  | 
|  | fn deref(&self) -> &T { | 
|  | &self.future | 
|  | } | 
|  | } | 
|  |  | 
|  | impl<T: Unpin> ops::DerefMut for Spawn<T> { | 
|  | fn deref_mut(&mut self) -> &mut T { | 
|  | &mut self.future | 
|  | } | 
|  | } | 
|  |  | 
|  | impl<T: Future> Spawn<T> { | 
|  | /// Polls a future | 
|  | pub fn poll(&mut self) -> Poll<T::Output> { | 
|  | let fut = self.future.as_mut(); | 
|  | self.task.enter(|cx| fut.poll(cx)) | 
|  | } | 
|  | } | 
|  |  | 
|  | impl<T: Stream> Spawn<T> { | 
|  | /// Polls a stream | 
|  | pub fn poll_next(&mut self) -> Poll<Option<T::Item>> { | 
|  | let stream = self.future.as_mut(); | 
|  | self.task.enter(|cx| stream.poll_next(cx)) | 
|  | } | 
|  | } | 
|  |  | 
|  | impl<T: Future> Future for Spawn<T> { | 
|  | type Output = T::Output; | 
|  |  | 
|  | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | 
|  | self.future.as_mut().poll(cx) | 
|  | } | 
|  | } | 
|  |  | 
|  | impl<T: Stream> Stream for Spawn<T> { | 
|  | type Item = T::Item; | 
|  |  | 
|  | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | 
|  | self.future.as_mut().poll_next(cx) | 
|  | } | 
|  | } | 
|  |  | 
|  | impl MockTask { | 
|  | /// Creates new mock task | 
|  | fn new() -> Self { | 
|  | MockTask { | 
|  | waker: Arc::new(ThreadWaker::new()), | 
|  | } | 
|  | } | 
|  |  | 
|  | /// Runs a closure from the context of the task. | 
|  | /// | 
|  | /// Any wake notifications resulting from the execution of the closure are | 
|  | /// tracked. | 
|  | fn enter<F, R>(&mut self, f: F) -> R | 
|  | where | 
|  | F: FnOnce(&mut Context<'_>) -> R, | 
|  | { | 
|  | self.waker.clear(); | 
|  | let waker = self.waker(); | 
|  | let mut cx = Context::from_waker(&waker); | 
|  |  | 
|  | f(&mut cx) | 
|  | } | 
|  |  | 
|  | /// Returns `true` if the inner future has received a wake notification | 
|  | /// since the last call to `enter`. | 
|  | fn is_woken(&self) -> bool { | 
|  | self.waker.is_woken() | 
|  | } | 
|  |  | 
|  | /// Returns the number of references to the task waker | 
|  | /// | 
|  | /// The task itself holds a reference. The return value will never be zero. | 
|  | fn waker_ref_count(&self) -> usize { | 
|  | Arc::strong_count(&self.waker) | 
|  | } | 
|  |  | 
|  | fn waker(&self) -> Waker { | 
|  | unsafe { | 
|  | let raw = to_raw(self.waker.clone()); | 
|  | Waker::from_raw(raw) | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | impl Default for MockTask { | 
|  | fn default() -> Self { | 
|  | Self::new() | 
|  | } | 
|  | } | 
|  |  | 
|  | impl ThreadWaker { | 
|  | fn new() -> Self { | 
|  | ThreadWaker { | 
|  | state: Mutex::new(IDLE), | 
|  | condvar: Condvar::new(), | 
|  | } | 
|  | } | 
|  |  | 
|  | /// Clears any previously received wakes, avoiding potential spurrious | 
|  | /// wake notifications. This should only be called immediately before running the | 
|  | /// task. | 
|  | fn clear(&self) { | 
|  | *self.state.lock().unwrap() = IDLE; | 
|  | } | 
|  |  | 
|  | fn is_woken(&self) -> bool { | 
|  | match *self.state.lock().unwrap() { | 
|  | IDLE => false, | 
|  | WAKE => true, | 
|  | _ => unreachable!(), | 
|  | } | 
|  | } | 
|  |  | 
|  | fn wake(&self) { | 
|  | // First, try transitioning from IDLE -> NOTIFY, this does not require a lock. | 
|  | let mut state = self.state.lock().unwrap(); | 
|  | let prev = *state; | 
|  |  | 
|  | if prev == WAKE { | 
|  | return; | 
|  | } | 
|  |  | 
|  | *state = WAKE; | 
|  |  | 
|  | if prev == IDLE { | 
|  | return; | 
|  | } | 
|  |  | 
|  | // The other half is sleeping, so we wake it up. | 
|  | assert_eq!(prev, SLEEP); | 
|  | self.condvar.notify_one(); | 
|  | } | 
|  | } | 
|  |  | 
|  | static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker); | 
|  |  | 
|  | unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker { | 
|  | RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE) | 
|  | } | 
|  |  | 
|  | unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> { | 
|  | Arc::from_raw(raw as *const ThreadWaker) | 
|  | } | 
|  |  | 
|  | unsafe fn clone(raw: *const ()) -> RawWaker { | 
|  | let waker = from_raw(raw); | 
|  |  | 
|  | // Increment the ref count | 
|  | mem::forget(waker.clone()); | 
|  |  | 
|  | to_raw(waker) | 
|  | } | 
|  |  | 
|  | unsafe fn wake(raw: *const ()) { | 
|  | let waker = from_raw(raw); | 
|  | waker.wake(); | 
|  | } | 
|  |  | 
|  | unsafe fn wake_by_ref(raw: *const ()) { | 
|  | let waker = from_raw(raw); | 
|  | waker.wake(); | 
|  |  | 
|  | // We don't actually own a reference to the unparker | 
|  | mem::forget(waker); | 
|  | } | 
|  |  | 
|  | unsafe fn drop_waker(raw: *const ()) { | 
|  | let _ = from_raw(raw); | 
|  | } |