Checkpoint with resolver tests including request/response.

This commit is contained in:
2025-05-03 13:50:16 -07:00
parent 31cd34d280
commit 9fe5b78962
20 changed files with 4012 additions and 1093 deletions

1076
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,42 +15,41 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] }
aes = { version = "0.8.4", features = ["zeroize"] } aes = { version = "0.8.4", features = ["zeroize"] }
base64 = "0.22.1" base64 = "0.22.1"
bcrypt-pbkdf = "0.10.0" bcrypt-pbkdf = "0.10.0"
bytes = "1.6.0" bytes = "1.10.1"
cipher = { version = "0.4.4", features = ["alloc", "block-padding", "rand_core", "std", "zeroize"] } clap = { version = "4.5.35", features = ["derive"] }
clap = { version = "4.5.7", features = ["derive"] } console-subscriber = "0.4.1"
console-subscriber = "0.3.0"
ctr = "0.9.2" ctr = "0.9.2"
ed25519-dalek = "2.1.1" ed25519-dalek = "2.1.1"
elliptic-curve = { version = "0.13.8", features = ["alloc", "digest", "ecdh", "pem", "pkcs8", "sec1", "serde", "std", "hash2curve", "voprf"] } elliptic-curve = { version = "0.13.8", features = ["alloc", "digest", "ecdh", "pem", "pkcs8", "sec1", "serde", "std", "hash2curve", "voprf"] }
error-stack = "0.5.0" error-stack = "0.5.0"
futures = "0.3.31" futures = "0.3.31"
generic-array = "0.14.7" generic-array = "0.14.7"
hexdump = "0.1.2" getrandom = "0.3.2"
hickory-client = { version = "0.24.1", features = ["mdns"] } internment = { version = "0.8.6", features = ["arc"] }
hickory-proto = "0.24.1" itertools = "0.14.0"
itertools = "0.13.0" nix = { version = "0.29.0", features = ["user"] }
moka = { version = "0.12.10", features = ["future"] }
nix = { version = "0.28.0", features = ["user"] }
num-bigint-dig = { version = "0.8.4", features = ["arbitrary", "i128", "zeroize", "prime", "rand"] } num-bigint-dig = { version = "0.8.4", features = ["arbitrary", "i128", "zeroize", "prime", "rand"] }
num-integer = { version = "0.1.46", features = ["i128"] } num-integer = { version = "0.1.46", features = ["i128"] }
num-traits = { version = "0.2.19", features = ["i128"] } num-traits = { version = "0.2.19", features = ["i128"] }
num_enum = "0.7.2" num_enum = "0.7.3"
p256 = { version = "0.13.2", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] } p256 = { version = "0.13.2", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
p384 = { version = "0.13.0", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] } p384 = { version = "0.13.1", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
p521 = { version = "0.13.3", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] } p521 = { version = "0.13.3", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
proptest = "1.5.0" proptest = "1.6.0"
rand = "0.8.5" proptest-derive = "0.5.1"
rand_chacha = "0.3.1" rand = "0.9.0"
rustix = "0.38.41" rand_chacha = "0.9.0"
rustix = "1.0.5"
sec1 = "0.7.3" sec1 = "0.7.3"
serde = { version = "1.0.203", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
tempfile = "3.12.0" tempfile = "3.19.1"
thiserror = "2.0.3" thiserror = "2.0.12"
tokio = { version = "1.38.0", features = ["full", "tracing"] } tokio = { version = "1.44.2", features = ["full", "tracing"] }
toml = "0.8.14" toml = "0.8.20"
tracing = "0.1.40" tracing = "0.1.41"
tracing-core = "0.1.32" tracing-core = "0.1.33"
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "tracing", "json"] } tracing-subscriber = { version = "0.3.19", features = ["env-filter", "tracing", "json"] }
whoami = { version = "1.5.2", default-features = false } url = "2.5.4"
whoami = { version = "1.6.0", default-features = false }
xdg = "2.5.2" xdg = "2.5.2"
zeroize = "1.8.1" zeroize = "1.8.1"

View File

@@ -159,7 +159,7 @@ async fn connect(
"received server preamble" "received server preamble"
); );
let stream = ssh::SshChannel::new(stream); let stream = ssh::SshChannel::new(stream).expect("can build new SSH channel");
let their_initial = stream let their_initial = stream
.read() .read()
.await .await

View File

@@ -1,7 +1,9 @@
use crate::network::resolver::name::Name;
use proptest::arbitrary::Arbitrary; use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy}; use proptest::strategy::{BoxedStrategy, Just, Strategy};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::SocketAddr;
use url::Url;
#[derive(Debug, PartialEq, Deserialize, Serialize)] #[derive(Debug, PartialEq, Deserialize, Serialize)]
pub struct DnsConfig { pub struct DnsConfig {
@@ -179,25 +181,80 @@ impl Arbitrary for BuiltinDnsOption {
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct NameServerConfig { pub struct NameServerConfig {
pub address: SocketAddr, #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
pub address: Url,
#[serde(default)] #[serde(default)]
pub timeout_in_seconds: Option<u64>, pub timeout_in_seconds: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub bind_address: Option<SocketAddr>, pub bind_address: Option<SocketAddr>,
} }
fn serialize_url<S: serde::Serializer>(url: &Url, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_str(url)
}
fn deserialize_url<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Url, D::Error> {
struct Visitor {}
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Url;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a legal URL")
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
Url::parse(v).map_err(|e| E::custom(e))
}
fn visit_borrowed_str<E: serde::de::Error>(self, v: &'de str) -> Result<Self::Value, E> {
Url::parse(v).map_err(|e| E::custom(e))
}
}
deserializer.deserialize_str(Visitor {})
}
impl Arbitrary for NameServerConfig { impl Arbitrary for NameServerConfig {
type Parameters = (); type Parameters = ();
type Strategy = BoxedStrategy<Self>; type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy { fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
let scheme = proptest::prop_oneof![
Just("http".to_string()),
Just("https".to_string()),
Just("tcp".to_string()),
Just("udp".to_string()),
];
let domain = Name::arbitrary();
let port = proptest::option::of(u16::arbitrary());
let user = proptest::option::of(proptest::string::string_regex("[A-Za-z]{1,30}").unwrap());
let password =
proptest::option::of(proptest::string::string_regex(":[A-Za-z]{1,30}").unwrap());
let path =
proptest::option::of(proptest::string::string_regex("/[A-Za-z/]{1,30}").unwrap());
let uri_strategy = (scheme, domain, user, password, port, path).prop_map(
|(scheme, domain, user, password, port, path)| {
let userpass_prefix = match (user, password) {
(None, None) => String::new(),
(Some(u), None) => format!("{u}@"),
(None, Some(p)) => format!(":{p}@"),
(Some(u), Some(p)) => format!("{u}:{p}@"),
};
let path = path.unwrap_or_default();
let port = port.map(|x| format!(":{x}")).unwrap_or_default();
let uri_str = format!("{scheme}://{userpass_prefix}{domain}{port}{path}");
Url::parse(&uri_str).unwrap()
},
);
( (
SocketAddr::arbitrary(), uri_strategy,
proptest::option::of(u64::arbitrary()), proptest::option::of(u64::arbitrary()),
proptest::option::of(SocketAddr::arbitrary()), proptest::option::of(SocketAddr::arbitrary()),
) )
.prop_map(|(mut address, mut timeout_in_seconds, mut bind_address)| { .prop_map(|(address, mut timeout_in_seconds, mut bind_address)| {
clear_flow_and_scope_info(&mut address);
if let Some(bind_address) = bind_address.as_mut() { if let Some(bind_address) = bind_address.as_mut() {
clear_flow_and_scope_info(bind_address); clear_flow_and_scope_info(bind_address);
} }
@@ -242,81 +299,56 @@ impl DnsConfig {
None => {} None => {}
Some(BuiltinDnsOption::Cloudflare) => { Some(BuiltinDnsOption::Cloudflare) => {
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 53)), address: Url::parse("udp://1.1.1.1:53").unwrap(),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 0, 0, 1), 53)), address: Url::parse("udp://1.0.0.1").unwrap(),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new( address: Url::parse("udp://2606:4700:4700::1111").unwrap(),
Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111),
53,
0,
0,
)),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new( address: Url::parse("udp://2606:4700:4700::1001:53").unwrap(),
Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1001),
53,
0,
0,
)),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
} }
Some(BuiltinDnsOption::Google) => { Some(BuiltinDnsOption::Google) => {
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 53)), address: Url::parse("udp://8.8.8.8:53").unwrap(),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 53)), address: Url::parse("udp://8.8.4.4").unwrap(),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new( address: Url::parse("udp://2001:4860:4860::8888:53").unwrap(),
Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888),
53,
0,
0,
)),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new( address: Url::parse("udp://2001:4860:4860::8844").unwrap(),
Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8844),
53,
0,
0,
)),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
} }
Some(BuiltinDnsOption::Quad9) => { Some(BuiltinDnsOption::Quad9) => {
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(9, 9, 9, 9), 53)), address: Url::parse("udp://9.9.9.9:53").unwrap(),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });
results.push(NameServerConfig { results.push(NameServerConfig {
address: SocketAddr::V6(SocketAddrV6::new( address: Url::parse("udp://2620::00fe:00f3").unwrap(),
Ipv6Addr::new(0x2620, 0, 0, 0, 0, 0, 0xfe, 0xfe),
53,
0,
0,
)),
timeout_in_seconds: None, timeout_in_seconds: None,
bind_address: None, bind_address: None,
}); });

View File

