Workspacify

This commit is contained in:
2025-05-03 17:30:01 -07:00
parent 9fe5b78962
commit d036997de3
60 changed files with 450 additions and 212 deletions

24
configuration/Cargo.toml Normal file
View File

@@ -0,0 +1,24 @@
[package]
name = "configuration"
edition = "2024"
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] }
[dependencies]
clap = { workspace = true }
console-subscriber = { workspace = true }
crypto = { workspace = true }
nix = { workspace = true }
proptest = { workspace = true }
rustix = { workspace = true }
serde = { workspace = true, features = ["derive"] }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
toml = { workspace = true }
tracing = { workspace = true }
tracing-core = { workspace = true }
tracing-subscriber = { workspace = true }
url = { workspace = true }
xdg = { workspace = true }

View File

@@ -0,0 +1,69 @@
use crate::command_line::ClientArguments;
pub use crate::command_line::ClientCommand;
use crate::config_file::ConfigFile;
use crate::error::ConfigurationError;
use crate::logging::LoggingConfiguration;
use crate::resolver::DnsConfig;
use crate::runtime::RuntimeConfiguration;
use clap::Parser;
use std::ffi::OsString;
use tracing_core::LevelFilter;
pub struct ClientConfiguration {
pub runtime: RuntimeConfiguration,
pub logging: LoggingConfiguration,
pub dns_config: DnsConfig,
_config_file: Option<ConfigFile>,
command: ClientCommand,
}
impl ClientConfiguration {
/// Load a basic client configuration for this run.
///
/// This will parse the process's command line arguments, and parse
/// a config file if given, but will not interpret this information
/// beyond that required to understand the user's goals for the
/// runtime system and logging. For this reason, it is not async,
/// as it is responsible for determining all the information we
/// will use to generate the runtime.
pub fn new<I, T>(args: I) -> Result<Self, ConfigurationError>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let mut command_line_arguments = ClientArguments::try_parse_from(args)?;
let mut _config_file =
ConfigFile::new(command_line_arguments.arguments.config_file.take())?;
let mut runtime = RuntimeConfiguration::default();
let mut logging = LoggingConfiguration::default();
// we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line,
// it overwrites any config file options.
if let Some(config_file) = _config_file.as_mut() {
config_file.merge_standard_options_into(&mut runtime, &mut logging);
}
command_line_arguments
.arguments
.merge_standard_options_into(&mut runtime, &mut logging);
if command_line_arguments.command.is_list_command() {
logging.filter = LevelFilter::ERROR;
}
// FIXME!!
let dns_config = DnsConfig::default();
Ok(ClientConfiguration {
runtime,
logging,
dns_config,
_config_file,
command: command_line_arguments.command,
})
}
// Returns the command we received from the user
pub fn command(&self) -> &ClientCommand {
&self.command
}
}

View File

@@ -0,0 +1,217 @@
use crate::console::ConsoleConfiguration;
use crate::logging::{LogTarget, LoggingConfiguration};
use crate::runtime::RuntimeConfiguration;
use clap::{Args, Parser, Subcommand};
use std::net::SocketAddr;
use std::path::PathBuf;
#[cfg(test)]
use std::str::FromStr;
use tracing_core::LevelFilter;
#[derive(Parser, Debug, Default)]
#[command(version, about, long_about = None)]
pub struct ServerArguments {
#[command(flatten)]
pub arguments: CommandLineArguments,
}
#[derive(Parser, Debug, Default)]
#[command(version, about, long_about = None)]
pub struct ClientArguments {
#[command(flatten)]
pub arguments: CommandLineArguments,
/// The command we're running as the client
#[command(subcommand)]
pub command: ClientCommand,
}
#[derive(Args, Debug, Default)]
pub struct CommandLineArguments {
/// The config file to use for this command.
#[arg(short, long)]
pub config_file: Option<PathBuf>,
/// The number of "normal" threads to use for this process.
///
/// These are the threads that do the predominant part of the work
/// for the system.
#[arg(short, long)]
threads: Option<usize>,
/// The number of "blocking" threads to use for this process.
///
/// These threads are used for long-running operations, or operations
/// that require extensive interaction with non-asynchronous-friendly
/// IO. This should definitely be >= 1, but does not need to be super
/// high.
#[arg(short, long)]
blocking_threads: Option<usize>,
/// The place to send log data to.
#[arg(short = 'o', long)]
log_file: Option<PathBuf>,
/// The log level to report.
#[arg(short, long)]
log_level: Option<LevelFilter>,
/// A network server IP address to use for tokio-console inspection.
#[arg(short = 's', long, group = "console")]
console_network_server: Option<SocketAddr>,
/// A unix domain socket address to use for tokio-console inspection.
#[arg(short = 'u', long, group = "console")]
console_unix_socket: Option<PathBuf>,
}
#[derive(Debug, Default, Subcommand)]
pub enum ClientCommand {
/// List the key exchange algorithms we currently allow
#[default]
ListKeyExchangeAlgorithms,
/// List the host key algorithms we currently allow
ListHostKeyAlgorithms,
/// List the encryption algorithms we currently allow
ListEncryptionAlgorithms,
/// List the MAC algorithms we currently allow
ListMacAlgorithms,
/// List the compression algorithms we currently allow
ListCompressionAlgorithms,
/// Connect to the given host and port
Connect { target: String },
}
impl ClientCommand {
/// Is this a command that's just going to list some information to the console?
pub fn is_list_command(&self) -> bool {
matches!(
self,
ClientCommand::ListKeyExchangeAlgorithms
| ClientCommand::ListHostKeyAlgorithms
| ClientCommand::ListEncryptionAlgorithms
| ClientCommand::ListMacAlgorithms
| ClientCommand::ListCompressionAlgorithms
)
}
}
impl CommandLineArguments {
pub fn merge_standard_options_into(
&mut self,
runtime_config: &mut RuntimeConfiguration,
logging_config: &mut LoggingConfiguration,
) {
if let Some(threads) = self.threads {
runtime_config.tokio_worker_threads = threads;
}
if let Some(threads) = self.blocking_threads {
runtime_config.tokio_blocking_threads = threads;
}
if let Some(log_file) = self.log_file.take() {
logging_config.target = LogTarget::File(log_file);
}
if let Some(log_level) = self.log_level.take() {
logging_config.filter = log_level;
}
if (self.console_network_server.is_some() || self.console_unix_socket.is_some())
&& runtime_config.console.is_none()
{
runtime_config.console = Some(ConsoleConfiguration::default());
}
if let Some(cns) = self.console_network_server.take() {
if let Some(x) = runtime_config.console.as_mut() {
x.server_addr = cns.into();
}
}
if let Some(cus) = self.console_unix_socket.take() {
if let Some(x) = runtime_config.console.as_mut() {
x.server_addr = cus.into();
}
}
}
}
#[cfg(test)]
fn apply_command_line(
mut cmdargs: CommandLineArguments,
) -> (RuntimeConfiguration, LoggingConfiguration) {
let mut original_runtime = RuntimeConfiguration::default();
let mut original_logging = LoggingConfiguration::default();
cmdargs.merge_standard_options_into(&mut original_runtime, &mut original_logging);
(original_runtime, original_logging)
}
#[test]
fn command_line_wins() {
let original_runtime = RuntimeConfiguration::default();
let cmd = CommandLineArguments {
threads: Some(original_runtime.tokio_worker_threads + 1),
..CommandLineArguments::default()
};
let (test1_run, _) = apply_command_line(cmd);
assert_ne!(
original_runtime.tokio_worker_threads,
test1_run.tokio_worker_threads
);
assert_eq!(
original_runtime.tokio_blocking_threads,
test1_run.tokio_blocking_threads
);
let cmd = CommandLineArguments {
blocking_threads: Some(original_runtime.tokio_blocking_threads + 1),
..CommandLineArguments::default()
};
let (test2_run, _) = apply_command_line(cmd);
assert_eq!(
original_runtime.tokio_worker_threads,
test2_run.tokio_worker_threads
);
assert_ne!(
original_runtime.tokio_blocking_threads,
test2_run.tokio_blocking_threads
);
}
#[test]
fn can_set_console_settings() {
let cmd = CommandLineArguments {
console_network_server: Some(
SocketAddr::from_str("127.0.0.1:8080").expect("reasonable address"),
),
..CommandLineArguments::default()
};
let (test1_run, _) = apply_command_line(cmd);
assert!(test1_run.console.is_some());
assert!(matches!(
test1_run.console.unwrap().server_addr,
console_subscriber::ServerAddr::Tcp(_)
));
let temp_path = tempfile::NamedTempFile::new()
.expect("can build temp file")
.into_temp_path();
std::fs::remove_file(&temp_path).unwrap();
let filename = temp_path.to_path_buf();
let cmd = CommandLineArguments {
console_unix_socket: Some(filename),
..CommandLineArguments::default()
};
let (test2_run, _) = apply_command_line(cmd);
assert!(test2_run.console.is_some());
assert!(matches!(
test2_run.console.unwrap().server_addr,
console_subscriber::ServerAddr::Unix(_)
));
}

