Snap for 10117275 from be908da35b79969eb11a853cb2090f26346cd821 to udc-d1-release

Change-Id: I058e740684398ed5e058e7cddb30a3495bc4dc5e
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
index dc95465..d2f328c 100644
--- a/.cargo_vcs_info.json
+++ b/.cargo_vcs_info.json
@@ -1,6 +1,6 @@
 {
   "git": {
-    "sha1": "70f3f76626420e854f1d7cd1dbc8060c27d848cf"
+    "sha1": "8e52adace55c5e082ba2effffcb70bf480d76ec0"
   },
   "path_in_vcs": ""
 }
\ No newline at end of file
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index b2ab149..025b7e9 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -78,7 +78,7 @@
     steps:
       - uses: actions/checkout@v2
       - name: Install QEMU
-        run: sudo apt update && sudo apt install ${{ matrix.packages }}
+        run: sudo apt update && sudo apt install ${{ matrix.packages }} && sudo chmod 666 /dev/vhost-vsock
       - uses: actions-rs/toolchain@v1
         with:
           profile: minimal
diff --git a/Android.bp b/Android.bp
index b8ee9ba..4e47048 100644
--- a/Android.bp
+++ b/Android.bp
@@ -24,7 +24,7 @@
     name: "libvirtio_drivers",
     crate_name: "virtio_drivers",
     cargo_env_compat: true,
-    cargo_pkg_version: "0.3.0",
+    cargo_pkg_version: "0.4.0",
     srcs: ["src/lib.rs"],
     edition: "2018",
     no_stdlibs: true,
@@ -43,7 +43,7 @@
     name: "virtio-drivers_test_src_lib",
     crate_name: "virtio_drivers",
     cargo_env_compat: true,
-    cargo_pkg_version: "0.3.0",
+    cargo_pkg_version: "0.4.0",
     srcs: ["src/lib.rs"],
     test_suites: ["general-tests"],
     auto_gen_config: true,
diff --git a/Cargo.toml b/Cargo.toml
index 17673fb..7f9968a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@
 [package]
 edition = "2018"
 name = "virtio-drivers"
-version = "0.3.0"
+version = "0.4.0"
 authors = [
     "Jiajie Chen <noc@jiegec.ac.cn>",
     "Runji Wang <wangrunji0408@163.com>",
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index 01bb35c..431ccd2 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -1,6 +1,6 @@
 [package]
 name = "virtio-drivers"
-version = "0.3.0"
+version = "0.4.0"
 license = "MIT"
 authors = [
   "Jiajie Chen <noc@jiegec.ac.cn>",
diff --git a/METADATA b/METADATA
index 3623029..b865aa4 100644
--- a/METADATA
+++ b/METADATA
@@ -11,13 +11,13 @@
   }
   url {
     type: ARCHIVE
-    value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.3.0.crate"
+    value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.4.0.crate"
   }
-  version: "0.3.0"
+  version: "0.4.0"
   license_type: NOTICE
   last_upgrade_date {
     year: 2023
-    month: 1
-    day: 24
+    month: 4
+    day: 19
   }
 }
diff --git a/README.md b/README.md
index fb1cda5..fdb61d8 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@
 | GPU     | ✅        |
 | Input   | ✅        |
 | Console | ✅        |
+| Socket  | ✅        |
 | ...     | ❌        |
 
 ### Transports
diff --git a/src/device/blk.rs b/src/device/blk.rs
index 69528b6..d095047 100644
--- a/src/device/blk.rs
+++ b/src/device/blk.rs
@@ -109,7 +109,7 @@
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
             &[req.as_bytes()],
-            &[buf, resp.as_bytes_mut()],
+            &mut [buf, resp.as_bytes_mut()],
             &mut self.transport,
         )?;
         resp.status.into()
@@ -187,7 +187,7 @@
         };
         let token = self
             .queue
