262 lines
8.8 KiB
Rust
262 lines
8.8 KiB
Rust
use crate::config::resolver::{DnsConfig, NameServerConfig};
|
|
use error_stack::report;
|
|
use futures::stream::{SelectAll, StreamExt};
|
|
use hickory_client::rr::Name;
|
|
use hickory_proto::error::ProtoError;
|
|
use hickory_proto::op::message::Message;
|
|
use hickory_proto::op::query::Query;
|
|
use hickory_proto::rr::record_data::RData;
|
|
use hickory_proto::rr::resource::Record;
|
|
use hickory_proto::rr::RecordType;
|
|
use hickory_proto::udp::UdpClientStream;
|
|
use hickory_proto::xfer::dns_request::DnsRequest;
|
|
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use thiserror::Error;
|
|
use tokio::net::UdpSocket;
|
|
use tokio::time::{Duration, Instant};
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum ResolverConfigError {
|
|
#[error("Bad local domain name '{name}' provided: '{error}'")]
|
|
BadDomainName { name: String, error: ProtoError },
|
|
#[error("Bad domain name for search '{name}' provided: '{error}'")]
|
|
BadSearchName { name: String, error: ProtoError },
|
|
#[error("Couldn't set up client for server at '{address}': '{error}'")]
|
|
FailedToCreateDnsClient {
|
|
address: SocketAddr,
|
|
error: ProtoError,
|
|
},
|
|
#[error("No DNS servers found to search, and mDNS not enabled")]
|
|
NoHosts,
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum ResolveError {
|
|
#[error("No servers available for query")]
|
|
NoServersAvailable,
|
|
#[error("No responses found for query")]
|
|
NoResponses,
|
|
#[error("Error reading response: {error}")]
|
|
ResponseError { error: ProtoError },
|
|
}
|
|
|
|
pub struct Resolver {
|
|
search_domains: Vec<Name>,
|
|
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>,
|
|
cache: HashMap<Name, Vec<(Instant, IpAddr)>>,
|
|
}
|
|
|
|
impl Resolver {
|
|
/// Create a new DNS resolution engine for use by some part of the system.
|
|
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
|
|
let mut search_domains = Vec::new();
|
|
|
|
if let Some(local) = config.local_domain.as_ref() {
|
|
search_domains.push(Name::from_utf8(local).map_err(|e| {
|
|
report!(ResolverConfigError::BadDomainName {
|
|
name: local.clone(),
|
|
error: e,
|
|
})
|
|
})?);
|
|
}
|
|
|
|
for local_domain in config.search_domains.iter() {
|
|
search_domains.push(Name::from_utf8(local_domain).map_err(|e| {
|
|
report!(ResolverConfigError::BadSearchName {
|
|
name: local_domain.clone(),
|
|
error: e,
|
|
})
|
|
})?);
|
|
}
|
|
|
|
let mut client_connections = Vec::new();
|
|
|
|
for name_server_config in config.name_servers() {
|
|
let stream = UdpClientStream::with_bind_addr_and_timeout(
|
|
name_server_config.address,
|
|
name_server_config.bind_address.clone(),
|
|
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)),
|
|
)
|
|
.await
|
|
.map_err(|error| ResolverConfigError::FailedToCreateDnsClient {
|
|
address: name_server_config.address,
|
|
error,
|
|
})?;
|
|
|
|
client_connections.push((name_server_config, stream));
|
|
}
|
|
|
|
Ok(Resolver {
|
|
search_domains,
|
|
client_connections,
|
|
cache: HashMap::new(),
|
|
})
|
|
}
|
|
|
|
/// Look up the address of the given name, returning either a set of results
|
|
/// we've received in a reasonable amount of time, or an error.
|
|
pub async fn lookup(&mut self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
|
|
let names = self.expand_name(name);
|
|
let mut response_stream = self.create_response_stream(names).await;
|
|
|
|
if response_stream.is_empty() {
|
|
return Err(ResolveError::NoServersAvailable);
|
|
}
|
|
|
|
let mut first_error = None;
|
|
|
|
while let Some(response) = response_stream.next().await {
|
|
match response {
|
|
Err(e) => {
|
|
if first_error.is_none() {
|
|
first_error = Some(e);
|
|
}
|
|
}
|
|
|
|
Ok(response) => {
|
|
for answer in response.into_message().take_answers() {
|
|
self.handle_response(name, answer).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
match first_error {
|
|
None => Err(ResolveError::NoResponses),
|
|
Some(error) => Err(ResolveError::ResponseError { error }),
|
|
}
|
|
}
|
|
|
|
async fn create_response_stream(&mut self, names: Vec<Name>) -> SelectAll<DnsResponseStream> {
|
|
let connections = std::mem::take(&mut self.client_connections);
|
|
let mut response_stream = futures::stream::SelectAll::new();
|
|
|
|
for (config, mut client) in connections.into_iter() {
|
|
if client.is_shutdown() {
|
|
let stream = UdpClientStream::with_bind_addr_and_timeout(
|
|
config.address,
|
|
config.bind_address.clone(),
|
|
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
|
|
)
|
|
.await;
|
|
|
|
match stream {
|
|
Ok(stream) => client = stream,
|
|
Err(_) => continue,
|
|
}
|
|
}
|
|
|
|
let mut message = Message::new();
|
|
|
|
for name in names.iter() {
|
|
message.add_query(Query::query(name.clone(), RecordType::A));
|
|
message.add_query(Query::query(name.clone(), RecordType::AAAA));
|
|
}
|
|
message.set_recursion_desired(true);
|
|
|
|
let request = DnsRequest::new(message, Default::default());
|
|
response_stream.push(client.send_message(request));
|
|
self.client_connections.push((config, client));
|
|
}
|
|
|
|
response_stream
|
|
}
|
|
|
|
/// Expand the given input name into the complete set of names that we should search for.
|
|
fn expand_name(&self, name: &Name) -> Vec<Name> {
|
|
let mut names = vec![name.clone()];
|
|
|
|
for search_domain in self.search_domains.iter() {
|
|
if let Ok(combined) = name.clone().append_name(search_domain) {
|
|
names.push(combined);
|
|
}
|
|
}
|
|
|
|
names
|
|
}
|
|
|
|
/// Handle an individual response from a server.
|
|
///
|
|
/// Returns true if we got an answer from the server, and updated the internal cache.
|
|
/// This answer is used externally to set a timer, so that we don't wait the full DNS
|
|
/// timeout period for answers to come back.
|
|
async fn handle_response(&mut self, query: &Name, record: Record) -> bool {
|
|
let ttl = record.ttl();
|
|
|
|
let Some(rdata) = record.into_data() else {
|
|
tracing::error!("for some reason, couldn't process incoming message");
|
|
return false;
|
|
};
|
|
|
|
let response: IpAddr = match rdata {
|
|
RData::A(arec) => arec.0.into(),
|
|
RData::AAAA(arec) => arec.0.into(),
|
|
|
|
_ => {
|
|
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
|
|
return false;
|
|
}
|
|
};
|
|
|
|
let expire_at = Instant::now() + Duration::from_secs(ttl as u64);
|
|
|
|
match self.cache.entry(query.clone()) {
|
|
std::collections::hash_map::Entry::Vacant(vec) => {
|
|
vec.insert(vec![(expire_at, response.clone())]);
|
|
}
|
|
|
|
std::collections::hash_map::Entry::Occupied(mut occ) => {
|
|
clean_expired_entries(occ.get_mut());
|
|
occ.get_mut().push((expire_at, response.clone()));
|
|
}
|
|
}
|
|
|
|
true
|
|
}
|
|
}
|
|
|
|
fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) {
|
|
let now = Instant::now();
|
|
list.retain(|(expire_at, _)| expire_at > &now);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn uhsure() {
|
|
let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
|
|
let name = Name::from_ascii("uhsure.com").unwrap();
|
|
let result = resolver.lookup(&name).await.unwrap();
|
|
println!("result = {:?}", result);
|
|
assert!(!result.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn manual_lookup() {
|
|
let mut stream: UdpClientStream<UdpSocket> = UdpClientStream::with_bind_addr_and_timeout(
|
|
SocketAddr::V4(std::net::SocketAddrV4::new(
|
|
std::net::Ipv4Addr::new(172, 19, 207, 1),
|
|
53,
|
|
)),
|
|
None,
|
|
Duration::from_secs(3),
|
|
)
|
|
.await
|
|
.expect("can create UDP DNS client stream");
|
|
|
|
let mut message = Message::new();
|
|
let name = Name::from_ascii("uhsure.com").unwrap();
|
|
message.add_query(Query::query(name.clone(), RecordType::A));
|
|
message.add_query(Query::query(name.clone(), RecordType::AAAA));
|
|
message.set_recursion_desired(true);
|
|
|
|
let request = DnsRequest::new(message, Default::default());
|
|
let mut responses = stream.send_message(request);
|
|
|
|
while let Some(response) = responses.next().await {
|
|
println!("response: {:?}", response);
|
|
}
|
|
|
|
unimplemented!();
|
|
}
|