@@ -1,7 +1,7 @@
use crate::network::resolver::name::Name;
use crate::network::resolver::{ResolveError, Resolver}; use crate::network::resolver::{ResolveError, Resolver};
use error_stack::{report, ResultExt}; use error_stack::{report, ResultExt};
use futures::stream::{FuturesUnordered, StreamExt}; use futures::stream::{FuturesUnordered, StreamExt};
use hickory_client::rr::Name;
use std::collections::HashSet; use std::collections::HashSet;
use std::fmt; use std::fmt;
use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
@@ -63,7 +63,7 @@ impl FromStr for Host {
} }
} }
if let Ok(name) = Name::from_utf8(s) { if let Ok(name) = Name::from_str(s) {
return Ok(Host::Hostname(name)); return Ok(Host::Hostname(name));
} }
@@ -76,11 +76,8 @@ impl FromStr for Host {
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ConnectionError { pub enum ConnectionError {
#[error("Failed to resolve host: {error}")] #[error("Connection error: failed to resolve host")]
ResolveError { ResolveError,
#[from]
error: ResolveError,
},
#[error("No valid IP addresses found")] #[error("No valid IP addresses found")]
NoAddresses, NoAddresses,
#[error("Error connecting to host: {error}")] #[error("Error connecting to host: {error}")]
@@ -98,7 +95,10 @@ impl Host {
/// are no relevant records for us to use for IPv4 or IPv6 connections. There /// are no relevant records for us to use for IPv4 or IPv6 connections. There
/// is also no guarantee that the host will have both IPv4 and IPv6 addresses, /// is also no guarantee that the host will have both IPv4 and IPv6 addresses,
/// so you may only see one or the other. /// so you may only see one or the other.
pub async fn resolve(&self, resolver: &mut Resolver) -> Result<HashSet<IpAddr>, ResolveError> { pub async fn resolve(
&self,
resolver: &mut Resolver,
) -> error_stack::Result<HashSet<IpAddr>, ResolveError> {
match self { match self {
Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])), Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])),
Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])), Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])),
@@ -121,7 +121,7 @@ impl Host {
let addresses = self let addresses = self
.resolve(resolver) .resolve(resolver)
.await .await
.map_err(ConnectionError::from) .change_context(ConnectionError::ResolveError)
.attach_printable_lazy(|| format!("target address {}", self))?; .attach_printable_lazy(|| format!("target address {}", self))?;
let mut connectors = FuturesUnordered::new(); let mut connectors = FuturesUnordered::new();

View File

@@ -1,39 +1,27 @@
pub mod name;
mod protocol;
mod resolution_table;
use crate::config::resolver::{DnsConfig, NameServerConfig}; use crate::config::resolver::{DnsConfig, NameServerConfig};
use error_stack::report; use crate::network::resolver::name::Name;
use futures::future::select; use crate::network::resolver::resolution_table::ResolutionTable;
use futures::stream::{SelectAll, StreamExt}; use error_stack::{report, ResultExt};
use hickory_client::rr::Name;
use hickory_proto::error::ProtoError;
use hickory_proto::op::message::Message;
use hickory_proto::op::query::Query;
use hickory_proto::rr::record_data::RData;
use hickory_proto::rr::resource::Record;
use hickory_proto::rr::RecordType;
use hickory_proto::udp::UdpClientStream;
use hickory_proto::xfer::dns_request::DnsRequest;
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
#[cfg(test)]
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::net::UdpSocket; use tokio::net::{TcpSocket, UdpSocket};
use tokio::sync::{oneshot, Mutex}; use tokio::sync::Mutex;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tokio::time::{timeout_at, Duration, Instant}; use tokio::time::{Duration, Instant};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ResolverConfigError { pub enum ResolverConfigError {
#[error("Bad local domain name '{name}' provided: '{error}'")] #[error("Bad local domain name provided")]
BadDomainName { name: String, error: ProtoError }, BadDomainName,
#[error("Bad domain name for search '{name}' provided: '{error}'")] #[error("Couldn't create a DNS client for the given address, port, and protocol.")]
BadSearchName { name: String, error: ProtoError }, FailedToCreateDnsClient,
#[error("Couldn't set up client for server at '{address}': '{error}'")]
FailedToCreateDnsClient {
address: SocketAddr,
error: ProtoError,
},
#[error("No DNS servers found to search, and mDNS not enabled")] #[error("No DNS servers found to search, and mDNS not enabled")]
NoHosts, NoHosts,
} }
@@ -44,8 +32,8 @@ pub enum ResolveError {
NoServersAvailable, NoServersAvailable,
#[error("No responses found for query")] #[error("No responses found for query")]
NoResponses, NoResponses,
#[error("Error reading response: {error}")] #[error("Error reading response from server")]
ResponseError { error: ProtoError }, ResponseError,
} }
pub struct Resolver { pub struct Resolver {
@@ -53,11 +41,13 @@ pub struct Resolver {
max_time_to_wait_for_initial: Duration, max_time_to_wait_for_initial: Duration,
time_to_wait_after_first: Duration, time_to_wait_after_first: Duration,
time_to_wait_for_lingering: Duration, time_to_wait_for_lingering: Duration,
state: Arc<Mutex<ResolverState>>, connections: Arc<Mutex<Vec<(NameServerConfig, protocol::Client)>>>,
table: Arc<Mutex<ResolutionTable>>,
tasks: JoinSet<()>,
} }
pub struct ResolverState { pub struct ResolverState {
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>, client_connections: Vec<(NameServerConfig, protocol::Client)>,
cache: HashMap<Name, Vec<DnsResolution>>, cache: HashMap<Name, Vec<DnsResolution>>,
} }
@@ -70,40 +60,151 @@ impl Resolver {
/// Create a new DNS resolution engine for use by some part of the system. /// Create a new DNS resolution engine for use by some part of the system.
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> { pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
let mut search_domains = Vec::new(); let mut search_domains = Vec::new();
let mut tasks = JoinSet::new();
if let Some(local) = config.local_domain.as_ref() { if let Some(local) = config.local_domain.as_ref() {
search_domains.push(Name::from_utf8(local).map_err(|e| { let name = Name::from_str(local)
report!(ResolverConfigError::BadDomainName { .change_context(ResolverConfigError::BadDomainName)
name: local.clone(), .attach_printable("Trying to add local domain.")
error: e, .attach_printable_lazy(|| "Offending non-name: '{local}'")?;
})
})?); search_domains.push(name);
} }
for local_domain in config.search_domains.iter() { for search_domain in config.search_domains.iter() {
search_domains.push(Name::from_utf8(local_domain).map_err(|e| { let name = Name::from_str(search_domain)
report!(ResolverConfigError::BadSearchName { .change_context(ResolverConfigError::BadDomainName)
name: local_domain.clone(), .attach_printable("Trying to add search domain.")
error: e, .attach_printable_lazy(|| "Offending non-name: '{search_domain}'")?;
})
})?); search_domains.push(name);
} }
let mut client_connections = Vec::new(); let mut client_connections = Vec::new();
for name_server_config in config.name_servers() { for target in config.name_servers.iter() {
let stream = UdpClientStream::with_bind_addr_and_timeout( let port = target.address.port().unwrap_or(53);
name_server_config.address, let Some(address) = target.address.host() else {
name_server_config.bind_address, return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)), .attach_printable("No address to connect to?")
) .attach_printable_lazy(|| format!("Target address was {}", target.address));
.await };
.map_err(|error| ResolverConfigError::FailedToCreateDnsClient {
address: name_server_config.address, let address = match address {
error, url::Host::Ipv4(addr) => IpAddr::V4(addr),
url::Host::Ipv6(addr) => IpAddr::V6(addr),
url::Host::Domain(name) => {
return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
.attach_printable("Cannot use domain names to identify domain servers")
.attach_printable_lazy(|| format!("Target address was {name}"));
}
};
let target_addr = SocketAddr::new(address, port);
match target.address.scheme() {
"tcp" => {
let socket = if target_addr.is_ipv4() {
TcpSocket::new_v4()
} else {
TcpSocket::new_v6()
};
let socket = socket
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Could not create a socket")
.attach_printable_lazy(|| {
format!("For target DNS server {}", target.address)
})?; })?;
client_connections.push((name_server_config, stream)); if let Some(bind_address) = target.bind_address {
socket
.bind(bind_address)
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Could not bind local address for socket.")
.attach_printable_lazy(|| {
format!("Binding to TCP address {}", bind_address)
})
.attach_printable_lazy(|| {
format!("For target DNS server {}", target.address)
})?;
}
let stream = socket
.connect(target_addr)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable_lazy(|| {
format!("Connecting to target {}", target_addr)
})?;
let client = protocol::Client::from_tcp(stream, &mut tasks)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable_lazy(|| {
format!("Connecting to target {}", target_addr)
})?;
client_connections.push((target.clone(), client));
}
"udp" => {
let port = target.address.port().unwrap_or(53);
let Some(address) = target.address.host() else {
tracing::warn!(address = %target.address, "proposed domain server has no host");
continue;
};
let address = match address {
url::Host::Ipv4(addr) => IpAddr::V4(addr),
url::Host::Ipv6(addr) => IpAddr::V6(addr),
url::Host::Domain(name) => {
tracing::warn!(
address = %target.address,
hostname = name,
"currently, we can't use hostnames for domain servers"
);
continue;
}
};
let sock_addr = SocketAddr::new(address, port);
let bind_address = target.bind_address.unwrap_or_else(|| {
if sock_addr.is_ipv4() {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
} else {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
}
});
let udp_socket = UdpSocket::bind(bind_address)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Generating UDP socket")
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
udp_socket
.connect(target_addr)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Connecting UDP socket")
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
let client = protocol::Client::from_udp(udp_socket, &mut tasks).await;
client_connections.push((target.clone(), client));
}
"unix" => unimplemented!(),
"unixd" => unimplemented!(),
"http" => unimplemented!(),
"https" => unimplemented!(),
_ => {
tracing::warn!(address = %target.address, "Unknown scheme for building DNS connections");
continue;
}
}
} }
Ok(Resolver { Ok(Resolver {
@@ -112,321 +213,35 @@ impl Resolver {
max_time_to_wait_for_initial: Duration::from_millis(150), max_time_to_wait_for_initial: Duration::from_millis(150),
time_to_wait_after_first: Duration::from_millis(50), time_to_wait_after_first: Duration::from_millis(50),
time_to_wait_for_lingering: Duration::from_secs(2), time_to_wait_for_lingering: Duration::from_secs(2),
state: Arc::new(Mutex::new(ResolverState { connections: Arc::new(Mutex::new(client_connections)),
client_connections, table: Arc::new(Mutex::new(ResolutionTable::new())),
cache: HashMap::new(), tasks,
})),
}) })
} }
/// Look up the address of the given name, returning either a set of results pub async fn lookup(&self, _name: &Name) -> error_stack::Result<HashSet<IpAddr>, ResolveError> {
/// we've received in a reasonable amount of time, or an error. unimplemented!()
pub async fn lookup(&self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
let names = self.expand_name(name);
let cached_values = self.cached_lookup(&names).await;
if !cached_values.is_empty() {
return Ok(cached_values);
}
let mut response_stream = self.create_response_stream(&names).await;
if response_stream.is_empty() {
return Err(ResolveError::NoServersAvailable);
}
let get_first_by = Instant::now() + self.max_time_to_wait_for_initial;
let got_responses =
get_responses_by(&self.state, get_first_by, name, &mut response_stream).await?;
if got_responses {
let get_rest_by = Instant::now() + self.time_to_wait_after_first;
let _ = get_responses_by(&self.state, get_rest_by, name, &mut response_stream).await;
}
let lingering_deadline = Instant::now() + self.time_to_wait_for_lingering;
let state_copy = self.state.clone();
let name_copy = name.clone();
tokio::task::spawn(async move {
let _ = get_responses_by(
&state_copy,
lingering_deadline,
&name_copy,
&mut response_stream,
)
.await;
});
let after_search_lookup = self.cached_lookup(&names).await;
if after_search_lookup.is_empty() {
Err(ResolveError::NoResponses)
} else {
Ok(after_search_lookup)
}
}
/// Look up any cached IP address we might have for the given names.
///
/// As a side effect, this function will clean out any expired entries in the
/// cache.
async fn cached_lookup(&self, names: &[Name]) -> HashSet<IpAddr> {
let mut retval = HashSet::new();
let mut state = self.state.lock().await;
for name in names {
if let Some(mut existing) = state.cache.remove(name) {
let now = Instant::now();
existing.retain(|item| {
let keeper = item.expires_at > now;
if keeper {
retval.insert(item.address);
}
keeper
});
if !existing.is_empty() {
state.cache.insert(name.clone(), existing);
}
}
}
retval
}
/// Reach out to all of our current connections, and start the process of getting
/// responses.
///
/// If the clients are shut down, we'll try to create new connections to them, or
/// remove them from our connection list if we can't. If there are no connections,
/// the `SelectAll` will be empty, and callers are advised to check for this
/// condition and inform the user that they're out of useful DNS servers.
async fn create_response_stream(&self, names: &[Name]) -> SelectAll<DnsResponseStream> {
let mut state = self.state.lock().await;
let connections = std::mem::take(&mut state.client_connections);
let mut response_stream = futures::stream::SelectAll::new();
for (config, mut client) in connections.into_iter() {
if client.is_shutdown() {
let stream = UdpClientStream::with_bind_addr_and_timeout(
config.address,
config.bind_address,
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
)
.await;
match stream {
Ok(stream) => client = stream,
Err(_) => continue,
}
}
let mut message = Message::new();
for name in names.iter() {
message.add_query(Query::query(name.clone(), RecordType::A));
message.add_query(Query::query(name.clone(), RecordType::AAAA));
}
message.set_recursion_desired(true);
let request = DnsRequest::new(message, Default::default());
response_stream.push(client.send_message(request));
state.client_connections.push((config, client));
}
response_stream
}
/// Expand the given input name into the complete set of names that we should search for.
fn expand_name(&self, name: &Name) -> Vec<Name> {
let mut names = vec![name.clone()];
for search_domain in self.search_domains.iter() {
if let Ok(combined) = name.clone().append_name(search_domain) {
names.push(combined);
}
}
names
}
/// Run a cleaner task on this Resolver object.
///
/// The task will run in the given JoinSet, set it can be easily killed, tracked, waited
/// for as desired. Every `period` interval, it will grab the state lock and clean out
/// any entries that are timed out. Ideally, you shouldn't set this too frequently, so
/// that this process doesn't interfere with other tasks trying to use the resolver.
///
/// This task will only weakly hang on to the resolver state, so that if the Resolver
/// object is dropped it will quietly stop running. It will also stop running if it sees
/// a message on the kill signal, if one is provided, or if one is provided and the other
/// end drops the sender. (So if you don't want a kill signal, send `None`, or you might
/// accidentally kill your GC task too early.)
pub async fn run_resolver_gc(
&self,
task_set: &mut JoinSet<Result<(), ()>>,
mut kill_signal: Option<oneshot::Receiver<()>>,
interval: Duration,
) {
let weak_locked_state = Arc::downgrade(&self.state);
task_set.spawn(async move {
loop {
tokio::time::sleep(interval).await;
let Some(locked_state) = weak_locked_state.upgrade() else {
break Ok(());
};
let mut state = if let Some(kill_signal) = kill_signal.as_mut() {
let locker = locked_state.lock();
futures::pin_mut!(locker);
match select(locker, kill_signal).await {
futures::future::Either::Left((state, _)) => state,
futures::future::Either::Right((_, _)) => break Ok(()),
}
} else {
locked_state.lock().await
};
let now = Instant::now();
state.cache.retain(|_, value| {
value.retain(|resolution| resolution.expires_at < now);
!value.is_empty()
});
drop(state);
}
});
}
/// Inject a mapping into the resolver, with the given TTL.
///
/// This is largely used for testing purposes, but could be used if there are static
/// addresses you want to always resolve in a particular way. For that use case, you
/// probably want to add absurdly-long TTLs for those names.
pub async fn inject_resolution(&self, name: Name, address: IpAddr, ttl: Duration) {
let resolution = DnsResolution {
address,
expires_at: Instant::now() + ttl,
};
let mut state = self.state.lock().await;
match state.cache.entry(name) {
std::collections::hash_map::Entry::Vacant(vac) => {
vac.insert(vec![resolution]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => {
occ.get_mut().push(resolution);
}
}
drop(state);
} }
} }
/// Get all the responses into cache that occur before the given timestamp. //#[tokio::test]
/// //async fn fetch_cached() {
/// Returns an error if all we get are errors while trying to read from these servers, // let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
/// otherwise returns Ok(true) if we received some interesting responses, or Ok(false) // let name = Name::from_utf8("name.foo").unwrap();
/// if we didn't. Note that receiving good responses wins out over receiving errors, so // let addr = IpAddr::from_str("1.2.4.5").unwrap();
/// if we receive 2 good responses and 3 errors, this function will return Ok(true). //
async fn get_responses_by( // resolver
state: &Arc<Mutex<ResolverState>>, // .inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
deadline: Instant, // .await;
name: &Name, // let read = resolver.lookup(&name).await.unwrap();
stream: &mut SelectAll<DnsResponseStream>, // assert!(read.contains(&addr));
) -> Result<bool, ResolveError> { //}
let mut received_error = None;
let mut got_something = false;
loop { //#[tokio::test]
match timeout_at(deadline, stream.next()).await { //async fn uhsure() {
Err(_) => return received_error.unwrap_or(Ok(got_something)), // let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
// let name = Name::from_ascii("uhsure.com").unwrap();
Ok(None) if got_something => return Ok(got_something), // let result = resolver.lookup(&name).await.unwrap();
Ok(None) => return received_error.unwrap_or(Ok(false)), // println!("result = {:?}", result);
// assert!(!result.is_empty());
Ok(Some(Err(error))) => { //}
if received_error.is_none() {
received_error = Some(Err(ResolveError::ResponseError { error }));
}
}
Ok(Some(Ok(response))) => {
for answer in response.into_message().take_answers() {
got_something |= handle_response(state, name, answer).await;
}
}
}
}
}
/// Handle an individual response from a server.
///
/// Returns true if we got an answer from the server, and updates the internal cache.
/// This function will never return an error, although there may be some odd processing
/// situations that could lead it to printing warnings to the logs.
async fn handle_response(state: &Arc<Mutex<ResolverState>>, query: &Name, record: Record) -> bool {
let ttl = record.ttl();
let Some(rdata) = record.into_data() else {
tracing::error!("for some reason, couldn't process incoming message");
return false;
};
let address: IpAddr = match rdata {
RData::A(arec) => arec.0.into(),
RData::AAAA(arec) => arec.0.into(),
_ => {
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
return false;
}
};
let now = Instant::now();
let expires_at = now + Duration::from_secs(ttl as u64);
let resolution = DnsResolution {
address,
expires_at,
};
let mut state = state.lock().await;
match state.cache.entry(query.clone()) {
std::collections::hash_map::Entry::Vacant(vec) => {
vec.insert(vec![resolution]);
}
std::collections::hash_map::Entry::Occupied(mut occ) => {
let vec = occ.get_mut();
vec.retain(|x| x.expires_at > now);
vec.push(resolution);
}
}
drop(state);
true
}
#[tokio::test]
async fn fetch_cached() {
let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
let name = Name::from_utf8("name.foo").unwrap();
let addr = IpAddr::from_str("1.2.4.5").unwrap();
resolver
.inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
.await;
let read = resolver.lookup(&name).await.unwrap();
assert!(read.contains(&addr));
}
#[tokio::test]
async fn uhsure() {
let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
let name = Name::from_ascii("uhsure.com").unwrap();
let result = resolver.lookup(&name).await.unwrap();
println!("result = {:?}", result);
assert!(!result.is_empty());
}

