blob: 939d3abc4073a97fd93b31e39b0fa3df12dd87ff [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::anyhow;
use bytes::Bytes;
use log::{debug, warn};
use socket2::{Protocol, Socket};
use std::mem::MaybeUninit;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::mpsc;
const MDNS_IP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
const MDNS_PORT: u16 = 5353;
struct MacAddress(u64);
impl MacAddress {
fn to_be_bytes(&self) -> [u8; 6] {
// NOTE: mac address is le
self.0.to_le_bytes()[0..6].try_into().unwrap()
}
}
impl From<MacAddress> for [u8; 6] {
fn from(MacAddress(addr): MacAddress) -> Self {
let bytes = u64::to_le_bytes(addr);
bytes[0..6].try_into().unwrap()
}
}
impl From<&[u8; 6]> for MacAddress {
fn from(bytes: &[u8; 6]) -> Self {
Self(u64::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], 0, 0]))
}
}
#[repr(C, packed)]
struct Ipv4Header {
version_ihl: u8, // 4 bits Version, 4 bits Internet Header Length
dscp_ecn: u8, // 6 bits Differentiated Services Code Point, 2 bits Explicit Congestion Notification
total_length: u16,
identification: u16,
flags_fragment_offset: u16, // 3 bits Flags, 13 bits Fragment Offset
time_to_live: u8,
protocol: u8,
header_checksum: u16,
source_ip: [u8; 4],
destination_ip: [u8; 4],
}
macro_rules! be_vec {
( $( $x:expr ),* ) => {
Vec::<u8>::new().iter().copied()
$( .chain($x.to_be_bytes()) )*
.collect()
};
}
impl Ipv4Header {
fn calculate_checksum(&self) -> u16 {
let mut sum: u32 = 0;
// Process fixed-size fields (first 20 bytes)
let fixed_bytes: [u8; 20] = self.to_be_bytes();
for i in 0..10 {
let word = ((fixed_bytes[i * 2] as u16) << 8) | (fixed_bytes[i * 2 + 1] as u16);
sum += word as u32;
}
// Handle carries (fold the carry into the sum)
while (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
// One's complement
!sum as u16
}
fn update_checksum(&mut self) {
self.header_checksum = 0; // Reset checksum before calculation
self.header_checksum = self.calculate_checksum();
}
fn to_be_bytes(&self) -> [u8; 20] {
let mut v: Vec<u8> = be_vec![
self.version_ihl,
self.dscp_ecn,
self.total_length,
self.identification,
self.flags_fragment_offset,
self.time_to_live,
self.protocol,
self.header_checksum
];
v.extend(Ipv4Addr::from(self.source_ip).octets());
v.extend(Ipv4Addr::from(self.destination_ip).octets());
v.try_into().unwrap()
}
}
#[repr(C, packed)]
struct UdpHeader {
source_port: u16,
destination_port: u16,
length: u16,
checksum: u16,
}
impl UdpHeader {
fn to_be_bytes(&self) -> [u8; 8] {
let v: Vec<u8> =
be_vec![self.source_port, self.destination_port, self.length, self.checksum];
v.try_into().unwrap()
}
}
/* 10Mb/s ethernet header */
#[repr(C, packed)]
struct EtherHeader {
ether_dhost: [u8; 6],
ether_shost: [u8; 6],
ether_type: u16,
}
/* Ethernet protocol ID's */
const ETHER_TYPE_IP: u16 = 0x0800;
impl EtherHeader {
fn to_be_bytes(&self) -> [u8; 14] {
let v: Vec<u8> = be_vec![
MacAddress::from(&self.ether_dhost),
MacAddress::from(&self.ether_shost),
self.ether_type
];
v.try_into().unwrap()
}
}
// Define constants for header sizes (bytes)
const UDP_HEADER_LEN: usize = std::mem::size_of::<UdpHeader>();
const IPV4_HEADER_LEN: usize = std::mem::size_of::<Ipv4Header>();
const ETHER_HEADER_LEN: usize = std::mem::size_of::<EtherHeader>();
/// Creates a new UDP socket to bind to `port` with REUSEPORT option.
/// `non_block` indicates whether to set O_NONBLOCK for the socket.
fn new_socket(addr: SocketAddr, non_block: bool) -> anyhow::Result<Socket> {
let domain = match addr {
SocketAddr::V4(_) => socket2::Domain::IPV4,
SocketAddr::V6(_) => socket2::Domain::IPV6,
};
let socket = Socket::new(domain, socket2::Type::DGRAM, Some(Protocol::UDP))
.map_err(|e| anyhow!("create socket failed: {:?}", e))?;
socket.set_reuse_address(true).map_err(|e| anyhow!("set ReuseAddr failed: {:?}", e))?;
#[cfg(not(windows))]
socket.set_reuse_port(true)?;
#[cfg(unix)] // this is currently restricted to Unix's in socket2
socket.set_reuse_port(true).map_err(|e| anyhow!("set ReusePort failed: {:?}", e))?;
if non_block {
socket.set_nonblocking(true).map_err(|e| anyhow!("set O_NONBLOCK: {:?}", e))?;
}
socket.join_multicast_v4(&MDNS_IP, &Ipv4Addr::UNSPECIFIED)?;
socket.set_multicast_loop_v4(false).expect("set_multicast_loop_v4 call failed");
socket.bind(&addr.into()).map_err(|e| anyhow!("socket bind to {} failed: {:?}", &addr, e))?;
Ok(socket)
}
fn create_ethernet_frame(packet: &[u8], ip_addr: &Ipv4Addr) -> anyhow::Result<Vec<u8>> {
// TODO: Use the etherparse crate
let ether_header = EtherHeader {
// mDNS multicast IP address
ether_dhost: [0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb],
ether_shost: [0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb],
ether_type: ETHER_TYPE_IP,
};
// Create UDP Header
let udp_header = UdpHeader {
source_port: MDNS_PORT,
destination_port: MDNS_PORT,
length: (packet.len() + UDP_HEADER_LEN) as u16,
// Usually 0 for mDNS
checksum: 0,
};
// Create IPv4 Header
let mut ipv4_header = Ipv4Header {
version_ihl: 0x45,
dscp_ecn: 0,
total_length: (packet.len() + UDP_HEADER_LEN + IPV4_HEADER_LEN) as u16,
identification: 0,
flags_fragment_offset: 0,
time_to_live: 64,
protocol: 17,
header_checksum: 0,
source_ip: ip_addr.octets(),
// mDNS multicast
destination_ip: MDNS_IP.octets(),
};
ipv4_header.update_checksum();
// Combine Headers and Payload (Safely using Vec)
let mut response_packet =
Vec::with_capacity(ETHER_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN + packet.len());
response_packet.extend_from_slice(&ether_header.to_be_bytes());
response_packet.extend_from_slice(&ipv4_header.to_be_bytes());
response_packet.extend_from_slice(&udp_header.to_be_bytes());
response_packet.extend_from_slice(packet);
Ok(response_packet)
}
pub fn run_mdns_forwarder(tx: mpsc::Sender<Bytes>) -> anyhow::Result<()> {
let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), MDNS_PORT);
let socket = new_socket(addr.into(), false)?;
// Typical max mDNS packet size
let mut buf: [MaybeUninit<u8>; 1500] = [MaybeUninit::new(0 as u8); 1500];
loop {
let (size, src_addr) = socket.recv_from(&mut buf[..])?;
// SAFETY: `recv_from` implementation promises not to write uninitialized bytes to `buf`.
// Documentation: https://docs.rs/socket2/latest/socket2/struct.Socket.html#method.recv_from
let packet = unsafe { &*(&buf[..size] as *const [MaybeUninit<u8>] as *const [u8]) };
if let Some(socket_addr_v4) = src_addr.as_socket_ipv4() {
debug!("Received {} bytes from {:?}", packet.len(), socket_addr_v4);
match create_ethernet_frame(packet, socket_addr_v4.ip()) {
Ok(ethernet_frame) => {
if let Err(e) = tx.send(ethernet_frame.into()) {
warn!("Failed to send packet: {e}");
}
}
Err(e) => warn!("Failed to create ethernet frame from UDP payload: {}", e),
};
} else {
warn!("Forwarding mDNS from IPv6 is not supported: {:?}", src_addr);
}
}
}