Clean ups and some other stuff?

This commit is contained in:
2025-02-07 20:59:04 -08:00
parent 268ca2d1a5
commit 31cd34d280
17 changed files with 506 additions and 236 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@ Cargo.lock
tarpaulin-report.html tarpaulin-report.html
proptest-regressions/ proptest-regressions/
config.toml

View File

@@ -1,37 +1,35 @@
use futures::stream::StreamExt;
use hickory_proto::udp::UdpStream;
use hush::config::server::ServerConfiguration; use hush::config::server::ServerConfiguration;
use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr};
use std::process; use std::process;
use thiserror::Error;
use tokio::net::UdpSocket;
async fn run_dns() -> tokio::io::Result<()> { //async fn run_dns() -> tokio::io::Result<()> {
let socket = UdpSocket::bind("127.0.0.1:3553").await?; // let socket = UdpSocket::bind("127.0.0.1:3553").await?;
let mut buffer = [0; 65536]; // let mut buffer = [0; 65536];
//
tracing::info!("Bound socket at {}, starting main DNS handler loop.", socket.local_addr()?); // tracing::info!(
let (first_len, first_addr) = socket.recv_from(&mut buffer).await?; // "Bound socket at {}, starting main DNS handler loop.",
println!("Received {} bytes from {}", first_len, first_addr); // socket.local_addr()?
// );
let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553)); // let (first_len, first_addr) = socket.recv_from(&mut buffer).await?;
let (mut stream, _handle) = UdpStream::with_bound(socket, remote); // println!("Received {} bytes from {}", first_len, first_addr);
//
while let Some(item) = stream.next().await { // let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553));
match item { // let (mut stream, _handle) = UdpStream::with_bound(socket, remote);
Err(e) => eprintln!("Got an error: {:?}", e), //
Ok(msg) => { // while let Some(item) = stream.next().await {
println!("Got a message from {}", msg.addr()); // match item {
match msg.to_message() { // Err(e) => eprintln!("Got an error: {:?}", e),
Err(e) => println!(" ... but it was malformed ({e})."), // Ok(msg) => {
Ok(v) => println!(" ... and it was {v}"), // println!("Got a message from {}", msg.addr());
} // match msg.to_message() {
} // Err(e) => println!(" ... but it was malformed ({e})."),
} // Ok(v) => println!(" ... and it was {v}"),
} // }
// }
Ok(()) // }
} // }
//
// Ok(())
//}
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
fn main() { fn main() {
@@ -57,7 +55,8 @@ fn main() {
}; };
let result = runtime.block_on(async move { let result = runtime.block_on(async move {
let version = std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "<unknown>".to_string()); let version =
std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "<unknown>".to_string());
tracing::info!(%version, "Starting Hush server (hushd)"); tracing::info!(%version, "Starting Hush server (hushd)");
hush::server::run(config).await hush::server::run(config).await
}); });

View File

