dns_proxy: UDP DNS forwarding

This CL implements the DNS query sending and reply back to client, with
some details left to be filled in later follow-up CLs.
To accomodate both test and production, android-specific FFIs are moved
behind "android-ffi" feature flag.

Flag: EXEMPT MAINLINE
Bug: 379992903
Test: m; cargo test --all-features
Change-Id: If6c06a56d397298a93c5c3d2ac3e72a7625d4ce5
diff --git a/Android.bp b/Android.bp
index 9a57c85..ee65b72 100644
--- a/Android.bp
+++ b/Android.bp
@@ -526,6 +526,7 @@
         "liblog_rust",
         "libnix",
         "libnum_enum",
+        "librand",
         "libsocket2",
         "libthiserror",
         "libtokio",
diff --git a/dns_proxy/Cargo.toml b/dns_proxy/Cargo.toml
index 5e243dc..574469b 100644
--- a/dns_proxy/Cargo.toml
+++ b/dns_proxy/Cargo.toml
@@ -30,6 +30,7 @@
 mockall = "^0.13.1"
 num_enum = "^0.7.3"
 nix = { version = "^0.29.0", features = ["fs", "net", "user"] }
+rand = { version = "^0.8.5", features = ["std", "std_rng"] }
 socket2 = "0.5.8"
 thiserror = "^2.0.11"
 tokio = { version = "^1.42.0", features = ["macros", "net", "rt", "rt-multi-thread", "sync", "time"] }
diff --git a/dns_proxy/ffi.rs b/dns_proxy/ffi.rs
index 6764fbb..f33782f 100644
--- a/dns_proxy/ffi.rs
+++ b/dns_proxy/ffi.rs
@@ -20,13 +20,34 @@
 
 use log::error;
 
+use crate::server::NetContextClient;
 use crate::server::Server;
