balloon: add event registration mechanism for sidecar processes.

this patch introduces a mechanism for so-called 'sidecar processes' to
register as event listeners, consuming events that are generated in
various parts of crosvm. in this patch we focus on events that enable
roziere cooperative ballooning, but lay the groundwork for a more
general framework.

the general idea is that a sidecar process would open a listening socket
and register the path with crosvm for a specific event. crosvm will pass
along a handle to a registered event tube to the appropriate device (or
other internal component) which would then be responsible for passing
along events. once events arrive back in the crosvm control loop, an
attempt is made to dispatch the events to any registered sockets.

BUG=b:269609274
TEST=sidecar program that performs registration and receives events

Change-Id: Iaff41aad8f862ed99a104c75623caaabc53e9e88
Reviewed-on: https://chromium-review.googlesource.com/c/crosvm/crosvm/+/4237140
Commit-Queue: Maciek Swiech <drmasquatch@google.com>
Reviewed-by: Daniel Verkamp <dverkamp@chromium.org>
diff --git a/base/src/sys/unix/net.rs b/base/src/sys/unix/net.rs
index b267f38..a4c6bd8 100644
--- a/base/src/sys/unix/net.rs
+++ b/base/src/sys/unix/net.rs
@@ -361,7 +361,7 @@
 }
 
 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
 pub struct UnixSeqpacket {
     #[serde(with = "super::with_raw_descriptor")]
     fd: RawFd,
diff --git a/crosvm_control/src/lib.rs b/crosvm_control/src/lib.rs
index 96629d5..648b7af 100644
--- a/crosvm_control/src/lib.rs
+++ b/crosvm_control/src/lib.rs
@@ -26,6 +26,7 @@
 use vm_control::BalloonControlCommand;
 use vm_control::BalloonStats;
 use vm_control::DiskControlCommand;
+use vm_control::RegisteredEvent;
 use vm_control::UsbControlAttachedDevice;
 use vm_control::UsbControlResult;
 use vm_control::VmRequest;
@@ -439,3 +440,76 @@
     })
     .unwrap_or(false)
 }
+
+/// Registers the connected process as a listener for `event`.
+#[no_mangle]
+pub extern "C" fn crosvm_client_register_events_listener(
+    socket_path: *const c_char,
+    listening_socket_path: *const c_char,
+    event: RegisteredEvent,
+) -> bool {
+    catch_unwind(|| {
+        if let Some(socket_path) = validate_socket_path(socket_path) {
+            if let Some(listening_socket_path) = validate_socket_path(listening_socket_path) {
+                let request = VmRequest::RegisterListener {
+                    event,
+                    socket_addr: listening_socket_path.to_str().unwrap().to_string(),
+                };
+                vms_request(&request, &socket_path).is_ok()
+            } else {
+                false
+            }
+        } else {
+            false
+        }
+    })
+    .unwrap_or(false)
+}
+
+/// Unegisters the connected process as a listener for `event`.
+#[no_mangle]
+pub extern "C" fn crosvm_client_unregister_events_listener(
+    socket_path: *const c_char,
+    listening_socket_path: *const c_char,
+    event: RegisteredEvent,
+) -> bool {
+    catch_unwind(|| {
+        if let Some(socket_path) = validate_socket_path(socket_path) {
+            if let Some(listening_socket_path) = validate_socket_path(listening_socket_path) {
+                let request = VmRequest::UnregisterListener {
+                    event,
+                    socket_addr: listening_socket_path.to_str().unwrap().to_string(),
+                };
+                vms_request(&request, &socket_path).is_ok()
+            } else {
+                false
+            }
+        } else {
+            false
+        }
+    })
+    .unwrap_or(false)
+}
+
+/// Unegisters the connected process as a listener for all events.
+#[no_mangle]
+pub extern "C" fn crosvm_client_unregister_listener(
+    socket_path: *const c_char,
+    listening_socket_path: *const c_char,
+) -> bool {
+    catch_unwind(|| {
+        if let Some(socket_path) = validate_socket_path(socket_path) {
+            if let Some(listening_socket_path) = validate_socket_path(listening_socket_path) {
+                let request = VmRequest::Unregister {
+                    socket_addr: listening_socket_path.to_str().unwrap().to_string(),
+                };
+                vms_request(&request, &socket_path).is_ok()
+            } else {
+                false
+            }
+        } else {
+            false
+        }
+    })
+    .unwrap_or(false)
+}
diff --git a/devices/src/virtio/balloon.rs b/devices/src/virtio/balloon.rs
index 2961f0d..6c1052f 100644
--- a/devices/src/virtio/balloon.rs
+++ b/devices/src/virtio/balloon.rs
@@ -17,6 +17,7 @@
 use base::AsRawDescriptor;
 use base::Event;
 use base::RawDescriptor;