@@ -95,15 +95,17 @@ impl FromStr for Target {
} }
} }
#[allow(unreachable_code,unused_variables)] #[allow(unreachable_code, unused_variables)]
async fn connect( async fn connect(
base_config: &ClientConfiguration, base_config: &ClientConfiguration,
target: &str, target: &str,
) -> error_stack::Result<(), OperationalError> { ) -> error_stack::Result<(), OperationalError> {
let resolver: Resolver = unimplemented!(); let mut resolver: Resolver = Resolver::new(&base_config.dns_config)
// let mut resolver = Resolver::new(&base_config.resolver) .await
// .await .change_context(OperationalError::Resolver)?;
// .change_context(OperationalError::DnsConfig)?; // let mut resolver = Resolver::new(&base_config.resolver)
// .await
// .change_context(OperationalError::DnsConfig)?;
let target = Target::from_str(target) let target = Target::from_str(target)
.change_context(OperationalError::UnableToParseHostAddress) .change_context(OperationalError::UnableToParseHostAddress)
.attach_printable_lazy(|| format!("target address '{}'", target))?; .attach_printable_lazy(|| format!("target address '{}'", target))?;
@@ -137,7 +139,7 @@ async fn connect(
.change_context(OperationalError::Connection)?; .change_context(OperationalError::Connection)?;
tracing::trace!("received their preamble"); tracing::trace!("received their preamble");
let _our_preamble = ssh::Preamble::default() ssh::Preamble::default()
.write(&mut stream) .write(&mut stream)
.instrument(stream_span) .instrument(stream_span)
.await .await

View File

@@ -9,9 +9,9 @@ pub mod resolver;
mod runtime; mod runtime;
pub mod server; pub mod server;
use tokio::runtime::Runtime;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
use tokio::runtime::Runtime;
//impl ClientConfiguration { //impl ClientConfiguration {
// pub async fn try_from( // pub async fn try_from(
@@ -84,4 +84,3 @@ pub fn configured_runtime(
.worker_threads(runtime.tokio_worker_threads) .worker_threads(runtime.tokio_worker_threads)
.build() .build()
} }

View File

@@ -1,17 +1,19 @@
use clap::Parser; use crate::config::command_line::ClientArguments;
pub use crate::config::command_line::ClientCommand;
use crate::config::config_file::ConfigFile; use crate::config::config_file::ConfigFile;
use crate::config::error::ConfigurationError; use crate::config::error::ConfigurationError;
use crate::config::logging::LoggingConfiguration; use crate::config::logging::LoggingConfiguration;
use crate::config::resolver::DnsConfig;
use crate::config::runtime::RuntimeConfiguration; use crate::config::runtime::RuntimeConfiguration;
use crate::config::command_line::ClientArguments; use clap::Parser;
pub use crate::config::command_line::ClientCommand;
use std::ffi::OsString; use std::ffi::OsString;
use tracing_core::LevelFilter; use tracing_core::LevelFilter;
pub struct ClientConfiguration { pub struct ClientConfiguration {
pub runtime: RuntimeConfiguration, pub runtime: RuntimeConfiguration,
pub logging: LoggingConfiguration, pub logging: LoggingConfiguration,
config_file: Option<ConfigFile>, pub dns_config: DnsConfig,
_config_file: Option<ConfigFile>,
command: ClientCommand, command: ClientCommand,
} }
@@ -30,25 +32,32 @@ impl ClientConfiguration {
T: Into<OsString> + Clone, T: Into<OsString> + Clone,
{ {
let mut command_line_arguments = ClientArguments::try_parse_from(args)?; let mut command_line_arguments = ClientArguments::try_parse_from(args)?;
let mut config_file = ConfigFile::new(command_line_arguments.arguments.config_file.take())?; let mut _config_file =
ConfigFile::new(command_line_arguments.arguments.config_file.take())?;
let mut runtime = RuntimeConfiguration::default(); let mut runtime = RuntimeConfiguration::default();
let mut logging = LoggingConfiguration::default(); let mut logging = LoggingConfiguration::default();
// we prefer the command line to the config file, so first merge // we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line, // in the config file so that when we later merge the command line,
// it overwrites any config file options. // it overwrites any config file options.
if let Some(config_file) = config_file.as_mut() { if let Some(config_file) = _config_file.as_mut() {
config_file.merge_standard_options_into(&mut runtime, &mut logging); config_file.merge_standard_options_into(&mut runtime, &mut logging);
} }
command_line_arguments.arguments.merge_standard_options_into(&mut runtime, &mut logging); command_line_arguments
.arguments
.merge_standard_options_into(&mut runtime, &mut logging);
if command_line_arguments.command.is_list_command() { if command_line_arguments.command.is_list_command() {
logging.filter = LevelFilter::ERROR; logging.filter = LevelFilter::ERROR;
} }
// FIXME!!
let dns_config = DnsConfig::default();
Ok(ClientConfiguration { Ok(ClientConfiguration {
runtime, runtime,
logging, logging,
config_file, dns_config,
_config_file,
command: command_line_arguments.command, command: command_line_arguments.command,
}) })
} }

View File

@@ -1,5 +1,5 @@
use crate::config::logging::{LogTarget, LoggingConfiguration};
use crate::config::console::ConsoleConfiguration; use crate::config::console::ConsoleConfiguration;
use crate::config::logging::{LogTarget, LoggingConfiguration};
use crate::config::runtime::RuntimeConfiguration; use crate::config::runtime::RuntimeConfiguration;
use clap::{Args, Parser, Subcommand}; use clap::{Args, Parser, Subcommand};
use std::net::SocketAddr; use std::net::SocketAddr;

View File

@@ -43,7 +43,7 @@ impl Arbitrary for ConfigFile {
keyed_section(ServerConfig::arbitrary()), keyed_section(ServerConfig::arbitrary()),
) )
.prop_map( .prop_map(
|(runtime, logging, resolver,sockets, keys, defaults, servers)| ConfigFile { |(runtime, logging, resolver, sockets, keys, defaults, servers)| ConfigFile {
runtime, runtime,
logging, logging,
resolver, resolver,
@@ -88,88 +88,105 @@ pub enum Permissions {
impl From<Permissions> for rustix::fs::Mode { impl From<Permissions> for rustix::fs::Mode {
fn from(value: Permissions) -> Self { fn from(value: Permissions) -> Self {
match value { match value {
Permissions::User => Permissions::User => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR,
Permissions::Group => Permissions::Group => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR,
Permissions::UserGroup => Permissions::UserGroup => {
rustix::fs::Mode::RUSR | rustix::fs::Mode::RUSR
rustix::fs::Mode::WUSR | | rustix::fs::Mode::WUSR
rustix::fs::Mode::RGRP | | rustix::fs::Mode::RGRP
rustix::fs::Mode::WGRP, | rustix::fs::Mode::WGRP
}
Permissions::Everyone => Permissions::Everyone => {
rustix::fs::Mode::RUSR | rustix::fs::Mode::RUSR
rustix::fs::Mode::WUSR | | rustix::fs::Mode::WUSR
rustix::fs::Mode::RGRP | | rustix::fs::Mode::RGRP
rustix::fs::Mode::WGRP | | rustix::fs::Mode::WGRP
rustix::fs::Mode::ROTH | | rustix::fs::Mode::ROTH
rustix::fs::Mode::WOTH, | rustix::fs::Mode::WOTH
}
} }
} }
} }
impl SocketConfig { impl SocketConfig {
pub async fn into_listener(self) -> Result<tokio::net::UnixListener, ConfigurationError> { pub async fn into_listener(self) -> Result<tokio::net::UnixListener, ConfigurationError> {
let base = tokio::net::UnixListener::bind(&self.path) let base = tokio::net::UnixListener::bind(&self.path).map_err(|error| {
.map_err(|error| ConfigurationError::CouldNotMakeSocket { ConfigurationError::CouldNotMakeSocket {
path: self.path.clone(),
error,
}
})?;
let user = self
.user
.map(|x| {
nix::unistd::User::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| {
Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(
std::io::ErrorKind::NotFound,
"could not find user",
),
})
})
})
.transpose()?;
let group = self
.group
.map(|x| {
nix::unistd::Group::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| {
Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(
std::io::ErrorKind::NotFound,
"could not find user",
),
})
})
})
.transpose()?;
if user.is_some() || group.is_some() {
std::os::unix::fs::chown(
&self.path,
user.map(|x| x.uid.as_raw()),
group.map(|x| x.gid.as_raw()),
)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "set user/group ownership".to_string(),
path: self.path.clone(), path: self.path.clone(),
error, error,
})?; })?;
let user = self.user.map(|x| {
nix::unistd::User::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
}))
}).transpose()?;
let group = self.group.map(|x| {
nix::unistd::Group::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
}))
}).transpose()?;
if user.is_some() || group.is_some() {
std::os::unix::fs::chown(&self.path, user.map(|x| x.uid.as_raw()), group.map(|x| x.gid.as_raw()))
.map_err(|error|
ConfigurationError::CouldNotSetPerms {
thing: "set user/group ownership".to_string(),
path: self.path.clone(),
error,
}
)?;
} }
let perms = self.permissions.unwrap_or(Permissions::UserGroup); let perms = self.permissions.unwrap_or(Permissions::UserGroup);
rustix::fs::chmod(&self.path, perms.into()) rustix::fs::chmod(&self.path, perms.into()).map_err(|error| {
.map_err(|error| ConfigurationError::CouldNotSetPerms { ConfigurationError::CouldNotSetPerms {
thing: "set permissions".to_string(), thing: "set permissions".to_string(),
path: self.path.clone(), path: self.path.clone(),
error: error.into(), error: error.into(),
})?; }
})?;
Ok(base) Ok(base)
} }
@@ -180,12 +197,18 @@ impl Arbitrary for SocketConfig {
type Parameters = (); type Parameters = ();
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(PathBuf::arbitrary(), (
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()), PathBuf::arbitrary(),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()), proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(Permissions::arbitrary())) proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
.prop_map(|(path, user, group, permissions)| proptest::option::of(Permissions::arbitrary()),
SocketConfig{ path, user, group, permissions }) )
.prop_map(|(path, user, group, permissions)| SocketConfig {
path,
user,
group,
permissions,
})
.boxed() .boxed()
} }
} }
@@ -200,7 +223,8 @@ impl Arbitrary for Permissions {
proptest::strategy::Just(Permissions::Group), proptest::strategy::Just(Permissions::Group),
proptest::strategy::Just(Permissions::UserGroup), proptest::strategy::Just(Permissions::UserGroup),
proptest::strategy::Just(Permissions::Everyone), proptest::strategy::Just(Permissions::Everyone),
].boxed() ]
.boxed()
} }
} }

