blob: c1c5ec5c3e3297ec84136eed77409c4d552d9b43 [file] [log] [blame]
// Copyright 2019 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::{array::TryFromSliceError, convert::TryInto, error, fmt, io, mem, os::unix::io::RawFd};
use cras_sys::gen::{
cras_client_connected, cras_client_message, cras_client_stream_connected,
CRAS_CLIENT_MAX_MSG_SIZE,
CRAS_CLIENT_MESSAGE_ID::{self, *},
};
use data_model::DataInit;
use sys_util::ScmSocket;
use crate::cras_server_socket::CrasServerSocket;
use crate::cras_shm::*;
use crate::cras_stream;
#[derive(Debug)]
pub enum Error {
IoError(io::Error),
SysUtilError(sys_util::Error),
CrasStreamError(cras_stream::Error),
ArrayTryFromSliceError(TryFromSliceError),
InvalidSize,
MessageTypeError,
MessageNumFdError,
MessageTruncated,
MessageIdError,
MessageFromSliceError,
}
impl error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::IoError(ref err) => err.fmt(f),
Error::SysUtilError(ref err) => err.fmt(f),
Error::MessageTypeError => write!(f, "Message type error"),
Error::CrasStreamError(ref err) => err.fmt(f),
Error::ArrayTryFromSliceError(ref err) => err.fmt(f),
Error::MessageNumFdError => write!(f, "Message the number of fds is not matched"),
Error::MessageTruncated => write!(f, "Read truncated message"),
Error::MessageIdError => write!(f, "No such id"),
Error::MessageFromSliceError => write!(f, "Message from slice error"),
Error::InvalidSize => write!(f, "Invalid data size"),
}
}
}
type Result<T> = std::result::Result<T, Error>;
impl From<io::Error> for Error {
fn from(io_err: io::Error) -> Self {
Error::IoError(io_err)
}
}
impl From<sys_util::Error> for Error {
fn from(sys_util_err: sys_util::Error) -> Self {
Error::SysUtilError(sys_util_err)
}
}
impl From<cras_stream::Error> for Error {
fn from(err: cras_stream::Error) -> Self {
Error::CrasStreamError(err)
}
}
impl From<TryFromSliceError> for Error {
fn from(err: TryFromSliceError) -> Self {
Error::ArrayTryFromSliceError(err)
}
}
/// A handled server result from one message sent from CRAS server.
pub enum ServerResult {
/// client_id, CrasServerStateShmFd
Connected(u32, CrasServerStateShmFd),
/// stream_id, header_fd, samples_fd
StreamConnected(u32, CrasAudioShmHeaderFd, CrasShmFd),
DebugInfoReady,
}
impl ServerResult {
/// Reads and handles one server message and converts `CrasClientMessage` into `ServerResult`
/// with error handling.
///
/// # Arguments
/// * `server_socket`: A reference to `CrasServerSocket`.
pub fn handle_server_message(server_socket: &CrasServerSocket) -> Result<ServerResult> {
let message = CrasClientMessage::try_new(&server_socket)?;
match message.get_id()? {
CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_CONNECTED => {
let cmsg: &cras_client_connected = message.get_message()?;
// CRAS server should return a shared memory area which contains
// `cras_server_state`.
let server_state_fd = unsafe { CrasServerStateShmFd::new(message.fds[0]) };
Ok(ServerResult::Connected(cmsg.client_id, server_state_fd))
}
CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_STREAM_CONNECTED => {
let cmsg: &cras_client_stream_connected = message.get_message()?;
// CRAS should return two shared memory areas the first which has
// mem::size_of::<cras_audio_shm_header>() bytes, and the second which has
// `samples_shm_size` bytes.
Ok(ServerResult::StreamConnected(
cmsg.stream_id,
// Safe because CRAS ensures that the first fd contains a cras_audio_shm_header
unsafe { CrasAudioShmHeaderFd::new(message.fds[0]) },
// Safe because CRAS ensures that the second fd has length 'samples_shm_size'
unsafe { CrasShmFd::new(message.fds[1], cmsg.samples_shm_size as usize) },
))
}
CRAS_CLIENT_MESSAGE_ID::CRAS_CLIENT_AUDIO_DEBUG_INFO_READY => {
Ok(ServerResult::DebugInfoReady)
}
_ => Err(Error::MessageTypeError),
}
}
}
// A structure for raw message with fds from CRAS server.
struct CrasClientMessage {
fds: [RawFd; 2],
data: [u8; CRAS_CLIENT_MAX_MSG_SIZE as usize],
len: usize,
}
/// The default constructor won't be used outside of this file and it's an optimization to prevent
/// having to copy the message data from a temp buffer.
impl Default for CrasClientMessage {
// Initializes fields with default values.
fn default() -> Self {
Self {
fds: [-1; 2],
data: [0; CRAS_CLIENT_MAX_MSG_SIZE as usize],
len: 0,
}
}
}
impl CrasClientMessage {
// Reads a message from server_socket and checks validity of the read result
fn try_new(server_socket: &CrasServerSocket) -> Result<CrasClientMessage> {
let mut message: Self = Default::default();
let (len, fd_nums) = server_socket.recv_with_fds(&mut message.data, &mut message.fds)?;
if len < mem::size_of::<cras_client_message>() {
Err(Error::MessageTruncated)
} else {
message.len = len;
message.check_fd_nums(fd_nums)?;
Ok(message)
}
}
// Check if `fd nums` of a read result is valid
fn check_fd_nums(&self, fd_nums: usize) -> Result<()> {
match self.get_id()? {
CRAS_CLIENT_CONNECTED => match fd_nums {
1 => Ok(()),
_ => Err(Error::MessageNumFdError),
},
CRAS_CLIENT_STREAM_CONNECTED => match fd_nums {
// CRAS should return two shared memory areas the first which has
// mem::size_of::<cras_audio_shm_header>() bytes, and the second which has
// `samples_shm_size` bytes.
2 => Ok(()),
_ => Err(Error::MessageNumFdError),
},
CRAS_CLIENT_AUDIO_DEBUG_INFO_READY => match fd_nums {
0 => Ok(()),
_ => Err(Error::MessageNumFdError),
},
_ => Err(Error::MessageTypeError),
}
}
// Gets the message id
fn get_id(&self) -> Result<CRAS_CLIENT_MESSAGE_ID> {
let offset = mem::size_of::<u32>();
match u32::from_le_bytes(self.data[offset..offset + 4].try_into()?) {
id if id == (CRAS_CLIENT_CONNECTED as u32) => Ok(CRAS_CLIENT_CONNECTED),
id if id == (CRAS_CLIENT_STREAM_CONNECTED as u32) => Ok(CRAS_CLIENT_STREAM_CONNECTED),
id if id == (CRAS_CLIENT_AUDIO_DEBUG_INFO_READY as u32) => {
Ok(CRAS_CLIENT_AUDIO_DEBUG_INFO_READY)
}
_ => Err(Error::MessageIdError),
}
}
// Gets a reference to the message content
fn get_message<T: DataInit>(&self) -> Result<&T> {
if self.len != mem::size_of::<T>() {
return Err(Error::InvalidSize);
}
T::from_slice(&self.data[..mem::size_of::<T>()]).ok_or(Error::MessageFromSliceError)
}
}