base: move read/write wrappers to their own module.

In the future we'll be using read/write wrappers in multiple places,
but the wrapping code currently lives in the named pipe implementation.
This CL pulls it out to its own module. It also brings in some other
code in win_util that wasn't upstreamed.

BUG=b:272614458
TEST=presubmit

Change-Id: I6871f3db6991336f42706652b69935755bf2fbc3
Reviewed-on: https://chromium-review.googlesource.com/c/crosvm/crosvm/+/4326942
Commit-Queue: Noah Gold <nkgold@google.com>
Reviewed-by: Daniel Verkamp <dverkamp@chromium.org>
diff --git a/base/src/sys/windows/mod.rs b/base/src/sys/windows/mod.rs
index d0809d9..a09a3e6 100644
--- a/base/src/sys/windows/mod.rs
+++ b/base/src/sys/windows/mod.rs
@@ -26,6 +26,7 @@
 mod priority;
 // Add conditional compile?
 mod punch_hole;
+mod read_write_wrappers;
 mod sched;
 mod shm;
 mod shm_platform;
@@ -50,6 +51,7 @@
 pub(crate) use mmap_platform::PROT_WRITE;
 pub use priority::*;
 pub(crate) use punch_hole::file_punch_hole;
+pub use read_write_wrappers::*;
 pub use sched::*;
 pub use shm::*;
 pub use shm_platform::*;
diff --git a/base/src/sys/windows/named_pipes.rs b/base/src/sys/windows/named_pipes.rs
index bdd2c4f..df94719 100644
--- a/base/src/sys/windows/named_pipes.rs
+++ b/base/src/sys/windows/named_pipes.rs
@@ -16,12 +16,11 @@
 use rand::Rng;
 use serde::Deserialize;
 use serde::Serialize;
+use win_util::fail_if_zero;
 use win_util::SecurityAttributes;
 use win_util::SelfRelativeSecurityDescriptor;
 use winapi::shared::minwindef::DWORD;
 use winapi::shared::minwindef::FALSE;
-use winapi::shared::minwindef::LPCVOID;
-use winapi::shared::minwindef::LPVOID;
 use winapi::shared::minwindef::TRUE;
 use winapi::shared::winerror::ERROR_IO_INCOMPLETE;
 use winapi::shared::winerror::ERROR_IO_PENDING;
@@ -29,13 +28,12 @@
 use winapi::shared::winerror::ERROR_PIPE_CONNECTED;
 use winapi::um::errhandlingapi::GetLastError;
 use winapi::um::fileapi::FlushFileBuffers;
-use winapi::um::fileapi::ReadFile;
-use winapi::um::fileapi::WriteFile;
 use winapi::um::handleapi::INVALID_HANDLE_VALUE;
 use winapi::um::ioapiset::CancelIoEx;
 use winapi::um::ioapiset::GetOverlappedResult;
 use winapi::um::minwinbase::OVERLAPPED;
 use winapi::um::namedpipeapi::ConnectNamedPipe;
+use winapi::um::namedpipeapi::DisconnectNamedPipe;
 use winapi::um::namedpipeapi::GetNamedPipeInfo;
 use winapi::um::namedpipeapi::PeekNamedPipe;
 use winapi::um::namedpipeapi::SetNamedPipeHandleState;
@@ -120,7 +118,13 @@
         } else {
             None
         };
