dns_proxy: UDP DNS forwarding
This CL implements the DNS query sending and reply back to client, with
some details left to be filled in later follow-up CLs.
To accomodate both test and production, android-specific FFIs are moved
behind "android-ffi" feature flag.
Flag: EXEMPT MAINLINE
Bug: 379992903
Test: m; cargo test --all-features
Change-Id: If6c06a56d397298a93c5c3d2ac3e72a7625d4ce5
diff --git a/Android.bp b/Android.bp
index 9a57c85..ee65b72 100644
--- a/Android.bp
+++ b/Android.bp
@@ -526,6 +526,7 @@
"liblog_rust",
"libnix",
"libnum_enum",
+ "librand",
"libsocket2",
"libthiserror",
"libtokio",
diff --git a/dns_proxy/Cargo.toml b/dns_proxy/Cargo.toml
index 5e243dc..574469b 100644
--- a/dns_proxy/Cargo.toml
+++ b/dns_proxy/Cargo.toml
@@ -30,6 +30,7 @@
mockall = "^0.13.1"
num_enum = "^0.7.3"
nix = { version = "^0.29.0", features = ["fs", "net", "user"] }
+rand = { version = "^0.8.5", features = ["std", "std_rng"] }
socket2 = "0.5.8"
thiserror = "^2.0.11"
tokio = { version = "^1.42.0", features = ["macros", "net", "rt", "rt-multi-thread", "sync", "time"] }
diff --git a/dns_proxy/ffi.rs b/dns_proxy/ffi.rs
index 6764fbb..f33782f 100644
--- a/dns_proxy/ffi.rs
+++ b/dns_proxy/ffi.rs
@@ -20,13 +20,34 @@
use log::error;
+use crate::server::NetContextClient;
use crate::server::Server;
+use crate::server::UpstreamParam;
+
+#[derive(Debug)]
+struct AndroidNetContextClient;
+
+impl AndroidNetContextClient {
+ fn new() -> Self {
+ Self {}
+ }
+}
+
+impl NetContextClient for AndroidNetContextClient {
+ fn get_dns_mark(&self, _upstream_param: &UpstreamParam) -> Option<u32> {
+ todo!();
+ }
+
+ fn get_name_servers(&self, _upstream_param: &UpstreamParam) -> Vec<std::net::IpAddr> {
+ todo!();
+ }
+}
/// Constructs the DNS proxy server.
/// Returns a pointer to the DNS proxy instance.
#[no_mangle]
pub extern "C" fn proxy_server_new() -> *mut Server {
- match Server::new() {
+ match Server::new(AndroidNetContextClient::new()) {
Ok(server) => Box::into_raw(Box::new(server)),
Err(e) => {
error!("proxy_server_new failed: {:?}", e);
@@ -62,8 +83,8 @@
if let Err(e) =
server.configure_dns_proxy(upstream_net_id, uid, downstream_if_index, downstream_port)
{
- error!("Error configure DNS proxy: {}", e);
- panic!();
+ error!("Error configure DNS proxy: {}", &e);
+ panic!("Error configure DNS proxy: {}", e);
}
}
@@ -78,8 +99,8 @@
downstream_port: u16,
) {
if let Err(e) = server.stop_dns_proxy(downstream_if_index, downstream_port) {
- error!("Error stop DNS proxy: {}", e);
- panic!();
+ error!("Error stop DNS proxy: {}", &e);
+ panic!("Error stop DNS proxy: {}", e);
}
}
diff --git a/dns_proxy/lib.rs b/dns_proxy/lib.rs
index b6cea9a..8035713 100644
--- a/dns_proxy/lib.rs
+++ b/dns_proxy/lib.rs
@@ -18,6 +18,7 @@
// TODO (b/379992903): Remove this after library is completed.
#![allow(dead_code)]
+#[cfg(feature = "android-ffi")]
mod ffi;
pub(crate) mod packet;
mod server;
diff --git a/dns_proxy/server.rs b/dns_proxy/server.rs
index 4f70615..946416f 100644
--- a/dns_proxy/server.rs
+++ b/dns_proxy/server.rs
@@ -17,8 +17,11 @@
//! DNS proxy server implementation.
use std::io::Error as IoError;
+use std::net::IpAddr;
use std::thread;
+#[cfg(test)]
+use mockall::automock;
use thiserror::Error;
use tokio::runtime::Builder as RuntimeBuilder;
use tokio::sync::mpsc;
@@ -26,6 +29,8 @@
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::RecvError;
+use crate::packet::PacketError;
+
mod driver;
use driver::Driver;
use driver::UdpDnsQuery;
@@ -42,6 +47,15 @@
/// Command send error:
#[error(transparent)]
CommandSend(#[from] SendError<Command>),
+ /// DNS response does not match query sent
+ #[error("DNS response does not match query sent")]
+ DnsResponseMismatch,
+ /// No name server on upstream
+ #[error("No name server on upstream")]
+ NoNameServer,
+ /// Packet error
+ #[error(transparent)]
+ Packet(#[from] PacketError),
/// Query send error:
#[error("Query send error: {0}")]
QuerySend(String),
@@ -128,12 +142,14 @@
impl Server {
/// Creates a server running a current thread runtime.
- pub fn new() -> Result<Server> {
+ pub fn new(net_context_client: impl NetContextClient + 'static) -> Result<Server> {
let runtime = RuntimeBuilder::new_current_thread().enable_all().build()?;
let (command_tx, command_rx) = mpsc::channel(100 /* capacity */);
let weak_command_tx = command_tx.clone().downgrade();
let join_handle = thread::spawn(move || {
- runtime.block_on(async { Driver::new(weak_command_tx, command_rx).drive().await });
+ runtime.block_on(async {
+ Driver::new(net_context_client, weak_command_tx, command_rx).drive().await
+ });
});
Ok(Server { command_tx, join_handle })
}
@@ -172,6 +188,16 @@
}
}
+/// NetContextClient gets the net context for upstream configuration.
+#[cfg_attr(test, automock)]
+pub(crate) trait NetContextClient: Send + Sync + std::fmt::Debug {
+ /// Returns the name servers given |upstream_param|.
+ fn get_name_servers(&self, upstream_param: &UpstreamParam) -> Vec<IpAddr>;
+
+ /// Returns the DNS fwmark for the upstream sockets. Returns None if none is available.
+ fn get_dns_mark(&self, upstream_param: &UpstreamParam) -> Option<u32>;
+}
+
#[cfg(test)]
pub mod tests {
use std::sync::atomic::AtomicU16;
@@ -188,14 +214,14 @@
/// Checks that the server can be created and deleted.
#[test]
fn server_new_delete() {
- let server = Server::new().unwrap();
+ let server = Server::new(MockNetContextClient::new()).unwrap();
server.stop();
}
/// Checks that the server can be created, added with a downstream, and deleted.
#[test]
fn server_new_listen_delete() {
- let server = Server::new().unwrap();
+ let server = Server::new(MockNetContextClient::new()).unwrap();
let test_port = next_test_port();
server
.configure_dns_proxy(
diff --git a/dns_proxy/server/driver.rs b/dns_proxy/server/driver.rs
index 7c7e925..9bbf9d2 100644
--- a/dns_proxy/server/driver.rs
+++ b/dns_proxy/server/driver.rs
@@ -21,7 +21,6 @@
use std::net::IpAddr;
use std::net::Ipv6Addr;
use std::net::SocketAddr;
-use std::net::UdpSocket as SyncUdpSocket;
use std::os::fd::AsFd;
use std::os::fd::AsRawFd;
use std::sync::Arc;
@@ -31,6 +30,8 @@
use log::info;
use nix::libc::c_int;
use nix::libc::setsockopt;
+use rand::rngs::ThreadRng;
+use rand::seq::SliceRandom;
use socket2::Domain;
use socket2::Protocol;
use socket2::Socket;
@@ -44,6 +45,7 @@
use super::Command;
use super::DownstreamIndexPort;
use super::Error;
+use super::NetContextClient;
use super::Result;
use super::UpstreamParam;
@@ -75,7 +77,9 @@
}
#[derive(Debug)]
-pub(super) struct Driver {
+pub(super) struct Driver<C: NetContextClient> {
+ /// NetContext client
+ net_context_client: C,
/// Weak Command sender
weak_command_tx: mpsc::WeakSender<Command>,
/// Command receiver.
@@ -84,18 +88,23 @@
upstream_map: HashMap<DownstreamIndexPort, UpstreamParam>,
/// Map of DownstreamIndexPort pair to the handle of UDP socket it is listening to.
downstream_task_handles_map: HashMap<DownstreamIndexPort, JoinHandle<Result<()>>>,
+ /// Random number generator
+ rng: ThreadRng,
}
-impl Driver {
+impl<C: NetContextClient> Driver<C> {
pub fn new(
+ net_context_client: C,
weak_command_tx: mpsc::WeakSender<Command>,
command_rx: mpsc::Receiver<Command>,
) -> Self {
Self {
+ net_context_client,
weak_command_tx,
command_rx,
upstream_map: HashMap::new(),
downstream_task_handles_map: HashMap::new(),
+ rng: rand::thread_rng(),
}
}
@@ -129,7 +138,7 @@
Some(())
}
Command::ForwardUdpQuery(udp_dns_query) => {
- if let Err(e) = self.handle_udp_dns_query(udp_dns_query).await {
+ if let Err(e) = self.handle_udp_dns_query(udp_dns_query) {
error!("Error handling UDP query: {}", e);
}
Some(())
@@ -146,10 +155,10 @@
if let HashMapEntry::Vacant(vacant_entry) =
self.downstream_task_handles_map.entry(index_port)
{
- let socket = Driver::build_udp_socket(&index_port)?;
+ let socket = build_udp_socket(&index_port)?;
let handle = spawn_downstream_udp_socket(
self.weak_command_tx.clone(),
- Arc::new(UdpSocket::from_std(socket)?),
+ Arc::new(socket),
index_port,
);
vacant_entry.insert(handle);
@@ -157,32 +166,6 @@
Ok(())
}
- fn build_udp_socket(index_port: &DownstreamIndexPort) -> Result<SyncUdpSocket> {
- let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
- Driver::set_downstream_sockopts(&socket, index_port.if_index)?;
- socket
- .bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::from_bits(0)), index_port.port).into())?;
- Ok(socket.into())
- }
-
- fn set_downstream_sockopts(socket: &Socket, if_index: u32) -> Result<()> {
- socket.set_nonblocking(true)?;
- if if_index > 0 {
- // TODO(409455084): workaround while waiting for upstream changes.
- // Safety: setting if_index is safe since we own the FD and if_index is fixed length.
- unsafe {
- setsockopt(
- socket.as_fd().as_raw_fd(),
- nix::libc::SOL_SOCKET,
- SO_BINDTOIFINDEX,
- &if_index as *const _ as *const nix::libc::c_void,
- std::mem::size_of::<c_int>() as nix::libc::socklen_t,
- );
- }
- }
- Ok(())
- }
-
async fn handle_stop_dns_proxy_cmd(&mut self, index_port: &DownstreamIndexPort) -> Result<()> {
self.upstream_map.remove(index_port);
if let Some(handle) = self.downstream_task_handles_map.remove(index_port) {
@@ -201,9 +184,81 @@
}
}
- async fn handle_udp_dns_query(&mut self, _query: UdpDnsQuery) -> Result<()> {
- todo!("Send query");
+ fn handle_udp_dns_query(&mut self, query: UdpDnsQuery) -> Result<()> {
+ let upstream_param = match self.upstream_map.get(&query.index_port) {
+ Some(p) => p.to_owned(),
+ None => return Ok(()),
+ };
+ let socket = self.configure_upstream_udp_socket(&upstream_param)?;
+ tokio::spawn(async move {
+ if let Err(e) = resolve_and_send_udp(socket, query).await {
+ error!("Error resolving and sending UDP query: {}", e);
+ }
+ });
+ Ok(())
}
+
+ fn configure_upstream_udp_socket(
+ &mut self,
+ upstream_param: &UpstreamParam,
+ ) -> Result<UdpSocket> {
+ let name_servers = self.net_context_client.get_name_servers(upstream_param);
+ let name_server = name_servers.choose(&mut self.rng).ok_or(Error::NoNameServer)?;
+ let domain = if name_server.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 };
+ let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
+ socket.set_nonblocking(true)?;
+ if let Some(mark) = self.net_context_client.get_dns_mark(upstream_param) {
+ socket.set_mark(mark)?;
+ }
+ // TODO(b:379992903): randomize port selection.
+ socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::from_bits(0)), 0).into())?;
+ socket.connect(&SocketAddr::new(*name_server, 53).into())?;
+ Ok(UdpSocket::from_std(socket.into())?)
+ }
+}
+
+fn build_udp_socket(index_port: &DownstreamIndexPort) -> Result<UdpSocket> {
+ let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
+ set_downstream_sockopts(&socket, index_port.if_index)?;
+ socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::from_bits(0)), index_port.port).into())?;
+ Ok(UdpSocket::from_std(socket.into())?)
+}
+
+fn set_downstream_sockopts(socket: &Socket, if_index: u32) -> Result<()> {
+ socket.set_nonblocking(true)?;
+ if if_index > 0 {
+ // TODO(409455084): workaround while waiting for upstream changes.
+ // Safety: setting if_index is safe since we own the FD and if_index is fixed length.
+ unsafe {
+ setsockopt(
+ socket.as_fd().as_raw_fd(),
+ nix::libc::SOL_SOCKET,
+ SO_BINDTOIFINDEX,
+ &if_index as *const _ as *const nix::libc::c_void,
+ std::mem::size_of::<c_int>() as nix::libc::socklen_t,
+ );
+ }
+ }
+ Ok(())
+}
+
+async fn resolve_and_send_udp(socket: UdpSocket, query: UdpDnsQuery) -> Result<()> {
+ // TODO (b:379992903): randomize DNS ID.
+ let query_dns_id = query.query_packet.header().id;
+ socket.send(query.query_packet.as_bytes()).await?;
+ // RFC 6891 6.2.3: UDP payload between 1280 and 1410 reasonable for EDNS.
+ // Fallback to TCP otherwise.
+ let mut buf = [0u8; 1410];
+ let size = socket.recv(&mut buf).await?;
+ let bytes = buf[0..size].to_vec();
+ let response = DnsPacket::try_from(bytes)?;
+ if response.header().id != query_dns_id {
+ return Err(Error::DnsResponseMismatch);
+ }
+ if let Some(resp_socket) = query.resp_socket.upgrade() {
+ resp_socket.send_to(response.as_bytes(), query.client_addr).await?;
+ }
+ Ok(())
}
/// Create downstream UDP socket that sends `UdpDnsQuery` through |query_tx|.