| use crate::codec::decoder::Decoder; |
| use crate::codec::encoder::Encoder; |
| |
| use futures_core::Stream; |
| use tokio::io::{AsyncRead, AsyncWrite}; |
| |
| use bytes::BytesMut; |
| use futures_core::ready; |
| use futures_sink::Sink; |
| use pin_project_lite::pin_project; |
| use std::borrow::{Borrow, BorrowMut}; |
| use std::io; |
| use std::pin::Pin; |
| use std::task::{Context, Poll}; |
| use tracing::trace; |
| |
| pin_project! { |
| #[derive(Debug)] |
| pub(crate) struct FramedImpl<T, U, State> { |
| #[pin] |
| pub(crate) inner: T, |
| pub(crate) state: State, |
| pub(crate) codec: U, |
| } |
| } |
| |
| const INITIAL_CAPACITY: usize = 8 * 1024; |
| |
| #[derive(Debug)] |
| pub(crate) struct ReadFrame { |
| pub(crate) eof: bool, |
| pub(crate) is_readable: bool, |
| pub(crate) buffer: BytesMut, |
| pub(crate) has_errored: bool, |
| } |
| |
| pub(crate) struct WriteFrame { |
| pub(crate) buffer: BytesMut, |
| pub(crate) backpressure_boundary: usize, |
| } |
| |
| #[derive(Default)] |
| pub(crate) struct RWFrames { |
| pub(crate) read: ReadFrame, |
| pub(crate) write: WriteFrame, |
| } |
| |
| impl Default for ReadFrame { |
| fn default() -> Self { |
| Self { |
| eof: false, |
| is_readable: false, |
| buffer: BytesMut::with_capacity(INITIAL_CAPACITY), |
| has_errored: false, |
| } |
| } |
| } |
| |
| impl Default for WriteFrame { |
| fn default() -> Self { |
| Self { |
| buffer: BytesMut::with_capacity(INITIAL_CAPACITY), |
| backpressure_boundary: INITIAL_CAPACITY, |
| } |
| } |
| } |
| |
| impl From<BytesMut> for ReadFrame { |
| fn from(mut buffer: BytesMut) -> Self { |
| let size = buffer.capacity(); |
| if size < INITIAL_CAPACITY { |
| buffer.reserve(INITIAL_CAPACITY - size); |
| } |
| |
| Self { |
| buffer, |
| is_readable: size > 0, |
| eof: false, |
| has_errored: false, |
| } |
| } |
| } |
| |
| impl From<BytesMut> for WriteFrame { |
| fn from(mut buffer: BytesMut) -> Self { |
| let size = buffer.capacity(); |
| if size < INITIAL_CAPACITY { |
| buffer.reserve(INITIAL_CAPACITY - size); |
| } |
| |
| Self { |
| buffer, |
| backpressure_boundary: INITIAL_CAPACITY, |
| } |
| } |
| } |
| |
| impl Borrow<ReadFrame> for RWFrames { |
| fn borrow(&self) -> &ReadFrame { |
| &self.read |
| } |
| } |
| impl BorrowMut<ReadFrame> for RWFrames { |
| fn borrow_mut(&mut self) -> &mut ReadFrame { |
| &mut self.read |
| } |
| } |
| impl Borrow<WriteFrame> for RWFrames { |
| fn borrow(&self) -> &WriteFrame { |
| &self.write |
| } |
| } |
| impl BorrowMut<WriteFrame> for RWFrames { |
| fn borrow_mut(&mut self) -> &mut WriteFrame { |
| &mut self.write |
| } |
| } |
| impl<T, U, R> Stream for FramedImpl<T, U, R> |
| where |
| T: AsyncRead, |
| U: Decoder, |
| R: BorrowMut<ReadFrame>, |
| { |
| type Item = Result<U::Item, U::Error>; |
| |
| fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| use crate::util::poll_read_buf; |
| |
| let mut pinned = self.project(); |
| let state: &mut ReadFrame = pinned.state.borrow_mut(); |
| // The following loops implements a state machine with each state corresponding |
| // to a combination of the `is_readable` and `eof` flags. States persist across |
| // loop entries and most state transitions occur with a return. |
| // |
| // The initial state is `reading`. |
| // |
| // | state | eof | is_readable | has_errored | |
| // |---------|-------|-------------|-------------| |
| // | reading | false | false | false | |
| // | framing | false | true | false | |
| // | pausing | true | true | false | |
| // | paused | true | false | false | |
| // | errored | <any> | <any> | true | |
| // `decode_eof` returns Err |
| // ┌────────────────────────────────────────────────────────┐ |
| // `decode_eof` returns │ │ |
| // `Ok(Some)` │ │ |
| // ┌─────┐ │ `decode_eof` returns After returning │ |
| // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ |
| // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ |
| // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ |
| // Pending read │ │ │ │ │ │ |
| // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ |
| // │ │ │ ┌──────┐ │ Pending │ │ |
| // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ |
| // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ |
| // └──┬─▲────┘ └─────┬──┬┘ │ │ |
| // │ │ │ │ `decode` returns Err │ │ |
| // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ |
| // │ read returns Err │ |
| // └────────────────────────────────────────────────────────────────────────────────────────────┘ |
| loop { |
| // Return `None` if we have encountered an error from the underlying decoder |
| // See: https://github.com/tokio-rs/tokio/issues/3976 |
| if state.has_errored { |
| // preparing has_errored -> paused |
| trace!("Returning None and setting paused"); |
| state.is_readable = false; |
| state.has_errored = false; |
| return Poll::Ready(None); |
| } |
| |
| // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", |
| // i.e. it _might_ contain data consumable as a frame or closing frame. |
| // Both signal that there is no such data by returning `None`. |
| // |
| // If `decode` couldn't read a frame and the upstream source has returned eof, |
| // `decode_eof` will attempt to decode the remaining bytes as closing frames. |
| // |
| // If the underlying AsyncRead is resumable, we may continue after an EOF, |
| // but must finish emitting all of it's associated `decode_eof` frames. |
| // Furthermore, we don't want to emit any `decode_eof` frames on retried |
| // reads after an EOF unless we've actually read more data. |
| if state.is_readable { |
| // pausing or framing |
| if state.eof { |
| // pausing |
| let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { |
| trace!("Got an error, going to errored state"); |
| state.has_errored = true; |
| err |
| })?; |
| if frame.is_none() { |
| state.is_readable = false; // prepare pausing -> paused |
| } |
| // implicit pausing -> pausing or pausing -> paused |
| return Poll::Ready(frame.map(Ok)); |
| } |
| |
| // framing |
| trace!("attempting to decode a frame"); |
| |
| if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { |
| trace!("Got an error, going to errored state"); |
| state.has_errored = true; |
| op |
| })? { |
| trace!("frame decoded from buffer"); |
| // implicit framing -> framing |
| return Poll::Ready(Some(Ok(frame))); |
| } |
| |
| // framing -> reading |
| state.is_readable = false; |
| } |
| // reading or paused |
| // If we can't build a frame yet, try to read more data and try again. |
| // Make sure we've got room for at least one byte to read to ensure |
| // that we don't get a spurious 0 that looks like EOF. |
| state.buffer.reserve(1); |
| let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( |
| |err| { |
| trace!("Got an error, going to errored state"); |
| state.has_errored = true; |
| err |
| }, |
| )? { |
| Poll::Ready(ct) => ct, |
| // implicit reading -> reading or implicit paused -> paused |
| Poll::Pending => return Poll::Pending, |
| }; |
| if bytect == 0 { |
| if state.eof { |
| // We're already at an EOF, and since we've reached this path |
| // we're also not readable. This implies that we've already finished |
| // our `decode_eof` handling, so we can simply return `None`. |
| // implicit paused -> paused |
| return Poll::Ready(None); |
| } |
| // prepare reading -> paused |
| state.eof = true; |
| } else { |
| // prepare paused -> framing or noop reading -> framing |
| state.eof = false; |
| } |
| |
| // paused -> framing or reading -> framing or reading -> pausing |
| state.is_readable = true; |
| } |
| } |
| } |
| |
| impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W> |
| where |
| T: AsyncWrite, |
| U: Encoder<I>, |
| U::Error: From<io::Error>, |
| W: BorrowMut<WriteFrame>, |
| { |
| type Error = U::Error; |
| |
| fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary { |
| self.as_mut().poll_flush(cx) |
| } else { |
| Poll::Ready(Ok(())) |
| } |
| } |
| |
| fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { |
| let pinned = self.project(); |
| pinned |
| .codec |
| .encode(item, &mut pinned.state.borrow_mut().buffer)?; |
| Ok(()) |
| } |
| |
| fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| use crate::util::poll_write_buf; |
| trace!("flushing framed transport"); |
| let mut pinned = self.project(); |
| |
| while !pinned.state.borrow_mut().buffer.is_empty() { |
| let WriteFrame { buffer, .. } = pinned.state.borrow_mut(); |
| trace!(remaining = buffer.len(), "writing;"); |
| |
| let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; |
| |
| if n == 0 { |
| return Poll::Ready(Err(io::Error::new( |
| io::ErrorKind::WriteZero, |
| "failed to \ |
| write frame to transport", |
| ) |
| .into())); |
| } |
| } |
| |
| // Try flushing the underlying IO |
| ready!(pinned.inner.poll_flush(cx))?; |
| |
| trace!("framed transport flushed"); |
| Poll::Ready(Ok(())) |
| } |
| |
| fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| ready!(self.as_mut().poll_flush(cx))?; |
| ready!(self.project().inner.poll_shutdown(cx))?; |
| |
| Poll::Ready(Ok(())) |
| } |
| } |