View File

@@ -0,0 +1,505 @@
use bytes::{Buf, BufMut};
use error_stack::{report, ResultExt};
use internment::ArcIntern;
use proptest::arbitrary::Arbitrary;
use proptest::char::CharStrategy;
use proptest::strategy::{BoxedStrategy, Strategy};
use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
use std::ops::{Range, RangeInclusive};
use std::str::FromStr;
use thiserror::Error;
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct Name {
labels: Vec<Label>,
}
impl fmt::Debug for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut first_section = true;
for &Label(ref label) in self.labels.iter() {
if first_section {
first_section = false;
write!(f, "{}", label.as_str())?;
} else {
write!(f, ".{}", label.as_str())?;
}
}
Ok(())
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
#[derive(Clone, Hash, Eq)]
pub struct Label(ArcIntern<String>);
impl fmt::Debug for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.as_str().fmt(f)
}
}
impl fmt::Display for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.as_str().fmt(f)
}
}
impl PartialEq for Label {
fn eq(&self, other: &Self) -> bool {
self.0.eq_ignore_ascii_case(other.0.as_str())
}
}
#[derive(Debug, Error)]
pub enum NameParseError {
#[error("Provided name '{name}' is too long ({observed_length} bytes); maximum length of a DNS name is 255 octets.")]
NameTooLong {
name: String,
observed_length: usize,
},
#[error("Provided name '{name}' contains illegal character '{illegal_character}'.")]
NonAsciiCharacter {
name: String,
illegal_character: char,
},
#[error("Provided name '{name}' contains an empty section/label, which isn't allowed.")]
EmptyLabel { name: String },
#[error("Provided name '{name}' contains a label ('{label}') that is too long ({observed_length} letters, which is more than 63).")]
LabelTooLong {
observed_length: usize,
name: String,
label: String,
},
#[error("Provided name '{name}' contains a label ('{label}') that begins with an illegal character; it must be a letter.")]
LabelStartsWrong { name: String, label: String },
#[error("Provided name '{name}' contains a label ('{label}') that ends with an illegal character; it must be a letter or number.")]
LabelEndsWrong { name: String, label: String },
#[error("Provided name '{name}' contains a label ('{label}') that contains a non-letter, non-number, and non-dash.")]
IllegalInternalCharacter { name: String, label: String },
}
impl FromStr for Name {
type Err = NameParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let observed_length = value.as_bytes().len();
if observed_length > 255 {
return Err(NameParseError::NameTooLong {
name: value.to_string(),
observed_length,
});
}
for char in value.chars() {
if !(char.is_ascii_alphanumeric() || char == '.' || char == '-') {
return Err(NameParseError::NonAsciiCharacter {
name: value.to_string(),
illegal_character: char,
});
}
}
let mut labels = vec![];
for label_str in value.split('.') {
if label_str.is_empty() {
return Err(NameParseError::EmptyLabel {
name: value.to_string(),
});
}
if label_str.len() > 63 {
return Err(NameParseError::LabelTooLong {
name: value.to_string(),
label: label_str.to_string(),
observed_length: label_str.len(),
});
}
let letter = |x| ('a'..='z').contains(&x) || ('A'..='Z').contains(&x);
let letter_or_num = |x| letter(x) || ('0'..='9').contains(&x);
let letter_num_dash = |x| letter_or_num(x) || (x == '-');
if !label_str.starts_with(letter) {
return Err(NameParseError::LabelStartsWrong {
name: value.to_string(),
label: label_str.to_string(),
});
}
if !label_str.ends_with(letter_or_num) {
return Err(NameParseError::LabelEndsWrong {
name: value.to_string(),
label: label_str.to_string(),
});
}
if label_str.contains(|x| !letter_num_dash(x)) {
return Err(NameParseError::IllegalInternalCharacter {
name: value.to_string(),
label: label_str.to_string(),
});
}
// RFC 1035 says that all domain names are case-insensitive. We
// arbitrarily normalize to lowercase here, because it shouldn't
// matter to anyone.
labels.push(Label(ArcIntern::new(label_str.to_string())));
}
Ok(Name { labels })
}
}
#[derive(Debug, Error)]
pub enum NameReadError {
#[error("Could not read name field out of an empty buffer.")]
EmptyBuffer,
#[error("Buffer truncated before we could read the last label")]
TruncatedBuffer,
#[error("Provided label value too long; must be 63 octets or less.")]
LabelTooLong,
#[error("Label truncated while reading. Broken stream?")]
LabelTruncated,
#[error("Label starts with an illegal character (must be [A-Za-z])")]
WrongFirstByte,
#[error("Label ends with an illegal character (must be [A-Za-z0-9]")]
WrongLastByte,
#[error("Label contains an illegal character (must be [A-Za-z0-9] or a dash)")]
WrongInnerByte,
}
impl Name {
/// Read a name frm a record, or return an error describing what went wrong.
///
/// This function will advance the read pointer on the record. If the result is
/// an error, the read pointer is not guaranteed to be in a good place, so you
/// may need to manage that externally if you want to implement some sort of
/// "try" functionality.
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Name, NameReadError> {
let mut labels = Vec::new();
let name_so_far = |labels: Vec<Label>| {
let mut result = String::new();
for Label(label) in labels.into_iter() {
if result.is_empty() {
result.push_str(label.as_str());
} else {
result.push('.');
result.push_str(label.as_str());
}
}
result
};
loop {
if !buffer.has_remaining() && labels.is_empty() {
return Err(report!(NameReadError::EmptyBuffer));
}
if !buffer.has_remaining() {
return Err(report!(NameReadError::TruncatedBuffer))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)));
}
let label_octet_length = buffer.get_u8() as usize;
if label_octet_length == 0 {
break;
}
if label_octet_length > 63 {
return Err(report!(NameReadError::LabelTooLong)).attach_printable_lazy(|| {
format!(
"label length too big; max is supposed to be 63, saw {label_octet_length}"
)
});
}
if !buffer.remaining() < label_octet_length {
let remaining = buffer.copy_to_bytes(buffer.remaining());
let partial = String::from_utf8_lossy(&remaining).to_string();
return Err(report!(NameReadError::LabelTruncated))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!(
"expected {} octets, but only found {}",
label_octet_length,
buffer.remaining()
)
})
.attach_printable_lazy(|| format!("partial read: '{partial}'"));
}
let label_bytes = buffer.copy_to_bytes(label_octet_length);
let Some(first_byte) = label_bytes.first() else {
panic!(
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
);
};
let Some(last_byte) = label_bytes.last() else {
panic!(
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
);
};
let letter = |x| (b'a'..=b'z').contains(&x) || (b'A'..=b'Z').contains(&x);
let letter_or_num = |x| letter(x) || (b'0'..=b'9').contains(&x);
let letter_num_dash = |x| letter_or_num(x) || (x == b'-');
if !letter(*first_byte) {
return Err(report!(NameReadError::WrongFirstByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
if !letter_or_num(*last_byte) {
return Err(report!(NameReadError::WrongLastByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
if label_bytes.iter().any(|x| !letter_num_dash(*x)) {
return Err(report!(NameReadError::WrongInnerByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
let label = label_bytes.into_iter().map(|x| x as char).collect();
labels.push(Label(ArcIntern::new(label)));
}
Ok(Name { labels })
}
/// Write a name out to the given buffer.
///
/// This will try as hard as it can to write the value to the given buffer. If an
/// error occurs, you may end up with partially-written data, so if you're worried
/// about that you should be careful to mark where you started in the output buffer.
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), NameWriteError> {
for &Label(ref label) in self.labels.iter() {
let bytes = label.as_bytes();
if buffer.remaining_mut() < (bytes.len() + 1) {
return Err(report!(NameWriteError::NoRoomForLabel))
.attach_printable_lazy(|| format!("Writing name {self}"))
.attach_printable_lazy(|| format!("For label {label}"));
}
if bytes.is_empty() || bytes.len() > 63 {
return Err(report!(NameWriteError::IllegalLabel))
.attach_printable_lazy(|| format!("Writing name {self}"))
.attach_printable_lazy(|| format!("For label {label}"));
}
buffer.put_u8(bytes.len() as u8);
buffer.put_slice(bytes);
}
if !buffer.has_remaining_mut() {
return Err(report!(NameWriteError::NoRoomForNull))
.attach_printable_lazy(|| format!("Writing name {self}"));
}
buffer.put_u8(0);
Ok(())
}
}
#[derive(Debug, Error)]
pub enum NameWriteError {
#[error("Not enough room to write label in name")]
NoRoomForLabel,
#[error("Ran out of room writing the terminating NULL for the name")]
NoRoomForNull,
#[error("Internal error: Illegal label (this shouldn't happen)")]
IllegalLabel,
}
#[derive(Debug)]
pub struct ArbitraryDomainNameSpecifications {
total_length_range: Range<usize>,
label_length_range: Range<usize>,
number_labels: Range<usize>,
}
impl Default for ArbitraryDomainNameSpecifications {
fn default() -> Self {
ArbitraryDomainNameSpecifications {
total_length_range: 5..256,
label_length_range: 1..64,
number_labels: 2..5,
}
}
}
impl Arbitrary for Name {
type Parameters = ArbitraryDomainNameSpecifications;
type Strategy = BoxedStrategy<Name>;
fn arbitrary_with(spec: Self::Parameters) -> Self::Strategy {
(spec.number_labels.clone(), spec.total_length_range.clone())
.prop_flat_map(move |(mut labels, mut total_length)| {
// we need to make sure that our total length and our label count are,
// at the very minimum, compatible. If they're not, we'll need to adjust
// them.
//
// in general, we prefer to update our number of labels, if we can, but
// only to the extent that the labels count is at least the minimum of
// our input specification. if we try to go below that, we increase total
// length as required. If we're forced into a place where we need to take
// the label length below it's minimum and/or the total_length over its
// maximum, we just give up and panic.
//
// Note that the minimum length of n labels is (n * label_minimum) + n - 1.
// Consider, for example, a label_minimum of 1 and a label length of 3.
// A minimum string is "a.b.c", which is (3 * 1) + 3 - 1 = 5 characters
// long.
//
// This loop does the first part, lowering the number of labels until we
// either get below the total length or reach the minimum number of labels.
while (labels * spec.label_length_range.start) + (labels - 1) > total_length {
if labels == spec.number_labels.start {
break;
} else {
labels -= 1;
}
}
// At this point, if it's not right, we just set it to be right.
if (labels * spec.total_length_range.start) + (labels - 1) > total_length {
total_length = (labels * spec.total_length_range.start) + (labels - 1);
}
// And if this takes us over our limit, just panic.
if total_length >= spec.total_length_range.end {
panic!("Unresolvable generation condition; couldn't resolve label count {} with total_length {}, with specification {:?}", labels, total_length, spec);
}
proptest::collection::vec(Label::arbitrary_with(spec.label_length_range.clone()), labels)
.prop_map(|labels| Name{ labels })
}).boxed()
}
}
impl Arbitrary for Label {
type Parameters = Range<usize>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary() -> Self::Strategy {
Self::arbitrary_with(1..64)
}
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
args.prop_flat_map(|length| {
let first = char_selector(&['a'..='z', 'A'..='Z']);
let middle = char_selector(&['a'..='z', 'A'..='Z', '-'..='-', '0'..='9']);
let last = char_selector(&['a'..='z', 'A'..='Z', '0'..='9']);
match length {
0 => panic!("Should not be able to generate a label of length 0"),
1 => first.prop_map(|x| x.into()).boxed(),
2 => (first, last).prop_map(|(a, b)| format!("{a}{b}")).boxed(),
_ => (first, proptest::collection::vec(middle, length - 2), last)
.prop_map(move |(first, middle, last)| {
let mut result = String::with_capacity(length);
result.push(first);
for c in middle.iter() {
result.push(*c);
}
result.push(last);
result
})
.boxed(),
}
})
.prop_map(|x| Label(ArcIntern::new(x)))
.boxed()
}
}
fn char_selector<'a>(ranges: &'a [RangeInclusive<char>]) -> CharStrategy<'a> {
CharStrategy::new(
Cow::Borrowed(&[]),
Cow::Borrowed(&[]),
Cow::Borrowed(ranges),
)
}
proptest::proptest! {
#[test]
fn any_random_names_parses(name: Name) {
let name_str = name.to_string();
let new_name = Name::from_str(&name_str).expect("can re-parse name");
assert_eq!(name, new_name);
}
#[test]
fn any_random_name_roundtrips(name: Name) {
let mut write_buffer = bytes::BytesMut::with_capacity(512);
name.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_name = Name::read(&mut read_buffer).expect("can read name");
assert_eq!(name, new_name);
}
}
#[test]
fn illegal_names_generate_errors() {
assert!(Name::from_str("").is_err());
assert!(Name::from_str(".").is_err());
assert!(Name::from_str(".com").is_err());
assert!(Name::from_str("com.").is_err());
assert!(Name::from_str("9.com").is_err());
assert!(Name::from_str("foo-.com").is_err());
assert!(Name::from_str("-foo.com").is_err());
assert!(Name::from_str(".foo.com").is_err());
assert!(Name::from_str("fo*o.com").is_err());
assert!(Name::from_str(
"foo.abcdefghiabcdefghiabcdefghiabcdefghiabcdefghiabcdefghijjjjjjabcdefghij.com"
)
.is_err());
assert!(Name::from_str("abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij").is_err());
}
#[test]
fn names_ignore_case() {
assert_eq!(
Name::from_str("UHSURE.COM").unwrap(),
Name::from_str("uhsure.com").unwrap()
);
}

View File

@@ -0,0 +1,9 @@
mod client;
mod header;
pub mod question;
mod request;
mod resource_record;
mod response;
mod server;
pub use client::Client;

View File

@@ -0,0 +1,273 @@
use crate::network::resolver::protocol::header::Header;
use crate::network::resolver::protocol::question::Question;
use bytes::{Bytes, BytesMut};
use error_stack::ResultExt;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncReadExt, WriteHalf};
use tokio::net::{TcpStream, UdpSocket, UnixDatagram, UnixStream};
use tokio::sync::RwLock;
use tokio::task::JoinSet;
type Callback = fn() -> ();
pub struct Client {
callback: Arc<RwLock<Callback>>,
channel: DnsChannel,
}
pub enum DnsChannel {
Tcp(WriteHalf<TcpStream>),
Udp(Arc<UdpSocket>),
UnixData(Arc<UnixDatagram>),
Unix(WriteHalf<UnixStream>),
}
fn empty_callback() {}
#[derive(Debug, thiserror::Error)]
pub enum SendError {}
#[derive(Clone, Debug)]
enum GeneralAddr {
Network(SocketAddr),
Unix(std::os::unix::net::SocketAddr),
}
impl From<SocketAddr> for GeneralAddr {
fn from(value: SocketAddr) -> Self {
GeneralAddr::Network(value)
}
}
impl From<tokio::net::unix::SocketAddr> for GeneralAddr {
fn from(value: tokio::net::unix::SocketAddr) -> Self {
GeneralAddr::Unix(value.into())
}
}
#[derive(Debug, Error)]
enum ServerProcessorError {
#[error("Could not read message header from server.")]
CouldNotReadHeader,
}
async fn process_server_response(
callback: &Arc<RwLock<Callback>>,
mut bytes: Bytes,
source: GeneralAddr,
) -> error_stack::Result<(), ServerProcessorError> {
unimplemented!()
}
async fn run_response_processing_loop(
server: String,
maximum_consecurive_errors: u64,
callback: Arc<RwLock<Callback>>,
mut fetcher: impl AsyncFnMut() -> Result<(Bytes, GeneralAddr), std::io::Error>,
) {
let mut consecutive_errors = 0;
loop {
match fetcher().await {
Ok((bytes, source)) => {
if let Err(e) = process_server_response(&callback, bytes, source).await {
consecutive_errors += 1;
tracing::warn!(
server,
error = %e,
maximum_consecurive_errors,
consecutive_errors,
"error processing DNS server response"
);
} else {
consecutive_errors = 0;
}
}
Err(e) => {
consecutive_errors += 1;
tracing::warn!(
server,
error = %e,
maximum_consecurive_errors,
consecutive_errors,
"failed to read response from DNS server"
);
if consecutive_errors >= maximum_consecurive_errors {
break;
}
}
};
}
tracing::error!(
server,
"quitting DNS response processing loop due to too many consecutive errors"
);
}
impl Client {
/// Create a new DNS client from the given, targeted UnixDatagram socket.
///
/// By "targeted", we mean that this socket should've had `connect` called on
/// it, so that the client can just send datagrams to the server without having
/// to know where to send them. It is unspecified what will happen to DNS clients
/// if you do not call `connect` beforehand.
pub async fn from_unix_datagram(socket: UnixDatagram, group: &mut JoinSet<()>) -> Self {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let socket = Arc::new(socket);
let reader_socket = socket.clone();
let reader_callback = callback.clone();
let server_addr = socket.local_addr()
.ok()
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
.flatten()
.unwrap_or_else(|| "<unknown>".into());
let server = format!("unixd://{}", server_addr);
group.spawn(async move {
let fetcher = async move || {
let mut buffer = BytesMut::with_capacity(16384);
match reader_socket.recv_buf_from(&mut buffer).await {
Err(x) => Err(x),
Ok((size, from)) => {
unsafe {
buffer.set_len(size);
};
Ok((buffer.freeze(), from.into()))
}
}
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Client {
callback,
channel: DnsChannel::UnixData(socket),
}
}
pub async fn from_unix_stream(
socket: UnixStream,
group: &mut JoinSet<()>,
) -> std::io::Result<Self> {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let server_addr = socket.local_addr()
.ok()
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
.flatten()
.unwrap_or_else(|| "<unknown>".into());
let server = format!("unix://{}", server_addr);
let other_side = GeneralAddr::Unix(socket.peer_addr()?.into());
let (mut reader, writer) = tokio::io::split(socket);
let reader_callback = callback.clone();
group.spawn(async move {
let fetcher = async move || {
let size = reader.read_u16().await?;
let mut buffer = vec![0u8; size as usize];
reader.read_exact(&mut buffer).await?;
Ok((Bytes::from(buffer), other_side.clone()))
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Ok(Client {
callback,
channel: DnsChannel::Unix(writer),
})
}
pub async fn from_udp(socket: UdpSocket, group: &mut JoinSet<()>) -> Self {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let socket = Arc::new(socket);
let reader_callback = callback.clone();
let reader_socket = socket.clone();
let server_addr = socket.local_addr()
.ok()
.map(|x| x.to_string())
.unwrap_or_else(|| "<unknown>".into());
let server = format!("udp://{}", server_addr);
group.spawn(async move {
let fetcher = async move || {
let mut buffer = BytesMut::with_capacity(16384);
match reader_socket.recv_buf_from(&mut buffer).await {
Err(x) => Err(x),
Ok((size, from)) => {
unsafe {
buffer.set_len(size);
};
Ok((buffer.freeze(), from.into()))
}
}
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Client {
callback,
channel: DnsChannel::Udp(socket),
}
}
pub async fn from_tcp(socket: TcpStream, group: &mut JoinSet<()>) -> std::io::Result<Self> {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let other_side = GeneralAddr::Network(socket.peer_addr()?);
let server_addr = socket.local_addr()
.ok()
.map(|x| x.to_string())
.unwrap_or_else(|| "<unknown>".into());
let server = format!("tcp://{}", server_addr);
let (mut reader, writer) = tokio::io::split(socket);
let reader_callback = callback.clone();
group.spawn(async move {
let fetcher = async move || {
let size = reader.read_u16().await?;
let mut buffer = vec![0u8; size as usize];
reader.read_exact(&mut buffer).await?;
Ok((Bytes::from(buffer), other_side.clone()))
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Ok(Client {
callback,
channel: DnsChannel::Tcp(writer),
})
}
/// Send a set of questions to the upstream server.
///
/// Any response(s) that is/are sent will be handled as part of the callback
/// scheme. This function will thus only return an error if there's a problem
/// sending the question to the server.
pub async fn send_questions(_questions: Vec<Question>) -> error_stack::Result<(), SendError> {
Ok(())
}
/// Set the callback handler for when we receive responses from the server.
///
/// This function may take awhile to execute, depending on how busy we are
/// taking responses, as it takes write ownership of a read-write lock that's
/// almost always written.
pub async fn set_callback(&self, callback: Callback) {
*self.callback.write().await = callback;
}
}

View File

@@ -0,0 +1,225 @@
use bytes::{Buf, BufMut};
use error_stack::{report, ResultExt};
use num_enum::{FromPrimitive, IntoPrimitive};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use std::fmt;
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub struct Header {
pub message_id: u16,
pub is_response: bool,
pub opcode: OpCode,
pub authoritative_answer: bool,
pub message_truncated: bool,
pub recursion_desired: bool,
pub recursion_available: bool,
pub response_code: ResponseCode,
pub question_count: u16,
pub answer_count: u16,
pub name_server_count: u16,
pub additional_record_count: u16,
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u8)]
pub enum OpCode {
StandardQuery = 0,
InverseQuery = 1,
ServiceStatusRequest = 2,
#[num_enum(catch_all)]
Other(u8),
}
impl Arbitrary for OpCode {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
// while the type is 8 bits in rust, it's 4 bits in the protocol,
// and dealing with the run-off is messy. so this only generates
// valid-sized values here. it also biases toward the legit values.
// both things might want to be reconsidered in the future.
proptest::prop_oneof![
Just(OpCode::StandardQuery),
Just(OpCode::InverseQuery),
Just(OpCode::ServiceStatusRequest),
(3u8..=15).prop_map(OpCode::Other),
]
.boxed()
}
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u8)]
pub enum ResponseCode {
NoErrorConditions = 0,
FormatError = 1,
ServerFailure = 2,
NameError = 3,
NotImplemented = 4,
Refused = 5,
#[num_enum(catch_all)]
Other(u8),
}
impl fmt::Display for ResponseCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ResponseCode::NoErrorConditions => write!(f, "No errors"),
ResponseCode::FormatError => write!(f, "Illegal format"),
ResponseCode::ServerFailure => write!(f, "Server failure"),
ResponseCode::NameError => write!(f, "Name error"),
ResponseCode::NotImplemented => write!(f, "Not implemented"),
ResponseCode::Refused => write!(f, "Refused"),
ResponseCode::Other(x) => write!(f, "unknown error {x}"),
}
}
}
impl Arbitrary for ResponseCode {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
// while the type is 8 bits in rust, it's 4 bits in the protocol,
// and dealing with the run-off is messy. so this only generates
// valid-sized values here. it also biases toward the legit values.
// both things might want to be reconsidered in the future.
proptest::prop_oneof![
Just(ResponseCode::NoErrorConditions),
Just(ResponseCode::FormatError),
Just(ResponseCode::ServerFailure),
Just(ResponseCode::NameError),
Just(ResponseCode::NotImplemented),
Just(ResponseCode::Refused),
(6u8..=15).prop_map(ResponseCode::Other),
]
.boxed()
}
}
#[derive(Debug, Error)]
pub enum HeaderReadError {
#[error("Buffer not large enough to have a DNS message header in it.")]
BufferTooSmall,
#[error("Invalid data in zero-filled space.")]
NonZeroInZeroes,
}
#[derive(Debug, Error)]
pub enum HeaderWriteError {
#[error("Buffer not large enough to write header.")]
BufferTooSmall,
}
impl Header {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, HeaderReadError> {
if buffer.remaining() < 12
/* 6 16-bit fields = 6 * 2 = 1 */
{
return Err(report!(HeaderReadError::BufferTooSmall)).attach_printable_lazy(|| {
format!(
"Need at least {} bytes, but only had {}",
6 * 12,
buffer.remaining()
)
});
}
let message_id = buffer.get_u16();
let flags = buffer.get_u16();
let question_count = buffer.get_u16();
let answer_count = buffer.get_u16();
let name_server_count = buffer.get_u16();
let additional_record_count = buffer.get_u16();
let is_response = (0x8000 & flags) != 0;
let opcode = OpCode::from(((flags >> 11) & 0xF) as u8);
let authoritative_answer = (0x0400 & flags) != 0;
let message_truncated = (0x0200 & flags) != 0;
let recursion_desired = (0x0100 & flags) != 0;
let recursion_available = (0x0080 & flags) != 0;
let zeroes = 0x0070 & flags;
let response_code = ResponseCode::from((flags & 0x000F) as u8);
if zeroes != 0 {
return Err(report!(HeaderReadError::NonZeroInZeroes))
.attach_printable_lazy(|| format!("Saw {:#x} instead.", zeroes >> 4));
}
Ok(Header {
message_id,
is_response,
opcode,
authoritative_answer,
message_truncated,
recursion_desired,
recursion_available,
response_code,
question_count,
answer_count,
name_server_count,
additional_record_count,
})
}
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), HeaderWriteError> {
if buffer.remaining_mut() < 12
/* 6 16-bit fields = 6 * 2 = 1 */
{
return Err(report!(HeaderWriteError::BufferTooSmall)).attach_printable_lazy(|| {
format!(
"Need at least {} to write DNS header, only have {}",
6 * 12,
buffer.remaining_mut()
)
});
}
let mut flags: u16 = 0;
if self.is_response {
flags |= 0x8000;
}
let opcode: u8 = self.opcode.into();
flags |= (opcode as u16) << 11;
if self.authoritative_answer {
flags |= 0x0400;
}
if self.message_truncated {
flags |= 0x0200;
}
if self.recursion_desired {
flags |= 0x0100;
}
if self.recursion_available {
flags |= 0x0080;
}
let response_code: u8 = self.response_code.into();
flags |= response_code as u16;
buffer.put_u16(self.message_id);
buffer.put_u16(flags);
buffer.put_u16(self.question_count);
buffer.put_u16(self.answer_count);
buffer.put_u16(self.name_server_count);
buffer.put_u16(self.additional_record_count);
Ok(())
}
}
proptest::proptest! {
#[test]
fn headers_roundtrip(header: Header) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
let safe_header = header.clone();
header.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_header = Header::read(&mut read_buffer).expect("can read name");
assert_eq!(safe_header, new_header);
}
}

View File

@@ -0,0 +1,91 @@
use crate::network::resolver::name::Name;
use crate::network::resolver::protocol::resource_record::raw::{RecordClass, RecordType};
use bytes::{Buf, BufMut, TryGetError};
use error_stack::{report, ResultExt};
use std::fmt;
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub struct Question {
name: Name,
record_type: RecordType,
record_class: RecordClass,
}
impl fmt::Display for Question {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "<Q:{}@{}.{}>", self.name, self.record_type, self.record_class)
}
}
#[derive(Debug, Error)]
pub enum QuestionReadError {
#[error("Could not read name for question.")]
CouldNotReadName,
#[error("Could not read the record type for the question: {0}")]
CouldNotReadType(TryGetError),
#[error("Could not read the record class for the question.")]
CouldNotReadClass(TryGetError),
}
#[derive(Debug, Error)]
pub enum QuestionWriteError {
#[error("Could not write name for the question.")]
CouldNotWriteName,
#[error("Buffer not large enough to write question type and class.")]
BufferTooSmall,
}
impl Question {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, QuestionReadError> {
let name = Name::read(buffer).change_context(QuestionReadError::CouldNotReadName)?;
let record_type_u16 = buffer
.try_get_u16()
.map_err(|e| report!(QuestionReadError::CouldNotReadType(e)))
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
let record_class_u16 = buffer
.try_get_u16()
.map_err(|e| report!(QuestionReadError::CouldNotReadClass(e)))
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
let record_type = RecordType::from(record_type_u16);
let record_class = RecordClass::from(record_class_u16);
Ok(Question {
name,
record_type,
record_class,
})
}
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), QuestionWriteError> {
self.name
.write(buffer)
.change_context(QuestionWriteError::CouldNotWriteName)
.attach_printable_lazy(|| format!("Question was about '{}'", self.name))?;
if buffer.remaining_mut() < 4 {
return Err(report!(QuestionWriteError::BufferTooSmall))
.attach_printable_lazy(|| format!("Question was about '{}'", self.name));
}
buffer.put_u16(self.record_type.into());
buffer.put_u16(self.record_class.into());
Ok(())
}
}
proptest::proptest! {
#[test]
fn questions_roundtrip(question: Question) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
question.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_question = Question::read(&mut read_buffer).expect("can read name");
assert_eq!(question, new_question);
}
}

