blob: ea5960c17ebf9e5ae7699ab1437ed57d639372fc [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Provides a struct for registering signal handlers that get cleared on drop.
use std::convert::TryFrom;
use std::fmt;
use std::io::{Cursor, Write};
use std::panic::catch_unwind;
use std::result;
use libc::{c_int, c_void, STDERR_FILENO};
use remain::sorted;
use thiserror::Error;
use crate::errno;
use crate::signal::{
clear_signal_handler, has_default_signal_handler, register_signal_handler, wait_for_signal,
Signal,
};
#[sorted]
#[derive(Error, Debug)]
pub enum Error {
/// Already waiting for interrupt.
#[error("already waiting for interrupt.")]
AlreadyWaiting,
/// Signal already has a handler.
#[error("signal handler already set for {0:?}")]
HandlerAlreadySet(Signal),
/// Failed to check if signal has the default signal handler.
#[error("failed to check the signal handler for {0:?}: {1}")]
HasDefaultSignalHandler(Signal, errno::Error),
/// Failed to register a signal handler.
#[error("failed to register a signal handler for {0:?}: {1}")]
RegisterSignalHandler(Signal, errno::Error),
/// Sigaction failed.
#[error("sigaction failed for {0:?}: {1}")]
Sigaction(Signal, errno::Error),
/// Failed to wait for signal.
#[error("wait_for_signal failed: {0}")]
WaitForSignal(errno::Error),
}
pub type Result<T> = result::Result<T, Error>;
/// The interface used by Scoped Signal handler.
///
/// # Safety
/// The implementation of handle_signal needs to be async signal-safe.
///
/// NOTE: panics are caught when possible because a panic inside ffi is undefined behavior.
pub unsafe trait SignalHandler {
/// A function that is called to handle the passed signal.
fn handle_signal(signal: Signal);
}
/// Wrap the handler with an extern "C" function.
extern "C" fn call_handler<H: SignalHandler>(signum: c_int) {
// Make an effort to surface an error.
if catch_unwind(|| H::handle_signal(Signal::try_from(signum).unwrap())).is_err() {
// Note the following cannot be used:
// eprintln! - uses std::io which has locks that may be held.
// format! - uses the allocator which enforces mutual exclusion.
// Get the debug representation of signum.
let signal: Signal;
let signal_debug: &dyn fmt::Debug = match Signal::try_from(signum) {
Ok(s) => {
signal = s;
&signal as &dyn fmt::Debug
}
Err(_) => &signum as &dyn fmt::Debug,
};
// Buffer the output, so a single call to write can be used.
// The message accounts for 29 chars, that leaves 35 for the string representation of the
// signal which is more than enough.
let mut buffer = [0u8; 64];
let mut cursor = Cursor::new(buffer.as_mut());
if writeln!(cursor, "signal handler got error for: {:?}", signal_debug).is_ok() {
let len = cursor.position() as usize;
// Safe in the sense that buffer is owned and the length is checked. This may print in
// the middle of an existing write, but that is considered better than dropping the
// error.
unsafe {
libc::write(
STDERR_FILENO,
cursor.get_ref().as_ptr() as *const c_void,
len,
)
};
} else {
// This should never happen, but write an error message just in case.
const ERROR_DROPPED: &str = "Error dropped by signal handler.";
let bytes = ERROR_DROPPED.as_bytes();
unsafe { libc::write(STDERR_FILENO, bytes.as_ptr() as *const c_void, bytes.len()) };
}
}
}
/// Represents a signal handler that is registered with a set of signals that unregistered when the
/// struct goes out of scope. Prefer a signalfd based solution before using this.
pub struct ScopedSignalHandler {
signals: Vec<Signal>,
}
impl ScopedSignalHandler {
/// Attempts to register `handler` with the provided `signals`. It will fail if there is already
/// an existing handler on any of `signals`.
///
/// # Safety
/// This is safe if H::handle_signal is async-signal safe.
pub fn new<H: SignalHandler>(signals: &[Signal]) -> Result<Self> {
let mut scoped_handler = ScopedSignalHandler {
signals: Vec::with_capacity(signals.len()),
};
for &signal in signals {
if !has_default_signal_handler((signal).into())
.map_err(|err| Error::HasDefaultSignalHandler(signal, err))?
{
return Err(Error::HandlerAlreadySet(signal));
}
// Requires an async-safe callback.
unsafe {
register_signal_handler((signal).into(), call_handler::<H>)
.map_err(|err| Error::RegisterSignalHandler(signal, err))?
};
scoped_handler.signals.push(signal);
}
Ok(scoped_handler)
}
}
/// Clears the signal handler for any of the associated signals.
impl Drop for ScopedSignalHandler {
fn drop(&mut self) {
for signal in &self.signals {
if let Err(err) = clear_signal_handler((*signal).into()) {
eprintln!("Error: failed to clear signal handler: {:?}", err);
}
}
}
}
/// A signal handler that does nothing.
///
/// This is useful in cases where wait_for_signal is used since it will never trigger if the signal
/// is blocked and the default handler may have undesired effects like terminating the process.
pub struct EmptySignalHandler;
/// # Safety
/// Safe because handle_signal is async-signal safe.
unsafe impl SignalHandler for EmptySignalHandler {
fn handle_signal(_: Signal) {}
}
/// Blocks until SIGINT is received, which often happens because Ctrl-C was pressed in an
/// interactive terminal.
///
/// Note: if you are using a multi-threaded application you need to block SIGINT on all other
/// threads or they may receive the signal instead of the desired thread.
pub fn wait_for_interrupt() -> Result<()> {
// Register a signal handler if there is not one already so the thread is not killed.
let ret = ScopedSignalHandler::new::<EmptySignalHandler>(&[Signal::Interrupt]);
if !matches!(&ret, Ok(_) | Err(Error::HandlerAlreadySet(_))) {
ret?;
}
match wait_for_signal(&[Signal::Interrupt.into()], None) {
Ok(_) => Ok(()),
Err(err) => Err(Error::WaitForSignal(err)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::mem::zeroed;
use std::ptr::{null, null_mut};
use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, Once};
use std::thread::{sleep, spawn};
use std::time::{Duration, Instant};
use libc::sigaction;
use crate::{gettid, kill, Pid};
const TEST_SIGNAL: Signal = Signal::User1;
const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2];
static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
/// Only allows one test case to execute at a time.
fn get_mutex() -> MutexGuard<'static, ()> {
static INIT: Once = Once::new();
static mut VAL: Option<Arc<Mutex<()>>> = None;
INIT.call_once(|| {
let val = Some(Arc::new(Mutex::new(())));
// Safe because the mutation is protected by the Once.
unsafe { VAL = val }
});
// Safe mutation only happens in the Once.
unsafe { VAL.as_ref() }.unwrap().lock().unwrap()
}
fn reset_counter() {
TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst);
}
fn get_sigaction(signal: Signal) -> Result<sigaction> {
// Safe because sigaction is owned and expected to be initialized ot zeros.
let mut sigact: sigaction = unsafe { zeroed() };
if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 {
Err(Error::Sigaction(signal, errno::Error::last()))
} else {
Ok(sigact)
}
}
/// Safety:
/// This is only safe if the signal handler set in sigaction is safe.
unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> {
if sigaction(signal.into(), &sigact, null_mut()) < 0 {
Err(Error::Sigaction(signal, errno::Error::last()))
} else {
Ok(sigact)
}
}
/// Safety:
/// Safe if the signal handler for Signal::User1 is safe.
unsafe fn send_test_signal() {
kill(gettid(), Signal::User1.into()).unwrap()
}
macro_rules! assert_counter_eq {
($compare_to:expr) => {{
let expected: usize = $compare_to;
let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst);
if got != expected {
panic!(
"wrong signal counter value: got {}; expected {}",
got, expected
);
}
}};
}
struct TestHandler;
/// # Safety
/// Safe because handle_signal is async-signal safe.
unsafe impl SignalHandler for TestHandler {
fn handle_signal(signal: Signal) {
if TEST_SIGNAL == signal {
TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}
}
#[test]
fn scopedsignalhandler_success() {
// Prevent other test cases from running concurrently since the signal
// handlers are shared for the process.
let _guard = get_mutex();
reset_counter();
assert_counter_eq!(0);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
// Safe because test_handler is safe.
unsafe { send_test_signal() };
// Give the handler time to run in case it is on a different thread.
for _ in 1..40 {
if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 {
break;
}
sleep(Duration::from_millis(250));
}
assert_counter_eq!(1);
drop(handler);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
}
#[test]
fn scopedsignalhandler_handleralreadyset() {
// Prevent other test cases from running concurrently since the signal
// handlers are shared for the process.
let _guard = get_mutex();
reset_counter();
assert_counter_eq!(0);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
// Safe because TestHandler is async-signal safe.
let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
// Safe because TestHandler is async-signal safe.
assert!(matches!(
ScopedSignalHandler::new::<TestHandler>(&TEST_SIGNALS),
Err(Error::HandlerAlreadySet(Signal::User1))
));
assert_counter_eq!(0);
drop(handler);
assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
}
/// Stores the thread used by WaitForInterruptHandler.
static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0);
/// Forwards SIGINT to the appropriate thread.
struct WaitForInterruptHandler;
/// # Safety
/// Safe because handle_signal is async-signal safe.
unsafe impl SignalHandler for WaitForInterruptHandler {
fn handle_signal(_: Signal) {
let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst);
// If the thread ID is set and executed on the wrong thread, forward the signal.
if tid != 0 && gettid() != tid {
// Safe because the handler is safe and the target thread id is expecting the signal.
unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap();
}
}
}
/// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in
/// D (disk sleep).
fn thread_is_sleeping(tid: Pid) -> result::Result<bool, errno::Error> {
const PREFIX: &str = "State:";
let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?);
let mut line = String::new();
loop {
let count = status_reader.read_line(&mut line)?;
if count == 0 {
return Err(errno::Error::new(libc::EIO));
}
if line.starts_with(PREFIX) {
return Ok(matches!(
line[PREFIX.len()..].trim_start().chars().next(),
Some('S') | Some('D')
));
}
line.clear();
}
}
/// Wait for a process to block either in a sleeping or disk sleep state.
fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), errno::Error> {
let start = Instant::now();
loop {
if thread_is_sleeping(tid)? {
return Ok(());
}
if start.elapsed() > timeout {
return Err(errno::Error::new(libc::EAGAIN));
}
sleep(Duration::from_millis(50));
}
}
#[test]
fn waitforinterrupt_success() {
// Prevent other test cases from running concurrently since the signal
// handlers are shared for the process.
let _guard = get_mutex();
let to_restore = get_sigaction(Signal::Interrupt).unwrap();
clear_signal_handler(Signal::Interrupt.into()).unwrap();
// Safe because TestHandler is async-signal safe.
let handler =
ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap();
let tid = gettid();
WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst);
let join_handle = spawn(move || -> result::Result<(), errno::Error> {
// Wait unitl the thread is ready to receive the signal.
wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap();
// Safe because the SIGINT handler is safe.
unsafe { kill(tid, Signal::Interrupt.into()) }
});
let wait_ret = wait_for_interrupt();
let join_ret = join_handle.join();
drop(handler);
// Safe because we are restoring the previous SIGINT handler.
unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap();
wait_ret.unwrap();
join_ret.unwrap().unwrap();
}
}