vhost-user: return the active Queue from stop_queue
Most of this change involves wrapping the Queue object inside an
Rc<RefCell<>>, Arc<Mutex<>>, or Rc<AsyncMutex<>>.
This is necessary for getting device suspend and resume to work.
Bug=280607608
TEST=presubmits
Change-Id: I7e3680aea2927c1fc9d971f27ebbb09ec308a634
Reviewed-on: https://chromium-review.googlesource.com/c/crosvm/crosvm/+/4545603
Commit-Queue: Richard Zhang <rizhang@google.com>
Reviewed-by: Noah Gold <nkgold@google.com>
diff --git a/devices/src/virtio/console.rs b/devices/src/virtio/console.rs
index fda4d26..8f8df3d 100644
--- a/devices/src/virtio/console.rs
+++ b/devices/src/virtio/console.rs
@@ -83,8 +83,11 @@
mem: &GuestMemory,
interrupt: &I,
buffer: &mut VecDeque<u8>,
- receive_queue: &mut Queue,
+ receive_queue: &Arc<Mutex<Queue>>,
) -> result::Result<(), ConsoleError> {
+ let mut receive_queue = receive_queue
+ .try_lock()
+ .expect("Lock should not be unavailable");
loop {
let mut desc = receive_queue
.peek(mem)
@@ -142,10 +145,13 @@
fn process_transmit_queue<I: SignalableInterrupt>(
mem: &GuestMemory,
interrupt: &I,
- transmit_queue: &mut Queue,
+ transmit_queue: &Arc<Mutex<Queue>>,
output: &mut dyn io::Write,
) {
let mut needs_interrupt = false;
+ let mut transmit_queue = transmit_queue
+ .try_lock()
+ .expect("Lock should not be unavailable");
while let Some(mut avail_desc) = transmit_queue.pop(mem) {
process_transmit_request(&mut avail_desc.reader, output)
.unwrap_or_else(|e| error!("console: process_transmit_request failed: {}", e));
@@ -166,9 +172,9 @@
output: Box<dyn io::Write + Send>,
kill_evt: Event,
in_avail_evt: Event,
- receive_queue: Queue,
+ receive_queue: Arc<Mutex<Queue>>,
receive_evt: Event,
- transmit_queue: Queue,
+ transmit_queue: Arc<Mutex<Queue>>,
transmit_evt: Event,
}
@@ -284,7 +290,7 @@
process_transmit_queue(
&self.mem,
&self.interrupt,
- &mut self.transmit_queue,
+ &self.transmit_queue,
&mut self.output,
);
}
@@ -298,7 +304,7 @@
&self.mem,
&self.interrupt,
in_buf_ref.lock().deref_mut(),
- &mut self.receive_queue,
+ &self.receive_queue,
) {
Ok(()) => {}
// Console errors are no-ops, so just continue.
@@ -318,7 +324,7 @@
&self.mem,
&self.interrupt,
in_buf_ref.lock().deref_mut(),
- &mut self.receive_queue,
+ &self.receive_queue,
) {
Ok(()) => {}
// Console errors are no-ops, so just continue.
@@ -450,10 +456,10 @@
in_avail_evt,
kill_evt,
// Device -> driver
- receive_queue,
+ receive_queue: Arc::new(Mutex::new(receive_queue)),
receive_evt,
// Driver -> device
- transmit_queue,
+ transmit_queue: Arc::new(Mutex::new(transmit_queue)),
transmit_evt,
};
worker.run();
diff --git a/devices/src/virtio/console/asynchronous.rs b/devices/src/virtio/console/asynchronous.rs
index 8d87050..0411e5c 100644
--- a/devices/src/virtio/console/asynchronous.rs
+++ b/devices/src/virtio/console/asynchronous.rs
@@ -6,6 +6,7 @@
use std::collections::VecDeque;
use std::io;
+use std::sync::Arc;
use anyhow::anyhow;
use anyhow::Context;
@@ -23,6 +24,7 @@
use cros_async::IoSource;
use futures::FutureExt;
use hypervisor::ProtectionType;
+use sync::Mutex;
use vm_memory::GuestMemory;
use vmm_vhost::message::VhostUserVirtioFeatures;
use zerocopy::AsBytes;
@@ -57,7 +59,7 @@
impl IntoAsync for AsyncSerialInput {}
async fn run_tx_queue<I: SignalableInterrupt>(
- mut queue: virtio::Queue,
+ queue: &Arc<Mutex<virtio::Queue>>,
mem: GuestMemory,
doorbell: I,
kick_evt: EventAsync,
@@ -68,12 +70,12 @@
error!("Failed to read kick event for tx queue: {}", e);
break;
}
- process_transmit_queue(&mem, &doorbell, &mut queue, output.as_mut());
+ process_transmit_queue(&mem, &doorbell, queue, output.as_mut());
}
}
async fn run_rx_queue<I: SignalableInterrupt>(
- mut queue: virtio::Queue,
+ queue: &Arc<Mutex<virtio::Queue>>,
mem: GuestMemory,
doorbell: I,
kick_evt: EventAsync,
@@ -100,7 +102,7 @@
// Submit all the data obtained during this read.
while !in_buffer.is_empty() {
- match handle_input(&mem, &doorbell, &mut in_buffer, &mut queue) {
+ match handle_input(&mem, &doorbell, &mut in_buffer, queue) {
Ok(()) => {}
Err(ConsoleError::RxDescriptorsExhausted) => {
// Wait until a descriptor becomes available and try again.
@@ -129,7 +131,7 @@
&mut self,
ex: &Executor,
mem: GuestMemory,
- queue: virtio::Queue,
+ queue: Arc<Mutex<virtio::Queue>>,
doorbell: I,
kick_evt: Event,
) -> anyhow::Result<()> {
@@ -149,7 +151,7 @@
Ok(async move {
select2(
- run_rx_queue(queue, mem, doorbell, kick_evt, &async_input).boxed_local(),
+ run_rx_queue(&queue, mem, doorbell, kick_evt, &async_input).boxed_local(),
abort,
)
.await;
@@ -173,7 +175,7 @@
&mut self,
ex: &Executor,
mem: GuestMemory,
- queue: virtio::Queue,
+ queue: Arc<Mutex<virtio::Queue>>,
doorbell: I,
kick_evt: Event,
) -> anyhow::Result<()> {
@@ -183,7 +185,7 @@
let tx_future = |mut output, abort| {
Ok(async move {
select2(
- run_tx_queue(queue, mem, doorbell, kick_evt, &mut output).boxed_local(),
+ run_tx_queue(&queue, mem, doorbell, kick_evt, &mut output).boxed_local(),
abort,
)
.await;
@@ -317,6 +319,8 @@
self.state =
VirtioConsoleState::Running(WorkerThread::start("v_console", move |kill_evt| {
let mut console = console;
+ let receive_queue = Arc::new(Mutex::new(receive_queue));
+ let transmit_queue = Arc::new(Mutex::new(transmit_queue));
console.start_receive_queue(
&ex,
diff --git a/devices/src/virtio/fs/worker.rs b/devices/src/virtio/fs/worker.rs
index e6c0a8b..475b9ec 100644
--- a/devices/src/virtio/fs/worker.rs
+++ b/devices/src/virtio/fs/worker.rs
@@ -2,11 +2,13 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+use std::cell::RefCell;
use std::convert::TryFrom;
use std::convert::TryInto;
use std::fs::File;
use std::io;
use std::os::unix::io::AsRawFd;
+use std::rc::Rc;
use std::sync::Arc;
use base::error;
@@ -142,7 +144,7 @@
pub struct Worker<F: FileSystem + Sync> {
mem: GuestMemory,
- queue: Queue,
+ queue: Rc<RefCell<Queue>>,
server: Arc<fuse::Server<F>>,
irq: Interrupt,
tube: Arc<Mutex<Tube>>,
@@ -152,12 +154,13 @@
pub fn process_fs_queue<I: SignalableInterrupt, F: FileSystem + Sync>(
mem: &GuestMemory,
interrupt: &I,
- queue: &mut Queue,
+ queue: &Rc<RefCell<Queue>>,
server: &Arc<fuse::Server<F>>,
tube: &Arc<Mutex<Tube>>,
slot: u32,
) -> Result<()> {
let mapper = Mapper::new(Arc::clone(tube), slot);
+ let mut queue = queue.borrow_mut();
while let Some(mut avail_desc) = queue.pop(mem) {
let total =
server.handle_message(&mut avail_desc.reader, &mut avail_desc.writer, &mapper)?;
@@ -180,7 +183,7 @@
) -> Worker<F> {
Worker {
mem,
- queue,
+ queue: Rc::new(RefCell::new(queue)),
server,
irq,
tube,
@@ -245,7 +248,7 @@
if let Err(e) = process_fs_queue(
&self.mem,
&self.irq,
- &mut self.queue,
+ &self.queue,
&self.server,
&self.tube,
self.slot,
diff --git a/devices/src/virtio/net.rs b/devices/src/virtio/net.rs
index 040a3eb..a71053b 100644
--- a/devices/src/virtio/net.rs
+++ b/devices/src/virtio/net.rs
@@ -8,6 +8,7 @@
use std::net::Ipv4Addr;
use std::os::raw::c_uint;
use std::str::FromStr;
+use std::sync::Arc;
use anyhow::anyhow;
use base::error;
@@ -29,6 +30,7 @@
use remain::sorted;
use serde::Deserialize;
use serde::Serialize;
+use sync::Mutex;
use thiserror::Error as ThisError;
use virtio_sys::virtio_net;
use virtio_sys::virtio_net::virtio_net_hdr_v1;
@@ -269,12 +271,15 @@
pub fn process_ctrl<I: SignalableInterrupt, T: TapT>(
interrupt: &I,
- ctrl_queue: &mut Queue,
+ ctrl_queue: &Arc<Mutex<Queue>>,
mem: &GuestMemory,
tap: &mut T,
acked_features: u64,
vq_pairs: u16,
) -> Result<(), NetError> {
+ let mut ctrl_queue = ctrl_queue
+ .try_lock()
+ .expect("Lock should not be unavailable");
while let Some(mut desc_chain) = ctrl_queue.pop(mem) {
if let Err(e) = process_ctrl_request(&mut desc_chain.reader, tap, acked_features, vq_pairs)
{
@@ -316,9 +321,9 @@
pub(super) struct Worker<T: TapT> {
pub(super) interrupt: Interrupt,
pub(super) mem: GuestMemory,
- pub(super) rx_queue: Queue,
- pub(super) tx_queue: Queue,
- pub(super) ctrl_queue: Option<Queue>,
+ pub(super) rx_queue: Arc<Mutex<Queue>>,
+ pub(super) tx_queue: Arc<Mutex<Queue>>,
+ pub(super) ctrl_queue: Option<Arc<Mutex<Queue>>>,
pub(super) tap: T,
#[cfg(windows)]
pub(super) overlapped_wrapper: OverlappedWrapper,
@@ -339,12 +344,7 @@
T: TapT + ReadNotifier,
{
fn process_tx(&mut self) {
- process_tx(
- &self.interrupt,
- &mut self.tx_queue,
- &self.mem,
- &mut self.tap,
- )
+ process_tx(&self.interrupt, &self.tx_queue, &self.mem, &mut self.tap)
}
fn process_ctrl(&mut self) -> Result<(), NetError> {
@@ -721,7 +721,7 @@
let (tx_queue, tx_queue_evt) = queues.remove(0);
let (ctrl_queue, ctrl_queue_evt) = if first_queue && ctrl_vq_enabled {
let (queue, evt) = queues.remove(queues.len() - 1);
- (Some(queue), Some(evt))
+ (Some(Arc::new(Mutex::new(queue))), Some(evt))
} else {
(None, None)
};
@@ -735,8 +735,8 @@
let mut worker = Worker {
interrupt,
mem: memory,
- rx_queue,
- tx_queue,
+ rx_queue: Arc::new(Mutex::new(rx_queue)),
+ tx_queue: Arc::new(Mutex::new(tx_queue)),
ctrl_queue,
tap,
#[cfg(windows)]
diff --git a/devices/src/virtio/snd/common_backend/async_funcs.rs b/devices/src/virtio/snd/common_backend/async_funcs.rs
index 14f0378..77d5333 100644
--- a/devices/src/virtio/snd/common_backend/async_funcs.rs
+++ b/devices/src/virtio/snd/common_backend/async_funcs.rs
@@ -508,7 +508,7 @@
pub async fn send_pcm_response_worker<I: SignalableInterrupt>(
mem: &GuestMemory,
- queue: &Rc<AsyncMutex<Queue>>,
+ queue: Rc<AsyncMutex<Queue>>,
interrupt: I,
recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
reset_signal: Option<&(AsyncMutex<bool>, Condvar)>,
@@ -552,7 +552,7 @@
mem: &GuestMemory,
streams: &Rc<AsyncMutex<Vec<AsyncMutex<StreamInfo>>>>,
mut response_sender: mpsc::UnboundedSender<PcmResponse>,
- queue: &Rc<AsyncMutex<Queue>>,
+ queue: Rc<AsyncMutex<Queue>>,
queue_event: &EventAsync,
reset_signal: Option<&(AsyncMutex<bool>, Condvar)>,
) -> Result<(), Error> {
@@ -636,7 +636,7 @@
mem: &GuestMemory,
streams: &Rc<AsyncMutex<Vec<AsyncMutex<StreamInfo>>>>,
snd_data: &SndData,
- queue: &mut Queue,
+ queue: Rc<AsyncMutex<Queue>>,
queue_event: &mut EventAsync,
interrupt: I,
tx_send: mpsc::UnboundedSender<PcmResponse>,
@@ -646,6 +646,7 @@
let on_reset = await_reset_signal(reset_signal).fuse();
pin_mut!(on_reset);
+ let mut queue = queue.lock().await;
loop {
let mut desc_chain = {
let next_async = queue.next_async(mem, queue_event).fuse();
diff --git a/devices/src/virtio/snd/common_backend/mod.rs b/devices/src/virtio/snd/common_backend/mod.rs
index 6d3bcc3..dcf2de5 100644
--- a/devices/src/virtio/snd/common_backend/mod.rs
+++ b/devices/src/virtio/snd/common_backend/mod.rs
@@ -513,7 +513,8 @@
})
.collect();
- let (mut ctrl_queue, mut ctrl_queue_evt) = queues.remove(0);
+ let (ctrl_queue, mut ctrl_queue_evt) = queues.remove(0);
+ let ctrl_queue = Rc::new(AsyncMutex::new(ctrl_queue));
let (_event_queue, _event_queue_evt) = queues.remove(0);
let (tx_queue, tx_queue_evt) = queues.remove(0);
let (rx_queue, rx_queue_evt) = queues.remove(0);
@@ -540,13 +541,13 @@
&snd_data,
&mut f_kill,
&mut f_resample,
- &mut ctrl_queue,
+ ctrl_queue.clone(),
&mut ctrl_queue_evt,
- &tx_queue,
+ tx_queue.clone(),
&tx_queue_evt,
tx_send.clone(),
&mut tx_recv,
- &rx_queue,
+ rx_queue.clone(),
&rx_queue_evt,
rx_send.clone(),
&mut rx_recv,
@@ -594,13 +595,13 @@
snd_data: &SndData,
mut f_kill: &mut (impl Future<Output = anyhow::Result<()>> + FusedFuture + Unpin),
mut f_resample: &mut (impl Future<Output = anyhow::Result<()>> + FusedFuture + Unpin),
- ctrl_queue: &mut Queue,
+ ctrl_queue: Rc<AsyncMutex<Queue>>,
ctrl_queue_evt: &mut EventAsync,
- tx_queue: &Rc<AsyncMutex<Queue>>,
+ tx_queue: Rc<AsyncMutex<Queue>>,
tx_queue_evt: &EventAsync,
tx_send: mpsc::UnboundedSender<PcmResponse>,
tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
- rx_queue: &Rc<AsyncMutex<Queue>>,
+ rx_queue: Rc<AsyncMutex<Queue>>,
rx_queue_evt: &EventAsync,
rx_send: mpsc::UnboundedSender<PcmResponse>,
rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
@@ -636,7 +637,7 @@
mem,
streams,
tx_send2,
- tx_queue,
+ tx_queue.clone(),
tx_queue_evt,
Some(&reset_signal),
)
@@ -653,7 +654,7 @@
mem,
streams,
rx_send2,
- rx_queue,
+ rx_queue.clone(),
rx_queue_evt,
Some(&reset_signal),
)
@@ -757,7 +758,7 @@
let f_tx_response = async {
while send_pcm_response_worker(
mem,
- tx_queue,
+ tx_queue.clone(),
interrupt.clone(),
tx_recv,
Some(&reset_signal),
@@ -770,7 +771,7 @@
let f_rx_response = async {
while send_pcm_response_worker(
mem,
- rx_queue,
+ rx_queue.clone(),
interrupt.clone(),
rx_recv,
Some(&reset_signal),
diff --git a/devices/src/virtio/sys/unix/net.rs b/devices/src/virtio/sys/unix/net.rs
index 31ad288..2caa94f 100644
--- a/devices/src/virtio/sys/unix/net.rs
+++ b/devices/src/virtio/sys/unix/net.rs
@@ -4,6 +4,7 @@
use std::io;
use std::result;
+use std::sync::Arc;
use base::error;
use base::warn;
@@ -11,6 +12,7 @@
use base::ReadNotifier;
use base::WaitContext;
use net_util::TapT;
+use sync::Mutex;
use vm_memory::GuestMemory;
use super::super::super::net::NetError;
@@ -21,13 +23,14 @@
pub fn process_rx<I: SignalableInterrupt, T: TapT>(
interrupt: &I,
- rx_queue: &mut Queue,
+ rx_queue: &Arc<Mutex<Queue>>,
mem: &GuestMemory,
mut tap: &mut T,
) -> result::Result<(), NetError> {
let mut needs_interrupt = false;
let mut exhausted_queue = false;
+ let mut rx_queue = rx_queue.try_lock().expect("Lock should not be unavailable");
// Read as many frames as possible.
loop {
let mut desc_chain = match rx_queue.peek(mem) {
@@ -78,10 +81,11 @@
pub fn process_tx<I: SignalableInterrupt, T: TapT>(
interrupt: &I,
- tx_queue: &mut Queue,
+ tx_queue: &Arc<Mutex<Queue>>,
mem: &GuestMemory,
mut tap: &mut T,
) {
+ let mut tx_queue = tx_queue.try_lock().expect("Lock should not be unavailable");
while let Some(mut desc_chain) = tx_queue.pop(mem) {
let reader = &mut desc_chain.reader;
let expected_count = reader.available_bytes();
@@ -137,11 +141,6 @@
Ok(())
}
pub(super) fn process_rx(&mut self) -> result::Result<(), NetError> {
- process_rx(
- &self.interrupt,
- &mut self.rx_queue,
- &self.mem,
- &mut self.tap,
- )
+ process_rx(&self.interrupt, &self.rx_queue, &self.mem, &mut self.tap)
}
}
diff --git a/devices/src/virtio/sys/windows/net.rs b/devices/src/virtio/sys/windows/net.rs
index 4a0f4cb..45b8ea8 100644
--- a/devices/src/virtio/sys/windows/net.rs
+++ b/devices/src/virtio/sys/windows/net.rs
@@ -6,6 +6,8 @@
use std::io::Read;
use std::io::Write;
use std::result;
+use std::sync::Arc;
+use std::sync::MutexGuard;
use base::error;
use base::named_pipes::OverlappedWrapper;
@@ -15,6 +17,7 @@
use base::WaitContext;
use libc::EEXIST;
use net_util::TapT;
+use sync::Mutex;
use virtio_sys::virtio_net;
use vm_memory::GuestMemory;
@@ -33,7 +36,7 @@
// if a buffer was used, and false if the frame must be deferred until a buffer
// is made available by the driver.
fn rx_single_frame(
- rx_queue: &mut Queue,
+ rx_queue: &mut MutexGuard<Queue>,
mem: &GuestMemory,
rx_buf: &mut [u8],
rx_count: usize,
@@ -65,7 +68,7 @@
pub fn process_rx<I: SignalableInterrupt, T: TapT>(
interrupt: &I,
- rx_queue: &mut Queue,
+ rx_queue: &Arc<Mutex<Queue>>,
mem: &GuestMemory,
tap: &mut T,
rx_buf: &mut [u8],
@@ -76,6 +79,7 @@
let mut needs_interrupt = false;
let mut first_frame = true;
+ let mut rx_queue = rx_queue.try_lock().expect("Lock should not be unavailable");
// Read as many frames as possible.
loop {
let res = if *deferred_rx {
@@ -87,7 +91,7 @@
match res {
Ok(count) => {
*rx_count = count;
- if !rx_single_frame(rx_queue, mem, rx_buf, *rx_count) {
+ if !rx_single_frame(&mut rx_queue, mem, rx_buf, *rx_count) {
*deferred_rx = true;
break;
} else if first_frame {
@@ -145,7 +149,7 @@
pub fn process_tx<I: SignalableInterrupt, T: TapT>(
interrupt: &I,
- tx_queue: &mut Queue,
+ tx_queue: &Arc<Mutex<Queue>>,
mem: &GuestMemory,
tap: &mut T,
) {
@@ -164,6 +168,7 @@
Ok(count)
}
+ let mut tx_queue = tx_queue.try_lock().expect("Lock should not be unavailable");
while let Some(mut desc_chain) = tx_queue.pop(mem) {
let mut frame = [0u8; MAX_BUFFER_SIZE];
match read_to_end(&mut desc_chain.reader, &mut frame[..]) {
@@ -226,7 +231,7 @@
pub(super) fn process_rx_slirp(&mut self) -> bool {
process_rx(
&self.interrupt,
- &mut self.rx_queue,
+ &self.rx_queue,
&self.mem,
&mut self.tap,
&mut self.rx_buf,
@@ -245,7 +250,10 @@
// until we manage to receive this deferred frame.
if self.deferred_rx {
if rx_single_frame(
- &mut self.rx_queue,
+ &mut self
+ .rx_queue
+ .try_lock()
+ .expect("Lock should not be unavailable"),
&self.mem,
&mut self.rx_buf,
self.rx_count,
@@ -264,7 +272,12 @@
}
needs_interrupt |= self.process_rx_slirp();
if needs_interrupt {
- self.interrupt.signal_used_queue(self.rx_queue.vector());
+ self.interrupt.signal_used_queue(
+ self.rx_queue
+ .try_lock()
+ .expect("Lock should not be unavailable")
+ .vector(),
+ );
}
Ok(())
}
@@ -274,14 +287,13 @@
wait_ctx: &WaitContext<Token>,
_tap_polling_enabled: bool,
) -> result::Result<(), NetError> {
+ let mut rx_queue = self
+ .rx_queue
+ .try_lock()
+ .expect("Lock should not be unavailable");
// There should be a buffer available now to receive the frame into.
if self.deferred_rx
- && rx_single_frame(
- &mut self.rx_queue,
- &self.mem,
- &mut self.rx_buf,
- self.rx_count,
- )
+ && rx_single_frame(&mut rx_queue, &self.mem, &mut self.rx_buf, self.rx_count)
{
// The guest has made buffers available, so add the tap back to the
// poll context in case it was removed.
@@ -293,7 +305,7 @@
}
}
self.deferred_rx = false;
- self.interrupt.signal_used_queue(self.rx_queue.vector());
+ self.interrupt.signal_used_queue(rx_queue.vector());
}
Ok(())
}
diff --git a/devices/src/virtio/vhost/user/device/block.rs b/devices/src/virtio/vhost/user/device/block.rs
index 35a56b7..13976c3 100644
--- a/devices/src/virtio/vhost/user/device/block.rs
+++ b/devices/src/virtio/vhost/user/device/block.rs
@@ -41,11 +41,14 @@
use crate::virtio::copy_config;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
use crate::virtio::vhost::user::device::handler::DeviceRequestHandler;
+use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostBackendReqConnection;
use crate::virtio::vhost::user::device::handler::VhostBackendReqConnectionState;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
use crate::virtio::vhost::user::device::handler::VhostUserPlatformOps;
+use crate::virtio::vhost::user::device::handler::WorkerState;
use crate::virtio::vhost::user::device::VhostUserDevice;
+use crate::virtio::Queue;
const NUM_QUEUES: u16 = 16;
@@ -61,7 +64,7 @@
flush_timer: Rc<RefCell<TimerAsync>>,
flush_timer_armed: Rc<RefCell<bool>>,
backend_req_conn: Arc<Mutex<VhostBackendReqConnectionState>>,
- workers: [Option<AbortHandle>; NUM_QUEUES as usize],
+ workers: [Option<WorkerState<Rc<RefCell<Queue>>, ()>>; NUM_QUEUES as usize],
}
impl VhostUserDevice for BlockAsync {
@@ -213,9 +216,9 @@
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
+ if self.workers[idx].is_some() {
warn!("Starting new queue handler without stopping old handler");
- handle.abort();
+ self.stop_queue(idx)?;
}
let kick_evt = EventAsync::new(kick_evt, &self.ex)
@@ -225,28 +228,43 @@
let disk_state = Rc::clone(&self.disk_state);
let timer = Rc::clone(&self.flush_timer);
let timer_armed = Rc::clone(&self.flush_timer_armed);
- self.ex
- .spawn_local(Abortable::new(
- handle_queue(
- mem,
- disk_state,
- Rc::new(RefCell::new(queue)),
- kick_evt,
- doorbell,
- timer,
- timer_armed,
- ),
- registration,
- ))
- .detach();
+ let queue = Rc::new(RefCell::new(queue));
+ let queue_task = self.ex.spawn_local(Abortable::new(
+ handle_queue(
+ mem,
+ disk_state,
+ queue.clone(),
+ kick_evt,
+ doorbell,
+ timer,
+ timer_armed,
+ ),
+ registration,
+ ));
- self.workers[idx] = Some(handle);
+ self.workers[idx] = Some(WorkerState {
+ abort_handle: handle,
+ queue_task,
+ queue,
+ });
Ok(())
}
- fn stop_queue(&mut self, idx: usize) {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
- handle.abort();
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
+ if let Some(worker) = self.workers.get_mut(idx).and_then(Option::take) {
+ worker.abort_handle.abort();
+
+ // Wait for queue_task to be aborted.
+ let _ = self.ex.run_until(async { worker.queue_task.await });
+
+ let queue = match Rc::try_unwrap(worker.queue) {
+ Ok(queue_cell) => queue_cell.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
}
}
diff --git a/devices/src/virtio/vhost/user/device/console.rs b/devices/src/virtio/vhost/user/device/console.rs
index e53e49a..ba9bb5c 100644
--- a/devices/src/virtio/vhost/user/device/console.rs
+++ b/devices/src/virtio/vhost/user/device/console.rs
@@ -3,6 +3,7 @@
// found in the LICENSE file.
use std::path::PathBuf;
+use std::sync::Arc;
use anyhow::anyhow;
use anyhow::bail;
@@ -14,6 +15,7 @@
use base::Terminal;
use cros_async::Executor;
use hypervisor::ProtectionType;
+use sync::Mutex;
use vm_memory::GuestMemory;
use vmm_vhost::message::VhostUserProtocolFeatures;
use vmm_vhost::message::VhostUserVirtioFeatures;
@@ -26,11 +28,13 @@
use crate::virtio::copy_config;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
use crate::virtio::vhost::user::device::handler::DeviceRequestHandler;
+use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
use crate::virtio::vhost::user::device::handler::VhostUserPlatformOps;
use crate::virtio::vhost::user::device::listener::sys::VhostUserListener;
use crate::virtio::vhost::user::device::listener::VhostUserListenerTrait;
use crate::virtio::vhost::user::device::VhostUserDevice;
+use crate::virtio::Queue;
use crate::SerialHardware;
use crate::SerialParameters;
use crate::SerialType;
@@ -79,6 +83,8 @@
acked_features: 0,
acked_protocol_features: VhostUserProtocolFeatures::empty(),
ex: ex.clone(),
+ active_in_queue: None,
+ active_out_queue: None,
};
let handler = DeviceRequestHandler::new(Box::new(backend), ops);
@@ -91,6 +97,8 @@
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
ex: Executor,
+ active_in_queue: Option<Arc<Mutex<Queue>>>,
+ active_out_queue: Option<Arc<Mutex<Queue>>>,
}
impl VhostUserBackend for ConsoleBackend {
@@ -143,7 +151,9 @@
fn reset(&mut self) {
for queue_num in 0..self.max_queue_num() {
- self.stop_queue(queue_num);
+ if let Err(e) = self.stop_queue(queue_num) {
+ error!("Failed to stop_queue during reset: {}", e);
+ }
}
}
@@ -155,35 +165,71 @@
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
+ let queue = Arc::new(Mutex::new(queue));
match idx {
// ReceiveQueue
- 0 => self
- .device
- .console
- .start_receive_queue(&self.ex, mem, queue, doorbell, kick_evt),
+ 0 => {
+ let res = self.device.console.start_receive_queue(
+ &self.ex,
+ mem,
+ queue.clone(),
+ doorbell,
+ kick_evt,
+ );
+ self.active_in_queue = Some(queue);
+ res
+ }
// TransmitQueue
- 1 => self
- .device
- .console
- .start_transmit_queue(&self.ex, mem, queue, doorbell, kick_evt),
+ 1 => {
+ let res = self.device.console.start_transmit_queue(
+ &self.ex,
+ mem,
+ queue.clone(),
+ doorbell,
+ kick_evt,
+ );
+ self.active_out_queue = Some(queue);
+ res
+ }
_ => bail!("attempted to start unknown queue: {}", idx),
}
}
- fn stop_queue(&mut self, idx: usize) {
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
match idx {
0 => {
if let Err(e) = self.device.console.stop_receive_queue() {
error!("error while stopping rx queue: {}", e);
}
+ if let Some(active_in_queue) = self.active_in_queue.take() {
+ let queue = match Arc::try_unwrap(active_in_queue) {
+ Ok(queue_mutex) => queue_mutex.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
+ }
}
1 => {
if let Err(e) = self.device.console.stop_transmit_queue() {
error!("error while stopping tx queue: {}", e);
}
+ if let Some(active_out_queue) = self.active_out_queue.take() {
+ let queue = match Arc::try_unwrap(active_out_queue) {
+ Ok(queue_mutex) => queue_mutex.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
+ }
}
- _ => error!("attempted to stop unknown queue: {}", idx),
- };
+ _ => {
+ error!("attempted to stop unknown queue: {}", idx);
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
+ }
+ }
}
}
diff --git a/devices/src/virtio/vhost/user/device/fs.rs b/devices/src/virtio/vhost/user/device/fs.rs
index 765b7f9..65d98e3 100644
--- a/devices/src/virtio/vhost/user/device/fs.rs
+++ b/devices/src/virtio/vhost/user/device/fs.rs
@@ -4,7 +4,9 @@
mod sys;
+use std::cell::RefCell;
use std::path::PathBuf;
+use std::rc::Rc;
use std::sync::Arc;
use anyhow::anyhow;
@@ -39,12 +41,15 @@
use crate::virtio::fs::passthrough::PassthroughFs;
use crate::virtio::fs::process_fs_queue;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
+use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
+use crate::virtio::vhost::user::device::handler::WorkerState;
+use crate::virtio::Queue;
const MAX_QUEUE_NUM: usize = 2; /* worker queue and high priority queue */
async fn handle_fs_queue(
- mut queue: virtio::Queue,
+ queue: Rc<RefCell<virtio::Queue>>,
mem: GuestMemory,
doorbell: Doorbell,
kick_evt: EventAsync,
@@ -59,7 +64,7 @@
error!("Failed to read kick event for fs queue: {}", e);
break;
}
- if let Err(e) = process_fs_queue(&mem, &doorbell, &mut queue, &server, &tube, slot) {
+ if let Err(e) = process_fs_queue(&mem, &doorbell, &queue, &server, &tube, slot) {
error!("Process FS queue failed: {}", e);
break;
}
@@ -73,7 +78,7 @@
avail_features: u64,
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
- workers: [Option<AbortHandle>; MAX_QUEUE_NUM],
+ workers: [Option<WorkerState<Rc<RefCell<Queue>>, ()>>; MAX_QUEUE_NUM],
keep_rds: Vec<RawDescriptor>,
}
@@ -165,8 +170,8 @@
}
fn reset(&mut self) {
- for handle in self.workers.iter_mut().filter_map(Option::take) {
- handle.abort();
+ for worker in self.workers.iter_mut().filter_map(Option::take) {
+ worker.abort_handle.abort();
}
}
@@ -178,9 +183,9 @@
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
+ if self.workers[idx].is_some() {
warn!("Starting new queue handler without stopping old handler");
- handle.abort();
+ self.stop_queue(idx)?;
}
let kick_evt = EventAsync::new(kick_evt, &self.ex)
@@ -188,27 +193,42 @@
let (handle, registration) = AbortHandle::new_pair();
let (_, fs_device_tube) = Tube::pair()?;
- self.ex
- .spawn_local(Abortable::new(
- handle_fs_queue(
- queue,
- mem,
- doorbell,
- kick_evt,
- self.server.clone(),
- Arc::new(Mutex::new(fs_device_tube)),
- ),
- registration,
- ))
- .detach();
+ let queue = Rc::new(RefCell::new(queue));
+ let queue_task = self.ex.spawn_local(Abortable::new(
+ handle_fs_queue(
+ queue.clone(),
+ mem,
+ doorbell,
+ kick_evt,
+ self.server.clone(),
+ Arc::new(Mutex::new(fs_device_tube)),
+ ),
+ registration,
+ ));
- self.workers[idx] = Some(handle);
+ self.workers[idx] = Some(WorkerState {
+ abort_handle: handle,
+ queue_task,
+ queue,
+ });
Ok(())
}
- fn stop_queue(&mut self, idx: usize) {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
- handle.abort();
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
+ if let Some(worker) = self.workers.get_mut(idx).and_then(Option::take) {
+ worker.abort_handle.abort();
+
+ // Wait for queue_task to be aborted.
+ let _ = self.ex.run_until(async { worker.queue_task.await });
+
+ let queue = match Rc::try_unwrap(worker.queue) {
+ Ok(queue_cell) => queue_cell.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
}
}
}
diff --git a/devices/src/virtio/vhost/user/device/gpu.rs b/devices/src/virtio/vhost/user/device/gpu.rs
index ab4f360..1a006e8 100644
--- a/devices/src/virtio/vhost/user/device/gpu.rs
+++ b/devices/src/virtio/vhost/user/device/gpu.rs
@@ -28,9 +28,11 @@
use crate::virtio::gpu;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
+use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostBackendReqConnection;
use crate::virtio::vhost::user::device::handler::VhostBackendReqConnectionState;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
+use crate::virtio::vhost::user::device::handler::WorkerState;
use crate::virtio::DescriptorChain;
use crate::virtio::Gpu;
use crate::virtio::Queue;
@@ -88,7 +90,7 @@
acked_protocol_features: u64,
state: Option<Rc<RefCell<gpu::Frontend>>>,
fence_state: Arc<Mutex<gpu::FenceState>>,
- queue_workers: [Option<AbortHandle>; MAX_QUEUE_NUM],
+ queue_workers: [Option<WorkerState<Arc<Mutex<Queue>>, ()>>; MAX_QUEUE_NUM],
platform_workers: Rc<RefCell<Vec<AbortHandle>>>,
backend_req_conn: VhostBackendReqConnectionState,
}
@@ -148,9 +150,9 @@
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
- if let Some(handle) = self.queue_workers.get_mut(idx).and_then(Option::take) {
+ if self.queue_workers[idx].is_some() {
warn!("Starting new queue handler without stopping old handler");
- handle.abort();
+ self.stop_queue(idx)?;
}
match idx {
@@ -164,8 +166,9 @@
let kick_evt = EventAsync::new(kick_evt, &self.ex)
.context("failed to create EventAsync for kick_evt")?;
+ let queue = Arc::new(Mutex::new(queue));
let reader = SharedReader {
- queue: Arc::new(Mutex::new(queue)),
+ queue: queue.clone(),
doorbell,
};
@@ -201,20 +204,34 @@
// Start handling the control queue.
let (handle, registration) = AbortHandle::new_pair();
- self.ex
- .spawn_local(Abortable::new(
- run_ctrl_queue(reader, mem, kick_evt, state),
- registration,
- ))
- .detach();
+ let queue_task = self.ex.spawn_local(Abortable::new(
+ run_ctrl_queue(reader, mem, kick_evt, state),
+ registration,
+ ));
- self.queue_workers[idx] = Some(handle);
+ self.queue_workers[idx] = Some(WorkerState {
+ abort_handle: handle,
+ queue_task,
+ queue,
+ });
Ok(())
}
- fn stop_queue(&mut self, idx: usize) {
- if let Some(handle) = self.queue_workers.get_mut(idx).and_then(Option::take) {
- handle.abort();
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
+ if let Some(worker) = self.queue_workers.get_mut(idx).and_then(Option::take) {
+ worker.abort_handle.abort();
+
+ // Wait for queue_task to be aborted.
+ let _ = self.ex.run_until(async { worker.queue_task.await });
+
+ let queue = match Arc::try_unwrap(worker.queue) {
+ Ok(queue_mutex) => queue_mutex.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
}
}
@@ -224,7 +241,9 @@
}
for queue_num in 0..self.max_queue_num() {
- self.stop_queue(queue_num);
+ if let Err(e) = self.stop_queue(queue_num) {
+ error!("Failed to stop_queue during reset: {}", e);
+ }
}
}
diff --git a/devices/src/virtio/vhost/user/device/handler.rs b/devices/src/virtio/vhost/user/device/handler.rs
index 6eea9d1..dd7b2b2 100644
--- a/devices/src/virtio/vhost/user/device/handler.rs
+++ b/devices/src/virtio/vhost/user/device/handler.rs
@@ -67,7 +67,11 @@
use base::Protection;
use base::SafeDescriptor;
use base::SharedMemory;
+use cros_async::TaskHandle;
+use futures::future::AbortHandle;
+use futures::future::Aborted;
use sys::Doorbell;
+use thiserror::Error as ThisError;
use vm_control::VmMemorySource;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
@@ -201,7 +205,8 @@
) -> anyhow::Result<()>;
/// Indicates that the backend should stop processing requests for virtio queue number `idx`.
- fn stop_queue(&mut self, idx: usize);
+ /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
/// Resets the vhost-user backend.
fn reset(&mut self);
@@ -532,7 +537,9 @@
// that file descriptor is readable) on the descriptor specified by
// VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
// VHOST_USER_GET_VRING_BASE.
- self.backend.stop_queue(index as usize);
+ if let Err(e) = self.backend.stop_queue(index as usize) {
+ error!("Failed to stop queue in get_vring_base: {}", e);
+ }
let vring = &mut self.vrings[index as usize];
vring.reset();
@@ -823,6 +830,19 @@
}
}
+pub(crate) struct WorkerState<T, U> {
+ pub(crate) abort_handle: AbortHandle,
+ pub(crate) queue_task: TaskHandle<std::result::Result<U, Aborted>>,
+ pub(crate) queue: T,
+}
+
+/// Errors for device operations
+#[derive(Debug, ThisError)]
+pub enum Error {
+ #[error("worker not found when stopping queue")]
+ WorkerNotFound,
+}
+
#[cfg(test)]
mod tests {
#[cfg(unix)]
@@ -932,7 +952,10 @@
Ok(())
}
- fn stop_queue(&mut self, _idx: usize) {}
+ fn stop_queue(&mut self, _idx: usize) -> anyhow::Result<Queue> {
+ // TODO(280607609): Return a `Queue`.
+ Err(anyhow!("Missing queue"))
+ }
}
#[cfg(unix)]
diff --git a/devices/src/virtio/vhost/user/device/net.rs b/devices/src/virtio/vhost/user/device/net.rs
index 1f74062..814b7dd 100644
--- a/devices/src/virtio/vhost/user/device/net.rs
+++ b/devices/src/virtio/vhost/user/device/net.rs
@@ -4,6 +4,8 @@
pub mod sys;
+use std::sync::Arc;
+
use anyhow::anyhow;
use anyhow::bail;
use anyhow::Context;
@@ -12,9 +14,9 @@
use cros_async::EventAsync;
use cros_async::Executor;
use cros_async::IntoAsync;
-use futures::future::AbortHandle;
use net_util::TapT;
use once_cell::sync::OnceCell;
+use sync::Mutex;
pub use sys::start_device as run_net_device;
pub use sys::Options;
use vm_memory::GuestMemory;
@@ -27,7 +29,10 @@
use crate::virtio::net::process_tx;
use crate::virtio::net::virtio_features_to_tap_offload;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
+use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
+use crate::virtio::vhost::user::device::handler::WorkerState;
+use crate::virtio::Queue;
thread_local! {
pub(crate) static NET_EXECUTOR: OnceCell<Executor> = OnceCell::new();
@@ -38,7 +43,7 @@
const MAX_QUEUE_NUM: usize = 3; /* rx, tx, ctrl */
async fn run_tx_queue<T: TapT>(
- mut queue: virtio::Queue,
+ queue: Arc<Mutex<Queue>>,
mem: GuestMemory,
mut tap: T,
doorbell: Doorbell,
@@ -50,12 +55,12 @@
break;
}
- process_tx(&doorbell, &mut queue, &mem, &mut tap);
+ process_tx(&doorbell, &queue, &mem, &mut tap);
}
}
async fn run_ctrl_queue<T: TapT>(
- mut queue: virtio::Queue,
+ queue: Arc<Mutex<Queue>>,
mem: GuestMemory,
mut tap: T,
doorbell: Doorbell,
@@ -69,14 +74,7 @@
break;
}
- if let Err(e) = process_ctrl(
- &doorbell,
- &mut queue,
- &mem,
- &mut tap,
- acked_features,
- vq_pairs,
- ) {
+ if let Err(e) = process_ctrl(&doorbell, &queue, &mem, &mut tap, acked_features, vq_pairs) {
error!("Failed to process ctrl queue: {}", e);
break;
}
@@ -88,7 +86,7 @@
avail_features: u64,
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
- workers: [Option<AbortHandle>; MAX_QUEUE_NUM],
+ workers: [Option<WorkerState<Arc<Mutex<Queue>>, ()>>; MAX_QUEUE_NUM],
mtu: u16,
#[cfg(all(windows, feature = "slirp"))]
slirp_kill_event: Event,
@@ -168,9 +166,24 @@
sys::start_queue(self, idx, queue, mem, doorbell, kick_evt)
}
- fn stop_queue(&mut self, idx: usize) {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
- handle.abort();
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
+ if let Some(worker) = self.workers.get_mut(idx).and_then(Option::take) {
+ worker.abort_handle.abort();
+
+ // Wait for queue_task to be aborted.
+ NET_EXECUTOR.with(|ex| {
+ let ex = ex.get().expect("Executor not initialized");
+ let _ = ex.run_until(async { worker.queue_task.await });
+ });
+
+ let queue = match Arc::try_unwrap(worker.queue) {
+ Ok(queue_mutex) => queue_mutex.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
}
}
}
diff --git a/devices/src/virtio/vhost/user/device/net/sys/unix.rs b/devices/src/virtio/vhost/user/device/net/sys/unix.rs
index 0a63935..e24cf5e 100644
--- a/devices/src/virtio/vhost/user/device/net/sys/unix.rs
+++ b/devices/src/virtio/vhost/user/device/net/sys/unix.rs
@@ -4,6 +4,7 @@
use std::net::Ipv4Addr;
use std::str::FromStr;
+use std::sync::Arc;
use std::thread;
use anyhow::anyhow;
@@ -26,6 +27,7 @@
use net_util::sys::unix::Tap;
use net_util::MacAddress;
use net_util::TapT;
+use sync::Mutex;
use virtio_sys::virtio_net;
use vm_memory::GuestMemory;
use vmm_vhost::message::VhostUserProtocolFeatures;
@@ -37,6 +39,7 @@
use crate::virtio::net::NetError;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
+use crate::virtio::vhost::user::device::handler::WorkerState;
use crate::virtio::vhost::user::device::listener::sys::VhostUserListener;
use crate::virtio::vhost::user::device::listener::VhostUserListenerTrait;
use crate::virtio::vhost::user::device::net::run_ctrl_queue;
@@ -135,7 +138,7 @@
}
async fn run_rx_queue<T: TapT>(
- mut queue: virtio::Queue,
+ queue: Arc<Mutex<virtio::Queue>>,
mem: GuestMemory,
mut tap: IoSource<T>,
doorbell: Doorbell,
@@ -146,7 +149,7 @@
error!("Failed to wait for tap device to become readable: {}", e);
break;
}
- match process_rx(&doorbell, &mut queue, &mem, tap.as_source_mut()) {
+ match process_rx(&doorbell, &queue, &mem, tap.as_source_mut()) {
Ok(()) => {}
Err(NetError::RxDescriptorsExhausted) => {
if let Err(e) = kick_evt.next_val().await {
@@ -171,9 +174,9 @@
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
- if let Some(handle) = backend.workers.get_mut(idx).and_then(Option::take) {
+ if backend.workers[idx].is_some() {
warn!("Starting new queue handler without stopping old handler");
- handle.abort();
+ backend.stop_queue(idx)?;
}
NET_EXECUTOR.with(|ex| {
@@ -187,29 +190,26 @@
.try_clone()
.context("failed to clone tap device")?;
let (handle, registration) = AbortHandle::new_pair();
- match idx {
+ let queue = Arc::new(Mutex::new(queue));
+ let queue_task = match idx {
0 => {
let tap = ex
.async_from(tap)
.context("failed to create async tap device")?;
ex.spawn_local(Abortable::new(
- run_rx_queue(queue, mem, tap, doorbell, kick_evt),
+ run_rx_queue(queue.clone(), mem, tap, doorbell, kick_evt),
registration,
))
- .detach();
}
- 1 => {
- ex.spawn_local(Abortable::new(
- run_tx_queue(queue, mem, tap, doorbell, kick_evt),
- registration,
- ))
- .detach();
- }
+ 1 => ex.spawn_local(Abortable::new(
+ run_tx_queue(queue.clone(), mem, tap, doorbell, kick_evt),
+ registration,
+ )),
2 => {
ex.spawn_local(Abortable::new(
run_ctrl_queue(
- queue,
+ queue.clone(),
mem,
tap,
doorbell,
@@ -219,12 +219,15 @@
),
registration,
))
- .detach();
}
_ => bail!("attempted to start unknown queue: {}", idx),
- }
+ };
- backend.workers[idx] = Some(handle);
+ backend.workers[idx] = Some(WorkerState {
+ abort_handle: handle,
+ queue_task,
+ queue,
+ });
Ok(())
})
}
diff --git a/devices/src/virtio/vhost/user/device/net/sys/windows.rs b/devices/src/virtio/vhost/user/device/net/sys/windows.rs
index ea65d09..265d9c2 100644
--- a/devices/src/virtio/vhost/user/device/net/sys/windows.rs
+++ b/devices/src/virtio/vhost/user/device/net/sys/windows.rs
@@ -47,11 +47,14 @@
use crate::virtio::vhost::user::device::handler::sys::windows::run_handler;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
use crate::virtio::vhost::user::device::handler::DeviceRequestHandler;
+use crate::virtio::vhost::user::device::handler::VhostUserBackend;
use crate::virtio::vhost::user::device::handler::VhostUserRegularOps;
+use crate::virtio::vhost::user::device::handler::WorkerState;
use crate::virtio::vhost::user::device::net::run_ctrl_queue;
use crate::virtio::vhost::user::device::net::run_tx_queue;
use crate::virtio::vhost::user::device::net::NetBackend;
use crate::virtio::vhost::user::device::net::NET_EXECUTOR;
+use crate::virtio::Queue;
use crate::virtio::SignalableInterrupt;
impl<T: 'static> NetBackend<T>
@@ -81,7 +84,7 @@
}
async fn run_rx_queue<T: TapT>(
- mut queue: virtio::Queue,
+ mut queue: Arc<Mutex<virtio::Queue>>,
mem: GuestMemory,
mut tap: IoSource<T>,
call_evt: Doorbell,
@@ -101,6 +104,7 @@
.read_overlapped(&mut rx_buf, &mut overlapped_wrapper)
.expect("read_overlapped failed")
};
+ // let queue = queue.try_lock().expect("Lock should not be unavailable");
loop {
// If we already have a packet from deferred RX, we don't need to wait for the slirp device.
if !deferred_rx {
@@ -112,7 +116,7 @@
let needs_interrupt = process_rx(
&call_evt,
- &mut queue,
+ &queue,
&mem,
tap.as_source_mut(),
&mut rx_buf,
@@ -121,7 +125,12 @@
&mut overlapped_wrapper,
);
if needs_interrupt {
- call_evt.signal_used_queue(queue.vector());
+ call_evt.signal_used_queue(
+ queue
+ .try_lock()
+ .expect("Lock should not be unavailable")
+ .vector(),
+ );
}
// There aren't any RX descriptors available for us to write packets to. Wait for the guest
@@ -144,9 +153,9 @@
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
- if let Some(handle) = backend.workers.get_mut(idx).and_then(Option::take) {
+ if backend.workers.get(idx).is_some() {
warn!("Starting new queue handler without stopping old handler");
- handle.abort();
+ backend.stop_queue(idx);
}
let overlapped_wrapper =
@@ -163,7 +172,8 @@
.try_clone()
.context("failed to clone tap device")?;
let (handle, registration) = AbortHandle::new_pair();
- match idx {
+ let queue = Arc::new(Mutex::new(queue));
+ let queue_task = match idx {
0 => {
let tap = ex
.async_from(tap)
@@ -178,7 +188,7 @@
ex.spawn_local(Abortable::new(
run_rx_queue(
- queue,
+ queue.clone(),
mem,
tap,
doorbell,
@@ -188,19 +198,15 @@
),
registration,
))
- .detach();
}
- 1 => {
- ex.spawn_local(Abortable::new(
- run_tx_queue(queue, mem, tap, doorbell, kick_evt),
- registration,
- ))
- .detach();
- }
+ 1 => ex.spawn_local(Abortable::new(
+ run_tx_queue(queue.clone(), mem, tap, doorbell, kick_evt),
+ registration,
+ )),
2 => {
ex.spawn_local(Abortable::new(
run_ctrl_queue(
- queue,
+ queue.clone(),
mem,
tap,
doorbell,
@@ -210,12 +216,15 @@
),
registration,
))
- .detach();
}
_ => bail!("attempted to start unknown queue: {}", idx),
- }
+ };
- backend.workers[idx] = Some(handle);
+ backend.workers[idx] = Some(WorkerState {
+ abort_handle: handle,
+ queue_task,
+ queue,
+ });
Ok(())
})
}
diff --git a/devices/src/virtio/vhost/user/device/snd.rs b/devices/src/virtio/vhost/user/device/snd.rs
index 3c1cc4b..dcc5ffd 100644
--- a/devices/src/virtio/vhost/user/device/snd.rs
+++ b/devices/src/virtio/vhost/user/device/snd.rs
@@ -38,14 +38,18 @@
use crate::virtio::snd::common_backend::hardcoded_virtio_snd_config;
use crate::virtio::snd::common_backend::stream_info::StreamInfo;
use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
+use crate::virtio::snd::common_backend::Error;
use crate::virtio::snd::common_backend::PcmResponse;
use crate::virtio::snd::common_backend::SndData;
use crate::virtio::snd::common_backend::MAX_QUEUE_NUM;
use crate::virtio::snd::parameters::Parameters;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
use crate::virtio::vhost::user::device::handler::DeviceRequestHandler;
+use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
+use crate::virtio::vhost::user::device::handler::WorkerState;
use crate::virtio::vhost::user::VhostUserDevice;
+use crate::virtio::Queue;
static SND_EXECUTOR: OnceCell<Executor> = OnceCell::new();
@@ -60,8 +64,9 @@
avail_features: u64,
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
- workers: [Option<AbortHandle>; MAX_QUEUE_NUM],
- response_workers: [Option<AbortHandle>; 2], // tx and rx
+ workers: [Option<WorkerState<Rc<AsyncMutex<Queue>>, Result<(), Error>>>; MAX_QUEUE_NUM],
+ // tx and rx
+ response_workers: [Option<WorkerState<Rc<AsyncMutex<Queue>>, Result<(), Error>>>; 2],
snd_data: Rc<SndData>,
streams: Rc<AsyncMutex<Vec<AsyncMutex<StreamInfo>>>>,
tx_send: mpsc::UnboundedSender<PcmResponse>,
@@ -175,22 +180,22 @@
}
fn reset(&mut self) {
- for handle in self.workers.iter_mut().filter_map(Option::take) {
- handle.abort();
+ for worker in self.workers.iter_mut().filter_map(Option::take) {
+ worker.abort_handle.abort();
}
}
fn start_queue(
&mut self,
idx: usize,
- mut queue: virtio::Queue,
+ queue: virtio::Queue,
mem: GuestMemory,
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
+ if self.workers[idx].is_some() {
warn!("Starting new queue handler without stopping old handler");
- handle.abort();
+ self.stop_queue(idx)?;
}
// Safe because the executor is initialized in main() below.
@@ -199,21 +204,23 @@
let mut kick_evt =
EventAsync::new(kick_evt, ex).context("failed to create EventAsync for kick_evt")?;
let (handle, registration) = AbortHandle::new_pair();
- match idx {
+ let queue = Rc::new(AsyncMutex::new(queue));
+ let queue_task = match idx {
0 => {
// ctrl queue
let streams = self.streams.clone();
let snd_data = self.snd_data.clone();
let tx_send = self.tx_send.clone();
let rx_send = self.rx_send.clone();
- ex.spawn_local(Abortable::new(
+ let ctrl_queue = queue.clone();
+ Some(ex.spawn_local(Abortable::new(
async move {
handle_ctrl_queue(
ex,
&mem,
&streams,
&snd_data,
- &mut queue,
+ ctrl_queue,
&mut kick_evt,
doorbell,
tx_send,
@@ -223,10 +230,9 @@
.await
},
registration,
- ))
- .detach();
+ )))
}
- 1 => {} // TODO(woodychow): Add event queue support
+ 1 => None, // TODO(woodychow): Add event queue support
2 | 3 => {
let (send, recv) = if idx == 2 {
(self.tx_send.clone(), self.tx_recv.take())
@@ -234,50 +240,85 @@
(self.rx_send.clone(), self.rx_recv.take())
};
let mut recv = recv.ok_or_else(|| anyhow!("queue restart is not supported"))?;
- let queue = Rc::new(AsyncMutex::new(queue));
- let queue2 = Rc::clone(&queue);
let mem = Rc::new(mem);
let mem2 = Rc::clone(&mem);
let streams = Rc::clone(&self.streams);
- ex.spawn_local(Abortable::new(
+ let queue_pcm_queue = queue.clone();
+ let queue_task = ex.spawn_local(Abortable::new(
async move {
- handle_pcm_queue(&mem, &streams, send, &queue, &kick_evt, None).await
+ handle_pcm_queue(&mem, &streams, send, queue_pcm_queue, &kick_evt, None)
+ .await
},
registration,
- ))
- .detach();
+ ));
let (handle2, registration2) = AbortHandle::new_pair();
- ex.spawn_local(Abortable::new(
+ let queue_response_queue = queue.clone();
+ let response_queue_task = ex.spawn_local(Abortable::new(
async move {
- send_pcm_response_worker(&mem2, &queue2, doorbell, &mut recv, None).await
+ send_pcm_response_worker(
+ &mem2,
+ queue_response_queue,
+ doorbell,
+ &mut recv,
+ None,
+ )
+ .await
},
registration2,
- ))
- .detach();
+ ));
- self.response_workers[idx - PCM_RESPONSE_WORKER_IDX_OFFSET] = Some(handle2);
+ self.response_workers[idx - PCM_RESPONSE_WORKER_IDX_OFFSET] = Some(WorkerState {
+ abort_handle: handle2,
+ queue_task: response_queue_task,
+ queue: queue.clone(),
+ });
+
+ Some(queue_task)
}
_ => bail!("attempted to start unknown queue: {}", idx),
- }
+ };
- self.workers[idx] = Some(handle);
+ if let Some(queue_task) = queue_task {
+ self.workers[idx] = Some(WorkerState {
+ abort_handle: handle,
+ queue_task,
+ queue,
+ });
+ }
Ok(())
}
- fn stop_queue(&mut self, idx: usize) {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
- handle.abort();
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
+ let ex = SND_EXECUTOR.get().expect("Executor not initialized");
+ if let Some(worker) = self.workers.get_mut(idx).and_then(Option::take) {
+ worker.abort_handle.abort();
+
+ // Wait for queue_task to be aborted.
+ let _ = ex.run_until(async { worker.queue_task.await });
}
if idx == 2 || idx == 3 {
- if let Some(handle) = self
+ if let Some(worker) = self
.response_workers
.get_mut(idx - PCM_RESPONSE_WORKER_IDX_OFFSET)
.and_then(Option::take)
{
- handle.abort();
+ worker.abort_handle.abort();
+
+ // Wait for queue_task to be aborted.
+ let _ = ex.run_until(async { worker.queue_task.await });
}
}
+ if let Some(worker) = self.workers.get_mut(idx).and_then(Option::take) {
+ let queue = match Rc::try_unwrap(worker.queue) {
+ Ok(queue_mutex) => queue_mutex.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
+ }
}
}
diff --git a/devices/src/virtio/vhost/user/device/wl.rs b/devices/src/virtio/vhost/user/device/wl.rs
index 75de9af..871b386 100644
--- a/devices/src/virtio/vhost/user/device/wl.rs
+++ b/devices/src/virtio/vhost/user/device/wl.rs
@@ -41,9 +41,11 @@
use crate::virtio::device_constants::wl::VIRTIO_WL_F_TRANS_FLAGS;
use crate::virtio::device_constants::wl::VIRTIO_WL_F_USE_SHMEM;
use crate::virtio::vhost::user::device::handler::sys::Doorbell;
+use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostBackendReqConnection;
use crate::virtio::vhost::user::device::handler::VhostBackendReqConnectionState;
use crate::virtio::vhost::user::device::handler::VhostUserBackend;
+use crate::virtio::vhost::user::device::handler::WorkerState;
use crate::virtio::vhost::user::device::listener::sys::VhostUserListener;
use crate::virtio::vhost::user::device::listener::VhostUserListenerTrait;
use crate::virtio::wl;
@@ -53,7 +55,7 @@
const MAX_QUEUE_NUM: usize = QUEUE_SIZES.len();
async fn run_out_queue(
- mut queue: Queue,
+ queue: Rc<RefCell<Queue>>,
mem: GuestMemory,
doorbell: Doorbell,
kick_evt: EventAsync,
@@ -65,12 +67,12 @@
break;
}
- wl::process_out_queue(&doorbell, &mut queue, &mem, &mut wlstate.borrow_mut());
+ wl::process_out_queue(&doorbell, &queue, &mem, &mut wlstate.borrow_mut());
}
}
async fn run_in_queue(
- mut queue: Queue,
+ queue: Rc<RefCell<Queue>>,
mem: GuestMemory,
doorbell: Doorbell,
kick_evt: EventAsync,
@@ -86,7 +88,7 @@
break;
}
- if wl::process_in_queue(&doorbell, &mut queue, &mem, &mut wlstate.borrow_mut())
+ if wl::process_in_queue(&doorbell, &queue, &mem, &mut wlstate.borrow_mut())
== Err(wl::DescriptorsExhausted)
{
if let Err(e) = kick_evt.next_val().await {
@@ -107,7 +109,7 @@
features: u64,
acked_features: u64,
wlstate: Option<Rc<RefCell<wl::WlState>>>,
- workers: [Option<AbortHandle>; MAX_QUEUE_NUM],
+ workers: [Option<WorkerState<Rc<RefCell<Queue>>, ()>>; MAX_QUEUE_NUM],
backend_req_conn: VhostBackendReqConnectionState,
}
@@ -206,9 +208,9 @@
doorbell: Doorbell,
kick_evt: Event,
) -> anyhow::Result<()> {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
+ if self.workers[idx].is_some() {
warn!("Starting new queue handler without stopping old handler");
- handle.abort();
+ self.stop_queue(idx)?;
}
let kick_evt = EventAsync::new(kick_evt, &self.ex)
@@ -259,7 +261,8 @@
Some(state) => state.clone(),
};
let (handle, registration) = AbortHandle::new_pair();
- match idx {
+ let queue = Rc::new(RefCell::new(queue));
+ let queue_task = match idx {
0 => {
let wlstate_ctx = clone_descriptor(wlstate.borrow().wait_ctx())
.map(|fd| {
@@ -273,35 +276,46 @@
.context("failed to create async WaitContext")
})?;
- self.ex
- .spawn_local(Abortable::new(
- run_in_queue(queue, mem, doorbell, kick_evt, wlstate, wlstate_ctx),
- registration,
- ))
- .detach();
+ self.ex.spawn_local(Abortable::new(
+ run_in_queue(queue.clone(), mem, doorbell, kick_evt, wlstate, wlstate_ctx),
+ registration,
+ ))
}
- 1 => {
- self.ex
- .spawn_local(Abortable::new(
- run_out_queue(queue, mem, doorbell, kick_evt, wlstate),
- registration,
- ))
- .detach();
- }
+ 1 => self.ex.spawn_local(Abortable::new(
+ run_out_queue(queue.clone(), mem, doorbell, kick_evt, wlstate),
+ registration,
+ )),
_ => bail!("attempted to start unknown queue: {}", idx),
- }
- self.workers[idx] = Some(handle);
+ };
+ self.workers[idx] = Some(WorkerState {
+ abort_handle: handle,
+ queue_task,
+ queue,
+ });
Ok(())
}
- fn stop_queue(&mut self, idx: usize) {
- if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
- handle.abort();
+ fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
+ if let Some(worker) = self.workers.get_mut(idx).and_then(Option::take) {
+ worker.abort_handle.abort();
+
+ // Wait for queue_task to be aborted.
+ let _ = self.ex.run_until(async { worker.queue_task.await });
+
+ let queue = match Rc::try_unwrap(worker.queue) {
+ Ok(queue_cell) => queue_cell.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
+ Ok(queue)
+ } else {
+ Err(anyhow::Error::new(DeviceError::WorkerNotFound))
}
}
+
fn reset(&mut self) {
- for handle in self.workers.iter_mut().filter_map(Option::take) {
- handle.abort();
+ for worker in self.workers.iter_mut().filter_map(Option::take) {
+ worker.abort_handle.abort();
}
}
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 071dc66..d4551ed 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -1684,7 +1684,7 @@
/// Handle incoming events and forward them to the VM over the input queue.
pub fn process_in_queue<I: SignalableInterrupt>(
interrupt: &I,
- in_queue: &mut Queue,
+ in_queue: &Rc<RefCell<Queue>>,
mem: &GuestMemory,
state: &mut WlState,
) -> ::std::result::Result<(), DescriptorsExhausted> {
@@ -1692,6 +1692,7 @@
let mut needs_interrupt = false;
let mut exhausted_queue = false;
+ let mut in_queue = in_queue.borrow_mut();
loop {
let mut desc = if let Some(d) = in_queue.peek(mem) {
d
@@ -1736,11 +1737,12 @@
/// Handle messages from the output queue and forward them to the display sever, if necessary.
pub fn process_out_queue<I: SignalableInterrupt>(
interrupt: &I,
- out_queue: &mut Queue,
+ out_queue: &Rc<RefCell<Queue>>,
mem: &GuestMemory,
state: &mut WlState,
) {
let mut needs_interrupt = false;
+ let mut out_queue = out_queue.borrow_mut();
while let Some(mut desc) = out_queue.pop(mem) {
let resp = match state.execute(&mut desc.reader) {
Ok(r) => r,
@@ -1767,9 +1769,9 @@
struct Worker {
interrupt: Interrupt,
mem: GuestMemory,
- in_queue: Queue,
+ in_queue: Rc<RefCell<Queue>>,
in_queue_evt: Event,
- out_queue: Queue,
+ out_queue: Rc<RefCell<Queue>>,
out_queue_evt: Event,
state: WlState,
}
@@ -1791,9 +1793,9 @@
Worker {
interrupt,
mem,
- in_queue: in_queue.0,
+ in_queue: Rc::new(RefCell::new(in_queue.0)),
in_queue_evt: in_queue.1,
- out_queue: out_queue.0,
+ out_queue: Rc::new(RefCell::new(out_queue.0)),
out_queue_evt: out_queue.1,
state: WlState::new(
wayland_paths,
@@ -1860,7 +1862,7 @@
let _ = self.out_queue_evt.wait();
process_out_queue(
&self.interrupt,
- &mut self.out_queue,
+ &self.out_queue,
&self.mem,
&mut self.state,
);
@@ -1869,7 +1871,7 @@
Token::State => {
if let Err(DescriptorsExhausted) = process_in_queue(
&self.interrupt,
- &mut self.in_queue,
+ &self.in_queue,
&self.mem,
&mut self.state,
) {
@@ -1892,8 +1894,18 @@
}
}
+ let in_queue = match Rc::try_unwrap(self.in_queue) {
+ Ok(queue_cell) => queue_cell.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
+ let out_queue = match Rc::try_unwrap(self.out_queue) {
+ Ok(queue_cell) => queue_cell.into_inner(),
+ Err(_) => panic!("failed to recover queue from worker"),
+ };
+
Ok(VirtioDeviceSaved {
- queues: vec![self.in_queue, self.out_queue],
+ queues: vec![in_queue, out_queue],
})
}
}
diff --git a/e2e_tests/tests/suspend_resume.rs b/e2e_tests/tests/suspend_resume.rs
index 7d86da7..d0a3e45 100644
--- a/e2e_tests/tests/suspend_resume.rs
+++ b/e2e_tests/tests/suspend_resume.rs
@@ -114,6 +114,7 @@
let _block_vu_device = VhostUserBackend::new(block_vu_config);
// Spin up net vhost user process.
+ // Queue handlers don't get activated currently.
let net_socket = NamedTempFile::new().unwrap();
let net_config = create_net_config(net_socket.path());
let _net_vu_device = VhostUserBackend::new(net_config).unwrap();