-            .add(&[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+            .add(&[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
         if self.queue.should_notify() {
             self.transport.notify(QUEUE);
         }
@@ -208,7 +208,7 @@
         resp: &mut BlkResp,
     ) -> Result<()> {
         self.queue
-            .pop_used(token, &[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+            .pop_used(token, &[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
         resp.status.into()
     }
 
@@ -225,7 +225,7 @@
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
             &[req.as_bytes(), buf],
-            &[resp.as_bytes_mut()],
+            &mut [resp.as_bytes_mut()],
             &mut self.transport,
         )?;
         resp.status.into()
@@ -268,7 +268,7 @@
         };
         let token = self
             .queue
-            .add(&[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+            .add(&[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
         if self.queue.should_notify() {
             self.transport.notify(QUEUE);
         }
@@ -289,7 +289,7 @@
         resp: &mut BlkResp,
     ) -> Result<()> {
         self.queue
-            .pop_used(token, &[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+            .pop_used(token, &[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
         resp.status.into()
     }
 
diff --git a/src/device/common.rs b/src/device/common.rs
new file mode 100644
index 0000000..2c8be3e
--- /dev/null
+++ b/src/device/common.rs
@@ -0,0 +1,23 @@
+//! Common part shared across all the devices.
+
+use bitflags::bitflags;
+
+bitflags! {
+    pub(crate) struct Feature: u64 {
+        // device independent
+        const NOTIFY_ON_EMPTY       = 1 << 24; // legacy
+        const ANY_LAYOUT            = 1 << 27; // legacy
+        const RING_INDIRECT_DESC    = 1 << 28;
+        const RING_EVENT_IDX        = 1 << 29;
+        const UNUSED                = 1 << 30; // legacy
+        const VERSION_1             = 1 << 32; // detect legacy
+
+        // since virtio v1.1
+        const ACCESS_PLATFORM       = 1 << 33;
+        const RING_PACKED           = 1 << 34;
+        const IN_ORDER              = 1 << 35;
+        const ORDER_PLATFORM        = 1 << 36;
+        const SR_IOV                = 1 << 37;
+        const NOTIFICATION_DATA     = 1 << 38;
+    }
+}
diff --git a/src/device/console.rs b/src/device/console.rs
index 749ebc1..e0b0356 100644
--- a/src/device/console.rs
+++ b/src/device/console.rs
@@ -118,7 +118,7 @@
         if self.receive_token.is_none() && self.cursor == self.pending_len {
             // Safe because the buffer lasts at least as long as the queue, and there are no other
             // outstanding requests using the buffer.
-            self.receive_token = Some(unsafe { self.receiveq.add(&[], &[self.queue_buf_rx]) }?);
+            self.receive_token = Some(unsafe { self.receiveq.add(&[], &mut [self.queue_buf_rx]) }?);
             if self.receiveq.should_notify() {
                 self.transport.notify(QUEUE_RECEIVEQ_PORT_0);
             }
@@ -145,13 +145,19 @@
         let mut flag = false;
         if let Some(receive_token) = self.receive_token {
             if self.receive_token == self.receiveq.peek_used() {
-                let len = self
-                    .receiveq
-                    .pop_used(receive_token, &[], &[self.queue_buf_rx])?;
+                // Safe because we are passing the same buffer as we passed to `VirtQueue::add` in
+                // `poll_retrieve` and it is still valid.
+                let len = unsafe {
+                    self.receiveq
+                        .pop_used(receive_token, &[], &mut [self.queue_buf_rx])?
+                };
                 flag = true;
                 assert_ne!(len, 0);
                 self.cursor = 0;
                 self.pending_len = len as usize;
+                // Clear `receive_token` so that when the buffer is used up the next call to
+                // `poll_retrieve` will add a new pending request.
+                self.receive_token.take();
             }
         }
         Ok(flag)
@@ -176,9 +182,8 @@
     /// Sends a character to the console.
     pub fn send(&mut self, chr: u8) -> Result<()> {
         let buf: [u8; 1] = [chr];
-        // Safe because the buffer is valid until we pop_used below.
         self.transmitq
-            .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?;
+            .add_notify_wait_pop(&[&buf], &mut [], &mut self.transport)?;
         Ok(())
     }
 }
diff --git a/src/device/gpu.rs b/src/device/gpu.rs
index eabf2d4..b1b53bd 100644
--- a/src/device/gpu.rs
+++ b/src/device/gpu.rs
@@ -7,6 +7,7 @@
 use crate::{pages, Error, Result};
 use bitflags::bitflags;
 use log::info;
+use zerocopy::{AsBytes, FromBytes};
 
 const QUEUE_SIZE: u16 = 2;
 
@@ -173,86 +174,86 @@
     }
 
     /// Send a request to the device and block for a response.
-    fn request<Req, Rsp>(&mut self, req: Req) -> Result<Rsp> {
-        unsafe {
-            (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
-        }
+    fn request<Req: AsBytes, Rsp: FromBytes>(&mut self, req: Req) -> Result<Rsp> {
+        req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
         self.control_queue.add_notify_wait_pop(
             &[self.queue_buf_send],
-            &[self.queue_buf_recv],
+            &mut [self.queue_buf_recv],
             &mut self.transport,
         )?;
-        Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() })
+        Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap())
     }
 
     /// Send a mouse cursor operation request to the device and block for a response.
-    fn cursor_request<Req>(&mut self, req: Req) -> Result {
-        unsafe {
-            (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
-        }
-        self.cursor_queue
-            .add_notify_wait_pop(&[self.queue_buf_send], &[], &mut self.transport)?;
+    fn cursor_request<Req: AsBytes>(&mut self, req: Req) -> Result {
+        req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
+        self.cursor_queue.add_notify_wait_pop(
+            &[self.queue_buf_send],
+            &mut [],
+            &mut self.transport,
+        )?;
         Ok(())
     }
 
     fn get_display_info(&mut self) -> Result<RespDisplayInfo> {
-        let info: RespDisplayInfo = self.request(CtrlHeader::with_type(Command::GetDisplayInfo))?;
-        info.header.check_type(Command::OkDisplayInfo)?;
+        let info: RespDisplayInfo =
+            self.request(CtrlHeader::with_type(Command::GET_DISPLAY_INFO))?;
+        info.header.check_type(Command::OK_DISPLAY_INFO)?;
         Ok(info)
     }
 
     fn resource_create_2d(&mut self, resource_id: u32, width: u32, height: u32) -> Result {
         let rsp: CtrlHeader = self.request(ResourceCreate2D {
-            header: CtrlHeader::with_type(Command::ResourceCreate2d),
+            header: CtrlHeader::with_type(Command::RESOURCE_CREATE_2D),
             resource_id,
             format: Format::B8G8R8A8UNORM,
             width,
             height,
         })?;
-        rsp.check_type(Command::OkNodata)
+        rsp.check_type(Command::OK_NODATA)
     }
 
     fn set_scanout(&mut self, rect: Rect, scanout_id: u32, resource_id: u32) -> Result {
         let rsp: CtrlHeader = self.request(SetScanout {
-            header: CtrlHeader::with_type(Command::SetScanout),
+            header: CtrlHeader::with_type(Command::SET_SCANOUT),
             rect,
             scanout_id,
             resource_id,
         })?;
-        rsp.check_type(Command::OkNodata)
+        rsp.check_type(Command::OK_NODATA)
     }
 
     fn resource_flush(&mut self, rect: Rect, resource_id: u32) -> Result {
         let rsp: CtrlHeader = self.request(ResourceFlush {
-            header: CtrlHeader::with_type(Command::ResourceFlush),
+            header: CtrlHeader::with_type(Command::RESOURCE_FLUSH),
             rect,
             resource_id,
             _padding: 0,
         })?;
-        rsp.check_type(Command::OkNodata)
+        rsp.check_type(Command::OK_NODATA)
     }
 
     fn transfer_to_host_2d(&mut self, rect: Rect, offset: u64, resource_id: u32) -> Result {
         let rsp: CtrlHeader = self.request(TransferToHost2D {
-            header: CtrlHeader::with_type(Command::TransferToHost2d),
+            header: CtrlHeader::with_type(Command::TRANSFER_TO_HOST_2D),
             rect,
             offset,
             resource_id,
             _padding: 0,
         })?;
-        rsp.check_type(Command::OkNodata)
+        rsp.check_type(Command::OK_NODATA)
     }
 
     fn resource_attach_backing(&mut self, resource_id: u32, paddr: u64, length: u32) -> Result {
         let rsp: CtrlHeader = self.request(ResourceAttachBacking {
-            header: CtrlHeader::with_type(Command::ResourceAttachBacking),
+            header: CtrlHeader::with_type(Command::RESOURCE_ATTACH_BACKING),
             resource_id,
             nr_entries: 1,
             addr: paddr,
             length,
             _padding: 0,
         })?;
-        rsp.check_type(Command::OkNodata)
+        rsp.check_type(Command::OK_NODATA)
     }
 
     fn update_cursor(
@@ -267,9 +268,9 @@
     ) -> Result {
         self.cursor_request(UpdateCursor {
             header: if is_move {
-                CtrlHeader::with_type(Command::MoveCursor)
+                CtrlHeader::with_type(Command::MOVE_CURSOR)
             } else {
-                CtrlHeader::with_type(Command::UpdateCursor)
+                CtrlHeader::with_type(Command::UPDATE_CURSOR)
             },
             pos: CursorPos {
                 scanout_id,
@@ -336,39 +337,41 @@
     }
 }
 
-#[repr(u32)]
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
-enum Command {
-    GetDisplayInfo = 0x100,
-    ResourceCreate2d = 0x101,
-    ResourceUnref = 0x102,
-    SetScanout = 0x103,
-    ResourceFlush = 0x104,
-    TransferToHost2d = 0x105,
-    ResourceAttachBacking = 0x106,
-    ResourceDetachBacking = 0x107,
-    GetCapsetInfo = 0x108,
-    GetCapset = 0x109,
-    GetEdid = 0x10a,
+#[repr(transparent)]
+#[derive(AsBytes, Clone, Copy, Debug, Eq, PartialEq, FromBytes)]
+struct Command(u32);
 
-    UpdateCursor = 0x300,
-    MoveCursor = 0x301,
+impl Command {
+    const GET_DISPLAY_INFO: Command = Command(0x100);
+    const RESOURCE_CREATE_2D: Command = Command(0x101);
+    const RESOURCE_UNREF: Command = Command(0x102);
+    const SET_SCANOUT: Command = Command(0x103);
+    const RESOURCE_FLUSH: Command = Command(0x104);
+    const TRANSFER_TO_HOST_2D: Command = Command(0x105);
+    const RESOURCE_ATTACH_BACKING: Command = Command(0x106);
+    const RESOURCE_DETACH_BACKING: Command = Command(0x107);
+    const GET_CAPSET_INFO: Command = Command(0x108);
+    const GET_CAPSET: Command = Command(0x109);
+    const GET_EDID: Command = Command(0x10a);
 
-    OkNodata = 0x1100,
-    OkDisplayInfo = 0x1101,
-    OkCapsetInfo = 0x1102,
-    OkCapset = 0x1103,
-    OkEdid = 0x1104,
+    const UPDATE_CURSOR: Command = Command(0x300);
+    const MOVE_CURSOR: Command = Command(0x301);
 
-    ErrUnspec = 0x1200,
-    ErrOutOfMemory = 0x1201,
-    ErrInvalidScanoutId = 0x1202,
+    const OK_NODATA: Command = Command(0x1100);
+    const OK_DISPLAY_INFO: Command = Command(0x1101);
+    const OK_CAPSET_INFO: Command = Command(0x1102);
+    const OK_CAPSET: Command = Command(0x1103);
+    const OK_EDID: Command = Command(0x1104);
+
+    const ERR_UNSPEC: Command = Command(0x1200);
+    const ERR_OUT_OF_MEMORY: Command = Command(0x1201);
+    const ERR_INVALID_SCANOUT_ID: Command = Command(0x1202);
 }
 
 const GPU_FLAG_FENCE: u32 = 1 << 0;
 
 #[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(AsBytes, Debug, Clone, Copy, FromBytes)]
 struct CtrlHeader {
     hdr_type: Command,
     flags: u32,
@@ -399,7 +402,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug, Copy, Clone, Default)]
+#[derive(AsBytes, Debug, Copy, Clone, Default, FromBytes)]
 struct Rect {
     x: u32,
     y: u32,
@@ -408,7 +411,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug)]
+#[derive(Debug, FromBytes)]
 struct RespDisplayInfo {
     header: CtrlHeader,
     rect: Rect,
@@ -417,7 +420,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 struct ResourceCreate2D {
     header: CtrlHeader,
     resource_id: u32,
@@ -427,13 +430,13 @@
 }
 
 #[repr(u32)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 enum Format {
     B8G8R8A8UNORM = 1,
 }
 
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 struct ResourceAttachBacking {
     header: CtrlHeader,
     resource_id: u32,
@@ -444,7 +447,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 struct SetScanout {
     header: CtrlHeader,
     rect: Rect,
@@ -453,7 +456,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 struct TransferToHost2D {
     header: CtrlHeader,
     rect: Rect,
@@ -463,7 +466,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 struct ResourceFlush {
     header: CtrlHeader,
     rect: Rect,
@@ -472,7 +475,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(AsBytes, Debug, Clone, Copy)]
 struct CursorPos {
     scanout_id: u32,
     x: u32,
@@ -481,7 +484,7 @@
 }
 
 #[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(AsBytes, Debug, Clone, Copy)]
 struct UpdateCursor {
     header: CtrlHeader,
     pos: CursorPos,
diff --git a/src/device/input.rs b/src/device/input.rs
index 8554282..dee2fec 100644
--- a/src/device/input.rs
+++ b/src/device/input.rs
@@ -1,12 +1,12 @@
 //! Driver for VirtIO input devices.
 
+use super::common::Feature;
 use crate::hal::Hal;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly};
 use crate::Result;
 use alloc::boxed::Box;
-use bitflags::bitflags;
 use core::ptr::NonNull;
 use log::info;
 use zerocopy::{AsBytes, FromBytes};
@@ -42,7 +42,7 @@
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
             // Safe because the buffer lasts as long as the queue.
-            let token = unsafe { event_queue.add(&[], &[event.as_bytes_mut()])? };
+            let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? };
             assert_eq!(token, i as u16);
         }
         if event_queue.should_notify() {
@@ -69,12 +69,18 @@
     pub fn pop_pending_event(&mut self) -> Option<InputEvent> {
         if let Some(token) = self.event_queue.peek_used() {
             let event = &mut self.event_buf[token as usize];
-            self.event_queue
-                .pop_used(token, &[], &[event.as_bytes_mut()])
-                .ok()?;
+            // Safe because we are passing the same buffer as we passed to `VirtQueue::add` and it
+            // is still valid.
+            unsafe {
+                self.event_queue
+                    .pop_used(token, &[], &mut [event.as_bytes_mut()])
+                    .ok()?;
+            }
+            let event_saved = *event;
             // requeue
             // Safe because buffer lasts as long as the queue.
-            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {
+            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &mut [event.as_bytes_mut()]) }
+            {
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // was just freed by `pop_used`.
@@ -82,7 +88,7 @@
                 if self.event_queue.should_notify() {
                     self.transport.notify(QUEUE_EVENT);
                 }
-                return Some(*event);
+                return Some(event_saved);
             }
         }
         None
@@ -185,26 +191,6 @@
     pub value: u32,
 }
 
-bitflags! {
-    struct Feature: u64 {
-        // device independent
-        const NOTIFY_ON_EMPTY       = 1 << 24; // legacy
-        const ANY_LAYOUT            = 1 << 27; // legacy
-        const RING_INDIRECT_DESC    = 1 << 28;
-        const RING_EVENT_IDX        = 1 << 29;
-        const UNUSED                = 1 << 30; // legacy
-        const VERSION_1             = 1 << 32; // detect legacy
-
-        // since virtio v1.1
-        const ACCESS_PLATFORM       = 1 << 33;
-        const RING_PACKED           = 1 << 34;
-        const IN_ORDER              = 1 << 35;
-        const ORDER_PLATFORM        = 1 << 36;
-        const SR_IOV                = 1 << 37;
-        const NOTIFICATION_DATA     = 1 << 38;
-    }
-}
-
 const QUEUE_EVENT: u16 = 0;
 const QUEUE_STATUS: u16 = 1;
 
diff --git a/src/device/mod.rs b/src/device/mod.rs
index f3e4f66..ca68901 100644
--- a/src/device/mod.rs
+++ b/src/device/mod.rs
@@ -5,4 +5,8 @@
 pub mod gpu;
 #[cfg(feature = "alloc")]
 pub mod input;
+#[cfg(feature = "alloc")]
 pub mod net;
+pub mod socket;
+
+pub(crate) mod common;
diff --git a/src/device/net.rs b/src/device/net.rs
index 7ca487e..4441f63 100644
--- a/src/device/net.rs
+++ b/src/device/net.rs
@@ -4,13 +4,95 @@
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly};
-use crate::Result;
+use crate::{Error, Result};
+use alloc::{vec, vec::Vec};
 use bitflags::bitflags;
-use core::mem::{size_of, MaybeUninit};
-use log::{debug, info};
+use core::{convert::TryInto, mem::size_of};
+use log::{debug, info, warn};
 use zerocopy::{AsBytes, FromBytes};
 
-const QUEUE_SIZE: u16 = 2;
+const MAX_BUFFER_LEN: usize = 65535;
+const MIN_BUFFER_LEN: usize = 1526;
+const NET_HDR_SIZE: usize = size_of::<VirtioNetHdr>();
+
+/// A buffer used for transmitting.
+pub struct TxBuffer(Vec<u8>);
+
+/// A buffer used for receiving.
+pub struct RxBuffer {
+    buf: Vec<usize>, // for alignment
+    packet_len: usize,
+    idx: u16,
+}
+
+impl TxBuffer {
+    /// Constructs the buffer from the given slice.
+    pub fn from(buf: &[u8]) -> Self {
+        Self(Vec::from(buf))
+    }
+
+    /// Returns the network packet length.
+    pub fn packet_len(&self) -> usize {
+        self.0.len()
+    }
+
+    /// Returns the network packet as a slice.
+    pub fn packet(&self) -> &[u8] {
+        self.0.as_slice()
+    }
+
+    /// Returns the network packet as a mutable slice.
+    pub fn packet_mut(&mut self) -> &mut [u8] {
+        self.0.as_mut_slice()
+    }
+}
+
+impl RxBuffer {
+    /// Allocates a new buffer with length `buf_len`.
+    fn new(idx: usize, buf_len: usize) -> Self {
+        Self {
+            buf: vec![0; buf_len / size_of::<usize>()],
+            packet_len: 0,
+            idx: idx.try_into().unwrap(),
+        }
+    }
+
+    /// Set the network packet length.
+    fn set_packet_len(&mut self, packet_len: usize) {
+        self.packet_len = packet_len
+    }
+
+    /// Returns the network packet length (witout header).
+    pub const fn packet_len(&self) -> usize {
+        self.packet_len
+    }
+
+    /// Returns all data in the buffer, including both the header and the packet.
+    pub fn as_bytes(&self) -> &[u8] {
+        self.buf.as_bytes()
+    }
+
+    /// Returns all data in the buffer with the mutable reference,
+    /// including both the header and the packet.
+    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
+        self.buf.as_bytes_mut()
+    }
+
+    /// Returns the reference of the header.
+    pub fn header(&self) -> &VirtioNetHdr {
+        unsafe { &*(self.buf.as_ptr() as *const VirtioNetHdr) }
+    }
+
+    /// Returns the network packet as a slice.
+    pub fn packet(&self) -> &[u8] {
+        &self.buf.as_bytes()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+    }
+
+    /// Returns the network packet as a mutable slice.
+    pub fn packet_mut(&mut self) -> &mut [u8] {
+        &mut self.buf.as_bytes_mut()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+    }
+}
 
 /// The virtio network device is a virtual ethernet card.
 ///
@@ -19,16 +101,17 @@
 /// Empty buffers are placed in one virtqueue for receiving packets, and
 /// outgoing packets are enqueued into another for transmission in that order.
 /// A third command queue is used to control advanced filtering features.
-pub struct VirtIONet<H: Hal, T: Transport> {
+pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> {
     transport: T,
     mac: EthernetAddress,
-    recv_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
-    send_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
+    recv_queue: VirtQueue<H, QUEUE_SIZE>,
+    send_queue: VirtQueue<H, QUEUE_SIZE>,
+    rx_buffers: [Option<RxBuffer>; QUEUE_SIZE],
 }
 
-impl<H: Hal, T: Transport> VirtIONet<H, T> {
+impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> {
     /// Create a new VirtIO-Net driver.
-    pub fn new(mut transport: T) -> Result<Self> {
+    pub fn new(mut transport: T, buf_len: usize) -> Result<Self> {
         transport.begin_init(|features| {
             let features = Features::from_bits_truncate(features);
             info!("Device features {:?}", features);
@@ -41,11 +124,37 @@
         // Safe because config points to a valid MMIO region for the config space.
         unsafe {
             mac = volread!(config, mac);
-            debug!("Got MAC={:?}, status={:?}", mac, volread!(config, status));
+            debug!(
+                "Got MAC={:02x?}, status={:?}",
+                mac,
+                volread!(config, status)
+            );
         }
 
-        let recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?;
+        if !(MIN_BUFFER_LEN..=MAX_BUFFER_LEN).contains(&buf_len) {
+            warn!(
+                "Receive buffer len {} is not in range [{}, {}]",
+                buf_len, MIN_BUFFER_LEN, MAX_BUFFER_LEN
+            );
+            return Err(Error::InvalidParam);
+        }
+
         let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
+        let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?;
+
+        const NONE_BUF: Option<RxBuffer> = None;
+        let mut rx_buffers = [NONE_BUF; QUEUE_SIZE];
+        for (i, rx_buf_place) in rx_buffers.iter_mut().enumerate() {
+            let mut rx_buf = RxBuffer::new(i, buf_len);
+            // Safe because the buffer lives as long as the queue.
+            let token = unsafe { recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()])? };
+            assert_eq!(token, i as u16);
+            *rx_buf_place = Some(rx_buf);
+        }
+
+        if recv_queue.should_notify() {
+            transport.notify(QUEUE_RECEIVE);
+        }
 
         transport.finish_init();
 
@@ -54,6 +163,7 @@
             mac,
             recv_queue,
             send_queue,
+            rx_buffers,
         })
     }
 
@@ -63,7 +173,7 @@
     }
 
     /// Get MAC address.
-    pub fn mac(&self) -> EthernetAddress {
+    pub fn mac_address(&self) -> EthernetAddress {
         self.mac
     }
 
@@ -77,27 +187,72 @@
         self.recv_queue.can_pop()
     }
 
-    /// Receive a packet.
-    pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
-        let mut header = MaybeUninit::<Header>::uninit();
-        let header_buf = unsafe { (*header.as_mut_ptr()).as_bytes_mut() };
-        let len =
-            self.recv_queue
-                .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
-        // let header = unsafe { header.assume_init() };
-        Ok(len as usize - size_of::<Header>())
+    /// Receives a [`RxBuffer`] from network. If currently no data, returns an
+    /// error with type [`Error::NotReady`].
+    ///
+    /// It will try to pop a buffer that completed data reception in the
+    /// NIC queue.
+    pub fn receive(&mut self) -> Result<RxBuffer> {
+        if let Some(token) = self.recv_queue.peek_used() {
+            let mut rx_buf = self.rx_buffers[token as usize]
+                .take()
+                .ok_or(Error::WrongToken)?;
+            if token != rx_buf.idx {
+                return Err(Error::WrongToken);
+            }
+
+            // Safe because `token` == `rx_buf.idx`, we are passing the same
+            // buffer as we passed to `VirtQueue::add` and it is still valid.
+            let len = unsafe {
+                self.recv_queue
+                    .pop_used(token, &[], &mut [rx_buf.as_bytes_mut()])?
+            } as usize;
+            rx_buf.set_packet_len(len.checked_sub(NET_HDR_SIZE).ok_or(Error::IoError)?);
+            Ok(rx_buf)
+        } else {
+            Err(Error::NotReady)
+        }
     }
 
-    /// Send a packet.
-    pub fn send(&mut self, buf: &[u8]) -> Result {
-        let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
-        self.send_queue
-            .add_notify_wait_pop(&[header.as_bytes(), buf], &[], &mut self.transport)?;
+    /// Gives back the ownership of `rx_buf`, and recycles it for next use.
+    ///
+    /// It will add the buffer back to the NIC queue.
+    pub fn recycle_rx_buffer(&mut self, mut rx_buf: RxBuffer) -> Result {
+        // Safe because we take the ownership of `rx_buf` back to `rx_buffers`,
+        // it lives as long as the queue.
+        let new_token = unsafe { self.recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()]) }?;
+        // `rx_buffers[new_token]` is expected to be `None` since it was taken
+        // away at `Self::receive()` and has not been added back.
+        if self.rx_buffers[new_token as usize].is_some() {
+            return Err(Error::WrongToken);
+        }
+        rx_buf.idx = new_token;
+        self.rx_buffers[new_token as usize] = Some(rx_buf);
+        if self.recv_queue.should_notify() {
+            self.transport.notify(QUEUE_RECEIVE);
+        }
+        Ok(())
+    }
+
+    /// Allocate a new buffer for transmitting.
+    pub fn new_tx_buffer(&self, buf_len: usize) -> TxBuffer {
+        TxBuffer(vec![0; buf_len])
+    }
+
+    /// Sends a [`TxBuffer`] to the network, and blocks until the request
+    /// completed.
+    pub fn send(&mut self, tx_buf: TxBuffer) -> Result {
+        let header = VirtioNetHdr::default();
+        self.send_queue.add_notify_wait_pop(
+            &[header.as_bytes(), tx_buf.packet()],
+            &mut [],
+            &mut self.transport,
+        )?;
         Ok(())
     }
 }
 
-impl<H: Hal, T: Transport> Drop for VirtIONet<H, T> {
+impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> Drop for VirtIONet<H, T, QUEUE_SIZE> {
     fn drop(&mut self) {
         // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
         // after they have been freed.
@@ -185,26 +340,33 @@
 struct Config {
     mac: ReadOnly<EthernetAddress>,
     status: ReadOnly<Status>,
+    max_virtqueue_pairs: ReadOnly<u16>,
+    mtu: ReadOnly<u16>,
 }
 
 type EthernetAddress = [u8; 6];
 
-// virtio 5.1.6 Device Operation
+/// VirtIO 5.1.6 Device Operation:
+///
+/// Packets are transmitted by placing them in the transmitq1. . .transmitqN,
+/// and buffers for incoming packets are placed in the receiveq1. . .receiveqN.
+/// In each case, the packet itself is preceded by a header.
 #[repr(C)]
-#[derive(AsBytes, Debug, FromBytes)]
-struct Header {
+#[derive(AsBytes, Debug, Default, FromBytes)]
+pub struct VirtioNetHdr {
     flags: Flags,
     gso_type: GsoType,
     hdr_len: u16, // cannot rely on this
     gso_size: u16,
     csum_start: u16,
     csum_offset: u16,
+    // num_buffers: u16, // only available when the feature MRG_RXBUF is negotiated.
     // payload starts from here
 }
 
 bitflags! {
     #[repr(transparent)]
-    #[derive(AsBytes, FromBytes)]
+    #[derive(AsBytes, Default, FromBytes)]
     struct Flags: u8 {
         const NEEDS_CSUM = 1;
         const DATA_VALID = 2;
@@ -213,7 +375,7 @@
 }
 
 #[repr(transparent)]
-#[derive(AsBytes, Debug, Copy, Clone, Eq, FromBytes, PartialEq)]
+#[derive(AsBytes, Debug, Copy, Clone, Default, Eq, FromBytes, PartialEq)]
 struct GsoType(u8);
 
 impl GsoType {
diff --git a/src/device/socket/error.rs b/src/device/socket/error.rs
new file mode 100644
index 0000000..4beec38
--- /dev/null
+++ b/src/device/socket/error.rs
@@ -0,0 +1,69 @@
+//! This module contain the error from the VirtIO socket driver.
+
+use core::{fmt, result};
+
+/// The error type of VirtIO socket driver.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum SocketError {
+    /// There is an existing connection.
+    ConnectionExists,
+    /// Failed to establish the connection.
+    ConnectionFailed,
+    /// The device is not connected to any peer.
+    NotConnected,
+    /// Peer socket is shutdown.
+    PeerSocketShutdown,
+    /// No response received.
+    NoResponseReceived,
+    /// The given buffer is shorter than expected.
+    BufferTooShort,
+    /// The given buffer for output is shorter than expected.
+    OutputBufferTooShort(usize),
+    /// The given buffer has exceeded the maximum buffer size.
+    BufferTooLong(usize, usize),
+    /// Unknown operation.
+    UnknownOperation(u16),
+    /// Invalid operation,
+    InvalidOperation,
+    /// Invalid number.
+    InvalidNumber,
+    /// Unexpected data in packet.
+    UnexpectedDataInPacket,
+    /// Peer has insufficient buffer space, try again later.
+    InsufficientBufferSpaceInPeer,
+    /// Recycled a wrong buffer.
+    RecycledWrongBuffer,
+}
+
+impl fmt::Display for SocketError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::ConnectionExists => write!(
+                f,
+                "There is an existing connection. Please close the current connection before attempting to connect again."),
+            Self::ConnectionFailed => write!(
+                f, "Failed to establish the connection. The packet sent may have an unknown type value"
+            ),
+            Self::NotConnected => write!(f, "The device is not connected to any peer. Please connect it to a peer first."),
+            Self::PeerSocketShutdown => write!(f, "The peer socket is shutdown."),
+            Self::NoResponseReceived => write!(f, "No response received"),
+            Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"),
+            Self::BufferTooLong(actual, max) => {
+                write!(f, "The given buffer length '{actual}' has exceeded the maximum allowed buffer length '{max}'")
+            }
+            Self::OutputBufferTooShort(expected) => {
+                write!(f, "The given output buffer is too short. '{expected}' bytes is needed for the output buffer.")
+            }
+            Self::UnknownOperation(op) => {
+                write!(f, "The operation code '{op}' is unknown")
+            }
+            Self::InvalidOperation => write!(f, "Invalid operation"),
+            Self::InvalidNumber => write!(f, "Invalid number"),
+            Self::UnexpectedDataInPacket => write!(f, "No data is expected in the packet"),
+            Self::InsufficientBufferSpaceInPeer => write!(f, "Peer has insufficient buffer space, try again later"),
+            Self::RecycledWrongBuffer => write!(f, "Recycled a wrong buffer"),
+        }
+    }
+}
+
+pub type Result<T> = result::Result<T, SocketError>;
diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs
new file mode 100644
index 0000000..65280aa
--- /dev/null
+++ b/src/device/socket/mod.rs
@@ -0,0 +1,8 @@
+//! This module implements the virtio vsock device.
+
+mod error;
+mod protocol;
+mod vsock;
+
+pub use error::SocketError;
+pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};
diff --git a/src/device/socket/protocol.rs b/src/device/socket/protocol.rs
new file mode 100644
index 0000000..abc1702
--- /dev/null
+++ b/src/device/socket/protocol.rs
@@ -0,0 +1,184 @@
+//! This module defines the socket device protocol according to the virtio spec v1.1 5.10 Socket Device
+
+use super::error::{self, SocketError};
+use crate::volatile::ReadOnly;
+use core::{
+    convert::{TryFrom, TryInto},
+    fmt,
+};
+use zerocopy::{
+    byteorder::{LittleEndian, U16, U32, U64},
+    AsBytes, FromBytes,
+};
+
+/// Currently only stream sockets are supported. type is 1 for stream socket types.
+#[derive(Copy, Clone, Debug)]
+#[repr(u16)]
+pub enum SocketType {
+    /// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries.
+    Stream = 1,
+}
+
+impl From<SocketType> for U16<LittleEndian> {
+    fn from(socket_type: SocketType) -> Self {
+        (socket_type as u16).into()
+    }
+}
+
+/// VirtioVsockConfig is the vsock device configuration space.
+#[repr(C)]
+pub struct VirtioVsockConfig {
+    /// The guest_cid field contains the guest’s context ID, which uniquely identifies
+    /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
+    ///
+    /// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space,
+    /// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic.
+    /// So we need to split the u64 guest_cid into two parts.
+    pub guest_cid_low: ReadOnly<u32>,
+    pub guest_cid_high: ReadOnly<u32>,
+}
+
+/// The message header for data packets sent on the tx/rx queues
+#[repr(packed)]
+#[derive(AsBytes, Clone, Copy, Debug, FromBytes)]
+pub struct VirtioVsockHdr {
+    pub src_cid: U64<LittleEndian>,
+    pub dst_cid: U64<LittleEndian>,
+    pub src_port: U32<LittleEndian>,
+    pub dst_port: U32<LittleEndian>,
+    pub len: U32<LittleEndian>,
+    pub socket_type: U16<LittleEndian>,
+    pub op: U16<LittleEndian>,
+    pub flags: U32<LittleEndian>,
+    /// Total receive buffer space for this socket. This includes both free and in-use buffers.
+    pub buf_alloc: U32<LittleEndian>,
+    /// Free-running bytes received counter.
+    pub fwd_cnt: U32<LittleEndian>,
+}
+
+impl Default for VirtioVsockHdr {
+    fn default() -> Self {
+        Self {
+            src_cid: 0.into(),
+            dst_cid: 0.into(),
+            src_port: 0.into(),
+            dst_port: 0.into(),
+            len: 0.into(),
+            socket_type: SocketType::Stream.into(),
+            op: 0.into(),
+            flags: 0.into(),
+            buf_alloc: 0.into(),
+            fwd_cnt: 0.into(),
+        }
+    }
+}
+
+impl VirtioVsockHdr {
+    /// Returns the length of the data.
+    pub fn len(&self) -> u32 {
+        u32::from(self.len)
+    }
+
+    pub fn op(&self) -> error::Result<VirtioVsockOp> {
+        self.op.try_into()
+    }
+
+    pub fn source(&self) -> VsockAddr {
+        VsockAddr {
+            cid: self.src_cid.get(),
+            port: self.src_port.get(),
+        }
+    }
+
+    pub fn destination(&self) -> VsockAddr {
+        VsockAddr {
+            cid: self.dst_cid.get(),
+            port: self.dst_port.get(),
+        }
+    }
+
+    pub fn check_data_is_empty(&self) -> error::Result<()> {
+        if self.len() == 0 {
+            Ok(())
+        } else {
+            Err(SocketError::UnexpectedDataInPacket)
+        }
+    }
+}
+
+/// Socket address.
+#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
+pub struct VsockAddr {
+    /// Context Identifier.
+    pub cid: u64,
+    /// Port number.
+    pub port: u32,
+}
+
+/// An event sent to the event queue
+#[derive(Copy, Clone, Debug, Default, AsBytes, FromBytes)]
+#[repr(C)]
+pub struct VirtioVsockEvent {
+    // ID from the virtio_vsock_event_id struct in the virtio spec
+    pub id: U32<LittleEndian>,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+#[repr(u16)]
+pub enum VirtioVsockOp {
+    Invalid = 0,
+
+    /* Connect operations */
+    Request = 1,
+    Response = 2,
+    Rst = 3,
+    Shutdown = 4,
+
+    /* To send payload */
+    Rw = 5,
+
+    /* Tell the peer our credit info */
+    CreditUpdate = 6,
+    /* Request the peer to send the credit info to us */
+    CreditRequest = 7,
+}
+
+impl From<VirtioVsockOp> for U16<LittleEndian> {
+    fn from(op: VirtioVsockOp) -> Self {
+        (op as u16).into()
+    }
+}
+
+impl TryFrom<U16<LittleEndian>> for VirtioVsockOp {
+    type Error = SocketError;
+
+    fn try_from(v: U16<LittleEndian>) -> Result<Self, Self::Error> {
+        let op = match u16::from(v) {
+            0 => Self::Invalid,
+            1 => Self::Request,
+            2 => Self::Response,
+            3 => Self::Rst,
+            4 => Self::Shutdown,
+            5 => Self::Rw,
+            6 => Self::CreditUpdate,
+            7 => Self::CreditRequest,
+            _ => return Err(SocketError::UnknownOperation(v.into())),
+        };
+        Ok(op)
+    }
+}
+
+impl fmt::Debug for VirtioVsockOp {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::Invalid => write!(f, "VIRTIO_VSOCK_OP_INVALID"),
+            Self::Request => write!(f, "VIRTIO_VSOCK_OP_REQUEST"),
+            Self::Response => write!(f, "VIRTIO_VSOCK_OP_RESPONSE"),
+            Self::Rst => write!(f, "VIRTIO_VSOCK_OP_RST"),
+            Self::Shutdown => write!(f, "VIRTIO_VSOCK_OP_SHUTDOWN"),
+            Self::Rw => write!(f, "VIRTIO_VSOCK_OP_RW"),
+            Self::CreditUpdate => write!(f, "VIRTIO_VSOCK_OP_CREDIT_UPDATE"),
+            Self::CreditRequest => write!(f, "VIRTIO_VSOCK_OP_CREDIT_REQUEST"),
+        }
+    }
+}
diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs
new file mode 100644
index 0000000..686d7a6
--- /dev/null
+++ b/src/device/socket/vsock.rs
@@ -0,0 +1,596 @@
+//! Driver for VirtIO socket devices.
+#![deny(unsafe_op_in_unsafe_fn)]
+
+use super::error::SocketError;
+use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
+use crate::device::common::Feature;
+use crate::hal::{BufferDirection, Dma, Hal};
+use crate::queue::VirtQueue;
+use crate::transport::Transport;
+use crate::volatile::volread;
+use crate::Result;
+use core::hint::spin_loop;
+use core::mem::size_of;
+use core::ptr::NonNull;
+use log::{debug, info};
+use zerocopy::{AsBytes, FromBytes};
+
+const RX_QUEUE_IDX: u16 = 0;
+const TX_QUEUE_IDX: u16 = 1;
+const EVENT_QUEUE_IDX: u16 = 2;
+
+const QUEUE_SIZE: usize = 8;
+
+#[derive(Clone, Debug, Default, PartialEq, Eq)]
+struct ConnectionInfo {
+    dst: VsockAddr,
+    src_port: u32,
+    /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
+    /// bytes it has allocated for packet bodies.
+    peer_buf_alloc: u32,
+    /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
+    /// has finished processing.
+    peer_fwd_cnt: u32,
+    /// The number of bytes of packet bodies which we have sent to the peer.
+    tx_cnt: u32,
+    /// The number of bytes of packet bodies which we have received from the peer and handled.
+    fwd_cnt: u32,
+    /// Whether we have recently requested credit from the peer.
+    ///
+    /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
+    /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
+    has_pending_credit_request: bool,
+}
+
+impl ConnectionInfo {
+    fn peer_free(&self) -> u32 {
+        self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
+    }
+
+    fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
+        VirtioVsockHdr {
+            src_cid: src_cid.into(),
+            dst_cid: self.dst.cid.into(),
+            src_port: self.src_port.into(),
+            dst_port: self.dst.port.into(),
+            fwd_cnt: self.fwd_cnt.into(),
+            ..Default::default()
+        }
+    }
+}
+
+/// An event received from a VirtIO socket device.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct VsockEvent {
+    /// The source of the event, i.e. the peer who sent it.
+    pub source: VsockAddr,
+    /// The destination of the event, i.e. the CID and port on our side.
+    pub destination: VsockAddr,
+    /// The type of event.
+    pub event_type: VsockEventType,
+}
+
+/// The reason why a vsock connection was closed.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum DisconnectReason {
+    /// The peer has either closed the connection in response to our shutdown request, or forcibly
+    /// closed it of its own accord.
+    Reset,
+    /// The peer asked to shut down the connection.
+    Shutdown,
+}
+
+/// Details of the type of an event received from a VirtIO socket.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum VsockEventType {
+    /// The connection was successfully established.
+    Connected,
+    /// The connection was closed.
+    Disconnected {
+        /// The reason for the disconnection.
+        reason: DisconnectReason,
+    },
+    /// Data was received on the connection.
+    Received {
+        /// The length of the data in bytes.
+        length: usize,
+    },
+}
+
+/// Driver for a VirtIO socket device.
+pub struct VirtIOSocket<H: Hal, T: Transport> {
+    transport: T,
+    /// Virtqueue to receive packets.
+    rx: VirtQueue<H, { QUEUE_SIZE }>,
+    tx: VirtQueue<H, { QUEUE_SIZE }>,
+    /// Virtqueue to receive events from the device.
+    event: VirtQueue<H, { QUEUE_SIZE }>,
+    /// The guest_cid field contains the guest’s context ID, which uniquely identifies
+    /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
+    guest_cid: u64,
+    rx_buf_dma: Dma<H>,
+
+    /// Currently the device is only allowed to be connected to one destination at a time.
+    connection_info: Option<ConnectionInfo>,
+}
+
+impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
+    fn drop(&mut self) {
+        // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+        // after they have been freed.
+        self.transport.queue_unset(RX_QUEUE_IDX);
+        self.transport.queue_unset(TX_QUEUE_IDX);
+        self.transport.queue_unset(EVENT_QUEUE_IDX);
+    }
+}
+
+impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
+    /// Create a new VirtIO Vsock driver.
+    pub fn new(mut transport: T) -> Result<Self> {
+        transport.begin_init(|features| {
+            let features = Feature::from_bits_truncate(features);
+            info!("Device features: {:?}", features);
+            // negotiate these flags only
+            let supported_features = Feature::empty();
+            (features & supported_features).bits()
+        });
+
+        let config = transport.config_space::<VirtioVsockConfig>()?;
+        info!("config: {:?}", config);
+        // Safe because config is a valid pointer to the device configuration space.
+        let guest_cid = unsafe {
+            volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
+        };
+        info!("guest cid: {guest_cid:?}");
+
+        let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
+        let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
+        let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
+
+        // Allocates 4 KiB memory as the rx buffer.
+        let rx_buf_dma = Dma::new(
+            1, // pages
+            BufferDirection::DeviceToDriver,
+        )?;
+        let rx_buf = rx_buf_dma.raw_slice();
+        // Safe because `rx_buf` lives as long as the `rx` queue.
+        unsafe {
+            Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?;
+        }
+        transport.finish_init();
+
+        Ok(Self {
+            transport,
+            rx,
+            tx,
+            event,
+            guest_cid,
+            rx_buf_dma,
+            connection_info: None,
+        })
+    }
+
+    /// Fills the `rx` queue with the buffer `rx_buf`.
+    ///
+    /// # Safety
+    ///
+    /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are
+    /// in the queue must not be used anywhere else at the same time.
+    unsafe fn fill_rx_queue(
+        rx: &mut VirtQueue<H, { QUEUE_SIZE }>,
+        rx_buf: NonNull<[u8]>,
+        transport: &mut T,
+    ) -> Result {
+        if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE {
+            return Err(SocketError::BufferTooShort.into());
+        }
+        for i in 0..QUEUE_SIZE {
+            // Safe because the buffer lives as long as the queue, as specified in the function
+            // safety requirement, and we don't access it until it is popped.
+            unsafe {
+                let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?;
+                let token = rx.add(&[], &mut [buffer])?;
+                assert_eq!(i, token.into());
+            }
+        }
+
+        if rx.should_notify() {
+            transport.notify(RX_QUEUE_IDX);
+        }
+        Ok(())
+    }
+
+    /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
+    pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
+        if self.connection_info.is_some() {
+            return Err(SocketError::ConnectionExists.into());
+        }
+        let new_connection_info = ConnectionInfo {
+            dst: VsockAddr {
+                cid: dst_cid,
+                port: dst_port,
+            },
+            src_port,
+            ..Default::default()
+        };
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Request.into(),
+            ..new_connection_info.new_header(self.guest_cid)
+        };
+        // Sends a header only packet to the tx queue to connect the device to the listening
+        // socket at the given destination.
+        self.send_packet_to_tx_queue(&header, &[])?;
+
+        self.connection_info = Some(new_connection_info);
+        debug!("Connection requested: {:?}", self.connection_info);
+        Ok(())
+    }
+
+    /// Blocks until the peer either accepts our connection request (with a
+    /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
+    /// `VIRTIO_VSOCK_OP_RST`).
+    pub fn wait_for_connect(&mut self) -> Result {
+        match self.wait_for_recv(&mut [])?.event_type {
+            VsockEventType::Connected => Ok(()),
+            VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
+            VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()),
+        }
+    }
+
+    /// Requests the peer to send us a credit update for the current connection.
+    fn request_credit(&mut self) -> Result {
+        let connection_info = self.connection_info()?;
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::CreditRequest.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])
+    }
+
+    /// Sends the buffer to the destination.
+    pub fn send(&mut self, buffer: &[u8]) -> Result {
+        let mut connection_info = self.connection_info()?;
+
+        let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len());
+        self.connection_info = Some(connection_info.clone());
+        result?;
+
+        let len = buffer.len() as u32;
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Rw.into(),
+            len: len.into(),
+            buf_alloc: 0.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.connection_info.as_mut().unwrap().tx_cnt += len;
+        self.send_packet_to_tx_queue(&header, buffer)
+    }
+
+    fn check_peer_buffer_is_sufficient(
+        &mut self,
+        connection_info: &mut ConnectionInfo,
+        buffer_len: usize,
+    ) -> Result {
+        if connection_info.peer_free() as usize >= buffer_len {
+            Ok(())
+        } else {
+            // Request an update of the cached peer credit, if we haven't already done so, and tell
+            // the caller to try again later.
+            if !connection_info.has_pending_credit_request {
+                self.request_credit()?;
+                connection_info.has_pending_credit_request = true;
+            }
+            Err(SocketError::InsufficientBufferSpaceInPeer.into())
+        }
+    }
+
+    /// Polls the vsock device to receive data or other updates.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
+        let connection_info = self.connection_info()?;
+
+        // Tell the peer that we have space to receive some data.
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::CreditUpdate.into(),
+            buf_alloc: (buffer.len() as u32).into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])?;
+
+        // Handle entries from the RX virtqueue until we find one that generates an event.
+        let event = self.poll_rx_queue(buffer)?;
+
+        if self.rx.should_notify() {
+            self.transport.notify(RX_QUEUE_IDX);
+        }
+
+        Ok(event)
+    }
+
+    /// Blocks until we get some event from the vsock device.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+        loop {
+            if let Some(event) = self.poll_recv(buffer)? {
+                return Ok(event);
+            } else {
+                spin_loop();
+            }
+        }
+    }
+
+    /// Request to shut down the connection cleanly.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
+    pub fn shutdown(&mut self) -> Result {
+        let connection_info = self.connection_info()?;
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Shutdown.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])
+    }
+
+    /// Forcibly closes the connection without waiting for the peer.
+    pub fn force_close(&mut self) -> Result {
+        let connection_info = self.connection_info()?;
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Rst.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])?;
+        self.connection_info = None;
+        Ok(())
+    }
+
+    fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
+        let _len = self.tx.add_notify_wait_pop(
+            &[header.as_bytes(), buffer],
+            &mut [],
+            &mut self.transport,
+        )?;
+        Ok(())
+    }
+
+    /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
+    /// which generates a `VsockEvent`.
+    ///
+    /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
+    /// don't result in any events to return.
+    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
+        loop {
+            let mut connection_info = self.connection_info.clone().unwrap_or_default();
+            let Some(header) = self.pop_packet_from_rx_queue(body)? else{
+                return Ok(None);
+            };
+
+            let op = header.op()?;
+
+            // Skip packets which don't match our current connection.
+            if header.source() != connection_info.dst
+                || header.dst_cid.get() != self.guest_cid
+                || header.dst_port.get() != connection_info.src_port
+            {
+                debug!(
+                    "Skipping {:?} as connection is {:?}",
+                    header, connection_info
+                );
+                continue;
+            }
+
+            connection_info.peer_buf_alloc = header.buf_alloc.into();
+            connection_info.peer_fwd_cnt = header.fwd_cnt.into();
+            if self.connection_info.is_some() {
+                self.connection_info = Some(connection_info.clone());
+                debug!("Connection info updated: {:?}", self.connection_info);
+            }
+
+            match op {
+                VirtioVsockOp::Request => {
+                    header.check_data_is_empty()?;
+                    // TODO: Send a Rst, or support listening.
+                }
+                VirtioVsockOp::Response => {
+                    header.check_data_is_empty()?;
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Connected,
+                    }));
+                }
+                VirtioVsockOp::CreditUpdate => {
+                    header.check_data_is_empty()?;
+                    connection_info.has_pending_credit_request = false;
+                    if self.connection_info.is_some() {
+                        self.connection_info = Some(connection_info.clone());
+                    }
+
+                    // Virtio v1.1 5.10.6.3
+                    // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
+                    // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
+                    // any time a change in buffer space occurs.
+                    continue;
+                }
+                VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
+                    header.check_data_is_empty()?;
+
+                    self.connection_info = None;
+                    info!("Disconnected from the peer");
+
+                    let reason = if op == VirtioVsockOp::Rst {
+                        DisconnectReason::Reset
+                    } else {
+                        DisconnectReason::Shutdown
+                    };
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Disconnected { reason },
+                    }));
+                }
+                VirtioVsockOp::Rw => {
+                    self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Received {
+                            length: header.len() as usize,
+                        },
+                    }));
+                }
+                VirtioVsockOp::CreditRequest => {
+                    header.check_data_is_empty()?;
+                    // TODO: Send a credit update.
+                }
+                VirtioVsockOp::Invalid => {
+                    return Err(SocketError::InvalidOperation.into());
+                }
+            }
+        }
+    }
+
+    /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
+    /// the body into the given buffer.
+    ///
+    /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
+    /// buffer supplied.
+    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
+        let Some(token) = self.rx.peek_used() else {
+            return Ok(None);
+        };
+
+        // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
+        // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
+        // buffer back to the RX queue then we don't access it again until next time it is popped.
+        let header = unsafe {
+            let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf_dma.raw_slice(), token.into())?;
+            let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
+
+            // Read the header and body from the buffer. Don't check the result yet, because we need
+            // to add the buffer back to the queue either way.
+            let header_result = read_header_and_body(buffer, body);
+
+            // Add the buffer back to the RX queue.
+            let new_token = self.rx.add(&[], &mut [buffer])?;
+            // If the RX buffer somehow gets assigned a different token, then our safety assumptions
+            // are broken and we can't safely continue to do anything with the device.
+            assert_eq!(new_token, token);
+
+            header_result
+        }?;
+
+        debug!("Received packet {:?}. Op {:?}", header, header.op());
+        Ok(Some(header))
+    }
+
+    fn connection_info(&self) -> Result<ConnectionInfo> {
+        self.connection_info
+            .clone()
+            .ok_or(SocketError::NotConnected.into())
+    }
+
+    /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX
+    /// virtqueue.
+    ///
+    /// # Safety
+    ///
+    /// `rx_buf` must be a valid dereferenceable pointer.
+    /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with
+    /// any other references to the same subslice of the RX buffer or outlive the buffer.
+    unsafe fn as_mut_sub_rx_buffer<'a>(
+        mut rx_buf: NonNull<[u8]>,
+        i: usize,
+    ) -> Result<&'a mut [u8]> {
+        let buffer_size = rx_buf.len() / QUEUE_SIZE;
+        let start = buffer_size
+            .checked_mul(i)
+            .ok_or(SocketError::InvalidNumber)?;
+        let end = start
+            .checked_add(buffer_size)
+            .ok_or(SocketError::InvalidNumber)?;
+        // Safe because no alignment or initialisation is required for [u8], and our caller assures
+        // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating
+        // won't overlap with any other references to the same slice or outlive it.
+        unsafe {
+            rx_buf
+                .as_mut()
+                .get_mut(start..end)
+                .ok_or(SocketError::BufferTooShort.into())
+        }
+    }
+}
+
+fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
+    let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
+    let body_length = header.len() as usize;
+    let data_end = size_of::<VirtioVsockHdr>()
+        .checked_add(body_length)
+        .ok_or(SocketError::InvalidNumber)?;
+    let data = buffer
+        .get(size_of::<VirtioVsockHdr>()..data_end)
+        .ok_or(SocketError::BufferTooShort)?;
+    body.get_mut(0..body_length)
+        .ok_or(SocketError::OutputBufferTooShort(body_length))?
+        .copy_from_slice(data);
+    Ok(header)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::volatile::ReadOnly;
+    use crate::{
+        hal::fake::FakeHal,
+        transport::{
+            fake::{FakeTransport, QueueStatus, State},
+            DeviceStatus, DeviceType,
+        },
+    };
+    use alloc::{sync::Arc, vec};
+    use core::ptr::NonNull;
+    use std::sync::Mutex;
+
+    #[test]
+    fn config() {
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![QueueStatus::default(); 3],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let socket =
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
+        assert_eq!(socket.guest_cid, 0x00_0000_0042);
+    }
+}
diff --git a/src/hal.rs b/src/hal.rs
index fd8c435..6295f5f 100644
--- a/src/hal.rs
+++ b/src/hal.rs
@@ -53,35 +53,81 @@
 
 impl<H: Hal> Drop for Dma<H> {
     fn drop(&mut self) {
-        let err = H::dma_dealloc(self.paddr, self.vaddr, self.pages);
+        // Safe because the memory was previously allocated by `dma_alloc` in `Dma::new`, not yet
+        // deallocated, and we are passing the values from then.
+        let err = unsafe { H::dma_dealloc(self.paddr, self.vaddr, self.pages) };
         assert_eq!(err, 0, "failed to deallocate DMA");
     }
 }
 
 /// The interface which a particular hardware implementation must implement.