View File

@@ -10,7 +10,6 @@ use proptest::strategy::{BoxedStrategy, Strategy};
use std::collections::HashSet; use std::collections::HashSet;
use std::str::FromStr; use std::str::FromStr;
#[derive(Debug)] #[derive(Debug)]
pub struct ClientConnectionOpts { pub struct ClientConnectionOpts {
pub key_exchange_algorithms: Vec<KeyExchangeAlgorithm>, pub key_exchange_algorithms: Vec<KeyExchangeAlgorithm>,
@@ -68,10 +67,12 @@ impl Arbitrary for ClientConnectionOpts {
} }
} }
fn finalize_options<T,E>(values: HashSet<&str>) -> Result<Vec<T>, E> fn finalize_options<T, E>(values: HashSet<&str>) -> Result<Vec<T>, E>
where T: FromStr<Err=E> where
T: FromStr<Err = E>,
{ {
Ok(values.into_iter().map(T::from_str).collect::<Result<Vec<T>,E>>()?) values
.into_iter()
.map(T::from_str)
.collect::<Result<Vec<T>, E>>()
} }

View File

@@ -42,6 +42,24 @@ impl Default for DnsConfig {
} }
} }
#[cfg(test)]
impl DnsConfig {
pub(crate) fn empty() -> Self {
DnsConfig {
built_in: None,
local_domain: None,
search_domains: vec![],
name_servers: vec![],
retry_attempts: None,
cache_size: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
allow_mdns: None,
}
}
}
impl Arbitrary for DnsConfig { impl Arbitrary for DnsConfig {
type Parameters = bool; type Parameters = bool;
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
@@ -178,12 +196,16 @@ impl Arbitrary for NameServerConfig {
proptest::option::of(u64::arbitrary()), proptest::option::of(u64::arbitrary()),
proptest::option::of(SocketAddr::arbitrary()), proptest::option::of(SocketAddr::arbitrary()),
) )
.prop_map(|(mut address, timeout_in_seconds, mut bind_address)| { .prop_map(|(mut address, mut timeout_in_seconds, mut bind_address)| {
clear_flow_and_scope_info(&mut address); clear_flow_and_scope_info(&mut address);
if let Some(bind_address) = bind_address.as_mut() { if let Some(bind_address) = bind_address.as_mut() {
clear_flow_and_scope_info(bind_address); clear_flow_and_scope_info(bind_address);
} }
if let Some(timeout_in_seconds) = timeout_in_seconds.as_mut() {
*timeout_in_seconds &= 0x7FFF_FFFF_FFFF_FFFF;
}
NameServerConfig { NameServerConfig {
address, address,
timeout_in_seconds, timeout_in_seconds,

View File

@@ -15,5 +15,3 @@ impl Default for RuntimeConfiguration {
} }
} }
} }

View File

@@ -1,9 +1,9 @@
use clap::Parser;
use crate::config::config_file::{ConfigFile, SocketConfig};
use crate::config::command_line::ServerArguments; use crate::config::command_line::ServerArguments;
use crate::config::config_file::{ConfigFile, SocketConfig};
use crate::config::error::ConfigurationError; use crate::config::error::ConfigurationError;
use crate::config::logging::LoggingConfiguration; use crate::config::logging::LoggingConfiguration;
use crate::config::runtime::RuntimeConfiguration; use crate::config::runtime::RuntimeConfiguration;
use clap::Parser;
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::OsString; use std::ffi::OsString;
use tokio::net::UnixListener; use tokio::net::UnixListener;
@@ -55,7 +55,9 @@ impl ServerConfiguration {
} }
/// Generate a series of UNIX sockets that clients will attempt to connect to. /// Generate a series of UNIX sockets that clients will attempt to connect to.
pub async fn generate_listener_sockets(&mut self) -> Result<HashMap<String, UnixListener>, ConfigurationError> { pub async fn generate_listener_sockets(
&mut self,
) -> Result<HashMap<String, UnixListener>, ConfigurationError> {
let mut results = HashMap::new(); let mut results = HashMap::new();
for (name, config) in std::mem::take(&mut self.sockets).into_iter() { for (name, config) in std::mem::take(&mut self.sockets).into_iter() {

View File

@@ -49,10 +49,25 @@ impl FromStr for Host {
return Ok(Host::IPv6(addr)); return Ok(Host::IPv6(addr));
} }
if let Some(prefix_removed) = s.strip_prefix('[') {
if let Some(cleaned) = prefix_removed.strip_suffix(']') {
match Ipv6Addr::from_str(cleaned) {
Ok(addr) => return Ok(Host::IPv6(addr)),
Err(error) => {
return Err(HostParseError::CouldNotParseIPv6 {
address: s.to_string(),
error,
})
}
}
}
}
if let Ok(name) = Name::from_utf8(s) { if let Ok(name) = Name::from_utf8(s) {
return Ok(Host::Hostname(name)); return Ok(Host::Hostname(name));
} }
println!(" ... not a hostname");
Err(HostParseError::InvalidHostname { Err(HostParseError::InvalidHostname {
hostname: s.to_string(), hostname: s.to_string(),
}) })

View File

@@ -1,5 +1,6 @@
use crate::config::resolver::{DnsConfig, NameServerConfig}; use crate::config::resolver::{DnsConfig, NameServerConfig};
use error_stack::report; use error_stack::report;
use futures::future::select;
use futures::stream::{SelectAll, StreamExt}; use futures::stream::{SelectAll, StreamExt};
use hickory_client::rr::Name; use hickory_client::rr::Name;
use hickory_proto::error::ProtoError; use hickory_proto::error::ProtoError;
@@ -13,9 +14,14 @@ use hickory_proto::xfer::dns_request::DnsRequest;
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream}; use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
#[cfg(test)]
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::time::{Duration, Instant}; use tokio::sync::{oneshot, Mutex};
use tokio::task::JoinSet;
use tokio::time::{timeout_at, Duration, Instant};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ResolverConfigError { pub enum ResolverConfigError {
@@ -44,8 +50,20 @@ pub enum ResolveError {
pub struct Resolver { pub struct Resolver {
search_domains: Vec<Name>, search_domains: Vec<Name>,
max_time_to_wait_for_initial: Duration,
time_to_wait_after_first: Duration,
time_to_wait_for_lingering: Duration,
state: Arc<Mutex<ResolverState>>,
}
pub struct ResolverState {
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>, client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>,
cache: HashMap<Name, Vec<(Instant, IpAddr)>>, cache: HashMap<Name, Vec<DnsResolution>>,
}
pub struct DnsResolution {
address: IpAddr,
expires_at: Instant,
} }
impl Resolver { impl Resolver {
@@ -76,7 +94,7 @@ impl Resolver {
for name_server_config in config.name_servers() { for name_server_config in config.name_servers() {
let stream = UdpClientStream::with_bind_addr_and_timeout( let stream = UdpClientStream::with_bind_addr_and_timeout(
name_server_config.address, name_server_config.address,
name_server_config.bind_address.clone(), name_server_config.bind_address,
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)), Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)),
) )
.await .await
@@ -90,54 +108,110 @@ impl Resolver {
Ok(Resolver { Ok(Resolver {
search_domains, search_domains,
client_connections, // FIXME: All of these should be configurable
cache: HashMap::new(), max_time_to_wait_for_initial: Duration::from_millis(150),
time_to_wait_after_first: Duration::from_millis(50),
time_to_wait_for_lingering: Duration::from_secs(2),
state: Arc::new(Mutex::new(ResolverState {
client_connections,
cache: HashMap::new(),
})),
}) })
} }
/// Look up the address of the given name, returning either a set of results /// Look up the address of the given name, returning either a set of results
/// we've received in a reasonable amount of time, or an error. /// we've received in a reasonable amount of time, or an error.
pub async fn lookup(&mut self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> { pub async fn lookup(&self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
let names = self.expand_name(name); let names = self.expand_name(name);
let mut response_stream = self.create_response_stream(names).await;
let cached_values = self.cached_lookup(&names).await;
if !cached_values.is_empty() {
return Ok(cached_values);
}
let mut response_stream = self.create_response_stream(&names).await;
if response_stream.is_empty() { if response_stream.is_empty() {
return Err(ResolveError::NoServersAvailable); return Err(ResolveError::NoServersAvailable);
} }
let mut first_error = None; let get_first_by = Instant::now() + self.max_time_to_wait_for_initial;
let got_responses =
get_responses_by(&self.state, get_first_by, name, &mut response_stream).await?;
while let Some(response) = response_stream.next().await { if got_responses {
match response { let get_rest_by = Instant::now() + self.time_to_wait_after_first;
Err(e) => { let _ = get_responses_by(&self.state, get_rest_by, name, &mut response_stream).await;
if first_error.is_none() { }
first_error = Some(e);
}
}
Ok(response) => { let lingering_deadline = Instant::now() + self.time_to_wait_for_lingering;
for answer in response.into_message().take_answers() { let state_copy = self.state.clone();
self.handle_response(name, answer).await; let name_copy = name.clone();
tokio::task::spawn(async move {
let _ = get_responses_by(
&state_copy,
lingering_deadline,
&name_copy,
&mut response_stream,
)
.await;
});
let after_search_lookup = self.cached_lookup(&names).await;
if after_search_lookup.is_empty() {
Err(ResolveError::NoResponses)
} else {
Ok(after_search_lookup)
}
}
/// Look up any cached IP address we might have for the given names.
///
/// As a side effect, this function will clean out any expired entries in the
/// cache.
async fn cached_lookup(&self, names: &[Name]) -> HashSet<IpAddr> {
let mut retval = HashSet::new();
let mut state = self.state.lock().await;
for name in names {
if let Some(mut existing) = state.cache.remove(name) {
let now = Instant::now();
existing.retain(|item| {
let keeper = item.expires_at > now;
if keeper {
retval.insert(item.address);
} }
keeper
});
if !existing.is_empty() {
state.cache.insert(name.clone(), existing);
} }
} }
} }
match first_error { retval
None => Err(ResolveError::NoResponses),
Some(error) => Err(ResolveError::ResponseError { error }),
}
} }
async fn create_response_stream(&mut self, names: Vec<Name>) -> SelectAll<DnsResponseStream> { /// Reach out to all of our current connections, and start the process of getting
let connections = std::mem::take(&mut self.client_connections); /// responses.
///
/// If the clients are shut down, we'll try to create new connections to them, or
/// remove them from our connection list if we can't. If there are no connections,
/// the `SelectAll` will be empty, and callers are advised to check for this
/// condition and inform the user that they're out of useful DNS servers.
async fn create_response_stream(&self, names: &[Name]) -> SelectAll<DnsResponseStream> {
let mut state = self.state.lock().await;
let connections = std::mem::take(&mut state.client_connections);
let mut response_stream = futures::stream::SelectAll::new(); let mut response_stream = futures::stream::SelectAll::new();
for (config, mut client) in connections.into_iter() { for (config, mut client) in connections.into_iter() {
if client.is_shutdown() { if client.is_shutdown() {
let stream = UdpClientStream::with_bind_addr_and_timeout( let stream = UdpClientStream::with_bind_addr_and_timeout(
config.address, config.address,
config.bind_address.clone(), config.bind_address,
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)), Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
) )
.await; .await;
@@ -158,7 +232,7 @@ impl Resolver {
let request = DnsRequest::new(message, Default::default()); let request = DnsRequest::new(message, Default::default());
response_stream.push(client.send_message(request)); response_stream.push(client.send_message(request));
self.client_connections.push((config, client)); state.client_connections.push((config, client));
} }
response_stream response_stream
@@ -177,85 +251,182 @@ impl Resolver {
names names
} }
/// Handle an individual response from a server. /// Run a cleaner task on this Resolver object.
/// ///
/// Returns true if we got an answer from the server, and updated the internal cache. /// The task will run in the given JoinSet, set it can be easily killed, tracked, waited
/// This answer is used externally to set a timer, so that we don't wait the full DNS /// for as desired. Every `period` interval, it will grab the state lock and clean out
/// timeout period for answers to come back. /// any entries that are timed out. Ideally, you shouldn't set this too frequently, so
async fn handle_response(&mut self, query: &Name, record: Record) -> bool { /// that this process doesn't interfere with other tasks trying to use the resolver.
let ttl = record.ttl(); ///
/// This task will only weakly hang on to the resolver state, so that if the Resolver
/// object is dropped it will quietly stop running. It will also stop running if it sees
/// a message on the kill signal, if one is provided, or if one is provided and the other
/// end drops the sender. (So if you don't want a kill signal, send `None`, or you might
/// accidentally kill your GC task too early.)
pub async fn run_resolver_gc(
&self,
task_set: &mut JoinSet<Result<(), ()>>,
mut kill_signal: Option<oneshot::Receiver<()>>,
interval: Duration,
) {
let weak_locked_state = Arc::downgrade(&self.state);
let Some(rdata) = record.into_data() else { task_set.spawn(async move {
tracing::error!("for some reason, couldn't process incoming message"); loop {
return false; tokio::time::sleep(interval).await;
let Some(locked_state) = weak_locked_state.upgrade() else {
break Ok(());
};
let mut state = if let Some(kill_signal) = kill_signal.as_mut() {
let locker = locked_state.lock();
futures::pin_mut!(locker);
match select(locker, kill_signal).await {
futures::future::Either::Left((state, _)) => state,
futures::future::Either::Right((_, _)) => break Ok(()),
}
} else {
locked_state.lock().await
};
let now = Instant::now();
state.cache.retain(|_, value| {
value.retain(|resolution| resolution.expires_at < now);
!value.is_empty()
});
drop(state);
}
});
}
/// Inject a mapping into the resolver, with the given TTL.
///
/// This is largely used for testing purposes, but could be used if there are static
/// addresses you want to always resolve in a particular way. For that use case, you
/// probably want to add absurdly-long TTLs for those names.
pub async fn inject_resolution(&self, name: Name, address: IpAddr, ttl: Duration) {
let resolution = DnsResolution {
address,
expires_at: Instant::now() + ttl,
}; };
let response: IpAddr = match rdata { let mut state = self.state.lock().await;
RData::A(arec) => arec.0.into(), match state.cache.entry(name) {
RData::AAAA(arec) => arec.0.into(), std::collections::hash_map::Entry::Vacant(vac) => {
vac.insert(vec![resolution]);
_ => {
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
return false;
} }
};
let expire_at = Instant::now() + Duration::from_secs(ttl as u64);
match self.cache.entry(query.clone()) {
std::collections::hash_map::Entry::Vacant(vec) => {
vec.insert(vec![(expire_at, response.clone())]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => { std::collections::hash_map::Entry::Occupied(mut occ) => {
clean_expired_entries(occ.get_mut()); occ.get_mut().push(resolution);
occ.get_mut().push((expire_at, response.clone()));
} }
} }
drop(state);
true
} }
} }
fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) { /// Get all the responses into cache that occur before the given timestamp.
///
/// Returns an error if all we get are errors while trying to read from these servers,
/// otherwise returns Ok(true) if we received some interesting responses, or Ok(false)
/// if we didn't. Note that receiving good responses wins out over receiving errors, so
/// if we receive 2 good responses and 3 errors, this function will return Ok(true).
async fn get_responses_by(
state: &Arc<Mutex<ResolverState>>,
deadline: Instant,
name: &Name,
stream: &mut SelectAll<DnsResponseStream>,
) -> Result<bool, ResolveError> {
let mut received_error = None;
let mut got_something = false;
loop {
match timeout_at(deadline, stream.next()).await {
Err(_) => return received_error.unwrap_or(Ok(got_something)),
Ok(None) if got_something => return Ok(got_something),
Ok(None) => return received_error.unwrap_or(Ok(false)),
Ok(Some(Err(error))) => {
if received_error.is_none() {
received_error = Some(Err(ResolveError::ResponseError { error }));
}
}
Ok(Some(Ok(response))) => {
for answer in response.into_message().take_answers() {
got_something |= handle_response(state, name, answer).await;
}
}
}
}
}
/// Handle an individual response from a server.
///
/// Returns true if we got an answer from the server, and updates the internal cache.
/// This function will never return an error, although there may be some odd processing
/// situations that could lead it to printing warnings to the logs.
async fn handle_response(state: &Arc<Mutex<ResolverState>>, query: &Name, record: Record) -> bool {
let ttl = record.ttl();
let Some(rdata) = record.into_data() else {
tracing::error!("for some reason, couldn't process incoming message");
return false;
};
let address: IpAddr = match rdata {
RData::A(arec) => arec.0.into(),
RData::AAAA(arec) => arec.0.into(),
_ => {
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
return false;
}
};
let now = Instant::now(); let now = Instant::now();
list.retain(|(expire_at, _)| expire_at > &now); let expires_at = now + Duration::from_secs(ttl as u64);
let resolution = DnsResolution {
address,
expires_at,
};
let mut state = state.lock().await;
match state.cache.entry(query.clone()) {
std::collections::hash_map::Entry::Vacant(vec) => {
vec.insert(vec![resolution]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => {
let vec = occ.get_mut();
vec.retain(|x| x.expires_at > now);
vec.push(resolution);
}
}
drop(state);
true
}
#[tokio::test]
async fn fetch_cached() {
let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
let name = Name::from_utf8("name.foo").unwrap();
let addr = IpAddr::from_str("1.2.4.5").unwrap();
resolver
.inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
.await;
let read = resolver.lookup(&name).await.unwrap();
assert!(read.contains(&addr));
} }
#[tokio::test] #[tokio::test]
async fn uhsure() { async fn uhsure() {
let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap(); let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
let name = Name::from_ascii("uhsure.com").unwrap(); let name = Name::from_ascii("uhsure.com").unwrap();
let result = resolver.lookup(&name).await.unwrap(); let result = resolver.lookup(&name).await.unwrap();
println!("result = {:?}", result); println!("result = {:?}", result);
assert!(!result.is_empty()); assert!(!result.is_empty());
} }
#[tokio::test]
async fn manual_lookup() {
let mut stream: UdpClientStream<UdpSocket> = UdpClientStream::with_bind_addr_and_timeout(
SocketAddr::V4(std::net::SocketAddrV4::new(
std::net::Ipv4Addr::new(172, 19, 207, 1),
53,
)),
None,
Duration::from_secs(3),
)
.await
.expect("can create UDP DNS client stream");
let mut message = Message::new();
let name = Name::from_ascii("uhsure.com").unwrap();
message.add_query(Query::query(name.clone(), RecordType::A));
message.add_query(Query::query(name.clone(), RecordType::AAAA));
message.set_recursion_desired(true);
let request = DnsRequest::new(message, Default::default());
let mut responses = stream.send_message(request);
while let Some(response) = responses.next().await {
println!("response: {:?}", response);
}
unimplemented!();
}

View File

@@ -48,4 +48,6 @@ pub enum OperationalError {
InvalidHostname(String), InvalidHostname(String),
#[error("Unable to parse host address")] #[error("Unable to parse host address")]
UnableToParseHostAddress, UnableToParseHostAddress,
#[error("Unable to configure resolver")]
Resolver,
} }

View File

@@ -14,15 +14,14 @@ pub enum TopLevelError {
} }
pub async fn run(mut config: ServerConfiguration) -> error_stack::Result<(), TopLevelError> { pub async fn run(mut config: ServerConfiguration) -> error_stack::Result<(), TopLevelError> {
let mut server_state = state::ServerState::default(); let _server_state = state::ServerState::default();
let listeners = config.generate_listener_sockets().await let listeners = config
.generate_listener_sockets()
.await
.change_context(TopLevelError::ConfigurationError)?; .change_context(TopLevelError::ConfigurationError)?;
for (name, listener) in listeners.into_iter() { for (_name, _listener) in listeners.into_iter() {}
}
Ok(()) Ok(())
} }

View File

@@ -10,6 +10,7 @@ pub struct SocketServer {
listener: UnixListener, listener: UnixListener,
} }
#[allow(unused)]
impl SocketServer { impl SocketServer {
/// Create a new server that will handle inputs from the client program. /// Create a new server that will handle inputs from the client program.
/// ///
@@ -18,10 +19,10 @@ impl SocketServer {
/// method will take ownership of the object. /// method will take ownership of the object.
pub fn new(name: String, listener: UnixListener) -> Self { pub fn new(name: String, listener: UnixListener) -> Self {
let path = listener let path = listener
.local_addr() .local_addr()
.map(|x| x.as_pathname().map(|p| format!("{}", p.display()))) .map(|x| x.as_pathname().map(|p| format!("{}", p.display())))
.unwrap_or_else(|_| None) .unwrap_or_else(|_| None)
.unwrap_or_else(|| format!("unknown")); .unwrap_or_else(|| "<unknown>".to_string());
tracing::trace!(%name, %path, "Creating new socket listener"); tracing::trace!(%name, %path, "Creating new socket listener");
@@ -41,7 +42,11 @@ impl SocketServer {
/// there, the core task should be unaffected. /// there, the core task should be unaffected.
pub async fn start(mut self) -> error_stack::Result<(), TopLevelError> { pub async fn start(mut self) -> error_stack::Result<(), TopLevelError> {
loop { loop {
let (stream, addr) = self.listener.accept().await.change_context(TopLevelError::SocketHandlerFailure)?; let (stream, addr) = self
.listener
.accept()
.await
.change_context(TopLevelError::SocketHandlerFailure)?;
let remote_addr = addr let remote_addr = addr
.as_pathname() .as_pathname()
.map(|x| x.display()) .map(|x| x.display())
@@ -58,8 +63,7 @@ impl SocketServer {
self.num_sessions_run += 1; self.num_sessions_run += 1;
tokio::task::spawn(Self::run_session(stream) tokio::task::spawn(Self::run_session(stream).instrument(span))
.instrument(span))
.await .await
.change_context(TopLevelError::SocketHandlerFailure)?; .change_context(TopLevelError::SocketHandlerFailure)?;
} }
@@ -70,6 +74,5 @@ impl SocketServer {
/// This is here because it's convenient, not because it shares state (obviously, /// This is here because it's convenient, not because it shares state (obviously,
/// given the type signature). But it's somewhat logically associated with this type, /// given the type signature). But it's somewhat logically associated with this type,
/// so it seems reasonable to make it an associated function. /// so it seems reasonable to make it an associated function.
async fn run_session(handle: UnixStream) { async fn run_session(handle: UnixStream) {}
}
} }

View File

@@ -6,3 +6,26 @@ pub struct ServerState {
/// The set of top level tasks that are currently running. /// The set of top level tasks that are currently running.
top_level_tasks: JoinSet<Result<(), TopLevelError>>, top_level_tasks: JoinSet<Result<(), TopLevelError>>,
} }
impl ServerState {
/// Block until all the current top level tasks have closed.
///
/// This will log any errors found in these tasks to the current log,
/// if there is one, but will otherwise drop them.
#[allow(unused)]
pub async fn shutdown(&mut self) {
while let Some(next) = self.top_level_tasks.join_next_with_id().await {
match next {
Err(e) => tracing::error!(id = %e.id(), "Failed to attach to top-level task"),
Ok((id, Err(e))) => {
tracing::error!(%id, "Top-level server error: {}", e);
}
Ok((id, Ok(()))) => {
tracing::debug!(%id, "Cleanly closed server task.");
}
}
}
}
}