Clean ups and some other stuff?

This commit is contained in:
2025-02-07 20:59:04 -08:00
parent 268ca2d1a5
commit 31cd34d280
17 changed files with 506 additions and 236 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@ Cargo.lock
tarpaulin-report.html
proptest-regressions/
config.toml

View File

@@ -1,37 +1,35 @@
use futures::stream::StreamExt;
use hickory_proto::udp::UdpStream;
use hush::config::server::ServerConfiguration;
use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr};
use std::process;
use thiserror::Error;
use tokio::net::UdpSocket;
async fn run_dns() -> tokio::io::Result<()> {
let socket = UdpSocket::bind("127.0.0.1:3553").await?;
let mut buffer = [0; 65536];
tracing::info!("Bound socket at {}, starting main DNS handler loop.", socket.local_addr()?);
let (first_len, first_addr) = socket.recv_from(&mut buffer).await?;
println!("Received {} bytes from {}", first_len, first_addr);
let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553));
let (mut stream, _handle) = UdpStream::with_bound(socket, remote);
while let Some(item) = stream.next().await {
match item {
Err(e) => eprintln!("Got an error: {:?}", e),
Ok(msg) => {
println!("Got a message from {}", msg.addr());
match msg.to_message() {
Err(e) => println!(" ... but it was malformed ({e})."),
Ok(v) => println!(" ... and it was {v}"),
}
}
}
}
Ok(())
}
//async fn run_dns() -> tokio::io::Result<()> {
// let socket = UdpSocket::bind("127.0.0.1:3553").await?;
// let mut buffer = [0; 65536];
//
// tracing::info!(
// "Bound socket at {}, starting main DNS handler loop.",
// socket.local_addr()?
// );
// let (first_len, first_addr) = socket.recv_from(&mut buffer).await?;
// println!("Received {} bytes from {}", first_len, first_addr);
//
// let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553));
// let (mut stream, _handle) = UdpStream::with_bound(socket, remote);
//
// while let Some(item) = stream.next().await {
// match item {
// Err(e) => eprintln!("Got an error: {:?}", e),
// Ok(msg) => {
// println!("Got a message from {}", msg.addr());
// match msg.to_message() {
// Err(e) => println!(" ... but it was malformed ({e})."),
// Ok(v) => println!(" ... and it was {v}"),
// }
// }
// }
// }
//
// Ok(())
//}
#[cfg(not(tarpaulin_include))]
fn main() {
@@ -57,7 +55,8 @@ fn main() {
};
let result = runtime.block_on(async move {
let version = std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "<unknown>".to_string());
let version =
std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "<unknown>".to_string());
tracing::info!(%version, "Starting Hush server (hushd)");
hush::server::run(config).await
});

View File

