Clean ups and some other stuff?
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@ Cargo.lock
|
||||
|
||||
tarpaulin-report.html
|
||||
proptest-regressions/
|
||||
config.toml
|
||||
|
||||
@@ -1,37 +1,35 @@
|
||||
use futures::stream::StreamExt;
|
||||
use hickory_proto::udp::UdpStream;
|
||||
use hush::config::server::ServerConfiguration;
|
||||
use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr};
|
||||
use std::process;
|
||||
use thiserror::Error;
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
async fn run_dns() -> tokio::io::Result<()> {
|
||||
let socket = UdpSocket::bind("127.0.0.1:3553").await?;
|
||||
let mut buffer = [0; 65536];
|
||||
|
||||
tracing::info!("Bound socket at {}, starting main DNS handler loop.", socket.local_addr()?);
|
||||
let (first_len, first_addr) = socket.recv_from(&mut buffer).await?;
|
||||
println!("Received {} bytes from {}", first_len, first_addr);
|
||||
|
||||
let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553));
|
||||
let (mut stream, _handle) = UdpStream::with_bound(socket, remote);
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Err(e) => eprintln!("Got an error: {:?}", e),
|
||||
Ok(msg) => {
|
||||
println!("Got a message from {}", msg.addr());
|
||||
match msg.to_message() {
|
||||
Err(e) => println!(" ... but it was malformed ({e})."),
|
||||
Ok(v) => println!(" ... and it was {v}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
//async fn run_dns() -> tokio::io::Result<()> {
|
||||
// let socket = UdpSocket::bind("127.0.0.1:3553").await?;
|
||||
// let mut buffer = [0; 65536];
|
||||
//
|
||||
// tracing::info!(
|
||||
// "Bound socket at {}, starting main DNS handler loop.",
|
||||
// socket.local_addr()?
|
||||
// );
|
||||
// let (first_len, first_addr) = socket.recv_from(&mut buffer).await?;
|
||||
// println!("Received {} bytes from {}", first_len, first_addr);
|
||||
//
|
||||
// let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553));
|
||||
// let (mut stream, _handle) = UdpStream::with_bound(socket, remote);
|
||||
//
|
||||
// while let Some(item) = stream.next().await {
|
||||
// match item {
|
||||
// Err(e) => eprintln!("Got an error: {:?}", e),
|
||||
// Ok(msg) => {
|
||||
// println!("Got a message from {}", msg.addr());
|
||||
// match msg.to_message() {
|
||||
// Err(e) => println!(" ... but it was malformed ({e})."),
|
||||
// Ok(v) => println!(" ... and it was {v}"),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
//}
|
||||
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
fn main() {
|
||||
@@ -57,7 +55,8 @@ fn main() {
|
||||
};
|
||||
|
||||
let result = runtime.block_on(async move {
|
||||
let version = std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "<unknown>".to_string());
|
||||
let version =
|
||||
std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "<unknown>".to_string());
|
||||
tracing::info!(%version, "Starting Hush server (hushd)");
|
||||
hush::server::run(config).await
|
||||
});
|
||||
|
||||
@@ -95,15 +95,17 @@ impl FromStr for Target {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unreachable_code,unused_variables)]
|
||||
#[allow(unreachable_code, unused_variables)]
|
||||
async fn connect(
|
||||
base_config: &ClientConfiguration,
|
||||
target: &str,
|
||||
) -> error_stack::Result<(), OperationalError> {
|
||||
let resolver: Resolver = unimplemented!();
|
||||
// let mut resolver = Resolver::new(&base_config.resolver)
|
||||
// .await
|
||||
// .change_context(OperationalError::DnsConfig)?;
|
||||
let mut resolver: Resolver = Resolver::new(&base_config.dns_config)
|
||||
.await
|
||||
.change_context(OperationalError::Resolver)?;
|
||||
// let mut resolver = Resolver::new(&base_config.resolver)
|
||||
// .await
|
||||
// .change_context(OperationalError::DnsConfig)?;
|
||||
let target = Target::from_str(target)
|
||||
.change_context(OperationalError::UnableToParseHostAddress)
|
||||
.attach_printable_lazy(|| format!("target address '{}'", target))?;
|
||||
@@ -137,7 +139,7 @@ async fn connect(
|
||||
.change_context(OperationalError::Connection)?;
|
||||
tracing::trace!("received their preamble");
|
||||
|
||||
let _our_preamble = ssh::Preamble::default()
|
||||
ssh::Preamble::default()
|
||||
.write(&mut stream)
|
||||
.instrument(stream_span)
|
||||
.await
|
||||
|
||||
@@ -9,9 +9,9 @@ pub mod resolver;
|
||||
mod runtime;
|
||||
pub mod server;
|
||||
|
||||
use tokio::runtime::Runtime;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
//impl ClientConfiguration {
|
||||
// pub async fn try_from(
|
||||
@@ -84,4 +84,3 @@ pub fn configured_runtime(
|
||||
.worker_threads(runtime.tokio_worker_threads)
|
||||
.build()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
use clap::Parser;
|
||||
use crate::config::command_line::ClientArguments;
|
||||
pub use crate::config::command_line::ClientCommand;
|
||||
use crate::config::config_file::ConfigFile;
|
||||
use crate::config::error::ConfigurationError;
|
||||
use crate::config::logging::LoggingConfiguration;
|
||||
use crate::config::resolver::DnsConfig;
|
||||
use crate::config::runtime::RuntimeConfiguration;
|
||||
use crate::config::command_line::ClientArguments;
|
||||
pub use crate::config::command_line::ClientCommand;
|
||||
use clap::Parser;
|
||||
use std::ffi::OsString;
|
||||
use tracing_core::LevelFilter;
|
||||
|
||||
pub struct ClientConfiguration {
|
||||
pub runtime: RuntimeConfiguration,
|
||||
pub logging: LoggingConfiguration,
|
||||
config_file: Option<ConfigFile>,
|
||||
pub dns_config: DnsConfig,
|
||||
_config_file: Option<ConfigFile>,
|
||||
command: ClientCommand,
|
||||
}
|
||||
|
||||
@@ -30,25 +32,32 @@ impl ClientConfiguration {
|
||||
T: Into<OsString> + Clone,
|
||||
{
|
||||
let mut command_line_arguments = ClientArguments::try_parse_from(args)?;
|
||||
let mut config_file = ConfigFile::new(command_line_arguments.arguments.config_file.take())?;
|
||||
let mut _config_file =
|
||||
ConfigFile::new(command_line_arguments.arguments.config_file.take())?;
|
||||
let mut runtime = RuntimeConfiguration::default();
|
||||
let mut logging = LoggingConfiguration::default();
|
||||
|
||||
// we prefer the command line to the config file, so first merge
|
||||
// in the config file so that when we later merge the command line,
|
||||
// it overwrites any config file options.
|
||||
if let Some(config_file) = config_file.as_mut() {
|
||||
if let Some(config_file) = _config_file.as_mut() {
|
||||
config_file.merge_standard_options_into(&mut runtime, &mut logging);
|
||||
}
|
||||
command_line_arguments.arguments.merge_standard_options_into(&mut runtime, &mut logging);
|
||||
command_line_arguments
|
||||
.arguments
|
||||
.merge_standard_options_into(&mut runtime, &mut logging);
|
||||
if command_line_arguments.command.is_list_command() {
|
||||
logging.filter = LevelFilter::ERROR;
|
||||
}
|
||||
|
||||
// FIXME!!
|
||||
let dns_config = DnsConfig::default();
|
||||
|
||||
Ok(ClientConfiguration {
|
||||
runtime,
|
||||
logging,
|
||||
config_file,
|
||||
dns_config,
|
||||
_config_file,
|
||||
command: command_line_arguments.command,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::config::logging::{LogTarget, LoggingConfiguration};
|
||||
use crate::config::console::ConsoleConfiguration;
|
||||
use crate::config::logging::{LogTarget, LoggingConfiguration};
|
||||
use crate::config::runtime::RuntimeConfiguration;
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
@@ -43,7 +43,7 @@ impl Arbitrary for ConfigFile {
|
||||
keyed_section(ServerConfig::arbitrary()),
|
||||
)
|
||||
.prop_map(
|
||||
|(runtime, logging, resolver,sockets, keys, defaults, servers)| ConfigFile {
|
||||
|(runtime, logging, resolver, sockets, keys, defaults, servers)| ConfigFile {
|
||||
runtime,
|
||||
logging,
|
||||
resolver,
|
||||
@@ -88,40 +88,41 @@ pub enum Permissions {
|
||||
impl From<Permissions> for rustix::fs::Mode {
|
||||
fn from(value: Permissions) -> Self {
|
||||
match value {
|
||||
Permissions::User =>
|
||||
rustix::fs::Mode::RUSR |
|
||||
rustix::fs::Mode::WUSR,
|
||||
Permissions::User => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
|
||||
|
||||
Permissions::Group =>
|
||||
rustix::fs::Mode::RUSR |
|
||||
rustix::fs::Mode::WUSR,
|
||||
Permissions::Group => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
|
||||
|
||||
Permissions::UserGroup =>
|
||||
rustix::fs::Mode::RUSR |
|
||||
rustix::fs::Mode::WUSR |
|
||||
rustix::fs::Mode::RGRP |
|
||||
rustix::fs::Mode::WGRP,
|
||||
Permissions::UserGroup => {
|
||||
rustix::fs::Mode::RUSR
|
||||
| rustix::fs::Mode::WUSR
|
||||
| rustix::fs::Mode::RGRP
|
||||
| rustix::fs::Mode::WGRP
|
||||
}
|
||||
|
||||
Permissions::Everyone =>
|
||||
rustix::fs::Mode::RUSR |
|
||||
rustix::fs::Mode::WUSR |
|
||||
rustix::fs::Mode::RGRP |
|
||||
rustix::fs::Mode::WGRP |
|
||||
rustix::fs::Mode::ROTH |
|
||||
rustix::fs::Mode::WOTH,
|
||||
Permissions::Everyone => {
|
||||
rustix::fs::Mode::RUSR
|
||||
| rustix::fs::Mode::WUSR
|
||||
| rustix::fs::Mode::RGRP
|
||||
| rustix::fs::Mode::WGRP
|
||||
| rustix::fs::Mode::ROTH
|
||||
| rustix::fs::Mode::WOTH
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SocketConfig {
|
||||
pub async fn into_listener(self) -> Result<tokio::net::UnixListener, ConfigurationError> {
|
||||
let base = tokio::net::UnixListener::bind(&self.path)
|
||||
.map_err(|error| ConfigurationError::CouldNotMakeSocket {
|
||||
let base = tokio::net::UnixListener::bind(&self.path).map_err(|error| {
|
||||
ConfigurationError::CouldNotMakeSocket {
|
||||
path: self.path.clone(),
|
||||
error,
|
||||
}
|
||||
})?;
|
||||
|
||||
let user = self.user.map(|x| {
|
||||
let user = self
|
||||
.user
|
||||
.map(|x| {
|
||||
nix::unistd::User::from_name(&x)
|
||||
.map_err(|error| ConfigurationError::CouldNotSetPerms {
|
||||
thing: "find user".to_string(),
|
||||
@@ -129,14 +130,22 @@ impl SocketConfig {
|
||||
error: error.into(),
|
||||
})
|
||||
.transpose()
|
||||
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
|
||||
.unwrap_or_else(|| {
|
||||
Err(ConfigurationError::CouldNotSetPerms {
|
||||
thing: "find user".to_string(),
|
||||
path: self.path.clone(),
|
||||
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
|
||||
}))
|
||||
}).transpose()?;
|
||||
error: std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"could not find user",
|
||||
),
|
||||
})
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let group = self.group.map(|x| {
|
||||
let group = self
|
||||
.group
|
||||
.map(|x| {
|
||||
nix::unistd::Group::from_name(&x)
|
||||
.map_err(|error| ConfigurationError::CouldNotSetPerms {
|
||||
thing: "find user".to_string(),
|
||||
@@ -144,33 +153,41 @@ impl SocketConfig {
|
||||
error: error.into(),
|
||||
})
|
||||
.transpose()
|
||||
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
|
||||
.unwrap_or_else(|| {
|
||||
Err(ConfigurationError::CouldNotSetPerms {
|
||||
thing: "find user".to_string(),
|
||||
path: self.path.clone(),
|
||||
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
|
||||
}))
|
||||
}).transpose()?;
|
||||
error: std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"could not find user",
|
||||
),
|
||||
})
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
if user.is_some() || group.is_some() {
|
||||
std::os::unix::fs::chown(&self.path, user.map(|x| x.uid.as_raw()), group.map(|x| x.gid.as_raw()))
|
||||
.map_err(|error|
|
||||
ConfigurationError::CouldNotSetPerms {
|
||||
std::os::unix::fs::chown(
|
||||
&self.path,
|
||||
user.map(|x| x.uid.as_raw()),
|
||||
group.map(|x| x.gid.as_raw()),
|
||||
)
|
||||
.map_err(|error| ConfigurationError::CouldNotSetPerms {
|
||||
thing: "set user/group ownership".to_string(),
|
||||
path: self.path.clone(),
|
||||
error,
|
||||
}
|
||||
)?;
|
||||
})?;
|
||||
}
|
||||
|
||||
let perms = self.permissions.unwrap_or(Permissions::UserGroup);
|
||||
rustix::fs::chmod(&self.path, perms.into())
|
||||
.map_err(|error| ConfigurationError::CouldNotSetPerms {
|
||||
rustix::fs::chmod(&self.path, perms.into()).map_err(|error| {
|
||||
ConfigurationError::CouldNotSetPerms {
|
||||
thing: "set permissions".to_string(),
|
||||
path: self.path.clone(),
|
||||
error: error.into(),
|
||||
}
|
||||
})?;
|
||||
|
||||
|
||||
Ok(base)
|
||||
}
|
||||
}
|
||||
@@ -180,12 +197,18 @@ impl Arbitrary for SocketConfig {
|
||||
type Parameters = ();
|
||||
|
||||
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
|
||||
(PathBuf::arbitrary(),
|
||||
(
|
||||
PathBuf::arbitrary(),
|
||||
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
|
||||
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
|
||||
proptest::option::of(Permissions::arbitrary()))
|
||||
.prop_map(|(path, user, group, permissions)|
|
||||
SocketConfig{ path, user, group, permissions })
|
||||
proptest::option::of(Permissions::arbitrary()),
|
||||
)
|
||||
.prop_map(|(path, user, group, permissions)| SocketConfig {
|
||||
path,
|
||||
user,
|
||||
group,
|
||||
permissions,
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
@@ -200,7 +223,8 @@ impl Arbitrary for Permissions {
|
||||
proptest::strategy::Just(Permissions::Group),
|
||||
proptest::strategy::Just(Permissions::UserGroup),
|
||||
proptest::strategy::Just(Permissions::Everyone),
|
||||
].boxed()
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ use proptest::strategy::{BoxedStrategy, Strategy};
|
||||
use std::collections::HashSet;
|
||||
use std::str::FromStr;
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientConnectionOpts {
|
||||
pub key_exchange_algorithms: Vec<KeyExchangeAlgorithm>,
|
||||
@@ -68,10 +67,12 @@ impl Arbitrary for ClientConnectionOpts {
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_options<T,E>(values: HashSet<&str>) -> Result<Vec<T>, E>
|
||||
where T: FromStr<Err=E>
|
||||
fn finalize_options<T, E>(values: HashSet<&str>) -> Result<Vec<T>, E>
|
||||
where
|
||||
T: FromStr<Err = E>,
|
||||
{
|
||||
Ok(values.into_iter().map(T::from_str).collect::<Result<Vec<T>,E>>()?)
|
||||
values
|
||||
.into_iter()
|
||||
.map(T::from_str)
|
||||
.collect::<Result<Vec<T>, E>>()
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,24 @@ impl Default for DnsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl DnsConfig {
|
||||
pub(crate) fn empty() -> Self {
|
||||
DnsConfig {
|
||||
built_in: None,
|
||||
local_domain: None,
|
||||
search_domains: vec![],
|
||||
name_servers: vec![],
|
||||
retry_attempts: None,
|
||||
cache_size: None,
|
||||
max_concurrent_requests_for_query: None,
|
||||
preserve_intermediates: None,
|
||||
shuffle_dns_servers: None,
|
||||
allow_mdns: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for DnsConfig {
|
||||
type Parameters = bool;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
@@ -178,12 +196,16 @@ impl Arbitrary for NameServerConfig {
|
||||
proptest::option::of(u64::arbitrary()),
|
||||
proptest::option::of(SocketAddr::arbitrary()),
|
||||
)
|
||||
.prop_map(|(mut address, timeout_in_seconds, mut bind_address)| {
|
||||
.prop_map(|(mut address, mut timeout_in_seconds, mut bind_address)| {
|
||||
clear_flow_and_scope_info(&mut address);
|
||||
if let Some(bind_address) = bind_address.as_mut() {
|
||||
clear_flow_and_scope_info(bind_address);
|
||||
}
|
||||
|
||||
if let Some(timeout_in_seconds) = timeout_in_seconds.as_mut() {
|
||||
*timeout_in_seconds &= 0x7FFF_FFFF_FFFF_FFFF;
|
||||
}
|
||||
|
||||
NameServerConfig {
|
||||
address,
|
||||
timeout_in_seconds,
|
||||
|
||||
@@ -15,5 +15,3 @@ impl Default for RuntimeConfiguration {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use clap::Parser;
|
||||
use crate::config::config_file::{ConfigFile, SocketConfig};
|
||||
use crate::config::command_line::ServerArguments;
|
||||
use crate::config::config_file::{ConfigFile, SocketConfig};
|
||||
use crate::config::error::ConfigurationError;
|
||||
use crate::config::logging::LoggingConfiguration;
|
||||
use crate::config::runtime::RuntimeConfiguration;
|
||||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use tokio::net::UnixListener;
|
||||
@@ -55,7 +55,9 @@ impl ServerConfiguration {
|
||||
}
|
||||
|
||||
/// Generate a series of UNIX sockets that clients will attempt to connect to.
|
||||
pub async fn generate_listener_sockets(&mut self) -> Result<HashMap<String, UnixListener>, ConfigurationError> {
|
||||
pub async fn generate_listener_sockets(
|
||||
&mut self,
|
||||
) -> Result<HashMap<String, UnixListener>, ConfigurationError> {
|
||||
let mut results = HashMap::new();
|
||||
|
||||
for (name, config) in std::mem::take(&mut self.sockets).into_iter() {
|
||||
|
||||
@@ -49,10 +49,25 @@ impl FromStr for Host {
|
||||
return Ok(Host::IPv6(addr));
|
||||
}
|
||||
|
||||
if let Some(prefix_removed) = s.strip_prefix('[') {
|
||||
if let Some(cleaned) = prefix_removed.strip_suffix(']') {
|
||||
match Ipv6Addr::from_str(cleaned) {
|
||||
Ok(addr) => return Ok(Host::IPv6(addr)),
|
||||
Err(error) => {
|
||||
return Err(HostParseError::CouldNotParseIPv6 {
|
||||
address: s.to_string(),
|
||||
error,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(name) = Name::from_utf8(s) {
|
||||
return Ok(Host::Hostname(name));
|
||||
}
|
||||
|
||||
println!(" ... not a hostname");
|
||||
Err(HostParseError::InvalidHostname {
|
||||
hostname: s.to_string(),
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::config::resolver::{DnsConfig, NameServerConfig};
|
||||
use error_stack::report;
|
||||
use futures::future::select;
|
||||
use futures::stream::{SelectAll, StreamExt};
|
||||
use hickory_client::rr::Name;
|
||||
use hickory_proto::error::ProtoError;
|
||||
@@ -13,9 +14,14 @@ use hickory_proto::xfer::dns_request::DnsRequest;
|
||||
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
#[cfg(test)]
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{timeout_at, Duration, Instant};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ResolverConfigError {
|
||||
@@ -44,8 +50,20 @@ pub enum ResolveError {
|
||||
|
||||
pub struct Resolver {
|
||||
search_domains: Vec<Name>,
|
||||
max_time_to_wait_for_initial: Duration,
|
||||
time_to_wait_after_first: Duration,
|
||||
time_to_wait_for_lingering: Duration,
|
||||
state: Arc<Mutex<ResolverState>>,
|
||||
}
|
||||
|
||||
pub struct ResolverState {
|
||||
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>,
|
||||
cache: HashMap<Name, Vec<(Instant, IpAddr)>>,
|
||||
cache: HashMap<Name, Vec<DnsResolution>>,
|
||||
}
|
||||
|
||||
pub struct DnsResolution {
|
||||
address: IpAddr,
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
@@ -76,7 +94,7 @@ impl Resolver {
|
||||
for name_server_config in config.name_servers() {
|
||||
let stream = UdpClientStream::with_bind_addr_and_timeout(
|
||||
name_server_config.address,
|
||||
name_server_config.bind_address.clone(),
|
||||
name_server_config.bind_address,
|
||||
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)),
|
||||
)
|
||||
.await
|
||||
@@ -90,54 +108,110 @@ impl Resolver {
|
||||
|
||||
Ok(Resolver {
|
||||
search_domains,
|
||||
// FIXME: All of these should be configurable
|
||||
max_time_to_wait_for_initial: Duration::from_millis(150),
|
||||
time_to_wait_after_first: Duration::from_millis(50),
|
||||
time_to_wait_for_lingering: Duration::from_secs(2),
|
||||
state: Arc::new(Mutex::new(ResolverState {
|
||||
client_connections,
|
||||
cache: HashMap::new(),
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
/// Look up the address of the given name, returning either a set of results
|
||||
/// we've received in a reasonable amount of time, or an error.
|
||||
pub async fn lookup(&mut self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
|
||||
pub async fn lookup(&self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
|
||||
let names = self.expand_name(name);
|
||||
let mut response_stream = self.create_response_stream(names).await;
|
||||
|
||||
let cached_values = self.cached_lookup(&names).await;
|
||||
if !cached_values.is_empty() {
|
||||
return Ok(cached_values);
|
||||
}
|
||||
|
||||
let mut response_stream = self.create_response_stream(&names).await;
|
||||
if response_stream.is_empty() {
|
||||
return Err(ResolveError::NoServersAvailable);
|
||||
}
|
||||
|
||||
let mut first_error = None;
|
||||
let get_first_by = Instant::now() + self.max_time_to_wait_for_initial;
|
||||
let got_responses =
|
||||
get_responses_by(&self.state, get_first_by, name, &mut response_stream).await?;
|
||||
|
||||
while let Some(response) = response_stream.next().await {
|
||||
match response {
|
||||
Err(e) => {
|
||||
if first_error.is_none() {
|
||||
first_error = Some(e);
|
||||
if got_responses {
|
||||
let get_rest_by = Instant::now() + self.time_to_wait_after_first;
|
||||
let _ = get_responses_by(&self.state, get_rest_by, name, &mut response_stream).await;
|
||||
}
|
||||
|
||||
let lingering_deadline = Instant::now() + self.time_to_wait_for_lingering;
|
||||
let state_copy = self.state.clone();
|
||||
let name_copy = name.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = get_responses_by(
|
||||
&state_copy,
|
||||
lingering_deadline,
|
||||
&name_copy,
|
||||
&mut response_stream,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let after_search_lookup = self.cached_lookup(&names).await;
|
||||
if after_search_lookup.is_empty() {
|
||||
Err(ResolveError::NoResponses)
|
||||
} else {
|
||||
Ok(after_search_lookup)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response) => {
|
||||
for answer in response.into_message().take_answers() {
|
||||
self.handle_response(name, answer).await;
|
||||
/// Look up any cached IP address we might have for the given names.
|
||||
///
|
||||
/// As a side effect, this function will clean out any expired entries in the
|
||||
/// cache.
|
||||
async fn cached_lookup(&self, names: &[Name]) -> HashSet<IpAddr> {
|
||||
let mut retval = HashSet::new();
|
||||
let mut state = self.state.lock().await;
|
||||
|
||||
for name in names {
|
||||
if let Some(mut existing) = state.cache.remove(name) {
|
||||
let now = Instant::now();
|
||||
|
||||
existing.retain(|item| {
|
||||
let keeper = item.expires_at > now;
|
||||
|
||||
if keeper {
|
||||
retval.insert(item.address);
|
||||
}
|
||||
|
||||
keeper
|
||||
});
|
||||
|
||||
if !existing.is_empty() {
|
||||
state.cache.insert(name.clone(), existing);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match first_error {
|
||||
None => Err(ResolveError::NoResponses),
|
||||
Some(error) => Err(ResolveError::ResponseError { error }),
|
||||
}
|
||||
retval
|
||||
}
|
||||
|
||||
async fn create_response_stream(&mut self, names: Vec<Name>) -> SelectAll<DnsResponseStream> {
|
||||
let connections = std::mem::take(&mut self.client_connections);
|
||||
/// Reach out to all of our current connections, and start the process of getting
|
||||
/// responses.
|
||||
///
|
||||
/// If the clients are shut down, we'll try to create new connections to them, or
|
||||
/// remove them from our connection list if we can't. If there are no connections,
|
||||
/// the `SelectAll` will be empty, and callers are advised to check for this
|
||||
/// condition and inform the user that they're out of useful DNS servers.
|
||||
async fn create_response_stream(&self, names: &[Name]) -> SelectAll<DnsResponseStream> {
|
||||
let mut state = self.state.lock().await;
|
||||
let connections = std::mem::take(&mut state.client_connections);
|
||||
let mut response_stream = futures::stream::SelectAll::new();
|
||||
|
||||
for (config, mut client) in connections.into_iter() {
|
||||
if client.is_shutdown() {
|
||||
let stream = UdpClientStream::with_bind_addr_and_timeout(
|
||||
config.address,
|
||||
config.bind_address.clone(),
|
||||
config.bind_address,
|
||||
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
|
||||
)
|
||||
.await;
|
||||
@@ -158,7 +232,7 @@ impl Resolver {
|
||||
|
||||
let request = DnsRequest::new(message, Default::default());
|
||||
response_stream.push(client.send_message(request));
|
||||
self.client_connections.push((config, client));
|
||||
state.client_connections.push((config, client));
|
||||
}
|
||||
|
||||
response_stream
|
||||
@@ -177,12 +251,123 @@ impl Resolver {
|
||||
names
|
||||
}
|
||||
|
||||
/// Handle an individual response from a server.
|
||||
/// Run a cleaner task on this Resolver object.
|
||||
///
|
||||
/// Returns true if we got an answer from the server, and updated the internal cache.
|
||||
/// This answer is used externally to set a timer, so that we don't wait the full DNS
|
||||
/// timeout period for answers to come back.
|
||||
async fn handle_response(&mut self, query: &Name, record: Record) -> bool {
|
||||
/// The task will run in the given JoinSet, set it can be easily killed, tracked, waited
|
||||
/// for as desired. Every `period` interval, it will grab the state lock and clean out
|
||||
/// any entries that are timed out. Ideally, you shouldn't set this too frequently, so
|
||||
/// that this process doesn't interfere with other tasks trying to use the resolver.
|
||||
///
|
||||
/// This task will only weakly hang on to the resolver state, so that if the Resolver
|
||||
/// object is dropped it will quietly stop running. It will also stop running if it sees
|
||||
/// a message on the kill signal, if one is provided, or if one is provided and the other
|
||||
/// end drops the sender. (So if you don't want a kill signal, send `None`, or you might
|
||||
/// accidentally kill your GC task too early.)
|
||||
pub async fn run_resolver_gc(
|
||||
&self,
|
||||
task_set: &mut JoinSet<Result<(), ()>>,
|
||||
mut kill_signal: Option<oneshot::Receiver<()>>,
|
||||
interval: Duration,
|
||||
) {
|
||||
let weak_locked_state = Arc::downgrade(&self.state);
|
||||
|
||||
task_set.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
let Some(locked_state) = weak_locked_state.upgrade() else {
|
||||
break Ok(());
|
||||
};
|
||||
|
||||
let mut state = if let Some(kill_signal) = kill_signal.as_mut() {
|
||||
let locker = locked_state.lock();
|
||||
|
||||
futures::pin_mut!(locker);
|
||||
match select(locker, kill_signal).await {
|
||||
futures::future::Either::Left((state, _)) => state,
|
||||
futures::future::Either::Right((_, _)) => break Ok(()),
|
||||
}
|
||||
} else {
|
||||
locked_state.lock().await
|
||||
};
|
||||
|
||||
let now = Instant::now();
|
||||
state.cache.retain(|_, value| {
|
||||
value.retain(|resolution| resolution.expires_at < now);
|
||||
!value.is_empty()
|
||||
});
|
||||
|
||||
drop(state);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Inject a mapping into the resolver, with the given TTL.
|
||||
///
|
||||
/// This is largely used for testing purposes, but could be used if there are static
|
||||
/// addresses you want to always resolve in a particular way. For that use case, you
|
||||
/// probably want to add absurdly-long TTLs for those names.
|
||||
pub async fn inject_resolution(&self, name: Name, address: IpAddr, ttl: Duration) {
|
||||
let resolution = DnsResolution {
|
||||
address,
|
||||
expires_at: Instant::now() + ttl,
|
||||
};
|
||||
|
||||
let mut state = self.state.lock().await;
|
||||
match state.cache.entry(name) {
|
||||
std::collections::hash_map::Entry::Vacant(vac) => {
|
||||
vac.insert(vec![resolution]);
|
||||
}
|
||||
std::collections::hash_map::Entry::Occupied(mut occ) => {
|
||||
occ.get_mut().push(resolution);
|
||||
}
|
||||
}
|
||||
drop(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all the responses into cache that occur before the given timestamp.
|
||||
///
|
||||
/// Returns an error if all we get are errors while trying to read from these servers,
|
||||
/// otherwise returns Ok(true) if we received some interesting responses, or Ok(false)
|
||||
/// if we didn't. Note that receiving good responses wins out over receiving errors, so
|
||||
/// if we receive 2 good responses and 3 errors, this function will return Ok(true).
|
||||
async fn get_responses_by(
|
||||
state: &Arc<Mutex<ResolverState>>,
|
||||
deadline: Instant,
|
||||
name: &Name,
|
||||
stream: &mut SelectAll<DnsResponseStream>,
|
||||
) -> Result<bool, ResolveError> {
|
||||
let mut received_error = None;
|
||||
let mut got_something = false;
|
||||
|
||||
loop {
|
||||
match timeout_at(deadline, stream.next()).await {
|
||||
Err(_) => return received_error.unwrap_or(Ok(got_something)),
|
||||
|
||||
Ok(None) if got_something => return Ok(got_something),
|
||||
Ok(None) => return received_error.unwrap_or(Ok(false)),
|
||||
|
||||
Ok(Some(Err(error))) => {
|
||||
if received_error.is_none() {
|
||||
received_error = Some(Err(ResolveError::ResponseError { error }));
|
||||
}
|
||||
}
|
||||
Ok(Some(Ok(response))) => {
|
||||
for answer in response.into_message().take_answers() {
|
||||
got_something |= handle_response(state, name, answer).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an individual response from a server.
|
||||
///
|
||||
/// Returns true if we got an answer from the server, and updates the internal cache.
|
||||
/// This function will never return an error, although there may be some odd processing
|
||||
/// situations that could lead it to printing warnings to the logs.
|
||||
async fn handle_response(state: &Arc<Mutex<ResolverState>>, query: &Name, record: Record) -> bool {
|
||||
let ttl = record.ttl();
|
||||
|
||||
let Some(rdata) = record.into_data() else {
|
||||
@@ -190,7 +375,7 @@ impl Resolver {
|
||||
return false;
|
||||
};
|
||||
|
||||
let response: IpAddr = match rdata {
|
||||
let address: IpAddr = match rdata {
|
||||
RData::A(arec) => arec.0.into(),
|
||||
RData::AAAA(arec) => arec.0.into(),
|
||||
|
||||
@@ -200,62 +385,48 @@ impl Resolver {
|
||||
}
|
||||
};
|
||||
|
||||
let expire_at = Instant::now() + Duration::from_secs(ttl as u64);
|
||||
let now = Instant::now();
|
||||
let expires_at = now + Duration::from_secs(ttl as u64);
|
||||
let resolution = DnsResolution {
|
||||
address,
|
||||
expires_at,
|
||||
};
|
||||
|
||||
match self.cache.entry(query.clone()) {
|
||||
let mut state = state.lock().await;
|
||||
match state.cache.entry(query.clone()) {
|
||||
std::collections::hash_map::Entry::Vacant(vec) => {
|
||||
vec.insert(vec![(expire_at, response.clone())]);
|
||||
vec.insert(vec![resolution]);
|
||||
}
|
||||
|
||||
std::collections::hash_map::Entry::Occupied(mut occ) => {
|
||||
clean_expired_entries(occ.get_mut());
|
||||
occ.get_mut().push((expire_at, response.clone()));
|
||||
let vec = occ.get_mut();
|
||||
vec.retain(|x| x.expires_at > now);
|
||||
vec.push(resolution);
|
||||
}
|
||||
}
|
||||
drop(state);
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) {
|
||||
let now = Instant::now();
|
||||
list.retain(|(expire_at, _)| expire_at > &now);
|
||||
#[tokio::test]
|
||||
async fn fetch_cached() {
|
||||
let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
|
||||
let name = Name::from_utf8("name.foo").unwrap();
|
||||
let addr = IpAddr::from_str("1.2.4.5").unwrap();
|
||||
|
||||
resolver
|
||||
.inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
|
||||
.await;
|
||||
let read = resolver.lookup(&name).await.unwrap();
|
||||
assert!(read.contains(&addr));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn uhsure() {
|
||||
let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
|
||||
let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
|
||||
let name = Name::from_ascii("uhsure.com").unwrap();
|
||||
let result = resolver.lookup(&name).await.unwrap();
|
||||
println!("result = {:?}", result);
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn manual_lookup() {
|
||||
let mut stream: UdpClientStream<UdpSocket> = UdpClientStream::with_bind_addr_and_timeout(
|
||||
SocketAddr::V4(std::net::SocketAddrV4::new(
|
||||
std::net::Ipv4Addr::new(172, 19, 207, 1),
|
||||
53,
|
||||
)),
|
||||
None,
|
||||
Duration::from_secs(3),
|
||||
)
|
||||
.await
|
||||
.expect("can create UDP DNS client stream");
|
||||
|
||||
let mut message = Message::new();
|
||||
let name = Name::from_ascii("uhsure.com").unwrap();
|
||||
message.add_query(Query::query(name.clone(), RecordType::A));
|
||||
message.add_query(Query::query(name.clone(), RecordType::AAAA));
|
||||
message.set_recursion_desired(true);
|
||||
|
||||
let request = DnsRequest::new(message, Default::default());
|
||||
let mut responses = stream.send_message(request);
|
||||
|
||||
while let Some(response) = responses.next().await {
|
||||
println!("response: {:?}", response);
|
||||
}
|
||||
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
@@ -48,4 +48,6 @@ pub enum OperationalError {
|
||||
InvalidHostname(String),
|
||||
#[error("Unable to parse host address")]
|
||||
UnableToParseHostAddress,
|
||||
#[error("Unable to configure resolver")]
|
||||
Resolver,
|
||||
}
|
||||
|
||||
@@ -14,15 +14,14 @@ pub enum TopLevelError {
|
||||
}
|
||||
|
||||
pub async fn run(mut config: ServerConfiguration) -> error_stack::Result<(), TopLevelError> {
|
||||
let mut server_state = state::ServerState::default();
|
||||
let _server_state = state::ServerState::default();
|
||||
|
||||
let listeners = config.generate_listener_sockets().await
|
||||
let listeners = config
|
||||
.generate_listener_sockets()
|
||||
.await
|
||||
.change_context(TopLevelError::ConfigurationError)?;
|
||||
|
||||
for (name, listener) in listeners.into_iter() {
|
||||
}
|
||||
for (_name, _listener) in listeners.into_iter() {}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ pub struct SocketServer {
|
||||
listener: UnixListener,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl SocketServer {
|
||||
/// Create a new server that will handle inputs from the client program.
|
||||
///
|
||||
@@ -21,7 +22,7 @@ impl SocketServer {
|
||||
.local_addr()
|
||||
.map(|x| x.as_pathname().map(|p| format!("{}", p.display())))
|
||||
.unwrap_or_else(|_| None)
|
||||
.unwrap_or_else(|| format!("unknown"));
|
||||
.unwrap_or_else(|| "<unknown>".to_string());
|
||||
|
||||
tracing::trace!(%name, %path, "Creating new socket listener");
|
||||
|
||||
@@ -41,7 +42,11 @@ impl SocketServer {
|
||||
/// there, the core task should be unaffected.
|
||||
pub async fn start(mut self) -> error_stack::Result<(), TopLevelError> {
|
||||
loop {
|
||||
let (stream, addr) = self.listener.accept().await.change_context(TopLevelError::SocketHandlerFailure)?;
|
||||
let (stream, addr) = self
|
||||
.listener
|
||||
.accept()
|
||||
.await
|
||||
.change_context(TopLevelError::SocketHandlerFailure)?;
|
||||
let remote_addr = addr
|
||||
.as_pathname()
|
||||
.map(|x| x.display())
|
||||
@@ -58,8 +63,7 @@ impl SocketServer {
|
||||
|
||||
self.num_sessions_run += 1;
|
||||
|
||||
tokio::task::spawn(Self::run_session(stream)
|
||||
.instrument(span))
|
||||
tokio::task::spawn(Self::run_session(stream).instrument(span))
|
||||
.await
|
||||
.change_context(TopLevelError::SocketHandlerFailure)?;
|
||||
}
|
||||
@@ -70,6 +74,5 @@ impl SocketServer {
|
||||
/// This is here because it's convenient, not because it shares state (obviously,
|
||||
/// given the type signature). But it's somewhat logically associated with this type,
|
||||
/// so it seems reasonable to make it an associated function.
|
||||
async fn run_session(handle: UnixStream) {
|
||||
}
|
||||
async fn run_session(handle: UnixStream) {}
|
||||
}
|
||||
|
||||
@@ -6,3 +6,26 @@ pub struct ServerState {
|
||||
/// The set of top level tasks that are currently running.
|
||||
top_level_tasks: JoinSet<Result<(), TopLevelError>>,
|
||||
}
|
||||
|
||||
impl ServerState {
|
||||
/// Block until all the current top level tasks have closed.
|
||||
///
|
||||
/// This will log any errors found in these tasks to the current log,
|
||||
/// if there is one, but will otherwise drop them.
|
||||
#[allow(unused)]
|
||||
pub async fn shutdown(&mut self) {
|
||||
while let Some(next) = self.top_level_tasks.join_next_with_id().await {
|
||||
match next {
|
||||
Err(e) => tracing::error!(id = %e.id(), "Failed to attach to top-level task"),
|
||||
|
||||
Ok((id, Err(e))) => {
|
||||
tracing::error!(%id, "Top-level server error: {}", e);
|
||||
}
|
||||
|
||||
Ok((id, Ok(()))) => {
|
||||
tracing::debug!(%id, "Cleanly closed server task.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user