View File

@@ -0,0 +1,156 @@
use bytes::{Buf, BufMut};
use error_stack::{report, ResultExt};
use crate::network::resolver::protocol::header::{Header, OpCode, ResponseCode};
use crate::network::resolver::protocol::question::Question;
use crate::network::resolver::protocol::resource_record::ResourceRecord;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub struct Request {
source_message_id: u16,
opcode: OpCode,
recursion_desired: bool,
questions: Vec<Question>,
}
#[derive(Debug, thiserror::Error)]
pub enum RequestReadError {
#[error("Error reading request header.")]
Header,
#[error("Message isn't a request.")]
NotRequest,
#[error("Illegal message.")]
IllegalMessage,
#[error("Error reading request question.")]
Question,
#[error("Error reading request answers (?!)")]
Answer,
#[error("Error reading requst name servers.")]
NameServers,
#[error("Reading request additional records.")]
AdditionalRecords,
}
#[derive(Debug, thiserror::Error)]
pub enum RequestWriteError {
#[error("Error writing request header.")]
Header,
#[error("Error writing request question.")]
Question,
#[error("Request had too many questions.")]
TooManyQuestions,
}
impl Request {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, RequestReadError> {
let header = Header::read(buffer)
.change_context(RequestReadError::Header)?;
if header.is_response {
return Err(report!(RequestReadError::NotRequest))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.authoritative_answer {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to set the 'authoritative answer' bit.")
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.response_code != ResponseCode::NoErrorConditions {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to set the response code.")
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.answer_count != 0 {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to include answers.")
.attach_printable_lazy(|| format!("{} answers declared", header.answer_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.name_server_count != 0 {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to include name servers.")
.attach_printable_lazy(|| format!("{} name servers declared", header.name_server_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.additional_record_count != 0 {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to include additional records.")
.attach_printable_lazy(|| format!("{} aditional records declared", header.additional_record_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
let mut questions = vec![];
for i in 0..header.question_count {
let question = Question::read(buffer)
.change_context(RequestReadError::Question)
.attach_printable_lazy(|| format!("for question #{} of {}", i+1, header.question_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id))?;
questions.push(question);
}
Ok(Request {
source_message_id: header.message_id,
opcode: header.opcode,
recursion_desired: header.recursion_desired,
questions,
})
}
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), RequestWriteError> {
let question_count = self.questions.len();
if question_count > (u16::MAX as usize) {
return Err(report!(RequestWriteError::TooManyQuestions))
.attach_printable(format!("message_id is {}", self.source_message_id));
}
let header = Header {
message_id: self.source_message_id,
is_response: false,
opcode: self.opcode,
authoritative_answer: false,
message_truncated: false,
recursion_desired: self.recursion_desired,
recursion_available: false,
response_code: ResponseCode::NoErrorConditions,
question_count: question_count as u16,
answer_count: 0,
name_server_count: 0,
additional_record_count: 0,
};
header.write(buffer)
.change_context(RequestWriteError::Header)
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))?;
for (index, question) in self.questions.into_iter().enumerate() {
question.write(buffer)
.change_context(RequestWriteError::Question)
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))
.attach_printable_lazy(|| format!("question #{} of {}", index+1, question_count))
.attach_printable_lazy(|| format!("{}", question))?;
}
Ok(())
}
}
proptest::proptest! {
#[test]
fn request_roundtrip(request: Request) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
let safe_request = request.clone();
request.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_request = Request::read(&mut read_buffer).expect("can read name");
assert_eq!(safe_request, new_request);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,304 @@
use crate::network::resolver::name::Name;
use bytes::{Buf, BufMut, Bytes};
use error_stack::{report, ResultExt};
use num_enum::{FromPrimitive, IntoPrimitive};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use std::fmt;
use thiserror::Error;
#[derive(Debug, PartialEq)]
pub struct RawResourceRecord {
pub name: Name,
pub record_type: RecordType,
pub record_class: RecordClass,
pub ttl: u32,
pub data: Bytes,
}
impl Arbitrary for RawResourceRecord {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(
Name::arbitrary(),
RecordType::arbitrary(),
RecordClass::arbitrary(),
u32::arbitrary(),
proptest::collection::vec(u8::arbitrary(), 0..65535),
)
.prop_map(
|(name, record_type, record_class, ttl, data)| RawResourceRecord {
name,
record_type,
record_class,
ttl,
data: data.into(),
},
)
.boxed()
}
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u16)]
pub enum RecordType {
A = 1,
AAAA = 28,
NS = 2,
MD = 3,
MF = 4,
CNAME = 5,
SOA = 6,
MB = 7,
MG = 8,
MR = 9,
NULL = 10,
WKS = 11,
PTR = 12,
HINFO = 13,
MINFO = 14,
MX = 15,
TXT = 16,
URI = 256,
#[num_enum(catch_all)]
Other(u16),
}
impl fmt::Display for RecordType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecordType::A => write!(f, "A"),
RecordType::AAAA => write!(f, "AAAA"),
RecordType::NS => write!(f, "NS"),
RecordType::MD => write!(f, "MD"),
RecordType::MF => write!(f, "MF"),
RecordType::CNAME => write!(f, "CNAME"),
RecordType::SOA => write!(f, "SOA"),
RecordType::MB => write!(f, "MB"),
RecordType::MG => write!(f, "MG"),
RecordType::MR => write!(f, "MR"),
RecordType::NULL => write!(f, "NULL"),
RecordType::WKS => write!(f, "WKS"),
RecordType::PTR => write!(f, "PTR"),
RecordType::HINFO => write!(f, "HINFO"),
RecordType::MINFO => write!(f, "MINFO"),
RecordType::MX => write!(f, "MX"),
RecordType::TXT => write!(f, "TXT"),
RecordType::URI => write!(f, "URI"),
RecordType::Other(x) => write!(f, "UNKNOWN<{x}>"),
}
}
}
impl Arbitrary for RecordType {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// this is intentionally biased towards the legit values
proptest::prop_oneof![
Just(RecordType::A),
Just(RecordType::AAAA),
Just(RecordType::NS),
Just(RecordType::MD),
Just(RecordType::MF),
Just(RecordType::CNAME),
Just(RecordType::SOA),
Just(RecordType::MB),
Just(RecordType::MG),
Just(RecordType::MR),
Just(RecordType::NULL),
Just(RecordType::WKS),
Just(RecordType::PTR),
Just(RecordType::HINFO),
Just(RecordType::MINFO),
Just(RecordType::MX),
Just(RecordType::TXT),
proptest::prop_oneof![
(17u16..28).prop_map(|x| RecordType::Other(x)),
(29u16..256).prop_map(|x| RecordType::Other(x)),
(257u16..=65535).prop_map(|x| RecordType::Other(x)),
Just(RecordType::Other(0)),
],
]
.boxed()
}
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u16)]
pub enum RecordClass {
IN = 1,
CS = 2,
CH = 3,
HS = 4,
#[num_enum(catch_all)]
Other(u16),
}
impl fmt::Display for RecordClass {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecordClass::IN => write!(f, "IN"),
RecordClass::CS => write!(f, "CS"),
RecordClass::CH => write!(f, "CH"),
RecordClass::HS => write!(f, "HS"),
RecordClass::Other(x) => write!(f, "UNKNOWN<{x}>"),
}
}
}
impl Arbitrary for RecordClass {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// this is intentionally biased towards the legit values
proptest::prop_oneof![
Just(RecordClass::IN),
Just(RecordClass::CS),
Just(RecordClass::CH),
Just(RecordClass::HS),
(5u16..=65535).prop_map(|x| RecordClass::Other(x)),
Just(RecordClass::Other(0)),
]
.boxed()
}
}
#[derive(Debug, Error)]
pub enum ResourceRecordReadError {
#[error("Failed to read initial record name.")]
InitialRecord,
#[error("Resource record truncated; couldn't find its type field.")]
NoTypeField,
#[error("Resource record truncated; couldn't find its class field.")]
NoClassField,
#[error("Resource record truncated; couldn't find its TTL field.")]
NoTtl,
#[error("Resource record truncated; couldn't find its data length.")]
NoDataLength,
#[error("Resource record truncated; couldn't read its entire data field.")]
DataTruncated,
}
#[derive(Debug, Error)]
pub enum ResourceRecordWriteError {
#[error("Could not write name to the output record.")]
CouldNotWriteName,
#[error("Could not write resource record type and class to output record.")]
CouldNotWriteTypeClass,
#[error("Could not write TTL to output record.")]
CountNotWriteTtl,
#[error("Could not write resource record data length to output record.")]
CountNotWriteDataLength,
#[error("Could not write resource record data to output record.")]
CountNotWriteData,
#[error("Input data was too large to write to output stream.")]
InputDataTooLarge,
}
impl RawResourceRecord {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, ResourceRecordReadError> {
let name = Name::read(buffer).change_context(ResourceRecordReadError::InitialRecord)?;
let record_type = buffer
.try_get_u16()
.map_err(|_| report!(ResourceRecordReadError::NoTypeField))?
.into();
let record_class = buffer
.try_get_u16()
.map_err(|_| report!(ResourceRecordReadError::NoClassField))?
.into();
let ttl = buffer
.try_get_u32()
.map_err(|_| report!(ResourceRecordReadError::NoTtl))?;
let rdata_length = buffer
.try_get_u16()
.map_err(|_| report!(ResourceRecordReadError::NoDataLength))?;
if buffer.remaining() < (rdata_length as usize) {
return Err(report!(ResourceRecordReadError::DataTruncated)).attach_printable_lazy(
|| {
format!(
"Expected {rdata_length} bytes, but only saw {}",
buffer.remaining()
)
},
);
}
let data = buffer.copy_to_bytes(rdata_length as usize);
Ok(RawResourceRecord {
name,
record_type,
record_class,
ttl,
data,
})
}
pub fn write<B: BufMut>(
&self,
buffer: &mut B,
) -> error_stack::Result<(), ResourceRecordWriteError> {
self.name
.write(buffer)
.change_context(ResourceRecordWriteError::CouldNotWriteName)?;
if buffer.remaining_mut() < 4 {
return Err(report!(ResourceRecordWriteError::CouldNotWriteTypeClass));
}
buffer.put_u16(self.record_type.into());
buffer.put_u16(self.record_class.into());
if buffer.remaining_mut() < 4 {
return Err(report!(ResourceRecordWriteError::CountNotWriteTtl));
}
buffer.put_u32(self.ttl);
if buffer.remaining_mut() < 2 {
return Err(report!(ResourceRecordWriteError::CountNotWriteDataLength));
}
if self.data.len() > (u16::MAX as usize) {
return Err(report!(ResourceRecordWriteError::InputDataTooLarge))
.attach_printable_lazy(|| {
format!(
"Incoming data was {} bytes, needs to be < 2^16",
self.data.len()
)
});
}
buffer.put_u16(self.data.len() as u16);
if buffer.remaining_mut() < self.data.len() {
return Err(report!(ResourceRecordWriteError::CountNotWriteData));
}
buffer.put_slice(&self.data);
Ok(())
}
}
proptest::proptest! {
#[test]
fn any_random_record_roundtrips(record: RawResourceRecord) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
record.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_record = RawResourceRecord::read(&mut read_buffer).expect("can read name");
assert_eq!(record, new_record);
}
}

