well, tests pass now
This commit is contained in:
304
src/config/resolver.rs
Normal file
304
src/config/resolver.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
use error_stack::report;
|
||||
use hickory_proto::error::ProtoError;
|
||||
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
|
||||
use hickory_resolver::{Name, TokioAsyncResolver};
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
||||
pub struct DnsConfig {
|
||||
built_in: Option<BuiltinDnsOption>,
|
||||
local_domain: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
search_domains: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
name_servers: Vec<ServerConfig>,
|
||||
#[serde(default)]
|
||||
timeout_in_seconds: Option<u16>,
|
||||
#[serde(default)]
|
||||
retry_attempts: Option<u16>,
|
||||
#[serde(default)]
|
||||
cache_size: Option<u32>,
|
||||
#[serde(default)]
|
||||
use_hosts_file: Option<bool>,
|
||||
#[serde(default)]
|
||||
max_concurrent_requests_for_query: Option<u16>,
|
||||
#[serde(default)]
|
||||
preserve_intermediates: Option<bool>,
|
||||
#[serde(default)]
|
||||
shuffle_dns_servers: Option<bool>,
|
||||
}
|
||||
|
||||
impl Default for DnsConfig {
|
||||
fn default() -> Self {
|
||||
DnsConfig {
|
||||
built_in: Some(BuiltinDnsOption::Cloudflare),
|
||||
local_domain: None,
|
||||
search_domains: vec![],
|
||||
name_servers: vec![],
|
||||
timeout_in_seconds: None,
|
||||
retry_attempts: None,
|
||||
cache_size: None,
|
||||
use_hosts_file: None,
|
||||
max_concurrent_requests_for_query: None,
|
||||
preserve_intermediates: None,
|
||||
shuffle_dns_servers: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for DnsConfig {
|
||||
type Parameters = bool;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(always_use_builtin: Self::Parameters) -> Self::Strategy {
|
||||
if always_use_builtin {
|
||||
BuiltinDnsOption::arbitrary()
|
||||
.prop_map(|x| DnsConfig {
|
||||
built_in: Some(x),
|
||||
local_domain: None,
|
||||
search_domains: vec![],
|
||||
name_servers: vec![],
|
||||
timeout_in_seconds: None,
|
||||
retry_attempts: None,
|
||||
cache_size: None,
|
||||
use_hosts_file: None,
|
||||
max_concurrent_requests_for_query: None,
|
||||
preserve_intermediates: None,
|
||||
shuffle_dns_servers: None,
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
let built_in = proptest::option::of(BuiltinDnsOption::arbitrary());
|
||||
built_in
|
||||
.prop_flat_map(|built_in| {
|
||||
let local_domain = proptest::option::of(domain_name_strat());
|
||||
let search_domains = proptest::collection::vec(domain_name_strat(), 0..10);
|
||||
let min_servers = if built_in.is_some() { 0 } else { 1 };
|
||||
let name_servers =
|
||||
proptest::collection::vec(ServerConfig::arbitrary(), min_servers..6);
|
||||
let timeout_in_seconds = proptest::option::of(u16::arbitrary());
|
||||
let retry_attempts = proptest::option::of(u16::arbitrary());
|
||||
let cache_size = proptest::option::of(u32::arbitrary());
|
||||
let use_hosts_file = proptest::option::of(bool::arbitrary());
|
||||
let max_concurrent_requests_for_query = proptest::option::of(u16::arbitrary());
|
||||
let preserve_intermediates = proptest::option::of(bool::arbitrary());
|
||||
let shuffle_dns_servers = proptest::option::of(bool::arbitrary());
|
||||
|
||||
(
|
||||
local_domain,
|
||||
search_domains,
|
||||
name_servers,
|
||||
timeout_in_seconds,
|
||||
retry_attempts,
|
||||
cache_size,
|
||||
use_hosts_file,
|
||||
max_concurrent_requests_for_query,
|
||||
preserve_intermediates,
|
||||
shuffle_dns_servers,
|
||||
)
|
||||
.prop_map(
|
||||
move |(
|
||||
local_domain,
|
||||
search_domains,
|
||||
name_servers,
|
||||
timeout_in_seconds,
|
||||
retry_attempts,
|
||||
cache_size,
|
||||
use_hosts_file,
|
||||
max_concurrent_requests_for_query,
|
||||
preserve_intermediates,
|
||||
shuffle_dns_servers,
|
||||
)| DnsConfig {
|
||||
built_in,
|
||||
local_domain,
|
||||
search_domains,
|
||||
name_servers,
|
||||
timeout_in_seconds,
|
||||
retry_attempts,
|
||||
cache_size,
|
||||
use_hosts_file,
|
||||
max_concurrent_requests_for_query,
|
||||
preserve_intermediates,
|
||||
shuffle_dns_servers,
|
||||
},
|
||||
)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn domain_name_strat() -> BoxedStrategy<String> {
|
||||
let chunk = proptest::string::string_regex("[a-zA-Z0-9]{2,32}").unwrap();
|
||||
let sets = proptest::collection::vec(chunk, 2..6);
|
||||
sets.prop_map(|set| {
|
||||
let mut output = String::new();
|
||||
|
||||
for x in set.into_iter() {
|
||||
if !output.is_empty() {
|
||||
output.push('.');
|
||||
}
|
||||
|
||||
output.push_str(&x);
|
||||
}
|
||||
|
||||
output
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Deserialize, Serialize)]
|
||||
enum BuiltinDnsOption {
|
||||
Google,
|
||||
Cloudflare,
|
||||
Quad9,
|
||||
}
|
||||
|
||||
impl Arbitrary for BuiltinDnsOption {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
|
||||
proptest::prop_oneof![
|
||||
Just(BuiltinDnsOption::Google),
|
||||
Just(BuiltinDnsOption::Cloudflare),
|
||||
Just(BuiltinDnsOption::Quad9),
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
||||
struct ServerConfig {
|
||||
address: SocketAddr,
|
||||
#[serde(default)]
|
||||
trust_negatives: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bind_address: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
impl Arbitrary for ServerConfig {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
|
||||
(
|
||||
SocketAddr::arbitrary(),
|
||||
bool::arbitrary(),
|
||||
proptest::option::of(SocketAddr::arbitrary()),
|
||||
)
|
||||
.prop_map(|(mut address, trust_negatives, mut bind_address)| {
|
||||
clear_flow_and_scope_info(&mut address);
|
||||
if let Some(bind_address) = bind_address.as_mut() {
|
||||
clear_flow_and_scope_info(bind_address);
|
||||
}
|
||||
|
||||
ServerConfig {
|
||||
address,
|
||||
trust_negatives,
|
||||
bind_address,
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_flow_and_scope_info(address: &mut SocketAddr) {
|
||||
if let SocketAddr::V6(addr) = address {
|
||||
addr.set_flowinfo(0);
|
||||
addr.set_scope_id(0);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ResolverConfigError {
|
||||
#[error("Bad local domain name '{name}' provided: {error}")]
|
||||
BadDomainName { name: String, error: ProtoError },
|
||||
#[error("Bad domain name for search '{name}' provided: {error}")]
|
||||
BadSearchName { name: String, error: ProtoError },
|
||||
#[error("No DNS hosts found to search")]
|
||||
NoHosts,
|
||||
}
|
||||
|
||||
impl DnsConfig {
|
||||
/// Convert this resolver configuration into an actual ResolverConfig, or say
|
||||
/// why it's bad.
|
||||
pub fn resolver(&self) -> error_stack::Result<TokioAsyncResolver, ResolverConfigError> {
|
||||
let mut config = match &self.built_in {
|
||||
None => ResolverConfig::new(),
|
||||
Some(BuiltinDnsOption::Cloudflare) => ResolverConfig::cloudflare(),
|
||||
Some(BuiltinDnsOption::Google) => ResolverConfig::google(),
|
||||
Some(BuiltinDnsOption::Quad9) => ResolverConfig::quad9(),
|
||||
};
|
||||
|
||||
if let Some(local_domain) = &self.local_domain {
|
||||
let name = Name::from_utf8(local_domain).map_err(|error| {
|
||||
report!(ResolverConfigError::BadDomainName {
|
||||
name: local_domain.clone(),
|
||||
error,
|
||||
})
|
||||
})?;
|
||||
config.set_domain(name);
|
||||
}
|
||||
|
||||
for name in self.search_domains.iter() {
|
||||
let name = Name::from_utf8(name).map_err(|error| {
|
||||
report!(ResolverConfigError::BadSearchName {
|
||||
name: name.clone(),
|
||||
error,
|
||||
})
|
||||
})?;
|
||||
config.add_search(name);
|
||||
}
|
||||
|
||||
for ns in self.name_servers.iter() {
|
||||
let mut nsconfig = NameServerConfig::new(ns.address, Protocol::Udp);
|
||||
|
||||
nsconfig.trust_negative_responses = ns.trust_negatives;
|
||||
nsconfig.bind_addr = ns.bind_address;
|
||||
|
||||
config.add_name_server(nsconfig);
|
||||
}
|
||||
|
||||
if config.name_servers().is_empty() {
|
||||
return Err(report!(ResolverConfigError::NoHosts));
|
||||
}
|
||||
|
||||
let mut options = ResolverOpts::default();
|
||||
|
||||
if let Some(seconds) = self.timeout_in_seconds {
|
||||
options.timeout = tokio::time::Duration::from_secs(seconds as u64);
|
||||
}
|
||||
if let Some(retries) = self.retry_attempts {
|
||||
options.attempts = retries as usize;
|
||||
}
|
||||
if let Some(cache_size) = self.cache_size {
|
||||
options.cache_size = cache_size as usize;
|
||||
}
|
||||
options.use_hosts_file = self.use_hosts_file.unwrap_or(true);
|
||||
if let Some(max) = self.max_concurrent_requests_for_query {
|
||||
options.num_concurrent_reqs = max as usize;
|
||||
}
|
||||
if let Some(preserve) = self.preserve_intermediates {
|
||||
options.preserve_intermediates = preserve;
|
||||
}
|
||||
if let Some(shuffle) = self.shuffle_dns_servers {
|
||||
options.shuffle_dns_servers = shuffle;
|
||||
}
|
||||
|
||||
Ok(TokioAsyncResolver::tokio(config, options))
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn valid_configs_parse(config in DnsConfig::arbitrary_with(false)) {
|
||||
let toml = toml::to_string(&config).unwrap();
|
||||
let reversed: DnsConfig = toml::from_str(&toml).unwrap();
|
||||
assert_eq!(config, reversed);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user