-pub trait Hal {
+///
+/// # Safety
+///
+/// Implementations of this trait must follow the "implementation safety" requirements documented
+/// for each method. Callers must follow the safety requirements documented for the unsafe methods.
+pub unsafe trait Hal {
     /// Allocates the given number of contiguous physical pages of DMA memory for VirtIO use.
     ///
     /// Returns both the physical address which the device can use to access the memory, and a
     /// pointer to the start of it which the driver can use to access it.
+    ///
+    /// # Implementation safety
+    ///
+    /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+    /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, aligned to
+    /// [`PAGE_SIZE`], and won't alias any other allocations or references in the program until it
+    /// is deallocated by `dma_dealloc`.
     fn dma_alloc(pages: usize, direction: BufferDirection) -> (PhysAddr, NonNull<u8>);
+
     /// Deallocates the given contiguous physical DMA memory pages.
-    fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+    ///
+    /// # Safety
+    ///
+    /// The memory must have been allocated by `dma_alloc` on the same `Hal` implementation, and not
+    /// yet deallocated. `pages` must be the same number passed to `dma_alloc` originally, and both
+    /// `paddr` and `vaddr` must be the values returned by `dma_alloc`.
+    unsafe fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+
     /// Converts a physical address used for MMIO to a virtual address which the driver can access.
     ///
     /// This is only used for MMIO addresses within BARs read from the device, for the PCI
     /// transport. It may check that the address range up to the given size is within the region
     /// expected for MMIO.
-    fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+    ///
+    /// # Implementation safety
+    ///
+    /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+    /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, and won't alias any
+    /// other allocations or references in the program.
+    ///
+    /// # Safety
+    ///
+    /// The `paddr` and `size` must describe a valid MMIO region. The implementation may validate it
+    /// in some way (and panic if it is invalid) but is not guaranteed to.
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+
     /// Shares the given memory range with the device, and returns the physical address that the
     /// device can use to access it.
     ///
     /// This may involve mapping the buffer into an IOMMU, giving the host permission to access the
     /// memory, or copying it to a special region where it can be accessed.
-    fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+    ///
+    /// # Safety
+    ///
+    /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+    /// for the duration of this method call.
+    unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+
     /// Unshares the given memory range from the device and (if necessary) copies it back to the
     /// original buffer.
-    fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
+    ///
+    /// # Safety
+    ///
+    /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+    /// for the duration of this method call. The `paddr` must be the value previously returned by
+    /// the corresponding `share` call.
+    unsafe fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
 }
 
 /// The direction in which a buffer is passed.
diff --git a/src/hal/fake.rs b/src/hal/fake.rs
index 2af60a9..5d46835 100644
--- a/src/hal/fake.rs
+++ b/src/hal/fake.rs
@@ -1,5 +1,7 @@
 //! Fake HAL implementation for tests.
 
+#![deny(unsafe_op_in_unsafe_fn)]
+
 use crate::{BufferDirection, Hal, PhysAddr, PAGE_SIZE};
 use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error};
 use core::{alloc::Layout, ptr::NonNull};
