diff --git a/.gitignore b/.gitignore index a047120..0e53496 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ Cargo.lock tarpaulin-report.html proptest-regressions/ +config.toml diff --git a/src/bin/hushd.rs b/src/bin/hushd.rs index 3c7b03e..ec84d6b 100644 --- a/src/bin/hushd.rs +++ b/src/bin/hushd.rs @@ -1,37 +1,35 @@ -use futures::stream::StreamExt; -use hickory_proto::udp::UdpStream; use hush::config::server::ServerConfiguration; -use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr}; use std::process; -use thiserror::Error; -use tokio::net::UdpSocket; -async fn run_dns() -> tokio::io::Result<()> { - let socket = UdpSocket::bind("127.0.0.1:3553").await?; - let mut buffer = [0; 65536]; - - tracing::info!("Bound socket at {}, starting main DNS handler loop.", socket.local_addr()?); - let (first_len, first_addr) = socket.recv_from(&mut buffer).await?; - println!("Received {} bytes from {}", first_len, first_addr); - - let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553)); - let (mut stream, _handle) = UdpStream::with_bound(socket, remote); - - while let Some(item) = stream.next().await { - match item { - Err(e) => eprintln!("Got an error: {:?}", e), - Ok(msg) => { - println!("Got a message from {}", msg.addr()); - match msg.to_message() { - Err(e) => println!(" ... but it was malformed ({e})."), - Ok(v) => println!(" ... and it was {v}"), - } - } - } - } - - Ok(()) -} +//async fn run_dns() -> tokio::io::Result<()> { +// let socket = UdpSocket::bind("127.0.0.1:3553").await?; +// let mut buffer = [0; 65536]; +// +// tracing::info!( +// "Bound socket at {}, starting main DNS handler loop.", +// socket.local_addr()? +// ); +// let (first_len, first_addr) = socket.recv_from(&mut buffer).await?; +// println!("Received {} bytes from {}", first_len, first_addr); +// +// let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553)); +// let (mut stream, _handle) = UdpStream::with_bound(socket, remote); +// +// while let Some(item) = stream.next().await { +// match item { +// Err(e) => eprintln!("Got an error: {:?}", e), +// Ok(msg) => { +// println!("Got a message from {}", msg.addr()); +// match msg.to_message() { +// Err(e) => println!(" ... but it was malformed ({e})."), +// Ok(v) => println!(" ... and it was {v}"), +// } +// } +// } +// } +// +// Ok(()) +//} #[cfg(not(tarpaulin_include))] fn main() { @@ -57,7 +55,8 @@ fn main() { }; let result = runtime.block_on(async move { - let version = std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "".to_string()); + let version = + std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "".to_string()); tracing::info!(%version, "Starting Hush server (hushd)"); hush::server::run(config).await }); diff --git a/src/client.rs b/src/client.rs index 3c4c8f5..180ad4f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -95,15 +95,17 @@ impl FromStr for Target { } } -#[allow(unreachable_code,unused_variables)] +#[allow(unreachable_code, unused_variables)] async fn connect( base_config: &ClientConfiguration, target: &str, ) -> error_stack::Result<(), OperationalError> { - let resolver: Resolver = unimplemented!(); -// let mut resolver = Resolver::new(&base_config.resolver) -// .await -// .change_context(OperationalError::DnsConfig)?; + let mut resolver: Resolver = Resolver::new(&base_config.dns_config) + .await + .change_context(OperationalError::Resolver)?; + // let mut resolver = Resolver::new(&base_config.resolver) + // .await + // .change_context(OperationalError::DnsConfig)?; let target = Target::from_str(target) .change_context(OperationalError::UnableToParseHostAddress) .attach_printable_lazy(|| format!("target address '{}'", target))?; @@ -137,7 +139,7 @@ async fn connect( .change_context(OperationalError::Connection)?; tracing::trace!("received their preamble"); - let _our_preamble = ssh::Preamble::default() + ssh::Preamble::default() .write(&mut stream) .instrument(stream_span) .await diff --git a/src/config.rs b/src/config.rs index 43437c2..e85eff6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,9 +9,9 @@ pub mod resolver; mod runtime; pub mod server; +use tokio::runtime::Runtime; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::prelude::*; -use tokio::runtime::Runtime; //impl ClientConfiguration { // pub async fn try_from( @@ -84,4 +84,3 @@ pub fn configured_runtime( .worker_threads(runtime.tokio_worker_threads) .build() } - diff --git a/src/config/client.rs b/src/config/client.rs index e0a0ff5..ba1ead2 100644 --- a/src/config/client.rs +++ b/src/config/client.rs @@ -1,17 +1,19 @@ -use clap::Parser; +use crate::config::command_line::ClientArguments; +pub use crate::config::command_line::ClientCommand; use crate::config::config_file::ConfigFile; use crate::config::error::ConfigurationError; use crate::config::logging::LoggingConfiguration; +use crate::config::resolver::DnsConfig; use crate::config::runtime::RuntimeConfiguration; -use crate::config::command_line::ClientArguments; -pub use crate::config::command_line::ClientCommand; +use clap::Parser; use std::ffi::OsString; use tracing_core::LevelFilter; pub struct ClientConfiguration { pub runtime: RuntimeConfiguration, pub logging: LoggingConfiguration, - config_file: Option, + pub dns_config: DnsConfig, + _config_file: Option, command: ClientCommand, } @@ -30,25 +32,32 @@ impl ClientConfiguration { T: Into + Clone, { let mut command_line_arguments = ClientArguments::try_parse_from(args)?; - let mut config_file = ConfigFile::new(command_line_arguments.arguments.config_file.take())?; + let mut _config_file = + ConfigFile::new(command_line_arguments.arguments.config_file.take())?; let mut runtime = RuntimeConfiguration::default(); let mut logging = LoggingConfiguration::default(); // we prefer the command line to the config file, so first merge // in the config file so that when we later merge the command line, // it overwrites any config file options. - if let Some(config_file) = config_file.as_mut() { + if let Some(config_file) = _config_file.as_mut() { config_file.merge_standard_options_into(&mut runtime, &mut logging); } - command_line_arguments.arguments.merge_standard_options_into(&mut runtime, &mut logging); + command_line_arguments + .arguments + .merge_standard_options_into(&mut runtime, &mut logging); if command_line_arguments.command.is_list_command() { logging.filter = LevelFilter::ERROR; } + // FIXME!! + let dns_config = DnsConfig::default(); + Ok(ClientConfiguration { runtime, logging, - config_file, + dns_config, + _config_file, command: command_line_arguments.command, }) } diff --git a/src/config/command_line.rs b/src/config/command_line.rs index acb394c..f9c67e6 100644 --- a/src/config/command_line.rs +++ b/src/config/command_line.rs @@ -1,5 +1,5 @@ -use crate::config::logging::{LogTarget, LoggingConfiguration}; use crate::config::console::ConsoleConfiguration; +use crate::config::logging::{LogTarget, LoggingConfiguration}; use crate::config::runtime::RuntimeConfiguration; use clap::{Args, Parser, Subcommand}; use std::net::SocketAddr; diff --git a/src/config/config_file.rs b/src/config/config_file.rs index d1bff99..59f8aec 100644 --- a/src/config/config_file.rs +++ b/src/config/config_file.rs @@ -43,7 +43,7 @@ impl Arbitrary for ConfigFile { keyed_section(ServerConfig::arbitrary()), ) .prop_map( - |(runtime, logging, resolver,sockets, keys, defaults, servers)| ConfigFile { + |(runtime, logging, resolver, sockets, keys, defaults, servers)| ConfigFile { runtime, logging, resolver, @@ -88,88 +88,105 @@ pub enum Permissions { impl From for rustix::fs::Mode { fn from(value: Permissions) -> Self { match value { - Permissions::User => - rustix::fs::Mode::RUSR | - rustix::fs::Mode::WUSR, + Permissions::User => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR, - Permissions::Group => - rustix::fs::Mode::RUSR | - rustix::fs::Mode::WUSR, - - Permissions::UserGroup => - rustix::fs::Mode::RUSR | - rustix::fs::Mode::WUSR | - rustix::fs::Mode::RGRP | - rustix::fs::Mode::WGRP, + Permissions::Group => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR, - Permissions::Everyone => - rustix::fs::Mode::RUSR | - rustix::fs::Mode::WUSR | - rustix::fs::Mode::RGRP | - rustix::fs::Mode::WGRP | - rustix::fs::Mode::ROTH | - rustix::fs::Mode::WOTH, + Permissions::UserGroup => { + rustix::fs::Mode::RUSR + | rustix::fs::Mode::WUSR + | rustix::fs::Mode::RGRP + | rustix::fs::Mode::WGRP + } + + Permissions::Everyone => { + rustix::fs::Mode::RUSR + | rustix::fs::Mode::WUSR + | rustix::fs::Mode::RGRP + | rustix::fs::Mode::WGRP + | rustix::fs::Mode::ROTH + | rustix::fs::Mode::WOTH + } } } } impl SocketConfig { pub async fn into_listener(self) -> Result { - let base = tokio::net::UnixListener::bind(&self.path) - .map_err(|error| ConfigurationError::CouldNotMakeSocket { + let base = tokio::net::UnixListener::bind(&self.path).map_err(|error| { + ConfigurationError::CouldNotMakeSocket { + path: self.path.clone(), + error, + } + })?; + + let user = self + .user + .map(|x| { + nix::unistd::User::from_name(&x) + .map_err(|error| ConfigurationError::CouldNotSetPerms { + thing: "find user".to_string(), + path: self.path.clone(), + error: error.into(), + }) + .transpose() + .unwrap_or_else(|| { + Err(ConfigurationError::CouldNotSetPerms { + thing: "find user".to_string(), + path: self.path.clone(), + error: std::io::Error::new( + std::io::ErrorKind::NotFound, + "could not find user", + ), + }) + }) + }) + .transpose()?; + + let group = self + .group + .map(|x| { + nix::unistd::Group::from_name(&x) + .map_err(|error| ConfigurationError::CouldNotSetPerms { + thing: "find user".to_string(), + path: self.path.clone(), + error: error.into(), + }) + .transpose() + .unwrap_or_else(|| { + Err(ConfigurationError::CouldNotSetPerms { + thing: "find user".to_string(), + path: self.path.clone(), + error: std::io::Error::new( + std::io::ErrorKind::NotFound, + "could not find user", + ), + }) + }) + }) + .transpose()?; + + if user.is_some() || group.is_some() { + std::os::unix::fs::chown( + &self.path, + user.map(|x| x.uid.as_raw()), + group.map(|x| x.gid.as_raw()), + ) + .map_err(|error| ConfigurationError::CouldNotSetPerms { + thing: "set user/group ownership".to_string(), path: self.path.clone(), error, })?; - - let user = self.user.map(|x| { - nix::unistd::User::from_name(&x) - .map_err(|error| ConfigurationError::CouldNotSetPerms { - thing: "find user".to_string(), - path: self.path.clone(), - error: error.into(), - }) - .transpose() - .unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms { - thing: "find user".to_string(), - path: self.path.clone(), - error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"), - })) - }).transpose()?; - - let group = self.group.map(|x| { - nix::unistd::Group::from_name(&x) - .map_err(|error| ConfigurationError::CouldNotSetPerms { - thing: "find user".to_string(), - path: self.path.clone(), - error: error.into(), - }) - .transpose() - .unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms { - thing: "find user".to_string(), - path: self.path.clone(), - error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"), - })) - }).transpose()?; - - if user.is_some() || group.is_some() { - std::os::unix::fs::chown(&self.path, user.map(|x| x.uid.as_raw()), group.map(|x| x.gid.as_raw())) - .map_err(|error| - ConfigurationError::CouldNotSetPerms { - thing: "set user/group ownership".to_string(), - path: self.path.clone(), - error, - } - )?; } let perms = self.permissions.unwrap_or(Permissions::UserGroup); - rustix::fs::chmod(&self.path, perms.into()) - .map_err(|error| ConfigurationError::CouldNotSetPerms { + rustix::fs::chmod(&self.path, perms.into()).map_err(|error| { + ConfigurationError::CouldNotSetPerms { thing: "set permissions".to_string(), path: self.path.clone(), error: error.into(), - })?; - + } + })?; Ok(base) } @@ -180,12 +197,18 @@ impl Arbitrary for SocketConfig { type Parameters = (); fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - (PathBuf::arbitrary(), - proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()), - proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()), - proptest::option::of(Permissions::arbitrary())) - .prop_map(|(path, user, group, permissions)| - SocketConfig{ path, user, group, permissions }) + ( + PathBuf::arbitrary(), + proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()), + proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()), + proptest::option::of(Permissions::arbitrary()), + ) + .prop_map(|(path, user, group, permissions)| SocketConfig { + path, + user, + group, + permissions, + }) .boxed() } } @@ -200,7 +223,8 @@ impl Arbitrary for Permissions { proptest::strategy::Just(Permissions::Group), proptest::strategy::Just(Permissions::UserGroup), proptest::strategy::Just(Permissions::Everyone), - ].boxed() + ] + .boxed() } } diff --git a/src/config/connection.rs b/src/config/connection.rs index 4f678f6..f8ec2f5 100644 --- a/src/config/connection.rs +++ b/src/config/connection.rs @@ -10,7 +10,6 @@ use proptest::strategy::{BoxedStrategy, Strategy}; use std::collections::HashSet; use std::str::FromStr; - #[derive(Debug)] pub struct ClientConnectionOpts { pub key_exchange_algorithms: Vec, @@ -68,10 +67,12 @@ impl Arbitrary for ClientConnectionOpts { } } -fn finalize_options(values: HashSet<&str>) -> Result, E> - where T: FromStr +fn finalize_options(values: HashSet<&str>) -> Result, E> +where + T: FromStr, { - Ok(values.into_iter().map(T::from_str).collect::,E>>()?) + values + .into_iter() + .map(T::from_str) + .collect::, E>>() } - - diff --git a/src/config/resolver.rs b/src/config/resolver.rs index f56f882..36b29d4 100644 --- a/src/config/resolver.rs +++ b/src/config/resolver.rs @@ -42,6 +42,24 @@ impl Default for DnsConfig { } } +#[cfg(test)] +impl DnsConfig { + pub(crate) fn empty() -> Self { + DnsConfig { + built_in: None, + local_domain: None, + search_domains: vec![], + name_servers: vec![], + retry_attempts: None, + cache_size: None, + max_concurrent_requests_for_query: None, + preserve_intermediates: None, + shuffle_dns_servers: None, + allow_mdns: None, + } + } +} + impl Arbitrary for DnsConfig { type Parameters = bool; type Strategy = BoxedStrategy; @@ -178,12 +196,16 @@ impl Arbitrary for NameServerConfig { proptest::option::of(u64::arbitrary()), proptest::option::of(SocketAddr::arbitrary()), ) - .prop_map(|(mut address, timeout_in_seconds, mut bind_address)| { + .prop_map(|(mut address, mut timeout_in_seconds, mut bind_address)| { clear_flow_and_scope_info(&mut address); if let Some(bind_address) = bind_address.as_mut() { clear_flow_and_scope_info(bind_address); } + if let Some(timeout_in_seconds) = timeout_in_seconds.as_mut() { + *timeout_in_seconds &= 0x7FFF_FFFF_FFFF_FFFF; + } + NameServerConfig { address, timeout_in_seconds, diff --git a/src/config/runtime.rs b/src/config/runtime.rs index 61f6605..72be273 100644 --- a/src/config/runtime.rs +++ b/src/config/runtime.rs @@ -15,5 +15,3 @@ impl Default for RuntimeConfiguration { } } } - - diff --git a/src/config/server.rs b/src/config/server.rs index 56c0b97..9706cd7 100644 --- a/src/config/server.rs +++ b/src/config/server.rs @@ -1,9 +1,9 @@ -use clap::Parser; -use crate::config::config_file::{ConfigFile, SocketConfig}; use crate::config::command_line::ServerArguments; +use crate::config::config_file::{ConfigFile, SocketConfig}; use crate::config::error::ConfigurationError; use crate::config::logging::LoggingConfiguration; use crate::config::runtime::RuntimeConfiguration; +use clap::Parser; use std::collections::HashMap; use std::ffi::OsString; use tokio::net::UnixListener; @@ -55,7 +55,9 @@ impl ServerConfiguration { } /// Generate a series of UNIX sockets that clients will attempt to connect to. - pub async fn generate_listener_sockets(&mut self) -> Result, ConfigurationError> { + pub async fn generate_listener_sockets( + &mut self, + ) -> Result, ConfigurationError> { let mut results = HashMap::new(); for (name, config) in std::mem::take(&mut self.sockets).into_iter() { diff --git a/src/network/host.rs b/src/network/host.rs index d1f2925..d0e5c91 100644 --- a/src/network/host.rs +++ b/src/network/host.rs @@ -49,10 +49,25 @@ impl FromStr for Host { return Ok(Host::IPv6(addr)); } + if let Some(prefix_removed) = s.strip_prefix('[') { + if let Some(cleaned) = prefix_removed.strip_suffix(']') { + match Ipv6Addr::from_str(cleaned) { + Ok(addr) => return Ok(Host::IPv6(addr)), + Err(error) => { + return Err(HostParseError::CouldNotParseIPv6 { + address: s.to_string(), + error, + }) + } + } + } + } + if let Ok(name) = Name::from_utf8(s) { return Ok(Host::Hostname(name)); } + println!(" ... not a hostname"); Err(HostParseError::InvalidHostname { hostname: s.to_string(), }) diff --git a/src/network/resolver.rs b/src/network/resolver.rs index 70b0663..67302e8 100644 --- a/src/network/resolver.rs +++ b/src/network/resolver.rs @@ -1,5 +1,6 @@ use crate::config::resolver::{DnsConfig, NameServerConfig}; use error_stack::report; +use futures::future::select; use futures::stream::{SelectAll, StreamExt}; use hickory_client::rr::Name; use hickory_proto::error::ProtoError; @@ -13,9 +14,14 @@ use hickory_proto::xfer::dns_request::DnsRequest; use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream}; use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, SocketAddr}; +#[cfg(test)] +use std::str::FromStr; +use std::sync::Arc; use thiserror::Error; use tokio::net::UdpSocket; -use tokio::time::{Duration, Instant}; +use tokio::sync::{oneshot, Mutex}; +use tokio::task::JoinSet; +use tokio::time::{timeout_at, Duration, Instant}; #[derive(Debug, Error)] pub enum ResolverConfigError { @@ -44,8 +50,20 @@ pub enum ResolveError { pub struct Resolver { search_domains: Vec, + max_time_to_wait_for_initial: Duration, + time_to_wait_after_first: Duration, + time_to_wait_for_lingering: Duration, + state: Arc>, +} + +pub struct ResolverState { client_connections: Vec<(NameServerConfig, UdpClientStream)>, - cache: HashMap>, + cache: HashMap>, +} + +pub struct DnsResolution { + address: IpAddr, + expires_at: Instant, } impl Resolver { @@ -76,7 +94,7 @@ impl Resolver { for name_server_config in config.name_servers() { let stream = UdpClientStream::with_bind_addr_and_timeout( name_server_config.address, - name_server_config.bind_address.clone(), + name_server_config.bind_address, Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)), ) .await @@ -90,54 +108,110 @@ impl Resolver { Ok(Resolver { search_domains, - client_connections, - cache: HashMap::new(), + // FIXME: All of these should be configurable + max_time_to_wait_for_initial: Duration::from_millis(150), + time_to_wait_after_first: Duration::from_millis(50), + time_to_wait_for_lingering: Duration::from_secs(2), + state: Arc::new(Mutex::new(ResolverState { + client_connections, + cache: HashMap::new(), + })), }) } /// Look up the address of the given name, returning either a set of results /// we've received in a reasonable amount of time, or an error. - pub async fn lookup(&mut self, name: &Name) -> Result, ResolveError> { + pub async fn lookup(&self, name: &Name) -> Result, ResolveError> { let names = self.expand_name(name); - let mut response_stream = self.create_response_stream(names).await; + let cached_values = self.cached_lookup(&names).await; + if !cached_values.is_empty() { + return Ok(cached_values); + } + + let mut response_stream = self.create_response_stream(&names).await; if response_stream.is_empty() { return Err(ResolveError::NoServersAvailable); } - let mut first_error = None; + let get_first_by = Instant::now() + self.max_time_to_wait_for_initial; + let got_responses = + get_responses_by(&self.state, get_first_by, name, &mut response_stream).await?; - while let Some(response) = response_stream.next().await { - match response { - Err(e) => { - if first_error.is_none() { - first_error = Some(e); - } - } + if got_responses { + let get_rest_by = Instant::now() + self.time_to_wait_after_first; + let _ = get_responses_by(&self.state, get_rest_by, name, &mut response_stream).await; + } - Ok(response) => { - for answer in response.into_message().take_answers() { - self.handle_response(name, answer).await; + let lingering_deadline = Instant::now() + self.time_to_wait_for_lingering; + let state_copy = self.state.clone(); + let name_copy = name.clone(); + tokio::task::spawn(async move { + let _ = get_responses_by( + &state_copy, + lingering_deadline, + &name_copy, + &mut response_stream, + ) + .await; + }); + + let after_search_lookup = self.cached_lookup(&names).await; + if after_search_lookup.is_empty() { + Err(ResolveError::NoResponses) + } else { + Ok(after_search_lookup) + } + } + + /// Look up any cached IP address we might have for the given names. + /// + /// As a side effect, this function will clean out any expired entries in the + /// cache. + async fn cached_lookup(&self, names: &[Name]) -> HashSet { + let mut retval = HashSet::new(); + let mut state = self.state.lock().await; + + for name in names { + if let Some(mut existing) = state.cache.remove(name) { + let now = Instant::now(); + + existing.retain(|item| { + let keeper = item.expires_at > now; + + if keeper { + retval.insert(item.address); } + + keeper + }); + + if !existing.is_empty() { + state.cache.insert(name.clone(), existing); } } } - match first_error { - None => Err(ResolveError::NoResponses), - Some(error) => Err(ResolveError::ResponseError { error }), - } + retval } - async fn create_response_stream(&mut self, names: Vec) -> SelectAll { - let connections = std::mem::take(&mut self.client_connections); + /// Reach out to all of our current connections, and start the process of getting + /// responses. + /// + /// If the clients are shut down, we'll try to create new connections to them, or + /// remove them from our connection list if we can't. If there are no connections, + /// the `SelectAll` will be empty, and callers are advised to check for this + /// condition and inform the user that they're out of useful DNS servers. + async fn create_response_stream(&self, names: &[Name]) -> SelectAll { + let mut state = self.state.lock().await; + let connections = std::mem::take(&mut state.client_connections); let mut response_stream = futures::stream::SelectAll::new(); for (config, mut client) in connections.into_iter() { if client.is_shutdown() { let stream = UdpClientStream::with_bind_addr_and_timeout( config.address, - config.bind_address.clone(), + config.bind_address, Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)), ) .await; @@ -158,7 +232,7 @@ impl Resolver { let request = DnsRequest::new(message, Default::default()); response_stream.push(client.send_message(request)); - self.client_connections.push((config, client)); + state.client_connections.push((config, client)); } response_stream @@ -177,85 +251,182 @@ impl Resolver { names } - /// Handle an individual response from a server. + /// Run a cleaner task on this Resolver object. /// - /// Returns true if we got an answer from the server, and updated the internal cache. - /// This answer is used externally to set a timer, so that we don't wait the full DNS - /// timeout period for answers to come back. - async fn handle_response(&mut self, query: &Name, record: Record) -> bool { - let ttl = record.ttl(); + /// The task will run in the given JoinSet, set it can be easily killed, tracked, waited + /// for as desired. Every `period` interval, it will grab the state lock and clean out + /// any entries that are timed out. Ideally, you shouldn't set this too frequently, so + /// that this process doesn't interfere with other tasks trying to use the resolver. + /// + /// This task will only weakly hang on to the resolver state, so that if the Resolver + /// object is dropped it will quietly stop running. It will also stop running if it sees + /// a message on the kill signal, if one is provided, or if one is provided and the other + /// end drops the sender. (So if you don't want a kill signal, send `None`, or you might + /// accidentally kill your GC task too early.) + pub async fn run_resolver_gc( + &self, + task_set: &mut JoinSet>, + mut kill_signal: Option>, + interval: Duration, + ) { + let weak_locked_state = Arc::downgrade(&self.state); - let Some(rdata) = record.into_data() else { - tracing::error!("for some reason, couldn't process incoming message"); - return false; + task_set.spawn(async move { + loop { + tokio::time::sleep(interval).await; + + let Some(locked_state) = weak_locked_state.upgrade() else { + break Ok(()); + }; + + let mut state = if let Some(kill_signal) = kill_signal.as_mut() { + let locker = locked_state.lock(); + + futures::pin_mut!(locker); + match select(locker, kill_signal).await { + futures::future::Either::Left((state, _)) => state, + futures::future::Either::Right((_, _)) => break Ok(()), + } + } else { + locked_state.lock().await + }; + + let now = Instant::now(); + state.cache.retain(|_, value| { + value.retain(|resolution| resolution.expires_at < now); + !value.is_empty() + }); + + drop(state); + } + }); + } + + /// Inject a mapping into the resolver, with the given TTL. + /// + /// This is largely used for testing purposes, but could be used if there are static + /// addresses you want to always resolve in a particular way. For that use case, you + /// probably want to add absurdly-long TTLs for those names. + pub async fn inject_resolution(&self, name: Name, address: IpAddr, ttl: Duration) { + let resolution = DnsResolution { + address, + expires_at: Instant::now() + ttl, }; - let response: IpAddr = match rdata { - RData::A(arec) => arec.0.into(), - RData::AAAA(arec) => arec.0.into(), - - _ => { - tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type"); - return false; + let mut state = self.state.lock().await; + match state.cache.entry(name) { + std::collections::hash_map::Entry::Vacant(vac) => { + vac.insert(vec![resolution]); } - }; - - let expire_at = Instant::now() + Duration::from_secs(ttl as u64); - - match self.cache.entry(query.clone()) { - std::collections::hash_map::Entry::Vacant(vec) => { - vec.insert(vec![(expire_at, response.clone())]); - } - std::collections::hash_map::Entry::Occupied(mut occ) => { - clean_expired_entries(occ.get_mut()); - occ.get_mut().push((expire_at, response.clone())); + occ.get_mut().push(resolution); } } - - true + drop(state); } } -fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) { +/// Get all the responses into cache that occur before the given timestamp. +/// +/// Returns an error if all we get are errors while trying to read from these servers, +/// otherwise returns Ok(true) if we received some interesting responses, or Ok(false) +/// if we didn't. Note that receiving good responses wins out over receiving errors, so +/// if we receive 2 good responses and 3 errors, this function will return Ok(true). +async fn get_responses_by( + state: &Arc>, + deadline: Instant, + name: &Name, + stream: &mut SelectAll, +) -> Result { + let mut received_error = None; + let mut got_something = false; + + loop { + match timeout_at(deadline, stream.next()).await { + Err(_) => return received_error.unwrap_or(Ok(got_something)), + + Ok(None) if got_something => return Ok(got_something), + Ok(None) => return received_error.unwrap_or(Ok(false)), + + Ok(Some(Err(error))) => { + if received_error.is_none() { + received_error = Some(Err(ResolveError::ResponseError { error })); + } + } + Ok(Some(Ok(response))) => { + for answer in response.into_message().take_answers() { + got_something |= handle_response(state, name, answer).await; + } + } + } + } +} + +/// Handle an individual response from a server. +/// +/// Returns true if we got an answer from the server, and updates the internal cache. +/// This function will never return an error, although there may be some odd processing +/// situations that could lead it to printing warnings to the logs. +async fn handle_response(state: &Arc>, query: &Name, record: Record) -> bool { + let ttl = record.ttl(); + + let Some(rdata) = record.into_data() else { + tracing::error!("for some reason, couldn't process incoming message"); + return false; + }; + + let address: IpAddr = match rdata { + RData::A(arec) => arec.0.into(), + RData::AAAA(arec) => arec.0.into(), + + _ => { + tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type"); + return false; + } + }; + let now = Instant::now(); - list.retain(|(expire_at, _)| expire_at > &now); + let expires_at = now + Duration::from_secs(ttl as u64); + let resolution = DnsResolution { + address, + expires_at, + }; + + let mut state = state.lock().await; + match state.cache.entry(query.clone()) { + std::collections::hash_map::Entry::Vacant(vec) => { + vec.insert(vec![resolution]); + } + + std::collections::hash_map::Entry::Occupied(mut occ) => { + let vec = occ.get_mut(); + vec.retain(|x| x.expires_at > now); + vec.push(resolution); + } + } + drop(state); + + true +} + +#[tokio::test] +async fn fetch_cached() { + let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap(); + let name = Name::from_utf8("name.foo").unwrap(); + let addr = IpAddr::from_str("1.2.4.5").unwrap(); + + resolver + .inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000)) + .await; + let read = resolver.lookup(&name).await.unwrap(); + assert!(read.contains(&addr)); } #[tokio::test] async fn uhsure() { - let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap(); + let resolver = Resolver::new(&DnsConfig::default()).await.unwrap(); let name = Name::from_ascii("uhsure.com").unwrap(); let result = resolver.lookup(&name).await.unwrap(); println!("result = {:?}", result); assert!(!result.is_empty()); } - -#[tokio::test] -async fn manual_lookup() { - let mut stream: UdpClientStream = UdpClientStream::with_bind_addr_and_timeout( - SocketAddr::V4(std::net::SocketAddrV4::new( - std::net::Ipv4Addr::new(172, 19, 207, 1), - 53, - )), - None, - Duration::from_secs(3), - ) - .await - .expect("can create UDP DNS client stream"); - - let mut message = Message::new(); - let name = Name::from_ascii("uhsure.com").unwrap(); - message.add_query(Query::query(name.clone(), RecordType::A)); - message.add_query(Query::query(name.clone(), RecordType::AAAA)); - message.set_recursion_desired(true); - - let request = DnsRequest::new(message, Default::default()); - let mut responses = stream.send_message(request); - - while let Some(response) = responses.next().await { - println!("response: {:?}", response); - } - - unimplemented!(); - } diff --git a/src/operational_error.rs b/src/operational_error.rs index 4c4e1e4..e9626b7 100644 --- a/src/operational_error.rs +++ b/src/operational_error.rs @@ -48,4 +48,6 @@ pub enum OperationalError { InvalidHostname(String), #[error("Unable to parse host address")] UnableToParseHostAddress, + #[error("Unable to configure resolver")] + Resolver, } diff --git a/src/server.rs b/src/server.rs index ed92313..a4e0d09 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,15 +14,14 @@ pub enum TopLevelError { } pub async fn run(mut config: ServerConfiguration) -> error_stack::Result<(), TopLevelError> { - let mut server_state = state::ServerState::default(); + let _server_state = state::ServerState::default(); - let listeners = config.generate_listener_sockets().await + let listeners = config + .generate_listener_sockets() + .await .change_context(TopLevelError::ConfigurationError)?; - for (name, listener) in listeners.into_iter() { - } + for (_name, _listener) in listeners.into_iter() {} Ok(()) } - - diff --git a/src/server/socket.rs b/src/server/socket.rs index 4bbe2d2..faa5a13 100644 --- a/src/server/socket.rs +++ b/src/server/socket.rs @@ -10,6 +10,7 @@ pub struct SocketServer { listener: UnixListener, } +#[allow(unused)] impl SocketServer { /// Create a new server that will handle inputs from the client program. /// @@ -18,10 +19,10 @@ impl SocketServer { /// method will take ownership of the object. pub fn new(name: String, listener: UnixListener) -> Self { let path = listener - .local_addr() - .map(|x| x.as_pathname().map(|p| format!("{}", p.display()))) - .unwrap_or_else(|_| None) - .unwrap_or_else(|| format!("unknown")); + .local_addr() + .map(|x| x.as_pathname().map(|p| format!("{}", p.display()))) + .unwrap_or_else(|_| None) + .unwrap_or_else(|| "".to_string()); tracing::trace!(%name, %path, "Creating new socket listener"); @@ -41,7 +42,11 @@ impl SocketServer { /// there, the core task should be unaffected. pub async fn start(mut self) -> error_stack::Result<(), TopLevelError> { loop { - let (stream, addr) = self.listener.accept().await.change_context(TopLevelError::SocketHandlerFailure)?; + let (stream, addr) = self + .listener + .accept() + .await + .change_context(TopLevelError::SocketHandlerFailure)?; let remote_addr = addr .as_pathname() .map(|x| x.display()) @@ -58,8 +63,7 @@ impl SocketServer { self.num_sessions_run += 1; - tokio::task::spawn(Self::run_session(stream) - .instrument(span)) + tokio::task::spawn(Self::run_session(stream).instrument(span)) .await .change_context(TopLevelError::SocketHandlerFailure)?; } @@ -70,6 +74,5 @@ impl SocketServer { /// This is here because it's convenient, not because it shares state (obviously, /// given the type signature). But it's somewhat logically associated with this type, /// so it seems reasonable to make it an associated function. - async fn run_session(handle: UnixStream) { - } + async fn run_session(handle: UnixStream) {} } diff --git a/src/server/state.rs b/src/server/state.rs index 6bfe7ba..1235cf9 100644 --- a/src/server/state.rs +++ b/src/server/state.rs @@ -6,3 +6,26 @@ pub struct ServerState { /// The set of top level tasks that are currently running. top_level_tasks: JoinSet>, } + +impl ServerState { + /// Block until all the current top level tasks have closed. + /// + /// This will log any errors found in these tasks to the current log, + /// if there is one, but will otherwise drop them. + #[allow(unused)] + pub async fn shutdown(&mut self) { + while let Some(next) = self.top_level_tasks.join_next_with_id().await { + match next { + Err(e) => tracing::error!(id = %e.id(), "Failed to attach to top-level task"), + + Ok((id, Err(e))) => { + tracing::error!(%id, "Top-level server error: {}", e); + } + + Ok((id, Ok(()))) => { + tracing::debug!(%id, "Cleanly closed server task."); + } + } + } + } +}