secure_env: add suspend-resume support to rust impl

This mimics the logic of the C++ suspend-resume code, but to reduce FFI
surface area, instead of using a few different primitives (two event_fds
and a condition variable), it uses one socket pair and a simple
protocol.

Bug: 295028759
Test: cvd suspend; cvd resume
Change-Id: I55628436e9236efa08d85f101b8684a6d2093547
Merged-In: I55628436e9236efa08d85f101b8684a6d2093547
diff --git a/host/commands/secure_env/rust/Android.bp b/host/commands/secure_env/rust/Android.bp
index 176430e..c5d6b01 100644
--- a/host/commands/secure_env/rust/Android.bp
+++ b/host/commands/secure_env/rust/Android.bp
@@ -38,6 +38,7 @@
         "libkmr_wire",
         "liblibc",
         "liblog_rust",
+        "libnix",
         "libprotobuf_deprecated",
         "libsecure_env_tpm",
     ],
@@ -71,6 +72,7 @@
         "libkmr_wire",
         "liblibc",
         "liblog_rust",
+        "libnix",
         "libprotobuf_deprecated",
         "libsecure_env_tpm",
     ],
diff --git a/host/commands/secure_env/rust/ffi.rs b/host/commands/secure_env/rust/ffi.rs
index e6360cd..9ffed7b 100644
--- a/host/commands/secure_env/rust/ffi.rs
+++ b/host/commands/secure_env/rust/ffi.rs
@@ -20,18 +20,21 @@
 use kmr_wire::keymint::SecurityLevel;
 use libc::c_int;
 use log::error;
+use std::os::fd::FromRawFd;
 
 /// FFI wrapper around [`kmr_cf::ta_main`].
 ///
 /// # Safety
 ///
-/// `fd_in` and `fd_out` must be valid and open file descriptors.
+/// `fd_in`, `fd_out`, and `snapshot_socket_fd` must be valid and open file descriptors and the
+/// caller must not use or close them after the call.
 #[no_mangle]
 pub unsafe extern "C" fn kmr_ta_main(
     fd_in: c_int,
     fd_out: c_int,
     security_level: c_int,
     trm: *mut libc::c_void,
+    snapshot_socket_fd: c_int,
 ) {
     let security_level = match security_level {
         x if x == SecurityLevel::TrustedEnvironment as i32 => SecurityLevel::TrustedEnvironment,
@@ -42,6 +45,10 @@
             SecurityLevel::Software
         }
     };
+    let snapshot_socket =
+        // SAFETY: fd being valid and open and exclusive is asserted in the unsafe function's
+        // preconditions, so this is pushed up to the caller.
+        unsafe { std::os::unix::net::UnixStream::from_raw_fd(snapshot_socket_fd) };
     // SAFETY: The caller guarantees that `fd_in` and `fd_out` are valid and open.
-    unsafe { kmr_cf::ta_main(fd_in, fd_out, security_level, trm) }
+    unsafe { kmr_cf::ta_main(fd_in, fd_out, security_level, trm, snapshot_socket) }
 }
diff --git a/host/commands/secure_env/rust/kmr_ta.h b/host/commands/secure_env/rust/kmr_ta.h
index ce70873..efa9f2d 100644
--- a/host/commands/secure_env/rust/kmr_ta.h
+++ b/host/commands/secure_env/rust/kmr_ta.h
@@ -26,7 +26,10 @@
 //   values from SecurityLevel.aidl.
 // - trm: pointer to a valid `TpmResourceManager`, which must remain valid
 //   for the entire duration of the function execution.
-void kmr_ta_main(int fd_in, int fd_out, int security_level, void* trm);
+// - snapshot_socket_fd: file descriptor for a socket used to communicate with
+//   the secure_env suspend-resume handler thread.
+void kmr_ta_main(int fd_in, int fd_out, int security_level, void* trm,
+                 int snapshot_socket_fd);
 
 #ifdef __cplusplus
 }
diff --git a/host/commands/secure_env/rust/lib.rs b/host/commands/secure_env/rust/lib.rs
index 49508b9..eb2536b 100644
--- a/host/commands/secure_env/rust/lib.rs
+++ b/host/commands/secure_env/rust/lib.rs
@@ -32,6 +32,7 @@
 use log::{error, info, trace};
 use std::ffi::CString;
 use std::io::{Read, Write};