-        overlapped.hEvent = h_event.as_ref().unwrap().as_raw_descriptor();
+
+        overlapped.hEvent = if let Some(event) = h_event.as_ref() {
+            event.as_raw_descriptor()
+        } else {
+            0 as RawDescriptor
+        };
+
         Ok(OverlappedWrapper {
             overlapped: Box::new(overlapped),
             h_event,
@@ -563,49 +567,26 @@
         buf: &mut [T],
         overlapped: Option<&mut OVERLAPPED>,
     ) -> Result<usize> {
-        let max_bytes_to_read: DWORD = mem::size_of_val(buf) as DWORD;
-        // Used to verify if ERROR_IO_PENDING should be an error.
-        let is_overlapped = overlapped.is_some();
-
-        // Safe because we cap the size of the read to the size of the buffer
-        // and check the return code
-        let mut bytes_read: DWORD = 0;
-        let success_flag = ReadFile(
-            handle.as_raw_descriptor(),
-            buf.as_ptr() as LPVOID,
-            max_bytes_to_read,
-            match overlapped {
-                Some(_) => std::ptr::null_mut(),
-                None => &mut bytes_read,
-            },
-            match overlapped {
-                Some(v) => v,
-                None => std::ptr::null_mut(),
-            },
+        let res = crate::platform::read_file(
+            handle,
+            buf.as_mut_ptr() as *mut u8,
+            mem::size_of_val(buf),
+            overlapped,
         );
-
-        if success_flag == 0 {
-            let e = io::Error::last_os_error();
-            match e.raw_os_error() {
-                Some(error_code)
-                    if blocking_mode == BlockingMode::NoWait
-                        && error_code == ERROR_NO_DATA as i32 =>
-                {
-                    // A NOWAIT pipe will return ERROR_NO_DATA when no data is available; however,
-                    // this code is interpreted as a std::io::ErrorKind::BrokenPipe, which is not
-                    // correct. For further details see:
-                    // https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499-
-                    // https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipe-type-read-and-wait-modes
-                    Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, e))
-                }
-                // ERROR_IO_PENDING, according the to docs, isn't really an error. This just means
-                // that the ReadFile operation hasn't completed. In this case,
-                // `get_overlapped_result` will wait until the operation is completed.
-                Some(error_code) if error_code == ERROR_IO_PENDING as i32 && is_overlapped => Ok(0),
-                _ => Err(e),
+        match res {
+            Ok(bytes_read) => Ok(bytes_read),
+            Err(e)
+                if blocking_mode == BlockingMode::NoWait
+                    && e.raw_os_error() == Some(ERROR_NO_DATA as i32) =>
+            {
+                // A NOWAIT pipe will return ERROR_NO_DATA when no data is available; however,
+                // this code is interpreted as a std::io::ErrorKind::BrokenPipe, which is not
+                // correct. For further details see:
+                // https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499-
+                // https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipe-type-read-and-wait-modes
+                Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, e))
             }
-        } else {
-            Ok(bytes_read as usize)
+            Err(e) => Err(e),
         }
     }
 
@@ -702,7 +683,7 @@
 
         // Safe because the underlying pipe handle is guaranteed to be open, and the output values
         // live at valid memory locations.
-        let res = unsafe {
+        fail_if_zero!(unsafe {
             PeekNamedPipe(
                 self.as_raw_descriptor(),
                 ptr::null_mut(),
@@ -711,13 +692,9 @@
                 &mut total_bytes_avail,
                 ptr::null_mut(),
             )
-        };
+        });
 
-        if res == 0 {
-            Err(io::Error::last_os_error())
-        } else {
-            Ok(total_bytes_avail)
-        }
+        Ok(total_bytes_avail)
     }
 
     /// Writes the bytes from a slice into the pipe. Returns the number of bytes written, which
@@ -762,6 +739,8 @@
     /// In order to get how many bytes were written, call `get_overlapped_result`. That function will
     /// also help with waiting until the write operation is complete. The pipe must be opened in
     /// overlapped otherwise there may be unexpected behavior.
+    ///
+    /// WARNING: this function is unsafe. TODO(b/272812234): mark unsafe.
     pub fn write_overlapped<T: PipeSendable>(
         &mut self,
         buf: &[T],
@@ -784,42 +763,21 @@
     }
 
     /// Helper for `write_overlapped` and `write`.
+    /// WARNING: this function is unsafe for overlapped IO. TODO(b/272812234): mark unsafe.
     fn write_internal<T: PipeSendable>(
         handle: &SafeDescriptor,
         buf: &[T],
         overlapped: Option<&mut OVERLAPPED>,
     ) -> Result<usize> {
-        let bytes_to_write: DWORD = mem::size_of_val(buf) as DWORD;
-        let is_overlapped = overlapped.is_some();
-
-        // Safe because buf points to a valid region of memory whose size we have computed,
-        // pipe has not been closed (as it's managed by this object), and we check the return
-        // value for any errors
+        // Safe because buf points to memory valid until the write completes and we pass a valid
+        // length for that memory.
         unsafe {
-            let mut bytes_written: DWORD = 0;
-            let success_flag = WriteFile(
-                handle.as_raw_descriptor(),
-                buf.as_ptr() as LPCVOID,
-                bytes_to_write,
-                match overlapped {
-                    Some(_) => std::ptr::null_mut(),
-                    None => &mut bytes_written,
-                },
-                match overlapped {
-                    Some(v) => v,
-                    None => std::ptr::null_mut(),
-                },
-            );
-
-            if success_flag == 0 {
-                let err = io::Error::last_os_error().raw_os_error().unwrap() as u32;
-                if err == ERROR_IO_PENDING && is_overlapped {
-                    return Ok(0);
-                }
-                Err(io::Error::last_os_error())
-            } else {
-                Ok(bytes_written as usize)
-            }
+            crate::platform::write_file(
+                handle,
+                buf.as_ptr() as *const u8,
+                mem::size_of_val(buf),
+                overlapped,
+            )
         }
     }
 
