blob: 634591e4754599800e8511b59b09c1dad3190308 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
/// Support for virtual sockets.
use std::fmt;
use std::io;
use std::mem::{self, size_of};
use std::num::ParseIntError;
use std::os::raw::{c_uchar, c_uint, c_ushort};
use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
use std::result;
use std::str::FromStr;
use libc::{
self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK,
VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR,
};
// The domain for vsock sockets.
const AF_VSOCK: sa_family_t = 40;
// Vsock loopback address.
const VMADDR_CID_LOCAL: c_uint = 1;
/// Vsock equivalent of binding on port 0. Binds to a random port.
pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
// from linux/vm_sockets.h.
const PADDING: usize = size_of::<sockaddr>()
- size_of::<sa_family_t>()
- size_of::<c_ushort>()
- (2 * size_of::<c_uint>());
#[repr(C)]
#[derive(Default)]
struct sockaddr_vm {
svm_family: sa_family_t,
svm_reserved1: c_ushort,
svm_port: c_uint,
svm_cid: c_uint,
svm_zero: [c_uchar; PADDING],
}
#[derive(Debug)]
pub struct AddrParseError;
impl fmt::Display for AddrParseError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "failed to parse vsock address")
}
}
/// The vsock equivalent of an IP address.
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
pub enum VsockCid {
/// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
Any,
/// An address that refers to the bare-metal machine that serves as the hypervisor.
Hypervisor,
/// The loopback address.
Local,
/// The parent machine. It may not be the hypervisor for nested VMs.
Host,
/// An assigned CID that serves as the address for VSOCK.
Cid(c_uint),
}
impl fmt::Display for VsockCid {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match &self {
VsockCid::Any => write!(fmt, "Any"),
VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
VsockCid::Local => write!(fmt, "Local"),
VsockCid::Host => write!(fmt, "Host"),
VsockCid::Cid(c) => write!(fmt, "'{}'", c),
}
}
}
impl From<c_uint> for VsockCid {
fn from(c: c_uint) -> Self {
match c {
VMADDR_CID_ANY => VsockCid::Any,
VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
VMADDR_CID_LOCAL => VsockCid::Local,
VMADDR_CID_HOST => VsockCid::Host,
_ => VsockCid::Cid(c),
}
}
}
impl FromStr for VsockCid {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let c: c_uint = s.parse()?;
Ok(c.into())
}
}
impl Into<c_uint> for VsockCid {
fn into(self) -> c_uint {
match self {
VsockCid::Any => VMADDR_CID_ANY,
VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
VsockCid::Local => VMADDR_CID_LOCAL,
VsockCid::Host => VMADDR_CID_HOST,
VsockCid::Cid(c) => c,
}
}
}
/// An address associated with a virtual socket.
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
pub struct SocketAddr {
pub cid: VsockCid,
pub port: c_uint,
}
pub trait ToSocketAddr {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
}
impl ToSocketAddr for SocketAddr {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
Ok(*self)
}
}
impl ToSocketAddr for str {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
self.parse()
}
}
impl ToSocketAddr for (VsockCid, c_uint) {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
let (cid, port) = *self;
Ok(SocketAddr { cid, port })
}
}
impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
(**self).to_socket_addr()
}
}
impl FromStr for SocketAddr {
type Err = AddrParseError;
/// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
/// "vsock:cid:port".
fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
let components: Vec<&str> = s.split(':').collect();
if components.len() != 3 || components[0] != "vsock" {
return Err(AddrParseError);
}
Ok(SocketAddr {
cid: components[1].parse().map_err(|_| AddrParseError)?,
port: components[2].parse().map_err(|_| AddrParseError)?,
})
}
}
impl fmt::Display for SocketAddr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}:{}", self.cid, self.port)
}
}
/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
let flags = libc::fcntl(fd, F_GETFL, 0);
if flags < 0 {
return Err(io::Error::last_os_error());
}
let flags = if nonblocking {
flags | O_NONBLOCK
} else {
flags & !O_NONBLOCK
};
let ret = libc::fcntl(fd, F_SETFL, flags);
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
/// A virtual socket.
///
/// Do not use this class unless you need to change socket options or query the
/// state of the socket prior to calling listen or connect. Instead use either VsockStream or
/// VsockListener.
#[derive(Debug)]
pub struct VsockSocket {
fd: RawFd,
}
impl VsockSocket {
pub fn new() -> io::Result<Self> {
let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
if fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(VsockSocket { fd })
}
}
pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
let sockaddr = addr
.to_socket_addr()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
// The compiler should optimize this out since these are both compile-time constants.
assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
let mut svm: sockaddr_vm = Default::default();
svm.svm_family = AF_VSOCK;
svm.svm_cid = sockaddr.cid.into();
svm.svm_port = sockaddr.port;
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe {
libc::bind(
self.fd,
&svm as *const sockaddr_vm as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
};
if ret < 0 {
let bind_err = io::Error::last_os_error();
Err(bind_err)
} else {
Ok(())
}
}
pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
let sockaddr = addr
.to_socket_addr()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
let mut svm: sockaddr_vm = Default::default();
svm.svm_family = AF_VSOCK;
svm.svm_cid = sockaddr.cid.into();
svm.svm_port = sockaddr.port;
// Safe because this just connects a vsock socket, and the return value is checked.
let ret = unsafe {
libc::connect(
self.fd,
&svm as *const sockaddr_vm as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
};
if ret < 0 {
let connect_err = io::Error::last_os_error();
Err(connect_err)
} else {
Ok(VsockStream { sock: self })
}
}
pub fn listen(self) -> io::Result<VsockListener> {
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe { libc::listen(self.fd, 1) };
if ret < 0 {
let listen_err = io::Error::last_os_error();
return Err(listen_err);
}
Ok(VsockListener { sock: self })
}
/// Returns the port that this socket is bound to. This can only succeed after bind is called.
pub fn local_port(&self) -> io::Result<u32> {
let mut svm: sockaddr_vm = Default::default();
// Safe because we give a valid pointer for addrlen and check the length.
let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
let ret = unsafe {
// Get the socket address that was actually bound.
libc::getsockname(
self.fd,
&mut svm as *mut sockaddr_vm as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
// If this doesn't match, it's not safe to get the port out of the sockaddr.
assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
Ok(svm.svm_port)
}
}
pub fn try_clone(&self) -> io::Result<Self> {
// Safe because this doesn't modify any memory and we check the return value.
let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
if dup_fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(Self { fd: dup_fd })
}
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
// Safe because the fd is valid and owned by this stream.
unsafe { set_nonblocking(self.fd, nonblocking) }
}
}
impl IntoRawFd for VsockSocket {
fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
mem::forget(self);
fd
}
}
impl AsRawFd for VsockSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for VsockSocket {
fn drop(&mut self) {
// Safe because this doesn't modify any memory and we are the only
// owner of the file descriptor.
unsafe { libc::close(self.fd) };
}
}
/// A virtual stream socket.
#[derive(Debug)]
pub struct VsockStream {
sock: VsockSocket,
}
impl VsockStream {
pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
let sock = VsockSocket::new()?;
sock.connect(addr)
}
/// Returns the port that this stream is bound to.
pub fn local_port(&self) -> io::Result<u32> {
self.sock.local_port()
}
pub fn try_clone(&self) -> io::Result<VsockStream> {
self.sock.try_clone().map(|f| VsockStream { sock: f })
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
self.sock.set_nonblocking(nonblocking)
}
}
impl io::Read for VsockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// Safe because this will only modify the contents of |buf| and we check the return value.
let ret = unsafe {
libc::read(
self.sock.as_raw_fd(),
buf as *mut [u8] as *mut c_void,
buf.len() as size_t,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(ret as usize)
}
}
impl io::Write for VsockStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe {
libc::write(
self.sock.as_raw_fd(),
buf as *const [u8] as *const c_void,
buf.len() as size_t,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(ret as usize)
}
fn flush(&mut self) -> io::Result<()> {
// No buffered data so nothing to do.
Ok(())
}
}
impl AsRawFd for VsockStream {
fn as_raw_fd(&self) -> RawFd {
self.sock.as_raw_fd()
}
}
impl IntoRawFd for VsockStream {
fn into_raw_fd(self) -> RawFd {
self.sock.into_raw_fd()
}
}
/// Represents a virtual socket server.
#[derive(Debug)]
pub struct VsockListener {
sock: VsockSocket,
}
impl VsockListener {
/// Creates a new `VsockListener` bound to the specified port on the current virtual socket
/// endpoint.
pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
let mut sock = VsockSocket::new()?;
sock.bind(addr)?;
sock.listen()
}
/// Returns the port that this listener is bound to.
pub fn local_port(&self) -> io::Result<u32> {
self.sock.local_port()
}
/// Accepts a new incoming connection on this listener. Blocks the calling thread until a
/// new connection is established. When established, returns the corresponding `VsockStream`
/// and the remote peer's address.
pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
let mut svm: sockaddr_vm = Default::default();
// Safe because this will only modify |svm| and we check the return value.
let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
let fd = unsafe {
libc::accept4(
self.sock.as_raw_fd(),
&mut svm as *mut sockaddr_vm as *mut sockaddr,
&mut socklen as *mut socklen_t,
libc::SOCK_CLOEXEC,
)
};
if fd < 0 {
return Err(io::Error::last_os_error());
}
if svm.svm_family != AF_VSOCK {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected address family: {}", svm.svm_family),
));
}
Ok((
VsockStream {
sock: VsockSocket { fd },
},
SocketAddr {
cid: svm.svm_cid.into(),
port: svm.svm_port,
},
))
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
self.sock.set_nonblocking(nonblocking)
}
}
impl AsRawFd for VsockListener {
fn as_raw_fd(&self) -> RawFd {
self.sock.as_raw_fd()
}
}