View File

@@ -0,0 +1,258 @@
use bytes::{Buf, BufMut};
use error_stack::ResultExt;
use crate::network::resolver::protocol::header::{Header, ResponseCode};
use crate::network::resolver::protocol::question::Question;
use crate::network::resolver::protocol::resource_record::ResourceRecord;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub enum Response {
Valid {
source_message_id: u16,
authoritative: bool,
truncated: bool,
answers: Vec<ResourceRecord>,
name_servers: Vec<ResourceRecord>,
additional_records: Vec<ResourceRecord>,
},
FormatError {
source_message_id: u16,
},
ServerFailure {
source_message_id: u16,
},
NameError {
source_message_id: u16,
},
NotImplemented {
source_message_id: u16,
},
Refused {
source_message_id: u16,
},
UnknownError {
#[proptest(strategy="6u8..=15")]
error_code: u8,
source_message_id: u16,
}
}
#[derive(Debug, thiserror::Error)]
pub enum ResponseReadError {
#[error("Could not read response header.")]
HeaderReadError,
#[error("Could not read question included in response.")]
QuestionReadError,
#[error("Could not read resource record included as an answer.")]
AnswerReadError,
#[error("Could not read name server record included in response.")]
NameServerReadError,
#[error("Could not read supplemental record included in response.")]
AdditionalInfoReadError,
}
#[derive(Debug, thiserror::Error)]
pub enum ResponseWriteError {
#[error("Could not write response header.")]
Header,
#[error("Could not write answer.")]
Answer,
#[error("Could not write name server.")]
NameServer,
#[error("Could not write additional record.")]
AdditionalRecord,
}
impl Response {
pub fn message_id(&self) -> u16 {
match self {
Response::Valid { source_message_id, .. } => *source_message_id,
Response::FormatError { source_message_id } => *source_message_id,
Response::ServerFailure { source_message_id } => *source_message_id,
Response::NameError { source_message_id } => *source_message_id,
Response::NotImplemented { source_message_id } => *source_message_id,
Response::Refused { source_message_id } => *source_message_id,
Response::UnknownError { source_message_id, .. } => *source_message_id,
}
}
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Response, ResponseReadError> {
let header = Header::read(buffer)
.change_context(ResponseReadError::HeaderReadError)?;
// check for errors, and short-cut out if we find any
match header.response_code {
ResponseCode::NoErrorConditions => {}
ResponseCode::FormatError => return Ok(Response::FormatError {
source_message_id: header.message_id,
}),
ResponseCode::ServerFailure => return Ok(Response::ServerFailure {
source_message_id: header.message_id,
}),
ResponseCode::NameError => return Ok(Response::NameError {
source_message_id: header.message_id,
}),
ResponseCode::NotImplemented => return Ok(Response::NotImplemented {
source_message_id: header.message_id,
}),
ResponseCode::Refused => return Ok(Response::Refused {
source_message_id: header.message_id,
}),
ResponseCode::Other(error_code) => return Ok(Response::UnknownError {
error_code,
source_message_id: header.message_id,
}),
}
// it seems weird to get questions in a response, but we need to parse
// them out if they exist.
for _ in 0..header.question_count {
let question = Question::read(buffer)
.change_context(ResponseReadError::QuestionReadError)?;
tracing::warn!(
%question,
"got question during server response."
);
}
let mut answers = vec![];
for idx in 0..header.answer_count {
let answer = ResourceRecord::read(buffer)
.change_context(ResponseReadError::AnswerReadError)
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.answer_count))?;
answers.push(answer);
}
let mut name_servers = vec![];
for idx in 0..header.name_server_count {
let name_server = ResourceRecord::read(buffer)
.change_context(ResponseReadError::NameServerReadError)
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.name_server_count))?;
name_servers.push(name_server);
}
let mut additional_records = vec![];
for idx in 0..header.additional_record_count {
let extra = ResourceRecord::read(buffer)
.change_context(ResponseReadError::AnswerReadError)
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.additional_record_count))?;
additional_records.push(extra);
}
Ok(Response::Valid {
source_message_id: header.message_id,
authoritative: header.authoritative_answer,
truncated: header.message_truncated,
answers,
name_servers,
additional_records,
})
}
fn write_error<B: BufMut>(source_message_id: u16, error_code: ResponseCode, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
let header = Header {
message_id: source_message_id,
is_response: true,
opcode: super::header::OpCode::StandardQuery,
authoritative_answer: false,
message_truncated: false,
recursion_desired: false,
recursion_available: false,
response_code: error_code,
question_count: 0,
answer_count: 0,
name_server_count: 0,
additional_record_count: 0,
};
header
.write(buffer)
.change_context(ResponseWriteError::Header)
.attach_printable_lazy(|| format!("responding to message {source_message_id} with error code {error_code}"))
}
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
let (source_message_id, authoritative, truncated, answers, name_servers, additional_records) = match self {
Response::FormatError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::FormatError, buffer),
Response::ServerFailure { source_message_id } => return Self::write_error(source_message_id, ResponseCode::ServerFailure, buffer),
Response::NameError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NameError, buffer),
Response::NotImplemented { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NotImplemented, buffer),
Response::Refused { source_message_id } => return Self::write_error(source_message_id, ResponseCode::Refused, buffer),
Response::UnknownError { error_code, source_message_id } => return Self::write_error(source_message_id, ResponseCode::Other(error_code), buffer),
Response::Valid { source_message_id, authoritative, truncated, answers, name_servers, additional_records } => {
(source_message_id, authoritative, truncated, answers, name_servers, additional_records)
}
};
let header = Header {
message_id: source_message_id,
is_response: true,
opcode: super::header::OpCode::StandardQuery,
authoritative_answer: authoritative,
message_truncated: truncated,
recursion_desired: false,
recursion_available: false,
response_code: ResponseCode::NoErrorConditions,
question_count: 0,
answer_count: answers.len() as u16,
name_server_count: name_servers.len() as u16,
additional_record_count: additional_records.len() as u16,
};
header.write(buffer)
.change_context(ResponseWriteError::Header)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))?;
let answer_count = answers.len();
for (item, answer) in answers.into_iter().enumerate() {
answer.write(buffer)
.change_context(ResponseWriteError::Answer)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
.attach_printable_lazy(|| format!("Writing answer {} of {}", item+1, answer_count))?;
}
let ns_count = name_servers.len();
for (item, answer) in name_servers.into_iter().enumerate() {
answer.write(buffer)
.change_context(ResponseWriteError::NameServer)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
.attach_printable_lazy(|| format!("Writing name server {} of {}", item+1, ns_count))?;
}
let ar_count = additional_records.len();
for (item, answer) in additional_records.into_iter().enumerate() {
answer.write(buffer)
.change_context(ResponseWriteError::AdditionalRecord)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
.attach_printable_lazy(|| format!("Writing additional record {} of {}", item+1, ar_count))?;
}
Ok(())
}
}
proptest::proptest! {
#[test]
fn any_random_response_roundtrips(record: Response) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
record.clone().write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_record = Response::read(&mut read_buffer).expect("can read name");
assert_eq!(record, new_record);
}
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,176 @@
use crate::network::resolver::name::Name;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use tokio::time::{Duration, Instant};
pub struct ResolutionTable {
inner: HashMap<Name, Vec<Resolution>>,
}
struct Resolution {
result: IpAddr,
expiration: Instant,
}
impl Default for ResolutionTable {
fn default() -> Self {
ResolutionTable::new()
}
}
impl ResolutionTable {
/// Generate a new, empty resolution table to use in a new DNS implementation,
/// or shared by a bunch of them.
pub fn new() -> Self {
ResolutionTable {
inner: HashMap::new(),
}
}
/// Clean the table of expired entries.
pub fn garbage_collect(&mut self) {
let now = Instant::now();
self.inner.retain(|_, items| {
items.retain(|x| x.expiration > now);
!items.is_empty()
});
}
/// Add a new entry to the resolution table, with a TTL on it.
pub fn add_entry(&mut self, name: Name, maps_to: IpAddr, ttl: Duration) {
let now = Instant::now();
let new_entry = Resolution {
result: maps_to,
expiration: now + ttl,
};
match self.inner.entry(name) {
Entry::Vacant(vac) => {
vac.insert(vec![new_entry]);
}
Entry::Occupied(mut occ) => {
occ.get_mut().push(new_entry);
}
}
}
/// Look up an entry in the resolution map. This will only return
/// unexpired items.
pub fn lookup(&mut self, name: &Name) -> HashSet<IpAddr> {
let mut result = HashSet::new();
let now = Instant::now();
if let Some(entry) = self.inner.get_mut(name) {
entry.retain(|x| {
let retain = x.expiration > now;
if retain {
result.insert(x.result);
}
retain
});
}
result
}
}
#[cfg(test)]
use std::net::{Ipv4Addr, Ipv6Addr};
#[cfg(test)]
use std::str::FromStr;
#[test]
fn empty_set_gets_fail() {
let mut empty = ResolutionTable::default();
assert!(empty.lookup(&Name::from_str("foo").unwrap()).is_empty());
assert!(empty.lookup(&Name::from_str("bar").unwrap()).is_empty());
}
#[test]
fn basic_lookups() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let bar = Name::from_str("bar").unwrap();
let baz = Name::from_str("baz").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(bar.clone(), localhost, long_time);
table.add_entry(bar.clone(), other, long_time);
assert_eq!(1, table.lookup(&foo).len());
assert_eq!(2, table.lookup(&bar).len());
assert!(table.lookup(&baz).is_empty());
assert!(table.lookup(&foo).contains(&localhost));
assert!(!table.lookup(&foo).contains(&other));
assert!(table.lookup(&bar).contains(&localhost));
assert!(table.lookup(&bar).contains(&other));
}
#[test]
fn lookup_cleans_up() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let short_time = Duration::from_millis(100);
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(foo.clone(), other, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
assert_eq!(1, table.lookup(&foo).len());
assert!(table.lookup(&foo).contains(&localhost));
assert!(!table.lookup(&foo).contains(&other));
}
#[test]
fn garbage_collection_works() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let short_time = Duration::from_millis(100);
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(foo.clone(), other, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
table.garbage_collect();
assert_eq!(1, table.inner.get(&foo).unwrap().len());
}
#[test]
fn garbage_collection_clears_empties() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
table.inner.insert(foo.clone(), vec![]);
table.garbage_collect();
assert!(table.inner.is_empty());
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let short_time = Duration::from_millis(100);
table.add_entry(foo.clone(), localhost, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
table.garbage_collect();
assert!(table.inner.is_empty());
}