+use base::SendTube;
 use base::Tube;
 use base::WorkerThread;
 use cros_async::block_on;
@@ -25,6 +26,7 @@
 use cros_async::AsyncTube;
 use cros_async::EventAsync;
 use cros_async::Executor;
+use cros_async::SendTubeAsync;
 use data_model::Le16;
 use data_model::Le32;
 use data_model::Le64;
@@ -34,6 +36,7 @@
 use futures::StreamExt;
 use remain::sorted;
 use thiserror::Error as ThisError;
+use vm_control::RegisteredEvent;
 use vm_memory::GuestAddress;
 use vm_memory::GuestMemory;
 use zerocopy::AsBytes;
@@ -509,6 +512,7 @@
     interrupt: Interrupt,
     state: Arc<AsyncMutex<BalloonState>>,
     mut stats_tx: mpsc::Sender<u64>,
+    registered_evt_q: Option<SendTubeAsync>,
 ) -> Result<()> {
     loop {
         match command_tube.next().await {
@@ -532,6 +536,15 @@
                             state.failable_update = true;
                         }
                     }
+
+                    if let Some(registered_evt_q) = &registered_evt_q {
+                        if let Err(e) = registered_evt_q
+                            .send(&RegisteredEvent::VirtioBalloonResize)
+                            .await
+                        {
+                            error!("failed to send VirtioBalloonResize event: {}", e);
+                        }
+                    }
                 }
                 BalloonTubeCommand::Stats { id } => {
                     if let Err(e) = stats_tx.try_send(id) {
@@ -580,9 +593,13 @@
     pending_adjusted_response_event: Event,
     mem: GuestMemory,
     state: Arc<AsyncMutex<BalloonState>>,
-) -> (Option<Tube>, Tube) {
+    registered_evt_q: Option<SendTube>,
+) -> (Option<Tube>, Tube, Option<SendTube>) {
     let ex = Executor::new().unwrap();
     let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
+    let registered_evt_q_async = registered_evt_q
+        .as_ref()
+        .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
 
     // We need a block to release all references to command_tube at the end before returning it.
     {
@@ -670,8 +687,13 @@
         pin_mut!(reporting);
 
         // Future to handle command messages that resize the balloon.
-        let command =
-            handle_command_tube(&command_tube, interrupt.clone(), state.clone(), stats_tx);
+        let command = handle_command_tube(
+            &command_tube,
+            interrupt.clone(),
+            state.clone(),
+            stats_tx,
+            registered_evt_q_async,
+        );
         pin_mut!(command);
 
         // Process any requests to resample the irq value.
@@ -724,7 +746,7 @@
         }
     }
 
-    (release_memory_tube, command_tube.into())
+    (release_memory_tube, command_tube.into(), registered_evt_q)
 }
 
 /// Virtio device for memory balloon inflation/deflation.
@@ -737,7 +759,8 @@
     state: Arc<AsyncMutex<BalloonState>>,
     features: u64,
     acked_features: u64,
-    worker_thread: Option<WorkerThread<(Option<Tube>, Tube)>>,
+    worker_thread: Option<WorkerThread<(Option<Tube>, Tube, Option<SendTube>)>>,
+    registered_evt_q: Option<SendTube>,
 }
 
 /// Operation mode of the balloon.
@@ -763,6 +786,7 @@
         init_balloon_size: u64,
         mode: BalloonMode,
         enabled_features: u64,
+        registered_evt_q: Option<SendTube>,
     ) -> Result<Balloon> {
         let features = base_features
             | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
@@ -790,6 +814,7 @@
             worker_thread: None,
             features,
             acked_features: 0,
+            registered_evt_q,
         })
     }
 
@@ -899,6 +924,7 @@
         #[cfg(windows)]
         let mapping_tube = self.dynamic_mapping_tube.take().unwrap();
         let release_memory_tube = self.release_memory_tube.take();
+        let registered_evt_q = self.registered_evt_q.take();
         let pending_adjusted_response_event = self
             .pending_adjusted_response_event
             .try_clone()
@@ -920,6 +946,7 @@
                 pending_adjusted_response_event,
                 mem,
                 state,
+                registered_evt_q,
             )
         }));
 