View File

@@ -0,0 +1,586 @@
use crate::error::ConfigurationError;
use crate::logging::{LogMode, LogTarget, LoggingConfiguration};
use crate::resolver::DnsConfig;
use crate::runtime::RuntimeConfiguration;
use crypto::known_algorithms::{
ALLOWED_COMPRESSION_ALGORITHMS, ALLOWED_ENCRYPTION_ALGORITHMS, ALLOWED_HOST_KEY_ALGORITHMS,
ALLOWED_KEY_EXCHANGE_ALGORITHMS, ALLOWED_MAC_ALGORITHMS,
};
use proptest::arbitrary::{any, Arbitrary};
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use serde::de::{self, Unexpected};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::{HashMap, HashSet};
use std::io::Read;
use std::path::PathBuf;
use thiserror::Error;
use tracing_core::Level;
#[allow(dead_code)]
#[derive(Debug, Deserialize, PartialEq, Serialize)]
pub struct ConfigFile {
runtime: Option<RuntimeConfig>,
logging: Option<LoggingConfig>,
pub resolver: Option<DnsConfig>,
pub sockets: Option<HashMap<String, SocketConfig>>,
pub keys: HashMap<String, KeyConfig>,
defaults: Option<ServerConfig>,
servers: HashMap<String, ServerConfig>,
}
impl Arbitrary for ConfigFile {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(
any::<Option<RuntimeConfig>>(),
any::<Option<LoggingConfig>>(),
any::<Option<DnsConfig>>(),
proptest::option::of(keyed_section(SocketConfig::arbitrary())),
keyed_section(KeyConfig::arbitrary()),
any::<Option<ServerConfig>>(),
keyed_section(ServerConfig::arbitrary()),
)
.prop_map(
|(runtime, logging, resolver, sockets, keys, defaults, servers)| ConfigFile {
runtime,
logging,
resolver,
sockets,
keys,
defaults,
servers,
},
)
.boxed()
}
}
fn keyed_section<S>(strat: S) -> BoxedStrategy<HashMap<String, S::Value>>
where
S: Strategy + 'static,
{
proptest::collection::hash_map(
proptest::string::string_regex("[a-zA-Z0-9-]{1,30}").unwrap(),
strat,
0..40,
)
.boxed()
}
#[derive(Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct SocketConfig {
path: PathBuf,
user: Option<String>,
group: Option<String>,
permissions: Option<Permissions>,
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub enum Permissions {
User,
Group,
UserGroup,
Everyone,
}
impl From<Permissions> for rustix::fs::Mode {
fn from(value: Permissions) -> Self {
match value {
Permissions::User => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
Permissions::Group => rustix::fs::Mode::RUSR | rustix::fs::Mode::WUSR,
Permissions::UserGroup => {
rustix::fs::Mode::RUSR
| rustix::fs::Mode::WUSR
| rustix::fs::Mode::RGRP
| rustix::fs::Mode::WGRP
}
Permissions::Everyone => {
rustix::fs::Mode::RUSR
| rustix::fs::Mode::WUSR
| rustix::fs::Mode::RGRP
| rustix::fs::Mode::WGRP
| rustix::fs::Mode::ROTH
| rustix::fs::Mode::WOTH
}
}
}
}
impl SocketConfig {
pub async fn into_listener(self) -> Result<tokio::net::UnixListener, ConfigurationError> {
let base = tokio::net::UnixListener::bind(&self.path).map_err(|error| {
ConfigurationError::CouldNotMakeSocket {
path: self.path.clone(),
error,
}
})?;
let user = self
.user
.map(|x| {
nix::unistd::User::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| {
Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(
std::io::ErrorKind::NotFound,
"could not find user",
),
})
})
})
.transpose()?;
let group = self
.group
.map(|x| {
nix::unistd::Group::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| {
Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(
std::io::ErrorKind::NotFound,
"could not find user",
),
})
})
})
.transpose()?;
if user.is_some() || group.is_some() {
std::os::unix::fs::chown(
&self.path,
user.map(|x| x.uid.as_raw()),
group.map(|x| x.gid.as_raw()),
)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "set user/group ownership".to_string(),
path: self.path.clone(),
error,
})?;
}
let perms = self.permissions.unwrap_or(Permissions::UserGroup);
rustix::fs::chmod(&self.path, perms.into()).map_err(|error| {
ConfigurationError::CouldNotSetPerms {
thing: "set permissions".to_string(),
path: self.path.clone(),
error: error.into(),
}
})?;
Ok(base)
}
}
impl Arbitrary for SocketConfig {
type Strategy = BoxedStrategy<SocketConfig>;
type Parameters = ();
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(
PathBuf::arbitrary(),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(Permissions::arbitrary()),
)
.prop_map(|(path, user, group, permissions)| SocketConfig {
path,
user,
group,
permissions,
})
.boxed()
}
}
impl Arbitrary for Permissions {
type Strategy = BoxedStrategy<Permissions>;
type Parameters = ();
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
proptest::prop_oneof![
proptest::strategy::Just(Permissions::User),
proptest::strategy::Just(Permissions::Group),
proptest::strategy::Just(Permissions::UserGroup),
proptest::strategy::Just(Permissions::Everyone),
]
.boxed()
}
}
#[derive(Debug, Deserialize, PartialEq, Serialize)]
pub struct RuntimeConfig {
worker_threads: Option<usize>,
blocking_threads: Option<usize>,
}
impl Arbitrary for RuntimeConfig {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(any::<Option<u16>>(), any::<Option<u16>>())
.prop_map(|(worker_threads, blocking_threads)| RuntimeConfig {
worker_threads: worker_threads.map(Into::into),
blocking_threads: blocking_threads.map(Into::into),
})
.boxed()
}
}
#[allow(dead_code)]
#[derive(Debug, Deserialize, PartialEq, Serialize)]
pub struct LoggingConfig {
#[serde(
default,
deserialize_with = "parse_level",
serialize_with = "write_level"
)]
level: Option<Level>,
include_filename: Option<bool>,
include_lineno: Option<bool>,
include_thread_ids: Option<bool>,
include_thread_names: Option<bool>,
#[serde(
default,
deserialize_with = "parse_mode",
serialize_with = "write_mode"
)]
mode: Option<LogMode>,
#[serde(
default,
deserialize_with = "parse_target",
serialize_with = "write_target"
)]
target: Option<LogTarget>,
}
impl Arbitrary for LoggingConfig {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
let level_strat = proptest::prop_oneof![
Just(None),
Just(Some(Level::TRACE)),
Just(Some(Level::DEBUG)),
Just(Some(Level::INFO)),
Just(Some(Level::WARN)),
Just(Some(Level::ERROR)),
];
let mode_strat = proptest::prop_oneof![
Just(None),
Just(Some(LogMode::Compact)),
Just(Some(LogMode::Pretty)),
Just(Some(LogMode::Json)),
];
let target_strat = proptest::prop_oneof![
Just(None),
Just(Some(LogTarget::StdErr)),
Just(Some(LogTarget::StdOut)),
Just(Some({
let tempfile = tempfile::NamedTempFile::new().unwrap();
let name = tempfile.into_temp_path();
LogTarget::File(name.to_path_buf())
})),
];
(
level_strat,
any::<Option<bool>>(),
any::<Option<bool>>(),
any::<Option<bool>>(),
any::<Option<bool>>(),
mode_strat,
target_strat,
)
.prop_map(
|(
level,
include_filename,
include_lineno,
include_thread_ids,
include_thread_names,
mode,
target,
)| LoggingConfig {
level,
include_filename,
include_lineno,
include_thread_ids,
include_thread_names,
mode,
target,
},
)
.boxed()
}
}
fn parse_level<'de, D>(deserializer: D) -> Result<Option<Level>, D::Error>
where
D: Deserializer<'de>,
{
let s: Option<String> = Option::deserialize(deserializer)?;
s.map(|x| match x.to_lowercase().as_str() {
"trace" => Ok(Some(Level::TRACE)),
"debug" => Ok(Some(Level::DEBUG)),
"info" => Ok(Some(Level::INFO)),
"warn" => Ok(Some(Level::WARN)),
"error" => Ok(Some(Level::ERROR)),
_ => Err(de::Error::invalid_value(
Unexpected::Str(&x),
&"valid logging level (trace, debug, info, warn, or error",
)),
})
.unwrap_or_else(|| Ok(None))
}
fn write_level<S: Serializer>(item: &Option<Level>, serializer: S) -> Result<S::Ok, S::Error> {
match item {
None => serializer.serialize_none(),
Some(Level::TRACE) => serializer.serialize_some("trace"),
Some(Level::DEBUG) => serializer.serialize_some("debug"),
Some(Level::INFO) => serializer.serialize_some("info"),
Some(Level::WARN) => serializer.serialize_some("warn"),
Some(Level::ERROR) => serializer.serialize_some("error"),
}
}
fn parse_mode<'de, D>(deserializer: D) -> Result<Option<LogMode>, D::Error>
where
D: Deserializer<'de>,
{
let s: Option<String> = Option::deserialize(deserializer)?;
s.map(|x| match x.to_lowercase().as_str() {
"compact" => Ok(Some(LogMode::Compact)),
"pretty" => Ok(Some(LogMode::Pretty)),
"json" => Ok(Some(LogMode::Json)),
_ => Err(de::Error::invalid_value(
Unexpected::Str(&x),
&"valid logging level (trace, debug, info, warn, or error",
)),
})
.unwrap_or_else(|| Ok(None))
}
fn write_mode<S: Serializer>(item: &Option<LogMode>, serializer: S) -> Result<S::Ok, S::Error> {
match item {
None => serializer.serialize_none(),
Some(LogMode::Compact) => serializer.serialize_some("compact"),
Some(LogMode::Pretty) => serializer.serialize_some("pretty"),
Some(LogMode::Json) => serializer.serialize_some("json"),
}
}
fn parse_target<'de, D>(deserializer: D) -> Result<Option<LogTarget>, D::Error>
where
D: Deserializer<'de>,
{
let s: Option<String> = Option::deserialize(deserializer)?;
Ok(s.map(|x| match x.to_lowercase().as_str() {
"stdout" => LogTarget::StdOut,
"stderr" => LogTarget::StdErr,
_ => LogTarget::File(x.into()),
}))
}
fn write_target<S: Serializer>(item: &Option<LogTarget>, serializer: S) -> Result<S::Ok, S::Error> {
match item {
None => serializer.serialize_none(),
Some(LogTarget::StdOut) => serializer.serialize_some("stdout"),
Some(LogTarget::StdErr) => serializer.serialize_some("stderr"),
Some(LogTarget::File(file)) => serializer.serialize_some(file),
}
}
#[derive(Debug, Deserialize, PartialEq, Serialize)]
pub struct KeyConfig {
pub public: PathBuf,
pub private: PathBuf,
pub password: Option<String>,
}
impl Arbitrary for KeyConfig {
type Parameters = bool;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_generate_real_keys: Self::Parameters) -> Self::Strategy {
let password = proptest::string::string_regex("[a-zA-Z0-9_!@#$%^&*]{8,40}").unwrap();
proptest::option::of(password)
.prop_map(|password| {
let public_file = tempfile::NamedTempFile::new().unwrap();
let private_file = tempfile::NamedTempFile::new().unwrap();
let public = public_file.into_temp_path().to_path_buf();
let private = private_file.into_temp_path().to_path_buf();
KeyConfig {
public,
private,
password,
}
})
.boxed()
}
}
#[derive(Debug, Deserialize, PartialEq, Serialize)]
pub struct ServerConfig {
key_exchange_algorithms: Option<Vec<String>>,
server_host_algorithms: Option<Vec<String>>,
encryption_algorithms: Option<Vec<String>>,
mac_algorithms: Option<Vec<String>>,
compression_algorithms: Option<Vec<String>>,
predict: Option<String>,
}
impl Arbitrary for ServerConfig {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let keyx = proptest::sample::select(ALLOWED_KEY_EXCHANGE_ALGORITHMS);
let hostkey = proptest::sample::select(ALLOWED_HOST_KEY_ALGORITHMS);
let enc = proptest::sample::select(ALLOWED_ENCRYPTION_ALGORITHMS);
let mac = proptest::sample::select(ALLOWED_MAC_ALGORITHMS);
let comp = proptest::sample::select(ALLOWED_COMPRESSION_ALGORITHMS);
(
proptest::collection::hash_set(keyx.clone(), 0..ALLOWED_KEY_EXCHANGE_ALGORITHMS.len()),
proptest::collection::hash_set(hostkey, 0..ALLOWED_HOST_KEY_ALGORITHMS.len()),
proptest::collection::hash_set(enc, 0..ALLOWED_ENCRYPTION_ALGORITHMS.len()),
proptest::collection::hash_set(mac, 0..ALLOWED_MAC_ALGORITHMS.len()),
proptest::collection::hash_set(comp, 0..ALLOWED_COMPRESSION_ALGORITHMS.len()),
proptest::option::of(keyx),
)
.prop_map(|(kex, host, enc, mac, comp, pred)| ServerConfig {
key_exchange_algorithms: finalize_options(kex),
server_host_algorithms: finalize_options(host),
encryption_algorithms: finalize_options(enc),
mac_algorithms: finalize_options(mac),
compression_algorithms: finalize_options(comp),
predict: pred.map(str::to_string),
})
.boxed()
}
}
fn finalize_options(values: HashSet<&str>) -> Option<Vec<String>> {
if values.is_empty() {
None
} else {
Some(values.into_iter().map(str::to_string).collect())
}
}
#[derive(Debug, Error)]
pub enum ConfigFileError {
#[error("Could not open provided config file: {0}")]
CouldNotOpen(std::io::Error),
#[error("Could not read config file: {0}")]
CouldNotRead(std::io::Error),
#[error("Error in config file: {0}")]
ParseError(#[from] toml::de::Error),
}
impl ConfigFile {
/// Try to read in a config file, using the given path if provided, or the XDG-standard
/// path if not.
///
/// If the user didn't provide a config file, and there isn't a config file in the standard
/// XDG place, returns `Ok(None)`.
///
/// This will return errors if there is a parse error in understanding the file, if
/// there's some basic disk error in reading the file, or if the user provided a config
/// file that we couldn't find on disk.
pub fn new(provided_path: Option<PathBuf>) -> Result<Option<Self>, ConfigFileError> {
let config_file = if let Some(path) = provided_path {
let file = std::fs::File::open(path).map_err(ConfigFileError::CouldNotOpen)?;
Some(file)
} else {
let Ok(xdg_base_dirs) = xdg::BaseDirectories::with_prefix("hushd") else {
return Ok(None);
};
let path = xdg_base_dirs.find_config_file("config.toml");
path.and_then(|x| std::fs::File::open(x).ok())
};
let Some(mut config_file) = config_file else {
return Ok(None);
};
let mut contents = String::new();
let _ = config_file
.read_to_string(&mut contents)
.map_err(ConfigFileError::CouldNotRead)?;
let config = toml::from_str(&contents)?;
Ok(Some(config))
}
/// Merge any settings found in the config file into our current configuration.
///
pub fn merge_standard_options_into(
&mut self,
runtime: &mut RuntimeConfiguration,
_logging: &mut LoggingConfiguration,
) {
if let Some(runtime_config) = self.runtime.take() {
runtime.tokio_worker_threads = runtime_config
.worker_threads
.unwrap_or(runtime.tokio_worker_threads);
runtime.tokio_blocking_threads = runtime_config
.blocking_threads
.unwrap_or(runtime.tokio_blocking_threads);
}
}
}
#[test]
fn all_keys_example_parses() {
let path = format!("{}/tests/all_keys.toml", env!("CARGO_MANIFEST_DIR"));
let result = ConfigFile::new(Some(path.into()));
assert!(result.is_ok());
}
proptest::proptest! {
#[test]
fn valid_configs_parse(config in ConfigFile::arbitrary()) {
use std::io::Write;
let mut tempfile = tempfile::NamedTempFile::new().unwrap();
let contents = toml::to_string(&config).unwrap();
tempfile.write_all(contents.as_bytes()).unwrap();
let path = tempfile.into_temp_path();
let parsed = ConfigFile::new(Some(path.to_path_buf())).unwrap().unwrap();
assert_eq!(config, parsed);
}
}

View File

@@ -0,0 +1,78 @@
use crypto::known_algorithms::{
ALLOWED_COMPRESSION_ALGORITHMS, ALLOWED_ENCRYPTION_ALGORITHMS, ALLOWED_HOST_KEY_ALGORITHMS,
ALLOWED_KEY_EXCHANGE_ALGORITHMS, ALLOWED_MAC_ALGORITHMS,
};
use crypto::{
CompressionAlgorithm, EncryptionAlgorithm, HostKeyAlgorithm, KeyExchangeAlgorithm, MacAlgorithm,
};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Strategy};
use std::collections::HashSet;
use std::str::FromStr;
#[derive(Debug)]
pub struct ClientConnectionOpts {
pub key_exchange_algorithms: Vec<KeyExchangeAlgorithm>,
pub server_host_key_algorithms: Vec<HostKeyAlgorithm>,
pub encryption_algorithms: Vec<EncryptionAlgorithm>,
pub mac_algorithms: Vec<MacAlgorithm>,
pub compression_algorithms: Vec<CompressionAlgorithm>,
pub languages: Vec<String>,
pub predict: Option<KeyExchangeAlgorithm>,
}
impl Default for ClientConnectionOpts {
fn default() -> Self {
ClientConnectionOpts {
key_exchange_algorithms: vec![KeyExchangeAlgorithm::Curve25519Sha256],
server_host_key_algorithms: vec![HostKeyAlgorithm::Ed25519],
encryption_algorithms: vec![EncryptionAlgorithm::Aes256Ctr],
mac_algorithms: vec![MacAlgorithm::HmacSha256],
compression_algorithms: vec![],
languages: vec![],
predict: None,
}
}
}
impl Arbitrary for ClientConnectionOpts {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let keyx = proptest::sample::select(ALLOWED_KEY_EXCHANGE_ALGORITHMS);
let hostkey = proptest::sample::select(ALLOWED_HOST_KEY_ALGORITHMS);
let enc = proptest::sample::select(ALLOWED_ENCRYPTION_ALGORITHMS);
let mac = proptest::sample::select(ALLOWED_MAC_ALGORITHMS);
let comp = proptest::sample::select(ALLOWED_COMPRESSION_ALGORITHMS);
(
proptest::collection::hash_set(keyx.clone(), 1..ALLOWED_KEY_EXCHANGE_ALGORITHMS.len()),
proptest::collection::hash_set(hostkey, 1..ALLOWED_HOST_KEY_ALGORITHMS.len()),
proptest::collection::hash_set(enc, 1..ALLOWED_ENCRYPTION_ALGORITHMS.len()),
proptest::collection::hash_set(mac, 1..ALLOWED_MAC_ALGORITHMS.len()),
proptest::collection::hash_set(comp, 1..ALLOWED_COMPRESSION_ALGORITHMS.len()),
proptest::option::of(keyx),
)
.prop_map(|(kex, host, enc, mac, comp, pred)| ClientConnectionOpts {
key_exchange_algorithms: finalize_options(kex).unwrap(),
server_host_key_algorithms: finalize_options(host).unwrap(),
encryption_algorithms: finalize_options(enc).unwrap(),
mac_algorithms: finalize_options(mac).unwrap(),
compression_algorithms: finalize_options(comp).unwrap(),
languages: vec![],
predict: pred.and_then(|x| KeyExchangeAlgorithm::from_str(x).ok()),
})
.boxed()
}
}
fn finalize_options<T, E>(values: HashSet<&str>) -> Result<Vec<T>, E>
where
T: FromStr<Err = E>,
{
values
.into_iter()
.map(T::from_str)
.collect::<Result<Vec<T>, E>>()
}