@@ -95,15 +95,17 @@ impl FromStr for Target {
}
}
#[allow(unreachable_code,unused_variables)]
#[allow(unreachable_code, unused_variables)]
async fn connect(
base_config: &ClientConfiguration,
target: &str,
) -> error_stack::Result<(), OperationalError> {
let resolver: Resolver = unimplemented!();
// let mut resolver = Resolver::new(&base_config.resolver)
// .await
// .change_context(OperationalError::DnsConfig)?;
let mut resolver: Resolver = Resolver::new(&base_config.dns_config)
.await
.change_context(OperationalError::Resolver)?;
// let mut resolver = Resolver::new(&base_config.resolver)
// .await
// .change_context(OperationalError::DnsConfig)?;
let target = Target::from_str(target)
.change_context(OperationalError::UnableToParseHostAddress)
.attach_printable_lazy(|| format!("target address '{}'", target))?;
@@ -137,7 +139,7 @@ async fn connect(
.change_context(OperationalError::Connection)?;
tracing::trace!("received their preamble");
let _our_preamble = ssh::Preamble::default()
ssh::Preamble::default()
.write(&mut stream)
.instrument(stream_span)
.await

View File

@@ -9,9 +9,9 @@ pub mod resolver;
mod runtime;
pub mod server;
use tokio::runtime::Runtime;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::prelude::*;
use tokio::runtime::Runtime;
//impl ClientConfiguration {
// pub async fn try_from(
@@ -84,4 +84,3 @@ pub fn configured_runtime(
.worker_threads(runtime.tokio_worker_threads)
.build()
}

View File

@@ -1,17 +1,19 @@
use clap::Parser;
use crate::config::command_line::ClientArguments;
pub use crate::config::command_line::ClientCommand;
use crate::config::config_file::ConfigFile;
use crate::config::error::ConfigurationError;
use crate::config::logging::LoggingConfiguration;
use crate::config::resolver::DnsConfig;
use crate::config::runtime::RuntimeConfiguration;
use crate::config::command_line::ClientArguments;
pub use crate::config::command_line::ClientCommand;
use clap::Parser;
use std::ffi::OsString;
use tracing_core::LevelFilter;
pub struct ClientConfiguration {
pub runtime: RuntimeConfiguration,
pub logging: LoggingConfiguration,
config_file: Option<ConfigFile>,
pub dns_config: DnsConfig,
_config_file: Option<ConfigFile>,
command: ClientCommand,
}
@@ -30,25 +32,32 @@ impl ClientConfiguration {
T: Into<OsString> + Clone,
{
let mut command_line_arguments = ClientArguments::try_parse_from(args)?;
let mut config_file = ConfigFile::new(command_line_arguments.arguments.config_file.take())?;
let mut _config_file =
ConfigFile::new(command_line_arguments.arguments.config_file.take())?;
let mut runtime = RuntimeConfiguration::default();
let mut logging = LoggingConfiguration::default();
// we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line,
// it overwrites any config file options.
if let Some(config_file) = config_file.as_mut() {
if let Some(config_file) = _config_file.as_mut() {
config_file.merge_standard_options_into(&mut runtime, &mut logging);
}
command_line_arguments.arguments.merge_standard_options_into(&mut runtime, &mut logging);
command_line_arguments
.arguments
.merge_standard_options_into(&mut runtime, &mut logging);
if command_line_arguments.command.is_list_command() {
logging.filter = LevelFilter::ERROR;
}
// FIXME!!
let dns_config = DnsConfig::default();
Ok(ClientConfiguration {
runtime,
logging,
config_file,
dns_config,
_config_file,
command: command_line_arguments.command,
})
}

View File

@@ -1,5 +1,5 @@
use crate::config::logging::{LogTarget, LoggingConfiguration};
use crate::config::console::ConsoleConfiguration;
use crate::config::logging::{LogTarget, LoggingConfiguration};
use crate::config::runtime::RuntimeConfiguration;
use clap::{Args, Parser, Subcommand};
use std::net::SocketAddr;

View File

@@ -43,7 +43,7 @@ impl Arbitrary for ConfigFile {
keyed_section(ServerConfig::arbitrary()),
)
.prop_map(
|(runtime, logging, resolver,sockets, keys, defaults, servers)| ConfigFile {
|(runtime, logging, resolver, sockets, keys, defaults, servers)| ConfigFile {
runtime,
logging,
resolver,
@@ -88,88 +88,105 @@ pub enum Permissions {
impl From<Permissions> for rustix::fs::Mode {
fn from(value: Permissions) -> Self {
match value {
Permissions::User =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR,
Permissions::User => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
Permissions::Group =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR,
Permissions::UserGroup =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR |
rustix::fs::Mode::RGRP |
rustix::fs::Mode::WGRP,
Permissions::Group => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
Permissions::Everyone =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR |
rustix::fs::Mode::RGRP |
rustix::fs::Mode::WGRP |
rustix::fs::Mode::ROTH |
rustix::fs::Mode::WOTH,
Permissions::UserGroup => {
rustix::fs::Mode::RUSR
| rustix::fs::Mode::WUSR
| rustix::fs::Mode::RGRP
| rustix::fs::Mode::WGRP
}
Permissions::Everyone => {
rustix::fs::Mode::RUSR
| rustix::fs::Mode::WUSR
| rustix::fs::Mode::RGRP
| rustix::fs::Mode::WGRP
| rustix::fs::Mode::ROTH
| rustix::fs::Mode::WOTH
}
}
}
}
impl SocketConfig {
pub async fn into_listener(self) -> Result<tokio::net::UnixListener, ConfigurationError> {
let base = tokio::net::UnixListener::bind(&self.path)
.map_err(|error| ConfigurationError::CouldNotMakeSocket {
let base = tokio::net::UnixListener::bind(&self.path).map_err(|error| {
ConfigurationError::CouldNotMakeSocket {
path: self.path.clone(),
error,
}
})?;
let user = self
.user
.map(|x| {
nix::unistd::User::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| {
Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(
std::io::ErrorKind::NotFound,
"could not find user",
),
})
})
})
.transpose()?;
let group = self
.group
.map(|x| {
nix::unistd::Group::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| {
Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(
std::io::ErrorKind::NotFound,
"could not find user",
),
})
})
})
.transpose()?;
if user.is_some() || group.is_some() {
std::os::unix::fs::chown(
&self.path,
user.map(|x| x.uid.as_raw()),
group.map(|x| x.gid.as_raw()),
)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "set user/group ownership".to_string(),
path: self.path.clone(),
error,
})?;
let user = self.user.map(|x| {
nix::unistd::User::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
}))
}).transpose()?;
let group = self.group.map(|x| {
nix::unistd::Group::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
}))
}).transpose()?;
if user.is_some() || group.is_some() {
std::os::unix::fs::chown(&self.path, user.map(|x| x.uid.as_raw()), group.map(|x| x.gid.as_raw()))
.map_err(|error|
ConfigurationError::CouldNotSetPerms {
thing: "set user/group ownership".to_string(),
path: self.path.clone(),
error,
}
)?;
}
let perms = self.permissions.unwrap_or(Permissions::UserGroup);
rustix::fs::chmod(&self.path, perms.into())
.map_err(|error| ConfigurationError::CouldNotSetPerms {
rustix::fs::chmod(&self.path, perms.into()).map_err(|error| {
ConfigurationError::CouldNotSetPerms {
thing: "set permissions".to_string(),
path: self.path.clone(),
error: error.into(),
})?;
}
})?;
Ok(base)
}
@@ -180,12 +197,18 @@ impl Arbitrary for SocketConfig {
type Parameters = ();
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(PathBuf::arbitrary(),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(Permissions::arbitrary()))
.prop_map(|(path, user, group, permissions)|
SocketConfig{ path, user, group, permissions })
(
PathBuf::arbitrary(),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(Permissions::arbitrary()),
)
.prop_map(|(path, user, group, permissions)| SocketConfig {
path,
user,
group,
permissions,
})
.boxed()
}
}
@@ -200,7 +223,8 @@ impl Arbitrary for Permissions {
proptest::strategy::Just(Permissions::Group),
proptest::strategy::Just(Permissions::UserGroup),
proptest::strategy::Just(Permissions::Everyone),
].boxed()
]
.boxed()
}
}