@@ -8,7 +10,7 @@
 pub struct FakeHal;
 
 /// Fake HAL implementation for use in unit tests.
-impl Hal for FakeHal {
+unsafe impl Hal for FakeHal {
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
@@ -21,7 +23,7 @@
         }
     }
 
-    fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
+    unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
         // Safe because the layout is the same as was used when the memory was allocated by
@@ -32,17 +34,17 @@
         0
     }
 
-    fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
         NonNull::new(paddr as _).unwrap()
     }
 
-    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
         let vaddr = buffer.as_ptr() as *mut u8 as usize;
         // Nothing to do, as the host already has access to all memory.
         virt_to_phys(vaddr)
     }
 
-    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
         // anywhere else.
     }
diff --git a/src/lib.rs b/src/lib.rs
index 6a12401..754dd51 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -89,6 +89,8 @@
     ConfigSpaceTooSmall,
     /// The device doesn't have any config space, but the driver expects some.
     ConfigSpaceMissing,
+    /// Error from the socket device.
+    SocketDeviceError(device::socket::SocketError),
 }
 
 impl Display for Error {
@@ -115,10 +117,17 @@
                     "The device doesn't have any config space, but the driver expects some"
                 )
             }
+            Self::SocketDeviceError(e) => write!(f, "Error from the socket device: {e:?}"),
         }
     }
 }
 
