This builds, I guess.

This commit is contained in:
2025-01-12 14:10:23 -08:00
parent b30823a502
commit 268ca2d1a5
25 changed files with 1750 additions and 735 deletions

911
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -26,10 +26,11 @@ error-stack = "0.5.0"
futures = "0.3.31"
generic-array = "0.14.7"
hexdump = "0.1.2"
hickory-client = { version = "0.24.1", features = ["mdns"] }
hickory-proto = "0.24.1"
hickory-resolver = "0.24.1"
hostname-validator = "1.1.1"
itertools = "0.13.0"
moka = { version = "0.12.10", features = ["future"] }
nix = { version = "0.28.0", features = ["user"] }
num-bigint-dig = { version = "0.8.4", features = ["arbitrary", "i128", "zeroize", "prime", "rand"] }
num-integer = { version = "0.1.46", features = ["i128"] }
num-traits = { version = "0.2.19", features = ["i128"] }
@@ -40,10 +41,11 @@ p521 = { version = "0.13.3", features = ["ecdh", "ecdsa-core", "hash2curve", "se
proptest = "1.5.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
rustix = "0.38.41"
sec1 = "0.7.3"
serde = { version = "1.0.203", features = ["derive"] }
tempfile = "3.12.0"
thiserror = "1.0.61"
thiserror = "2.0.3"
tokio = { version = "1.38.0", features = ["full", "tracing"] }
toml = "0.8.14"
tracing = "0.1.40"

View File

@@ -1,9 +1,9 @@
use error_stack::ResultExt;
use hush::config::{BasicClientConfiguration, ClientConfiguration};
use hush::config::client::ClientConfiguration;
use std::process;
#[cfg(not(tarpaulin_include))]
fn main() {
let mut base_config = match BasicClientConfiguration::new(std::env::args()) {
let mut config = match ClientConfiguration::new(std::env::args()) {
Ok(config) => config,
Err(e) => {
eprintln!("ERROR: {}", e);
@@ -11,29 +11,26 @@ fn main() {
}
};
if let Err(e) = base_config.establish_subscribers() {
if let Err(e) = hush::config::establish_subscribers(&config.logging, &mut config.runtime) {
eprintln!("ERROR: could not set up logging infrastructure: {}", e);
std::process::exit(2);
process::exit(2);
}
let runtime = match base_config.configured_runtime() {
let runtime = match hush::config::configured_runtime(&config.runtime) {
Ok(runtime) => runtime,
Err(e) => {
tracing::error!(%e, "could not start system runtime");
std::process::exit(3);
process::exit(3);
}
};
let result = runtime.block_on(async move {
tracing::info!("Starting Hush");
match ClientConfiguration::try_from(base_config).await {
Ok(config) => hush::client::hush(config).await,
Err(e) => Err(e).change_context(hush::OperationalError::ConfigurationError),
}
hush::client::hush(config).await
});
if let Err(e) = result {
tracing::error!("{}", e);
std::process::exit(1);
process::exit(1);
}
}

View File

@@ -1,4 +1,69 @@
use futures::stream::StreamExt;
use hickory_proto::udp::UdpStream;
use hush::config::server::ServerConfiguration;
use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr};
use std::process;
use thiserror::Error;
use tokio::net::UdpSocket;
async fn run_dns() -> tokio::io::Result<()> {
let socket = UdpSocket::bind("127.0.0.1:3553").await?;
let mut buffer = [0; 65536];
tracing::info!("Bound socket at {}, starting main DNS handler loop.", socket.local_addr()?);
let (first_len, first_addr) = socket.recv_from(&mut buffer).await?;
println!("Received {} bytes from {}", first_len, first_addr);
let remote = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3553));
let (mut stream, _handle) = UdpStream::with_bound(socket, remote);
while let Some(item) = stream.next().await {
match item {
Err(e) => eprintln!("Got an error: {:?}", e),
Ok(msg) => {
println!("Got a message from {}", msg.addr());
match msg.to_message() {
Err(e) => println!(" ... but it was malformed ({e})."),
Ok(v) => println!(" ... and it was {v}"),
}
}
}
}
Ok(())
}
#[cfg(not(tarpaulin_include))]
fn main() {
unimplemented!();
let mut config = match ServerConfiguration::new(std::env::args()) {
Ok(config) => config,
Err(e) => {
eprintln!("ERROR: {}", e);
process::exit(1);
}
};
if let Err(e) = hush::config::establish_subscribers(&config.logging, &mut config.runtime) {
eprintln!("ERROR: could not set up logging infrastruxture: {}", e);
process::exit(2);
};
let runtime = match hush::config::configured_runtime(&config.runtime) {
Ok(runtime) => runtime,
Err(error) => {
tracing::error!(%error, "could not start system runtime");
process::exit(3);
}
};
let result = runtime.block_on(async move {
let version = std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "<unknown>".to_string());
tracing::info!(%version, "Starting Hush server (hushd)");
hush::server::run(config).await
});
if let Err(e) = result {
tracing::error!("{}", e);
process::exit(5);
}
}

View File

@@ -1,12 +1,13 @@
use crate::config::{ClientCommand, ClientConfiguration};
use crate::config::client::{ClientCommand, ClientConfiguration};
use crate::crypto::{
CompressionAlgorithm, EncryptionAlgorithm, HostKeyAlgorithm, KeyExchangeAlgorithm, MacAlgorithm,
};
use crate::network::host::Host;
use crate::ssh;
use crate::network::resolver::Resolver;
use crate::ssh::{self, SshKeyExchange};
use crate::OperationalError;
use error_stack::{report, Report, ResultExt};
use std::fmt::Display;
use std::fmt;
use std::str::FromStr;
use thiserror::Error;
use tracing::Instrument;
@@ -27,12 +28,19 @@ pub async fn hush(base_config: ClientConfiguration) -> error_stack::Result<(), O
}
}
#[derive(Debug)]
struct Target {
username: String,
host: Host,
port: u16,
}
impl fmt::Display for Target {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}@{}:{}", self.username, self.host, self.port)
}
}
#[derive(Debug, Error)]
pub enum TargetParseError {
#[error("Invalid port number '{port_string}': {error}")]
@@ -87,35 +95,54 @@ impl FromStr for Target {
}
}
#[allow(unreachable_code,unused_variables)]
async fn connect(
base_config: &ClientConfiguration,
target: &str,
) -> error_stack::Result<(), OperationalError> {
let resolver = base_config.resolver();
let resolver: Resolver = unimplemented!();
// let mut resolver = Resolver::new(&base_config.resolver)
// .await
// .change_context(OperationalError::DnsConfig)?;
let target = Target::from_str(target)
.change_context(OperationalError::UnableToParseHostAddress)
.attach_printable_lazy(|| format!("target address '{}'", target))?;
tracing::trace!(%target, "determined SSH target");
let mut stream = target
.host
.connect(resolver, 22)
.connect(&mut resolver, 22)
.await
.change_context(OperationalError::Connection)?;
let server_addr_str = stream
.peer_addr()
.map(|x| x.to_string())
.unwrap_or_else(|_| "<unknown>".to_string());
let client_addr_str = stream
.local_addr()
.map(|x| x.to_string())
.unwrap_or_else(|_| "<unknown>".to_string());
let stream_span = tracing::debug_span!(
"client connection",
server = %stream.peer_addr().map(|x| x.to_string()).unwrap_or_else(|_| "<unknown>".to_string()),
client = %stream.local_addr().map(|x| x.to_string()).unwrap_or_else(|_| "<unknown>".to_string()),
server = %server_addr_str,
client = %client_addr_str,
);
tracing::trace!(%target, "connected to target server");
let stream_error_info = || server_addr_str.clone();
let their_preamble = ssh::Preamble::read(&mut stream)
.instrument(stream_span.clone())
.await
.change_context(OperationalError::Connection)?;
tracing::trace!("received their preamble");
let our_preamble = ssh::Preamble::default()
let _our_preamble = ssh::Preamble::default()
.write(&mut stream)
.instrument(stream_span)
.await
.change_context(OperationalError::Connection)?;
tracing::trace!("wrote our preamble");
if !their_preamble.preamble.is_empty() {
for line in their_preamble.preamble.lines() {
@@ -130,13 +157,22 @@ async fn connect(
"received server preamble"
);
let mut stream = ssh::SshChannel::new(stream);
let stream = ssh::SshChannel::new(stream);
let their_initial = stream
.read()
.await
.attach_printable_lazy(stream_error_info)
.change_context(OperationalError::KeyExchange)?
.ok_or_else();
.ok_or_else(|| report!(OperationalError::KeyExchange))
.attach_printable_lazy(stream_error_info)
.attach_printable_lazy(|| "No initial key exchange message found")?;
let their_kex: SshKeyExchange = SshKeyExchange::try_from(their_initial)
.attach_printable_lazy(stream_error_info)
.change_context(OperationalError::KeyExchange)?;
println!("their_key: {:?}", their_kex);
// let mut rng = rand::thread_rng();
//
// let packet = stream
@@ -165,7 +201,7 @@ async fn connect(
Ok(())
}
fn print_options<T: Display>(items: &[T]) -> error_stack::Result<(), OperationalError> {
fn print_options<T: fmt::Display>(items: &[T]) -> error_stack::Result<(), OperationalError> {
for item in items.iter() {
println!("{}", item);
}

View File

@@ -1,89 +1,69 @@
pub mod client;
mod command_line;
mod config_file;
pub mod connection;
mod console;
mod error;
pub mod error;
mod logging;
mod resolver;
pub mod resolver;
mod runtime;
pub mod server;
use crate::config::console::ConsoleConfiguration;
use crate::config::logging::LoggingConfiguration;
use crate::crypto::known_algorithms::{
ALLOWED_COMPRESSION_ALGORITHMS, ALLOWED_ENCRYPTION_ALGORITHMS, ALLOWED_HOST_KEY_ALGORITHMS,
ALLOWED_KEY_EXCHANGE_ALGORITHMS, ALLOWED_MAC_ALGORITHMS,
};
use crate::crypto::{
CompressionAlgorithm, EncryptionAlgorithm, HostKeyAlgorithm, KeyExchangeAlgorithm, MacAlgorithm,
};
use crate::encodings::ssh::{load_openssh_file_keys, PublicKey};
use clap::Parser;
use config_file::ConfigFile;
use error_stack::ResultExt;
use hickory_resolver::TokioAsyncResolver;
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Strategy};
use std::collections::{HashMap, HashSet};
use std::ffi::OsString;
use std::str::FromStr;
use tokio::runtime::Runtime;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::prelude::*;
use tokio::runtime::Runtime;
pub use self::command_line::ClientCommand;
use self::command_line::CommandLineArguments;
pub use self::error::ConfigurationError;
pub struct BasicClientConfiguration {
runtime: RuntimeConfiguration,
logging: LoggingConfiguration,
config_file: Option<ConfigFile>,
command: ClientCommand,
}
impl BasicClientConfiguration {
/// Load a basic client configuration for this run.
///
/// This will parse the process's command line arguments, and parse
/// a config file if given, but will not interpret this information
/// beyond that required to understand the user's goals for the
/// runtime system and logging. For this reason, it is not async,
/// as it is responsible for determining all the information we
/// will use to generate the runtime.
pub fn new<I, T>(args: I) -> Result<Self, ConfigurationError>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let mut command_line_arguments = CommandLineArguments::try_parse_from(args)?;
let mut config_file = ConfigFile::new(command_line_arguments.config_file.take())?;
let mut runtime = RuntimeConfiguration::default();
let mut logging = LoggingConfiguration::default();
// we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line,
// it overwrites any config file options.
if let Some(config_file) = config_file.as_mut() {
config_file.merge_standard_options_into(&mut runtime, &mut logging);
}
command_line_arguments.merge_standard_options_into(&mut runtime, &mut logging);
Ok(BasicClientConfiguration {
runtime,
logging,
config_file,
command: command_line_arguments.command,
})
}
//impl ClientConfiguration {
// pub async fn try_from(
// mut basic: BasicClientConfiguration,
// ) -> error_stack::Result<Self, ConfigurationError> {
// let mut _ssh_keys = HashMap::new();
// let mut _defaults = ClientConnectionOpts::default();
// let resolver = basic
// .config_file
// .as_mut()
// .and_then(|x| x.resolver.take())
// .unwrap_or_default();
// let user_ssh_keys = basic.config_file.map(|cf| cf.keys).unwrap_or_default();
//
// tracing::info!(
// provided_ssh_keys = user_ssh_keys.len(),
// "loading user-provided SSH keys"
// );
// for (key_name, key_info) in user_ssh_keys.into_iter() {
// let _public_keys = PublicKey::load(key_info.public).await.unwrap();
// tracing::info!(?key_name, "public keys loaded");
// let _private_key = load_openssh_file_keys(key_info.private, &key_info.password)
// .await
// .change_context(ConfigurationError::PrivateKey)?;
// tracing::info!(?key_name, "private keys loaded");
// }
//
// Ok(ClientConfiguration {
// _runtime: basic.runtime,
// _logging: basic.logging,
// resolver,
// _ssh_keys,
// _defaults,
// command: basic.command,
// })
// }
//}
//
/// Set up the tracing subscribers based on the config file and command line options.
///
/// This will definitely set up our logging substrate, but may also create a subscriber
/// for the console.
#[cfg(not(tarpaulin_include))]
pub fn establish_subscribers(&mut self) -> Result<(), std::io::Error> {
let tracing_layer = self.logging.layer()?;
pub fn establish_subscribers(
logging: &logging::LoggingConfiguration,
runtime: &mut runtime::RuntimeConfiguration,
) -> Result<(), std::io::Error> {
let tracing_layer = logging.layer()?;
let mut layers = vec![tracing_layer];
if let Some(console_config) = self.runtime.console.take() {
if let Some(console_config) = runtime.console.take() {
layers.push(console_config.layer());
}
@@ -95,204 +75,13 @@ impl BasicClientConfiguration {
/// Generate a new tokio runtime based on the configuration / command line options
/// provided.
#[cfg(not(tarpaulin_include))]
pub fn configured_runtime(&mut self) -> Result<Runtime, std::io::Error> {
pub fn configured_runtime(
runtime: &runtime::RuntimeConfiguration,
) -> Result<Runtime, std::io::Error> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.max_blocking_threads(self.runtime.tokio_blocking_threads)
.worker_threads(self.runtime.tokio_worker_threads)
.max_blocking_threads(runtime.tokio_blocking_threads)
.worker_threads(runtime.tokio_worker_threads)
.build()
}
}
#[derive(Debug)]
pub struct ClientConnectionOpts {
pub key_exchange_algorithms: Vec<KeyExchangeAlgorithm>,
pub server_host_key_algorithms: Vec<HostKeyAlgorithm>,
pub encryption_algorithms: Vec<EncryptionAlgorithm>,
pub mac_algorithms: Vec<MacAlgorithm>,
pub compression_algorithms: Vec<CompressionAlgorithm>,
pub languages: Vec<String>,
pub predict: Option<KeyExchangeAlgorithm>,
}
impl Default for ClientConnectionOpts {
fn default() -> Self {
ClientConnectionOpts {
key_exchange_algorithms: vec![KeyExchangeAlgorithm::Curve25519Sha256],
server_host_key_algorithms: vec![HostKeyAlgorithm::Ed25519],
encryption_algorithms: vec![EncryptionAlgorithm::Aes256Ctr],
mac_algorithms: vec![MacAlgorithm::HmacSha256],
compression_algorithms: vec![],
languages: vec![],
predict: None,
}
}
}
impl Arbitrary for ClientConnectionOpts {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let keyx = proptest::sample::select(ALLOWED_KEY_EXCHANGE_ALGORITHMS);
let hostkey = proptest::sample::select(ALLOWED_HOST_KEY_ALGORITHMS);
let enc = proptest::sample::select(ALLOWED_ENCRYPTION_ALGORITHMS);
let mac = proptest::sample::select(ALLOWED_MAC_ALGORITHMS);
let comp = proptest::sample::select(ALLOWED_COMPRESSION_ALGORITHMS);
(
proptest::collection::hash_set(keyx.clone(), 1..ALLOWED_KEY_EXCHANGE_ALGORITHMS.len()),
proptest::collection::hash_set(hostkey, 1..ALLOWED_HOST_KEY_ALGORITHMS.len()),
proptest::collection::hash_set(enc, 1..ALLOWED_ENCRYPTION_ALGORITHMS.len()),
proptest::collection::hash_set(mac, 1..ALLOWED_MAC_ALGORITHMS.len()),
proptest::collection::hash_set(comp, 1..ALLOWED_COMPRESSION_ALGORITHMS.len()),
proptest::option::of(keyx),
)
.prop_map(|(kex, host, enc, mac, comp, pred)| ClientConnectionOpts {
key_exchange_algorithms: finalize_options(kex).unwrap(),
server_host_key_algorithms: finalize_options(host).unwrap(),
encryption_algorithms: finalize_options(enc).unwrap(),
mac_algorithms: finalize_options(mac).unwrap(),
compression_algorithms: finalize_options(comp).unwrap(),
languages: vec![],
predict: pred.and_then(|x| KeyExchangeAlgorithm::from_str(x).ok()),
})
.boxed()
}
}
fn finalize_options<T, E>(inputs: HashSet<&str>) -> Result<Vec<T>, E>
where
T: FromStr<Err = E>,
{
let mut result = vec![];
for item in inputs.into_iter() {
let item = T::from_str(item)?;
result.push(item)
}
Ok(result)
}
#[derive(Default)]
pub struct ServerConfiguration {
runtime: RuntimeConfiguration,
logging: LoggingConfiguration,
}
pub struct RuntimeConfiguration {
tokio_worker_threads: usize,
tokio_blocking_threads: usize,
console: Option<ConsoleConfiguration>,
}
impl Default for RuntimeConfiguration {
fn default() -> Self {
RuntimeConfiguration {
tokio_worker_threads: 4,
tokio_blocking_threads: 16,
console: None,
}
}
}
impl ServerConfiguration {
/// Load a server configuration for this run.
///
/// This will parse the process's command line arguments, and parse
/// a config file if given, so it can take awhile. Even though this
/// function does a bunch of IO, it is not async, because it is expected
/// ro run before we have a tokio runtime fully established. (This
/// function will determine a bunch of things related to the runtime,
/// like how many threads to run, what tracing subscribers to include,
/// etc.)
pub fn new() -> Result<Self, ConfigurationError> {
let mut new_configuration = Self::default();
let mut command_line = CommandLineArguments::parse();
let mut config_file = ConfigFile::new(command_line.config_file.take())?;
// we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line,
// it overwrites any config file options.
if let Some(config_file) = config_file.as_mut() {
config_file.merge_standard_options_into(
&mut new_configuration.runtime,
&mut new_configuration.logging,
);
}
command_line.merge_standard_options_into(
&mut new_configuration.runtime,
&mut new_configuration.logging,
);
Ok(new_configuration)
}
}
impl BasicClientConfiguration {}
pub struct ClientConfiguration {
_runtime: RuntimeConfiguration,
_logging: LoggingConfiguration,
resolver: TokioAsyncResolver,
_ssh_keys: HashMap<String, SshKey>,
_defaults: ClientConnectionOpts,
command: ClientCommand,
}
impl ClientConfiguration {
pub async fn try_from(
mut basic: BasicClientConfiguration,
) -> error_stack::Result<Self, ConfigurationError> {
let mut _ssh_keys = HashMap::new();
let mut _defaults = ClientConnectionOpts::default();
let resolver = basic
.config_file
.as_mut()
.and_then(|x| x.resolver.take())
.unwrap_or_default()
.resolver()
.change_context(ConfigurationError::Resolver)?;
let user_ssh_keys = basic
.config_file
.map(|cf| cf.keys)
.unwrap_or_default();
tracing::info!(
provided_ssh_keys = user_ssh_keys.len(),
"loading user-provided SSH keys"
);
for (key_name, key_info) in user_ssh_keys.into_iter() {
let _public_keys = PublicKey::load(key_info.public).await.unwrap();
tracing::info!(?key_name, "public keys loaded");
let _private_key = load_openssh_file_keys(key_info.private, &key_info.password)
.await
.change_context(ConfigurationError::PrivateKey)?;
tracing::info!(?key_name, "private keys loaded");
}
Ok(ClientConfiguration {
_runtime: basic.runtime,
_logging: basic.logging,
resolver,
_ssh_keys,
_defaults,
command: basic.command,
})
}
pub fn command(&self) -> &ClientCommand {
&self.command
}
pub fn resolver(&self) -> &TokioAsyncResolver {
&self.resolver
}
}
pub enum SshKey {
Ed25519 {},
Ecdsa {},
Rsa {},
}

60
src/config/client.rs Normal file
View File

@@ -0,0 +1,60 @@
use clap::Parser;
use crate::config::config_file::ConfigFile;
use crate::config::error::ConfigurationError;
use crate::config::logging::LoggingConfiguration;
use crate::config::runtime::RuntimeConfiguration;
use crate::config::command_line::ClientArguments;
pub use crate::config::command_line::ClientCommand;
use std::ffi::OsString;
use tracing_core::LevelFilter;
pub struct ClientConfiguration {
pub runtime: RuntimeConfiguration,
pub logging: LoggingConfiguration,
config_file: Option<ConfigFile>,
command: ClientCommand,
}
impl ClientConfiguration {
/// Load a basic client configuration for this run.
///
/// This will parse the process's command line arguments, and parse
/// a config file if given, but will not interpret this information
/// beyond that required to understand the user's goals for the
/// runtime system and logging. For this reason, it is not async,
/// as it is responsible for determining all the information we
/// will use to generate the runtime.
pub fn new<I, T>(args: I) -> Result<Self, ConfigurationError>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let mut command_line_arguments = ClientArguments::try_parse_from(args)?;
let mut config_file = ConfigFile::new(command_line_arguments.arguments.config_file.take())?;
let mut runtime = RuntimeConfiguration::default();
let mut logging = LoggingConfiguration::default();
// we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line,
// it overwrites any config file options.
if let Some(config_file) = config_file.as_mut() {
config_file.merge_standard_options_into(&mut runtime, &mut logging);
}
command_line_arguments.arguments.merge_standard_options_into(&mut runtime, &mut logging);
if command_line_arguments.command.is_list_command() {
logging.filter = LevelFilter::ERROR;
}
Ok(ClientConfiguration {
runtime,
logging,
config_file,
command: command_line_arguments.command,
})
}
// Returns the command we received from the user
pub fn command(&self) -> &ClientCommand {
&self.command
}
}

View File

@@ -1,6 +1,7 @@
use crate::config::logging::{LogTarget, LoggingConfiguration};
use crate::config::{ConsoleConfiguration, RuntimeConfiguration};
use clap::{Parser, Subcommand};
use crate::config::console::ConsoleConfiguration;
use crate::config::runtime::RuntimeConfiguration;
use clap::{Args, Parser, Subcommand};
use std::net::SocketAddr;
use std::path::PathBuf;
#[cfg(test)]
@@ -9,6 +10,23 @@ use tracing_core::LevelFilter;
#[derive(Parser, Debug, Default)]
#[command(version, about, long_about = None)]
pub struct ServerArguments {
#[command(flatten)]
pub arguments: CommandLineArguments,
}
#[derive(Parser, Debug, Default)]
#[command(version, about, long_about = None)]
pub struct ClientArguments {
#[command(flatten)]
pub arguments: CommandLineArguments,
/// The command we're running as the client
#[command(subcommand)]
pub command: ClientCommand,
}
#[derive(Args, Debug, Default)]
pub struct CommandLineArguments {
/// The config file to use for this command.
#[arg(short, long)]
@@ -45,10 +63,6 @@ pub struct CommandLineArguments {
/// A unix domain socket address to use for tokio-console inspection.
#[arg(short = 'u', long, group = "console")]
console_unix_socket: Option<PathBuf>,
/// The command we're running as the client
#[command(subcommand)]
pub command: ClientCommand,
}
#[derive(Debug, Default, Subcommand)]
@@ -102,8 +116,6 @@ impl CommandLineArguments {
if let Some(log_level) = self.log_level.take() {
logging_config.filter = log_level;
} else if self.command.is_list_command() {
logging_config.filter = LevelFilter::ERROR;
}
if (self.console_network_server.is_some() || self.console_unix_socket.is_some())

View File

@@ -1,6 +1,7 @@
use crate::config::error::ConfigurationError;
use crate::config::logging::{LogMode, LogTarget, LoggingConfiguration};
use crate::config::resolver::DnsConfig;
use crate::config::RuntimeConfiguration;
use crate::config::runtime::RuntimeConfiguration;
use crate::crypto::known_algorithms::{
ALLOWED_COMPRESSION_ALGORITHMS, ALLOWED_ENCRYPTION_ALGORITHMS, ALLOWED_HOST_KEY_ALGORITHMS,
ALLOWED_KEY_EXCHANGE_ALGORITHMS, ALLOWED_MAC_ALGORITHMS,
@@ -21,6 +22,7 @@ pub struct ConfigFile {
runtime: Option<RuntimeConfig>,
logging: Option<LoggingConfig>,
pub resolver: Option<DnsConfig>,
pub sockets: Option<HashMap<String, SocketConfig>>,
pub keys: HashMap<String, KeyConfig>,
defaults: Option<ServerConfig>,
servers: HashMap<String, ServerConfig>,
@@ -35,15 +37,17 @@ impl Arbitrary for ConfigFile {
any::<Option<RuntimeConfig>>(),
any::<Option<LoggingConfig>>(),
any::<Option<DnsConfig>>(),
proptest::option::of(keyed_section(SocketConfig::arbitrary())),
keyed_section(KeyConfig::arbitrary()),
any::<Option<ServerConfig>>(),
keyed_section(ServerConfig::arbitrary()),
)
.prop_map(
|(runtime, logging, resolver, keys, defaults, servers)| ConfigFile {
|(runtime, logging, resolver,sockets, keys, defaults, servers)| ConfigFile {
runtime,
logging,
resolver,
sockets,
keys,
defaults,
servers,
@@ -65,6 +69,141 @@ where
.boxed()
}
#[derive(Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct SocketConfig {
path: PathBuf,
user: Option<String>,
group: Option<String>,
permissions: Option<Permissions>,
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub enum Permissions {
User,
Group,
UserGroup,
Everyone,
}
impl From<Permissions> for rustix::fs::Mode {
fn from(value: Permissions) -> Self {
match value {
Permissions::User =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR,
Permissions::Group =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR,
Permissions::UserGroup =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR |
rustix::fs::Mode::RGRP |
rustix::fs::Mode::WGRP,
Permissions::Everyone =>
rustix::fs::Mode::RUSR |
rustix::fs::Mode::WUSR |
rustix::fs::Mode::RGRP |
rustix::fs::Mode::WGRP |
rustix::fs::Mode::ROTH |
rustix::fs::Mode::WOTH,
}
}
}
impl SocketConfig {
pub async fn into_listener(self) -> Result<tokio::net::UnixListener, ConfigurationError> {
let base = tokio::net::UnixListener::bind(&self.path)
.map_err(|error| ConfigurationError::CouldNotMakeSocket {
path: self.path.clone(),
error,
})?;
let user = self.user.map(|x| {
nix::unistd::User::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
}))
}).transpose()?;
let group = self.group.map(|x| {
nix::unistd::Group::from_name(&x)
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: error.into(),
})
.transpose()
.unwrap_or_else(|| Err(ConfigurationError::CouldNotSetPerms {
thing: "find user".to_string(),
path: self.path.clone(),
error: std::io::Error::new(std::io::ErrorKind::NotFound, "could not find user"),
}))
}).transpose()?;
if user.is_some() || group.is_some() {
std::os::unix::fs::chown(&self.path, user.map(|x| x.uid.as_raw()), group.map(|x| x.gid.as_raw()))
.map_err(|error|
ConfigurationError::CouldNotSetPerms {
thing: "set user/group ownership".to_string(),
path: self.path.clone(),
error,
}
)?;
}
let perms = self.permissions.unwrap_or(Permissions::UserGroup);
rustix::fs::chmod(&self.path, perms.into())
.map_err(|error| ConfigurationError::CouldNotSetPerms {
thing: "set permissions".to_string(),
path: self.path.clone(),
error: error.into(),
})?;
Ok(base)
}
}
impl Arbitrary for SocketConfig {
type Strategy = BoxedStrategy<SocketConfig>;
type Parameters = ();
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(PathBuf::arbitrary(),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(proptest::string::string_regex("[a-zA-Z0-9]{1,30}").unwrap()),
proptest::option::of(Permissions::arbitrary()))
.prop_map(|(path, user, group, permissions)|
SocketConfig{ path, user, group, permissions })
.boxed()
}
}
impl Arbitrary for Permissions {
type Strategy = BoxedStrategy<Permissions>;
type Parameters = ();
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
proptest::prop_oneof![
proptest::strategy::Just(Permissions::User),
proptest::strategy::Just(Permissions::Group),
proptest::strategy::Just(Permissions::UserGroup),
proptest::strategy::Just(Permissions::Everyone),
].boxed()
}
}
#[derive(Debug, Deserialize, PartialEq, Serialize)]
pub struct RuntimeConfig {
worker_threads: Option<usize>,

77
src/config/connection.rs Normal file
View File

@@ -0,0 +1,77 @@
use crate::crypto::known_algorithms::{
ALLOWED_COMPRESSION_ALGORITHMS, ALLOWED_ENCRYPTION_ALGORITHMS, ALLOWED_HOST_KEY_ALGORITHMS,
ALLOWED_KEY_EXCHANGE_ALGORITHMS, ALLOWED_MAC_ALGORITHMS,
};
use crate::crypto::{
CompressionAlgorithm, EncryptionAlgorithm, HostKeyAlgorithm, KeyExchangeAlgorithm, MacAlgorithm,
};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Strategy};
use std::collections::HashSet;
use std::str::FromStr;
#[derive(Debug)]
pub struct ClientConnectionOpts {
pub key_exchange_algorithms: Vec<KeyExchangeAlgorithm>,
pub server_host_key_algorithms: Vec<HostKeyAlgorithm>,
pub encryption_algorithms: Vec<EncryptionAlgorithm>,
pub mac_algorithms: Vec<MacAlgorithm>,
pub compression_algorithms: Vec<CompressionAlgorithm>,
pub languages: Vec<String>,
pub predict: Option<KeyExchangeAlgorithm>,
}
impl Default for ClientConnectionOpts {
fn default() -> Self {
ClientConnectionOpts {
key_exchange_algorithms: vec![KeyExchangeAlgorithm::Curve25519Sha256],
server_host_key_algorithms: vec![HostKeyAlgorithm::Ed25519],
encryption_algorithms: vec![EncryptionAlgorithm::Aes256Ctr],
mac_algorithms: vec![MacAlgorithm::HmacSha256],
compression_algorithms: vec![],
languages: vec![],
predict: None,
}
}
}
impl Arbitrary for ClientConnectionOpts {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
let keyx = proptest::sample::select(ALLOWED_KEY_EXCHANGE_ALGORITHMS);
let hostkey = proptest::sample::select(ALLOWED_HOST_KEY_ALGORITHMS);
let enc = proptest::sample::select(ALLOWED_ENCRYPTION_ALGORITHMS);
let mac = proptest::sample::select(ALLOWED_MAC_ALGORITHMS);
let comp = proptest::sample::select(ALLOWED_COMPRESSION_ALGORITHMS);
(
proptest::collection::hash_set(keyx.clone(), 1..ALLOWED_KEY_EXCHANGE_ALGORITHMS.len()),
proptest::collection::hash_set(hostkey, 1..ALLOWED_HOST_KEY_ALGORITHMS.len()),
proptest::collection::hash_set(enc, 1..ALLOWED_ENCRYPTION_ALGORITHMS.len()),
proptest::collection::hash_set(mac, 1..ALLOWED_MAC_ALGORITHMS.len()),
proptest::collection::hash_set(comp, 1..ALLOWED_COMPRESSION_ALGORITHMS.len()),
proptest::option::of(keyx),
)
.prop_map(|(kex, host, enc, mac, comp, pred)| ClientConnectionOpts {
key_exchange_algorithms: finalize_options(kex).unwrap(),
server_host_key_algorithms: finalize_options(host).unwrap(),
encryption_algorithms: finalize_options(enc).unwrap(),
mac_algorithms: finalize_options(mac).unwrap(),
compression_algorithms: finalize_options(comp).unwrap(),
languages: vec![],
predict: pred.and_then(|x| KeyExchangeAlgorithm::from_str(x).ok()),
})
.boxed()
}
}
fn finalize_options<T,E>(values: HashSet<&str>) -> Result<Vec<T>, E>
where T: FromStr<Err=E>
{
Ok(values.into_iter().map(T::from_str).collect::<Result<Vec<T>,E>>()?)
}

View File

@@ -19,4 +19,15 @@ pub enum ConfigurationError {
PrivateKey,
#[error("Error configuring DNS resolver")]
Resolver,
#[error("Could not create UNIX listener socket at {path}: {error}")]
CouldNotMakeSocket {
path: PathBuf,
error: std::io::Error,
},
#[error("Could not {thing} for {path}: {error}")]
CouldNotSetPerms {
thing: String,
path: PathBuf,
error: std::io::Error,
},
}

View File

@@ -1,35 +1,28 @@
use error_stack::report;
use hickory_proto::error::ProtoError;
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::{Name, TokioAsyncResolver};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use thiserror::Error;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
#[derive(Debug, PartialEq, Deserialize, Serialize)]
pub struct DnsConfig {
built_in: Option<BuiltinDnsOption>,
local_domain: Option<String>,
pub local_domain: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
search_domains: Vec<String>,
pub search_domains: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
name_servers: Vec<ServerConfig>,
pub name_servers: Vec<NameServerConfig>,
#[serde(default)]
timeout_in_seconds: Option<u16>,
pub retry_attempts: Option<u16>,
#[serde(default)]
retry_attempts: Option<u16>,
pub cache_size: Option<u32>,
#[serde(default)]
cache_size: Option<u32>,
pub max_concurrent_requests_for_query: Option<u16>,
#[serde(default)]
use_hosts_file: Option<bool>,
pub preserve_intermediates: Option<bool>,
#[serde(default)]
max_concurrent_requests_for_query: Option<u16>,
pub shuffle_dns_servers: Option<bool>,
#[serde(default)]
preserve_intermediates: Option<bool>,
#[serde(default)]
shuffle_dns_servers: Option<bool>,
pub allow_mdns: Option<bool>,
}
impl Default for DnsConfig {
@@ -39,13 +32,12 @@ impl Default for DnsConfig {
local_domain: None,
search_domains: vec![],
name_servers: vec![],
timeout_in_seconds: None,
retry_attempts: None,
cache_size: None,
use_hosts_file: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
allow_mdns: None,
}
}
}
@@ -62,13 +54,12 @@ impl Arbitrary for DnsConfig {
local_domain: None,
search_domains: vec![],
name_servers: vec![],
timeout_in_seconds: None,
retry_attempts: None,
cache_size: None,
use_hosts_file: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
allow_mdns: None,
})
.boxed()
} else {
@@ -79,51 +70,47 @@ impl Arbitrary for DnsConfig {
let search_domains = proptest::collection::vec(domain_name_strat(), 0..10);
let min_servers = if built_in.is_some() { 0 } else { 1 };
let name_servers =
proptest::collection::vec(ServerConfig::arbitrary(), min_servers..6);
let timeout_in_seconds = proptest::option::of(u16::arbitrary());
proptest::collection::vec(NameServerConfig::arbitrary(), min_servers..6);
let retry_attempts = proptest::option::of(u16::arbitrary());
let cache_size = proptest::option::of(u32::arbitrary());
let use_hosts_file = proptest::option::of(bool::arbitrary());
let max_concurrent_requests_for_query = proptest::option::of(u16::arbitrary());
let preserve_intermediates = proptest::option::of(bool::arbitrary());
let shuffle_dns_servers = proptest::option::of(bool::arbitrary());
let allow_mdns = proptest::option::of(bool::arbitrary());
(
local_domain,
search_domains,
name_servers,
timeout_in_seconds,
retry_attempts,
cache_size,
use_hosts_file,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
allow_mdns,
)
.prop_map(
move |(
local_domain,
search_domains,
name_servers,
timeout_in_seconds,
retry_attempts,
cache_size,
use_hosts_file,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
allow_mdns,
)| DnsConfig {
built_in,
local_domain,
search_domains,
name_servers,
timeout_in_seconds,
retry_attempts,
cache_size,
use_hosts_file,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
allow_mdns,
},
)
})
@@ -172,34 +159,34 @@ impl Arbitrary for BuiltinDnsOption {
}
}
#[derive(Debug, PartialEq, Deserialize, Serialize)]
struct ServerConfig {
address: SocketAddr,
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct NameServerConfig {
pub address: SocketAddr,
#[serde(default)]
trust_negatives: bool,
pub timeout_in_seconds: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
bind_address: Option<SocketAddr>,
pub bind_address: Option<SocketAddr>,
}
impl Arbitrary for ServerConfig {
impl Arbitrary for NameServerConfig {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
(
SocketAddr::arbitrary(),
bool::arbitrary(),
proptest::option::of(u64::arbitrary()),
proptest::option::of(SocketAddr::arbitrary()),
)
.prop_map(|(mut address, trust_negatives, mut bind_address)| {
.prop_map(|(mut address, timeout_in_seconds, mut bind_address)| {
clear_flow_and_scope_info(&mut address);
if let Some(bind_address) = bind_address.as_mut() {
clear_flow_and_scope_info(bind_address);
}
ServerConfig {
NameServerConfig {
address,
trust_negatives,
timeout_in_seconds,
bind_address,
}
})
@@ -214,86 +201,6 @@ fn clear_flow_and_scope_info(address: &mut SocketAddr) {
}
}
#[derive(Debug, Error)]
pub enum ResolverConfigError {
#[error("Bad local domain name '{name}' provided: {error}")]
BadDomainName { name: String, error: ProtoError },
#[error("Bad domain name for search '{name}' provided: {error}")]
BadSearchName { name: String, error: ProtoError },
#[error("No DNS hosts found to search")]
NoHosts,
}
impl DnsConfig {
/// Convert this resolver configuration into an actual ResolverConfig, or say
/// why it's bad.
pub fn resolver(&self) -> error_stack::Result<TokioAsyncResolver, ResolverConfigError> {
let mut config = match &self.built_in {
None => ResolverConfig::new(),
Some(BuiltinDnsOption::Cloudflare) => ResolverConfig::cloudflare(),
Some(BuiltinDnsOption::Google) => ResolverConfig::google(),
Some(BuiltinDnsOption::Quad9) => ResolverConfig::quad9(),
};
if let Some(local_domain) = &self.local_domain {
let name = Name::from_utf8(local_domain).map_err(|error| {
report!(ResolverConfigError::BadDomainName {
name: local_domain.clone(),
error,
})
})?;
config.set_domain(name);
}
for name in self.search_domains.iter() {
let name = Name::from_utf8(name).map_err(|error| {
report!(ResolverConfigError::BadSearchName {
name: name.clone(),
error,
})
})?;
config.add_search(name);
}
for ns in self.name_servers.iter() {
let mut nsconfig = NameServerConfig::new(ns.address, Protocol::Udp);
nsconfig.trust_negative_responses = ns.trust_negatives;
nsconfig.bind_addr = ns.bind_address;
config.add_name_server(nsconfig);
}
if config.name_servers().is_empty() {
return Err(report!(ResolverConfigError::NoHosts));
}
let mut options = ResolverOpts::default();
if let Some(seconds) = self.timeout_in_seconds {
options.timeout = tokio::time::Duration::from_secs(seconds as u64);
}
if let Some(retries) = self.retry_attempts {
options.attempts = retries as usize;
}
if let Some(cache_size) = self.cache_size {
options.cache_size = cache_size as usize;
}
options.use_hosts_file = self.use_hosts_file.unwrap_or(true);
if let Some(max) = self.max_concurrent_requests_for_query {
options.num_concurrent_reqs = max as usize;
}
if let Some(preserve) = self.preserve_intermediates {
options.preserve_intermediates = preserve;
}
if let Some(shuffle) = self.shuffle_dns_servers {
options.shuffle_dns_servers = shuffle;
}
Ok(TokioAsyncResolver::tokio(config, options))
}
}
proptest::proptest! {
#[test]
fn valid_configs_parse(config in DnsConfig::arbitrary_with(false)) {
@@ -302,3 +209,98 @@ proptest::proptest! {
assert_eq!(config, reversed);
}
}
impl DnsConfig {
/// Return the configurations for all of the name servers that the user
/// has selected.
pub fn name_servers(&self) -> Vec<NameServerConfig> {
let mut results = self.name_servers.clone();
match self.built_in {
None => {}
Some(BuiltinDnsOption::Cloudflare) => {
results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 53)),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 0, 0, 1), 53)),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111),
53,
0,
0,
)),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1001),
53,
0,
0,
)),
timeout_in_seconds: None,
bind_address: None,
});
}
Some(BuiltinDnsOption::Google) => {
results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 53)),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 53)),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888),
53,
0,
0,
)),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8844),
53,
0,
0,
)),
timeout_in_seconds: None,
bind_address: None,
});
}
Some(BuiltinDnsOption::Quad9) => {
results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(9, 9, 9, 9), 53)),
timeout_in_seconds: None,
bind_address: None,
});
results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0x2620, 0, 0, 0, 0, 0, 0xfe, 0xfe),
53,
0,
0,
)),
timeout_in_seconds: None,
bind_address: None,
});
}
}
results
}
}