View File

@@ -10,7 +10,6 @@ use proptest::strategy::{BoxedStrategy, Strategy};
use std::collections::HashSet;
use std::str::FromStr;
#[derive(Debug)]
pub struct ClientConnectionOpts {
pub key_exchange_algorithms: Vec<KeyExchangeAlgorithm>,
@@ -68,10 +67,12 @@ impl Arbitrary for ClientConnectionOpts {
}
}
fn finalize_options<T,E>(values: HashSet<&str>) -> Result<Vec<T>, E>
where T: FromStr<Err=E>
fn finalize_options<T, E>(values: HashSet<&str>) -> Result<Vec<T>, E>
where
T: FromStr<Err = E>,
{
Ok(values.into_iter().map(T::from_str).collect::<Result<Vec<T>,E>>()?)
values
.into_iter()
.map(T::from_str)
.collect::<Result<Vec<T>, E>>()
}

View File

@@ -42,6 +42,24 @@ impl Default for DnsConfig {
}
}
#[cfg(test)]
impl DnsConfig {
pub(crate) fn empty() -> Self {
DnsConfig {
built_in: None,
local_domain: None,
search_domains: vec![],
name_servers: vec![],
retry_attempts: None,
cache_size: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
allow_mdns: None,
}
}
}
impl Arbitrary for DnsConfig {
type Parameters = bool;
type Strategy = BoxedStrategy<Self>;
@@ -178,12 +196,16 @@ impl Arbitrary for NameServerConfig {
proptest::option::of(u64::arbitrary()),
proptest::option::of(SocketAddr::arbitrary()),
)
.prop_map(|(mut address, timeout_in_seconds, mut bind_address)| {
.prop_map(|(mut address, mut timeout_in_seconds, mut bind_address)| {
clear_flow_and_scope_info(&mut address);
if let Some(bind_address) = bind_address.as_mut() {
clear_flow_and_scope_info(bind_address);
}
if let Some(timeout_in_seconds) = timeout_in_seconds.as_mut() {
*timeout_in_seconds &= 0x7FFF_FFFF_FFFF_FFFF;
}
NameServerConfig {
address,
timeout_in_seconds,

View File

@@ -15,5 +15,3 @@ impl Default for RuntimeConfiguration {
}
}
}

View File

@@ -1,9 +1,9 @@
use clap::Parser;
use crate::config::config_file::{ConfigFile, SocketConfig};
use crate::config::command_line::ServerArguments;
use crate::config::config_file::{ConfigFile, SocketConfig};
use crate::config::error::ConfigurationError;
use crate::config::logging::LoggingConfiguration;
use crate::config::runtime::RuntimeConfiguration;
use clap::Parser;
use std::collections::HashMap;
use std::ffi::OsString;
use tokio::net::UnixListener;
@@ -55,7 +55,9 @@ impl ServerConfiguration {
}
/// Generate a series of UNIX sockets that clients will attempt to connect to.
pub async fn generate_listener_sockets(&mut self) -> Result<HashMap<String, UnixListener>, ConfigurationError> {
pub async fn generate_listener_sockets(
&mut self,
) -> Result<HashMap<String, UnixListener>, ConfigurationError> {
let mut results = HashMap::new();
for (name, config) in std::mem::take(&mut self.sockets).into_iter() {

View File

@@ -49,10 +49,25 @@ impl FromStr for Host {
return Ok(Host::IPv6(addr));
}
if let Some(prefix_removed) = s.strip_prefix('[') {
if let Some(cleaned) = prefix_removed.strip_suffix(']') {
match Ipv6Addr::from_str(cleaned) {
Ok(addr) => return Ok(Host::IPv6(addr)),
Err(error) => {
return Err(HostParseError::CouldNotParseIPv6 {
address: s.to_string(),
error,
})
}
}
}
}
if let Ok(name) = Name::from_utf8(s) {
return Ok(Host::Hostname(name));
}
println!(" ... not a hostname");
Err(HostParseError::InvalidHostname {
hostname: s.to_string(),
})

View File

@@ -1,5 +1,6 @@
use crate::config::resolver::{DnsConfig, NameServerConfig};
use error_stack::report;
use futures::future::select;
use futures::stream::{SelectAll, StreamExt};
use hickory_client::rr::Name;
use hickory_proto::error::ProtoError;
@@ -13,9 +14,14 @@ use hickory_proto::xfer::dns_request::DnsRequest;
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr};
#[cfg(test)]
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::time::{Duration, Instant};
use tokio::sync::{oneshot, Mutex};
use tokio::task::JoinSet;
use tokio::time::{timeout_at, Duration, Instant};
#[derive(Debug, Error)]
pub enum ResolverConfigError {
@@ -44,8 +50,20 @@ pub enum ResolveError {
pub struct Resolver {
search_domains: Vec<Name>,
max_time_to_wait_for_initial: Duration,
time_to_wait_after_first: Duration,
time_to_wait_for_lingering: Duration,
state: Arc<Mutex<ResolverState>>,
}
pub struct ResolverState {
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>,
cache: HashMap<Name, Vec<(Instant, IpAddr)>>,
cache: HashMap<Name, Vec<DnsResolution>>,
}
pub struct DnsResolution {
address: IpAddr,
expires_at: Instant,
}
impl Resolver {
@@ -76,7 +94,7 @@ impl Resolver {
for name_server_config in config.name_servers() {
let stream = UdpClientStream::with_bind_addr_and_timeout(
name_server_config.address,
name_server_config.bind_address.clone(),
name_server_config.bind_address,
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)),
)
.await
@@ -90,54 +108,110 @@ impl Resolver {
Ok(Resolver {
search_domains,
client_connections,
cache: HashMap::new(),
// FIXME: All of these should be configurable
max_time_to_wait_for_initial: Duration::from_millis(150),
time_to_wait_after_first: Duration::from_millis(50),
time_to_wait_for_lingering: Duration::from_secs(2),
state: Arc::new(Mutex::new(ResolverState {
client_connections,
cache: HashMap::new(),
})),
})
}
/// Look up the address of the given name, returning either a set of results
/// we've received in a reasonable amount of time, or an error.
pub async fn lookup(&mut self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
pub async fn lookup(&self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
let names = self.expand_name(name);
let mut response_stream = self.create_response_stream(names).await;
let cached_values = self.cached_lookup(&names).await;
if !cached_values.is_empty() {
return Ok(cached_values);
}
let mut response_stream = self.create_response_stream(&names).await;
if response_stream.is_empty() {
return Err(ResolveError::NoServersAvailable);
}
let mut first_error = None;
let get_first_by = Instant::now() + self.max_time_to_wait_for_initial;
let got_responses =
get_responses_by(&self.state, get_first_by, name, &mut response_stream).await?;
while let Some(response) = response_stream.next().await {
match response {
Err(e) => {
if first_error.is_none() {
first_error = Some(e);
}
}
if got_responses {
let get_rest_by = Instant::now() + self.time_to_wait_after_first;
let _ = get_responses_by(&self.state, get_rest_by, name, &mut response_stream).await;
}
Ok(response) => {
for answer in response.into_message().take_answers() {
self.handle_response(name, answer).await;
let lingering_deadline = Instant::now() + self.time_to_wait_for_lingering;
let state_copy = self.state.clone();
let name_copy = name.clone();
tokio::task::spawn(async move {
let _ = get_responses_by(
&state_copy,
lingering_deadline,
&name_copy,
&mut response_stream,
)
.await;
});
let after_search_lookup = self.cached_lookup(&names).await;
if after_search_lookup.is_empty() {
Err(ResolveError::NoResponses)
} else {
Ok(after_search_lookup)
}
}
/// Look up any cached IP address we might have for the given names.
///
/// As a side effect, this function will clean out any expired entries in the
/// cache.
async fn cached_lookup(&self, names: &[Name]) -> HashSet<IpAddr> {
let mut retval = HashSet::new();
let mut state = self.state.lock().await;
for name in names {
if let Some(mut existing) = state.cache.remove(name) {
let now = Instant::now();
existing.retain(|item| {
let keeper = item.expires_at > now;
if keeper {
retval.insert(item.address);
}
keeper
});
if !existing.is_empty() {
state.cache.insert(name.clone(), existing);
}
}
}
match first_error {
None => Err(ResolveError::NoResponses),
Some(error) => Err(ResolveError::ResponseError { error }),
}
retval
}
async fn create_response_stream(&mut self, names: Vec<Name>) -> SelectAll<DnsResponseStream> {
let connections = std::mem::take(&mut self.client_connections);
/// Reach out to all of our current connections, and start the process of getting
/// responses.
///
/// If the clients are shut down, we'll try to create new connections to them, or
/// remove them from our connection list if we can't. If there are no connections,
/// the `SelectAll` will be empty, and callers are advised to check for this
/// condition and inform the user that they're out of useful DNS servers.
async fn create_response_stream(&self, names: &[Name]) -> SelectAll<DnsResponseStream> {
let mut state = self.state.lock().await;
let connections = std::mem::take(&mut state.client_connections);
let mut response_stream = futures::stream::SelectAll::new();
for (config, mut client) in connections.into_iter() {
if client.is_shutdown() {
let stream = UdpClientStream::with_bind_addr_and_timeout(
config.address,
config.bind_address.clone(),
config.bind_address,
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
)
.await;
@@ -158,7 +232,7 @@ impl Resolver {
let request = DnsRequest::new(message, Default::default());
response_stream.push(client.send_message(request));
self.client_connections.push((config, client));
state.client_connections.push((config, client));
}
response_stream
@@ -177,85 +251,182 @@ impl Resolver {
names
}
/// Handle an individual response from a server.
/// Run a cleaner task on this Resolver object.
///
/// Returns true if we got an answer from the server, and updated the internal cache.
/// This answer is used externally to set a timer, so that we don't wait the full DNS
/// timeout period for answers to come back.
async fn handle_response(&mut self, query: &Name, record: Record) -> bool {
let ttl = record.ttl();
/// The task will run in the given JoinSet, set it can be easily killed, tracked, waited
/// for as desired. Every `period` interval, it will grab the state lock and clean out
/// any entries that are timed out. Ideally, you shouldn't set this too frequently, so
/// that this process doesn't interfere with other tasks trying to use the resolver.
///
/// This task will only weakly hang on to the resolver state, so that if the Resolver
/// object is dropped it will quietly stop running. It will also stop running if it sees
/// a message on the kill signal, if one is provided, or if one is provided and the other
/// end drops the sender. (So if you don't want a kill signal, send `None`, or you might
/// accidentally kill your GC task too early.)
pub async fn run_resolver_gc(
&self,
task_set: &mut JoinSet<Result<(), ()>>,
mut kill_signal: Option<oneshot::Receiver<()>>,
interval: Duration,
) {
let weak_locked_state = Arc::downgrade(&self.state);
let Some(rdata) = record.into_data() else {
tracing::error!("for some reason, couldn't process incoming message");
return false;
task_set.spawn(async move {
loop {
tokio::time::sleep(interval).await;
let Some(locked_state) = weak_locked_state.upgrade() else {
break Ok(());
};
let mut state = if let Some(kill_signal) = kill_signal.as_mut() {
let locker = locked_state.lock();
futures::pin_mut!(locker);
match select(locker, kill_signal).await {
futures::future::Either::Left((state, _)) => state,
futures::future::Either::Right((_, _)) => break Ok(()),
}
} else {
locked_state.lock().await
};
let now = Instant::now();
state.cache.retain(|_, value| {
value.retain(|resolution| resolution.expires_at < now);
!value.is_empty()
});
drop(state);
}
});
}
/// Inject a mapping into the resolver, with the given TTL.
///
/// This is largely used for testing purposes, but could be used if there are static
/// addresses you want to always resolve in a particular way. For that use case, you
/// probably want to add absurdly-long TTLs for those names.
pub async fn inject_resolution(&self, name: Name, address: IpAddr, ttl: Duration) {
let resolution = DnsResolution {
address,
expires_at: Instant::now() + ttl,
};
let response: IpAddr = match rdata {
RData::A(arec) => arec.0.into(),
RData::AAAA(arec) => arec.0.into(),
_ => {
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
return false;
let mut state = self.state.lock().await;
match state.cache.entry(name) {
std::collections::hash_map::Entry::Vacant(vac) => {
vac.insert(vec![resolution]);
}
};
let expire_at = Instant::now() + Duration::from_secs(ttl as u64);
match self.cache.entry(query.clone()) {
std::collections::hash_map::Entry::Vacant(vec) => {
vec.insert(vec![(expire_at, response.clone())]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => {
clean_expired_entries(occ.get_mut());
occ.get_mut().push((expire_at, response.clone()));
occ.get_mut().push(resolution);
}
}
true
drop(state);
}
}
fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) {
/// Get all the responses into cache that occur before the given timestamp.
///
/// Returns an error if all we get are errors while trying to read from these servers,
/// otherwise returns Ok(true) if we received some interesting responses, or Ok(false)
/// if we didn't. Note that receiving good responses wins out over receiving errors, so
/// if we receive 2 good responses and 3 errors, this function will return Ok(true).
async fn get_responses_by(
state: &Arc<Mutex<ResolverState>>,
deadline: Instant,
name: &Name,
stream: &mut SelectAll<DnsResponseStream>,
) -> Result<bool, ResolveError> {
let mut received_error = None;
let mut got_something = false;
loop {
match timeout_at(deadline, stream.next()).await {
Err(_) => return received_error.unwrap_or(Ok(got_something)),
Ok(None) if got_something => return Ok(got_something),
Ok(None) => return received_error.unwrap_or(Ok(false)),
Ok(Some(Err(error))) => {
if received_error.is_none() {
received_error = Some(Err(ResolveError::ResponseError { error }));
}
}
Ok(Some(Ok(response))) => {
for answer in response.into_message().take_answers() {
got_something |= handle_response(state, name, answer).await;
}
}
}
}
}
/// Handle an individual response from a server.
///
/// Returns true if we got an answer from the server, and updates the internal cache.
/// This function will never return an error, although there may be some odd processing
/// situations that could lead it to printing warnings to the logs.
async fn handle_response(state: &Arc<Mutex<ResolverState>>, query: &Name, record: Record) -> bool {
let ttl = record.ttl();
let Some(rdata) = record.into_data() else {
tracing::error!("for some reason, couldn't process incoming message");
return false;
};
let address: IpAddr = match rdata {
RData::A(arec) => arec.0.into(),
RData::AAAA(arec) => arec.0.into(),
_ => {
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
return false;
}
};
let now = Instant::now();
list.retain(|(expire_at, _)| expire_at > &now);
let expires_at = now + Duration::from_secs(ttl as u64);
let resolution = DnsResolution {
address,
expires_at,
};
let mut state = state.lock().await;
match state.cache.entry(query.clone()) {
std::collections::hash_map::Entry::Vacant(vec) => {
vec.insert(vec![resolution]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => {
let vec = occ.get_mut();
vec.retain(|x| x.expires_at > now);
vec.push(resolution);
}
}
drop(state);
true
}
#[tokio::test]
async fn fetch_cached() {
let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
let name = Name::from_utf8("name.foo").unwrap();
let addr = IpAddr::from_str("1.2.4.5").unwrap();
resolver
.inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
.await;
let read = resolver.lookup(&name).await.unwrap();
assert!(read.contains(&addr));
}
#[tokio::test]
async fn uhsure() {
let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
let name = Name::from_ascii("uhsure.com").unwrap();
let result = resolver.lookup(&name).await.unwrap();
println!("result = {:?}", result);
assert!(!result.is_empty());
}
#[tokio::test]
async fn manual_lookup() {
let mut stream: UdpClientStream<UdpSocket> = UdpClientStream::with_bind_addr_and_timeout(
SocketAddr::V4(std::net::SocketAddrV4::new(
std::net::Ipv4Addr::new(172, 19, 207, 1),
53,
)),
None,
Duration::from_secs(3),
)
.await
.expect("can create UDP DNS client stream");
let mut message = Message::new();
let name = Name::from_ascii("uhsure.com").unwrap();
message.add_query(Query::query(name.clone(), RecordType::A));
message.add_query(Query::query(name.clone(), RecordType::AAAA));
message.set_recursion_desired(true);
let request = DnsRequest::new(message, Default::default());
let mut responses = stream.send_message(request);
while let Some(response) = responses.next().await {
println!("response: {:?}", response);
}
unimplemented!();
}

View File

@@ -48,4 +48,6 @@ pub enum OperationalError {
InvalidHostname(String),
#[error("Unable to parse host address")]
UnableToParseHostAddress,
#[error("Unable to configure resolver")]
Resolver,
}

View File

@@ -14,15 +14,14 @@ pub enum TopLevelError {
}
pub async fn run(mut config: ServerConfiguration) -> error_stack::Result<(), TopLevelError> {
let mut server_state = state::ServerState::default();
let _server_state = state::ServerState::default();
let listeners = config.generate_listener_sockets().await
let listeners = config
.generate_listener_sockets()
.await
.change_context(TopLevelError::ConfigurationError)?;
for (name, listener) in listeners.into_iter() {
}
for (_name, _listener) in listeners.into_iter() {}
Ok(())
}

View File

@@ -10,6 +10,7 @@ pub struct SocketServer {
listener: UnixListener,
}
#[allow(unused)]
impl SocketServer {
/// Create a new server that will handle inputs from the client program.
///
@@ -18,10 +19,10 @@ impl SocketServer {
/// method will take ownership of the object.
pub fn new(name: String, listener: UnixListener) -> Self {
let path = listener
.local_addr()
.map(|x| x.as_pathname().map(|p| format!("{}", p.display())))
.unwrap_or_else(|_| None)
.unwrap_or_else(|| format!("unknown"));
.local_addr()
.map(|x| x.as_pathname().map(|p| format!("{}", p.display())))
.unwrap_or_else(|_| None)
.unwrap_or_else(|| "<unknown>".to_string());
tracing::trace!(%name, %path, "Creating new socket listener");
@@ -41,7 +42,11 @@ impl SocketServer {
/// there, the core task should be unaffected.
pub async fn start(mut self) -> error_stack::Result<(), TopLevelError> {
loop {
let (stream, addr) = self.listener.accept().await.change_context(TopLevelError::SocketHandlerFailure)?;
let (stream, addr) = self
.listener
.accept()
.await
.change_context(TopLevelError::SocketHandlerFailure)?;
let remote_addr = addr
.as_pathname()
.map(|x| x.display())
@@ -58,8 +63,7 @@ impl SocketServer {
self.num_sessions_run += 1;
tokio::task::spawn(Self::run_session(stream)
.instrument(span))
tokio::task::spawn(Self::run_session(stream).instrument(span))
.await
.change_context(TopLevelError::SocketHandlerFailure)?;
}
@@ -70,6 +74,5 @@ impl SocketServer {
/// This is here because it's convenient, not because it shares state (obviously,
/// given the type signature). But it's somewhat logically associated with this type,
/// so it seems reasonable to make it an associated function.
async fn run_session(handle: UnixStream) {
}
async fn run_session(handle: UnixStream) {}
}

View File

@@ -6,3 +6,26 @@ pub struct ServerState {
/// The set of top level tasks that are currently running.
top_level_tasks: JoinSet<Result<(), TopLevelError>>,
}
impl ServerState {
/// Block until all the current top level tasks have closed.
///
/// This will log any errors found in these tasks to the current log,
/// if there is one, but will otherwise drop them.
#[allow(unused)]
pub async fn shutdown(&mut self) {
while let Some(next) = self.top_level_tasks.join_next_with_id().await {
match next {
Err(e) => tracing::error!(id = %e.id(), "Failed to attach to top-level task"),
Ok((id, Err(e))) => {
tracing::error!(%id, "Top-level server error: {}", e);
}
Ok((id, Ok(()))) => {
tracing::debug!(%id, "Cleanly closed server task.");
}
}
}
}
}