blob: 40677bbda54694de2daf6d2fcf12533a72eb09ac [file]
// Copyright 2022 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::ffi::c_void;
use std::ffi::OsString;
use std::io;
use std::ptr;
use winapi::shared::minwindef::ULONG;
use winapi::um::winnt::PVOID;
use super::unicode_string_to_os_string;
// Required for Windows API FFI bindings, as the names of the FFI structs and
// functions get called out by the linter.
#[allow(non_upper_case_globals)]
#[allow(non_camel_case_types)]
#[allow(non_snake_case)]
#[allow(dead_code)]
mod dll_notification_sys {
use std::io;
use winapi::shared::minwindef::ULONG;
use winapi::shared::ntdef::NTSTATUS;
use winapi::shared::ntdef::PCUNICODE_STRING;
use winapi::shared::ntstatus::STATUS_SUCCESS;
use winapi::um::libloaderapi::GetModuleHandleA;
use winapi::um::libloaderapi::GetProcAddress;
use winapi::um::winnt::CHAR;
use winapi::um::winnt::PVOID;
#[repr(C)]
pub union _LDR_DLL_NOTIFICATION_DATA {
pub Loaded: LDR_DLL_LOADED_NOTIFICATION_DATA,
pub Unloaded: LDR_DLL_UNLOADED_NOTIFICATION_DATA,
}
pub type LDR_DLL_NOTIFICATION_DATA = _LDR_DLL_NOTIFICATION_DATA;
pub type PLDR_DLL_NOTIFICATION_DATA = *mut LDR_DLL_NOTIFICATION_DATA;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct _LDR_DLL_LOADED_NOTIFICATION_DATA {
pub Flags: ULONG, // Reserved.
pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
pub DllBase: PVOID, // A pointer to the base address for the DLL in memory.
pub SizeOfImage: ULONG, // The size of the DLL image, in bytes.
}
pub type LDR_DLL_LOADED_NOTIFICATION_DATA = _LDR_DLL_LOADED_NOTIFICATION_DATA;
pub type PLDR_DLL_LOADED_NOTIFICATION_DATA = *mut LDR_DLL_LOADED_NOTIFICATION_DATA;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA {
pub Flags: ULONG, // Reserved.
pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
pub DllBase: PVOID, // A pointer to the base address for the DLL in memory.
pub SizeOfImage: ULONG, // The size of the DLL image, in bytes.
}
pub type LDR_DLL_UNLOADED_NOTIFICATION_DATA = _LDR_DLL_UNLOADED_NOTIFICATION_DATA;
pub type PLDR_DLL_UNLOADED_NOTIFICATION_DATA = *mut LDR_DLL_UNLOADED_NOTIFICATION_DATA;
pub const LDR_DLL_NOTIFICATION_REASON_LOADED: ULONG = 1;
pub const LDR_DLL_NOTIFICATION_REASON_UNLOADED: ULONG = 2;
const NTDLL: &[u8] = b"ntdll\0";
const LDR_REGISTER_DLL_NOTIFICATION: &[u8] = b"LdrRegisterDllNotification\0";
const LDR_UNREGISTER_DLL_NOTIFICATION: &[u8] = b"LdrUnregisterDllNotification\0";
pub type LdrDllNotification = unsafe extern "C" fn(
NotificationReason: ULONG,
NotificationData: PLDR_DLL_NOTIFICATION_DATA,
Context: PVOID,
);
pub type FnLdrRegisterDllNotification =
unsafe extern "C" fn(ULONG, LdrDllNotification, PVOID, *mut PVOID) -> NTSTATUS;
pub type FnLdrUnregisterDllNotification = unsafe extern "C" fn(PVOID) -> NTSTATUS;
extern "C" {
pub fn RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG;
}
/// Wrapper for the NTDLL `LdrRegisterDllNotification` function. Dynamically
/// gets the address of the function and invokes the function with the given
/// arguments.
///
/// # Safety
/// Unsafe as this function does not verify its arguments; the caller is
/// expected to verify the safety as if invoking the underlying C function.
pub unsafe fn LdrRegisterDllNotification(
Flags: ULONG,
NotificationFunction: LdrDllNotification,
Context: PVOID,
Cookie: *mut PVOID,
) -> io::Result<()> {
let proc_addr = GetProcAddress(
/* hModule= */
GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR),
/* lpProcName= */
LDR_REGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR,
);
if proc_addr.is_null() {
return Err(std::io::Error::last_os_error());
}
let ldr_register_dll_notification: FnLdrRegisterDllNotification =
std::mem::transmute(proc_addr);
let ret = ldr_register_dll_notification(Flags, NotificationFunction, Context, Cookie);
if ret != STATUS_SUCCESS {
return Err(io::Error::from_raw_os_error(
RtlNtStatusToDosError(/* Status= */ ret) as i32,
));
};
Ok(())
}
/// Wrapper for the NTDLL `LdrUnregisterDllNotification` function. Dynamically
/// gets the address of the function and invokes the function with the given
/// arguments.
///
/// # Safety
/// Unsafe as this function does not verify its arguments; the caller is
/// expected to verify the safety as if invoking the underlying C function.
pub unsafe fn LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()> {
let proc_addr = GetProcAddress(
/* hModule= */
GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR),
/* lpProcName= */
LDR_UNREGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR,
);
if proc_addr.is_null() {
return Err(std::io::Error::last_os_error());
}
let ldr_unregister_dll_notification: FnLdrUnregisterDllNotification =
std::mem::transmute(proc_addr);
let ret = ldr_unregister_dll_notification(Cookie);
if ret != STATUS_SUCCESS {
return Err(io::Error::from_raw_os_error(
RtlNtStatusToDosError(/* Status= */ ret) as i32,
));
};
Ok(())
}
}
use dll_notification_sys::*;
#[derive(Debug)]
pub struct DllNotificationData {
pub full_dll_name: OsString,
pub base_dll_name: OsString,
}
/// Callback context wrapper for DLL load notification functions.
///
/// This struct provides a wrapper for invoking a function-like type any time a
/// DLL is loaded in the current process. This is done in a type-safe way,
/// provided that users of this struct observe some safety invariants.
///
/// # Safety
/// The struct instance must not be used once it has been registered as a
/// notification target. The callback function assumes that it has a mutable
/// reference to the struct instance. Only once the callback is unregistered is
/// it safe to re-use the struct instance.
struct CallbackContext<F1, F2>
where
F1: FnMut(DllNotificationData),
F2: FnMut(DllNotificationData),
{
loaded_callback: F1,
unloaded_callback: F2,
}
impl<F1, F2> CallbackContext<F1, F2>
where
F1: FnMut(DllNotificationData),
F2: FnMut(DllNotificationData),
{
/// Create a new `CallbackContext` with the two callback functions. Takes
/// two callbacks, a `loaded_callback` which is called when a DLL is
/// loaded, and `unloaded_callback` which is called when a DLL is unloaded.
pub fn new(loaded_callback: F1, unloaded_callback: F2) -> Self {
CallbackContext {
loaded_callback,
unloaded_callback,
}
}
/// Provides a notification function that can be passed to the
/// `LdrRegisterDllNotification` function.
pub fn get_notification_function(&self) -> LdrDllNotification {
Self::notification_function
}
/// A notification function with C linkage. This function assumes that it
/// has exclusive access to the instance of the struct passed through the
/// `context` parameter.
extern "C" fn notification_function(
notification_reason: ULONG,
notification_data: PLDR_DLL_NOTIFICATION_DATA,
context: PVOID,
) {
let callback_context =
// SAFETY: The DLLWatcher guarantees that the CallbackContext instance is not null and
// that we have exclusive access to it.
unsafe { (context as *mut Self).as_mut() }.expect("context was null");
assert!(!notification_data.is_null());
match notification_reason {
LDR_DLL_NOTIFICATION_REASON_LOADED => {
// SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the
// LDR_DLL_LOADED_NOTIFICATION_DATA because we got
// LDR_DLL_NOTIFICATION_REASON_LOADED as the notification reason.
let loaded = unsafe { &mut (*notification_data).Loaded };
assert!(!loaded.BaseDllName.is_null());
// SAFETY: We assert that the pointer is not null and expect that the OS has
// provided a valid UNICODE_STRING struct.
let base_dll_name = unsafe { unicode_string_to_os_string(&*loaded.BaseDllName) };
assert!(!loaded.FullDllName.is_null());
// SAFETY: We assert that the pointer is not null and expect that the OS has
// provided a valid UNICODE_STRING struct.
let full_dll_name = unsafe { unicode_string_to_os_string(&*loaded.FullDllName) };
(callback_context.loaded_callback)(DllNotificationData {
base_dll_name,
full_dll_name,
});
}
LDR_DLL_NOTIFICATION_REASON_UNLOADED => {
// SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the
// LDR_DLL_UNLOADED_NOTIFICATION_DATA because we got
// LDR_DLL_NOTIFICATION_REASON_UNLOADED as the notification reason.
let unloaded = unsafe { &mut (*notification_data).Unloaded };
assert!(!unloaded.BaseDllName.is_null());
// SAFETY: We assert that the pointer is not null and expect that the OS has
// provided a valid UNICODE_STRING struct.
let base_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.BaseDllName) };
assert!(!unloaded.FullDllName.is_null());
// SAFETY: We assert that the pointer is not null and expect that the OS has
// provided a valid UNICODE_STRING struct.
let full_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.FullDllName) };
(callback_context.unloaded_callback)(DllNotificationData {
base_dll_name,
full_dll_name,
})
}
n => panic!("invalid value \"{}\" for dll notification reason", n),
}
}
}
/// DLL watcher for monitoring DLL loads/unloads.
///
/// Provides a method to invoke a function-like type any time a DLL
/// is loaded or unloaded in the current process.
pub struct DllWatcher<F1, F2>
where
F1: FnMut(DllNotificationData),
F2: FnMut(DllNotificationData),
{
context: Box<CallbackContext<F1, F2>>,
cookie: Option<ptr::NonNull<c_void>>,
}
impl<F1, F2> DllWatcher<F1, F2>
where
F1: FnMut(DllNotificationData),
F2: FnMut(DllNotificationData),
{
/// Create a new `DllWatcher` with the two callback functions. Takes two
/// callbacks, a `loaded_callback` which is called when a DLL is loaded,
/// and `unloaded_callback` which is called when a DLL is unloaded.
pub fn new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self> {
let mut watcher = Self {
context: Box::new(CallbackContext::new(loaded_callback, unloaded_callback)),
cookie: None,
};
let mut cookie: PVOID = ptr::null_mut();
// SAFETY: We guarantee that the notification function that we register will have exclusive
// access to the context.
unsafe {
LdrRegisterDllNotification(
/* Flags= */ 0,
/* NotificationFunction= */ watcher.context.get_notification_function(),
/* Context= */
&mut *watcher.context as *mut CallbackContext<F1, F2> as PVOID,
/* Cookie= */ &mut cookie as *mut PVOID,
)?
};
watcher.cookie = ptr::NonNull::new(cookie);
Ok(watcher)
}
fn unregister_dll_notification(&mut self) -> io::Result<()> {
if let Some(c) = self.cookie.take() {
// SAFETY: We guarantee that `Cookie` was previously initialized.
unsafe {
LdrUnregisterDllNotification(/* Cookie= */ c.as_ptr() as PVOID)?
}
}
Ok(())
}
}
impl<F1, F2> Drop for DllWatcher<F1, F2>
where
F1: FnMut(DllNotificationData),
F2: FnMut(DllNotificationData),
{
fn drop(&mut self) {
self.unregister_dll_notification()
.expect("error unregistering dll notification");
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::ffi::CString;
use std::io;
use winapi::shared::minwindef::FALSE;
use winapi::shared::minwindef::TRUE;
use winapi::um::handleapi::CloseHandle;
use winapi::um::libloaderapi::FreeLibrary;
use winapi::um::libloaderapi::LoadLibraryA;
use winapi::um::synchapi::CreateEventA;
use winapi::um::synchapi::SetEvent;
use winapi::um::synchapi::WaitForSingleObject;
use winapi::um::winbase::WAIT_OBJECT_0;
use super::*;
// Arbitrarily chosen DLLs for load/unload test. Chosen because they're
// hopefully esoteric enough that they're probably not already loaded in
// the process so we can test load/unload notifications.
//
// Using a single DLL can lead to flakiness; since the tests are run in the
// same process, it can be hard to rely on the OS to clean up the DLL loaded
// by one test before the other test runs. Using a different DLL makes the
// tests more independent.
const TEST_DLL_NAME_1: &str = "Imagehlp.dll";
const TEST_DLL_NAME_2: &str = "dbghelp.dll";
#[test]
fn load_dll() {
let test_dll_name = CString::new(TEST_DLL_NAME_1).expect("failed to create CString");
let mut loaded_dlls: HashSet<OsString> = HashSet::new();
let h_module = {
let _watcher = DllWatcher::new(
|data| {
loaded_dlls.insert(data.base_dll_name);
},
|_data| (),
)
.expect("failed to create DllWatcher");
// SAFETY: We pass a valid C string in to the function.
unsafe { LoadLibraryA(test_dll_name.as_ptr()) }
};
assert!(
!h_module.is_null(),
"failed to load {}: {}",
TEST_DLL_NAME_1,
io::Error::last_os_error()
);
assert!(
!loaded_dlls.is_empty(),
"no DLL loads recorded by DLL watcher"
);
assert!(
loaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_1.to_owned().into())),
"{} load wasn't recorded by DLL watcher",
TEST_DLL_NAME_1
);
// SAFETY: We initialized h_module with a LoadLibraryA call.
let success = unsafe { FreeLibrary(h_module) } > 0;
assert!(
success,
"failed to free {}: {}",
TEST_DLL_NAME_1,
io::Error::last_os_error(),
)
}
#[test]
fn unload_dll() {
let mut unloaded_dlls: HashSet<OsString> = HashSet::new();
let event =
// SAFETY: No pointers are passed. The handle may leak if the test fails.
unsafe { CreateEventA(std::ptr::null_mut(), TRUE, FALSE, std::ptr::null_mut()) };
assert!(
!event.is_null(),
"failed to create event; event was NULL: {}",
io::Error::last_os_error()
);
{
let test_dll_name = CString::new(TEST_DLL_NAME_2).expect("failed to create CString");
let _watcher = DllWatcher::new(
|_data| (),
|data| {
unloaded_dlls.insert(data.base_dll_name);
// SAFETY: We assert that the event is valid above.
unsafe { SetEvent(event) };
},
)
.expect("failed to create DllWatcher");
// SAFETY: We pass a valid C string in to the function.
let h_module = unsafe { LoadLibraryA(test_dll_name.as_ptr()) };
assert!(
!h_module.is_null(),
"failed to load {}: {}",
TEST_DLL_NAME_2,
io::Error::last_os_error()
);
// SAFETY: We initialized h_module with a LoadLibraryA call.
let success = unsafe { FreeLibrary(h_module) } > 0;
assert!(
success,
"failed to free {}: {}",
TEST_DLL_NAME_2,
io::Error::last_os_error(),
)
};
// SAFETY: We assert that the event is valid above.
assert_eq!(unsafe { WaitForSingleObject(event, 5000) }, WAIT_OBJECT_0);
assert!(
!unloaded_dlls.is_empty(),
"no DLL unloads recorded by DLL watcher"
);
assert!(
unloaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_2.to_owned().into())),
"{} unload wasn't recorded by DLL watcher",
TEST_DLL_NAME_2
);
// SAFETY: We assert that the event is valid above.
unsafe { CloseHandle(event) };
}
}