View File

@@ -37,8 +37,6 @@ pub enum OperationalError {
ChannelFailure, ChannelFailure,
#[error("Error in initial handshake: {0}")] #[error("Error in initial handshake: {0}")]
KeyxProcessingError(#[from] SshKeyExchangeProcessingError), KeyxProcessingError(#[from] SshKeyExchangeProcessingError),
#[error("Call into random number generator failed: {0}")]
RngFailure(#[from] rand::Error),
#[error("Invalid port number '{port_string}': {error}")] #[error("Invalid port number '{port_string}': {error}")]
InvalidPort { InvalidPort {
port_string: String, port_string: String,

View File

@@ -44,10 +44,10 @@ where
/// write operations are cancel- and concurrency-safe. They take ownership /// write operations are cancel- and concurrency-safe. They take ownership
/// of the underlying stream once established, where "established" means /// of the underlying stream once established, where "established" means
/// that we have read and written the initial SSH banners. /// that we have read and written the initial SSH banners.
pub fn new(stream: Stream) -> SshChannel<Stream> { pub fn new(stream: Stream) -> Result<SshChannel<Stream>, getrandom::Error> {
let (read_half, write_half) = tokio::io::split(stream); let (read_half, write_half) = tokio::io::split(stream);
SshChannel { Ok(SshChannel {
read_side: Mutex::new(ReadSide { read_side: Mutex::new(ReadSide {
stream: read_half, stream: read_half,
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE), buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
@@ -55,12 +55,12 @@ where
write_side: Mutex::new(WriteSide { write_side: Mutex::new(WriteSide {
stream: write_half, stream: write_half,
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE), buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
rng: ChaCha20Rng::from_entropy(), rng: ChaCha20Rng::try_from_os_rng()?,
}), }),
mac_length: 0, mac_length: 0,
cipher_block_size: 8, cipher_block_size: 8,
channel_is_closed: false, channel_is_closed: false,
} })
} }
/// Read an SshPacket from the wire. /// Read an SshPacket from the wire.
@@ -155,7 +155,7 @@ where
encoded_packet.put(packet.buffer); encoded_packet.put(packet.buffer);
for _ in 0..rounded_padding { for _ in 0..rounded_padding {
encoded_packet.put_u8(rng.gen()); encoded_packet.put_u8(rng.random());
} }
Some(encoded_packet.freeze()) Some(encoded_packet.freeze())
@@ -247,8 +247,8 @@ proptest::proptest! {
fn can_read_back_anything(packet in SshPacket::arbitrary()) { fn can_read_back_anything(packet in SshPacket::arbitrary()) {
let result = tokio::runtime::Runtime::new().unwrap().block_on(async { let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
let (left, right) = tokio::io::duplex(8192); let (left, right) = tokio::io::duplex(8192);
let leftssh = SshChannel::new(left); let leftssh = SshChannel::new(left).unwrap();
let rightssh = SshChannel::new(right); let rightssh = SshChannel::new(right).unwrap();
let packet_copy = packet.clone(); let packet_copy = packet.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
@@ -264,8 +264,8 @@ proptest::proptest! {
fn sequences_send_correctly_serial(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) { fn sequences_send_correctly_serial(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) {
tokio::runtime::Runtime::new().unwrap().block_on(async { tokio::runtime::Runtime::new().unwrap().block_on(async {
let (left, right) = tokio::io::duplex(8192); let (left, right) = tokio::io::duplex(8192);
let leftssh = SshChannel::new(left); let leftssh = SshChannel::new(left).unwrap();
let rightssh = SshChannel::new(right); let rightssh = SshChannel::new(right).unwrap();
let sequence_left = sequence.clone(); let sequence_left = sequence.clone();
let sequence_right = sequence; let sequence_right = sequence;
@@ -335,9 +335,9 @@ proptest::proptest! {
tokio::runtime::Runtime::new().unwrap().block_on(async { tokio::runtime::Runtime::new().unwrap().block_on(async {
let (left, right) = tokio::io::duplex(8192); let (left, right) = tokio::io::duplex(8192);
let leftsshw = Arc::new(SshChannel::new(left)); let leftsshw = Arc::new(SshChannel::new(left).unwrap());
let leftsshr = leftsshw.clone(); let leftsshr = leftsshw.clone();
let rightsshw = Arc::new(SshChannel::new(right)); let rightsshw = Arc::new(SshChannel::new(right).unwrap());
let rightsshr = rightsshw.clone(); let rightsshr = rightsshw.clone();
let sequence_left_write = sequence.clone(); let sequence_left_write = sequence.clone();

View File

@@ -134,7 +134,7 @@ impl SshKeyExchange {
/// seed the message with a random cookie, but is otherwise deterministic. /// seed the message with a random cookie, but is otherwise deterministic.
/// It will fail only in the case that the underlying random number /// It will fail only in the case that the underlying random number
/// generator fails, and return exactly that error. /// generator fails, and return exactly that error.
pub fn new<R>(rng: &mut R, value: ClientConnectionOpts) -> Result<Self, rand::Error> pub fn new<R>(rng: &mut R, value: ClientConnectionOpts) -> Result<Self, ()>
where where
R: CryptoRng + Rng, R: CryptoRng + Rng,
{ {
@@ -185,7 +185,7 @@ impl SshKeyExchange {
first_kex_packet_follows: value.predict.is_some(), first_kex_packet_follows: value.predict.is_some(),
}; };
rng.try_fill(&mut result.cookie)?; rng.fill(&mut result.cookie);
Ok(result) Ok(result)
} }