blob: 77adfc3eec7544604ae0b47864cd4102ae0d957d [file] [log] [blame]
#![cfg(not(loom))]
//! A mock type implementing [`AsyncRead`] and [`AsyncWrite`].
//!
//!
//! # Overview
//!
//! Provides a type that implements [`AsyncRead`] + [`AsyncWrite`] that can be configured
//! to handle an arbitrary sequence of read and write operations. This is useful
//! for writing unit tests for networking services as using an actual network
//! type is fairly non deterministic.
//!
//! # Usage
//!
//! Attempting to write data that the mock isn't expecting will result in a
//! panic.
//!
//! [`AsyncRead`]: tokio::io::AsyncRead
//! [`AsyncWrite`]: tokio::io::AsyncWrite
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
use tokio::time::{self, Duration, Instant, Sleep};
use futures_core::{ready, Stream};
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Poll, Waker};
use std::{cmp, io};
/// An I/O object that follows a predefined script.
///
/// This value is created by `Builder` and implements `AsyncRead` + `AsyncWrite`. It
/// follows the scenario described by the builder and panics otherwise.
#[derive(Debug)]
pub struct Mock {
inner: Inner,
}
/// A handle to send additional actions to the related `Mock`.
#[derive(Debug)]
pub struct Handle {
tx: mpsc::UnboundedSender<Action>,
}
/// Builds `Mock` instances.
#[derive(Debug, Clone, Default)]
pub struct Builder {
// Sequence of actions for the Mock to take
actions: VecDeque<Action>,
}
#[derive(Debug, Clone)]
enum Action {
Read(Vec<u8>),
Write(Vec<u8>),
Wait(Duration),
// Wrapped in Arc so that Builder can be cloned and Send.
// Mock is not cloned as does not need to check Rc for ref counts.
ReadError(Option<Arc<io::Error>>),
WriteError(Option<Arc<io::Error>>),
}
struct Inner {
actions: VecDeque<Action>,
waiting: Option<Instant>,
sleep: Option<Pin<Box<Sleep>>>,
read_wait: Option<Waker>,
// rx: mpsc::UnboundedReceiver<Action>,
rx: Pin<Box<dyn Stream<Item = Action> + Send>>,
}
impl Builder {
/// Return a new, empty `Builder.
pub fn new() -> Self {
Self::default()
}
/// Sequence a `read` operation.
///
/// The next operation in the mock's script will be to expect a `read` call
/// and return `buf`.
pub fn read(&mut self, buf: &[u8]) -> &mut Self {
self.actions.push_back(Action::Read(buf.into()));
self
}
/// Sequence a `read` operation that produces an error.
///
/// The next operation in the mock's script will be to expect a `read` call
/// and return `error`.
pub fn read_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.actions.push_back(Action::ReadError(error));
self
}
/// Sequence a `write` operation.
///
/// The next operation in the mock's script will be to expect a `write`
/// call.
pub fn write(&mut self, buf: &[u8]) -> &mut Self {
self.actions.push_back(Action::Write(buf.into()));
self
}
/// Sequence a `write` operation that produces an error.
///
/// The next operation in the mock's script will be to expect a `write`
/// call that provides `error`.
pub fn write_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.actions.push_back(Action::WriteError(error));
self
}
/// Sequence a wait.
///
/// The next operation in the mock's script will be to wait without doing so
/// for `duration` amount of time.
pub fn wait(&mut self, duration: Duration) -> &mut Self {
let duration = cmp::max(duration, Duration::from_millis(1));
self.actions.push_back(Action::Wait(duration));
self
}
/// Build a `Mock` value according to the defined script.
pub fn build(&mut self) -> Mock {
let (mock, _) = self.build_with_handle();
mock
}
/// Build a `Mock` value paired with a handle
pub fn build_with_handle(&mut self) -> (Mock, Handle) {
let (inner, handle) = Inner::new(self.actions.clone());
let mock = Mock { inner };
(mock, handle)
}
}
impl Handle {
/// Sequence a `read` operation.
///
/// The next operation in the mock's script will be to expect a `read` call
/// and return `buf`.
pub fn read(&mut self, buf: &[u8]) -> &mut Self {
self.tx.send(Action::Read(buf.into())).unwrap();
self
}
/// Sequence a `read` operation error.
///
/// The next operation in the mock's script will be to expect a `read` call
/// and return `error`.
pub fn read_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.tx.send(Action::ReadError(error)).unwrap();
self
}
/// Sequence a `write` operation.
///
/// The next operation in the mock's script will be to expect a `write`
/// call.
pub fn write(&mut self, buf: &[u8]) -> &mut Self {
self.tx.send(Action::Write(buf.into())).unwrap();
self
}
/// Sequence a `write` operation error.
///
/// The next operation in the mock's script will be to expect a `write`
/// call error.
pub fn write_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.tx.send(Action::WriteError(error)).unwrap();
self
}
}
impl Inner {
fn new(actions: VecDeque<Action>) -> (Inner, Handle) {
let (tx, mut rx) = mpsc::unbounded_channel();
let rx = Box::pin(async_stream::stream! {
while let Some(item) = rx.recv().await {
yield item;
}
});
let inner = Inner {
actions,
sleep: None,
read_wait: None,
rx,
waiting: None,
};
let handle = Handle { tx };
(inner, handle)
}
fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Action>> {
Pin::new(&mut self.rx).poll_next(cx)
}
fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
match self.action() {
Some(&mut Action::Read(ref mut data)) => {
// Figure out how much to copy
let n = cmp::min(dst.remaining(), data.len());
// Copy the data into the `dst` slice
dst.put_slice(&data[..n]);
// Drain the data from the source
data.drain(..n);
Ok(())
}
Some(&mut Action::ReadError(ref mut err)) => {
// As the
let err = err.take().expect("Should have been removed from actions.");
let err = Arc::try_unwrap(err).expect("There are no other references.");
Err(err)
}
Some(_) => {
// Either waiting or expecting a write
Err(io::ErrorKind::WouldBlock.into())
}
None => Ok(()),
}
}
fn write(&mut self, mut src: &[u8]) -> io::Result<usize> {
let mut ret = 0;
if self.actions.is_empty() {
return Err(io::ErrorKind::BrokenPipe.into());
}
if let Some(&mut Action::Wait(..)) = self.action() {
return Err(io::ErrorKind::WouldBlock.into());
}
if let Some(&mut Action::WriteError(ref mut err)) = self.action() {
let err = err.take().expect("Should have been removed from actions.");
let err = Arc::try_unwrap(err).expect("There are no other references.");
return Err(err);
}
for i in 0..self.actions.len() {
match self.actions[i] {
Action::Write(ref mut expect) => {
let n = cmp::min(src.len(), expect.len());
assert_eq!(&src[..n], &expect[..n]);
// Drop data that was matched
expect.drain(..n);
src = &src[n..];
ret += n;
if src.is_empty() {
return Ok(ret);
}
}
Action::Wait(..) | Action::WriteError(..) => {
break;
}
_ => {}
}
// TODO: remove write
}
Ok(ret)
}
fn remaining_wait(&mut self) -> Option<Duration> {
match self.action() {
Some(&mut Action::Wait(dur)) => Some(dur),
_ => None,
}
}
fn action(&mut self) -> Option<&mut Action> {
loop {
if self.actions.is_empty() {
return None;
}
match self.actions[0] {
Action::Read(ref mut data) => {
if !data.is_empty() {
break;
}
}
Action::Write(ref mut data) => {
if !data.is_empty() {
break;
}
}
Action::Wait(ref mut dur) => {
if let Some(until) = self.waiting {
let now = Instant::now();
if now < until {
break;
}
} else {
self.waiting = Some(Instant::now() + *dur);
break;
}
}
Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => {
if error.is_some() {
break;
}
}
}
let _action = self.actions.pop_front();
}
self.actions.front_mut()
}
}
// ===== impl Inner =====
impl Mock {
fn maybe_wakeup_reader(&mut self) {
match self.inner.action() {
Some(&mut Action::Read(_)) | Some(&mut Action::ReadError(_)) | None => {
if let Some(waker) = self.inner.read_wait.take() {
waker.wake();
}
}
_ => {}
}
}
}
impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
}
// If a sleep is set, it has already fired
self.inner.sleep = None;
// Capture 'filled' to monitor if it changed
let filled = buf.filled().len();
match self.inner.read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
let until = Instant::now() + rem;
self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
} else {
self.inner.read_wait = Some(cx.waker().clone());
return Poll::Pending;
}
}
Ok(()) => {
if buf.filled().len() == filled {
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
return Poll::Ready(Ok(()));
}
}
} else {
return Poll::Ready(Ok(()));
}
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
}
impl AsyncWrite for Mock {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
}
// If a sleep is set, it has already fired
self.inner.sleep = None;
match self.inner.write(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
let until = Instant::now() + rem;
self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
} else {
panic!("unexpected WouldBlock");
}
}
Ok(0) => {
// TODO: Is this correct?
if !self.inner.actions.is_empty() {
return Poll::Pending;
}
// TODO: Extract
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
panic!("unexpected write");
}
}
}
ret => {
self.maybe_wakeup_reader();
return Poll::Ready(ret);
}
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
/// Ensures that Mock isn't dropped with data "inside".
impl Drop for Mock {
fn drop(&mut self) {
// Avoid double panicking, since makes debugging much harder.
if std::thread::panicking() {
return;
}
self.inner.actions.iter().for_each(|a| match a {
Action::Read(data) => assert!(data.is_empty(), "There is still data left to read."),
Action::Write(data) => assert!(data.is_empty(), "There is still data left to write."),
_ => (),
})
}
}
/*
/// Returns `true` if called from the context of a futures-rs Task
fn is_task_ctx() -> bool {
use std::panic;
// Save the existing panic hook
let h = panic::take_hook();
// Install a new one that does nothing
panic::set_hook(Box::new(|_| {}));
// Attempt to call the fn
let r = panic::catch_unwind(|| task::current()).is_ok();
// Re-install the old one
panic::set_hook(h);
// Return the result
r
}
*/
impl fmt::Debug for Inner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Inner {{...}}")
}
}