19
src/config/runtime.rs Normal file
View File

@@ -0,0 +1,19 @@
use crate::config::console::ConsoleConfiguration;
pub struct RuntimeConfiguration {
pub tokio_worker_threads: usize,
pub tokio_blocking_threads: usize,
pub console: Option<ConsoleConfiguration>,
}
impl Default for RuntimeConfiguration {
fn default() -> Self {
RuntimeConfiguration {
tokio_worker_threads: 4,
tokio_blocking_threads: 16,
console: None,
}
}
}

67
src/config/server.rs Normal file
View File

@@ -0,0 +1,67 @@
use clap::Parser;
use crate::config::config_file::{ConfigFile, SocketConfig};
use crate::config::command_line::ServerArguments;
use crate::config::error::ConfigurationError;
use crate::config::logging::LoggingConfiguration;
use crate::config::runtime::RuntimeConfiguration;
use std::collections::HashMap;
use std::ffi::OsString;
use tokio::net::UnixListener;
#[derive(Default)]
pub struct ServerConfiguration {
pub runtime: RuntimeConfiguration,
pub logging: LoggingConfiguration,
pub sockets: HashMap<String, SocketConfig>,
}
impl ServerConfiguration {
/// Load a server configuration for this run.
///
/// This will parse the process's command line arguments, and parse
/// a config file if given, so it can take awhile. Even though this
/// function does a bunch of IO, it is not async, because it is expected
/// ro run before we have a tokio runtime fully established. (This
/// function will determine a bunch of things related to the runtime,
/// like how many threads to run, what tracing subscribers to include,
/// etc.)
pub fn new<I, T>(args: I) -> Result<Self, ConfigurationError>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let mut new_configuration = Self::default();
let mut command_line = ServerArguments::try_parse_from(args)?;
let mut config_file = ConfigFile::new(command_line.arguments.config_file.take())?;
// we prefer the command line to the config file, so first merge
// in the config file so that when we later merge the command line,
// it overwrites any config file options.
if let Some(config_file) = config_file.as_mut() {
config_file.merge_standard_options_into(
&mut new_configuration.runtime,
&mut new_configuration.logging,
);
new_configuration.sockets = config_file.sockets.take().unwrap_or_default();
}
command_line.arguments.merge_standard_options_into(
&mut new_configuration.runtime,
&mut new_configuration.logging,
);
Ok(new_configuration)
}
/// Generate a series of UNIX sockets that clients will attempt to connect to.
pub async fn generate_listener_sockets(&mut self) -> Result<HashMap<String, UnixListener>, ConfigurationError> {
let mut results = HashMap::new();
for (name, config) in std::mem::take(&mut self.sockets).into_iter() {
results.insert(name, config.into_listener().await?);
}
Ok(results)
}
}

