| // Copyright 2021 The ChromiumOS Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| //! Library for implementing vhost-user device executables. |
| //! |
| //! This crate provides |
| //! * `VhostUserBackend` trait, which is a collection of methods to handle vhost-user requests, and |
| //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop. |
| //! |
| //! They are expected to be used as follows: |
| //! |
| //! 1. Define a struct and implement `VhostUserBackend` for it. |
| //! 2. Create a `DeviceRequestHandler` with the backend struct. |
| //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor. |
| //! |
| //! ```ignore |
| //! struct MyBackend { |
| //! /* fields */ |
| //! } |
| //! |
| //! impl VhostUserBackend for MyBackend { |
| //! /* implement methods */ |
| //! } |
| //! |
| //! fn main() -> Result<(), Box<dyn Error>> { |
| //! let backend = MyBackend { /* initialize fields */ }; |
| //! let handler = DeviceRequestHandler::new(backend); |
| //! let socket = std::path::Path("/path/to/socket"); |
| //! let ex = cros_async::Executor::new()?; |
| //! |
| //! if let Err(e) = ex.run_until(handler.run(socket, &ex)) { |
| //! eprintln!("error happened: {}", e); |
| //! } |
| //! Ok(()) |
| //! } |
| //! ``` |
| //! |
| // Implementation note: |
| // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user |
| // protocol. DeviceRequestHandler implements the VhostUserSlaveReqHandlerMut trait from vmm_vhost, |
| // and includes some common code for setting up guest memory and managing partially configured |
| // vrings. DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() |
| // when it becomes readable. handle_request() reads and parses the message and then calls one of the |
| // VhostUserSlaveReqHandlerMut trait methods. These dispatch back to the supplied VhostUserBackend |
| // implementation (this is what our devices implement). |
| |
| pub(super) mod sys; |
| |
| use std::collections::BTreeMap; |
| use std::convert::From; |
| use std::convert::TryFrom; |
| use std::fs::File; |
| use std::num::Wrapping; |
| #[cfg(unix)] |
| use std::os::unix::io::AsRawFd; |
| use std::sync::Arc; |
| |
| use anyhow::bail; |
| use anyhow::Context; |
| #[cfg(unix)] |
| use base::clear_fd_flags; |
| use base::error; |
| use base::Event; |
| use base::FromRawDescriptor; |
| use base::IntoRawDescriptor; |
| use base::Protection; |
| use base::SafeDescriptor; |
| use base::SharedMemory; |
| use cros_async::TaskHandle; |
| use futures::future::AbortHandle; |
| use futures::future::Aborted; |
| use sys::Doorbell; |
| use thiserror::Error as ThisError; |
| use vm_control::VmMemorySource; |
| use vm_memory::GuestAddress; |
| use vm_memory::GuestMemory; |
| use vm_memory::MemoryRegion; |
| use vmm_vhost::connection::Endpoint; |
| use vmm_vhost::message::SlaveReq; |
| use vmm_vhost::message::VhostSharedMemoryRegion; |
| use vmm_vhost::message::VhostUserConfigFlags; |
| use vmm_vhost::message::VhostUserGpuMapMsg; |
| use vmm_vhost::message::VhostUserInflight; |
| use vmm_vhost::message::VhostUserMemoryRegion; |
| use vmm_vhost::message::VhostUserProtocolFeatures; |
| use vmm_vhost::message::VhostUserShmemMapMsg; |
| use vmm_vhost::message::VhostUserShmemMapMsgFlags; |
| use vmm_vhost::message::VhostUserShmemUnmapMsg; |
| use vmm_vhost::message::VhostUserSingleMemoryRegion; |
| use vmm_vhost::message::VhostUserVirtioFeatures; |
| use vmm_vhost::message::VhostUserVringAddrFlags; |
| use vmm_vhost::message::VhostUserVringState; |
| use vmm_vhost::Error as VhostError; |
| use vmm_vhost::Protocol; |
| use vmm_vhost::Result as VhostResult; |
| use vmm_vhost::Slave; |
| use vmm_vhost::VhostUserMasterReqHandler; |
| use vmm_vhost::VhostUserSlaveReqHandlerMut; |
| |
| use crate::virtio::Queue; |
| use crate::virtio::SharedMemoryMapper; |
| use crate::virtio::SharedMemoryRegion; |
| use crate::virtio::SignalableInterrupt; |
| |
| /// Largest valid number of entries in a virtqueue. |
| const MAX_VRING_LEN: u16 = 32768; |
| |
| /// An event to deliver an interrupt to the guest. |
| /// |
| /// Unlike `devices::Interrupt`, this doesn't support interrupt status and signal resampling. |
| // TODO(b/187487351): To avoid sending unnecessary events, we might want to support interrupt |
| // status. For this purpose, we need a mechanism to share interrupt status between the vmm and the |
| // device process. |
| #[derive(Clone)] |
| pub struct CallEvent(Arc<Event>); |
| |
| impl CallEvent { |
| #[cfg_attr(windows, allow(dead_code))] |
| pub fn into_inner(self) -> Event { |
| Arc::try_unwrap(self.0).unwrap() |
| } |
| } |
| |
| impl SignalableInterrupt for CallEvent { |
| fn signal(&self, _vector: u16, _interrupt_status_mask: u32) { |
| self.0.signal().unwrap(); |
| } |
| |
| fn signal_config_changed(&self) {} // TODO(dgreid) |
| |
| fn get_resample_evt(&self) -> Option<&Event> { |
| None |
| } |
| |
| fn do_interrupt_resample(&self) {} |
| } |
| |
| impl From<File> for CallEvent { |
| fn from(file: File) -> Self { |
| // Safe because we own the file. |
| CallEvent(Arc::new(unsafe { |
| Event::from_raw_descriptor(file.into_raw_descriptor()) |
| })) |
| } |
| } |
| |
| /// Keeps a mapping from the vmm's virtual addresses to guest addresses. |
| /// used to translate messages from the vmm to guest offsets. |
| #[derive(Default)] |
| pub struct MappingInfo { |
| pub vmm_addr: u64, |
| pub guest_phys: u64, |
| pub size: u64, |
| } |
| |
| pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> { |
| for map in maps { |
| if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size { |
| return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys)); |
| } |
| } |
| Err(VhostError::InvalidMessage) |
| } |
| |
| /// Trait for vhost-user backend. |
| pub trait VhostUserBackend { |
| /// The maximum number of queues that this backend can manage. |
| fn max_queue_num(&self) -> usize; |
| |
| /// The set of feature bits that this backend supports. |
| fn features(&self) -> u64; |
| |
| /// Acknowledges that this set of features should be enabled. |
| fn ack_features(&mut self, value: u64) -> anyhow::Result<()>; |
| |
| /// Returns the set of enabled features. |
| fn acked_features(&self) -> u64; |
| |
| /// The set of protocol feature bits that this backend supports. |
| fn protocol_features(&self) -> VhostUserProtocolFeatures; |
| |
| /// Acknowledges that this set of protocol features should be enabled. |
| fn ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>; |
| |
| /// Returns the set of enabled protocol features. |
| fn acked_protocol_features(&self) -> u64; |
| |
| /// Reads this device configuration space at `offset`. |
| fn read_config(&self, offset: u64, dst: &mut [u8]); |
| |
| /// writes `data` to this device's configuration space at `offset`. |
| fn write_config(&self, _offset: u64, _data: &[u8]) {} |
| |
| /// Indicates that the backend should start processing requests for virtio queue number `idx`. |
| /// This method must not block the current thread so device backends should either spawn an |
| /// async task or another thread to handle messages from the Queue. |
| fn start_queue( |
| &mut self, |
| idx: usize, |
| queue: Queue, |
| mem: GuestMemory, |
| doorbell: Doorbell, |
| kick_evt: Event, |
| ) -> anyhow::Result<()>; |
| |
| /// Indicates that the backend should stop processing requests for virtio queue number `idx`. |
| /// This method should return the queue passed to `start_queue` for the corresponding `idx`. |
| fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>; |
| |
| /// Resets the vhost-user backend. |
| fn reset(&mut self); |
| |
| /// Returns the device's shared memory region if present. |
| fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> { |
| None |
| } |
| |
| /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message |
| /// handling. |
| /// |
| /// This method will be called when `VhostUserProtocolFeatures::SLAVE_REQ` is |
| /// negotiated. |
| fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) { |
| error!("set_backend_req_connection is not implemented"); |
| } |
| } |
| |
| /// A virtio ring entry. |
| struct Vring { |
| queue: Queue, |
| doorbell: Option<Doorbell>, |
| enabled: bool, |
| } |
| |
| impl Vring { |
| fn new(max_size: u16) -> Self { |
| Self { |
| queue: Queue::new(max_size), |
| doorbell: None, |
| enabled: false, |
| } |
| } |
| |
| fn reset(&mut self) { |
| self.queue.reset(); |
| self.doorbell = None; |
| self.enabled = false; |
| } |
| } |
| |
| /// Trait for defining vhost-user ops that are platform-dependent. |
| pub trait VhostUserPlatformOps { |
| /// Returns the protocol implemented by these platform ops. |
| fn protocol(&self) -> Protocol; |
| /// Create the guest memory for the backend. |
| /// |
| /// `contexts` and `files` must be the same size, and provide a description of the memory |
| /// regions to map as well as the file descriptors from which to obtain the memory backing these |
| /// regions, respectively. |
| /// |
| /// The returned tuple contains the constructed `GuestMemory` from these memory contexts, as |
| /// well as a vector describing all the mappings described by these contexts. |
| |
| fn set_mem_table( |
| &mut self, |
| contexts: &[VhostUserMemoryRegion], |
| files: Vec<File>, |
| ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>; |
| |
| /// Return an `Event` that will be signaled by the frontend whenever vring `index` should be |
| /// processed. |
| /// |
| /// For protocols that support providing that event using a file descriptor (`Regular`), it is |
| /// provided by `file`. For other protocols, `file` will be `None`. |
| fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<Event>; |
| |
| /// Return a `Doorbell` that the backend will signal whenever it puts used buffers for vring |
| /// `index`. |
| /// |
| /// For protocols that support listening to a file descriptor (`Regular`), `file` provides a |
| /// file descriptor from which the `Doorbell` should be built. For other protocols, it will be |
| /// `None`. |
| fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<Doorbell>; |
| } |
| |
| /// Ops for running vhost-user over a stream (i.e. regular protocol). |
| pub(super) struct VhostUserRegularOps; |
| |
| impl VhostUserPlatformOps for VhostUserRegularOps { |
| fn protocol(&self) -> Protocol { |
| Protocol::Regular |
| } |
| |
| fn set_mem_table( |
| &mut self, |
| contexts: &[VhostUserMemoryRegion], |
| files: Vec<File>, |
| ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> { |
| if files.len() != contexts.len() { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| let mut regions = Vec::with_capacity(files.len()); |
| for (region, file) in contexts.iter().zip(files.into_iter()) { |
| let region = MemoryRegion::new_from_shm( |
| region.memory_size, |
| GuestAddress(region.guest_phys_addr), |
| region.mmap_offset, |
| Arc::new( |
| SharedMemory::from_safe_descriptor( |
| SafeDescriptor::from(file), |
| Some(region.memory_size), |
| ) |
| .unwrap(), |
| ), |
| ) |
| .map_err(|e| { |
| error!("failed to create a memory region: {}", e); |
| VhostError::InvalidOperation |
| })?; |
| regions.push(region); |
| } |
| let guest_mem = GuestMemory::from_regions(regions).map_err(|e| { |
| error!("failed to create guest memory: {}", e); |
| VhostError::InvalidOperation |
| })?; |
| |
| let vmm_maps = contexts |
| .iter() |
| .map(|region| MappingInfo { |
| vmm_addr: region.user_addr, |
| guest_phys: region.guest_phys_addr, |
| size: region.memory_size, |
| }) |
| .collect(); |
| Ok((guest_mem, vmm_maps)) |
| } |
| |
| fn set_vring_kick(&mut self, _index: u8, file: Option<File>) -> VhostResult<Event> { |
| let file = file.ok_or(VhostError::InvalidParam)?; |
| // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read |
| // values via `next_val()` later. |
| // This is only required (and can only be done) on Unix platforms. |
| #[cfg(unix)] |
| if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) { |
| error!("failed to remove O_NONBLOCK for kick fd: {}", e); |
| return Err(VhostError::InvalidParam); |
| } |
| |
| // Safe because we own the file. |
| Ok(unsafe { Event::from_raw_descriptor(file.into_raw_descriptor()) }) |
| } |
| |
| fn set_vring_call(&mut self, _index: u8, file: Option<File>) -> VhostResult<Doorbell> { |
| let file = file.ok_or(VhostError::InvalidParam)?; |
| Ok( |
| // `Doorbell` is defined as `CallEvent` on Windows, prevent clippy from giving us a |
| // warning about the unneeded conversion. |
| #[allow(clippy::useless_conversion)] |
| Doorbell::from(CallEvent::try_from(file).map_err(|_| { |
| error!("failed to convert callfd to CallSignal"); |
| VhostError::InvalidParam |
| })?), |
| ) |
| } |
| } |
| |
| /// A request handler for devices implementing `VhostUserBackend`. |
| pub struct DeviceRequestHandler { |
| vrings: Vec<Vring>, |
| owned: bool, |
| vmm_maps: Option<Vec<MappingInfo>>, |
| mem: Option<GuestMemory>, |
| backend: Box<dyn VhostUserBackend>, |
| ops: Box<dyn VhostUserPlatformOps>, |
| } |
| |
| impl DeviceRequestHandler { |
| /// Creates a vhost-user handler instance for `backend` with a different set of platform ops |
| /// than the regular vhost-user ones. |
| pub(crate) fn new( |
| backend: Box<dyn VhostUserBackend>, |
| ops: Box<dyn VhostUserPlatformOps>, |
| ) -> Self { |
| let mut vrings = Vec::with_capacity(backend.max_queue_num()); |
| for _ in 0..backend.max_queue_num() { |
| vrings.push(Vring::new(MAX_VRING_LEN)); |
| } |
| |
| DeviceRequestHandler { |
| vrings, |
| owned: false, |
| vmm_maps: None, |
| mem: None, |
| backend, |
| ops, |
| } |
| } |
| } |
| |
| impl VhostUserSlaveReqHandlerMut for DeviceRequestHandler { |
| fn protocol(&self) -> Protocol { |
| self.ops.protocol() |
| } |
| |
| fn set_owner(&mut self) -> VhostResult<()> { |
| if self.owned { |
| return Err(VhostError::InvalidOperation); |
| } |
| self.owned = true; |
| Ok(()) |
| } |
| |
| fn reset_owner(&mut self) -> VhostResult<()> { |
| self.owned = false; |
| self.backend.reset(); |
| Ok(()) |
| } |
| |
| fn get_features(&mut self) -> VhostResult<u64> { |
| let features = self.backend.features(); |
| Ok(features) |
| } |
| |
| fn set_features(&mut self, features: u64) -> VhostResult<()> { |
| if !self.owned { |
| return Err(VhostError::InvalidOperation); |
| } |
| |
| if (features & !(self.backend.features())) != 0 { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| if let Err(e) = self.backend.ack_features(features) { |
| error!("failed to acknowledge features 0x{:x}: {}", features, e); |
| return Err(VhostError::InvalidOperation); |
| } |
| |
| // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an |
| // enabled state. |
| // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a |
| // disabled state. |
| // Client must not pass data to/from the backend until ring is enabled by |
| // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by |
| // VHOST_USER_SET_VRING_ENABLE with parameter 0. |
| let acked_features = self.backend.acked_features(); |
| let vring_enabled = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() & acked_features != 0; |
| for v in &mut self.vrings { |
| v.enabled = vring_enabled; |
| } |
| |
| Ok(()) |
| } |
| |
| fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> { |
| Ok(self.backend.protocol_features()) |
| } |
| |
| fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> { |
| if let Err(e) = self.backend.ack_protocol_features(features) { |
| error!("failed to set protocol features 0x{:x}: {}", features, e); |
| return Err(VhostError::InvalidOperation); |
| } |
| Ok(()) |
| } |
| |
| fn set_mem_table( |
| &mut self, |
| contexts: &[VhostUserMemoryRegion], |
| files: Vec<File>, |
| ) -> VhostResult<()> { |
| let (guest_mem, vmm_maps) = self.ops.set_mem_table(contexts, files)?; |
| self.mem = Some(guest_mem); |
| self.vmm_maps = Some(vmm_maps); |
| Ok(()) |
| } |
| |
| fn get_queue_num(&mut self) -> VhostResult<u64> { |
| Ok(self.vrings.len() as u64) |
| } |
| |
| fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> { |
| if index as usize >= self.vrings.len() || num == 0 || num > MAX_VRING_LEN.into() { |
| return Err(VhostError::InvalidParam); |
| } |
| self.vrings[index as usize].queue.set_size(num as u16); |
| |
| Ok(()) |
| } |
| |
| fn set_vring_addr( |
| &mut self, |
| index: u32, |
| _flags: VhostUserVringAddrFlags, |
| descriptor: u64, |
| used: u64, |
| available: u64, |
| _log: u64, |
| ) -> VhostResult<()> { |
| if index as usize >= self.vrings.len() { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?; |
| let vring = &mut self.vrings[index as usize]; |
| vring |
| .queue |
| .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?); |
| vring |
| .queue |
| .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?); |
| vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?); |
| |
| Ok(()) |
| } |
| |
| fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> { |
| if index as usize >= self.vrings.len() || base >= MAX_VRING_LEN.into() { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| let vring = &mut self.vrings[index as usize]; |
| vring.queue.next_avail = Wrapping(base as u16); |
| vring.queue.next_used = Wrapping(base as u16); |
| |
| Ok(()) |
| } |
| |
| fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> { |
| if index as usize >= self.vrings.len() { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| // Quotation from vhost-user spec: |
| // Client must start ring upon receiving a kick (that is, detecting |
| // that file descriptor is readable) on the descriptor specified by |
| // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving |
| // VHOST_USER_GET_VRING_BASE. |
| if let Err(e) = self.backend.stop_queue(index as usize) { |
| error!("Failed to stop queue in get_vring_base: {}", e); |
| } |
| |
| let vring = &mut self.vrings[index as usize]; |
| vring.reset(); |
| |
| Ok(VhostUserVringState::new( |
| index, |
| vring.queue.next_avail.0 as u32, |
| )) |
| } |
| |
| fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> { |
| if index as usize >= self.vrings.len() { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| let vring = &mut self.vrings[index as usize]; |
| if vring.queue.ready() { |
| error!("kick fd cannot replaced after queue is started"); |
| return Err(VhostError::InvalidOperation); |
| } |
| |
| let kick_evt = self.ops.set_vring_kick(index, file)?; |
| |
| // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX). |
| vring.queue.ack_features(self.backend.acked_features()); |
| vring.queue.set_ready(true); |
| |
| let queue = match vring.queue.activate() { |
| Ok(queue) => queue, |
| Err(e) => { |
| error!("failed to activate vring: {:#}", e); |
| return Err(VhostError::SlaveInternalError); |
| } |
| }; |
| |
| let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?; |
| let mem = self |
| .mem |
| .as_ref() |
| .cloned() |
| .ok_or(VhostError::InvalidOperation)?; |
| |
| if let Err(e) = self |
| .backend |
| .start_queue(index as usize, queue, mem, doorbell, kick_evt) |
| { |
| error!("Failed to start queue {}: {}", index, e); |
| return Err(VhostError::SlaveInternalError); |
| } |
| |
| Ok(()) |
| } |
| |
| fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> { |
| if index as usize >= self.vrings.len() { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| let doorbell = self.ops.set_vring_call(index, file)?; |
| self.vrings[index as usize].doorbell = Some(doorbell); |
| Ok(()) |
| } |
| |
| fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> { |
| // TODO |
| Ok(()) |
| } |
| |
| fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> { |
| if index as usize >= self.vrings.len() { |
| return Err(VhostError::InvalidParam); |
| } |
| |
| // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES |
| // has been negotiated. |
| if self.backend.acked_features() & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { |
| return Err(VhostError::InvalidOperation); |
| } |
| |
| // Slave must not pass data to/from the backend until ring is |
| // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, |
| // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE |
| // with parameter 0. |
| self.vrings[index as usize].enabled = enable; |
| |
| Ok(()) |
| } |
| |
| fn get_config( |
| &mut self, |
| offset: u32, |
| size: u32, |
| _flags: VhostUserConfigFlags, |
| ) -> VhostResult<Vec<u8>> { |
| let mut data = vec![0; size as usize]; |
| self.backend.read_config(u64::from(offset), &mut data); |
| Ok(data) |
| } |
| |
| fn set_config( |
| &mut self, |
| offset: u32, |
| buf: &[u8], |
| _flags: VhostUserConfigFlags, |
| ) -> VhostResult<()> { |
| self.backend.write_config(u64::from(offset), buf); |
| Ok(()) |
| } |
| |
| fn set_slave_req_fd(&mut self, ep: Box<dyn Endpoint<SlaveReq>>) { |
| let conn = VhostBackendReqConnection::new( |
| Slave::new(ep), |
| self.backend.get_shared_memory_region().map(|r| r.id), |
| ); |
| self.backend.set_backend_req_connection(conn); |
| } |
| |
| fn get_inflight_fd( |
| &mut self, |
| _inflight: &VhostUserInflight, |
| ) -> VhostResult<(VhostUserInflight, File)> { |
| unimplemented!("get_inflight_fd"); |
| } |
| |
| fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> { |
| unimplemented!("set_inflight_fd"); |
| } |
| |
| fn get_max_mem_slots(&mut self) -> VhostResult<u64> { |
| //TODO |
| Ok(0) |
| } |
| |
| fn add_mem_region( |
| &mut self, |
| _region: &VhostUserSingleMemoryRegion, |
| _fd: File, |
| ) -> VhostResult<()> { |
| //TODO |
| Ok(()) |
| } |
| |
| fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> { |
| //TODO |
| Ok(()) |
| } |
| |
| fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> { |
| Ok(if let Some(r) = self.backend.get_shared_memory_region() { |
| vec![VhostSharedMemoryRegion::new(r.id, r.length)] |
| } else { |
| Vec::new() |
| }) |
| } |
| } |
| |
| /// Indicates the state of backend request connection |
| pub enum VhostBackendReqConnectionState { |
| /// A backend request connection (`VhostBackendReqConnection`) is established |
| Connected(VhostBackendReqConnection), |
| /// No backend request connection has been established yet |
| NoConnection, |
| } |
| |
| /// Keeps track of Vhost user backend request connection. |
| pub struct VhostBackendReqConnection { |
| conn: Slave, |
| shmem_info: Option<ShmemInfo>, |
| } |
| |
| #[derive(Clone)] |
| struct ShmemInfo { |
| shmid: u8, |
| mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>, |
| } |
| |
| impl VhostBackendReqConnection { |
| pub fn new(conn: Slave, shmid: Option<u8>) -> Self { |
| let shmem_info = shmid.map(|shmid| ShmemInfo { |
| shmid, |
| mapped_regions: BTreeMap::new(), |
| }); |
| Self { conn, shmem_info } |
| } |
| |
| /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend |
| pub fn send_config_changed(&self) -> anyhow::Result<()> { |
| self.conn |
| .handle_config_change() |
| .context("Could not send config change message")?; |
| Ok(()) |
| } |
| |
| /// Create a SharedMemoryMapper trait object from the ShmemInfo. |
| pub fn take_shmem_mapper(&mut self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> { |
| let shmem_info = self |
| .shmem_info |
| .take() |
| .context("could not take shared memory mapper information")?; |
| |
| Ok(Box::new(VhostShmemMapper { |
| conn: self.conn.clone(), |
| shmem_info, |
| })) |
| } |
| } |
| |
| struct VhostShmemMapper { |
| conn: Slave, |
| shmem_info: ShmemInfo, |
| } |
| |
| impl SharedMemoryMapper for VhostShmemMapper { |
| fn add_mapping( |
| &mut self, |
| source: VmMemorySource, |
| offset: u64, |
| prot: Protection, |
| ) -> anyhow::Result<()> { |
| // True if we should send gpu_map instead of shmem_map. |
| let is_gpu = matches!(&source, &VmMemorySource::Vulkan { .. }); |
| |
| let size = if is_gpu { |
| match source { |
| VmMemorySource::Vulkan { |
| descriptor, |
| handle_type, |
| memory_idx, |
| device_id, |
| size, |
| } => { |
| let msg = VhostUserGpuMapMsg::new( |
| self.shmem_info.shmid, |
| offset, |
| size, |
| memory_idx, |
| handle_type, |
| device_id.device_uuid, |
| device_id.driver_uuid, |
| ); |
| self.conn |
| .gpu_map(&msg, &descriptor) |
| .context("failed to map memory")?; |
| size |
| } |
| _ => unreachable!("inconsistent pattern match"), |
| } |
| } else { |
| let (descriptor, fd_offset, size) = match source { |
| VmMemorySource::Descriptor { |
| descriptor, |
| offset, |
| size, |
| } => (descriptor, offset, size), |
| VmMemorySource::SharedMemory(shmem) => { |
| let size = shmem.size(); |
| // Safe because we own shmem. |
| let descriptor = |
| unsafe { SafeDescriptor::from_raw_descriptor(shmem.into_raw_descriptor()) }; |
| (descriptor, 0, size) |
| } |
| _ => bail!("unsupported source"), |
| }; |
| let flags = VhostUserShmemMapMsgFlags::from(prot); |
| let msg = |
| VhostUserShmemMapMsg::new(self.shmem_info.shmid, offset, fd_offset, size, flags); |
| self.conn |
| .shmem_map(&msg, &descriptor) |
| .context("failed to map memory")?; |
| size |
| }; |
| |
| self.shmem_info.mapped_regions.insert(offset, size); |
| Ok(()) |
| } |
| |
| fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> { |
| let size = self |
| .shmem_info |
| .mapped_regions |
| .remove(&offset) |
| .context("unknown offset")?; |
| let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size); |
| self.conn |
| .shmem_unmap(&msg) |
| .context("failed to map memory") |
| .map(|_| ()) |
| } |
| } |
| |
| pub(crate) struct WorkerState<T, U> { |
| pub(crate) abort_handle: AbortHandle, |
| pub(crate) queue_task: TaskHandle<std::result::Result<U, Aborted>>, |
| pub(crate) queue: T, |
| } |
| |
| /// Errors for device operations |
| #[derive(Debug, ThisError)] |
| pub enum Error { |
| #[error("worker not found when stopping queue")] |
| WorkerNotFound, |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| #[cfg(unix)] |
| use std::sync::mpsc::channel; |
| #[cfg(unix)] |
| use std::sync::Barrier; |
| |
| use anyhow::anyhow; |
| use anyhow::bail; |
| #[cfg(unix)] |
| use tempfile::Builder; |
| #[cfg(unix)] |
| use tempfile::TempDir; |
| use vmm_vhost::message::MasterReq; |
| use vmm_vhost::SlaveReqHandler; |
| use vmm_vhost::VhostUserSlaveReqHandler; |
| use zerocopy::AsBytes; |
| use zerocopy::FromBytes; |
| |
| use super::*; |
| use crate::virtio::vhost::user::vmm::VhostUserHandler; |
| |
| #[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromBytes)] |
| #[repr(C, packed(4))] |
| struct FakeConfig { |
| x: u32, |
| y: u64, |
| } |
| |
| const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 }; |
| |
| pub(super) struct FakeBackend { |
| avail_features: u64, |
| acked_features: u64, |
| acked_protocol_features: VhostUserProtocolFeatures, |
| } |
| |
| impl FakeBackend { |
| const MAX_QUEUE_NUM: usize = 16; |
| |
| pub(super) fn new() -> Self { |
| Self { |
| avail_features: VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(), |
| acked_features: 0, |
| acked_protocol_features: VhostUserProtocolFeatures::empty(), |
| } |
| } |
| } |
| |
| impl VhostUserBackend for FakeBackend { |
| fn max_queue_num(&self) -> usize { |
| Self::MAX_QUEUE_NUM |
| } |
| |
| fn features(&self) -> u64 { |
| self.avail_features |
| } |
| |
| fn ack_features(&mut self, value: u64) -> anyhow::Result<()> { |
| let unrequested_features = value & !self.avail_features; |
| if unrequested_features != 0 { |
| bail!( |
| "invalid protocol features are given: 0x{:x}", |
| unrequested_features |
| ); |
| } |
| self.acked_features |= value; |
| Ok(()) |
| } |
| |
| fn acked_features(&self) -> u64 { |
| self.acked_features |
| } |
| |
| fn protocol_features(&self) -> VhostUserProtocolFeatures { |
| VhostUserProtocolFeatures::CONFIG |
| } |
| |
| fn ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()> { |
| let features = VhostUserProtocolFeatures::from_bits(features).ok_or(anyhow!( |
| "invalid protocol features are given: 0x{:x}", |
| features |
| ))?; |
| let supported = self.protocol_features(); |
| self.acked_protocol_features = features & supported; |
| Ok(()) |
| } |
| |
| fn acked_protocol_features(&self) -> u64 { |
| self.acked_protocol_features.bits() |
| } |
| |
| fn read_config(&self, offset: u64, dst: &mut [u8]) { |
| dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]); |
| } |
| |
| fn reset(&mut self) {} |
| |
| fn start_queue( |
| &mut self, |
| _idx: usize, |
| _queue: Queue, |
| _mem: GuestMemory, |
| _doorbell: Doorbell, |
| _kick_evt: Event, |
| ) -> anyhow::Result<()> { |
| Ok(()) |
| } |
| |
| fn stop_queue(&mut self, _idx: usize) -> anyhow::Result<Queue> { |
| // TODO(280607609): Return a `Queue`. |
| Err(anyhow!("Missing queue")) |
| } |
| } |
| |
| #[cfg(unix)] |
| fn temp_dir() -> TempDir { |
| Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() |
| } |
| |
| #[cfg(unix)] |
| #[test] |
| fn test_vhost_user_activate() { |
| use std::os::unix::net::UnixStream; |
| |
| use vmm_vhost::connection::socket::Listener as SocketListener; |
| use vmm_vhost::SlaveListener; |
| |
| const QUEUES_NUM: usize = 2; |
| |
| let dir = temp_dir(); |
| let mut path = dir.path().to_owned(); |
| path.push("sock"); |
| let listener = SocketListener::new(&path, true).unwrap(); |
| |
| let vmm_bar = Arc::new(Barrier::new(2)); |
| let dev_bar = vmm_bar.clone(); |
| |
| let (tx, rx) = channel(); |
| |
| std::thread::spawn(move || { |
| // VMM side |
| rx.recv().unwrap(); // Ensure the device is ready. |
| |
| let allow_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); |
| let init_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); |
| let allow_protocol_features = VhostUserProtocolFeatures::CONFIG; |
| let connection = UnixStream::connect(&path).unwrap(); |
| let mut vmm_handler = VhostUserHandler::new_from_connection( |
| connection, |
| QUEUES_NUM as u64, |
| allow_features, |
| init_features, |
| allow_protocol_features, |
| ) |
| .unwrap(); |
| |
| println!("read_config"); |
| let mut buf = vec![0; std::mem::size_of::<FakeConfig>()]; |
| vmm_handler.read_config(0, &mut buf).unwrap(); |
| // Check if the obtained config data is correct. |
| let config = FakeConfig::read_from(buf.as_bytes()).unwrap(); |
| assert_eq!(config, FAKE_CONFIG_DATA); |
| |
| println!("set_mem_table"); |
| let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap(); |
| vmm_handler.set_mem_table(&mem).unwrap(); |
| |
| for idx in 0..QUEUES_NUM { |
| println!("activate_mem_table: queue_index={}", idx); |
| let queue = Queue::new(0x10); |
| let queue_evt = Event::new().unwrap(); |
| let irqfd = Event::new().unwrap(); |
| |
| vmm_handler |
| .activate_vring(&mem, idx, &queue, &queue_evt, &irqfd) |
| .unwrap(); |
| } |
| |
| // The VMM side is supposed to stop before the device side. |
| drop(vmm_handler); |
| |
| vmm_bar.wait(); |
| }); |
| |
| // Device side |
| let handler = std::sync::Mutex::new(DeviceRequestHandler::new( |
| Box::new(FakeBackend::new()), |
| Box::new(VhostUserRegularOps), |
| )); |
| let mut listener = SlaveListener::<SocketListener, _>::new(listener, handler).unwrap(); |
| |
| // Notify listener is ready. |
| tx.send(()).unwrap(); |
| |
| let mut listener = listener.accept().unwrap().unwrap(); |
| |
| // VhostUserHandler::new() |
| handle_request(&mut listener).expect("set_owner"); |
| handle_request(&mut listener).expect("get_features"); |
| handle_request(&mut listener).expect("set_features"); |
| handle_request(&mut listener).expect("get_protocol_features"); |
| handle_request(&mut listener).expect("set_protocol_features"); |
| |
| // VhostUserHandler::read_config() |
| handle_request(&mut listener).expect("get_config"); |
| |
| // VhostUserHandler::set_mem_table() |
| handle_request(&mut listener).expect("set_mem_table"); |
| |
| for _ in 0..QUEUES_NUM { |
| // VhostUserHandler::activate_vring() |
| handle_request(&mut listener).expect("set_vring_num"); |
| handle_request(&mut listener).expect("set_vring_addr"); |
| handle_request(&mut listener).expect("set_vring_base"); |
| handle_request(&mut listener).expect("set_vring_call"); |
| handle_request(&mut listener).expect("set_vring_kick"); |
| handle_request(&mut listener).expect("set_vring_enable"); |
| } |
| |
| dev_bar.wait(); |
| |
| match handle_request(&mut listener) { |
| Err(VhostError::ClientExit) => (), |
| r => panic!("Err(ClientExit) was expected but {:?}", r), |
| } |
| } |
| |
| pub(super) fn vmm_handler_send_requests(vmm_handler: &mut VhostUserHandler, queues_num: usize) { |
| println!("read_config"); |
| let mut buf = vec![0; std::mem::size_of::<FakeConfig>()]; |
| vmm_handler.read_config(0, &mut buf).unwrap(); |
| // Check if the obtained config data is correct. |
| let config = FakeConfig::read_from(buf.as_bytes()).unwrap(); |
| assert_eq!(config, FAKE_CONFIG_DATA); |
| |
| println!("set_mem_table"); |
| let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap(); |
| vmm_handler.set_mem_table(&mem).unwrap(); |
| |
| for idx in 0..queues_num { |
| println!("activate_mem_table: queue_index={}", idx); |
| let queue = Queue::new(0x10); |
| let queue_evt = Event::new().unwrap(); |
| let irqfd = Event::new().unwrap(); |
| |
| vmm_handler |
| .activate_vring(&mem, idx, &queue, &queue_evt, &irqfd) |
| .unwrap(); |
| } |
| } |
| |
| fn handle_request<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>( |
| handler: &mut SlaveReqHandler<S, E>, |
| ) -> Result<(), VhostError> { |
| let (hdr, files) = handler.recv_header()?; |
| handler.process_message(hdr, files) |
| } |
| |
| pub(super) fn test_handle_requests<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>( |
| req_handler: &mut SlaveReqHandler<S, E>, |
| queues_num: usize, |
| ) { |
| // VhostUserHandler::new() |
| handle_request(req_handler).expect("set_owner"); |
| handle_request(req_handler).expect("get_features"); |
| handle_request(req_handler).expect("set_features"); |
| handle_request(req_handler).expect("get_protocol_features"); |
| handle_request(req_handler).expect("set_protocol_features"); |
| |
| // VhostUserHandler::read_config() |
| handle_request(req_handler).expect("get_config"); |
| |
| // VhostUserHandler::set_mem_table() |
| handle_request(req_handler).expect("set_mem_table"); |
| |
| for _ in 0..queues_num { |
| // VhostUserHandler::activate_vring() |
| handle_request(req_handler).expect("set_vring_num"); |
| handle_request(req_handler).expect("set_vring_addr"); |
| handle_request(req_handler).expect("set_vring_base"); |
| handle_request(req_handler).expect("set_vring_call"); |
| handle_request(req_handler).expect("set_vring_kick"); |
| handle_request(req_handler).expect("set_vring_enable"); |
| } |
| } |
| } |