+use crate::server::UpstreamParam;
+
+#[derive(Debug)]
+struct AndroidNetContextClient;
+
+impl AndroidNetContextClient {
+    fn new() -> Self {
+        Self {}
+    }
+}
+
+impl NetContextClient for AndroidNetContextClient {
+    fn get_dns_mark(&self, _upstream_param: &UpstreamParam) -> Option<u32> {
+        todo!();
+    }
+
+    fn get_name_servers(&self, _upstream_param: &UpstreamParam) -> Vec<std::net::IpAddr> {
+        todo!();
+    }
+}
 
 /// Constructs the DNS proxy server.
 /// Returns a pointer to the DNS proxy instance.
 #[no_mangle]
 pub extern "C" fn proxy_server_new() -> *mut Server {
-    match Server::new() {
+    match Server::new(AndroidNetContextClient::new()) {
         Ok(server) => Box::into_raw(Box::new(server)),
         Err(e) => {
             error!("proxy_server_new failed: {:?}", e);
@@ -62,8 +83,8 @@
     if let Err(e) =
         server.configure_dns_proxy(upstream_net_id, uid, downstream_if_index, downstream_port)
     {
-        error!("Error configure DNS proxy: {}", e);
-        panic!();
+        error!("Error configure DNS proxy: {}", &e);
+        panic!("Error configure DNS proxy: {}", e);
     }
 }
 
@@ -78,8 +99,8 @@
     downstream_port: u16,
 ) {
     if let Err(e) = server.stop_dns_proxy(downstream_if_index, downstream_port) {
-        error!("Error stop DNS proxy: {}", e);
-        panic!();
+        error!("Error stop DNS proxy: {}", &e);
+        panic!("Error stop DNS proxy: {}", e);
     }
 }
 
diff --git a/dns_proxy/lib.rs b/dns_proxy/lib.rs
index b6cea9a..8035713 100644
--- a/dns_proxy/lib.rs
+++ b/dns_proxy/lib.rs
@@ -18,6 +18,7 @@
 // TODO (b/379992903): Remove this after library is completed.
 #![allow(dead_code)]
 
+#[cfg(feature = "android-ffi")]
 mod ffi;
 pub(crate) mod packet;
 mod server;
diff --git a/dns_proxy/server.rs b/dns_proxy/server.rs
index 4f70615..946416f 100644
--- a/dns_proxy/server.rs
+++ b/dns_proxy/server.rs
@@ -17,8 +17,11 @@
 //! DNS proxy server implementation.
 
 use std::io::Error as IoError;
+use std::net::IpAddr;
 use std::thread;
 
+#[cfg(test)]
+use mockall::automock;
 use thiserror::Error;
 use tokio::runtime::Builder as RuntimeBuilder;
 use tokio::sync::mpsc;
@@ -26,6 +29,8 @@
 use tokio::sync::oneshot;
 use tokio::sync::oneshot::error::RecvError;
 
+use crate::packet::PacketError;
+
 mod driver;
 use driver::Driver;
 use driver::UdpDnsQuery;
@@ -42,6 +47,15 @@
     /// Command send error:
     #[error(transparent)]
     CommandSend(#[from] SendError<Command>),
+    /// DNS response does not match query sent
+    #[error("DNS response does not match query sent")]
+    DnsResponseMismatch,
+    /// No name server on upstream
+    #[error("No name server on upstream")]
+    NoNameServer,
+    /// Packet error
+    #[error(transparent)]
+    Packet(#[from] PacketError),
     /// Query send error:
     #[error("Query send error: {0}")]
     QuerySend(String),
@@ -128,12 +142,14 @@
 
 impl Server {
     /// Creates a server running a current thread runtime.
-    pub fn new() -> Result<Server> {
+    pub fn new(net_context_client: impl NetContextClient + 'static) -> Result<Server> {
         let runtime = RuntimeBuilder::new_current_thread().enable_all().build()?;
         let (command_tx, command_rx) = mpsc::channel(100 /* capacity */);
         let weak_command_tx = command_tx.clone().downgrade();
         let join_handle = thread::spawn(move || {
-            runtime.block_on(async { Driver::new(weak_command_tx, command_rx).drive().await });
+            runtime.block_on(async {
+                Driver::new(net_context_client, weak_command_tx, command_rx).drive().await
+            });
         });
         Ok(Server { command_tx, join_handle })
     }
@@ -172,6 +188,16 @@
     }
 }
 
+/// NetContextClient gets the net context for upstream configuration.
+#[cfg_attr(test, automock)]
+pub(crate) trait NetContextClient: Send + Sync + std::fmt::Debug {
+    /// Returns the name servers given |upstream_param|.
+    fn get_name_servers(&self, upstream_param: &UpstreamParam) -> Vec<IpAddr>;
+
+    /// Returns the DNS fwmark for the upstream sockets. Returns None if none is available.
+    fn get_dns_mark(&self, upstream_param: &UpstreamParam) -> Option<u32>;
+}
+
 #[cfg(test)]
 pub mod tests {
     use std::sync::atomic::AtomicU16;
@@ -188,14 +214,14 @@
     /// Checks that the server can be created and deleted.
     #[test]
     fn server_new_delete() {
-        let server = Server::new().unwrap();
+        let server = Server::new(MockNetContextClient::new()).unwrap();
         server.stop();
     }
 
     /// Checks that the server can be created, added with a downstream, and deleted.
     #[test]
     fn server_new_listen_delete() {
-        let server = Server::new().unwrap();
+        let server = Server::new(MockNetContextClient::new()).unwrap();
         let test_port = next_test_port();
         server
             .configure_dns_proxy(
diff --git a/dns_proxy/server/driver.rs b/dns_proxy/server/driver.rs
index 7c7e925..9bbf9d2 100644
--- a/dns_proxy/server/driver.rs
+++ b/dns_proxy/server/driver.rs
@@ -21,7 +21,6 @@
 use std::net::IpAddr;
 use std::net::Ipv6Addr;
 use std::net::SocketAddr;
-use std::net::UdpSocket as SyncUdpSocket;
 use std::os::fd::AsFd;
 use std::os::fd::AsRawFd;
 use std::sync::Arc;
@@ -31,6 +30,8 @@
 use log::info;
 use nix::libc::c_int;
 use nix::libc::setsockopt;
+use rand::rngs::ThreadRng;
+use rand::seq::SliceRandom;
 use socket2::Domain;
 use socket2::Protocol;
 use socket2::Socket;
@@ -44,6 +45,7 @@
 use super::Command;
 use super::DownstreamIndexPort;
 use super::Error;
+use super::NetContextClient;
 use super::Result;
 use super::UpstreamParam;
 
@@ -75,7 +77,9 @@
 }
 
 #[derive(Debug)]
-pub(super) struct Driver {
+pub(super) struct Driver<C: NetContextClient> {
+    /// NetContext client
+    net_context_client: C,
     /// Weak Command sender
     weak_command_tx: mpsc::WeakSender<Command>,
     /// Command receiver.
@@ -84,18 +88,23 @@
     upstream_map: HashMap<DownstreamIndexPort, UpstreamParam>,
     /// Map of DownstreamIndexPort pair to the handle of UDP socket it is listening to.
     downstream_task_handles_map: HashMap<DownstreamIndexPort, JoinHandle<Result<()>>>,
+    /// Random number generator
+    rng: ThreadRng,
 }
 
-impl Driver {
+impl<C: NetContextClient> Driver<C> {
     pub fn new(
+        net_context_client: C,
         weak_command_tx: mpsc::WeakSender<Command>,
         command_rx: mpsc::Receiver<Command>,
     ) -> Self {
         Self {
+            net_context_client,
             weak_command_tx,
             command_rx,
             upstream_map: HashMap::new(),
             downstream_task_handles_map: HashMap::new(),
+            rng: rand::thread_rng(),
         }
     }
 
@@ -129,7 +138,7 @@
                 Some(())
             }
             Command::ForwardUdpQuery(udp_dns_query) => {
-                if let Err(e) = self.handle_udp_dns_query(udp_dns_query).await {
+                if let Err(e) = self.handle_udp_dns_query(udp_dns_query) {
                     error!("Error handling UDP query: {}", e);
                 }
                 Some(())
@@ -146,10 +155,10 @@
         if let HashMapEntry::Vacant(vacant_entry) =
             self.downstream_task_handles_map.entry(index_port)
         {
-            let socket = Driver::build_udp_socket(&index_port)?;
+            let socket = build_udp_socket(&index_port)?;
             let handle = spawn_downstream_udp_socket(
                 self.weak_command_tx.clone(),
-                Arc::new(UdpSocket::from_std(socket)?),
+                Arc::new(socket),
                 index_port,
             );
             vacant_entry.insert(handle);
@@ -157,32 +166,6 @@
         Ok(())
     }
 
-    fn build_udp_socket(index_port: &DownstreamIndexPort) -> Result<SyncUdpSocket> {
-        let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
-        Driver::set_downstream_sockopts(&socket, index_port.if_index)?;
-        socket
-            .bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::from_bits(0)), index_port.port).into())?;
-        Ok(socket.into())
-    }
-
-    fn set_downstream_sockopts(socket: &Socket, if_index: u32) -> Result<()> {
-        socket.set_nonblocking(true)?;
-        if if_index > 0 {
-            // TODO(409455084): workaround while waiting for upstream changes.
-            // Safety: setting if_index is safe since we own the FD and if_index is fixed length.
-            unsafe {
-                setsockopt(
-                    socket.as_fd().as_raw_fd(),
-                    nix::libc::SOL_SOCKET,
-                    SO_BINDTOIFINDEX,
-                    &if_index as *const _ as *const nix::libc::c_void,
-                    std::mem::size_of::<c_int>() as nix::libc::socklen_t,
-                );
-            }
-        }
-        Ok(())
-    }
-
     async fn handle_stop_dns_proxy_cmd(&mut self, index_port: &DownstreamIndexPort) -> Result<()> {
         self.upstream_map.remove(index_port);
         if let Some(handle) = self.downstream_task_handles_map.remove(index_port) {
@@ -201,9 +184,81 @@
         }
     }
 
-    async fn handle_udp_dns_query(&mut self, _query: UdpDnsQuery) -> Result<()> {
-        todo!("Send query");
+    fn handle_udp_dns_query(&mut self, query: UdpDnsQuery) -> Result<()> {
+        let upstream_param = match self.upstream_map.get(&query.index_port) {
+            Some(p) => p.to_owned(),
+            None => return Ok(()),
+        };
+        let socket = self.configure_upstream_udp_socket(&upstream_param)?;
+        tokio::spawn(async move {
+            if let Err(e) = resolve_and_send_udp(socket, query).await {
+                error!("Error resolving and sending UDP query: {}", e);
+            }
+        });
+        Ok(())
     }
+
+    fn configure_upstream_udp_socket(
+        &mut self,
+        upstream_param: &UpstreamParam,
+    ) -> Result<UdpSocket> {
+        let name_servers = self.net_context_client.get_name_servers(upstream_param);
+        let name_server = name_servers.choose(&mut self.rng).ok_or(Error::NoNameServer)?;
+        let domain = if name_server.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 };
+        let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
+        socket.set_nonblocking(true)?;
+        if let Some(mark) = self.net_context_client.get_dns_mark(upstream_param) {
+            socket.set_mark(mark)?;
+        }
+        // TODO(b:379992903): randomize port selection.
+        socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::from_bits(0)), 0).into())?;
+        socket.connect(&SocketAddr::new(*name_server, 53).into())?;
+        Ok(UdpSocket::from_std(socket.into())?)
+    }
+}
+
+fn build_udp_socket(index_port: &DownstreamIndexPort) -> Result<UdpSocket> {
+    let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
+    set_downstream_sockopts(&socket, index_port.if_index)?;
+    socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::from_bits(0)), index_port.port).into())?;
+    Ok(UdpSocket::from_std(socket.into())?)
+}
+
+fn set_downstream_sockopts(socket: &Socket, if_index: u32) -> Result<()> {
+    socket.set_nonblocking(true)?;
+    if if_index > 0 {
+        // TODO(409455084): workaround while waiting for upstream changes.
+        // Safety: setting if_index is safe since we own the FD and if_index is fixed length.
+        unsafe {
+            setsockopt(
+                socket.as_fd().as_raw_fd(),
+                nix::libc::SOL_SOCKET,
+                SO_BINDTOIFINDEX,
+                &if_index as *const _ as *const nix::libc::c_void,
+                std::mem::size_of::<c_int>() as nix::libc::socklen_t,
+            );
+        }
+    }
+    Ok(())
+}
+
+async fn resolve_and_send_udp(socket: UdpSocket, query: UdpDnsQuery) -> Result<()> {
+    // TODO (b:379992903): randomize DNS ID.
+    let query_dns_id = query.query_packet.header().id;
+    socket.send(query.query_packet.as_bytes()).await?;
+    // RFC 6891 6.2.3: UDP payload between 1280 and 1410 reasonable for EDNS.
+    // Fallback to TCP otherwise.
+    let mut buf = [0u8; 1410];
+    let size = socket.recv(&mut buf).await?;
+    let bytes = buf[0..size].to_vec();
+    let response = DnsPacket::try_from(bytes)?;
+    if response.header().id != query_dns_id {
+        return Err(Error::DnsResponseMismatch);
+    }
+    if let Some(resp_socket) = query.resp_socket.upgrade() {
+        resp_socket.send_to(response.as_bytes(), query.client_addr).await?;
+    }
+    Ok(())
 }
 
 /// Create downstream UDP socket that sends `UdpDnsQuery` through |query_tx|.