+use std::os::fd::AsRawFd;
 use std::os::unix::{ffi::OsStrExt, io::FromRawFd};
 
 pub mod attest;
@@ -44,16 +45,23 @@
 #[cfg(test)]
 mod tests;
 
+// See `SnapshotSocketMessage` in suspend_resume_handler.h for docs.
+const SNAPSHOT_SOCKET_MESSAGE_SUSPEND: u8 = 1;
+const SNAPSHOT_SOCKET_MESSAGE_SUSPEND_ACK: u8 = 2;
+const SNAPSHOT_SOCKET_MESSAGE_RESUME: u8 = 3;
+
 /// Main routine for the KeyMint TA. Only returns if there is a fatal error.
 ///
 /// # Safety
 ///
-/// `fd_in` and `fd_out` must be valid and open file descriptors.
+/// `fd_in` and `fd_out` must be valid and open file descriptors and the caller must not use or
+/// close them after the call.
 pub unsafe fn ta_main(
     fd_in: c_int,
     fd_out: c_int,
     security_level: SecurityLevel,
     trm: *mut libc::c_void,
+    mut snapshot_socket: std::os::unix::net::UnixStream,
 ) {
     log::set_logger(&AndroidCppLogger).unwrap();
     log::set_max_level(log::LevelFilter::Debug); // Filtering happens elsewhere
@@ -62,9 +70,9 @@
         fd_in, fd_out, security_level,
     );
 
-    // SAFETY: The caller guarantees that `fd_in` is valid and open.
+    // SAFETY: The caller guarantees that `fd_in` is valid and open and exclusive.
     let mut infile = unsafe { std::fs::File::from_raw_fd(fd_in) };