@@ -928,9 +955,10 @@
 
     fn reset(&mut self) -> bool {
         if let Some(worker_thread) = self.worker_thread.take() {
-            let (release_memory_tube, command_tube) = worker_thread.stop();
+            let (release_memory_tube, command_tube, registered_evt_q) = worker_thread.stop();
             self.release_memory_tube = release_memory_tube;
             self.command_tube = Some(command_tube);
+            self.registered_evt_q = registered_evt_q;
             return true;
         }
         false
diff --git a/src/crosvm/sys/unix.rs b/src/crosvm/sys/unix.rs
index 3494a74..3ed6f98 100644
--- a/src/crosvm/sys/unix.rs
+++ b/src/crosvm/sys/unix.rs
@@ -15,6 +15,8 @@
 use std::cmp::Reverse;
 use std::collections::BTreeMap;
 use std::collections::BTreeSet;
+use std::collections::HashMap;
+use std::collections::HashSet;
 use std::convert::TryInto;
 use std::ffi::CString;
 use std::fs::File;
@@ -202,6 +204,7 @@
     >,
     vvu_proxy_device_tubes: &mut Vec<Tube>,
     vvu_proxy_max_sibling_mem_size: u64,
+    registered_evt_q: &SendTube,
 ) -> DeviceResult<Vec<VirtioDeviceStub>> {
     let mut devs = Vec::new();
 
@@ -480,6 +483,11 @@
             balloon_inflate_tube,
             init_balloon_size,
             balloon_features,
+            Some(
+                registered_evt_q
+                    .try_clone()
+                    .context("failed to clone registered_evt_q tube")?,
+            ),
         )?);
     }
 
@@ -707,6 +715,7 @@
     vvu_proxy_device_tubes: &mut Vec<Tube>,
     vvu_proxy_max_sibling_mem_size: u64,
     iova_max_addr: &mut Option<u64>,
+    registered_evt_q: &SendTube,
 ) -> DeviceResult<Vec<(Box<dyn BusDeviceObj>, Option<Minijail>)>> {
     let mut devices: Vec<(Box<dyn BusDeviceObj>, Option<Minijail>)> = Vec::new();
     #[cfg(feature = "balloon")]
@@ -837,6 +846,7 @@
         render_server_fd,
         vvu_proxy_device_tubes,
         vvu_proxy_max_sibling_mem_size,
+        registered_evt_q,
     )?;
 
     for stub in stubs {
@@ -1706,6 +1716,10 @@
     let mut iommu_attached_endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>> =
         BTreeMap::new();
     let mut iova_max_addr: Option<u64> = None;
+
+    let (reg_evt_wrtube, reg_evt_rdtube) =
+        Tube::directional_pair().context("failed to create registered event tube")?;
+
     let mut devices = create_devices(
         &cfg,
         &mut vm,
@@ -1730,6 +1744,7 @@
         &mut vvu_proxy_device_tubes,
         components.memory_size,
         &mut iova_max_addr,
+        &reg_evt_wrtube,
     )?;
 
     #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
@@ -1926,6 +1941,7 @@
         hp_thread,
         #[cfg(feature = "swap")]
         swap_controller,
+        reg_evt_rdtube,
     )
 }
 
@@ -2397,6 +2413,7 @@
     >,
     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] hp_thread: std::thread::JoinHandle<()>,
     #[cfg(feature = "swap")] swap_controller: Option<SwapController>,
+    reg_evt_rdtube: RecvTube,
 ) -> Result<ExitState> {
     #[derive(EventToken)]
     enum Token {
@@ -2405,6 +2422,7 @@
         ChildSignal,
         VmControlServer,
         VmControl { index: usize },
+        RegisteredEvent,
     }
 
     let mut iommu_client = iommu_host_tube
@@ -2421,6 +2439,7 @@
         (&linux.suspend_evt, Token::Suspend),
         (&sigchld_fd, Token::ChildSignal),
         (&vm_evt_rdtube, Token::VmEvent),
+        (&reg_evt_rdtube, Token::RegisteredEvent),
     ])
     .context("failed to build wait context")?;
 
@@ -2650,6 +2669,8 @@
     let mut pvpanic_code = PvPanicCode::Unknown;
     #[cfg(feature = "balloon")]
     let mut balloon_stats_id: u64 = 0;