@@ -834,15 +792,58 @@
 
     /// For a server named pipe, waits for a client to connect
     pub fn wait_for_client_connection(&self) -> Result<()> {
+        let mut overlapped_wrapper = OverlappedWrapper::new(/* include_event = */ true)?;
+        self.wait_for_client_connection_internal(
+            &mut overlapped_wrapper,
+            /* should_block = */ true,
+        )
+    }
+
+    /// For a server named pipe, waits for a client to connect using the given overlapped wrapper
+    /// to signal connection.
+    pub fn wait_for_client_connection_overlapped(
+        &self,
+        overlapped_wrapper: &mut OverlappedWrapper,
+    ) -> Result<()> {
+        self.wait_for_client_connection_internal(
+            overlapped_wrapper,
+            /* should_block = */ false,
+        )
+    }
+
+    fn wait_for_client_connection_internal(
+        &self,
+        overlapped_wrapper: &mut OverlappedWrapper,
+        should_block: bool,
+    ) -> Result<()> {
         // Safe because the handle is valid and we're checking the return
         // code according to the documentation
         unsafe {
             let success_flag = ConnectNamedPipe(
                 self.as_raw_descriptor(),
-                /* lpOverlapped= */ ptr::null_mut(),
+                // Note: The overlapped structure is only used if the pipe was opened in
+                // OVERLAPPED mode, but is necessary in that case.
+                &mut *overlapped_wrapper.overlapped,
             );
-            if success_flag == 0 && GetLastError() != ERROR_PIPE_CONNECTED {
-                return Err(io::Error::last_os_error());
+            if success_flag == 0 {
+                return match GetLastError() {
+                    ERROR_PIPE_CONNECTED => {
+                        if !should_block {
+                            // If async, make sure the event is signalled to indicate the client
+                            // is ready.
+                            overlapped_wrapper.get_h_event_ref().unwrap().signal()?;
+                        }
+
+                        Ok(())
+                    }
+                    ERROR_IO_PENDING => {
+                        if should_block {
+                            overlapped_wrapper.get_h_event_ref().unwrap().wait()?;
+                        }
+                        Ok(())
+                    }
+                    err => Err(io::Error::from_raw_os_error(err as i32)),
+                };
             }
         }
         Ok(())
@@ -902,35 +903,29 @@
         let mut size_transferred = 0;
         // Safe as long as `overlapped_struct` isn't copied and also contains a valid event.
         // Also the named pipe handle must created with `FILE_FLAG_OVERLAPPED`.
-        let res = unsafe {
+        fail_if_zero!(unsafe {
             GetOverlappedResult(
                 self.handle.as_raw_descriptor(),
                 &mut *overlapped_wrapper.overlapped,
                 &mut size_transferred,
                 if wait { TRUE } else { FALSE },
             )
-        };
-        if res == 0 {
-            Err(io::Error::last_os_error())
-        } else {
-            Ok(size_transferred)
-        }
+        });
+
+        Ok(size_transferred)
     }
 
     /// Cancels I/O Operations in the current process. Since `lpOverlapped` is null, this will
     /// cancel all I/O requests for the file handle passed in.
     pub fn cancel_io(&mut self) -> Result<()> {
-        let res = unsafe {
+        fail_if_zero!(unsafe {
             CancelIoEx(
                 self.handle.as_raw_descriptor(),
                 /* lpOverlapped= */ std::ptr::null_mut(),
             )
-        };
-        if res == 0 {
-            Err(io::Error::last_os_error())
-        } else {
-            Ok(())
-        }
+        });
+
+        Ok(())
     }
 
     /// Get the framing mode of the pipe.
@@ -957,7 +952,7 @@
         }
         // Safe because we have allocated all pointers and own
         // them as mutable.
-        let res = unsafe {
+        fail_if_zero!(unsafe {
             GetNamedPipeInfo(
                 self.as_raw_descriptor(),
                 flags as *mut u32,
@@ -965,17 +960,13 @@
                 incoming_buffer_size as *mut u32,
                 max_instances as *mut u32,
             )
-        };
+        });
 
-        if res == 0 {
-            Err(io::Error::last_os_error())
-        } else {
-            Ok(NamedPipeInfo {
-                outgoing_buffer_size,
-                incoming_buffer_size,
-                max_instances,
-            })
-        }
+        Ok(NamedPipeInfo {
+            outgoing_buffer_size,
+            incoming_buffer_size,
+            max_instances,
+        })
     }
 
     /// For a server pipe, flush the pipe contents. This will
@@ -985,12 +976,16 @@
     pub fn flush_data_blocking(&self) -> Result<()> {
         // Safe because the only buffers interacted with are
         // outside of Rust memory
-        let res = unsafe { FlushFileBuffers(self.as_raw_descriptor()) };
-        if res == 0 {
-            Err(io::Error::last_os_error())
-        } else {
-            Ok(())
-        }
+        fail_if_zero!(unsafe { FlushFileBuffers(self.as_raw_descriptor()) });
+        Ok(())
+    }
+
+    /// For a server pipe, disconnect all clients, discarding any buffered data.
+    pub fn disconnect_clients(&self) -> Result<()> {
+        // Safe because we own the handle passed in and know it will remain valid for the duration
+        // of the call. Discarded buffers are not managed by rust.
+        fail_if_zero!(unsafe { DisconnectNamedPipe(self.as_raw_descriptor()) });
+        Ok(())
     }
 }
 
