blob: 8ae70ce13bccc06971792cb2f8f76d8a10d104f1 [file] [log] [blame]
//! `TcpStream` split support.
//!
//! A `TcpStream` can be split into a `ReadHalf` and a
//! `WriteHalf` with the `TcpStream::split` method. `ReadHalf`
//! implements `AsyncRead` while `WriteHalf` implements `AsyncWrite`.
//!
//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
//! split has no associated overhead and enforces all invariants at the type
//! level.
use crate::future::poll_fn;
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::TcpStream;
use std::io;
use std::net::Shutdown;
use std::pin::Pin;
use std::task::{Context, Poll};
/// Borrowed read half of a [`TcpStream`], created by [`split`].
///
/// Reading from a `ReadHalf` is usually done using the convenience methods found on the
/// [`AsyncReadExt`] trait.
///
/// [`TcpStream`]: TcpStream
/// [`split`]: TcpStream::split()
/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
#[derive(Debug)]
pub struct ReadHalf<'a>(&'a TcpStream);
/// Borrowed write half of a [`TcpStream`], created by [`split`].
///
/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
/// shut down the TCP stream in the write direction.
///
/// Writing to an `WriteHalf` is usually done using the convenience methods found
/// on the [`AsyncWriteExt`] trait.
///
/// [`TcpStream`]: TcpStream
/// [`split`]: TcpStream::split()
/// [`AsyncWrite`]: trait@crate::io::AsyncWrite
/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
#[derive(Debug)]
pub struct WriteHalf<'a>(&'a TcpStream);
pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
(ReadHalf(&*stream), WriteHalf(&*stream))
}
impl ReadHalf<'_> {
/// Attempt to receive data on the socket, without removing that data from
/// the queue, registering the current task for wakeup if data is not yet
/// available.
///
/// Note that on multiple calls to `poll_peek` or `poll_read`, only the
/// `Waker` from the `Context` passed to the most recent call is scheduled
/// to receive a wakeup.
///
/// See the [`TcpStream::poll_peek`] level documentation for more details.
///
/// # Examples
///
/// ```no_run
/// use tokio::io::{self, ReadBuf};
/// use tokio::net::TcpStream;
///
/// use futures::future::poll_fn;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut stream = TcpStream::connect("127.0.0.1:8000").await?;
/// let (mut read_half, _) = stream.split();
/// let mut buf = [0; 10];
/// let mut buf = ReadBuf::new(&mut buf);
///
/// poll_fn(|cx| {
/// read_half.poll_peek(cx, &mut buf)
/// }).await?;
///
/// Ok(())
/// }
/// ```
///
/// [`TcpStream::poll_peek`]: TcpStream::poll_peek
pub fn poll_peek(
&mut self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> {
self.0.poll_peek(cx, buf)
}
/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// See the [`TcpStream::peek`] level documentation for more details.
///
/// [`TcpStream::peek`]: TcpStream::peek
///
/// # Examples
///
/// ```no_run
/// use tokio::net::TcpStream;
/// use tokio::io::AsyncReadExt;
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// // Connect to a peer
/// let mut stream = TcpStream::connect("127.0.0.1:8080").await?;
/// let (mut read_half, _) = stream.split();
///
/// let mut b1 = [0; 10];
/// let mut b2 = [0; 10];
///
/// // Peek at the data
/// let n = read_half.peek(&mut b1).await?;
///
/// // Read the data
/// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
/// assert_eq!(&b1[..n], &b2[..n]);
///
/// Ok(())
/// }
/// ```
///
/// The [`read`] method is defined on the [`AsyncReadExt`] trait.
///
/// [`read`]: fn@crate::io::AsyncReadExt::read
/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut buf = ReadBuf::new(buf);
poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
}
}
impl AsyncRead for ReadHalf<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.0.poll_read_priv(cx, buf)
}
}
impl AsyncWrite for WriteHalf<'_> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.0.poll_write_priv(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.0.poll_write_vectored_priv(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.0.is_write_vectored()
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}
// `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.shutdown_std(Shutdown::Write).into()
}
}
impl AsRef<TcpStream> for ReadHalf<'_> {
fn as_ref(&self) -> &TcpStream {
self.0
}
}
impl AsRef<TcpStream> for WriteHalf<'_> {
fn as_ref(&self) -> &TcpStream {
self.0
}
}