This builds, I guess.

This commit is contained in:
2025-01-12 14:10:23 -08:00
parent b30823a502
commit 268ca2d1a5
25 changed files with 1750 additions and 735 deletions

261
src/network/resolver.rs Normal file
View File

@@ -0,0 +1,261 @@
use crate::config::resolver::{DnsConfig, NameServerConfig};
use error_stack::report;
use futures::stream::{SelectAll, StreamExt};
use hickory_client::rr::Name;
use hickory_proto::error::ProtoError;
use hickory_proto::op::message::Message;
use hickory_proto::op::query::Query;
use hickory_proto::rr::record_data::RData;
use hickory_proto::rr::resource::Record;
use hickory_proto::rr::RecordType;
use hickory_proto::udp::UdpClientStream;
use hickory_proto::xfer::dns_request::DnsRequest;
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr};
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::time::{Duration, Instant};
#[derive(Debug, Error)]
pub enum ResolverConfigError {
#[error("Bad local domain name '{name}' provided: '{error}'")]
BadDomainName { name: String, error: ProtoError },
#[error("Bad domain name for search '{name}' provided: '{error}'")]
BadSearchName { name: String, error: ProtoError },
#[error("Couldn't set up client for server at '{address}': '{error}'")]
FailedToCreateDnsClient {
address: SocketAddr,
error: ProtoError,
},
#[error("No DNS servers found to search, and mDNS not enabled")]
NoHosts,
}
#[derive(Debug, Error)]
pub enum ResolveError {
#[error("No servers available for query")]
NoServersAvailable,
#[error("No responses found for query")]
NoResponses,
#[error("Error reading response: {error}")]
ResponseError { error: ProtoError },
}
pub struct Resolver {
search_domains: Vec<Name>,
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>,
cache: HashMap<Name, Vec<(Instant, IpAddr)>>,
}
impl Resolver {
/// Create a new DNS resolution engine for use by some part of the system.
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
let mut search_domains = Vec::new();
if let Some(local) = config.local_domain.as_ref() {
search_domains.push(Name::from_utf8(local).map_err(|e| {
report!(ResolverConfigError::BadDomainName {
name: local.clone(),
error: e,
})
})?);
}
for local_domain in config.search_domains.iter() {
search_domains.push(Name::from_utf8(local_domain).map_err(|e| {
report!(ResolverConfigError::BadSearchName {
name: local_domain.clone(),
error: e,
})
})?);
}
let mut client_connections = Vec::new();
for name_server_config in config.name_servers() {
let stream = UdpClientStream::with_bind_addr_and_timeout(
name_server_config.address,
name_server_config.bind_address.clone(),
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)),
)
.await
.map_err(|error| ResolverConfigError::FailedToCreateDnsClient {
address: name_server_config.address,
error,
})?;
client_connections.push((name_server_config, stream));
}
Ok(Resolver {
search_domains,
client_connections,
cache: HashMap::new(),
})
}
/// Look up the address of the given name, returning either a set of results
/// we've received in a reasonable amount of time, or an error.
pub async fn lookup(&mut self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
let names = self.expand_name(name);
let mut response_stream = self.create_response_stream(names).await;
if response_stream.is_empty() {
return Err(ResolveError::NoServersAvailable);
}
let mut first_error = None;
while let Some(response) = response_stream.next().await {
match response {
Err(e) => {
if first_error.is_none() {
first_error = Some(e);
}
}
Ok(response) => {
for answer in response.into_message().take_answers() {
self.handle_response(name, answer).await;
}
}
}
}
match first_error {
None => Err(ResolveError::NoResponses),
Some(error) => Err(ResolveError::ResponseError { error }),
}
}
async fn create_response_stream(&mut self, names: Vec<Name>) -> SelectAll<DnsResponseStream> {
let connections = std::mem::take(&mut self.client_connections);
let mut response_stream = futures::stream::SelectAll::new();
for (config, mut client) in connections.into_iter() {
if client.is_shutdown() {
let stream = UdpClientStream::with_bind_addr_and_timeout(
config.address,
config.bind_address.clone(),
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
)
.await;
match stream {
Ok(stream) => client = stream,
Err(_) => continue,
}
}
let mut message = Message::new();
for name in names.iter() {
message.add_query(Query::query(name.clone(), RecordType::A));
message.add_query(Query::query(name.clone(), RecordType::AAAA));
}
message.set_recursion_desired(true);
let request = DnsRequest::new(message, Default::default());
response_stream.push(client.send_message(request));
self.client_connections.push((config, client));
}
response_stream
}
/// Expand the given input name into the complete set of names that we should search for.
fn expand_name(&self, name: &Name) -> Vec<Name> {
let mut names = vec![name.clone()];
for search_domain in self.search_domains.iter() {
if let Ok(combined) = name.clone().append_name(search_domain) {
names.push(combined);
}
}
names
}
/// Handle an individual response from a server.
///
/// Returns true if we got an answer from the server, and updated the internal cache.
/// This answer is used externally to set a timer, so that we don't wait the full DNS
/// timeout period for answers to come back.
async fn handle_response(&mut self, query: &Name, record: Record) -> bool {
let ttl = record.ttl();
let Some(rdata) = record.into_data() else {
tracing::error!("for some reason, couldn't process incoming message");
return false;
};
let response: IpAddr = match rdata {
RData::A(arec) => arec.0.into(),
RData::AAAA(arec) => arec.0.into(),
_ => {
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
return false;
}
};
let expire_at = Instant::now() + Duration::from_secs(ttl as u64);
match self.cache.entry(query.clone()) {
std::collections::hash_map::Entry::Vacant(vec) => {
vec.insert(vec![(expire_at, response.clone())]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => {
clean_expired_entries(occ.get_mut());
occ.get_mut().push((expire_at, response.clone()));
}
}
true
}
}
fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) {
let now = Instant::now();
list.retain(|(expire_at, _)| expire_at > &now);
}
#[tokio::test]
async fn uhsure() {
let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
let name = Name::from_ascii("uhsure.com").unwrap();
let result = resolver.lookup(&name).await.unwrap();
println!("result = {:?}", result);
assert!(!result.is_empty());
}
#[tokio::test]
async fn manual_lookup() {
let mut stream: UdpClientStream<UdpSocket> = UdpClientStream::with_bind_addr_and_timeout(
SocketAddr::V4(std::net::SocketAddrV4::new(
std::net::Ipv4Addr::new(172, 19, 207, 1),
53,
)),
None,
Duration::from_secs(3),
)
.await
.expect("can create UDP DNS client stream");
let mut message = Message::new();
let name = Name::from_ascii("uhsure.com").unwrap();
message.add_query(Query::query(name.clone(), RecordType::A));
message.add_query(Query::query(name.clone(), RecordType::AAAA));
message.set_recursion_desired(true);
let request = DnsRequest::new(message, Default::default());
let mut responses = stream.send_message(request);
while let Some(response) = responses.next().await {
println!("response: {:?}", response);
}
unimplemented!();
}