+impl From<device::socket::SocketError> for Error {
+    fn from(e: device::socket::SocketError) -> Self {
+        Self::SocketDeviceError(e)
+    }
+}
+
 /// Align `size` up to a page.
 fn align_up(size: usize) -> usize {
     (size + PAGE_SIZE) & !(PAGE_SIZE - 1)
diff --git a/src/queue.rs b/src/queue.rs
index f45da11..d6baf17 100644
--- a/src/queue.rs
+++ b/src/queue.rs
@@ -1,3 +1,5 @@
+#![deny(unsafe_op_in_unsafe_fn)]
+
 use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
 use crate::transport::Transport;
 use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE};
@@ -5,7 +7,7 @@
 #[cfg(test)]
 use core::cmp::min;
 use core::hint::spin_loop;
-use core::mem::size_of;
+use core::mem::{size_of, take};
 #[cfg(test)]
 use core::ptr;
 use core::ptr::NonNull;
@@ -114,8 +116,13 @@
     ///
     /// # Safety
     ///
-    /// The input and output buffers must remain valid until the token is returned by `pop_used`.
-    pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
+    /// The input and output buffers must remain valid and not be accessed until a call to
+    /// `pop_used` with the returned token succeeds.
+    pub unsafe fn add<'a, 'b>(
+        &mut self,
+        inputs: &'a [&'b [u8]],
+        outputs: &'a mut [&'b mut [u8]],
+    ) -> Result<u16> {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
         }