-    // SAFETY: The caller guarantees that `fd_out` is valid and open.
+    // SAFETY: The caller guarantees that `fd_out` is valid and open and exclusive.
     let mut outfile = unsafe { std::fs::File::from_raw_fd(fd_out) };
 
     let hw_info = HardwareInfo {
@@ -147,52 +155,107 @@
 
     let mut buf = [0; kmr_wire::DEFAULT_MAX_SIZE];
     loop {
-        // Read a request message from the pipe, as a 4-byte BE length followed by the message.
-        let mut req_len_data = [0u8; 4];
-        if let Err(e) = infile.read_exact(&mut req_len_data) {
-            error!("FATAL: Failed to read request length from connection: {:?}", e);
-            return;
-        }
-        let req_len = u32::from_be_bytes(req_len_data) as usize;
-        if req_len > kmr_wire::DEFAULT_MAX_SIZE {
-            error!("FATAL: Request too long ({})", req_len);
-            return;
-        }
-        let req_data = &mut buf[..req_len];
-        if let Err(e) = infile.read_exact(req_data) {
-            error!(
-                "FATAL: Failed to read request data of length {} from connection: {:?}",
-                req_len, e
-            );
+        // Wait for data from either `infile` or `snapshot_socket`. If both have data, we prioritize
+        // processing only `infile` until it is empty so that there is no pending state when we
+        // suspend the loop.
+        let mut fd_set = nix::sys::select::FdSet::new();
+        fd_set.insert(infile.as_raw_fd());
+        fd_set.insert(snapshot_socket.as_raw_fd());
+        if let Err(e) = nix::sys::select::select(
+            None,
+            /*readfds=*/ Some(&mut fd_set),
+            None,
+            None,
+            /*timeout=*/ None,
+        ) {
+            error!("FATAL: Failed to select on input FDs: {:?}", e);
             return;
         }
 
-        // Pass to the TA to process.
-        trace!("-> TA: received data: (len={})", req_data.len());
-        let rsp = ta.process(req_data);
-        trace!("<- TA: send data: (len={})", rsp.len());
-
-        // Send the response message down the pipe, as a 4-byte BE length followed by the message.
-        let rsp_len: u32 = match rsp.len().try_into() {
-            Ok(l) => l,
-            Err(_e) => {
-                error!("FATAL: Response too long (len={})", rsp.len());
+        if fd_set.contains(infile.as_raw_fd()) {
+            // Read a request message from the pipe, as a 4-byte BE length followed by the message.
+            let mut req_len_data = [0u8; 4];
+            if let Err(e) = infile.read_exact(&mut req_len_data) {
+                error!("FATAL: Failed to read request length from connection: {:?}", e);
                 return;
             }
-        };
-        let rsp_len_data = rsp_len.to_be_bytes();
-        if let Err(e) = outfile.write_all(&rsp_len_data[..]) {
-            error!("FATAL: Failed to write response length to connection: {:?}", e);
-            return;
+            let req_len = u32::from_be_bytes(req_len_data) as usize;
+            if req_len > kmr_wire::DEFAULT_MAX_SIZE {
+                error!("FATAL: Request too long ({})", req_len);
+                return;
+            }
+            let req_data = &mut buf[..req_len];
+            if let Err(e) = infile.read_exact(req_data) {
+                error!(
+                    "FATAL: Failed to read request data of length {} from connection: {:?}",
+                    req_len, e
+                );
+                return;
+            }
+
+            // Pass to the TA to process.
+            trace!("-> TA: received data: (len={})", req_data.len());
+            let rsp = ta.process(req_data);
+            trace!("<- TA: send data: (len={})", rsp.len());
+
+            // Send the response message down the pipe, as a 4-byte BE length followed by the message.
+            let rsp_len: u32 = match rsp.len().try_into() {
+                Ok(l) => l,
+                Err(_e) => {
+                    error!("FATAL: Response too long (len={})", rsp.len());
+                    return;
+                }
+            };
+            let rsp_len_data = rsp_len.to_be_bytes();
+            if let Err(e) = outfile.write_all(&rsp_len_data[..]) {
+                error!("FATAL: Failed to write response length to connection: {:?}", e);
+                return;
+            }
+            if let Err(e) = outfile.write_all(&rsp) {
+                error!(
+                    "FATAL: Failed to write response data of length {} to connection: {:?}",
+                    rsp_len, e
+                );
+                return;
+            }
+            let _ = outfile.flush();
+
+            continue;
         }
-        if let Err(e) = outfile.write_all(&rsp) {
-            error!(
-                "FATAL: Failed to write response data of length {} to connection: {:?}",
-                rsp_len, e
-            );
-            return;
+
+        if fd_set.contains(snapshot_socket.as_raw_fd()) {
+            // Read suspend request.
+            let mut suspend_request = 0u8;
+            if let Err(e) = snapshot_socket.read_exact(std::slice::from_mut(&mut suspend_request)) {
+                error!("FATAL: Failed to read suspend request: {:?}", e);
+                return;
+            }
+            if suspend_request != SNAPSHOT_SOCKET_MESSAGE_SUSPEND {
+                error!(
+                    "FATAL: Unexpected value from snapshot socket: got {}, expected {}",
+                    suspend_request, SNAPSHOT_SOCKET_MESSAGE_SUSPEND
+                );
+                return;
+            }
+            // Write ACK.
+            if let Err(e) = snapshot_socket.write_all(&[SNAPSHOT_SOCKET_MESSAGE_SUSPEND_ACK]) {
+                error!("FATAL: Failed to write suspend ACK request: {:?}", e);
+                return;
+            }
+            // Block until we get a resume request.
+            let mut resume_request = 0u8;
+            if let Err(e) = snapshot_socket.read_exact(std::slice::from_mut(&mut resume_request)) {
+                error!("FATAL: Failed to read resume request: {:?}", e);
+                return;
+            }
+            if resume_request != SNAPSHOT_SOCKET_MESSAGE_RESUME {
+                error!(
+                    "FATAL: Unexpected value from snapshot socket: got {}, expected {}",
+                    resume_request, SNAPSHOT_SOCKET_MESSAGE_RESUME
+                );
+                return;
+            }
         }
-        let _ = outfile.flush();
     }
 }
 
diff --git a/host/commands/secure_env/secure_env_not_windows_main.cpp b/host/commands/secure_env/secure_env_not_windows_main.cpp
index d9122df..b7a617a 100644
--- a/host/commands/secure_env/secure_env_not_windows_main.cpp
+++ b/host/commands/secure_env/secure_env_not_windows_main.cpp
@@ -259,13 +259,16 @@
   }
 
   // go/cf-secure-env-snapshot
+  auto [rust_snapshot_socket1, rust_snapshot_socket2] =
+      CF_EXPECT(SharedFD::SocketPair(AF_UNIX, SOCK_STREAM, 0));
   SnapshotRunningFlag running;
   SharedFD channel_to_run_cvd = DupFdFlag(FLAGS_snapshot_control_fd);
   EventFdsManager event_fds_manager = CF_EXPECT(EventFdsManager::Create());
   EventNotifiers suspended_notifiers;
 
   SnapshotCommandHandler suspend_resume_handler(
-      channel_to_run_cvd, event_fds_manager, suspended_notifiers, running);
+      channel_to_run_cvd, event_fds_manager, suspended_notifiers, running,
+      std::move(rust_snapshot_socket1));
 
   // The guest image may have either the C++ implementation of
   // KeyMint/Keymaster, xor the Rust implementation of KeyMint.  Those different
@@ -281,8 +284,12 @@
   int keymint_in = FLAGS_keymint_fd_in;
   int keymint_out = FLAGS_keymint_fd_out;
   TpmResourceManager* rm = resource_manager;
-  threads.emplace_back([rm, keymint_in, keymint_out, security_level]() {
-    kmr_ta_main(keymint_in, keymint_out, security_level, rm);
+  threads.emplace_back([rm, keymint_in, keymint_out, security_level,
+                        rust_snapshot_socket2 =
+                            std::move(rust_snapshot_socket2)]() {
+    int snapshot_socket_fd = std::move(rust_snapshot_socket2)->UNMANAGED_Dup();
+    kmr_ta_main(keymint_in, keymint_out, security_level, rm,
+                snapshot_socket_fd);
   });
 #endif
 
diff --git a/host/commands/secure_env/suspend_resume_handler.cpp b/host/commands/secure_env/suspend_resume_handler.cpp
index a2b55ae..0b47134 100644
--- a/host/commands/secure_env/suspend_resume_handler.cpp
+++ b/host/commands/secure_env/suspend_resume_handler.cpp
@@ -21,6 +21,34 @@
 #include "host/libs/config/cuttlefish_config.h"
 
 namespace cuttlefish {
+namespace {
+
+Result<void> WriteSuspendRequest(const SharedFD& socket) {
+  const SnapshotSocketMessage suspend_request = SnapshotSocketMessage::kSuspend;
+  CF_EXPECT_EQ(sizeof(suspend_request),
+               socket->Write(&suspend_request, sizeof(suspend_request)),
+               "socket write failed: " << socket->StrError());
+  return {};
+}
+
+Result<void> ReadSuspendAck(const SharedFD& socket) {
+  SnapshotSocketMessage ack_response;
+  CF_EXPECT_EQ(sizeof(ack_response),
+               socket->Read(&ack_response, sizeof(ack_response)),
+               "socket read failed: " << socket->StrError());
+  CF_EXPECT_EQ(SnapshotSocketMessage::kSuspendAck, ack_response);
+  return {};
+}
+
+Result<void> WriteResumeRequest(const SharedFD& socket) {
+  const SnapshotSocketMessage resume_request = SnapshotSocketMessage::kResume;
+  CF_EXPECT_EQ(sizeof(resume_request),
+               socket->Write(&resume_request, sizeof(resume_request)),
+               "socket write failed: " << socket->StrError());
+  return {};
+}
+
+}  // namespace
 
 SnapshotCommandHandler::~SnapshotCommandHandler() { Join(); }
 
@@ -32,11 +60,13 @@
 
 SnapshotCommandHandler::SnapshotCommandHandler(
     SharedFD channel_to_run_cvd, EventFdsManager& event_fds_manager,
-    EventNotifiers& suspended_notifiers, SnapshotRunningFlag& running)
+    EventNotifiers& suspended_notifiers, SnapshotRunningFlag& running,
+    SharedFD rust_snapshot_socket)
     : channel_to_run_cvd_(channel_to_run_cvd),
       event_fds_manager_(event_fds_manager),
       suspended_notifiers_(suspended_notifiers),
-      shared_running_(running) {
+      shared_running_(running),
+      rust_snapshot_socket_(std::move(rust_snapshot_socket)) {
   handler_thread_ = std::thread([this]() {
     while (true) {
       auto result = SuspendResumeHandler();
@@ -67,13 +97,18 @@
   switch (snapshot_cmd) {
     case ExtendedActionType::kSuspend: {
       LOG(DEBUG) << "Handling suspended...";
+      // Request all worker threads to suspend.
       shared_running_.UnsetRunning();  // running := false
       CF_EXPECT(event_fds_manager_.SuspendKeymasterResponder());
       CF_EXPECT(event_fds_manager_.SuspendGatekeeperResponder());
       CF_EXPECT(event_fds_manager_.SuspendOemlockResponder());
+      CF_EXPECT(WriteSuspendRequest(rust_snapshot_socket_));
+      // Wait for ACKs from worker threads.
       suspended_notifiers_.keymaster_suspended_.WaitAndReset();
       suspended_notifiers_.gatekeeper_suspended_.WaitAndReset();
       suspended_notifiers_.oemlock_suspended_.WaitAndReset();
+      CF_EXPECT(ReadSuspendAck(rust_snapshot_socket_));
+      // Write response to run_cvd.
       auto response = LauncherResponse::kSuccess;
       const auto n_written =
           channel_to_run_cvd_->Write(&response, sizeof(response));
@@ -82,7 +117,10 @@
     };
     case ExtendedActionType::kResume: {
       LOG(DEBUG) << "Handling resume...";
+      // Request all worker threads to resume.
       shared_running_.SetRunning();  // running := true, and notifies all
+      CF_EXPECT(WriteResumeRequest(rust_snapshot_socket_));
+      // Write response to run_cvd.
       auto response = LauncherResponse::kSuccess;
       const auto n_written =
           channel_to_run_cvd_->Write(&response, sizeof(response));
diff --git a/host/commands/secure_env/suspend_resume_handler.h b/host/commands/secure_env/suspend_resume_handler.h
index b36f3cc..e2e8c20 100644
--- a/host/commands/secure_env/suspend_resume_handler.h
+++ b/host/commands/secure_env/suspend_resume_handler.h
@@ -26,13 +26,43 @@
 
 namespace cuttlefish {
 
+// `SnapshotCommandHandler` can request threads to suspend and resume using the
+// following protocol. Each message on the socket is 1 byte.
+//
+// Suspend flow:
+//
+//   1. `SnapshotCommandHandler` writes `kSuspend` to the socket.
+//   2. When the worker thread sees the socket is readable, it should assume the
+//      incoming message is `kSuspend`, finish all non-blocking work, read the
+//      `kSuspend` message, write a `kSuspendAck` message back into the socket,
+//      and then, finally, block until it receives another message from the
+//      socket (which will always be `kResume`).
+//   3. `SnapshotCommandHandler` waits for the `kSuspendAck` to ensure the
+//      worker thread is actually suspended and then proceeds.
+//
+// Resume flow:
+//
+//   1. The worker thread is already blocked waiting for a `kResume` from the
+//      socket.
+//   2. `SnapshotCommandHandler` sends a `kResume`.
+//   3. The worker thread sees it and goes back to normal operation.
+//
+// WARNING: Keep in sync with the `SNAPSHOT_SOCKET_MESSAGE_*` constants in
+// secure_env/rust/lib.rs.
+enum SnapshotSocketMessage : uint8_t {
+  kSuspend = 1,
+  kSuspendAck = 2,
+  kResume = 3,
+};
+
 class SnapshotCommandHandler {
  public:
   ~SnapshotCommandHandler();
   SnapshotCommandHandler(SharedFD channel_to_run_cvd,
                          EventFdsManager& event_fds,
                          EventNotifiers& suspended_notifiers,
-                         SnapshotRunningFlag& running);
+                         SnapshotRunningFlag& running,
+                         SharedFD rust_snapshot_socket);
 
  private:
   Result<void> SuspendResumeHandler();
@@ -43,6 +73,7 @@
   EventFdsManager& event_fds_manager_;
   EventNotifiers& suspended_notifiers_;
   SnapshotRunningFlag& shared_running_;  // shared by other components outside
+  SharedFD rust_snapshot_socket_;
   std::thread handler_thread_;
 };