blob: e067ac4891cad3f23ccd7f703e133c22d4dfc30e [file] [log] [blame]
// Copyright 2019 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::cmp;
use std::io;
use std::io::Write;
use std::iter::FromIterator;
use std::marker::PhantomData;
use std::mem::size_of;
use std::mem::MaybeUninit;
use std::ptr::copy_nonoverlapping;
use std::sync::Arc;
use anyhow::Context;
use base::FileReadWriteAtVolatile;
use base::FileReadWriteVolatile;
use base::VolatileSlice;
use cros_async::MemRegion;
use cros_async::MemRegionIter;
use data_model::Le16;
use data_model::Le32;
use data_model::Le64;
use disk::AsyncDisk;
use smallvec::SmallVec;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use zerocopy::AsBytes;
use zerocopy::FromBytes;
use zerocopy::FromZeroes;
use super::DescriptorChain;
use crate::virtio::SplitDescriptorChain;
struct DescriptorChainRegions {
regions: SmallVec<[MemRegion; 2]>,
// Index of the current region in `regions`.
current_region_index: usize,
// Number of bytes consumed in the current region.
current_region_offset: usize,
// Total bytes consumed in the entire descriptor chain.
bytes_consumed: usize,
}
impl DescriptorChainRegions {
fn new(regions: SmallVec<[MemRegion; 2]>) -> Self {
DescriptorChainRegions {
regions,
current_region_index: 0,
current_region_offset: 0,
bytes_consumed: 0,
}
}
fn available_bytes(&self) -> usize {
// This is guaranteed not to overflow because the total length of the chain is checked
// during all creations of `DescriptorChain` (see `DescriptorChain::new()`).
self.get_remaining_regions()
.fold(0usize, |count, region| count + region.len)
}
fn bytes_consumed(&self) -> usize {
self.bytes_consumed
}
/// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
/// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
/// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
/// to `consume` will return the same data.
fn get_remaining_regions(&self) -> MemRegionIter {
MemRegionIter::new(&self.regions[self.current_region_index..])
.skip_bytes(self.current_region_offset)
}
/// Like `get_remaining_regions` but guarantees that the combined length of all the returned
/// iovecs is not greater than `count`. The combined length of the returned iovecs may be less
/// than `count` but will always be greater than 0 as long as there is still space left in the
/// `DescriptorChain`.
fn get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter {
MemRegionIter::new(&self.regions[self.current_region_index..])
.skip_bytes(self.current_region_offset)
.take_bytes(count)
}
/// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
/// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
/// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
/// calls to `get` with no intervening calls to `consume` will return the same data.
fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
self.get_remaining_regions()
.filter_map(|region| {
mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
.ok()
})
.collect()
}
/// Like 'get_remaining_regions_with_count' except convert the offsets to volatile slices in
/// the 'GuestMemory' given by 'mem'.
fn get_remaining_with_count<'mem>(
&self,
mem: &'mem GuestMemory,
count: usize,
) -> SmallVec<[VolatileSlice<'mem>; 16]> {
self.get_remaining_regions_with_count(count)
.filter_map(|region| {
mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
.ok()
})
.collect()
}
/// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
/// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
fn consume(&mut self, mut count: usize) {
while let Some(region) = self.regions.get(self.current_region_index) {
let region_remaining = region.len - self.current_region_offset;
if count < region_remaining {
// The remaining count to consume is less than the remaining un-consumed length of
// the current region. Adjust the region offset without advancing to the next region
// and stop.
self.current_region_offset += count;
self.bytes_consumed += count;
return;
}
// The current region has been exhausted. Advance to the next region.
self.current_region_index += 1;
self.current_region_offset = 0;
self.bytes_consumed += region_remaining;
count -= region_remaining;
}
}
fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
let mut other = DescriptorChainRegions {
regions: self.regions.clone(),
current_region_index: self.current_region_index,
current_region_offset: self.current_region_offset,
bytes_consumed: self.bytes_consumed,
};
other.consume(offset);
other.bytes_consumed = 0;
let mut rem = offset;
let mut end = self.current_region_index;
for region in &mut self.regions[self.current_region_index..] {
if rem <= region.len {
region.len = rem;
break;
}
end += 1;
rem -= region.len;
}
self.regions.truncate(end + 1);
other
}
}
/// Provides high-level interface over the sequence of memory regions
/// defined by readable descriptors in the descriptor chain.
///
/// Note that virtio spec requires driver to place any device-writable
/// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
/// Reader will skip iterating over descriptor chain when first writable
/// descriptor is encountered.
pub struct Reader {
mem: GuestMemory,
regions: DescriptorChainRegions,
}
/// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain.
pub struct ReaderIterator<'a, T: FromBytes> {
reader: &'a mut Reader,
phantom: PhantomData<T>,
}
impl<'a, T: FromBytes> Iterator for ReaderIterator<'a, T> {
type Item = io::Result<T>;
fn next(&mut self) -> Option<io::Result<T>> {
if self.reader.available_bytes() == 0 {
None
} else {
Some(self.reader.read_obj())
}
}
}
impl Reader {
/// Construct a new Reader wrapper over `readable_regions`.
pub fn new_from_regions(
mem: &GuestMemory,
readable_regions: SmallVec<[MemRegion; 2]>,
) -> Reader {
Reader {
mem: mem.clone(),
regions: DescriptorChainRegions::new(readable_regions),
}
}
/// Reads an object from the descriptor chain buffer without consuming it.
pub fn peek_obj<T: FromBytes>(&self) -> io::Result<T> {
let mut obj = MaybeUninit::uninit();
// SAFETY: We pass a valid pointer and size of `obj`.
let copied = unsafe {
copy_regions_to_mut_ptr(
&self.mem,
self.get_remaining_regions(),
obj.as_mut_ptr() as *mut u8,
size_of::<T>(),
)?
};
if copied != size_of::<T>() {
return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
}
// SAFETY: `FromBytes` guarantees any set of initialized bytes is a valid value for `T`, and
// we initialized all bytes in `obj` in the copy above.
Ok(unsafe { obj.assume_init() })
}
/// Reads and consumes an object from the descriptor chain buffer.
pub fn read_obj<T: FromBytes>(&mut self) -> io::Result<T> {
let obj = self.peek_obj::<T>()?;
self.consume(size_of::<T>());
Ok(obj)
}
/// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
/// them as a collection. Returns an error if the size of the remaining data is indivisible by
/// the size of an object of type `T`.
pub fn collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C {
self.iter().collect()
}
/// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`.
/// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
/// doesn't require the objects to be stored in a separate collection.
pub fn iter<T: FromBytes>(&mut self) -> ReaderIterator<T> {
ReaderIterator {
reader: self,
phantom: PhantomData,
}
}
/// Reads data into a volatile slice up to the minimum of the slice's length or the number of
/// bytes remaining. Returns the number of bytes read.
pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
let mut read = 0usize;
let mut dst = slice;
for src in self.get_remaining() {
src.copy_to_volatile_slice(dst);
let copied = std::cmp::min(src.size(), dst.size());
read += copied;
dst = match dst.offset(copied) {
Ok(v) => v,
Err(_) => break, // The slice is fully consumed
};
}
self.regions.consume(read);
read
}
/// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback
/// `cb`.
pub fn read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>(
&mut self,
cb: C,
count: usize,
) -> usize {
let iovs = self.regions.get_remaining_with_count(&self.mem, count);
let written = cb(&iovs[..]);
self.regions.consume(written);
written
}
/// Reads data from the descriptor chain buffer into a writable object.
/// Returns the number of bytes read from the descriptor chain buffer.
/// The number of bytes read can be less than `count` if there isn't
/// enough data in the descriptor chain buffer.
pub fn read_to<F: FileReadWriteVolatile>(
&mut self,
mut dst: F,
count: usize,
) -> io::Result<usize> {
let iovs = self.regions.get_remaining_with_count(&self.mem, count);
let written = dst.write_vectored_volatile(&iovs[..])?;
self.regions.consume(written);
Ok(written)
}
/// Reads data from the descriptor chain buffer into a File at offset `off`.
/// Returns the number of bytes read from the descriptor chain buffer.
/// The number of bytes read can be less than `count` if there isn't
/// enough data in the descriptor chain buffer.
pub fn read_to_at<F: FileReadWriteAtVolatile>(
&mut self,
mut dst: F,
count: usize,
off: u64,
) -> io::Result<usize> {
let iovs = self.regions.get_remaining_with_count(&self.mem, count);
let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
self.regions.consume(written);
Ok(written)
}
/// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
/// returning an error if 'count' bytes can't be read.
pub fn read_exact_to<F: FileReadWriteVolatile>(
&mut self,
mut dst: F,
mut count: usize,
) -> io::Result<()> {
while count > 0 {
match self.read_to(&mut dst, count) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
))
}
Ok(n) => count -= n,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
/// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
/// returning an error if 'count' bytes can't be read.
pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
&mut self,
mut dst: F,
mut count: usize,
mut off: u64,
) -> io::Result<()> {
while count > 0 {
match self.read_to_at(&mut dst, count, off) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
))
}
Ok(n) => {
count -= n;
off += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
/// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
/// Returns the number of bytes read from the descriptor chain buffer.
/// The number of bytes read can be less than `count` if there isn't
/// enough data in the descriptor chain buffer.
pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
&mut self,
dst: &F,
count: usize,
off: u64,
) -> disk::Result<usize> {
let written = dst
.write_from_mem(
off,
Arc::new(self.mem.clone()),
self.regions.get_remaining_regions_with_count(count),
)
.await?;
self.regions.consume(written);
Ok(written)
}
/// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
/// not enough data can be read.
pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
&mut self,
dst: &F,
mut count: usize,
mut off: u64,
) -> disk::Result<()> {
while count > 0 {
let nread = self.read_to_at_fut(dst, count, off).await?;
if nread == 0 {
return Err(disk::Error::ReadingData(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to write whole buffer",
)));
}
count -= nread;
off += nread as u64;
}
Ok(())
}
/// Returns number of bytes available for reading. May return an error if the combined
/// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
pub fn available_bytes(&self) -> usize {
self.regions.available_bytes()
}
/// Returns number of bytes already read from the descriptor chain buffer.
pub fn bytes_read(&self) -> usize {
self.regions.bytes_consumed()
}
pub fn get_remaining_regions(&self) -> MemRegionIter {
self.regions.get_remaining_regions()
}
/// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
/// Calling this method does not actually consume any data from the `Reader` and callers should
/// call `consume` to advance the `Reader`.
pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
self.regions.get_remaining(&self.mem)
}
/// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
/// remaining data left in this `Reader`, then all remaining data will be consumed.
pub fn consume(&mut self, amt: usize) {
self.regions.consume(amt)
}
/// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
/// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
/// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
/// returned `Reader` will not be able to read any bytes.
pub fn split_at(&mut self, offset: usize) -> Reader {
Reader {
mem: self.mem.clone(),
regions: self.regions.split_at(offset),
}
}
}
/// Copy up to `size` bytes from `src` into `dst`.
///
/// Returns the total number of bytes copied.
///
/// # Safety
///
/// The caller must ensure that it is safe to write `size` bytes of data into `dst`.
///
/// After the function returns, it is only safe to assume that the number of bytes indicated by the
/// return value (which may be less than the requested `size`) have been initialized. Bytes beyond
/// that point are not initialized by this function.
unsafe fn copy_regions_to_mut_ptr(
mem: &GuestMemory,
src: MemRegionIter,
dst: *mut u8,
size: usize,
) -> io::Result<usize> {
let mut copied = 0;
for src_region in src {
if copied >= size {
break;
}
let remaining = size - copied;
let count = cmp::min(remaining, src_region.len);
let vslice = mem
.get_slice_at_addr(GuestAddress(src_region.offset), count)
.map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))?;
// SAFETY: `get_slice_at_addr()` verified that the region points to valid memory, and
// the `count` calculation ensures we will write at most `size` bytes into `dst`.
unsafe {
copy_nonoverlapping(vslice.as_ptr(), dst.add(copied), count);
}
copied += count;
}
Ok(copied)
}
impl io::Read for Reader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// SAFETY: We pass a valid pointer and size combination derived from `buf`.
let total = unsafe {
copy_regions_to_mut_ptr(
&self.mem,
self.regions.get_remaining_regions(),
buf.as_mut_ptr(),
buf.len(),
)?
};
self.regions.consume(total);
Ok(total)
}
}
/// Provides high-level interface over the sequence of memory regions
/// defined by writable descriptors in the descriptor chain.
///
/// Note that virtio spec requires driver to place any device-writable
/// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
/// Writer will start iterating the descriptors from the first writable one and will
/// assume that all following descriptors are writable.
pub struct Writer {
mem: GuestMemory,
regions: DescriptorChainRegions,
}
impl Writer {
/// Construct a new Writer wrapper over `writable_regions`.
pub fn new_from_regions(
mem: &GuestMemory,
writable_regions: SmallVec<[MemRegion; 2]>,
) -> Writer {
Writer {
mem: mem.clone(),
regions: DescriptorChainRegions::new(writable_regions),
}
}
/// Writes an object to the descriptor chain buffer.
pub fn write_obj<T: AsBytes>(&mut self, val: T) -> io::Result<()> {
self.write_all(val.as_bytes())
}
/// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
/// this doesn't require the values to be stored in an intermediate collection first. It also
/// allows callers to choose which elements in a collection to write, for example by using the
/// `filter` or `take` methods of the `Iterator` trait.
pub fn write_iter<T: AsBytes, I: Iterator<Item = T>>(&mut self, mut iter: I) -> io::Result<()> {
iter.try_for_each(|v| self.write_obj(v))
}
/// Writes a collection of objects into the descriptor chain buffer.
pub fn consume<T: AsBytes, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()> {
self.write_iter(vals.into_iter())
}
/// Returns number of bytes available for writing. May return an error if the combined
/// lengths of all the buffers in the DescriptorChain would cause an overflow.
pub fn available_bytes(&self) -> usize {
self.regions.available_bytes()
}
/// Reads data into a volatile slice up to the minimum of the slice's length or the number of
/// bytes remaining. Returns the number of bytes read.
pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
let mut written = 0usize;
let mut src = slice;
for dst in self.get_remaining() {
src.copy_to_volatile_slice(dst);
let copied = std::cmp::min(src.size(), dst.size());
written += copied;
src = match src.offset(copied) {
Ok(v) => v,
Err(_) => break, // The slice is fully consumed
};
}
self.regions.consume(written);
written
}
/// Writes data to the descriptor chain buffer from a readable object.
/// Returns the number of bytes written to the descriptor chain buffer.
/// The number of bytes written can be less than `count` if
/// there isn't enough data in the descriptor chain buffer.
pub fn write_from<F: FileReadWriteVolatile>(
&mut self,
mut src: F,
count: usize,
) -> io::Result<usize> {
let iovs = self.regions.get_remaining_with_count(&self.mem, count);
let read = src.read_vectored_volatile(&iovs[..])?;
self.regions.consume(read);
Ok(read)
}
/// Writes data to the descriptor chain buffer from a File at offset `off`.
/// Returns the number of bytes written to the descriptor chain buffer.
/// The number of bytes written can be less than `count` if
/// there isn't enough data in the descriptor chain buffer.
pub fn write_from_at<F: FileReadWriteAtVolatile>(
&mut self,
mut src: F,
count: usize,
off: u64,
) -> io::Result<usize> {
let iovs = self.regions.get_remaining_with_count(&self.mem, count);
let read = src.read_vectored_at_volatile(&iovs[..], off)?;
self.regions.consume(read);
Ok(read)
}
pub fn write_all_from<F: FileReadWriteVolatile>(
&mut self,
mut src: F,
mut count: usize,
) -> io::Result<()> {
while count > 0 {
match self.write_from(&mut src, count) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
))
}
Ok(n) => count -= n,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
&mut self,
mut src: F,
mut count: usize,
mut off: u64,
) -> io::Result<()> {
while count > 0 {
match self.write_from_at(&mut src, count, off) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
))
}
Ok(n) => {
count -= n;
off += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
/// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
/// Returns the number of bytes written to the descriptor chain buffer.
/// The number of bytes written can be less than `count` if
/// there isn't enough data in the descriptor chain buffer.
pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
&mut self,
src: &F,
count: usize,
off: u64,
) -> disk::Result<usize> {
let read = src
.read_to_mem(
off,
Arc::new(self.mem.clone()),
self.regions.get_remaining_regions_with_count(count),
)
.await?;
self.regions.consume(read);
Ok(read)
}
pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
&mut self,
src: &F,
mut count: usize,
mut off: u64,
) -> disk::Result<()> {
while count > 0 {
let nwritten = self.write_from_at_fut(src, count, off).await?;
if nwritten == 0 {
return Err(disk::Error::WritingData(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to write whole buffer",
)));
}
count -= nwritten;
off += nwritten as u64;
}
Ok(())
}
/// Returns number of bytes already written to the descriptor chain buffer.
pub fn bytes_written(&self) -> usize {
self.regions.bytes_consumed()
}
pub fn get_remaining_regions(&self) -> MemRegionIter {
self.regions.get_remaining_regions()
}
/// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
/// Calling this method does not actually advance the current position of the `Writer` in the
/// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
/// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
/// result in that that data being overwritten the next time data is written into the `Writer`.
pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
self.regions.get_remaining(&self.mem)
}
/// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
/// remaining data left in this `Reader`, then all remaining data will be consumed.
pub fn consume_bytes(&mut self, amt: usize) {
self.regions.consume(amt)
}
/// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
/// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
/// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
/// the returned `Writer` will not be able to write any bytes.
pub fn split_at(&mut self, offset: usize) -> Writer {
Writer {
mem: self.mem.clone(),
regions: self.regions.split_at(offset),
}
}
}
impl io::Write for Writer {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut rem = buf;
let mut total = 0;
for b in self.regions.get_remaining(&self.mem) {
if rem.is_empty() {
break;
}
let count = cmp::min(rem.len(), b.size());
// SAFETY:
// Safe because we have already verified that `vs` points to valid memory.
unsafe {
copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
}
rem = &rem[count..];
total += count;
}
self.regions.consume(total);
Ok(total)
}
fn flush(&mut self) -> io::Result<()> {
// Nothing to flush since the writes go straight into the buffer.
Ok(())
}
}
const VIRTQ_DESC_F_NEXT: u16 = 0x1;
const VIRTQ_DESC_F_WRITE: u16 = 0x2;
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum DescriptorType {
Readable,
Writable,
}
#[derive(Copy, Clone, Debug, FromZeroes, FromBytes, AsBytes)]
#[repr(C)]
struct virtq_desc {
addr: Le64,
len: Le32,
flags: Le16,
next: Le16,
}
/// Test utility function to create a descriptor chain in guest memory.
pub fn create_descriptor_chain(
memory: &GuestMemory,
descriptor_array_addr: GuestAddress,
mut buffers_start_addr: GuestAddress,
descriptors: Vec<(DescriptorType, u32)>,
spaces_between_regions: u32,
) -> anyhow::Result<DescriptorChain> {
let descriptors_len = descriptors.len();
for (index, (type_, size)) in descriptors.into_iter().enumerate() {
let mut flags = 0;
if let DescriptorType::Writable = type_ {
flags |= VIRTQ_DESC_F_WRITE;
}
if index + 1 < descriptors_len {
flags |= VIRTQ_DESC_F_NEXT;
}
let index = index as u16;
let desc = virtq_desc {
addr: buffers_start_addr.offset().into(),
len: size.into(),
flags: flags.into(),
next: (index + 1).into(),
};
let offset = size + spaces_between_regions;
buffers_start_addr = buffers_start_addr
.checked_add(offset as u64)
.context("Invalid buffers_start_addr)")?;
let _ = memory.write_obj_at_addr(
desc,
descriptor_array_addr
.checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
.context("Invalid descriptor_array_addr")?,
);
}
let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0);
DescriptorChain::new(chain, memory, 0)
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::Read;
use cros_async::Executor;
use tempfile::tempfile;
use tempfile::NamedTempFile;
use super::*;
#[test]
fn reader_test_simple_chain() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 8),
(Readable, 16),
(Readable, 18),
(Readable, 64),
],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
assert_eq!(reader.available_bytes(), 106);
assert_eq!(reader.bytes_read(), 0);
let mut buffer = [0u8; 64];
reader
.read_exact(&mut buffer)
.expect("read_exact should not fail here");
assert_eq!(reader.available_bytes(), 42);
assert_eq!(reader.bytes_read(), 64);
match reader.read(&mut buffer) {
Err(_) => panic!("read should not fail here"),
Ok(length) => assert_eq!(length, 42),
}
assert_eq!(reader.available_bytes(), 0);
assert_eq!(reader.bytes_read(), 106);
}
#[test]
fn writer_test_simple_chain() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Writable, 8),
(Writable, 16),
(Writable, 18),
(Writable, 64),
],
0,
)
.expect("create_descriptor_chain failed");
let writer = &mut chain.writer;
assert_eq!(writer.available_bytes(), 106);
assert_eq!(writer.bytes_written(), 0);
let buffer = [0; 64];
writer
.write_all(&buffer)
.expect("write_all should not fail here");
assert_eq!(writer.available_bytes(), 42);
assert_eq!(writer.bytes_written(), 64);
match writer.write(&buffer) {
Err(_) => panic!("write should not fail here"),
Ok(length) => assert_eq!(length, 42),
}
assert_eq!(writer.available_bytes(), 0);
assert_eq!(writer.bytes_written(), 106);
}
#[test]
fn reader_test_incompatible_chain() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Writable, 8)],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
assert_eq!(reader.available_bytes(), 0);
assert_eq!(reader.bytes_read(), 0);
assert!(reader.read_obj::<u8>().is_err());
assert_eq!(reader.available_bytes(), 0);
assert_eq!(reader.bytes_read(), 0);
}
#[test]
fn writer_test_incompatible_chain() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 8)],
0,
)
.expect("create_descriptor_chain failed");
let writer = &mut chain.writer;
assert_eq!(writer.available_bytes(), 0);
assert_eq!(writer.bytes_written(), 0);
assert!(writer.write_obj(0u8).is_err());
assert_eq!(writer.available_bytes(), 0);
assert_eq!(writer.bytes_written(), 0);
}
#[test]
fn reader_failing_io() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 256), (Readable, 256)],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
// Open a file in read-only mode so writes to it to trigger an I/O error.
let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
let mut ro_file = File::open(device_file).expect("failed to open device file");
reader
.read_exact_to(&mut ro_file, 512)
.expect_err("successfully read more bytes than SharedMemory size");
// The write above should have failed entirely, so we end up not writing any bytes at all.
assert_eq!(reader.available_bytes(), 512);
assert_eq!(reader.bytes_read(), 0);
}
#[test]
fn writer_failing_io() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Writable, 256), (Writable, 256)],
0,
)
.expect("create_descriptor_chain failed");
let writer = &mut chain.writer;
let mut file = tempfile().unwrap();
file.set_len(384).unwrap();
writer
.write_all_from(&mut file, 512)
.expect_err("successfully wrote more bytes than in SharedMemory");
assert_eq!(writer.available_bytes(), 128);
assert_eq!(writer.bytes_written(), 384);
}
#[test]
fn reader_writer_shared_chain() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let writer = &mut chain.writer;
assert_eq!(reader.bytes_read(), 0);
assert_eq!(writer.bytes_written(), 0);
let mut buffer = Vec::with_capacity(200);
assert_eq!(
reader
.read_to_end(&mut buffer)
.expect("read should not fail here"),
128
);
// The writable descriptors are only 68 bytes long.
writer
.write_all(&buffer[..68])
.expect("write should not fail here");
assert_eq!(reader.available_bytes(), 0);
assert_eq!(reader.bytes_read(), 128);
assert_eq!(writer.available_bytes(), 0);
assert_eq!(writer.bytes_written(), 68);
}
#[test]
fn reader_writer_shattered_object() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let secret: Le32 = 0x12345678.into();
// Create a descriptor chain with memory regions that are properly separated.
let mut chain_writer = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
123,
)
.expect("create_descriptor_chain failed");
let writer = &mut chain_writer.writer;
writer
.write_obj(secret)
.expect("write_obj should not fail here");
// Now create new descriptor chain pointing to the same memory and try to read it.
let mut chain_reader = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
123,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain_reader.reader;
match reader.read_obj::<Le32>() {
Err(_) => panic!("read_obj should not fail here"),
Ok(read_secret) => assert_eq!(read_secret, secret),
}
}
#[test]
fn reader_unexpected_eof() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 256), (Readable, 256)],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let mut buf = vec![0; 1024];
assert_eq!(
reader
.read_exact(&mut buf[..])
.expect_err("read more bytes than available")
.kind(),
io::ErrorKind::UnexpectedEof
);
}
#[test]
fn split_border() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let other = reader.split_at(32);
assert_eq!(reader.available_bytes(), 32);
assert_eq!(other.available_bytes(), 96);
}
#[test]
fn split_middle() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let other = reader.split_at(24);
assert_eq!(reader.available_bytes(), 24);
assert_eq!(other.available_bytes(), 104);
}
#[test]
fn split_end() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let other = reader.split_at(128);
assert_eq!(reader.available_bytes(), 128);
assert_eq!(other.available_bytes(), 0);
}
#[test]
fn split_beginning() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let other = reader.split_at(0);
assert_eq!(reader.available_bytes(), 0);
assert_eq!(other.available_bytes(), 128);
}
#[test]
fn split_outofbounds() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let other = reader.split_at(256);
assert_eq!(
other.available_bytes(),
0,
"Reader returned from out-of-bounds split still has available bytes"
);
}
#[test]
fn read_full() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 16), (Readable, 16), (Readable, 16)],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
let mut buf = [0u8; 64];
assert_eq!(
reader.read(&mut buf[..]).expect("failed to read to buffer"),
48
);
}
#[test]
fn write_full() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Writable, 16), (Writable, 16), (Writable, 16)],
0,
)
.expect("create_descriptor_chain failed");
let writer = &mut chain.writer;
let buf = [0xdeu8; 64];
assert_eq!(
writer.write(&buf[..]).expect("failed to write from buffer"),
48
);
}
#[test]
fn consume_collect() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let vs: Vec<Le64> = vec![
0x0101010101010101.into(),
0x0202020202020202.into(),
0x0303030303030303.into(),
];
let mut write_chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Writable, 24)],
0,
)
.expect("create_descriptor_chain failed");
let writer = &mut write_chain.writer;
writer
.consume(vs.clone())
.expect("failed to consume() a vector");
let mut read_chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 24)],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut read_chain.reader;
let vs_read = reader
.collect::<io::Result<Vec<Le64>>, _>()
.expect("failed to collect() values");
assert_eq!(vs, vs_read);
}
#[test]
fn get_remaining_region_with_count() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let Reader {
mem: _,
mut regions,
} = chain.reader;
let drain = regions
.get_remaining_regions_with_count(::std::usize::MAX)
.fold(0usize, |total, region| total + region.len);
assert_eq!(drain, 128);
let exact = regions
.get_remaining_regions_with_count(32)
.fold(0usize, |total, region| total + region.len);
assert!(exact > 0);
assert!(exact <= 32);
let split = regions
.get_remaining_regions_with_count(24)
.fold(0usize, |total, region| total + region.len);
assert!(split > 0);
assert!(split <= 24);
regions.consume(64);
let first = regions
.get_remaining_regions_with_count(8)
.fold(0usize, |total, region| total + region.len);
assert!(first > 0);
assert!(first <= 8);
}
#[test]
fn get_remaining_with_count() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(Readable, 16),
(Readable, 16),
(Readable, 96),
(Writable, 64),
(Writable, 1),
(Writable, 3),
],
0,
)
.expect("create_descriptor_chain failed");
let Reader {
mem: _,
mut regions,
} = chain.reader;
let drain = regions
.get_remaining_with_count(&memory, ::std::usize::MAX)
.iter()
.fold(0usize, |total, iov| total + iov.size());
assert_eq!(drain, 128);
let exact = regions
.get_remaining_with_count(&memory, 32)
.iter()
.fold(0usize, |total, iov| total + iov.size());
assert!(exact > 0);
assert!(exact <= 32);
let split = regions
.get_remaining_with_count(&memory, 24)
.iter()
.fold(0usize, |total, iov| total + iov.size());
assert!(split > 0);
assert!(split <= 24);
regions.consume(64);
let first = regions
.get_remaining_with_count(&memory, 8)
.iter()
.fold(0usize, |total, iov| total + iov.size());
assert!(first > 0);
assert!(first <= 8);
}
#[test]
fn reader_peek_obj() {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
// Write test data to memory.
memory
.write_obj_at_addr(Le16::from(0xBEEF), GuestAddress(0x100))
.unwrap();
memory
.write_obj_at_addr(Le16::from(0xDEAD), GuestAddress(0x200))
.unwrap();
let mut chain_reader = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 2), (Readable, 2)],
0x100 - 2,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain_reader.reader;
// peek_obj() at the beginning of the chain should return the first object.
let peek1 = reader.peek_obj::<Le16>().unwrap();
assert_eq!(peek1, Le16::from(0xBEEF));
// peek_obj() again should return the same object, since it was not consumed.
let peek2 = reader.peek_obj::<Le16>().unwrap();
assert_eq!(peek2, Le16::from(0xBEEF));
// peek_obj() of an object spanning two descriptors should copy from both.
let peek3 = reader.peek_obj::<Le32>().unwrap();
assert_eq!(peek3, Le32::from(0xDEADBEEF));
// read_obj() should return the first object.
let read1 = reader.read_obj::<Le16>().unwrap();
assert_eq!(read1, Le16::from(0xBEEF));
// peek_obj() of a value that is larger than the rest of the chain should fail.
reader
.peek_obj::<Le32>()
.expect_err("peek_obj past end of chain");
// read_obj() again should return the second object.
let read2 = reader.read_obj::<Le16>().unwrap();
assert_eq!(read2, Le16::from(0xDEAD));
// peek_obj() should fail at the end of the chain.
reader
.peek_obj::<Le16>()
.expect_err("peek_obj past end of chain");
}
#[test]
fn region_reader_failing_io() {
let ex = Executor::new().unwrap();
ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
}
async fn region_reader_failing_io_async(ex: &Executor) {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Readable, 256), (Readable, 256)],
0,
)
.expect("create_descriptor_chain failed");
let reader = &mut chain.reader;
// Open a file in read-only mode so writes to it to trigger an I/O error.
let named_temp_file = NamedTempFile::new().expect("failed to create temp file");
let ro_file =
File::open(named_temp_file.path()).expect("failed to open temp file read only");
let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
reader
.read_exact_to_at_fut(&async_ro_file, 512, 0)
.await
.expect_err("successfully read more bytes than SingleFileDisk size");
// The write above should have failed entirely, so we end up not writing any bytes at all.
assert_eq!(reader.available_bytes(), 512);
assert_eq!(reader.bytes_read(), 0);
}
#[test]
fn region_writer_failing_io() {
let ex = Executor::new().unwrap();
ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
}
async fn region_writer_failing_io_async(ex: &Executor) {
use DescriptorType::*;
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
let mut chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![(Writable, 256), (Writable, 256)],
0,
)
.expect("create_descriptor_chain failed");
let writer = &mut chain.writer;
let file = tempfile().expect("failed to create temp file");
file.set_len(384).unwrap();
let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
writer
.write_all_from_at_fut(&async_file, 512, 0)
.await
.expect_err("successfully wrote more bytes than in SingleFileDisk");
assert_eq!(writer.available_bytes(), 128);
assert_eq!(writer.bytes_written(), 384);
}
}