well, tests pass now

This commit is contained in:
2024-10-14 17:33:24 -07:00
parent 795c528754
commit a813b65535
81 changed files with 15233 additions and 6 deletions

304
src/config/resolver.rs Normal file
View File

@@ -0,0 +1,304 @@
use error_stack::report;
use hickory_proto::error::ProtoError;
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::{Name, TokioAsyncResolver};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use thiserror::Error;
#[derive(Debug, PartialEq, Deserialize, Serialize)]
pub struct DnsConfig {
built_in: Option<BuiltinDnsOption>,
local_domain: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
search_domains: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
name_servers: Vec<ServerConfig>,
#[serde(default)]
timeout_in_seconds: Option<u16>,
#[serde(default)]
retry_attempts: Option<u16>,
#[serde(default)]
cache_size: Option<u32>,
#[serde(default)]
use_hosts_file: Option<bool>,
#[serde(default)]
max_concurrent_requests_for_query: Option<u16>,
#[serde(default)]
preserve_intermediates: Option<bool>,
#[serde(default)]
shuffle_dns_servers: Option<bool>,
}
impl Default for DnsConfig {
fn default() -> Self {
DnsConfig {
built_in: Some(BuiltinDnsOption::Cloudflare),
local_domain: None,
search_domains: vec![],
name_servers: vec![],
timeout_in_seconds: None,
retry_attempts: None,
cache_size: None,
use_hosts_file: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
}
}
}
impl Arbitrary for DnsConfig {
type Parameters = bool;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(always_use_builtin: Self::Parameters) -> Self::Strategy {
if always_use_builtin {
BuiltinDnsOption::arbitrary()
.prop_map(|x| DnsConfig {
built_in: Some(x),
local_domain: None,
search_domains: vec![],
name_servers: vec![],
timeout_in_seconds: None,
retry_attempts: None,
cache_size: None,
use_hosts_file: None,
max_concurrent_requests_for_query: None,
preserve_intermediates: None,
shuffle_dns_servers: None,
})
.boxed()
} else {
let built_in = proptest::option::of(BuiltinDnsOption::arbitrary());
built_in
.prop_flat_map(|built_in| {
let local_domain = proptest::option::of(domain_name_strat());
let search_domains = proptest::collection::vec(domain_name_strat(), 0..10);
let min_servers = if built_in.is_some() { 0 } else { 1 };
let name_servers =
proptest::collection::vec(ServerConfig::arbitrary(), min_servers..6);
let timeout_in_seconds = proptest::option::of(u16::arbitrary());
let retry_attempts = proptest::option::of(u16::arbitrary());
let cache_size = proptest::option::of(u32::arbitrary());
let use_hosts_file = proptest::option::of(bool::arbitrary());
let max_concurrent_requests_for_query = proptest::option::of(u16::arbitrary());
let preserve_intermediates = proptest::option::of(bool::arbitrary());
let shuffle_dns_servers = proptest::option::of(bool::arbitrary());
(
local_domain,
search_domains,
name_servers,
timeout_in_seconds,
retry_attempts,
cache_size,
use_hosts_file,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
)
.prop_map(
move |(
local_domain,
search_domains,
name_servers,
timeout_in_seconds,
retry_attempts,
cache_size,
use_hosts_file,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
)| DnsConfig {
built_in,
local_domain,
search_domains,
name_servers,
timeout_in_seconds,
retry_attempts,
cache_size,
use_hosts_file,
max_concurrent_requests_for_query,
preserve_intermediates,
shuffle_dns_servers,
},
)
})
.boxed()
}
}
}
fn domain_name_strat() -> BoxedStrategy<String> {
let chunk = proptest::string::string_regex("[a-zA-Z0-9]{2,32}").unwrap();
let sets = proptest::collection::vec(chunk, 2..6);
sets.prop_map(|set| {
let mut output = String::new();
for x in set.into_iter() {
if !output.is_empty() {
output.push('.');
}
output.push_str(&x);
}
output
})
.boxed()
}
#[derive(Clone, Copy, Debug, PartialEq, Deserialize, Serialize)]
enum BuiltinDnsOption {
Google,
Cloudflare,
Quad9,
}
impl Arbitrary for BuiltinDnsOption {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
proptest::prop_oneof![
Just(BuiltinDnsOption::Google),
Just(BuiltinDnsOption::Cloudflare),
Just(BuiltinDnsOption::Quad9),
]
.boxed()
}
}
#[derive(Debug, PartialEq, Deserialize, Serialize)]
struct ServerConfig {
address: SocketAddr,
#[serde(default)]
trust_negatives: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
bind_address: Option<SocketAddr>,
}
impl Arbitrary for ServerConfig {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
(
SocketAddr::arbitrary(),
bool::arbitrary(),
proptest::option::of(SocketAddr::arbitrary()),
)
.prop_map(|(mut address, trust_negatives, mut bind_address)| {
clear_flow_and_scope_info(&mut address);
if let Some(bind_address) = bind_address.as_mut() {
clear_flow_and_scope_info(bind_address);
}
ServerConfig {
address,
trust_negatives,
bind_address,
}
})
.boxed()
}
}
fn clear_flow_and_scope_info(address: &mut SocketAddr) {
if let SocketAddr::V6(addr) = address {
addr.set_flowinfo(0);
addr.set_scope_id(0);
}
}
#[derive(Debug, Error)]
pub enum ResolverConfigError {
#[error("Bad local domain name '{name}' provided: {error}")]
BadDomainName { name: String, error: ProtoError },
#[error("Bad domain name for search '{name}' provided: {error}")]
BadSearchName { name: String, error: ProtoError },
#[error("No DNS hosts found to search")]
NoHosts,
}
impl DnsConfig {
/// Convert this resolver configuration into an actual ResolverConfig, or say
/// why it's bad.
pub fn resolver(&self) -> error_stack::Result<TokioAsyncResolver, ResolverConfigError> {
let mut config = match &self.built_in {
None => ResolverConfig::new(),
Some(BuiltinDnsOption::Cloudflare) => ResolverConfig::cloudflare(),
Some(BuiltinDnsOption::Google) => ResolverConfig::google(),
Some(BuiltinDnsOption::Quad9) => ResolverConfig::quad9(),
};
if let Some(local_domain) = &self.local_domain {
let name = Name::from_utf8(local_domain).map_err(|error| {
report!(ResolverConfigError::BadDomainName {
name: local_domain.clone(),
error,
})
})?;
config.set_domain(name);
}
for name in self.search_domains.iter() {
let name = Name::from_utf8(name).map_err(|error| {
report!(ResolverConfigError::BadSearchName {
name: name.clone(),
error,
})
})?;
config.add_search(name);
}
for ns in self.name_servers.iter() {
let mut nsconfig = NameServerConfig::new(ns.address, Protocol::Udp);
nsconfig.trust_negative_responses = ns.trust_negatives;
nsconfig.bind_addr = ns.bind_address;
config.add_name_server(nsconfig);
}
if config.name_servers().is_empty() {
return Err(report!(ResolverConfigError::NoHosts));
}
let mut options = ResolverOpts::default();
if let Some(seconds) = self.timeout_in_seconds {
options.timeout = tokio::time::Duration::from_secs(seconds as u64);
}
if let Some(retries) = self.retry_attempts {
options.attempts = retries as usize;
}
if let Some(cache_size) = self.cache_size {
options.cache_size = cache_size as usize;
}
options.use_hosts_file = self.use_hosts_file.unwrap_or(true);
if let Some(max) = self.max_concurrent_requests_for_query {
options.num_concurrent_reqs = max as usize;
}
if let Some(preserve) = self.preserve_intermediates {
options.preserve_intermediates = preserve;
}
if let Some(shuffle) = self.shuffle_dns_servers {
options.shuffle_dns_servers = shuffle;
}
Ok(TokioAsyncResolver::tokio(config, options))
}
}
proptest::proptest! {
#[test]
fn valid_configs_parse(config in DnsConfig::arbitrary_with(false)) {
let toml = toml::to_string(&config).unwrap();
let reversed: DnsConfig = toml::from_str(&toml).unwrap();
assert_eq!(config, reversed);
}
}