@@ -127,10 +134,14 @@
         let head = self.free_head;
         let mut last = self.free_head;
 
-        for (buffer, direction) in input_output_iter(inputs, outputs) {
+        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
             // Write to desc_shadow then copy.
             let desc = &mut self.desc_shadow[usize::from(self.free_head)];
-            desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
+            // Safe because our caller promises that the buffers live at least until `pop_used`
+            // returns them.
+            unsafe {
+                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
+            }
             last = self.free_head;
             self.free_head = desc.next;
 
@@ -172,14 +183,14 @@
     /// them, then pops them.
     ///
     /// This assumes that the device isn't processing any other buffers at the same time.
-    pub fn add_notify_wait_pop(
+    pub fn add_notify_wait_pop<'a>(
         &mut self,
-        inputs: &[*const [u8]],
-        outputs: &[*mut [u8]],
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
         transport: &mut impl Transport,
     ) -> Result<u32> {
-        // Safe because we don't return until the same token has been popped, so they remain valid
-        // until then.
+        // Safe because we don't return until the same token has been popped, so the buffers remain
+        // valid and are not otherwise accessed until then.
         let token = unsafe { self.add(inputs, outputs) }?;
 
         // Notify the queue.
@@ -192,7 +203,9 @@
             spin_loop();
         }
 
-        self.pop_used(token, inputs, outputs)
+        // Safe because these are the same buffers as we passed to `add` above and they are still
+        // valid.
+        unsafe { self.pop_used(token, inputs, outputs) }
     }
 
     /// Returns whether the driver should notify the device after adding a new buffer to the
