blob: dac66a04c97a38b9861ea8fe1d980eedf6376c7e [file] [log] [blame]
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//! DoH backend for the Android DnsResolver module.
use anyhow::{anyhow, bail, Context, Result};
use futures::future::join_all;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
use log::{debug, error, info, trace, warn};
use quiche::h3;
use ring::rand::SecureRandom;
use std::collections::HashMap;
use std::ffi::CString;
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::{Arc, Mutex, Once};
use std::{ptr, slice};
use tokio::net::UdpSocket;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::{mpsc, oneshot};
use tokio::task;
use tokio::time::{timeout, Duration, Instant};
use url::Url;
static INIT: Once = Once::new();
/// The return code of doh_query means that there is no answer.
pub const RESULT_INTERNAL_ERROR: ssize_t = -1;
/// The return code of doh_query means that query can't be sent.
pub const RESULT_CAN_NOT_SEND: ssize_t = -2;
/// The return code of doh_query to indicate that the query timed out.
pub const RESULT_TIMEOUT: ssize_t = -255;
/// The error log level.
pub const LOG_LEVEL_ERROR: u32 = 0;
/// The warning log level.
pub const LOG_LEVEL_WARN: u32 = 1;
/// The info log level.
pub const LOG_LEVEL_INFO: u32 = 2;
/// The debug log level.
pub const LOG_LEVEL_DEBUG: u32 = 3;
/// The trace log level.
pub const LOG_LEVEL_TRACE: u32 = 4;
const MAX_BUFFERED_CMD_SIZE: usize = 400;
const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000;
const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000;
const MAX_CONCURRENT_STREAM_SIZE: u64 = 100;
const MAX_DATAGRAM_SIZE: usize = 1350;
const DOH_PORT: u16 = 443;
const QUICHE_IDLE_TIMEOUT_MS: u64 = 180000;
const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
const NS_T_AAAA: u8 = 28;
const NS_C_IN: u8 = 1;
// Used to randomly generate query prefix and query id.
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
abcdefghijklmnopqrstuvwxyz\
0123456789";
type SCID = [u8; quiche::MAX_CONN_ID_LEN];
type Base64Query = String;
type CmdSender = mpsc::Sender<DohCommand>;
type CmdReceiver = mpsc::Receiver<DohCommand>;
type QueryResponder = oneshot::Sender<Response>;
type DnsRequest = Vec<quiche::h3::Header>;
type ValidationCallback =
extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
#[derive(Eq, PartialEq, Debug)]
enum QueryError {
BrokenServer,
ConnectionError,
ServerNotReady,
Unexpected,
}
#[derive(Eq, PartialEq, Debug, Clone)]
struct ServerInfo {
net_id: u32,
url: Url,
peer_addr: SocketAddr,
domain: Option<String>,
sk_mark: u32,
cert_path: Option<String>,
}
#[derive(Eq, PartialEq, Debug)]
enum Response {
Error { error: QueryError },
Success { answer: Vec<u8> },
}
#[derive(Debug)]
enum DohCommand {
Probe { info: ServerInfo, timeout: Duration },
Query { net_id: u32, base64_query: Base64Query, expired_time: Instant, resp: QueryResponder },
Clear { net_id: u32 },
Exit,
}
#[allow(clippy::large_enum_variant)]
enum ConnectionState {
Idle,
Connecting {
quic_conn: Option<Pin<Box<quiche::Connection>>>,
udp_sk: Option<UdpSocket>,
expired_time: Option<BootTime>,
},
Connected {
quic_conn: Pin<Box<quiche::Connection>>,
udp_sk: UdpSocket,
h3_conn: Option<h3::Connection>,
query_map: HashMap<u64, (Vec<u8>, QueryResponder)>,
expired_time: Option<BootTime>,
},
/// Indicate that the Connection can't be used due to
/// network or unexpected reasons.
Error,
}
enum H3Result {
Data { data: Vec<u8> },
Finished,
Ignore,
}
trait OptionDeref<T: Deref> {
fn as_deref(&self) -> Option<&T::Target>;
}
impl<T: Deref> OptionDeref<T> for Option<T> {
fn as_deref(&self) -> Option<&T::Target> {
self.as_ref().map(Deref::deref)
}
}
#[derive(Copy, Clone, Debug)]
struct BootTime {
d: Duration,
}
impl BootTime {
fn now() -> BootTime {
unsafe {
let mut t = libc::timespec { tv_sec: 0, tv_nsec: 0 };
if libc::clock_gettime(libc::CLOCK_BOOTTIME, &mut t as *mut libc::timespec) != 0 {
panic!("get boot time failed: {:?}", std::io::Error::last_os_error());
}
BootTime { d: Duration::new(t.tv_sec as u64, t.tv_nsec as u32) }
}
}
fn elapsed(&self) -> Option<Duration> {
BootTime::now().duration_since(*self)
}
fn checked_add(&self, duration: Duration) -> Option<BootTime> {
Some(BootTime { d: self.d.checked_add(duration)? })
}
fn duration_since(&self, earlier: BootTime) -> Option<Duration> {
self.d.checked_sub(earlier.d)
}
}
/// Context for a running DoH engine.
pub struct DohDispatcher {
/// Used to submit cmds to the I/O task.
cmd_sender: CmdSender,
join_handle: task::JoinHandle<Result<()>>,
runtime: Arc<Runtime>,
}
// DoH dispatcher
impl DohDispatcher {
fn new(validation_fn: ValidationCallback) -> Result<Box<DohDispatcher>> {
let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE);
let runtime = Arc::new(
Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.thread_name("doh-handler")
.build()
.expect("Failed to create tokio runtime"),
);
let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn));
Ok(Box::new(DohDispatcher { cmd_sender, join_handle, runtime }))
}
fn send_cmd(&self, cmd: DohCommand) -> Result<()> {
self.cmd_sender.blocking_send(cmd)?;
Ok(())
}
fn exit_handler(&mut self) {
if self.cmd_sender.blocking_send(DohCommand::Exit).is_err() {
return;
}
let _ = self.runtime.block_on(&mut self.join_handle);
}
}
struct DohConnection {
info: ServerInfo,
shared_config: Arc<Mutex<QuicheConfigCache>>,
scid: SCID,
state: ConnectionState,
pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>,
cached_session: Option<Vec<u8>>,
}
impl DohConnection {
fn new(
info: &ServerInfo,
shared_config: Arc<Mutex<QuicheConfigCache>>,
) -> Result<DohConnection> {
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?;
Ok(DohConnection {
info: info.clone(),
shared_config,
scid,
state: ConnectionState::Idle,
pending_queries: Vec::new(),
cached_session: None,
})
}
fn state_to_connecting(&mut self) -> Result<()> {
self.state = match self.state {
ConnectionState::Idle => {
let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?;
let udp_sk = UdpSocket::from_std(udp_sk_std)?;
let connid = quiche::ConnectionId::from_ref(&self.scid);
let mut cache = self.shared_config.lock().unwrap();
let config =
cache.get(&self.info.cert_path)?.ok_or_else(|| anyhow!("no quiche config"))?;
debug!("init the connection for Network {}", self.info.net_id);
let mut quic_conn = quiche::connect(
self.info.domain.as_deref(),
&connid,
self.info.peer_addr,
config,
)?;
if let Some(session) = &self.cached_session {
if quic_conn.set_session(session).is_err() {
warn!("can't restore session for network {}", self.info.net_id);
}
}
ConnectionState::Connecting {
quic_conn: Some(quic_conn),
udp_sk: Some(udp_sk),
expired_time: None,
}
}
ConnectionState::Error => {
self.state_to_idle();
return self.state_to_connecting();
}
ConnectionState::Connecting { .. } => return Ok(()),
ConnectionState::Connected { .. } => {
panic!("Invalid state transition to Connecting state!")
}
};
Ok(())
}
fn state_to_connected(&mut self) -> Result<()> {
self.state = match &mut self.state {
// Only Connecting -> Connected is valid.
ConnectionState::Connecting { quic_conn, udp_sk, .. } => {
if let (Some(mut quic_conn), Some(udp_sk)) = (quic_conn.take(), udp_sk.take()) {
let h3_config = h3::Config::new()?;
let h3_conn =
quiche::h3::Connection::with_transport(&mut quic_conn, &h3_config)?;
ConnectionState::Connected {
quic_conn,
udp_sk,
h3_conn: Some(h3_conn),
query_map: HashMap::new(),
expired_time: None,
}
} else {
bail!("state transition fail!");
}
}
// The rest should fail.
_ => panic!("Invalid state transition to Connected state!"),
};
Ok(())
}
fn state_to_idle(&mut self) {
self.state = match self.state {
// Only either Connected or Error -> Idle is valid.
// TODO: Error -> Idle is the re-probing case, add the relevant statistic.
ConnectionState::Connected { .. } | ConnectionState::Error => ConnectionState::Idle,
// The rest should fail.
_ => panic!("Invalid state transition to Idle state!"),
}
}
fn state_to_error(&mut self) {
self.pending_queries.clear();
self.state = ConnectionState::Error
}
fn is_reprobe_required(&self) -> bool {
matches!(self.state, ConnectionState::Error)
}
fn has_not_handled_queries(&self) -> bool {
match &self.state {
ConnectionState::Connecting { .. } | ConnectionState::Idle => {
!self.pending_queries.is_empty()
}
ConnectionState::Connected { query_map, .. } => {
!query_map.is_empty() || !self.pending_queries.is_empty()
}
_ => false,
}
}
fn handle_if_connection_expired(&mut self) {
let expired_time = match &mut self.state {
ConnectionState::Connecting { expired_time, .. } => expired_time,
ConnectionState::Connected { expired_time, .. } => expired_time,
// ignore
_ => return,
};
if let Some(expired_time) = expired_time {
if let Some(elapsed) = expired_time.elapsed() {
warn!(
"Change the state to Idle due to connection timeout, {:?}, {}",
elapsed, self.info.net_id
);
self.state_to_idle();
}
}
}
async fn probe(&mut self, t: Duration) -> Result<()> {
match timeout(t, async {
self.try_connect().await?;
info!("probe start for {}", self.info.net_id);
if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, expired_time, .. } =
&mut self.state
{
let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
let req = match make_probe_query() {
Ok(q) => match make_dns_request(&q, &self.info.url) {
Ok(req) => req,
Err(e) => bail!(e),
},
Err(e) => bail!(e),
};
// Send the probe query.
let req_id = h3_conn.send_request(quic_conn, &req, true /*fin*/)?;
loop {
flush_tx(quic_conn, udp_sk).await?;
recv_rx(quic_conn, udp_sk, expired_time).await?;
loop {
match recv_h3(quic_conn, h3_conn) {
Ok((stream_id, H3Result::Finished)) => {
if stream_id == req_id {
return Ok(());
}
}
// TODO: Verify the answer
Ok((_stream_id, H3Result::Data { .. })) => {}
Ok((_stream_id, H3Result::Ignore)) => {}
Err(_) => break,
}
}
}
} else {
bail!("state error while performing probe()");
}
})
.await
{
Ok(v) => match v {
Ok(_) => Ok(()),
Err(e) => {
self.state_to_error();
bail!(e);
}
},
Err(e) => {
self.state_to_error();
bail!(e);
}
}
}
async fn try_connect(&mut self) -> Result<()> {
if matches!(self.state, ConnectionState::Connected { .. }) {
return Ok(());
}
self.state_to_connecting()?;
debug!("connecting to Network {}", self.info.net_id);
let (quic_conn, udp_sk, expired_time) = match &mut self.state {
ConnectionState::Connecting { quic_conn, udp_sk, expired_time, .. } => {
if let (Some(quic_conn), Some(udp_sk)) = (quic_conn.as_mut(), udp_sk.as_mut()) {
(quic_conn, udp_sk, expired_time)
} else {
bail!("unexpected error while performing connect()");
}
}
_ => bail!("state error while performing try_connect()"),
};
while !quic_conn.is_established() {
flush_tx(quic_conn, udp_sk).await?;
recv_rx(quic_conn, udp_sk, expired_time).await?;
}
self.cached_session = quic_conn.session();
self.state_to_connected()?;
info!("connected to Network {}", self.info.net_id);
Ok(())
}
async fn try_send_doh_query(
&mut self,
req: DnsRequest,
resp: QueryResponder,
expired_time: Instant,
) -> Result<()> {
self.handle_if_connection_expired();
match &mut self.state {
ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, .. } => {
let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
send_dns_query(
quic_conn,
udp_sk,
h3_conn,
query_map,
&mut self.pending_queries,
resp,
expired_time,
req,
)
.await?
}
ConnectionState::Connecting { .. } | ConnectionState::Idle => {
self.pending_queries.push((req, resp, expired_time))
}
ConnectionState::Error => {
error!(
"state is error while performing try_send_doh_query(), network: {}",
self.info.net_id
);
let _ = resp.send(Response::Error { error: QueryError::BrokenServer });
}
}
Ok(())
}
async fn process_queries(&mut self) -> Result<()> {
debug!("process_queries entry, Network {}", self.info.net_id);
self.try_connect().await?;
if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, expired_time } =
&mut self.state
{
let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
loop {
while !self.pending_queries.is_empty() {
if let Some((req, resp, exp_time)) = self.pending_queries.pop() {
// Ignore the expired queries.
if Instant::now().checked_duration_since(exp_time).is_some() {
warn!("Drop the obsolete query for network {}", self.info.net_id);
continue;
}
send_dns_query(
quic_conn,
udp_sk,
h3_conn,
query_map,
&mut self.pending_queries,
resp,
exp_time,
req,
)
.await?;
}
}
flush_tx(quic_conn, udp_sk).await?;
recv_rx(quic_conn, udp_sk, expired_time).await?;
loop {
match recv_h3(quic_conn, h3_conn) {
Ok((stream_id, H3Result::Data { mut data })) => {
if let Some((answer, _)) = query_map.get_mut(&stream_id) {
answer.append(&mut data);
} else {
// Should not happen
warn!("No associated receiver found while receiving Data, Network {}, stream id: {}", self.info.net_id, stream_id);
}
}
Ok((stream_id, H3Result::Finished)) => {
if let Some((answer, resp)) = query_map.remove(&stream_id) {
debug!(
"sending answer back to resolv, Network {}, stream id: {}",
self.info.net_id, stream_id
);
resp.send(Response::Success { answer }).unwrap_or_else(|e| {
trace!(
"the receiver dropped {:?}, stream id: {}",
e,
stream_id
);
});
} else {
// Should not happen
warn!("No associated receiver found while receiving Finished, Network {}, stream id: {}", self.info.net_id, stream_id);
}
}
Ok((_stream_id, H3Result::Ignore)) => {}
Err(_) => break,
}
}
if quic_conn.is_closed() || !quic_conn.is_established() {
self.state_to_idle();
bail!("connection become idle");
}
}
} else {
self.state_to_error();
bail!("state error while performing process_queries(), network: {}", self.info.net_id);
}
}
}
fn recv_h3(
quic_conn: &mut Pin<Box<quiche::Connection>>,
h3_conn: &mut h3::Connection,
) -> Result<(u64, H3Result)> {
match h3_conn.poll(quic_conn) {
// Process HTTP/3 events.
Ok((stream_id, quiche::h3::Event::Data)) => {
debug!("quiche::h3::Event::Data");
let mut buf = vec![0; MAX_DATAGRAM_SIZE];
match h3_conn.recv_body(quic_conn, stream_id, &mut buf) {
Ok(read) => {
trace!(
"got {} bytes of response data on stream {}: {:x?}",
read,
stream_id,
&buf[..read]
);
buf.truncate(read);
Ok((stream_id, H3Result::Data { data: buf }))
}
Err(e) => {
warn!("recv_h3::recv_body {:?}", e);
bail!(e);
}
}
}
Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
trace!(
"got response headers {:?} on stream id {} has_body {}",
list,
stream_id,
has_body
);
Ok((stream_id, H3Result::Ignore))
}
Ok((stream_id, quiche::h3::Event::Finished)) => {
debug!("quiche::h3::Event::Finished on stream id {}", stream_id);
Ok((stream_id, H3Result::Finished))
}
Ok((stream_id, quiche::h3::Event::Datagram)) => {
debug!("quiche::h3::Event::Datagram on stream id {}", stream_id);
Ok((stream_id, H3Result::Ignore))
}
// TODO: Check if it's necessary to handle GoAway event.
Ok((stream_id, quiche::h3::Event::GoAway)) => {
debug!("quiche::h3::Event::GoAway on stream id {}", stream_id);
Ok((stream_id, H3Result::Ignore))
}
Err(e) => {
debug!("recv_h3 {:?}", e);
bail!(e);
}
}
}
#[allow(clippy::too_many_arguments)]
async fn send_dns_query(
quic_conn: &mut Pin<Box<quiche::Connection>>,
udp_sk: &mut UdpSocket,
h3_conn: &mut h3::Connection,
query_map: &mut HashMap<u64, (Vec<u8>, QueryResponder)>,
pending_queries: &mut Vec<(DnsRequest, QueryResponder, Instant)>,
resp: QueryResponder,
expired_time: Instant,
req: DnsRequest,
) -> Result<()> {
if !quic_conn.is_established() {
bail!("quic connection is not ready");
}
match h3_conn.send_request(quic_conn, &req, true /*fin*/) {
Ok(stream_id) => {
query_map.insert(stream_id, (Vec::new(), resp));
flush_tx(quic_conn, udp_sk).await?;
debug!("send dns query successfully stream id: {}", stream_id);
Ok(())
}
Err(quiche::h3::Error::StreamBlocked) => {
warn!("try to send query but error on StreamBlocked");
pending_queries.push((req, resp, expired_time));
Ok(())
}
Err(e) => {
resp.send(Response::Error { error: QueryError::ConnectionError }).ok();
bail!(e);
}
}
}
async fn recv_rx(
quic_conn: &mut Pin<Box<quiche::Connection>>,
udp_sk: &mut UdpSocket,
expired_time: &mut Option<BootTime>,
) -> Result<()> {
// TODO: Evaluate if we could make the buffer smaller.
let mut buf = [0; 65535];
let quic_idle_timeout_ms = Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS);
let ts = quic_conn.timeout().unwrap_or(quic_idle_timeout_ms);
if let Some(next_expired) = BootTime::now().checked_add(quic_idle_timeout_ms) {
expired_time.replace(next_expired);
} else {
expired_time.take();
}
debug!("recv_rx entry next timeout {:?} {:?}", ts, expired_time);
match timeout(ts, udp_sk.recv_from(&mut buf)).await {
Ok(v) => match v {
Ok((size, from)) => {
let recv_info = quiche::RecvInfo { from };
let processed = match quic_conn.recv(&mut buf[..size], recv_info) {
Ok(l) => l,
Err(e) => {
debug!("recv_rx error {:?}", e);
bail!("quic recv failed: {:?}", e);
}
};
debug!("processed {} bytes", processed);
Ok(())
}
Err(e) => bail!("socket recv failed: {:?}", e),
},
Err(_) => {
warn!("timeout did not receive value within {:?}", ts);
quic_conn.on_timeout();
Ok(())
}
}
}
async fn flush_tx(
quic_conn: &mut Pin<Box<quiche::Connection>>,
udp_sk: &mut UdpSocket,
) -> Result<()> {
let mut out = [0; MAX_DATAGRAM_SIZE];
loop {
let (write, _) = match quic_conn.send(&mut out) {
Ok(v) => v,
Err(quiche::Error::Done) => {
debug!("done writing");
break;
}
Err(e) => {
quic_conn.close(false, 0x1, b"fail").ok();
bail!(e);
}
};
udp_sk.send(&out[..write]).await?;
debug!("written {}", write);
}
Ok(())
}
fn report_private_dns_validation(
info: &ServerInfo,
state: &ConnectionState,
runtime: Arc<Runtime>,
validation_fn: ValidationCallback,
) {
let (ip_addr, domain) = match (
CString::new(info.peer_addr.ip().to_string()),
CString::new(info.domain.clone().unwrap_or_default()),
) {
(Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
_ => {
error!("report_private_dns_validation bad input");
return;
}
};
let netd_id = info.net_id;
let success = matches!(state, ConnectionState::Connected { .. });
runtime
.spawn_blocking(move || validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()));
}
fn handle_probe_result(
result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>),
doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
runtime: Arc<Runtime>,
validation_fn: ValidationCallback,
) {
let (info, doh_conn) = match result {
(info, Ok(doh_conn)) => {
info!("probing_task success on net_id: {}", info.net_id);
(info, doh_conn)
}
(info, Err((e, doh_conn))) => {
error!("probe failed on network {}, {:?}", e, info.net_id);
(info, doh_conn)
// TODO: Retry probe?
}
};
// If the network is removed or the server is replaced before probing,
// ignore the probe result.
match doh_conn_map.get(&info.net_id) {
Some((server_info, _)) => {
if *server_info != info {
warn!(
"The previous configuration for network {} was replaced before probe finished",
info.net_id
);
return;
}
}
_ => {
warn!("network {} was removed before probe finished", info.net_id);
return;
}
}
report_private_dns_validation(&info, &doh_conn.state, runtime, validation_fn);
doh_conn_map.insert(info.net_id, (info, Some(doh_conn)));
}
async fn probe_task(
info: ServerInfo,
mut doh: DohConnection,
t: Duration,
) -> (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>) {
match doh.probe(t).await {
Ok(_) => (info, Ok(doh)),
Err(e) => (info, Err((anyhow!(e), doh))),
}
}
fn make_connection_if_needed(
info: &ServerInfo,
doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
shared_config: Arc<Mutex<QuicheConfigCache>>,
) -> Result<Option<DohConnection>> {
// Check if connection exists.
match doh_conn_map.get(&info.net_id) {
// The connection exists but has failed. Re-probe.
Some((server_info, Some(doh))) if *server_info == *info && doh.is_reprobe_required() => {
let (_, doh) = doh_conn_map
.insert(info.net_id, (info.clone(), None))
.ok_or_else(|| anyhow!("unexpected error, missing connection"))?;
return Ok(doh);
}
// The connection exists or the connection is under probing, ignore.
Some((server_info, _)) if *server_info == *info => return Ok(None),
// TODO: change the inner connection instead of removing?
_ => doh_conn_map.remove(&info.net_id),
};
let doh = DohConnection::new(info, shared_config)?;
doh_conn_map.insert(info.net_id, (info.clone(), None));
Ok(Some(doh))
}
struct QuicheConfigCache {
cert_path: Option<String>,
config: Option<quiche::Config>,
}
impl QuicheConfigCache {
fn get(&mut self, cert_path: &Option<String>) -> Result<Option<&mut quiche::Config>> {
// No config is cached or the cached config isn't matched with the input cert_path
// Create it with the input cert_path.
if self.config.is_none() || self.cert_path != *cert_path {
self.config = Some(create_quiche_config(cert_path.as_deref())?);
self.cert_path = cert_path.clone();
}
return Ok(self.config.as_mut());
}
}
async fn handle_query_cmd(
net_id: u32,
base64_query: Base64Query,
expired_time: Instant,
resp: QueryResponder,
doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
) {
if let Some((info, quic_conn)) = doh_conn_map.get_mut(&net_id) {
match (&info.domain, quic_conn) {
// Connection is not ready, strict mode
(Some(_), None) => {
let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
}
// Connection is not ready, Opportunistic mode
(None, None) => {
let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
}
// Connection is ready
(_, Some(quic_conn)) => {
if let Ok(req) = make_dns_request(&base64_query, &info.url) {
let _ = quic_conn.try_send_doh_query(req, resp, expired_time).await;
} else {
let _ = resp.send(Response::Error { error: QueryError::Unexpected });
}
}
}
} else {
error!("No connection is associated with the given net id {}", net_id);
let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
}
}
fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConnection>)>) -> bool {
if doh_conn_map.is_empty() {
return false;
}
for (_, doh_conn) in doh_conn_map.values() {
if let Some(doh_conn) = doh_conn {
if doh_conn.has_not_handled_queries() {
return true;
}
}
}
false
}
async fn doh_handler(
mut cmd_rx: CmdReceiver,
runtime: Arc<Runtime>,
validation_fn: ValidationCallback,
) -> Result<()> {
info!("doh_dispatcher entry");
let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None }));
// Currently, only support 1 server per network.
let mut doh_conn_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
let mut probe_futures = FuturesUnordered::new();
loop {
tokio::select! {
_ = async {
let mut futures = vec![];
for (_, doh_conn) in doh_conn_map.values_mut() {
if let Some(doh_conn) = doh_conn {
futures.push(doh_conn.process_queries());
}
}
join_all(futures).await
}, if need_process_queries(&doh_conn_map) => {},
Some(result) = probe_futures.next() => {
let runtime_clone = runtime.clone();
handle_probe_result(result, &mut doh_conn_map, runtime_clone, validation_fn);
info!("probe_futures remaining size: {}", probe_futures.len());
},
Some(cmd) = cmd_rx.recv() => {
trace!("recv {:?}", cmd);
match cmd {
DohCommand::Probe { info, timeout: t } => {
match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone()) {
Ok(Some(doh)) => {
// Create a new async task associated to the DoH connection.
probe_futures.push(probe_task(info, doh, t));
debug!("probe_futures size: {}", probe_futures.len());
}
Ok(None) => {
// No further probe is needed.
warn!("connection for network {} already exists", info.net_id);
// TODO: Report the status again?
}
Err(e) => {
error!("create connection for network {} error {:?}", info.net_id, e);
report_private_dns_validation(&info, &ConnectionState::Error, runtime.clone(), validation_fn);
}
}
},
DohCommand::Query { net_id, base64_query, expired_time, resp } => {
handle_query_cmd(net_id, base64_query, expired_time, resp, &mut doh_conn_map).await;
},
DohCommand::Clear { net_id } => {
doh_conn_map.remove(&net_id);
info!("Doh Clear server for netid: {}", net_id);
},
DohCommand::Exit => return Ok(()),
}
}
}
}
}
fn make_dns_request(base64_query: &str, url: &url::Url) -> Result<DnsRequest> {
let mut path = String::from(url.path());
path.push_str("?dns=");
path.push_str(base64_query);
let req = vec![
quiche::h3::Header::new(b":method", b"GET"),
quiche::h3::Header::new(b":scheme", b"https"),
quiche::h3::Header::new(
b":authority",
url.host_str().ok_or_else(|| anyhow!("failed to get host"))?.as_bytes(),
),
quiche::h3::Header::new(b":path", path.as_bytes()),
quiche::h3::Header::new(b"user-agent", b"quiche"),
quiche::h3::Header::new(b"accept", b"application/dns-message"),
// TODO: is content-length required?
];
Ok(req)
}
fn make_doh_udp_socket(peer_addr: SocketAddr, mark: u32) -> Result<std::net::UdpSocket> {
let bind_addr = match peer_addr {
std::net::SocketAddr::V4(_) => "0.0.0.0:0",
std::net::SocketAddr::V6(_) => "[::]:0",
};
let udp_sk = std::net::UdpSocket::bind(bind_addr)?;
udp_sk.set_nonblocking(true)?;
if mark_socket(udp_sk.as_raw_fd(), mark).is_err() {
warn!("Mark socket failed, is it a test?");
}
udp_sk.connect(peer_addr)?;
trace!("connecting to {:} from {:}", peer_addr, udp_sk.local_addr()?);
Ok(udp_sk)
}
fn create_quiche_config(cert_path: Option<&str>) -> Result<quiche::Config> {
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
config.set_application_protos(h3::APPLICATION_PROTOCOL)?;
match cert_path {
Some(path) => {
config.verify_peer(true);
config.load_verify_locations_from_directory(path)?;
}
None => config.verify_peer(false),
}
// Some of these configs are necessary, or the server can't respond the HTTP/3 request.
config.set_max_idle_timeout(QUICHE_IDLE_TIMEOUT_MS);
config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
config.set_initial_max_data(MAX_INCOMING_BUFFER_SIZE_WHOLE);
config.set_initial_max_stream_data_bidi_local(MAX_INCOMING_BUFFER_SIZE_EACH);
config.set_initial_max_stream_data_bidi_remote(MAX_INCOMING_BUFFER_SIZE_EACH);
config.set_initial_max_stream_data_uni(MAX_INCOMING_BUFFER_SIZE_EACH);
config.set_initial_max_streams_bidi(MAX_CONCURRENT_STREAM_SIZE);
config.set_initial_max_streams_uni(MAX_CONCURRENT_STREAM_SIZE);
config.set_disable_active_migration(true);
Ok(config)
}
fn mark_socket(fd: RawFd, mark: u32) -> Result<()> {
// libc::setsockopt is a wrapper function calling into bionic setsockopt.
// Both fd and mark are valid, which makes the function call mostly safe.
if unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_MARK,
&mark as *const _ as *const libc::c_void,
std::mem::size_of::<u32>() as libc::socklen_t,
)
} == 0
{
Ok(())
} else {
Err(anyhow::Error::new(std::io::Error::last_os_error()))
}
}
#[rustfmt::skip]
fn make_probe_query() -> Result<String> {
let mut rnd = [0; 8];
ring::rand::SystemRandom::new().fill(&mut rnd).context("failed to generate probe rnd")?;
let c = |byte| CHARSET[(byte as usize) % CHARSET.len()];
let query = vec![
rnd[6], rnd[7], // [0-1] query ID
1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
0, 1, // [4-5] QDCOUNT (number of queries)
0, 0, // [6-7] ANCOUNT (number of answers)
0, 0, // [8-9] NSCOUNT (number of name server records)
0, 0, // [10-11] ARCOUNT (number of additional records)
19, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), b'-', b'd', b'n',
b's', b'o', b'h', b't', b't', b'p', b's', b'-', b'd', b's',
6, b'm', b'e', b't', b'r', b'i', b'c', 7, b'g', b's',
b't', b'a', b't', b'i', b'c', 3, b'c', b'o', b'm',
0, // null terminator of FQDN (root TLD)
0, NS_T_AAAA, // QTYPE
0, NS_C_IN // QCLASS
];
Ok(base64::encode_config(query, base64::URL_SAFE_NO_PAD))
}
/// Performs static initialization for android logger.
#[no_mangle]
pub extern "C" fn doh_init_logger(level: u32) {
INIT.call_once(|| {
let level = match level {
LOG_LEVEL_WARN => log::Level::Warn,
LOG_LEVEL_DEBUG => log::Level::Debug,
_ => log::Level::Error,
};
android_logger::init_once(android_logger::Config::default().with_min_level(level));
});
}
/// Set the log level.
#[no_mangle]
pub extern "C" fn doh_set_log_level(level: u32) {
let level = match level {
LOG_LEVEL_ERROR => log::LevelFilter::Error,
LOG_LEVEL_WARN => log::LevelFilter::Warn,
LOG_LEVEL_INFO => log::LevelFilter::Info,
LOG_LEVEL_DEBUG => log::LevelFilter::Debug,
LOG_LEVEL_TRACE => log::LevelFilter::Trace,
_ => log::LevelFilter::Off,
};
log::set_max_level(level);
}
/// Performs the initialization for the DoH engine.
/// Creates and returns a DoH engine instance.
#[no_mangle]
pub extern "C" fn doh_dispatcher_new(ptr: ValidationCallback) -> *mut DohDispatcher {
match DohDispatcher::new(ptr) {
Ok(c) => Box::into_raw(c),
Err(e) => {
error!("doh_dispatcher_new: failed: {:?}", e);
ptr::null_mut()
}
}
}
/// Deletes a DoH engine created by doh_dispatcher_new().
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
#[no_mangle]
pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
Box::from_raw(doh).exit_handler()
}
/// Probes and stores the DoH server with the given configurations.
/// Use the negative errno-style codes as the return value to represent the result.
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
/// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
#[no_mangle]
pub unsafe extern "C" fn doh_net_new(
doh: &mut DohDispatcher,
net_id: uint32_t,
url: *const c_char,
domain: *const c_char,
ip_addr: *const c_char,
sk_mark: libc::uint32_t,
cert_path: *const c_char,
timeout_ms: libc::uint64_t,
) -> int32_t {
let (url, domain, ip_addr, cert_path) = match (
std::ffi::CStr::from_ptr(url).to_str(),
std::ffi::CStr::from_ptr(domain).to_str(),
std::ffi::CStr::from_ptr(ip_addr).to_str(),
std::ffi::CStr::from_ptr(cert_path).to_str(),
) {
(Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
if domain.is_empty() {
(url, None, ip_addr.to_string(), None)
} else if !cert_path.is_empty() {
(url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
} else {
(
url,
Some(domain.to_string()),
ip_addr.to_string(),
Some(SYSTEM_CERT_PATH.to_string()),
)
}
}
_ => {
error!("bad input"); // Should not happen
return -libc::EINVAL;
}
};
let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) {
(Ok(url), Ok(ip_addr)) => (url, ip_addr),
_ => {
error!("bad ip or url"); // Should not happen
return -libc::EINVAL;
}
};
let cmd = DohCommand::Probe {
info: ServerInfo {
net_id,
url,
peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
domain,
sk_mark,
cert_path,
},
timeout: Duration::from_millis(timeout_ms),
};
if let Err(e) = doh.send_cmd(cmd) {
error!("Failed to send the probe: {:?}", e);
return -libc::EPIPE;
}
0
}
/// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
/// The return code should be either one of the public constant RESULT_* to indicate the error or
/// the size of the answer.
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
/// `dns_query` must point to a buffer at least `dns_query_len` in size.
/// `response` must point to a buffer at least `response_len` in size.
#[no_mangle]
pub unsafe extern "C" fn doh_query(
doh: &mut DohDispatcher,
net_id: uint32_t,
dns_query: *mut u8,
dns_query_len: size_t,
response: *mut u8,
response_len: size_t,
timeout_ms: uint64_t,
) -> ssize_t {
let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
let (resp_tx, resp_rx) = oneshot::channel();
let t = Duration::from_millis(timeout_ms);
if let Some(expired_time) = Instant::now().checked_add(t) {
let cmd = DohCommand::Query {
net_id,
base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
expired_time,
resp: resp_tx,
};
if let Err(e) = doh.send_cmd(cmd) {
error!("Failed to send the query: {:?}", e);
return RESULT_CAN_NOT_SEND;
}
} else {
error!("Bad timeout parameter: {}", timeout_ms);
return RESULT_CAN_NOT_SEND;
}
if let Ok(rt) = Runtime::new() {
let local = task::LocalSet::new();
match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
Ok(v) => match v {
Ok(v) => match v {
Response::Success { answer } => {
if answer.len() > response_len || answer.len() > isize::MAX as usize {
return RESULT_INTERNAL_ERROR;
}
let response = slice::from_raw_parts_mut(response, answer.len());
response.copy_from_slice(&answer);
answer.len() as ssize_t
}
_ => RESULT_CAN_NOT_SEND,
},
Err(e) => {
error!("no result {}", e);
RESULT_CAN_NOT_SEND
}
},
Err(e) => {
error!("timeout: {}", e);
RESULT_TIMEOUT
}
}
} else {
RESULT_CAN_NOT_SEND
}
}
/// Clears the DoH servers associated with the given |netid|.
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
#[no_mangle]
pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) {
if let Err(e) = doh.send_cmd(DohCommand::Clear { net_id }) {
error!("Failed to send the query: {:?}", e);
}
}
#[cfg(test)]
mod tests {
use super::*;
use quiche::h3::NameValue;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
const TEST_NET_ID: u32 = 50;
const PROBE_QUERY_SIZE: usize = 56;
const H3_DNS_REQUEST_HEADER_SIZE: usize = 6;
const TEST_MARK: u32 = 0xD0033;
const LOOPBACK_ADDR: &str = "127.0.0.1:443";
const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
// TODO: Make some tests for DohConnection and QuicheConfigCache.
fn make_testing_variables() -> (
ServerInfo,
HashMap<u32, (ServerInfo, Option<DohConnection>)>,
Arc<Mutex<QuicheConfigCache>>,
Arc<Runtime>,
) {
let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
let info = ServerInfo {
net_id: TEST_NET_ID,
url: Url::parse(LOCALHOST_URL).unwrap(),
peer_addr: LOOPBACK_ADDR.parse().unwrap(),
domain: None,
sk_mark: 0,
cert_path: None,
};
let config_cache =
Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None }));
let rt = Arc::new(
Builder::new_current_thread()
.thread_name("test-runtime")
.enable_all()
.build()
.expect("Failed to create testing tokio runtime"),
);
(info, test_map, config_cache, rt)
}
#[test]
fn make_connection_if_needed() {
let (info, mut test_map, config, rt) = make_testing_variables();
rt.block_on(async {
// Expect to make a new connection.
let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone())
.unwrap()
.unwrap();
assert_eq!(doh.info.net_id, info.net_id);
assert!(matches!(doh.state, ConnectionState::Idle));
doh.state = ConnectionState::Error;
test_map.insert(info.net_id, (info.clone(), Some(doh)));
// Expect that we will get a connection with fail status that we added to the map before.
let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone())
.unwrap()
.unwrap();
assert_eq!(doh.info.net_id, info.net_id);
assert!(matches!(doh.state, ConnectionState::Error));
doh.state = make_dummy_connected_state();
test_map.insert(info.net_id, (info.clone(), Some(doh)));
// Expect that we will get None because the map contains a connection with ready status.
assert!(super::make_connection_if_needed(&info, &mut test_map, config.clone())
.unwrap()
.is_none());
});
}
#[test]
fn handle_query_cmd() {
let (info, mut test_map, config, rt) = make_testing_variables();
let t = Duration::from_millis(100);
rt.block_on(async {
// Test no available server cases.
let (resp_tx, resp_rx) = oneshot::channel();
let query = super::make_probe_query().unwrap();
super::handle_query_cmd(
info.net_id,
query.clone(),
Instant::now().checked_add(t).unwrap(),
resp_tx,
&mut test_map,
)
.await;
assert_eq!(
timeout(t, resp_rx).await.unwrap().unwrap(),
Response::Error { error: QueryError::ServerNotReady }
);
let (resp_tx, resp_rx) = oneshot::channel();
test_map.insert(info.net_id, (info.clone(), None));
super::handle_query_cmd(
info.net_id,
query.clone(),
Instant::now().checked_add(t).unwrap(),
resp_tx,
&mut test_map,
)
.await;
assert_eq!(
timeout(t, resp_rx).await.unwrap().unwrap(),
Response::Error { error: QueryError::ServerNotReady }
);
// Test the connection broken case.
test_map.clear();
let (resp_tx, resp_rx) = oneshot::channel();
let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone())
.unwrap()
.unwrap();
doh.state = ConnectionState::Error;
test_map.insert(info.net_id, (info.clone(), Some(doh)));
super::handle_query_cmd(
info.net_id,
query.clone(),
Instant::now().checked_add(t).unwrap(),
resp_tx,
&mut test_map,
)
.await;
assert_eq!(
timeout(t, resp_rx).await.unwrap().unwrap(),
Response::Error { error: QueryError::BrokenServer }
);
});
}
extern "C" fn success_cb(
net_id: uint32_t,
success: bool,
ip_addr: *const c_char,
host: *const c_char,
) {
assert!(success);
unsafe {
assert_validation_info(net_id, ip_addr, host);
}
}
extern "C" fn fail_cb(
net_id: uint32_t,
success: bool,
ip_addr: *const c_char,
host: *const c_char,
) {
assert!(!success);
unsafe {
assert_validation_info(net_id, ip_addr, host);
}
}
// # Safety
// `ip_addr`, `host` are null terminated strings
unsafe fn assert_validation_info(
net_id: uint32_t,
ip_addr: *const c_char,
host: *const c_char,
) {
assert_eq!(net_id, TEST_NET_ID);
let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
assert_eq!(ip_addr, expected_addr.ip().to_string());
let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
assert_eq!(host, "");
}
fn make_testing_connection_variables() -> (Pin<Box<quiche::Connection>>, UdpSocket) {
let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
let udp_sk = UdpSocket::from_std(sk).unwrap();
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid").unwrap();
let connid = quiche::ConnectionId::from_ref(&scid);
let mut config = super::create_quiche_config(None).unwrap();
let quic_conn =
quiche::connect(None, &connid, LOOPBACK_ADDR.parse().unwrap(), &mut config).unwrap();
(quic_conn, udp_sk)
}
fn make_dummy_connected_state() -> super::ConnectionState {
let (quic_conn, udp_sk) = make_testing_connection_variables();
ConnectionState::Connected {
quic_conn,
udp_sk,
h3_conn: None,
query_map: HashMap::new(),
expired_time: None,
}
}
fn make_dummy_connecting_state() -> super::ConnectionState {
let (quic_conn, udp_sk) = make_testing_connection_variables();
ConnectionState::Connecting {
quic_conn: Some(quic_conn),
udp_sk: Some(udp_sk),
expired_time: None,
}
}
#[test]
fn report_private_dns_validation() {
let info = ServerInfo {
net_id: TEST_NET_ID,
url: Url::parse(LOCALHOST_URL).unwrap(),
peer_addr: LOOPBACK_ADDR.parse().unwrap(),
domain: None,
sk_mark: 0,
cert_path: None,
};
let rt = Arc::new(
Builder::new_current_thread()
.thread_name("test-runtime")
.enable_io()
.build()
.expect("Failed to create testing tokio runtime"),
);
let default_panic = std::panic::take_hook();
// Exit the test if the worker inside tokio runtime panicked.
std::panic::set_hook(Box::new(move |info| {
default_panic(info);
std::process::exit(1);
}));
rt.block_on(async {
super::report_private_dns_validation(
&info,
&make_dummy_connected_state(),
rt.clone(),
success_cb,
);
super::report_private_dns_validation(
&info,
&ConnectionState::Error,
rt.clone(),
fail_cb,
);
super::report_private_dns_validation(
&info,
&make_dummy_connecting_state(),
rt.clone(),
fail_cb,
);
super::report_private_dns_validation(
&info,
&ConnectionState::Idle,
rt.clone(),
fail_cb,
);
});
}
#[test]
fn make_probe_query_and_request() {
let probe_query = super::make_probe_query().unwrap();
let url = Url::parse(LOCALHOST_URL).unwrap();
let request = make_dns_request(&probe_query, &url).unwrap();
// Verify H3 DNS request.
assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE);
assert_eq!(request[0].name(), b":method");
assert_eq!(request[0].value(), b"GET");
assert_eq!(request[1].name(), b":scheme");
assert_eq!(request[1].value(), b"https");
assert_eq!(request[2].name(), b":authority");
assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes());
assert_eq!(request[3].name(), b":path");
let mut path = String::from(url.path());
path.push_str("?dns=");
path.push_str(&probe_query);
assert_eq!(request[3].value(), path.as_bytes());
assert_eq!(request[5].name(), b"accept");
assert_eq!(request[5].value(), b"application/dns-message");
// Verify DNS probe packet.
let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap();
assert_eq!(bytes.len(), PROBE_QUERY_SIZE);
// TODO: Parse the result to ensure it's a valid DNS packet.
}
#[test]
fn create_quiche_config() {
assert!(
super::create_quiche_config(None).is_ok(),
"quiche config without cert creating failed"
);
assert!(
super::create_quiche_config(Some("data/local/tmp/")).is_ok(),
"quiche config with cert creating failed"
);
}
#[test]
fn make_doh_udp_socket() {
// Make a socket connecting to loopback with a test mark.
let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
// Check if the socket is connected to loopback.
assert_eq!(
sk.peer_addr().unwrap(),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
);
// Check if the socket mark is correct.
let fd: RawFd = sk.as_raw_fd();
let mut mark: u32 = 50;
let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
unsafe {
// Safety: It's fine since the fd belongs to this test.
assert_eq!(
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_MARK,
&mut mark as *mut _ as *mut libc::c_void,
&mut size as *mut libc::socklen_t,
),
0
);
}
assert_eq!(mark, TEST_MARK);
// Check if the socket is non-blocking.
unsafe {
// Safety: It's fine since the fd belongs to this test.
assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
}
}
}