use crate::config::resolver::{DnsConfig, NameServerConfig}; use error_stack::report; use futures::stream::{SelectAll, StreamExt}; use hickory_client::rr::Name; use hickory_proto::error::ProtoError; use hickory_proto::op::message::Message; use hickory_proto::op::query::Query; use hickory_proto::rr::record_data::RData; use hickory_proto::rr::resource::Record; use hickory_proto::rr::RecordType; use hickory_proto::udp::UdpClientStream; use hickory_proto::xfer::dns_request::DnsRequest; use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream}; use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, SocketAddr}; use thiserror::Error; use tokio::net::UdpSocket; use tokio::time::{Duration, Instant}; #[derive(Debug, Error)] pub enum ResolverConfigError { #[error("Bad local domain name '{name}' provided: '{error}'")] BadDomainName { name: String, error: ProtoError }, #[error("Bad domain name for search '{name}' provided: '{error}'")] BadSearchName { name: String, error: ProtoError }, #[error("Couldn't set up client for server at '{address}': '{error}'")] FailedToCreateDnsClient { address: SocketAddr, error: ProtoError, }, #[error("No DNS servers found to search, and mDNS not enabled")] NoHosts, } #[derive(Debug, Error)] pub enum ResolveError { #[error("No servers available for query")] NoServersAvailable, #[error("No responses found for query")] NoResponses, #[error("Error reading response: {error}")] ResponseError { error: ProtoError }, } pub struct Resolver { search_domains: Vec, client_connections: Vec<(NameServerConfig, UdpClientStream)>, cache: HashMap>, } impl Resolver { /// Create a new DNS resolution engine for use by some part of the system. pub async fn new(config: &DnsConfig) -> error_stack::Result { let mut search_domains = Vec::new(); if let Some(local) = config.local_domain.as_ref() { search_domains.push(Name::from_utf8(local).map_err(|e| { report!(ResolverConfigError::BadDomainName { name: local.clone(), error: e, }) })?); } for local_domain in config.search_domains.iter() { search_domains.push(Name::from_utf8(local_domain).map_err(|e| { report!(ResolverConfigError::BadSearchName { name: local_domain.clone(), error: e, }) })?); } let mut client_connections = Vec::new(); for name_server_config in config.name_servers() { let stream = UdpClientStream::with_bind_addr_and_timeout( name_server_config.address, name_server_config.bind_address.clone(), Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)), ) .await .map_err(|error| ResolverConfigError::FailedToCreateDnsClient { address: name_server_config.address, error, })?; client_connections.push((name_server_config, stream)); } Ok(Resolver { search_domains, client_connections, cache: HashMap::new(), }) } /// Look up the address of the given name, returning either a set of results /// we've received in a reasonable amount of time, or an error. pub async fn lookup(&mut self, name: &Name) -> Result, ResolveError> { let names = self.expand_name(name); let mut response_stream = self.create_response_stream(names).await; if response_stream.is_empty() { return Err(ResolveError::NoServersAvailable); } let mut first_error = None; while let Some(response) = response_stream.next().await { match response { Err(e) => { if first_error.is_none() { first_error = Some(e); } } Ok(response) => { for answer in response.into_message().take_answers() { self.handle_response(name, answer).await; } } } } match first_error { None => Err(ResolveError::NoResponses), Some(error) => Err(ResolveError::ResponseError { error }), } } async fn create_response_stream(&mut self, names: Vec) -> SelectAll { let connections = std::mem::take(&mut self.client_connections); let mut response_stream = futures::stream::SelectAll::new(); for (config, mut client) in connections.into_iter() { if client.is_shutdown() { let stream = UdpClientStream::with_bind_addr_and_timeout( config.address, config.bind_address.clone(), Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)), ) .await; match stream { Ok(stream) => client = stream, Err(_) => continue, } } let mut message = Message::new(); for name in names.iter() { message.add_query(Query::query(name.clone(), RecordType::A)); message.add_query(Query::query(name.clone(), RecordType::AAAA)); } message.set_recursion_desired(true); let request = DnsRequest::new(message, Default::default()); response_stream.push(client.send_message(request)); self.client_connections.push((config, client)); } response_stream } /// Expand the given input name into the complete set of names that we should search for. fn expand_name(&self, name: &Name) -> Vec { let mut names = vec![name.clone()]; for search_domain in self.search_domains.iter() { if let Ok(combined) = name.clone().append_name(search_domain) { names.push(combined); } } names } /// Handle an individual response from a server. /// /// Returns true if we got an answer from the server, and updated the internal cache. /// This answer is used externally to set a timer, so that we don't wait the full DNS /// timeout period for answers to come back. async fn handle_response(&mut self, query: &Name, record: Record) -> bool { let ttl = record.ttl(); let Some(rdata) = record.into_data() else { tracing::error!("for some reason, couldn't process incoming message"); return false; }; let response: IpAddr = match rdata { RData::A(arec) => arec.0.into(), RData::AAAA(arec) => arec.0.into(), _ => { tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type"); return false; } }; let expire_at = Instant::now() + Duration::from_secs(ttl as u64); match self.cache.entry(query.clone()) { std::collections::hash_map::Entry::Vacant(vec) => { vec.insert(vec![(expire_at, response.clone())]); } std::collections::hash_map::Entry::Occupied(mut occ) => { clean_expired_entries(occ.get_mut()); occ.get_mut().push((expire_at, response.clone())); } } true } } fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) { let now = Instant::now(); list.retain(|(expire_at, _)| expire_at > &now); } #[tokio::test] async fn uhsure() { let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap(); let name = Name::from_ascii("uhsure.com").unwrap(); let result = resolver.lookup(&name).await.unwrap(); println!("result = {:?}", result); assert!(!result.is_empty()); } #[tokio::test] async fn manual_lookup() { let mut stream: UdpClientStream = UdpClientStream::with_bind_addr_and_timeout( SocketAddr::V4(std::net::SocketAddrV4::new( std::net::Ipv4Addr::new(172, 19, 207, 1), 53, )), None, Duration::from_secs(3), ) .await .expect("can create UDP DNS client stream"); let mut message = Message::new(); let name = Name::from_ascii("uhsure.com").unwrap(); message.add_query(Query::query(name.clone(), RecordType::A)); message.add_query(Query::query(name.clone(), RecordType::AAAA)); message.set_recursion_desired(true); let request = DnsRequest::new(message, Default::default()); let mut responses = stream.send_message(request); while let Some(response) = responses.next().await { println!("response: {:?}", response); } unimplemented!(); }