@@ -252,12 +265,22 @@
     /// passed in too.
     ///
     /// This will push all linked descriptors at the front of the free list.
-    fn recycle_descriptors(&mut self, head: u16, inputs: &[*const [u8]], outputs: &[*mut [u8]]) {
+    ///
+    /// # Safety
+    ///
+    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+    /// queue by `add`.
+    unsafe fn recycle_descriptors<'a>(
+        &mut self,
+        head: u16,
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
+    ) {
         let original_free_head = self.free_head;
         self.free_head = head;
         let mut next = Some(head);
 
-        for (buffer, direction) in input_output_iter(inputs, outputs) {
+        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
             let desc_index = next.expect("Descriptor chain was shorter than expected.");
             let desc = &mut self.desc_shadow[usize::from(desc_index)];
 
@@ -271,8 +294,12 @@
 
             self.write_desc(desc_index);
 
-            // Unshare the buffer (and perhaps copy its contents back to the original buffer).
-            H::unshare(paddr as usize, buffer, direction);
+            // Safe because the caller ensures that the buffer is valid and matches the descriptor
+            // from which we got `paddr`.
+            unsafe {
+                // Unshare the buffer (and perhaps copy its contents back to the original buffer).
+                H::unshare(paddr as usize, buffer, direction);
+            }
         }
 
         if next.is_some() {
@@ -284,11 +311,16 @@
     /// length which was used (written) by the device.
     ///
     /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
-    pub fn pop_used(
+    ///
+    /// # Safety
+    ///
+    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+    /// queue by `add` when it returned the token being passed in here.
+    pub unsafe fn pop_used<'a>(
         &mut self,
         token: u16,
-        inputs: &[*const [u8]],
-        outputs: &[*mut [u8]],
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
     ) -> Result<u32> {
         if !self.can_pop() {
             return Err(Error::NotReady);
@@ -311,7 +343,10 @@
             return Err(Error::WrongToken);
         }
 
-        self.recycle_descriptors(index, inputs, outputs);
+        // Safe because the caller ensures the buffers are valid and match the descriptor.
+        unsafe {
+            self.recycle_descriptors(index, inputs, outputs);
+        }
         self.last_used_idx = self.last_used_idx.wrapping_add(1);
 
         Ok(len)
@@ -486,7 +521,10 @@
         direction: BufferDirection,
         extra_flags: DescFlags,
     ) {
-        self.addr = H::share(buf, direction) as u64;
+        // Safe because our caller promises that the buffer is valid.
+        unsafe {
+            self.addr = H::share(buf, direction) as u64;
+        }
         self.len = buf.len() as u32;
         self.flags = extra_flags
             | match direction {
@@ -558,6 +596,46 @@
     len: u32,
 }
 
+struct InputOutputIter<'a, 'b> {
+    inputs: &'a [&'b [u8]],
+    outputs: &'a mut [&'b mut [u8]],
+}
+
+impl<'a, 'b> InputOutputIter<'a, 'b> {
+    fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
+        Self { inputs, outputs }
+    }
+}
+
+impl<'a, 'b> Iterator for InputOutputIter<'a, 'b> {
+    type Item = (NonNull<[u8]>, BufferDirection);
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if let Some(input) = take_first(&mut self.inputs) {
+            Some(((*input).into(), BufferDirection::DriverToDevice))
+        } else {
+            let output = take_first_mut(&mut self.outputs)?;
+            Some(((*output).into(), BufferDirection::DeviceToDriver))
+        }
+    }
+}
+
+// TODO: Use `slice::take_first` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
+    let (first, rem) = slice.split_first()?;
+    *slice = rem;
+    Some(first)
+}
+
+// TODO: Use `slice::take_first_mut` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
+    let (first, rem) = take(slice).split_first_mut()?;
+    *slice = rem;
+    Some(first)
+}
+
 /// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
 ///
 /// The fake device always uses descriptors in order.
