blob: fa98bae0b0157ded37a001893bc725aba2285709 [file] [log] [blame]
//! Futures task based helpers
#![allow(clippy::mutex_atomic)]
use std::future::Future;
use std::mem;
use std::ops;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use tokio_stream::Stream;
/// TODO: dox
pub fn spawn<T>(task: T) -> Spawn<T> {
Spawn {
task: MockTask::new(),
future: Box::pin(task),
}
}
/// Future spawned on a mock task
#[derive(Debug)]
pub struct Spawn<T> {
task: MockTask,
future: Pin<Box<T>>,
}
/// Mock task
///
/// A mock task is able to intercept and track wake notifications.
#[derive(Debug, Clone)]
struct MockTask {
waker: Arc<ThreadWaker>,
}
#[derive(Debug)]
struct ThreadWaker {
state: Mutex<usize>,
condvar: Condvar,
}
const IDLE: usize = 0;
const WAKE: usize = 1;
const SLEEP: usize = 2;
impl<T> Spawn<T> {
/// Consumes `self` returning the inner value
pub fn into_inner(self) -> T
where
T: Unpin,
{
*Pin::into_inner(self.future)
}
/// Returns `true` if the inner future has received a wake notification
/// since the last call to `enter`.
pub fn is_woken(&self) -> bool {
self.task.is_woken()
}
/// Returns the number of references to the task waker
///
/// The task itself holds a reference. The return value will never be zero.
pub fn waker_ref_count(&self) -> usize {
self.task.waker_ref_count()
}
/// Enter the task context
pub fn enter<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
{
let fut = self.future.as_mut();
self.task.enter(|cx| f(cx, fut))
}
}
impl<T: Unpin> ops::Deref for Spawn<T> {
type Target = T;
fn deref(&self) -> &T {
&self.future
}
}
impl<T: Unpin> ops::DerefMut for Spawn<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.future
}
}
impl<T: Future> Spawn<T> {
/// Polls a future
pub fn poll(&mut self) -> Poll<T::Output> {
let fut = self.future.as_mut();
self.task.enter(|cx| fut.poll(cx))
}
}
impl<T: Stream> Spawn<T> {
/// Polls a stream
pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
let stream = self.future.as_mut();
self.task.enter(|cx| stream.poll_next(cx))
}
}
impl<T: Future> Future for Spawn<T> {
type Output = T::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.future.as_mut().poll(cx)
}
}
impl<T: Stream> Stream for Spawn<T> {
type Item = T::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.future.as_mut().poll_next(cx)
}
}
impl MockTask {
/// Creates new mock task
fn new() -> Self {
MockTask {
waker: Arc::new(ThreadWaker::new()),
}
}
/// Runs a closure from the context of the task.
///
/// Any wake notifications resulting from the execution of the closure are
/// tracked.
fn enter<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>) -> R,
{
self.waker.clear();
let waker = self.waker();
let mut cx = Context::from_waker(&waker);
f(&mut cx)
}
/// Returns `true` if the inner future has received a wake notification
/// since the last call to `enter`.
fn is_woken(&self) -> bool {
self.waker.is_woken()
}
/// Returns the number of references to the task waker
///
/// The task itself holds a reference. The return value will never be zero.
fn waker_ref_count(&self) -> usize {
Arc::strong_count(&self.waker)
}
fn waker(&self) -> Waker {
unsafe {
let raw = to_raw(self.waker.clone());
Waker::from_raw(raw)
}
}
}
impl Default for MockTask {
fn default() -> Self {
Self::new()
}
}
impl ThreadWaker {
fn new() -> Self {
ThreadWaker {
state: Mutex::new(IDLE),
condvar: Condvar::new(),
}
}
/// Clears any previously received wakes, avoiding potential spurrious
/// wake notifications. This should only be called immediately before running the
/// task.
fn clear(&self) {
*self.state.lock().unwrap() = IDLE;
}
fn is_woken(&self) -> bool {
match *self.state.lock().unwrap() {
IDLE => false,
WAKE => true,
_ => unreachable!(),
}
}
fn wake(&self) {
// First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
let mut state = self.state.lock().unwrap();
let prev = *state;
if prev == WAKE {
return;
}
*state = WAKE;
if prev == IDLE {
return;
}
// The other half is sleeping, so we wake it up.
assert_eq!(prev, SLEEP);
self.condvar.notify_one();
}
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
}
unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
Arc::from_raw(raw as *const ThreadWaker)
}
unsafe fn clone(raw: *const ()) -> RawWaker {
let waker = from_raw(raw);
// Increment the ref count
mem::forget(waker.clone());
to_raw(waker)
}
unsafe fn wake(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
}
unsafe fn wake_by_ref(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
// We don't actually own a reference to the unparker
mem::forget(waker);
}
unsafe fn drop_waker(raw: *const ()) {
let _ = from_raw(raw);
}