diff --git a/base/src/sys/windows/read_write_wrappers.rs b/base/src/sys/windows/read_write_wrappers.rs
new file mode 100644
index 0000000..b10c86e
--- /dev/null
+++ b/base/src/sys/windows/read_write_wrappers.rs
@@ -0,0 +1,101 @@
+// Copyright 2022 The ChromiumOS Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::io;
+
+use winapi::shared::minwindef::DWORD;
+use winapi::shared::minwindef::LPCVOID;
+use winapi::shared::minwindef::LPVOID;
+use winapi::shared::winerror::ERROR_IO_PENDING;
+use winapi::um::fileapi::ReadFile;
+use winapi::um::fileapi::WriteFile;
+use winapi::um::minwinbase::OVERLAPPED;
+
+use crate::AsRawDescriptor;
+
+/// Safety requirements:
+/// 1. buf points to memory that will not be freed until the write operation completes.
+/// 2. buf points to at least buf_len bytes.
+pub unsafe fn write_file(
+    handle: &dyn AsRawDescriptor,
+    buf: *const u8,
+    buf_len: usize,
+    overlapped: Option<&mut OVERLAPPED>,
+) -> io::Result<usize> {
+    let is_overlapped = overlapped.is_some();
+
+    // Safe because buf points to a valid region of memory whose size we have computed,
+    // pipe has not been closed (as it's managed by this object), and we check the return
+    // value for any errors
+    let mut bytes_written: DWORD = 0;
+    let success_flag = WriteFile(
+        handle.as_raw_descriptor(),
+        buf as LPCVOID,
+        buf_len
+            .try_into()
+            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?,
+        match overlapped {
+            Some(_) => std::ptr::null_mut(),
+            None => &mut bytes_written,
+        },
+        match overlapped {
+            Some(v) => v,
+            None => std::ptr::null_mut(),
+        },
+    );
+
+    if success_flag == 0 {
+        let err = io::Error::last_os_error();
+        if Some(ERROR_IO_PENDING as i32) == err.raw_os_error() && is_overlapped {
+            Ok(0)
+        } else {
+            Err(err)
+        }
+    } else {
+        Ok(bytes_written as usize)
+    }
+}
+
+/// Safety requirements:
+/// 1. buf points to memory that will not be freed until the read operation completes.
+/// 2. buf points to at least buf_len bytes.
+pub unsafe fn read_file(
+    handle: &dyn AsRawDescriptor,
+    buf: *mut u8,
+    buf_len: usize,
+    overlapped: Option<&mut OVERLAPPED>,
+) -> io::Result<usize> {
+    // Used to verify if ERROR_IO_PENDING should be an error.
+    let is_overlapped = overlapped.is_some();
+
+    // Safe because we cap the size of the read to the size of the buffer
+    // and check the return code
+    let mut bytes_read: DWORD = 0;
+    let success_flag = ReadFile(
+        handle.as_raw_descriptor(),
+        buf as LPVOID,
+        buf_len as DWORD,
+        match overlapped {
+            Some(_) => std::ptr::null_mut(),
+            None => &mut bytes_read,
+        },
+        match overlapped {
+            Some(v) => v,
+            None => std::ptr::null_mut(),
+        },
+    );
+
+    if success_flag == 0 {
+        let e = io::Error::last_os_error();
+        match e.raw_os_error() {
+            // ERROR_IO_PENDING, according the to docs, isn't really an error. This just means
+            // that the ReadFile operation hasn't completed. In this case,
+            // `get_overlapped_result` will wait until the operation is completed.
+            Some(error_code) if error_code == ERROR_IO_PENDING as i32 && is_overlapped => Ok(0),
+            _ => Err(e),
+        }
+    } else {
+        Ok(bytes_read as usize)
+    }
+}
diff --git a/win_util/src/lib.rs b/win_util/src/lib.rs
index 660d51d..1fecf65 100644
--- a/win_util/src/lib.rs
+++ b/win_util/src/lib.rs
@@ -68,6 +68,15 @@
     };
 }
 
+#[macro_export]
+macro_rules! fail_if_zero {
+    ($syscall:expr) => {
+        if $syscall == 0 {
+            return Err(io::Error::last_os_error());
+        }
+    };
+}
+
 /// Returns the lower 32 bits of a u64 as a u32 (c_ulong/DWORD)
 pub fn get_low_order(number: u64) -> c_ulong {
     (number & (u32::max_value() as u64)) as c_ulong