@@ -680,7 +758,7 @@
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(
-            unsafe { queue.add(&[], &[]) }.unwrap_err(),
+            unsafe { queue.add(&[], &mut []) }.unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -692,7 +770,7 @@
         let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
-            unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
+            unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
             Error::QueueFull
         );
     }
@@ -706,7 +784,7 @@
 
         // Add a buffer chain consisting of two device-readable parts followed by two
         // device-writable parts.
-        let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap();
+        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
 
         assert_eq!(queue.available_desc(), 0);
         assert!(!queue.can_pop());
@@ -757,27 +835,3 @@
         }
     }
 }
-
-/// Returns an iterator over the buffers of first `inputs` and then `outputs`, paired with the
-/// corresponding `BufferDirection`.
-///
-/// Panics if any of the buffer pointers is null.
-fn input_output_iter<'a>(
-    inputs: &'a [*const [u8]],
-    outputs: &'a [*mut [u8]],
-) -> impl Iterator<Item = (NonNull<[u8]>, BufferDirection)> + 'a {
-    inputs
-        .iter()
-        .map(|input| {
-            (
-                NonNull::new(*input as *mut [u8]).unwrap(),
-                BufferDirection::DriverToDevice,
-            )
-        })
-        .chain(outputs.iter().map(|output| {
-            (
-                NonNull::new(*output).unwrap(),
-                BufferDirection::DeviceToDriver,
-            )
-        }))
-}
diff --git a/src/transport/fake.rs b/src/transport/fake.rs
index 1d599d8..a578db2 100644
--- a/src/transport/fake.rs
+++ b/src/transport/fake.rs
@@ -38,6 +38,10 @@
         self.state.lock().unwrap().queues[queue as usize].notified = true;
     }
 
+    fn get_status(&self) -> DeviceStatus {
+        self.state.lock().unwrap().status
+    }
+
     fn set_status(&mut self, status: DeviceStatus) {
         self.state.lock().unwrap().status = status;
     }
diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs
index a6d421e..026646b 100644
--- a/src/transport/mmio.rs
+++ b/src/transport/mmio.rs
@@ -350,6 +350,11 @@
         }
     }
 
+    fn get_status(&self) -> DeviceStatus {
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe { volread!(self.header, status) }
+    }
+
     fn set_status(&mut self, status: DeviceStatus) {
         // Safe because self.header points to a valid VirtIO MMIO region.
         unsafe {
@@ -442,7 +447,11 @@
                 // Safe because self.header points to a valid VirtIO MMIO region.
                 unsafe {
                     volwrite!(self.header, queue_sel, queue.into());
+
                     volwrite!(self.header, queue_ready, 0);
+                    // Wait until we read the same value back, to ensure synchronisation (see 4.2.2.2).
+                    while volread!(self.header, queue_ready) != 0 {}
+
                     volwrite!(self.header, queue_num, 0);
                     volwrite!(self.header, queue_desc_low, 0);
                     volwrite!(self.header, queue_desc_high, 0);
diff --git a/src/transport/mod.rs b/src/transport/mod.rs
index 013fa27..f88293c 100644
--- a/src/transport/mod.rs
+++ b/src/transport/mod.rs
@@ -26,6 +26,9 @@
     /// Notifies the given queue on the device.
     fn notify(&mut self, queue: u16);
 
+    /// Gets the device status.
+    fn get_status(&self) -> DeviceStatus;
+
     /// Sets the device status.
     fn set_status(&mut self, status: DeviceStatus);
 
diff --git a/src/transport/pci.rs b/src/transport/pci.rs
index f6473f8..b8bcb15 100644
--- a/src/transport/pci.rs
+++ b/src/transport/pci.rs
@@ -251,6 +251,13 @@
         }
     }
 
+    fn get_status(&self) -> DeviceStatus {
+        // Safe because the common config pointer is valid and we checked in get_bar_region that it
+        // was aligned.
+        let status = unsafe { volread!(self.common_cfg, device_status) };
+        DeviceStatus::from_bits_truncate(status.into())
+    }
+
     fn set_status(&mut self, status: DeviceStatus) {
         // Safe because the common config pointer is valid and we checked in get_bar_region that it
         // was aligned.
@@ -287,16 +294,9 @@
         }
     }
 
-    fn queue_unset(&mut self, queue: u16) {
-        // Safe because the common config pointer is valid and we checked in get_bar_region that it
-        // was aligned.
-        unsafe {
-            volwrite!(self.common_cfg, queue_select, queue);
-            volwrite!(self.common_cfg, queue_size, 0);
-            volwrite!(self.common_cfg, queue_desc, 0);
-            volwrite!(self.common_cfg, queue_driver, 0);
-            volwrite!(self.common_cfg, queue_device, 0);
-        }
+    fn queue_unset(&mut self, _queue: u16) {
+        // The VirtIO spec doesn't allow queues to be unset once they have been set up for the PCI
+        // transport, so this is a no-op.
     }
 
     fn queue_used(&mut self, queue: u16) -> bool {
@@ -341,7 +341,8 @@
 impl Drop for PciTransport {
     fn drop(&mut self) {
         // Reset the device when the transport is dropped.
-        self.set_status(DeviceStatus::empty())
+        self.set_status(DeviceStatus::empty());
+        while self.get_status() != DeviceStatus::empty() {}
     }
 }
 
@@ -395,7 +396,9 @@
         return Err(VirtioPciError::BarOffsetOutOfRange);
     }
     let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr;
-    let vaddr = H::mmio_phys_to_virt(paddr, struct_info.length as usize);
+    // Safe because the paddr and size describe a valid MMIO region, at least according to the PCI
+    // bus.
+    let vaddr = unsafe { H::mmio_phys_to_virt(paddr, struct_info.length as usize) };
     if vaddr.as_ptr() as usize % align_of::<T>() != 0 {
         return Err(VirtioPciError::Misaligned {
             vaddr,