secure_env: add suspend-resume support to rust impl
This mimics the logic of the C++ suspend-resume code, but to reduce FFI
surface area, instead of using a few different primitives (two event_fds
and a condition variable), it uses one socket pair and a simple
protocol.
Bug: 295028759
Test: cvd suspend; cvd resume
Change-Id: I55628436e9236efa08d85f101b8684a6d2093547
Merged-In: I55628436e9236efa08d85f101b8684a6d2093547
diff --git a/host/commands/secure_env/rust/Android.bp b/host/commands/secure_env/rust/Android.bp
index 176430e..c5d6b01 100644
--- a/host/commands/secure_env/rust/Android.bp
+++ b/host/commands/secure_env/rust/Android.bp
@@ -38,6 +38,7 @@
"libkmr_wire",
"liblibc",
"liblog_rust",
+ "libnix",
"libprotobuf_deprecated",
"libsecure_env_tpm",
],
@@ -71,6 +72,7 @@
"libkmr_wire",
"liblibc",
"liblog_rust",
+ "libnix",
"libprotobuf_deprecated",
"libsecure_env_tpm",
],
diff --git a/host/commands/secure_env/rust/ffi.rs b/host/commands/secure_env/rust/ffi.rs
index e6360cd..9ffed7b 100644
--- a/host/commands/secure_env/rust/ffi.rs
+++ b/host/commands/secure_env/rust/ffi.rs
@@ -20,18 +20,21 @@
use kmr_wire::keymint::SecurityLevel;
use libc::c_int;
use log::error;
+use std::os::fd::FromRawFd;
/// FFI wrapper around [`kmr_cf::ta_main`].
///
/// # Safety
///
-/// `fd_in` and `fd_out` must be valid and open file descriptors.
+/// `fd_in`, `fd_out`, and `snapshot_socket_fd` must be valid and open file descriptors and the
+/// caller must not use or close them after the call.
#[no_mangle]
pub unsafe extern "C" fn kmr_ta_main(
fd_in: c_int,
fd_out: c_int,
security_level: c_int,
trm: *mut libc::c_void,
+ snapshot_socket_fd: c_int,
) {
let security_level = match security_level {
x if x == SecurityLevel::TrustedEnvironment as i32 => SecurityLevel::TrustedEnvironment,
@@ -42,6 +45,10 @@
SecurityLevel::Software
}
};
+ let snapshot_socket =
+ // SAFETY: fd being valid and open and exclusive is asserted in the unsafe function's
+ // preconditions, so this is pushed up to the caller.
+ unsafe { std::os::unix::net::UnixStream::from_raw_fd(snapshot_socket_fd) };
// SAFETY: The caller guarantees that `fd_in` and `fd_out` are valid and open.
- unsafe { kmr_cf::ta_main(fd_in, fd_out, security_level, trm) }
+ unsafe { kmr_cf::ta_main(fd_in, fd_out, security_level, trm, snapshot_socket) }
}
diff --git a/host/commands/secure_env/rust/kmr_ta.h b/host/commands/secure_env/rust/kmr_ta.h
index ce70873..efa9f2d 100644
--- a/host/commands/secure_env/rust/kmr_ta.h
+++ b/host/commands/secure_env/rust/kmr_ta.h
@@ -26,7 +26,10 @@
// values from SecurityLevel.aidl.
// - trm: pointer to a valid `TpmResourceManager`, which must remain valid
// for the entire duration of the function execution.
-void kmr_ta_main(int fd_in, int fd_out, int security_level, void* trm);
+// - snapshot_socket_fd: file descriptor for a socket used to communicate with
+// the secure_env suspend-resume handler thread.
+void kmr_ta_main(int fd_in, int fd_out, int security_level, void* trm,
+ int snapshot_socket_fd);
#ifdef __cplusplus
}
diff --git a/host/commands/secure_env/rust/lib.rs b/host/commands/secure_env/rust/lib.rs
index 49508b9..eb2536b 100644
--- a/host/commands/secure_env/rust/lib.rs
+++ b/host/commands/secure_env/rust/lib.rs
@@ -32,6 +32,7 @@
use log::{error, info, trace};
use std::ffi::CString;
use std::io::{Read, Write};
+use std::os::fd::AsRawFd;
use std::os::unix::{ffi::OsStrExt, io::FromRawFd};
pub mod attest;
@@ -44,16 +45,23 @@
#[cfg(test)]
mod tests;
+// See `SnapshotSocketMessage` in suspend_resume_handler.h for docs.
+const SNAPSHOT_SOCKET_MESSAGE_SUSPEND: u8 = 1;
+const SNAPSHOT_SOCKET_MESSAGE_SUSPEND_ACK: u8 = 2;
+const SNAPSHOT_SOCKET_MESSAGE_RESUME: u8 = 3;
+
/// Main routine for the KeyMint TA. Only returns if there is a fatal error.
///
/// # Safety
///
-/// `fd_in` and `fd_out` must be valid and open file descriptors.
+/// `fd_in` and `fd_out` must be valid and open file descriptors and the caller must not use or
+/// close them after the call.
pub unsafe fn ta_main(
fd_in: c_int,
fd_out: c_int,
security_level: SecurityLevel,
trm: *mut libc::c_void,
+ mut snapshot_socket: std::os::unix::net::UnixStream,
) {
log::set_logger(&AndroidCppLogger).unwrap();
log::set_max_level(log::LevelFilter::Debug); // Filtering happens elsewhere
@@ -62,9 +70,9 @@
fd_in, fd_out, security_level,
);
- // SAFETY: The caller guarantees that `fd_in` is valid and open.
+ // SAFETY: The caller guarantees that `fd_in` is valid and open and exclusive.
let mut infile = unsafe { std::fs::File::from_raw_fd(fd_in) };
- // SAFETY: The caller guarantees that `fd_out` is valid and open.
+ // SAFETY: The caller guarantees that `fd_out` is valid and open and exclusive.
let mut outfile = unsafe { std::fs::File::from_raw_fd(fd_out) };
let hw_info = HardwareInfo {
@@ -147,52 +155,107 @@
let mut buf = [0; kmr_wire::DEFAULT_MAX_SIZE];
loop {
- // Read a request message from the pipe, as a 4-byte BE length followed by the message.
- let mut req_len_data = [0u8; 4];
- if let Err(e) = infile.read_exact(&mut req_len_data) {
- error!("FATAL: Failed to read request length from connection: {:?}", e);
- return;
- }
- let req_len = u32::from_be_bytes(req_len_data) as usize;
- if req_len > kmr_wire::DEFAULT_MAX_SIZE {
- error!("FATAL: Request too long ({})", req_len);
- return;
- }
- let req_data = &mut buf[..req_len];
- if let Err(e) = infile.read_exact(req_data) {
- error!(
- "FATAL: Failed to read request data of length {} from connection: {:?}",
- req_len, e
- );
+ // Wait for data from either `infile` or `snapshot_socket`. If both have data, we prioritize
+ // processing only `infile` until it is empty so that there is no pending state when we
+ // suspend the loop.
+ let mut fd_set = nix::sys::select::FdSet::new();
+ fd_set.insert(infile.as_raw_fd());
+ fd_set.insert(snapshot_socket.as_raw_fd());
+ if let Err(e) = nix::sys::select::select(
+ None,
+ /*readfds=*/ Some(&mut fd_set),
+ None,
+ None,
+ /*timeout=*/ None,
+ ) {
+ error!("FATAL: Failed to select on input FDs: {:?}", e);
return;
}
- // Pass to the TA to process.
- trace!("-> TA: received data: (len={})", req_data.len());
- let rsp = ta.process(req_data);
- trace!("<- TA: send data: (len={})", rsp.len());
-
- // Send the response message down the pipe, as a 4-byte BE length followed by the message.
- let rsp_len: u32 = match rsp.len().try_into() {
- Ok(l) => l,
- Err(_e) => {
- error!("FATAL: Response too long (len={})", rsp.len());
+ if fd_set.contains(infile.as_raw_fd()) {
+ // Read a request message from the pipe, as a 4-byte BE length followed by the message.
+ let mut req_len_data = [0u8; 4];
+ if let Err(e) = infile.read_exact(&mut req_len_data) {
+ error!("FATAL: Failed to read request length from connection: {:?}", e);
return;
}
- };
- let rsp_len_data = rsp_len.to_be_bytes();
- if let Err(e) = outfile.write_all(&rsp_len_data[..]) {
- error!("FATAL: Failed to write response length to connection: {:?}", e);
- return;
+ let req_len = u32::from_be_bytes(req_len_data) as usize;
+ if req_len > kmr_wire::DEFAULT_MAX_SIZE {
+ error!("FATAL: Request too long ({})", req_len);
+ return;
+ }
+ let req_data = &mut buf[..req_len];
+ if let Err(e) = infile.read_exact(req_data) {
+ error!(
+ "FATAL: Failed to read request data of length {} from connection: {:?}",
+ req_len, e
+ );
+ return;
+ }
+
+ // Pass to the TA to process.
+ trace!("-> TA: received data: (len={})", req_data.len());
+ let rsp = ta.process(req_data);
+ trace!("<- TA: send data: (len={})", rsp.len());
+
+ // Send the response message down the pipe, as a 4-byte BE length followed by the message.
+ let rsp_len: u32 = match rsp.len().try_into() {
+ Ok(l) => l,
+ Err(_e) => {
+ error!("FATAL: Response too long (len={})", rsp.len());
+ return;
+ }
+ };
+ let rsp_len_data = rsp_len.to_be_bytes();
+ if let Err(e) = outfile.write_all(&rsp_len_data[..]) {
+ error!("FATAL: Failed to write response length to connection: {:?}", e);
+ return;
+ }
+ if let Err(e) = outfile.write_all(&rsp) {
+ error!(
+ "FATAL: Failed to write response data of length {} to connection: {:?}",
+ rsp_len, e
+ );
+ return;
+ }
+ let _ = outfile.flush();
+
+ continue;
}
- if let Err(e) = outfile.write_all(&rsp) {
- error!(
- "FATAL: Failed to write response data of length {} to connection: {:?}",
- rsp_len, e
- );
- return;
+
+ if fd_set.contains(snapshot_socket.as_raw_fd()) {
+ // Read suspend request.
+ let mut suspend_request = 0u8;
+ if let Err(e) = snapshot_socket.read_exact(std::slice::from_mut(&mut suspend_request)) {
+ error!("FATAL: Failed to read suspend request: {:?}", e);
+ return;
+ }
+ if suspend_request != SNAPSHOT_SOCKET_MESSAGE_SUSPEND {
+ error!(
+ "FATAL: Unexpected value from snapshot socket: got {}, expected {}",
+ suspend_request, SNAPSHOT_SOCKET_MESSAGE_SUSPEND
+ );
+ return;
+ }
+ // Write ACK.
+ if let Err(e) = snapshot_socket.write_all(&[SNAPSHOT_SOCKET_MESSAGE_SUSPEND_ACK]) {
+ error!("FATAL: Failed to write suspend ACK request: {:?}", e);
+ return;
+ }
+ // Block until we get a resume request.
+ let mut resume_request = 0u8;
+ if let Err(e) = snapshot_socket.read_exact(std::slice::from_mut(&mut resume_request)) {
+ error!("FATAL: Failed to read resume request: {:?}", e);
+ return;
+ }
+ if resume_request != SNAPSHOT_SOCKET_MESSAGE_RESUME {
+ error!(
+ "FATAL: Unexpected value from snapshot socket: got {}, expected {}",
+ resume_request, SNAPSHOT_SOCKET_MESSAGE_RESUME
+ );
+ return;
+ }
}
- let _ = outfile.flush();
}
}
diff --git a/host/commands/secure_env/secure_env_not_windows_main.cpp b/host/commands/secure_env/secure_env_not_windows_main.cpp
index d9122df..b7a617a 100644
--- a/host/commands/secure_env/secure_env_not_windows_main.cpp
+++ b/host/commands/secure_env/secure_env_not_windows_main.cpp
@@ -259,13 +259,16 @@
}
// go/cf-secure-env-snapshot
+ auto [rust_snapshot_socket1, rust_snapshot_socket2] =
+ CF_EXPECT(SharedFD::SocketPair(AF_UNIX, SOCK_STREAM, 0));
SnapshotRunningFlag running;
SharedFD channel_to_run_cvd = DupFdFlag(FLAGS_snapshot_control_fd);
EventFdsManager event_fds_manager = CF_EXPECT(EventFdsManager::Create());
EventNotifiers suspended_notifiers;
SnapshotCommandHandler suspend_resume_handler(
- channel_to_run_cvd, event_fds_manager, suspended_notifiers, running);
+ channel_to_run_cvd, event_fds_manager, suspended_notifiers, running,
+ std::move(rust_snapshot_socket1));
// The guest image may have either the C++ implementation of
// KeyMint/Keymaster, xor the Rust implementation of KeyMint. Those different
@@ -281,8 +284,12 @@
int keymint_in = FLAGS_keymint_fd_in;
int keymint_out = FLAGS_keymint_fd_out;
TpmResourceManager* rm = resource_manager;
- threads.emplace_back([rm, keymint_in, keymint_out, security_level]() {
- kmr_ta_main(keymint_in, keymint_out, security_level, rm);
+ threads.emplace_back([rm, keymint_in, keymint_out, security_level,
+ rust_snapshot_socket2 =
+ std::move(rust_snapshot_socket2)]() {
+ int snapshot_socket_fd = std::move(rust_snapshot_socket2)->UNMANAGED_Dup();
+ kmr_ta_main(keymint_in, keymint_out, security_level, rm,
+ snapshot_socket_fd);
});
#endif
diff --git a/host/commands/secure_env/suspend_resume_handler.cpp b/host/commands/secure_env/suspend_resume_handler.cpp
index a2b55ae..0b47134 100644
--- a/host/commands/secure_env/suspend_resume_handler.cpp
+++ b/host/commands/secure_env/suspend_resume_handler.cpp
@@ -21,6 +21,34 @@
#include "host/libs/config/cuttlefish_config.h"
namespace cuttlefish {
+namespace {
+
+Result<void> WriteSuspendRequest(const SharedFD& socket) {
+ const SnapshotSocketMessage suspend_request = SnapshotSocketMessage::kSuspend;
+ CF_EXPECT_EQ(sizeof(suspend_request),
+ socket->Write(&suspend_request, sizeof(suspend_request)),
+ "socket write failed: " << socket->StrError());
+ return {};
+}
+
+Result<void> ReadSuspendAck(const SharedFD& socket) {
+ SnapshotSocketMessage ack_response;
+ CF_EXPECT_EQ(sizeof(ack_response),
+ socket->Read(&ack_response, sizeof(ack_response)),
+ "socket read failed: " << socket->StrError());
+ CF_EXPECT_EQ(SnapshotSocketMessage::kSuspendAck, ack_response);
+ return {};
+}
+
+Result<void> WriteResumeRequest(const SharedFD& socket) {
+ const SnapshotSocketMessage resume_request = SnapshotSocketMessage::kResume;
+ CF_EXPECT_EQ(sizeof(resume_request),
+ socket->Write(&resume_request, sizeof(resume_request)),
+ "socket write failed: " << socket->StrError());
+ return {};
+}
+
+} // namespace
SnapshotCommandHandler::~SnapshotCommandHandler() { Join(); }
@@ -32,11 +60,13 @@
SnapshotCommandHandler::SnapshotCommandHandler(
SharedFD channel_to_run_cvd, EventFdsManager& event_fds_manager,
- EventNotifiers& suspended_notifiers, SnapshotRunningFlag& running)
+ EventNotifiers& suspended_notifiers, SnapshotRunningFlag& running,
+ SharedFD rust_snapshot_socket)
: channel_to_run_cvd_(channel_to_run_cvd),
event_fds_manager_(event_fds_manager),
suspended_notifiers_(suspended_notifiers),
- shared_running_(running) {
+ shared_running_(running),
+ rust_snapshot_socket_(std::move(rust_snapshot_socket)) {
handler_thread_ = std::thread([this]() {
while (true) {
auto result = SuspendResumeHandler();
@@ -67,13 +97,18 @@
switch (snapshot_cmd) {
case ExtendedActionType::kSuspend: {
LOG(DEBUG) << "Handling suspended...";
+ // Request all worker threads to suspend.
shared_running_.UnsetRunning(); // running := false
CF_EXPECT(event_fds_manager_.SuspendKeymasterResponder());
CF_EXPECT(event_fds_manager_.SuspendGatekeeperResponder());
CF_EXPECT(event_fds_manager_.SuspendOemlockResponder());
+ CF_EXPECT(WriteSuspendRequest(rust_snapshot_socket_));
+ // Wait for ACKs from worker threads.
suspended_notifiers_.keymaster_suspended_.WaitAndReset();
suspended_notifiers_.gatekeeper_suspended_.WaitAndReset();
suspended_notifiers_.oemlock_suspended_.WaitAndReset();
+ CF_EXPECT(ReadSuspendAck(rust_snapshot_socket_));
+ // Write response to run_cvd.
auto response = LauncherResponse::kSuccess;
const auto n_written =
channel_to_run_cvd_->Write(&response, sizeof(response));
@@ -82,7 +117,10 @@
};
case ExtendedActionType::kResume: {
LOG(DEBUG) << "Handling resume...";
+ // Request all worker threads to resume.
shared_running_.SetRunning(); // running := true, and notifies all
+ CF_EXPECT(WriteResumeRequest(rust_snapshot_socket_));
+ // Write response to run_cvd.
auto response = LauncherResponse::kSuccess;
const auto n_written =
channel_to_run_cvd_->Write(&response, sizeof(response));
diff --git a/host/commands/secure_env/suspend_resume_handler.h b/host/commands/secure_env/suspend_resume_handler.h
index b36f3cc..e2e8c20 100644
--- a/host/commands/secure_env/suspend_resume_handler.h
+++ b/host/commands/secure_env/suspend_resume_handler.h
@@ -26,13 +26,43 @@
namespace cuttlefish {
+// `SnapshotCommandHandler` can request threads to suspend and resume using the
+// following protocol. Each message on the socket is 1 byte.
+//
+// Suspend flow:
+//
+// 1. `SnapshotCommandHandler` writes `kSuspend` to the socket.
+// 2. When the worker thread sees the socket is readable, it should assume the
+// incoming message is `kSuspend`, finish all non-blocking work, read the
+// `kSuspend` message, write a `kSuspendAck` message back into the socket,
+// and then, finally, block until it receives another message from the
+// socket (which will always be `kResume`).
+// 3. `SnapshotCommandHandler` waits for the `kSuspendAck` to ensure the
+// worker thread is actually suspended and then proceeds.
+//
+// Resume flow:
+//
+// 1. The worker thread is already blocked waiting for a `kResume` from the
+// socket.
+// 2. `SnapshotCommandHandler` sends a `kResume`.
+// 3. The worker thread sees it and goes back to normal operation.
+//
+// WARNING: Keep in sync with the `SNAPSHOT_SOCKET_MESSAGE_*` constants in
+// secure_env/rust/lib.rs.
+enum SnapshotSocketMessage : uint8_t {
+ kSuspend = 1,
+ kSuspendAck = 2,
+ kResume = 3,
+};
+
class SnapshotCommandHandler {
public:
~SnapshotCommandHandler();
SnapshotCommandHandler(SharedFD channel_to_run_cvd,
EventFdsManager& event_fds,
EventNotifiers& suspended_notifiers,
- SnapshotRunningFlag& running);
+ SnapshotRunningFlag& running,
+ SharedFD rust_snapshot_socket);
private:
Result<void> SuspendResumeHandler();
@@ -43,6 +73,7 @@
EventFdsManager& event_fds_manager_;
EventNotifiers& suspended_notifiers_;
SnapshotRunningFlag& shared_running_; // shared by other components outside
+ SharedFD rust_snapshot_socket_;
std::thread handler_thread_;
};