View File

@@ -0,0 +1,48 @@
use console_subscriber::{ConsoleLayer, ServerAddr};
use core::time::Duration;
use std::path::PathBuf;
use tracing_subscriber::Layer;
pub struct ConsoleConfiguration {
client_buffer_capacity: usize,
publish_interval: Duration,
retention: Duration,
pub server_addr: ServerAddr,
poll_duration_histogram_max: Duration,
}
impl Default for ConsoleConfiguration {
fn default() -> Self {
ConsoleConfiguration {
client_buffer_capacity: ConsoleLayer::DEFAULT_CLIENT_BUFFER_CAPACITY,
publish_interval: ConsoleLayer::DEFAULT_PUBLISH_INTERVAL,
retention: ConsoleLayer::DEFAULT_RETENTION,
server_addr: xdg::BaseDirectories::with_prefix("hushd")
.and_then(|x| x.get_runtime_directory().cloned())
.map(|mut v| {
v.push("console.sock");
v
})
.map(|p| ServerAddr::Unix(p.clone()))
.unwrap_or_else(|_| PathBuf::from("console.sock").into()),
poll_duration_histogram_max: ConsoleLayer::DEFAULT_SCHEDULED_DURATION_MAX,
}
}
}
impl ConsoleConfiguration {
#[cfg(not(tarpaulin_include))]
pub fn layer<S>(&self) -> Box<dyn Layer<S> + Send + Sync + 'static>
where
S: tracing_core::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
{
ConsoleLayer::builder()
.client_buffer_capacity(self.client_buffer_capacity)
.publish_interval(self.publish_interval)
.retention(self.retention)
.server_addr(self.server_addr.clone())
.poll_duration_histogram_max(self.poll_duration_histogram_max)
.spawn()
.boxed()
}
}