+    let mut registered_evt_sockets: HashMap<RegisteredEvent, HashSet<UnixSeqpacket>> =
+        HashMap::new();
 
     'wait: loop {
         let events = {
@@ -2665,6 +2686,22 @@
         let mut vm_control_indices_to_remove = Vec::new();
         for event in events.iter().filter(|e| e.is_readable) {
             match event.token {
+                Token::RegisteredEvent => match reg_evt_rdtube.recv::<RegisteredEvent>() {
+                    Ok(reg_evt) => {
+                        if let Some(sockets) = registered_evt_sockets.get_mut(&reg_evt) {
+                            for socket in sockets.iter() {
+                                let tube =
+                                    Tube::new_from_unix_seqpacket(socket.try_clone().unwrap());
+                                if let Err(e) = tube.send(&reg_evt) {
+                                    warn!("failed to send registered event: {}", e);
+                                }
+                            }
+                        }
+                    }
+                    Err(e) => {
+                        warn!("failed to recv RegisteredEvent: {}", e);
+                    }
+                },
                 Token::VmEvent => {
                     let mut break_to_wait: bool = true;
                     match vm_evt_rdtube.recv::<VmEventType>() {
@@ -2807,6 +2844,44 @@
                                                 VmResponse::Ok
                                             }
                                         }
+                                        VmRequest::RegisterListener { socket_addr, event } => {
+                                            if let Ok(socket) = UnixSeqpacket::connect(socket_addr)
+                                            {
+                                                if let Some(sockets) =
+                                                    registered_evt_sockets.get_mut(&event)
+                                                {
+                                                    sockets.insert(socket);
+                                                } else {
+                                                    registered_evt_sockets.insert(
+                                                        event,
+                                                        vec![socket].into_iter().collect(),
+                                                    );
+                                                }
+                                            }
+                                            VmResponse::Ok
+                                        }
+                                        VmRequest::UnregisterListener { socket_addr, event } => {
+                                            if let Ok(socket) = UnixSeqpacket::connect(socket_addr)
+                                            {
+                                                if let Some(sockets) =
+                                                    registered_evt_sockets.get_mut(&event)
+                                                {
+                                                    sockets.remove(&socket);
+                                                }
+                                            }
+                                            VmResponse::Ok
+                                        }
+                                        VmRequest::Unregister { socket_addr } => {
+                                            if let Ok(socket) = UnixSeqpacket::connect(socket_addr)
+                                            {
+                                                for (_, sockets) in
+                                                    registered_evt_sockets.iter_mut()
+                                                {
+                                                    sockets.remove(&socket);
+                                                }
+                                            }
+                                            VmResponse::Ok
+                                        }
                                         _ => {
                                             let response = request.execute(
                                                 &mut run_mode_opt,
diff --git a/src/crosvm/sys/unix/device_helpers.rs b/src/crosvm/sys/unix/device_helpers.rs
index 60a64dd..05424d5 100644
--- a/src/crosvm/sys/unix/device_helpers.rs
+++ b/src/crosvm/sys/unix/device_helpers.rs
@@ -708,6 +708,7 @@
     inflate_tube: Option<Tube>,
     init_balloon_size: u64,
     enabled_features: u64,
+    registered_evt_q: Option<SendTube>,
 ) -> DeviceResult {
     let dev = virtio::Balloon::new(
         virtio::base_features(protection_type),
@@ -716,6 +717,7 @@
         init_balloon_size,
         mode,
         enabled_features,
+        registered_evt_q,
     )
     .context("failed to create balloon")?;
 
diff --git a/src/sys/windows.rs b/src/sys/windows.rs
index 0afcc62..8a824d7 100644
--- a/src/sys/windows.rs
+++ b/src/sys/windows.rs
@@ -445,6 +445,7 @@
             BalloonMode::Relaxed
         },
         balloon_features,
+        None,
     )
     .exit_context(Exit::BalloonDeviceNew, "failed to create balloon")?;
 
diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs
index 86cb661..d325075 100644
--- a/vm_control/src/lib.rs
+++ b/vm_control/src/lib.rs
@@ -1041,6 +1041,26 @@
     Snapshot(SnapshotCommand),
     /// Command to Restore devices
     Restore(RestoreCommand),
+    /// Register for event notification
+    RegisterListener {
+        socket_addr: String,
+        event: RegisteredEvent,
+    },
+    /// Unregister for notifications for event
+    UnregisterListener {
+        socket_addr: String,
+        event: RegisteredEvent,
+    },
+    /// Unregister for all event notification
+    Unregister { socket_addr: String },
+}
+
+#[repr(C)]
+#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
+pub enum RegisteredEvent {
+    VirtioBalloonWssReport,
+    VirtioBalloonResize,
+    VirtioBalloonOOMDeflation,
 }
 
 pub fn handle_disk_command(command: &DiskControlCommand, disk_host_tube: &Tube) -> VmResponse {
@@ -1534,6 +1554,15 @@
                     }
                 }
             }
+            VmRequest::RegisterListener {
+                socket_addr: _,
+                event: _,
+            } => VmResponse::Ok,
+            VmRequest::UnregisterListener {
+                socket_addr: _,
+                event: _,
+            } => VmResponse::Ok,
+            VmRequest::Unregister { socket_addr: _ } => VmResponse::Ok,
         }
     }
 }