View File

@@ -4,6 +4,7 @@ pub mod crypto;
pub mod encodings;
pub mod network;
mod operational_error;
pub mod server;
pub mod ssh;
pub use operational_error::OperationalError;

View File

@@ -1 +1,2 @@
pub mod host;
pub mod resolver;

View File

@@ -1,8 +1,7 @@
use crate::network::resolver::{ResolveError, Resolver};
use error_stack::{report, ResultExt};
use futures::stream::{FuturesUnordered, StreamExt};
use hickory_resolver::error::ResolveError;
use hickory_resolver::name_server::ConnectionProvider;
use hickory_resolver::AsyncResolver;
use hickory_client::rr::Name;
use std::collections::HashSet;
use std::fmt;
use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
@@ -10,10 +9,11 @@ use std::str::FromStr;
use thiserror::Error;
use tokio::net::TcpStream;
#[derive(Debug)]
pub enum Host {
IPv4(Ipv4Addr),
IPv6(Ipv6Addr),
Hostname(String),
Hostname(Name),
}
#[derive(Debug, Error)]
@@ -49,20 +49,8 @@ impl FromStr for Host {
return Ok(Host::IPv6(addr));
}
if s.starts_with('[') && s.ends_with(']') {
match s.trim_start_matches('[').trim_end_matches(']').parse() {
Ok(x) => return Ok(Host::IPv6(x)),
Err(e) => {
return Err(HostParseError::CouldNotParseIPv6 {
address: s.to_string(),
error: e,
})
}
}
}
if hostname_validator::is_valid(s) {
return Ok(Host::Hostname(s.to_string()));
if let Ok(name) = Name::from_utf8(s) {
return Ok(Host::Hostname(name));
}
Err(HostParseError::InvalidHostname {
@@ -95,18 +83,11 @@ impl Host {
/// are no relevant records for us to use for IPv4 or IPv6 connections. There
/// is also no guarantee that the host will have both IPv4 and IPv6 addresses,
/// so you may only see one or the other.
pub async fn resolve<P: ConnectionProvider>(
&self,
resolver: &AsyncResolver<P>,
) -> Result<HashSet<IpAddr>, ResolveError> {
pub async fn resolve(&self, resolver: &mut Resolver) -> Result<HashSet<IpAddr>, ResolveError> {
match self {
Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])),
Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])),
Host::Hostname(name) => {
let resolve_result = resolver.lookup_ip(name).await?;
let possibilities = resolve_result.iter().collect();
Ok(possibilities)
}
Host::Hostname(name) => resolver.lookup(name).await,
}
}
@@ -117,9 +98,9 @@ impl Host {
/// connections fail, it will return the first error it receives. This routine
/// will also return an error if there are no addresses to connect to (which
/// can happen in cases in which [`Host::resolve`] would return an empty set.
pub async fn connect<P: ConnectionProvider>(
pub async fn connect(
&self,
resolver: &AsyncResolver<P>,
resolver: &mut Resolver,
port: u16,
) -> error_stack::Result<TcpStream, ConnectionError> {
let addresses = self
@@ -131,6 +112,7 @@ impl Host {
let mut connectors = FuturesUnordered::new();
for address in addresses.into_iter() {
tracing::trace!(?address, "adding possible target address");
let connect_future = TcpStream::connect(SocketAddr::new(address, port));
connectors.push(connect_future);
}

261
src/network/resolver.rs Normal file
View File

@@ -0,0 +1,261 @@
use crate::config::resolver::{DnsConfig, NameServerConfig};
use error_stack::report;
use futures::stream::{SelectAll, StreamExt};
use hickory_client::rr::Name;
use hickory_proto::error::ProtoError;
use hickory_proto::op::message::Message;
use hickory_proto::op::query::Query;
use hickory_proto::rr::record_data::RData;
use hickory_proto::rr::resource::Record;
use hickory_proto::rr::RecordType;
use hickory_proto::udp::UdpClientStream;
use hickory_proto::xfer::dns_request::DnsRequest;
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr};
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::time::{Duration, Instant};
#[derive(Debug, Error)]
pub enum ResolverConfigError {
#[error("Bad local domain name '{name}' provided: '{error}'")]
BadDomainName { name: String, error: ProtoError },
#[error("Bad domain name for search '{name}' provided: '{error}'")]
BadSearchName { name: String, error: ProtoError },
#[error("Couldn't set up client for server at '{address}': '{error}'")]
FailedToCreateDnsClient {
address: SocketAddr,
error: ProtoError,
},
#[error("No DNS servers found to search, and mDNS not enabled")]
NoHosts,
}
#[derive(Debug, Error)]
pub enum ResolveError {
#[error("No servers available for query")]
NoServersAvailable,
#[error("No responses found for query")]
NoResponses,
#[error("Error reading response: {error}")]
ResponseError { error: ProtoError },
}
pub struct Resolver {
search_domains: Vec<Name>,
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>,
cache: HashMap<Name, Vec<(Instant, IpAddr)>>,
}
impl Resolver {
/// Create a new DNS resolution engine for use by some part of the system.
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
let mut search_domains = Vec::new();
if let Some(local) = config.local_domain.as_ref() {
search_domains.push(Name::from_utf8(local).map_err(|e| {
report!(ResolverConfigError::BadDomainName {
name: local.clone(),
error: e,
})
})?);
}
for local_domain in config.search_domains.iter() {
search_domains.push(Name::from_utf8(local_domain).map_err(|e| {
report!(ResolverConfigError::BadSearchName {
name: local_domain.clone(),
error: e,
})
})?);
}
let mut client_connections = Vec::new();
for name_server_config in config.name_servers() {
let stream = UdpClientStream::with_bind_addr_and_timeout(
name_server_config.address,
name_server_config.bind_address.clone(),
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)),
)
.await
.map_err(|error| ResolverConfigError::FailedToCreateDnsClient {
address: name_server_config.address,
error,
})?;
client_connections.push((name_server_config, stream));
}
Ok(Resolver {
search_domains,
client_connections,
cache: HashMap::new(),
})
}
/// Look up the address of the given name, returning either a set of results
/// we've received in a reasonable amount of time, or an error.
pub async fn lookup(&mut self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
let names = self.expand_name(name);
let mut response_stream = self.create_response_stream(names).await;
if response_stream.is_empty() {
return Err(ResolveError::NoServersAvailable);
}
let mut first_error = None;
while let Some(response) = response_stream.next().await {
match response {
Err(e) => {
if first_error.is_none() {
first_error = Some(e);
}
}
Ok(response) => {
for answer in response.into_message().take_answers() {
self.handle_response(name, answer).await;
}
}
}
}
match first_error {
None => Err(ResolveError::NoResponses),
Some(error) => Err(ResolveError::ResponseError { error }),
}
}
async fn create_response_stream(&mut self, names: Vec<Name>) -> SelectAll<DnsResponseStream> {
let connections = std::mem::take(&mut self.client_connections);
let mut response_stream = futures::stream::SelectAll::new();
for (config, mut client) in connections.into_iter() {
if client.is_shutdown() {
let stream = UdpClientStream::with_bind_addr_and_timeout(
config.address,
config.bind_address.clone(),
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
)
.await;
match stream {
Ok(stream) => client = stream,
Err(_) => continue,
}
}
let mut message = Message::new();
for name in names.iter() {
message.add_query(Query::query(name.clone(), RecordType::A));
message.add_query(Query::query(name.clone(), RecordType::AAAA));
}
message.set_recursion_desired(true);
let request = DnsRequest::new(message, Default::default());
response_stream.push(client.send_message(request));
self.client_connections.push((config, client));
}
response_stream
}
/// Expand the given input name into the complete set of names that we should search for.
fn expand_name(&self, name: &Name) -> Vec<Name> {
let mut names = vec![name.clone()];
for search_domain in self.search_domains.iter() {
if let Ok(combined) = name.clone().append_name(search_domain) {
names.push(combined);
}
}
names
}
/// Handle an individual response from a server.
///
/// Returns true if we got an answer from the server, and updated the internal cache.
/// This answer is used externally to set a timer, so that we don't wait the full DNS
/// timeout period for answers to come back.
async fn handle_response(&mut self, query: &Name, record: Record) -> bool {
let ttl = record.ttl();
let Some(rdata) = record.into_data() else {
tracing::error!("for some reason, couldn't process incoming message");
return false;
};
let response: IpAddr = match rdata {
RData::A(arec) => arec.0.into(),
RData::AAAA(arec) => arec.0.into(),
_ => {
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
return false;
}
};
let expire_at = Instant::now() + Duration::from_secs(ttl as u64);
match self.cache.entry(query.clone()) {
std::collections::hash_map::Entry::Vacant(vec) => {
vec.insert(vec![(expire_at, response.clone())]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => {
clean_expired_entries(occ.get_mut());
occ.get_mut().push((expire_at, response.clone()));
}
}
true
}
}
fn clean_expired_entries(list: &mut Vec<(Instant, IpAddr)>) {
let now = Instant::now();
list.retain(|(expire_at, _)| expire_at > &now);
}
#[tokio::test]
async fn uhsure() {
let mut resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
let name = Name::from_ascii("uhsure.com").unwrap();
let result = resolver.lookup(&name).await.unwrap();
println!("result = {:?}", result);
assert!(!result.is_empty());
}
#[tokio::test]
async fn manual_lookup() {
let mut stream: UdpClientStream<UdpSocket> = UdpClientStream::with_bind_addr_and_timeout(
SocketAddr::V4(std::net::SocketAddrV4::new(
std::net::Ipv4Addr::new(172, 19, 207, 1),
53,
)),
None,
Duration::from_secs(3),
)
.await
.expect("can create UDP DNS client stream");
let mut message = Message::new();
let name = Name::from_ascii("uhsure.com").unwrap();
message.add_query(Query::query(name.clone(), RecordType::A));
message.add_query(Query::query(name.clone(), RecordType::AAAA));
message.set_recursion_desired(true);
let request = DnsRequest::new(message, Default::default());
let mut responses = stream.send_message(request);
while let Some(response) = responses.next().await {
println!("response: {:?}", response);
}
unimplemented!();
}

View File

@@ -5,6 +5,8 @@ use thiserror::Error;
pub enum OperationalError {
#[error("Configuration error")]
ConfigurationError,
#[error("DNS client configuration error")]
DnsConfig,
#[error("Failed to connect to target address")]
Connection,
#[error("Failure during key exchange / agreement protocol")]

28
src/server.rs Normal file
View File

@@ -0,0 +1,28 @@
mod socket;
mod state;
use crate::config::server::ServerConfiguration;
use error_stack::ResultExt;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum TopLevelError {
#[error("Configuration error")]
ConfigurationError,
#[error("Failure running UNIX socket handling task")]
SocketHandlerFailure,
}
pub async fn run(mut config: ServerConfiguration) -> error_stack::Result<(), TopLevelError> {
let mut server_state = state::ServerState::default();
let listeners = config.generate_listener_sockets().await
.change_context(TopLevelError::ConfigurationError)?;
for (name, listener) in listeners.into_iter() {
}
Ok(())
}

75
src/server/socket.rs Normal file
View File

@@ -0,0 +1,75 @@
use crate::server::TopLevelError;
use error_stack::ResultExt;
use tokio::net::{UnixListener, UnixStream};
use tracing::Instrument;
pub struct SocketServer {
name: String,
path: String,
num_sessions_run: u64,
listener: UnixListener,
}
impl SocketServer {
/// Create a new server that will handle inputs from the client program.
///
/// This function will just generate the function required, without starting the
/// underlying task. To start the task, use [`SocketServer::start`], although that
/// method will take ownership of the object.
pub fn new(name: String, listener: UnixListener) -> Self {
let path = listener
.local_addr()
.map(|x| x.as_pathname().map(|p| format!("{}", p.display())))
.unwrap_or_else(|_| None)
.unwrap_or_else(|| format!("unknown"));
tracing::trace!(%name, %path, "Creating new socket listener");
SocketServer {
name,
path,
num_sessions_run: 0,
listener,
}
}
/// Start running the service, returning a handle that will pass on an error if
/// one occurs in the core of this task.
///
/// Typically, errors shouldn't happen in the core task, as all it does is listen
/// for new connections and then spawn other tasks based on them. If errors occur
/// there, the core task should be unaffected.
pub async fn start(mut self) -> error_stack::Result<(), TopLevelError> {
loop {
let (stream, addr) = self.listener.accept().await.change_context(TopLevelError::SocketHandlerFailure)?;
let remote_addr = addr
.as_pathname()
.map(|x| x.display())
.map(|x| format!("{}", x))
.unwrap_or("<unknown>".to_string());
let span = tracing::debug_span!(
"unix socket handler",
socket_name = %self.name,
socket_path = %self.path,
session_no = %self.num_sessions_run,
%remote_addr,
);
self.num_sessions_run += 1;
tokio::task::spawn(Self::run_session(stream)
.instrument(span))
.await
.change_context(TopLevelError::SocketHandlerFailure)?;
}
}
/// Run a session.
///
/// This is here because it's convenient, not because it shares state (obviously,
/// given the type signature). But it's somewhat logically associated with this type,
/// so it seems reasonable to make it an associated function.
async fn run_session(handle: UnixStream) {
}
}

8
src/server/state.rs Normal file
View File

@@ -0,0 +1,8 @@
use crate::server::TopLevelError;
use tokio::task::JoinSet;
#[derive(Default)]
pub struct ServerState {
/// The set of top level tasks that are currently running.
top_level_tasks: JoinSet<Result<(), TopLevelError>>,
}

View File

@@ -5,5 +5,5 @@ mod preamble;
pub use channel::SshChannel;
pub use message_ids::SshMessageID;
pub use packets::SshKeyExchangeProcessingError;
pub use packets::{SshKeyExchange, SshKeyExchangeProcessingError};
pub use preamble::Preamble;

View File

@@ -1,3 +1,3 @@
mod key_exchange;
pub use key_exchange::SshKeyExchangeProcessingError;
pub use key_exchange::{SshKeyExchange, SshKeyExchangeProcessingError};

View File

@@ -1,4 +1,4 @@
use crate::config::ClientConnectionOpts;
use crate::config::connection::ClientConnectionOpts;
use crate::ssh::channel::SshPacket;
use crate::ssh::message_ids::SshMessageID;
use bytes::{Buf, BufMut, Bytes, BytesMut};