View File

@@ -0,0 +1,33 @@
use crate::config_file::ConfigFileError;
use std::path::PathBuf;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ConfigurationError {
#[error(transparent)]
ConfigFile(#[from] ConfigFileError),
#[error(transparent)]
CommandLineError(#[from] clap::error::Error),
#[error("Could not read file {file}: {error}")]
CouldNotRead {
file: PathBuf,
error: std::io::Error,
},
#[error("Error loading public key information")]
PublicKey,
#[error("Error loading private key")]
PrivateKey,
#[error("Error configuring DNS resolver")]
Resolver,
#[error("Could not create UNIX listener socket at {path}: {error}")]
CouldNotMakeSocket {
path: PathBuf,
error: std::io::Error,
},
#[error("Could not {thing} for {path}: {error}")]
CouldNotSetPerms {
thing: String,
path: PathBuf,
error: std::io::Error,
},
}

86
configuration/src/lib.rs Normal file
View File

@@ -0,0 +1,86 @@
pub mod client;
mod command_line;
mod config_file;
pub mod connection;
mod console;
pub mod error;
mod logging;
pub mod resolver;
mod runtime;
pub mod server;
use tokio::runtime::Runtime;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::prelude::*;
//impl ClientConfiguration {
// pub async fn try_from(
// mut basic: BasicClientConfiguration,
// ) -> error_stack::Result<Self, ConfigurationError> {
// let mut _ssh_keys = HashMap::new();
// let mut _defaults = ClientConnectionOpts::default();
// let resolver = basic
// .config_file
// .as_mut()
// .and_then(|x| x.resolver.take())
// .unwrap_or_default();
// let user_ssh_keys = basic.config_file.map(|cf| cf.keys).unwrap_or_default();
//
// tracing::info!(
// provided_ssh_keys = user_ssh_keys.len(),
// "loading user-provided SSH keys"
// );
// for (key_name, key_info) in user_ssh_keys.into_iter() {
// let _public_keys = PublicKey::load(key_info.public).await.unwrap();
// tracing::info!(?key_name, "public keys loaded");
// let _private_key = load_openssh_file_keys(key_info.private, &key_info.password)
// .await
// .change_context(ConfigurationError::PrivateKey)?;
// tracing::info!(?key_name, "private keys loaded");
// }
//
// Ok(ClientConfiguration {
// _runtime: basic.runtime,
// _logging: basic.logging,
// resolver,
// _ssh_keys,
// _defaults,
// command: basic.command,
// })
// }
//}
//
/// Set up the tracing subscribers based on the config file and command line options.
///
/// This will definitely set up our logging substrate, but may also create a subscriber
/// for the console.
#[cfg(not(tarpaulin_include))]
pub fn establish_subscribers(
logging: &logging::LoggingConfiguration,
runtime: &mut runtime::RuntimeConfiguration,
) -> Result<(), std::io::Error> {
let tracing_layer = logging.layer()?;
let mut layers = vec![tracing_layer];
if let Some(console_config) = runtime.console.take() {
layers.push(console_config.layer());
}
tracing_subscriber::registry().with(layers).init();
Ok(())
}
/// Generate a new tokio runtime based on the configuration / command line options
/// provided.
#[cfg(not(tarpaulin_include))]
pub fn configured_runtime(
runtime: &runtime::RuntimeConfiguration,
) -> Result<Runtime, std::io::Error> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.max_blocking_threads(runtime.tokio_blocking_threads)
.worker_threads(runtime.tokio_worker_threads)
.build()
}

View File

@@ -0,0 +1,95 @@
use std::path::PathBuf;
use tracing_core::Subscriber;
use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::{EnvFilter, Layer};
pub struct LoggingConfiguration {
pub(crate) filter: LevelFilter,
pub(crate) include_filename: bool,
pub(crate) include_lineno: bool,
pub(crate) include_thread_ids: bool,
pub(crate) include_thread_names: bool,
pub(crate) mode: LogMode,
pub(crate) target: LogTarget,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum LogMode {
Compact,
Pretty,
Json,
}
#[derive(Clone, Debug, PartialEq)]
pub enum LogTarget {
StdOut,
StdErr,
File(PathBuf),
}
impl LogTarget {
fn supports_ansi(&self) -> bool {
matches!(self, LogTarget::StdOut | LogTarget::StdErr)
}
}
impl Default for LoggingConfiguration {
fn default() -> Self {
LoggingConfiguration {
filter: LevelFilter::INFO,
include_filename: false,
include_lineno: false,
include_thread_ids: true,
include_thread_names: true,
mode: LogMode::Compact,
target: LogTarget::StdErr,
}
}
}
impl LoggingConfiguration {
#[cfg(not(tarpaulin_include))]
pub fn layer<S>(&self) -> Result<Box<dyn Layer<S> + Send + Sync + 'static>, std::io::Error>
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
let filter = EnvFilter::builder()
.with_default_directive(self.filter.into())
.with_env_var("HUSHD_LOG")
.from_env_lossy();
let base = tracing_subscriber::fmt::layer()
.with_file(self.include_filename)
.with_line_number(self.include_lineno)
.with_thread_ids(self.include_thread_ids)
.with_thread_names(self.include_thread_names)
.with_ansi(self.target.supports_ansi());
macro_rules! finalize {
($layer: expr) => {
match self.mode {
LogMode::Compact => Ok($layer.compact().with_filter(filter).boxed()),
LogMode::Json => Ok($layer.json().with_filter(filter).boxed()),
LogMode::Pretty => Ok($layer.pretty().with_filter(filter).boxed()),
}
};
}
match self.target {
LogTarget::StdOut => finalize!(base.with_writer(std::io::stdout)),
LogTarget::StdErr => finalize!(base.with_writer(std::io::stderr)),
LogTarget::File(ref path) => {
let log_file = std::fs::File::create(path)?;
finalize!(base.with_writer(std::sync::Mutex::new(log_file)))
}
}
}
}
#[test]
fn supports_ansi() {
assert!(LogTarget::StdOut.supports_ansi());
assert!(LogTarget::StdErr.supports_ansi());
assert!(!LogTarget::File("/dev/null".into()).supports_ansi());
}

View File

@@ -0,0 +1,359 @@
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use url::Url;
#[derive(Debug, PartialEq, Deserialize, Serialize)]
pub struct DnsConfig {
built_in: Option<BuiltinDnsOption>,
pub local_domain: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub search_domains: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub name_servers: Vec<NameServerConfig>,
#[serde(default)]
pub retry_attempts: Option<u16>,
#[serde(default)]
pub cache_size: Option<u32>,
#[serde(default)]
pub max_concurrent_requests_for_query: Option<u16>,
#[serde(default)]
pub preserve_intermediates: Option<bool>,
#[serde(default)]
pub shuffle_dns_servers: Option<bool>,
#[serde(default)]
pub allow_mdns: Option<bool>,
}
impl Default for DnsConfig {
fn default() -> Self {
DnsConfig {
built_in: Some(BuiltinDnsOption::Cloudflare),
local_domain: None,
search_domains: vec![],
name_servers: vec![],
retry_attempts: None,
cache_size: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
allow_mdns: None,
}
}
}
#[cfg(test)]
impl DnsConfig {
pub(crate) fn empty() -> Self {
DnsConfig {
built_in: None,
local_domain: None,
search_domains: vec![],
name_servers: vec![],
retry_attempts: None,
cache_size: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
allow_mdns: None,
}
}
}
impl Arbitrary for DnsConfig {
type Parameters = bool;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(always_use_builtin: Self::Parameters) -> Self::Strategy {
if always_use_builtin {
BuiltinDnsOption::arbitrary()
.prop_map(|x| DnsConfig {
built_in: Some(x),
local_domain: None,
search_domains: vec![],
name_servers: vec![],
retry_attempts: None,
cache_size: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
allow_mdns: None,
})
.boxed()
} else {
let built_in = proptest::option::of(BuiltinDnsOption::arbitrary());
built_in
.prop_flat_map(|built_in| {
let local_domain = proptest::option::of(domain_name_strat());
let search_domains = proptest::collection::vec(domain_name_strat(), 0..10);
let min_servers = if built_in.is_some() { 0 } else { 1 };
let name_servers =
proptest::collection::vec(NameServerConfig::arbitrary(), min_servers..6);
let retry_attempts = proptest::option::of(u16::arbitrary());
let cache_size = proptest::option::of(u32::arbitrary());
let max_concurrent_requests_for_query = proptest::option::of(u16::arbitrary());
let preserve_intermediates = proptest::option::of(bool::arbitrary());
let shuffle_dns_servers = proptest::option::of(bool::arbitrary());
let allow_mdns = proptest::option::of(bool::arbitrary());
(
local_domain,
search_domains,
name_servers,
retry_attempts,
cache_size,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
allow_mdns,
)
.prop_map(
move |(
local_domain,
search_domains,
name_servers,
retry_attempts,
cache_size,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
allow_mdns,
)| DnsConfig {
built_in,
local_domain,
search_domains,
name_servers,
retry_attempts,
cache_size,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
allow_mdns,
},
)
})
.boxed()
}
}
}
fn domain_name_strat() -> BoxedStrategy<String> {
let chunk = proptest::string::string_regex("[a-zA-Z0-9]{2,32}").unwrap();
let sets = proptest::collection::vec(chunk, 2..6);
sets.prop_map(|set| {
let mut output = String::new();
for x in set.into_iter() {
if !output.is_empty() {
output.push('.');
}
output.push_str(&x);
}
output
})
.boxed()
}
#[derive(Clone, Copy, Debug, PartialEq, Deserialize, Serialize)]
enum BuiltinDnsOption {
Google,
Cloudflare,
Quad9,
}
impl Arbitrary for BuiltinDnsOption {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
proptest::prop_oneof![
Just(BuiltinDnsOption::Google),
Just(BuiltinDnsOption::Cloudflare),
Just(BuiltinDnsOption::Quad9),
]
.boxed()
}
}
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct NameServerConfig {
#[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
pub address: Url,
#[serde(default)]
pub timeout_in_seconds: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bind_address: Option<SocketAddr>,
}
fn serialize_url<S: serde::Serializer>(url: &Url, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_str(url)
}
fn deserialize_url<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Url, D::Error> {
struct Visitor {}
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Url;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a legal URL")
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
Url::parse(v).map_err(|e| E::custom(e))
}
fn visit_borrowed_str<E: serde::de::Error>(self, v: &'de str) -> Result<Self::Value, E> {
Url::parse(v).map_err(|e| E::custom(e))
}
}
deserializer.deserialize_str(Visitor {})
}
impl Arbitrary for NameServerConfig {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
let scheme = proptest::prop_oneof![
Just("http".to_string()),
Just("https".to_string()),
Just("tcp".to_string()),
Just("udp".to_string()),
];
let domain = proptest::string::string_regex("[A-Za-z0-9]{1,40}").unwrap();
let port = proptest::option::of(u16::arbitrary());
let user = proptest::option::of(proptest::string::string_regex("[A-Za-z]{1,30}").unwrap());
let password =
proptest::option::of(proptest::string::string_regex(":[A-Za-z]{1,30}").unwrap());
let path =
proptest::option::of(proptest::string::string_regex("/[A-Za-z/]{1,30}").unwrap());
let uri_strategy = (scheme, domain, user, password, port, path).prop_map(
|(scheme, domain, user, password, port, path)| {
let userpass_prefix = match (user, password) {
(None, None) => String::new(),
(Some(u), None) => format!("{u}@"),
(None, Some(p)) => format!(":{p}@"),
(Some(u), Some(p)) => format!("{u}:{p}@"),
};
let path = path.unwrap_or_default();
let port = port.map(|x| format!(":{x}")).unwrap_or_default();
let uri_str = format!("{scheme}://{userpass_prefix}{domain}{port}{path}");
Url::parse(&uri_str).unwrap()
},
);
(
uri_strategy,
proptest::option::of(u64::arbitrary()),
proptest::option::of(SocketAddr::arbitrary()),
)
.prop_map(|(address, mut timeout_in_seconds, mut bind_address)| {
if let Some(bind_address) = bind_address.as_mut() {
clear_flow_and_scope_info(bind_address);
}
if let Some(timeout_in_seconds) = timeout_in_seconds.as_mut() {
*timeout_in_seconds &= 0x7FFF_FFFF_FFFF_FFFF;
}
NameServerConfig {
address,
timeout_in_seconds,
bind_address,
}
})
.boxed()
}
}
fn clear_flow_and_scope_info(address: &mut SocketAddr) {
if let SocketAddr::V6(addr) = address {
addr.set_flowinfo(0);
addr.set_scope_id(0);
}
}
proptest::proptest! {
#[test]
fn valid_configs_parse(config in DnsConfig::arbitrary_with(false)) {
let toml = toml::to_string(&config).unwrap();
let reversed: DnsConfig = toml::from_str(&toml).unwrap();
assert_eq!(config, reversed);
}
}
impl DnsConfig {
/// Return the configurations for all of the name servers that the user
/// has selected.
pub fn name_servers(&self) -> Vec<NameServerConfig> {
let mut results = self.name_servers.clone();
match self.built_in {
None => {}
Some(BuiltinDnsOption::Cloudflare) => {
results.push(NameServerConfig {
address: Url::parse("udp://1.1.1.1:53").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: Url::parse("udp://1.0.0.1").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: Url::parse("udp://2606:4700:4700::1111").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: Url::parse("udp://2606:4700:4700::1001:53").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
}
Some(BuiltinDnsOption::Google) => {
results.push(NameServerConfig {
address: Url::parse("udp://8.8.8.8:53").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: Url::parse("udp://8.8.4.4").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: Url::parse("udp://2001:4860:4860::8888:53").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: Url::parse("udp://2001:4860:4860::8844").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
}
Some(BuiltinDnsOption::Quad9) => {
results.push(NameServerConfig {
address: Url::parse("udp://9.9.9.9:53").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: Url::parse("udp://2620::00fe:00f3").unwrap(),
timeout_in_seconds: None,
bind_address: None,
});
}
}
results
}
}

View File

@@ -0,0 +1,17 @@
use crate::console::ConsoleConfiguration;
pub struct RuntimeConfiguration {
pub tokio_worker_threads: usize,
pub tokio_blocking_threads: usize,
pub console: Option<ConsoleConfiguration>,
}
impl Default for RuntimeConfiguration {
fn default() -> Self {
RuntimeConfiguration {
tokio_worker_threads: 4,
tokio_blocking_threads: 16,
console: None,
}
}
}

View File

@@ -0,0 +1,69 @@
use crate::command_line::ServerArguments;
use crate::config_file::{ConfigFile, SocketConfig};
use crate::error::ConfigurationError;
use crate::logging::LoggingConfiguration;
use crate::runtime::RuntimeConfiguration;
use clap::Parser;
use std::collections::HashMap;
use std::ffi::OsString;
use tokio::net::UnixListener;
#[derive(Default)]
pub struct ServerConfiguration {
pub runtime: RuntimeConfiguration,
pub logging: LoggingConfiguration,
pub sockets: HashMap<String, SocketConfig>,
}
impl ServerConfiguration {
/// Load a server configuration for this run.
///
/// This will parse the process's command line arguments, and parse
/// a config file if given, so it can take awhile. Even though this
/// function does a bunch of IO, it is not async, because it is expected
/// ro run before we have a tokio runtime fully established. (This
/// function will determine a bunch of things related to the runtime,
/// like how many threads to run, what tracing subscribers to include,
/// etc.)
pub fn new<I, T>(args: I) -> Result<Self, ConfigurationError>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let mut new_configuration = Self::default();
let mut command_line = ServerArguments::try_parse_from(args)?;
let mut config_file = ConfigFile::new(command_line.arguments.config_file.take())?;
// we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line,
// it overwrites any config file options.
if let Some(config_file) = config_file.as_mut() {
config_file.merge_standard_options_into(
&mut new_configuration.runtime,
&mut new_configuration.logging,
);
new_configuration.sockets = config_file.sockets.take().unwrap_or_default();
}
command_line.arguments.merge_standard_options_into(
&mut new_configuration.runtime,
&mut new_configuration.logging,
);
Ok(new_configuration)
}
/// Generate a series of UNIX sockets that clients will attempt to connect to.
pub async fn generate_listener_sockets(
&mut self,
) -> Result<HashMap<String, UnixListener>, ConfigurationError> {
let mut results = HashMap::new();
for (name, config) in std::mem::take(&mut self.sockets).into_iter() {
results.insert(name, config.into_listener().await?);
}
Ok(results)
}
}