Checkpoint with resolver tests including request/response.
This commit is contained in:
1076
Cargo.lock
generated
1076
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
49
Cargo.toml
49
Cargo.toml
@@ -15,42 +15,41 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] }
|
|||||||
aes = { version = "0.8.4", features = ["zeroize"] }
|
aes = { version = "0.8.4", features = ["zeroize"] }
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
bcrypt-pbkdf = "0.10.0"
|
bcrypt-pbkdf = "0.10.0"
|
||||||
bytes = "1.6.0"
|
bytes = "1.10.1"
|
||||||
cipher = { version = "0.4.4", features = ["alloc", "block-padding", "rand_core", "std", "zeroize"] }
|
clap = { version = "4.5.35", features = ["derive"] }
|
||||||
clap = { version = "4.5.7", features = ["derive"] }
|
console-subscriber = "0.4.1"
|
||||||
console-subscriber = "0.3.0"
|
|
||||||
ctr = "0.9.2"
|
ctr = "0.9.2"
|
||||||
ed25519-dalek = "2.1.1"
|
ed25519-dalek = "2.1.1"
|
||||||
elliptic-curve = { version = "0.13.8", features = ["alloc", "digest", "ecdh", "pem", "pkcs8", "sec1", "serde", "std", "hash2curve", "voprf"] }
|
elliptic-curve = { version = "0.13.8", features = ["alloc", "digest", "ecdh", "pem", "pkcs8", "sec1", "serde", "std", "hash2curve", "voprf"] }
|
||||||
error-stack = "0.5.0"
|
error-stack = "0.5.0"
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
generic-array = "0.14.7"
|
generic-array = "0.14.7"
|
||||||
hexdump = "0.1.2"
|
getrandom = "0.3.2"
|
||||||
hickory-client = { version = "0.24.1", features = ["mdns"] }
|
internment = { version = "0.8.6", features = ["arc"] }
|
||||||
hickory-proto = "0.24.1"
|
itertools = "0.14.0"
|
||||||
itertools = "0.13.0"
|
nix = { version = "0.29.0", features = ["user"] }
|
||||||
moka = { version = "0.12.10", features = ["future"] }
|
|
||||||
nix = { version = "0.28.0", features = ["user"] }
|
|
||||||
num-bigint-dig = { version = "0.8.4", features = ["arbitrary", "i128", "zeroize", "prime", "rand"] }
|
num-bigint-dig = { version = "0.8.4", features = ["arbitrary", "i128", "zeroize", "prime", "rand"] }
|
||||||
num-integer = { version = "0.1.46", features = ["i128"] }
|
num-integer = { version = "0.1.46", features = ["i128"] }
|
||||||
num-traits = { version = "0.2.19", features = ["i128"] }
|
num-traits = { version = "0.2.19", features = ["i128"] }
|
||||||
num_enum = "0.7.2"
|
num_enum = "0.7.3"
|
||||||
p256 = { version = "0.13.2", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
|
p256 = { version = "0.13.2", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
|
||||||
p384 = { version = "0.13.0", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
|
p384 = { version = "0.13.1", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
|
||||||
p521 = { version = "0.13.3", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
|
p521 = { version = "0.13.3", features = ["ecdh", "ecdsa-core", "hash2curve", "serde", "test-vectors"] }
|
||||||
proptest = "1.5.0"
|
proptest = "1.6.0"
|
||||||
rand = "0.8.5"
|
proptest-derive = "0.5.1"
|
||||||
rand_chacha = "0.3.1"
|
rand = "0.9.0"
|
||||||
rustix = "0.38.41"
|
rand_chacha = "0.9.0"
|
||||||
|
rustix = "1.0.5"
|
||||||
sec1 = "0.7.3"
|
sec1 = "0.7.3"
|
||||||
serde = { version = "1.0.203", features = ["derive"] }
|
serde = { version = "1.0.219", features = ["derive"] }
|
||||||
tempfile = "3.12.0"
|
tempfile = "3.19.1"
|
||||||
thiserror = "2.0.3"
|
thiserror = "2.0.12"
|
||||||
tokio = { version = "1.38.0", features = ["full", "tracing"] }
|
tokio = { version = "1.44.2", features = ["full", "tracing"] }
|
||||||
toml = "0.8.14"
|
toml = "0.8.20"
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.41"
|
||||||
tracing-core = "0.1.32"
|
tracing-core = "0.1.33"
|
||||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "tracing", "json"] }
|
tracing-subscriber = { version = "0.3.19", features = ["env-filter", "tracing", "json"] }
|
||||||
whoami = { version = "1.5.2", default-features = false }
|
url = "2.5.4"
|
||||||
|
whoami = { version = "1.6.0", default-features = false }
|
||||||
xdg = "2.5.2"
|
xdg = "2.5.2"
|
||||||
zeroize = "1.8.1"
|
zeroize = "1.8.1"
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ async fn connect(
|
|||||||
"received server preamble"
|
"received server preamble"
|
||||||
);
|
);
|
||||||
|
|
||||||
let stream = ssh::SshChannel::new(stream);
|
let stream = ssh::SshChannel::new(stream).expect("can build new SSH channel");
|
||||||
let their_initial = stream
|
let their_initial = stream
|
||||||
.read()
|
.read()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
use crate::network::resolver::name::Name;
|
||||||
use proptest::arbitrary::Arbitrary;
|
use proptest::arbitrary::Arbitrary;
|
||||||
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
use std::net::SocketAddr;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
||||||
pub struct DnsConfig {
|
pub struct DnsConfig {
|
||||||
@@ -179,25 +181,80 @@ impl Arbitrary for BuiltinDnsOption {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
|
||||||
pub struct NameServerConfig {
|
pub struct NameServerConfig {
|
||||||
pub address: SocketAddr,
|
#[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
|
||||||
|
pub address: Url,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub timeout_in_seconds: Option<u64>,
|
pub timeout_in_seconds: Option<u64>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub bind_address: Option<SocketAddr>,
|
pub bind_address: Option<SocketAddr>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn serialize_url<S: serde::Serializer>(url: &Url, serializer: S) -> Result<S::Ok, S::Error> {
|
||||||
|
serializer.collect_str(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_url<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Url, D::Error> {
|
||||||
|
struct Visitor {}
|
||||||
|
|
||||||
|
impl<'de> serde::de::Visitor<'de> for Visitor {
|
||||||
|
type Value = Url;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(formatter, "a legal URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
|
||||||
|
Url::parse(v).map_err(|e| E::custom(e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_borrowed_str<E: serde::de::Error>(self, v: &'de str) -> Result<Self::Value, E> {
|
||||||
|
Url::parse(v).map_err(|e| E::custom(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializer.deserialize_str(Visitor {})
|
||||||
|
}
|
||||||
|
|
||||||
impl Arbitrary for NameServerConfig {
|
impl Arbitrary for NameServerConfig {
|
||||||
type Parameters = ();
|
type Parameters = ();
|
||||||
type Strategy = BoxedStrategy<Self>;
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
|
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
|
||||||
|
let scheme = proptest::prop_oneof![
|
||||||
|
Just("http".to_string()),
|
||||||
|
Just("https".to_string()),
|
||||||
|
Just("tcp".to_string()),
|
||||||
|
Just("udp".to_string()),
|
||||||
|
];
|
||||||
|
let domain = Name::arbitrary();
|
||||||
|
let port = proptest::option::of(u16::arbitrary());
|
||||||
|
let user = proptest::option::of(proptest::string::string_regex("[A-Za-z]{1,30}").unwrap());
|
||||||
|
let password =
|
||||||
|
proptest::option::of(proptest::string::string_regex(":[A-Za-z]{1,30}").unwrap());
|
||||||
|
let path =
|
||||||
|
proptest::option::of(proptest::string::string_regex("/[A-Za-z/]{1,30}").unwrap());
|
||||||
|
|
||||||
|
let uri_strategy = (scheme, domain, user, password, port, path).prop_map(
|
||||||
|
|(scheme, domain, user, password, port, path)| {
|
||||||
|
let userpass_prefix = match (user, password) {
|
||||||
|
(None, None) => String::new(),
|
||||||
|
(Some(u), None) => format!("{u}@"),
|
||||||
|
(None, Some(p)) => format!(":{p}@"),
|
||||||
|
(Some(u), Some(p)) => format!("{u}:{p}@"),
|
||||||
|
};
|
||||||
|
let path = path.unwrap_or_default();
|
||||||
|
let port = port.map(|x| format!(":{x}")).unwrap_or_default();
|
||||||
|
let uri_str = format!("{scheme}://{userpass_prefix}{domain}{port}{path}");
|
||||||
|
Url::parse(&uri_str).unwrap()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
(
|
(
|
||||||
SocketAddr::arbitrary(),
|
uri_strategy,
|
||||||
proptest::option::of(u64::arbitrary()),
|
proptest::option::of(u64::arbitrary()),
|
||||||
proptest::option::of(SocketAddr::arbitrary()),
|
proptest::option::of(SocketAddr::arbitrary()),
|
||||||
)
|
)
|
||||||
.prop_map(|(mut address, mut timeout_in_seconds, mut bind_address)| {
|
.prop_map(|(address, mut timeout_in_seconds, mut bind_address)| {
|
||||||
clear_flow_and_scope_info(&mut address);
|
|
||||||
if let Some(bind_address) = bind_address.as_mut() {
|
if let Some(bind_address) = bind_address.as_mut() {
|
||||||
clear_flow_and_scope_info(bind_address);
|
clear_flow_and_scope_info(bind_address);
|
||||||
}
|
}
|
||||||
@@ -242,81 +299,56 @@ impl DnsConfig {
|
|||||||
None => {}
|
None => {}
|
||||||
Some(BuiltinDnsOption::Cloudflare) => {
|
Some(BuiltinDnsOption::Cloudflare) => {
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 53)),
|
address: Url::parse("udp://1.1.1.1:53").unwrap(),
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 0, 0, 1), 53)),
|
address: Url::parse("udp://1.0.0.1").unwrap(),
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V6(SocketAddrV6::new(
|
address: Url::parse("udp://2606:4700:4700::1111").unwrap(),
|
||||||
Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111),
|
|
||||||
53,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)),
|
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V6(SocketAddrV6::new(
|
address: Url::parse("udp://2606:4700:4700::1001:53").unwrap(),
|
||||||
Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1001),
|
|
||||||
53,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)),
|
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Some(BuiltinDnsOption::Google) => {
|
Some(BuiltinDnsOption::Google) => {
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 53)),
|
address: Url::parse("udp://8.8.8.8:53").unwrap(),
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 53)),
|
address: Url::parse("udp://8.8.4.4").unwrap(),
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V6(SocketAddrV6::new(
|
address: Url::parse("udp://2001:4860:4860::8888:53").unwrap(),
|
||||||
Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888),
|
|
||||||
53,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)),
|
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V6(SocketAddrV6::new(
|
address: Url::parse("udp://2001:4860:4860::8844").unwrap(),
|
||||||
Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8844),
|
|
||||||
53,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)),
|
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Some(BuiltinDnsOption::Quad9) => {
|
Some(BuiltinDnsOption::Quad9) => {
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(9, 9, 9, 9), 53)),
|
address: Url::parse("udp://9.9.9.9:53").unwrap(),
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
results.push(NameServerConfig {
|
results.push(NameServerConfig {
|
||||||
address: SocketAddr::V6(SocketAddrV6::new(
|
address: Url::parse("udp://2620::00fe:00f3").unwrap(),
|
||||||
Ipv6Addr::new(0x2620, 0, 0, 0, 0, 0, 0xfe, 0xfe),
|
|
||||||
53,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)),
|
|
||||||
timeout_in_seconds: None,
|
timeout_in_seconds: None,
|
||||||
bind_address: None,
|
bind_address: None,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
use crate::network::resolver::name::Name;
|
||||||
use crate::network::resolver::{ResolveError, Resolver};
|
use crate::network::resolver::{ResolveError, Resolver};
|
||||||
use error_stack::{report, ResultExt};
|
use error_stack::{report, ResultExt};
|
||||||
use futures::stream::{FuturesUnordered, StreamExt};
|
use futures::stream::{FuturesUnordered, StreamExt};
|
||||||
use hickory_client::rr::Name;
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||||
@@ -63,7 +63,7 @@ impl FromStr for Host {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(name) = Name::from_utf8(s) {
|
if let Ok(name) = Name::from_str(s) {
|
||||||
return Ok(Host::Hostname(name));
|
return Ok(Host::Hostname(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,11 +76,8 @@ impl FromStr for Host {
|
|||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum ConnectionError {
|
pub enum ConnectionError {
|
||||||
#[error("Failed to resolve host: {error}")]
|
#[error("Connection error: failed to resolve host")]
|
||||||
ResolveError {
|
ResolveError,
|
||||||
#[from]
|
|
||||||
error: ResolveError,
|
|
||||||
},
|
|
||||||
#[error("No valid IP addresses found")]
|
#[error("No valid IP addresses found")]
|
||||||
NoAddresses,
|
NoAddresses,
|
||||||
#[error("Error connecting to host: {error}")]
|
#[error("Error connecting to host: {error}")]
|
||||||
@@ -98,7 +95,10 @@ impl Host {
|
|||||||
/// are no relevant records for us to use for IPv4 or IPv6 connections. There
|
/// are no relevant records for us to use for IPv4 or IPv6 connections. There
|
||||||
/// is also no guarantee that the host will have both IPv4 and IPv6 addresses,
|
/// is also no guarantee that the host will have both IPv4 and IPv6 addresses,
|
||||||
/// so you may only see one or the other.
|
/// so you may only see one or the other.
|
||||||
pub async fn resolve(&self, resolver: &mut Resolver) -> Result<HashSet<IpAddr>, ResolveError> {
|
pub async fn resolve(
|
||||||
|
&self,
|
||||||
|
resolver: &mut Resolver,
|
||||||
|
) -> error_stack::Result<HashSet<IpAddr>, ResolveError> {
|
||||||
match self {
|
match self {
|
||||||
Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])),
|
Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])),
|
||||||
Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])),
|
Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])),
|
||||||
@@ -121,7 +121,7 @@ impl Host {
|
|||||||
let addresses = self
|
let addresses = self
|
||||||
.resolve(resolver)
|
.resolve(resolver)
|
||||||
.await
|
.await
|
||||||
.map_err(ConnectionError::from)
|
.change_context(ConnectionError::ResolveError)
|
||||||
.attach_printable_lazy(|| format!("target address {}", self))?;
|
.attach_printable_lazy(|| format!("target address {}", self))?;
|
||||||
|
|
||||||
let mut connectors = FuturesUnordered::new();
|
let mut connectors = FuturesUnordered::new();
|
||||||
|
|||||||
@@ -1,39 +1,27 @@
|
|||||||
|
pub mod name;
|
||||||
|
mod protocol;
|
||||||
|
mod resolution_table;
|
||||||
|
|
||||||
use crate::config::resolver::{DnsConfig, NameServerConfig};
|
use crate::config::resolver::{DnsConfig, NameServerConfig};
|
||||||
use error_stack::report;
|
use crate::network::resolver::name::Name;
|
||||||
use futures::future::select;
|
use crate::network::resolver::resolution_table::ResolutionTable;
|
||||||
use futures::stream::{SelectAll, StreamExt};
|
use error_stack::{report, ResultExt};
|
||||||
use hickory_client::rr::Name;
|
|
||||||
use hickory_proto::error::ProtoError;
|
|
||||||
use hickory_proto::op::message::Message;
|
|
||||||
use hickory_proto::op::query::Query;
|
|
||||||
use hickory_proto::rr::record_data::RData;
|
|
||||||
use hickory_proto::rr::resource::Record;
|
|
||||||
use hickory_proto::rr::RecordType;
|
|
||||||
use hickory_proto::udp::UdpClientStream;
|
|
||||||
use hickory_proto::xfer::dns_request::DnsRequest;
|
|
||||||
use hickory_proto::xfer::{DnsRequestSender, DnsResponseStream};
|
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||||
#[cfg(test)]
|
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::{TcpSocket, UdpSocket};
|
||||||
use tokio::sync::{oneshot, Mutex};
|
use tokio::sync::Mutex;
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
use tokio::time::{timeout_at, Duration, Instant};
|
use tokio::time::{Duration, Instant};
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum ResolverConfigError {
|
pub enum ResolverConfigError {
|
||||||
#[error("Bad local domain name '{name}' provided: '{error}'")]
|
#[error("Bad local domain name provided")]
|
||||||
BadDomainName { name: String, error: ProtoError },
|
BadDomainName,
|
||||||
#[error("Bad domain name for search '{name}' provided: '{error}'")]
|
#[error("Couldn't create a DNS client for the given address, port, and protocol.")]
|
||||||
BadSearchName { name: String, error: ProtoError },
|
FailedToCreateDnsClient,
|
||||||
#[error("Couldn't set up client for server at '{address}': '{error}'")]
|
|
||||||
FailedToCreateDnsClient {
|
|
||||||
address: SocketAddr,
|
|
||||||
error: ProtoError,
|
|
||||||
},
|
|
||||||
#[error("No DNS servers found to search, and mDNS not enabled")]
|
#[error("No DNS servers found to search, and mDNS not enabled")]
|
||||||
NoHosts,
|
NoHosts,
|
||||||
}
|
}
|
||||||
@@ -44,8 +32,8 @@ pub enum ResolveError {
|
|||||||
NoServersAvailable,
|
NoServersAvailable,
|
||||||
#[error("No responses found for query")]
|
#[error("No responses found for query")]
|
||||||
NoResponses,
|
NoResponses,
|
||||||
#[error("Error reading response: {error}")]
|
#[error("Error reading response from server")]
|
||||||
ResponseError { error: ProtoError },
|
ResponseError,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Resolver {
|
pub struct Resolver {
|
||||||
@@ -53,11 +41,13 @@ pub struct Resolver {
|
|||||||
max_time_to_wait_for_initial: Duration,
|
max_time_to_wait_for_initial: Duration,
|
||||||
time_to_wait_after_first: Duration,
|
time_to_wait_after_first: Duration,
|
||||||
time_to_wait_for_lingering: Duration,
|
time_to_wait_for_lingering: Duration,
|
||||||
state: Arc<Mutex<ResolverState>>,
|
connections: Arc<Mutex<Vec<(NameServerConfig, protocol::Client)>>>,
|
||||||
|
table: Arc<Mutex<ResolutionTable>>,
|
||||||
|
tasks: JoinSet<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ResolverState {
|
pub struct ResolverState {
|
||||||
client_connections: Vec<(NameServerConfig, UdpClientStream<UdpSocket>)>,
|
client_connections: Vec<(NameServerConfig, protocol::Client)>,
|
||||||
cache: HashMap<Name, Vec<DnsResolution>>,
|
cache: HashMap<Name, Vec<DnsResolution>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,40 +60,151 @@ impl Resolver {
|
|||||||
/// Create a new DNS resolution engine for use by some part of the system.
|
/// Create a new DNS resolution engine for use by some part of the system.
|
||||||
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
|
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
|
||||||
let mut search_domains = Vec::new();
|
let mut search_domains = Vec::new();
|
||||||
|
let mut tasks = JoinSet::new();
|
||||||
|
|
||||||
if let Some(local) = config.local_domain.as_ref() {
|
if let Some(local) = config.local_domain.as_ref() {
|
||||||
search_domains.push(Name::from_utf8(local).map_err(|e| {
|
let name = Name::from_str(local)
|
||||||
report!(ResolverConfigError::BadDomainName {
|
.change_context(ResolverConfigError::BadDomainName)
|
||||||
name: local.clone(),
|
.attach_printable("Trying to add local domain.")
|
||||||
error: e,
|
.attach_printable_lazy(|| "Offending non-name: '{local}'")?;
|
||||||
})
|
|
||||||
})?);
|
search_domains.push(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
for local_domain in config.search_domains.iter() {
|
for search_domain in config.search_domains.iter() {
|
||||||
search_domains.push(Name::from_utf8(local_domain).map_err(|e| {
|
let name = Name::from_str(search_domain)
|
||||||
report!(ResolverConfigError::BadSearchName {
|
.change_context(ResolverConfigError::BadDomainName)
|
||||||
name: local_domain.clone(),
|
.attach_printable("Trying to add search domain.")
|
||||||
error: e,
|
.attach_printable_lazy(|| "Offending non-name: '{search_domain}'")?;
|
||||||
})
|
|
||||||
})?);
|
search_domains.push(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut client_connections = Vec::new();
|
let mut client_connections = Vec::new();
|
||||||
|
|
||||||
for name_server_config in config.name_servers() {
|
for target in config.name_servers.iter() {
|
||||||
let stream = UdpClientStream::with_bind_addr_and_timeout(
|
let port = target.address.port().unwrap_or(53);
|
||||||
name_server_config.address,
|
let Some(address) = target.address.host() else {
|
||||||
name_server_config.bind_address,
|
return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
|
||||||
Duration::from_secs(name_server_config.timeout_in_seconds.unwrap_or(3)),
|
.attach_printable("No address to connect to?")
|
||||||
)
|
.attach_printable_lazy(|| format!("Target address was {}", target.address));
|
||||||
.await
|
};
|
||||||
.map_err(|error| ResolverConfigError::FailedToCreateDnsClient {
|
|
||||||
address: name_server_config.address,
|
let address = match address {
|
||||||
error,
|
url::Host::Ipv4(addr) => IpAddr::V4(addr),
|
||||||
|
url::Host::Ipv6(addr) => IpAddr::V6(addr),
|
||||||
|
url::Host::Domain(name) => {
|
||||||
|
return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
|
||||||
|
.attach_printable("Cannot use domain names to identify domain servers")
|
||||||
|
.attach_printable_lazy(|| format!("Target address was {name}"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let target_addr = SocketAddr::new(address, port);
|
||||||
|
|
||||||
|
match target.address.scheme() {
|
||||||
|
"tcp" => {
|
||||||
|
let socket = if target_addr.is_ipv4() {
|
||||||
|
TcpSocket::new_v4()
|
||||||
|
} else {
|
||||||
|
TcpSocket::new_v6()
|
||||||
|
};
|
||||||
|
|
||||||
|
let socket = socket
|
||||||
|
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||||
|
.attach_printable("Could not create a socket")
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("For target DNS server {}", target.address)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
client_connections.push((name_server_config, stream));
|
if let Some(bind_address) = target.bind_address {
|
||||||
|
socket
|
||||||
|
.bind(bind_address)
|
||||||
|
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||||
|
.attach_printable("Could not bind local address for socket.")
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("Binding to TCP address {}", bind_address)
|
||||||
|
})
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("For target DNS server {}", target.address)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream = socket
|
||||||
|
.connect(target_addr)
|
||||||
|
.await
|
||||||
|
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("Connecting to target {}", target_addr)
|
||||||
|
})?;
|
||||||
|
let client = protocol::Client::from_tcp(stream, &mut tasks)
|
||||||
|
.await
|
||||||
|
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("Connecting to target {}", target_addr)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
client_connections.push((target.clone(), client));
|
||||||
|
}
|
||||||
|
|
||||||
|
"udp" => {
|
||||||
|
let port = target.address.port().unwrap_or(53);
|
||||||
|
let Some(address) = target.address.host() else {
|
||||||
|
tracing::warn!(address = %target.address, "proposed domain server has no host");
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let address = match address {
|
||||||
|
url::Host::Ipv4(addr) => IpAddr::V4(addr),
|
||||||
|
url::Host::Ipv6(addr) => IpAddr::V6(addr),
|
||||||
|
url::Host::Domain(name) => {
|
||||||
|
tracing::warn!(
|
||||||
|
address = %target.address,
|
||||||
|
hostname = name,
|
||||||
|
"currently, we can't use hostnames for domain servers"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let sock_addr = SocketAddr::new(address, port);
|
||||||
|
let bind_address = target.bind_address.unwrap_or_else(|| {
|
||||||
|
if sock_addr.is_ipv4() {
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
|
||||||
|
} else {
|
||||||
|
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let udp_socket = UdpSocket::bind(bind_address)
|
||||||
|
.await
|
||||||
|
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||||
|
.attach_printable("Generating UDP socket")
|
||||||
|
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
|
||||||
|
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
|
||||||
|
|
||||||
|
udp_socket
|
||||||
|
.connect(target_addr)
|
||||||
|
.await
|
||||||
|
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||||
|
.attach_printable("Connecting UDP socket")
|
||||||
|
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
|
||||||
|
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
|
||||||
|
|
||||||
|
let client = protocol::Client::from_udp(udp_socket, &mut tasks).await;
|
||||||
|
|
||||||
|
client_connections.push((target.clone(), client));
|
||||||
|
}
|
||||||
|
|
||||||
|
"unix" => unimplemented!(),
|
||||||
|
"unixd" => unimplemented!(),
|
||||||
|
|
||||||
|
"http" => unimplemented!(),
|
||||||
|
"https" => unimplemented!(),
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
tracing::warn!(address = %target.address, "Unknown scheme for building DNS connections");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Resolver {
|
Ok(Resolver {
|
||||||
@@ -112,321 +213,35 @@ impl Resolver {
|
|||||||
max_time_to_wait_for_initial: Duration::from_millis(150),
|
max_time_to_wait_for_initial: Duration::from_millis(150),
|
||||||
time_to_wait_after_first: Duration::from_millis(50),
|
time_to_wait_after_first: Duration::from_millis(50),
|
||||||
time_to_wait_for_lingering: Duration::from_secs(2),
|
time_to_wait_for_lingering: Duration::from_secs(2),
|
||||||
state: Arc::new(Mutex::new(ResolverState {
|
connections: Arc::new(Mutex::new(client_connections)),
|
||||||
client_connections,
|
table: Arc::new(Mutex::new(ResolutionTable::new())),
|
||||||
cache: HashMap::new(),
|
tasks,
|
||||||
})),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Look up the address of the given name, returning either a set of results
|
pub async fn lookup(&self, _name: &Name) -> error_stack::Result<HashSet<IpAddr>, ResolveError> {
|
||||||
/// we've received in a reasonable amount of time, or an error.
|
unimplemented!()
|
||||||
pub async fn lookup(&self, name: &Name) -> Result<HashSet<IpAddr>, ResolveError> {
|
|
||||||
let names = self.expand_name(name);
|
|
||||||
|
|
||||||
let cached_values = self.cached_lookup(&names).await;
|
|
||||||
if !cached_values.is_empty() {
|
|
||||||
return Ok(cached_values);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut response_stream = self.create_response_stream(&names).await;
|
|
||||||
if response_stream.is_empty() {
|
|
||||||
return Err(ResolveError::NoServersAvailable);
|
|
||||||
}
|
|
||||||
|
|
||||||
let get_first_by = Instant::now() + self.max_time_to_wait_for_initial;
|
|
||||||
let got_responses =
|
|
||||||
get_responses_by(&self.state, get_first_by, name, &mut response_stream).await?;
|
|
||||||
|
|
||||||
if got_responses {
|
|
||||||
let get_rest_by = Instant::now() + self.time_to_wait_after_first;
|
|
||||||
let _ = get_responses_by(&self.state, get_rest_by, name, &mut response_stream).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
let lingering_deadline = Instant::now() + self.time_to_wait_for_lingering;
|
|
||||||
let state_copy = self.state.clone();
|
|
||||||
let name_copy = name.clone();
|
|
||||||
tokio::task::spawn(async move {
|
|
||||||
let _ = get_responses_by(
|
|
||||||
&state_copy,
|
|
||||||
lingering_deadline,
|
|
||||||
&name_copy,
|
|
||||||
&mut response_stream,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
});
|
|
||||||
|
|
||||||
let after_search_lookup = self.cached_lookup(&names).await;
|
|
||||||
if after_search_lookup.is_empty() {
|
|
||||||
Err(ResolveError::NoResponses)
|
|
||||||
} else {
|
|
||||||
Ok(after_search_lookup)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Look up any cached IP address we might have for the given names.
|
|
||||||
///
|
|
||||||
/// As a side effect, this function will clean out any expired entries in the
|
|
||||||
/// cache.
|
|
||||||
async fn cached_lookup(&self, names: &[Name]) -> HashSet<IpAddr> {
|
|
||||||
let mut retval = HashSet::new();
|
|
||||||
let mut state = self.state.lock().await;
|
|
||||||
|
|
||||||
for name in names {
|
|
||||||
if let Some(mut existing) = state.cache.remove(name) {
|
|
||||||
let now = Instant::now();
|
|
||||||
|
|
||||||
existing.retain(|item| {
|
|
||||||
let keeper = item.expires_at > now;
|
|
||||||
|
|
||||||
if keeper {
|
|
||||||
retval.insert(item.address);
|
|
||||||
}
|
|
||||||
|
|
||||||
keeper
|
|
||||||
});
|
|
||||||
|
|
||||||
if !existing.is_empty() {
|
|
||||||
state.cache.insert(name.clone(), existing);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
retval
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Reach out to all of our current connections, and start the process of getting
|
|
||||||
/// responses.
|
|
||||||
///
|
|
||||||
/// If the clients are shut down, we'll try to create new connections to them, or
|
|
||||||
/// remove them from our connection list if we can't. If there are no connections,
|
|
||||||
/// the `SelectAll` will be empty, and callers are advised to check for this
|
|
||||||
/// condition and inform the user that they're out of useful DNS servers.
|
|
||||||
async fn create_response_stream(&self, names: &[Name]) -> SelectAll<DnsResponseStream> {
|
|
||||||
let mut state = self.state.lock().await;
|
|
||||||
let connections = std::mem::take(&mut state.client_connections);
|
|
||||||
let mut response_stream = futures::stream::SelectAll::new();
|
|
||||||
|
|
||||||
for (config, mut client) in connections.into_iter() {
|
|
||||||
if client.is_shutdown() {
|
|
||||||
let stream = UdpClientStream::with_bind_addr_and_timeout(
|
|
||||||
config.address,
|
|
||||||
config.bind_address,
|
|
||||||
Duration::from_secs(config.timeout_in_seconds.unwrap_or(3)),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match stream {
|
|
||||||
Ok(stream) => client = stream,
|
|
||||||
Err(_) => continue,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut message = Message::new();
|
|
||||||
|
|
||||||
for name in names.iter() {
|
|
||||||
message.add_query(Query::query(name.clone(), RecordType::A));
|
|
||||||
message.add_query(Query::query(name.clone(), RecordType::AAAA));
|
|
||||||
}
|
|
||||||
message.set_recursion_desired(true);
|
|
||||||
|
|
||||||
let request = DnsRequest::new(message, Default::default());
|
|
||||||
response_stream.push(client.send_message(request));
|
|
||||||
state.client_connections.push((config, client));
|
|
||||||
}
|
|
||||||
|
|
||||||
response_stream
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Expand the given input name into the complete set of names that we should search for.
|
|
||||||
fn expand_name(&self, name: &Name) -> Vec<Name> {
|
|
||||||
let mut names = vec![name.clone()];
|
|
||||||
|
|
||||||
for search_domain in self.search_domains.iter() {
|
|
||||||
if let Ok(combined) = name.clone().append_name(search_domain) {
|
|
||||||
names.push(combined);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
names
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run a cleaner task on this Resolver object.
|
|
||||||
///
|
|
||||||
/// The task will run in the given JoinSet, set it can be easily killed, tracked, waited
|
|
||||||
/// for as desired. Every `period` interval, it will grab the state lock and clean out
|
|
||||||
/// any entries that are timed out. Ideally, you shouldn't set this too frequently, so
|
|
||||||
/// that this process doesn't interfere with other tasks trying to use the resolver.
|
|
||||||
///
|
|
||||||
/// This task will only weakly hang on to the resolver state, so that if the Resolver
|
|
||||||
/// object is dropped it will quietly stop running. It will also stop running if it sees
|
|
||||||
/// a message on the kill signal, if one is provided, or if one is provided and the other
|
|
||||||
/// end drops the sender. (So if you don't want a kill signal, send `None`, or you might
|
|
||||||
/// accidentally kill your GC task too early.)
|
|
||||||
pub async fn run_resolver_gc(
|
|
||||||
&self,
|
|
||||||
task_set: &mut JoinSet<Result<(), ()>>,
|
|
||||||
mut kill_signal: Option<oneshot::Receiver<()>>,
|
|
||||||
interval: Duration,
|
|
||||||
) {
|
|
||||||
let weak_locked_state = Arc::downgrade(&self.state);
|
|
||||||
|
|
||||||
task_set.spawn(async move {
|
|
||||||
loop {
|
|
||||||
tokio::time::sleep(interval).await;
|
|
||||||
|
|
||||||
let Some(locked_state) = weak_locked_state.upgrade() else {
|
|
||||||
break Ok(());
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut state = if let Some(kill_signal) = kill_signal.as_mut() {
|
|
||||||
let locker = locked_state.lock();
|
|
||||||
|
|
||||||
futures::pin_mut!(locker);
|
|
||||||
match select(locker, kill_signal).await {
|
|
||||||
futures::future::Either::Left((state, _)) => state,
|
|
||||||
futures::future::Either::Right((_, _)) => break Ok(()),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
locked_state.lock().await
|
|
||||||
};
|
|
||||||
|
|
||||||
let now = Instant::now();
|
|
||||||
state.cache.retain(|_, value| {
|
|
||||||
value.retain(|resolution| resolution.expires_at < now);
|
|
||||||
!value.is_empty()
|
|
||||||
});
|
|
||||||
|
|
||||||
drop(state);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Inject a mapping into the resolver, with the given TTL.
|
|
||||||
///
|
|
||||||
/// This is largely used for testing purposes, but could be used if there are static
|
|
||||||
/// addresses you want to always resolve in a particular way. For that use case, you
|
|
||||||
/// probably want to add absurdly-long TTLs for those names.
|
|
||||||
pub async fn inject_resolution(&self, name: Name, address: IpAddr, ttl: Duration) {
|
|
||||||
let resolution = DnsResolution {
|
|
||||||
address,
|
|
||||||
expires_at: Instant::now() + ttl,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut state = self.state.lock().await;
|
|
||||||
match state.cache.entry(name) {
|
|
||||||
std::collections::hash_map::Entry::Vacant(vac) => {
|
|
||||||
vac.insert(vec![resolution]);
|
|
||||||
}
|
|
||||||
std::collections::hash_map::Entry::Occupied(mut occ) => {
|
|
||||||
occ.get_mut().push(resolution);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
drop(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get all the responses into cache that occur before the given timestamp.
|
//#[tokio::test]
|
||||||
///
|
//async fn fetch_cached() {
|
||||||
/// Returns an error if all we get are errors while trying to read from these servers,
|
// let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
|
||||||
/// otherwise returns Ok(true) if we received some interesting responses, or Ok(false)
|
// let name = Name::from_utf8("name.foo").unwrap();
|
||||||
/// if we didn't. Note that receiving good responses wins out over receiving errors, so
|
// let addr = IpAddr::from_str("1.2.4.5").unwrap();
|
||||||
/// if we receive 2 good responses and 3 errors, this function will return Ok(true).
|
//
|
||||||
async fn get_responses_by(
|
// resolver
|
||||||
state: &Arc<Mutex<ResolverState>>,
|
// .inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
|
||||||
deadline: Instant,
|
// .await;
|
||||||
name: &Name,
|
// let read = resolver.lookup(&name).await.unwrap();
|
||||||
stream: &mut SelectAll<DnsResponseStream>,
|
// assert!(read.contains(&addr));
|
||||||
) -> Result<bool, ResolveError> {
|
//}
|
||||||
let mut received_error = None;
|
|
||||||
let mut got_something = false;
|
|
||||||
|
|
||||||
loop {
|
//#[tokio::test]
|
||||||
match timeout_at(deadline, stream.next()).await {
|
//async fn uhsure() {
|
||||||
Err(_) => return received_error.unwrap_or(Ok(got_something)),
|
// let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
|
||||||
|
// let name = Name::from_ascii("uhsure.com").unwrap();
|
||||||
Ok(None) if got_something => return Ok(got_something),
|
// let result = resolver.lookup(&name).await.unwrap();
|
||||||
Ok(None) => return received_error.unwrap_or(Ok(false)),
|
// println!("result = {:?}", result);
|
||||||
|
// assert!(!result.is_empty());
|
||||||
Ok(Some(Err(error))) => {
|
//}
|
||||||
if received_error.is_none() {
|
|
||||||
received_error = Some(Err(ResolveError::ResponseError { error }));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Some(Ok(response))) => {
|
|
||||||
for answer in response.into_message().take_answers() {
|
|
||||||
got_something |= handle_response(state, name, answer).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle an individual response from a server.
|
|
||||||
///
|
|
||||||
/// Returns true if we got an answer from the server, and updates the internal cache.
|
|
||||||
/// This function will never return an error, although there may be some odd processing
|
|
||||||
/// situations that could lead it to printing warnings to the logs.
|
|
||||||
async fn handle_response(state: &Arc<Mutex<ResolverState>>, query: &Name, record: Record) -> bool {
|
|
||||||
let ttl = record.ttl();
|
|
||||||
|
|
||||||
let Some(rdata) = record.into_data() else {
|
|
||||||
tracing::error!("for some reason, couldn't process incoming message");
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
let address: IpAddr = match rdata {
|
|
||||||
RData::A(arec) => arec.0.into(),
|
|
||||||
RData::AAAA(arec) => arec.0.into(),
|
|
||||||
|
|
||||||
_ => {
|
|
||||||
tracing::warn!(record_type = %rdata.record_type(), "skipping unknown / unexpected record type");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let now = Instant::now();
|
|
||||||
let expires_at = now + Duration::from_secs(ttl as u64);
|
|
||||||
let resolution = DnsResolution {
|
|
||||||
address,
|
|
||||||
expires_at,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut state = state.lock().await;
|
|
||||||
match state.cache.entry(query.clone()) {
|
|
||||||
std::collections::hash_map::Entry::Vacant(vec) => {
|
|
||||||
vec.insert(vec![resolution]);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::collections::hash_map::Entry::Occupied(mut occ) => {
|
|
||||||
let vec = occ.get_mut();
|
|
||||||
vec.retain(|x| x.expires_at > now);
|
|
||||||
vec.push(resolution);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
drop(state);
|
|
||||||
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn fetch_cached() {
|
|
||||||
let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
|
|
||||||
let name = Name::from_utf8("name.foo").unwrap();
|
|
||||||
let addr = IpAddr::from_str("1.2.4.5").unwrap();
|
|
||||||
|
|
||||||
resolver
|
|
||||||
.inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
|
|
||||||
.await;
|
|
||||||
let read = resolver.lookup(&name).await.unwrap();
|
|
||||||
assert!(read.contains(&addr));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn uhsure() {
|
|
||||||
let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
|
|
||||||
let name = Name::from_ascii("uhsure.com").unwrap();
|
|
||||||
let result = resolver.lookup(&name).await.unwrap();
|
|
||||||
println!("result = {:?}", result);
|
|
||||||
assert!(!result.is_empty());
|
|
||||||
}
|
|
||||||
|
|||||||
505
src/network/resolver/name.rs
Normal file
505
src/network/resolver/name.rs
Normal file
@@ -0,0 +1,505 @@
|
|||||||
|
use bytes::{Buf, BufMut};
|
||||||
|
use error_stack::{report, ResultExt};
|
||||||
|
use internment::ArcIntern;
|
||||||
|
use proptest::arbitrary::Arbitrary;
|
||||||
|
use proptest::char::CharStrategy;
|
||||||
|
use proptest::strategy::{BoxedStrategy, Strategy};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::fmt;
|
||||||
|
use std::hash::Hash;
|
||||||
|
use std::ops::{Range, RangeInclusive};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||||
|
pub struct Name {
|
||||||
|
labels: Vec<Label>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for Name {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
let mut first_section = true;
|
||||||
|
|
||||||
|
for &Label(ref label) in self.labels.iter() {
|
||||||
|
if first_section {
|
||||||
|
first_section = false;
|
||||||
|
write!(f, "{}", label.as_str())?;
|
||||||
|
} else {
|
||||||
|
write!(f, ".{}", label.as_str())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Name {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
<Self as fmt::Debug>::fmt(self, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Hash, Eq)]
|
||||||
|
pub struct Label(ArcIntern<String>);
|
||||||
|
|
||||||
|
impl fmt::Debug for Label {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
self.0.as_str().fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Label {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
self.0.as_str().fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for Label {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.0.eq_ignore_ascii_case(other.0.as_str())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum NameParseError {
|
||||||
|
#[error("Provided name '{name}' is too long ({observed_length} bytes); maximum length of a DNS name is 255 octets.")]
|
||||||
|
NameTooLong {
|
||||||
|
name: String,
|
||||||
|
observed_length: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("Provided name '{name}' contains illegal character '{illegal_character}'.")]
|
||||||
|
NonAsciiCharacter {
|
||||||
|
name: String,
|
||||||
|
illegal_character: char,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("Provided name '{name}' contains an empty section/label, which isn't allowed.")]
|
||||||
|
EmptyLabel { name: String },
|
||||||
|
|
||||||
|
#[error("Provided name '{name}' contains a label ('{label}') that is too long ({observed_length} letters, which is more than 63).")]
|
||||||
|
LabelTooLong {
|
||||||
|
observed_length: usize,
|
||||||
|
name: String,
|
||||||
|
label: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("Provided name '{name}' contains a label ('{label}') that begins with an illegal character; it must be a letter.")]
|
||||||
|
LabelStartsWrong { name: String, label: String },
|
||||||
|
|
||||||
|
#[error("Provided name '{name}' contains a label ('{label}') that ends with an illegal character; it must be a letter or number.")]
|
||||||
|
LabelEndsWrong { name: String, label: String },
|
||||||
|
|
||||||
|
#[error("Provided name '{name}' contains a label ('{label}') that contains a non-letter, non-number, and non-dash.")]
|
||||||
|
IllegalInternalCharacter { name: String, label: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for Name {
|
||||||
|
type Err = NameParseError;
|
||||||
|
|
||||||
|
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||||
|
let observed_length = value.as_bytes().len();
|
||||||
|
|
||||||
|
if observed_length > 255 {
|
||||||
|
return Err(NameParseError::NameTooLong {
|
||||||
|
name: value.to_string(),
|
||||||
|
observed_length,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for char in value.chars() {
|
||||||
|
if !(char.is_ascii_alphanumeric() || char == '.' || char == '-') {
|
||||||
|
return Err(NameParseError::NonAsciiCharacter {
|
||||||
|
name: value.to_string(),
|
||||||
|
illegal_character: char,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut labels = vec![];
|
||||||
|
|
||||||
|
for label_str in value.split('.') {
|
||||||
|
if label_str.is_empty() {
|
||||||
|
return Err(NameParseError::EmptyLabel {
|
||||||
|
name: value.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if label_str.len() > 63 {
|
||||||
|
return Err(NameParseError::LabelTooLong {
|
||||||
|
name: value.to_string(),
|
||||||
|
label: label_str.to_string(),
|
||||||
|
observed_length: label_str.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let letter = |x| ('a'..='z').contains(&x) || ('A'..='Z').contains(&x);
|
||||||
|
let letter_or_num = |x| letter(x) || ('0'..='9').contains(&x);
|
||||||
|
let letter_num_dash = |x| letter_or_num(x) || (x == '-');
|
||||||
|
|
||||||
|
if !label_str.starts_with(letter) {
|
||||||
|
return Err(NameParseError::LabelStartsWrong {
|
||||||
|
name: value.to_string(),
|
||||||
|
label: label_str.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !label_str.ends_with(letter_or_num) {
|
||||||
|
return Err(NameParseError::LabelEndsWrong {
|
||||||
|
name: value.to_string(),
|
||||||
|
label: label_str.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if label_str.contains(|x| !letter_num_dash(x)) {
|
||||||
|
return Err(NameParseError::IllegalInternalCharacter {
|
||||||
|
name: value.to_string(),
|
||||||
|
label: label_str.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// RFC 1035 says that all domain names are case-insensitive. We
|
||||||
|
// arbitrarily normalize to lowercase here, because it shouldn't
|
||||||
|
// matter to anyone.
|
||||||
|
labels.push(Label(ArcIntern::new(label_str.to_string())));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Name { labels })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum NameReadError {
|
||||||
|
#[error("Could not read name field out of an empty buffer.")]
|
||||||
|
EmptyBuffer,
|
||||||
|
|
||||||
|
#[error("Buffer truncated before we could read the last label")]
|
||||||
|
TruncatedBuffer,
|
||||||
|
|
||||||
|
#[error("Provided label value too long; must be 63 octets or less.")]
|
||||||
|
LabelTooLong,
|
||||||
|
|
||||||
|
#[error("Label truncated while reading. Broken stream?")]
|
||||||
|
LabelTruncated,
|
||||||
|
|
||||||
|
#[error("Label starts with an illegal character (must be [A-Za-z])")]
|
||||||
|
WrongFirstByte,
|
||||||
|
|
||||||
|
#[error("Label ends with an illegal character (must be [A-Za-z0-9]")]
|
||||||
|
WrongLastByte,
|
||||||
|
|
||||||
|
#[error("Label contains an illegal character (must be [A-Za-z0-9] or a dash)")]
|
||||||
|
WrongInnerByte,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Name {
|
||||||
|
/// Read a name frm a record, or return an error describing what went wrong.
|
||||||
|
///
|
||||||
|
/// This function will advance the read pointer on the record. If the result is
|
||||||
|
/// an error, the read pointer is not guaranteed to be in a good place, so you
|
||||||
|
/// may need to manage that externally if you want to implement some sort of
|
||||||
|
/// "try" functionality.
|
||||||
|
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Name, NameReadError> {
|
||||||
|
let mut labels = Vec::new();
|
||||||
|
let name_so_far = |labels: Vec<Label>| {
|
||||||
|
let mut result = String::new();
|
||||||
|
|
||||||
|
for Label(label) in labels.into_iter() {
|
||||||
|
if result.is_empty() {
|
||||||
|
result.push_str(label.as_str());
|
||||||
|
} else {
|
||||||
|
result.push('.');
|
||||||
|
result.push_str(label.as_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if !buffer.has_remaining() && labels.is_empty() {
|
||||||
|
return Err(report!(NameReadError::EmptyBuffer));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !buffer.has_remaining() {
|
||||||
|
return Err(report!(NameReadError::TruncatedBuffer))
|
||||||
|
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let label_octet_length = buffer.get_u8() as usize;
|
||||||
|
if label_octet_length == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if label_octet_length > 63 {
|
||||||
|
return Err(report!(NameReadError::LabelTooLong)).attach_printable_lazy(|| {
|
||||||
|
format!(
|
||||||
|
"label length too big; max is supposed to be 63, saw {label_octet_length}"
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !buffer.remaining() < label_octet_length {
|
||||||
|
let remaining = buffer.copy_to_bytes(buffer.remaining());
|
||||||
|
let partial = String::from_utf8_lossy(&remaining).to_string();
|
||||||
|
|
||||||
|
return Err(report!(NameReadError::LabelTruncated))
|
||||||
|
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!(
|
||||||
|
"expected {} octets, but only found {}",
|
||||||
|
label_octet_length,
|
||||||
|
buffer.remaining()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.attach_printable_lazy(|| format!("partial read: '{partial}'"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let label_bytes = buffer.copy_to_bytes(label_octet_length);
|
||||||
|
let Some(first_byte) = label_bytes.first() else {
|
||||||
|
panic!(
|
||||||
|
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
|
||||||
|
);
|
||||||
|
};
|
||||||
|
let Some(last_byte) = label_bytes.last() else {
|
||||||
|
panic!(
|
||||||
|
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
let letter = |x| (b'a'..=b'z').contains(&x) || (b'A'..=b'Z').contains(&x);
|
||||||
|
let letter_or_num = |x| letter(x) || (b'0'..=b'9').contains(&x);
|
||||||
|
let letter_num_dash = |x| letter_or_num(x) || (x == b'-');
|
||||||
|
|
||||||
|
if !letter(*first_byte) {
|
||||||
|
return Err(report!(NameReadError::WrongFirstByte))
|
||||||
|
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !letter_or_num(*last_byte) {
|
||||||
|
return Err(report!(NameReadError::WrongLastByte))
|
||||||
|
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if label_bytes.iter().any(|x| !letter_num_dash(*x)) {
|
||||||
|
return Err(report!(NameReadError::WrongInnerByte))
|
||||||
|
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let label = label_bytes.into_iter().map(|x| x as char).collect();
|
||||||
|
labels.push(Label(ArcIntern::new(label)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Name { labels })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write a name out to the given buffer.
|
||||||
|
///
|
||||||
|
/// This will try as hard as it can to write the value to the given buffer. If an
|
||||||
|
/// error occurs, you may end up with partially-written data, so if you're worried
|
||||||
|
/// about that you should be careful to mark where you started in the output buffer.
|
||||||
|
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), NameWriteError> {
|
||||||
|
for &Label(ref label) in self.labels.iter() {
|
||||||
|
let bytes = label.as_bytes();
|
||||||
|
|
||||||
|
if buffer.remaining_mut() < (bytes.len() + 1) {
|
||||||
|
return Err(report!(NameWriteError::NoRoomForLabel))
|
||||||
|
.attach_printable_lazy(|| format!("Writing name {self}"))
|
||||||
|
.attach_printable_lazy(|| format!("For label {label}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.is_empty() || bytes.len() > 63 {
|
||||||
|
return Err(report!(NameWriteError::IllegalLabel))
|
||||||
|
.attach_printable_lazy(|| format!("Writing name {self}"))
|
||||||
|
.attach_printable_lazy(|| format!("For label {label}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.put_u8(bytes.len() as u8);
|
||||||
|
buffer.put_slice(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !buffer.has_remaining_mut() {
|
||||||
|
return Err(report!(NameWriteError::NoRoomForNull))
|
||||||
|
.attach_printable_lazy(|| format!("Writing name {self}"));
|
||||||
|
}
|
||||||
|
buffer.put_u8(0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum NameWriteError {
|
||||||
|
#[error("Not enough room to write label in name")]
|
||||||
|
NoRoomForLabel,
|
||||||
|
|
||||||
|
#[error("Ran out of room writing the terminating NULL for the name")]
|
||||||
|
NoRoomForNull,
|
||||||
|
|
||||||
|
#[error("Internal error: Illegal label (this shouldn't happen)")]
|
||||||
|
IllegalLabel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ArbitraryDomainNameSpecifications {
|
||||||
|
total_length_range: Range<usize>,
|
||||||
|
label_length_range: Range<usize>,
|
||||||
|
number_labels: Range<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ArbitraryDomainNameSpecifications {
|
||||||
|
fn default() -> Self {
|
||||||
|
ArbitraryDomainNameSpecifications {
|
||||||
|
total_length_range: 5..256,
|
||||||
|
label_length_range: 1..64,
|
||||||
|
number_labels: 2..5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Arbitrary for Name {
|
||||||
|
type Parameters = ArbitraryDomainNameSpecifications;
|
||||||
|
type Strategy = BoxedStrategy<Name>;
|
||||||
|
|
||||||
|
fn arbitrary_with(spec: Self::Parameters) -> Self::Strategy {
|
||||||
|
(spec.number_labels.clone(), spec.total_length_range.clone())
|
||||||
|
.prop_flat_map(move |(mut labels, mut total_length)| {
|
||||||
|
// we need to make sure that our total length and our label count are,
|
||||||
|
// at the very minimum, compatible. If they're not, we'll need to adjust
|
||||||
|
// them.
|
||||||
|
//
|
||||||
|
// in general, we prefer to update our number of labels, if we can, but
|
||||||
|
// only to the extent that the labels count is at least the minimum of
|
||||||
|
// our input specification. if we try to go below that, we increase total
|
||||||
|
// length as required. If we're forced into a place where we need to take
|
||||||
|
// the label length below it's minimum and/or the total_length over its
|
||||||
|
// maximum, we just give up and panic.
|
||||||
|
//
|
||||||
|
// Note that the minimum length of n labels is (n * label_minimum) + n - 1.
|
||||||
|
// Consider, for example, a label_minimum of 1 and a label length of 3.
|
||||||
|
// A minimum string is "a.b.c", which is (3 * 1) + 3 - 1 = 5 characters
|
||||||
|
// long.
|
||||||
|
//
|
||||||
|
// This loop does the first part, lowering the number of labels until we
|
||||||
|
// either get below the total length or reach the minimum number of labels.
|
||||||
|
while (labels * spec.label_length_range.start) + (labels - 1) > total_length {
|
||||||
|
if labels == spec.number_labels.start {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
labels -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, if it's not right, we just set it to be right.
|
||||||
|
if (labels * spec.total_length_range.start) + (labels - 1) > total_length {
|
||||||
|
total_length = (labels * spec.total_length_range.start) + (labels - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// And if this takes us over our limit, just panic.
|
||||||
|
if total_length >= spec.total_length_range.end {
|
||||||
|
panic!("Unresolvable generation condition; couldn't resolve label count {} with total_length {}, with specification {:?}", labels, total_length, spec);
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest::collection::vec(Label::arbitrary_with(spec.label_length_range.clone()), labels)
|
||||||
|
.prop_map(|labels| Name{ labels })
|
||||||
|
}).boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Arbitrary for Label {
|
||||||
|
type Parameters = Range<usize>;
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary() -> Self::Strategy {
|
||||||
|
Self::arbitrary_with(1..64)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||||
|
args.prop_flat_map(|length| {
|
||||||
|
let first = char_selector(&['a'..='z', 'A'..='Z']);
|
||||||
|
let middle = char_selector(&['a'..='z', 'A'..='Z', '-'..='-', '0'..='9']);
|
||||||
|
let last = char_selector(&['a'..='z', 'A'..='Z', '0'..='9']);
|
||||||
|
|
||||||
|
match length {
|
||||||
|
0 => panic!("Should not be able to generate a label of length 0"),
|
||||||
|
1 => first.prop_map(|x| x.into()).boxed(),
|
||||||
|
2 => (first, last).prop_map(|(a, b)| format!("{a}{b}")).boxed(),
|
||||||
|
_ => (first, proptest::collection::vec(middle, length - 2), last)
|
||||||
|
.prop_map(move |(first, middle, last)| {
|
||||||
|
let mut result = String::with_capacity(length);
|
||||||
|
result.push(first);
|
||||||
|
for c in middle.iter() {
|
||||||
|
result.push(*c);
|
||||||
|
}
|
||||||
|
result.push(last);
|
||||||
|
result
|
||||||
|
})
|
||||||
|
.boxed(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.prop_map(|x| Label(ArcIntern::new(x)))
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn char_selector<'a>(ranges: &'a [RangeInclusive<char>]) -> CharStrategy<'a> {
|
||||||
|
CharStrategy::new(
|
||||||
|
Cow::Borrowed(&[]),
|
||||||
|
Cow::Borrowed(&[]),
|
||||||
|
Cow::Borrowed(ranges),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest::proptest! {
|
||||||
|
#[test]
|
||||||
|
fn any_random_names_parses(name: Name) {
|
||||||
|
let name_str = name.to_string();
|
||||||
|
let new_name = Name::from_str(&name_str).expect("can re-parse name");
|
||||||
|
assert_eq!(name, new_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn any_random_name_roundtrips(name: Name) {
|
||||||
|
let mut write_buffer = bytes::BytesMut::with_capacity(512);
|
||||||
|
name.write(&mut write_buffer).expect("can write name");
|
||||||
|
let mut read_buffer = write_buffer.freeze();
|
||||||
|
let new_name = Name::read(&mut read_buffer).expect("can read name");
|
||||||
|
assert_eq!(name, new_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn illegal_names_generate_errors() {
|
||||||
|
assert!(Name::from_str("").is_err());
|
||||||
|
assert!(Name::from_str(".").is_err());
|
||||||
|
assert!(Name::from_str(".com").is_err());
|
||||||
|
assert!(Name::from_str("com.").is_err());
|
||||||
|
assert!(Name::from_str("9.com").is_err());
|
||||||
|
assert!(Name::from_str("foo-.com").is_err());
|
||||||
|
assert!(Name::from_str("-foo.com").is_err());
|
||||||
|
assert!(Name::from_str(".foo.com").is_err());
|
||||||
|
assert!(Name::from_str("fo*o.com").is_err());
|
||||||
|
assert!(Name::from_str(
|
||||||
|
"foo.abcdefghiabcdefghiabcdefghiabcdefghiabcdefghiabcdefghijjjjjjabcdefghij.com"
|
||||||
|
)
|
||||||
|
.is_err());
|
||||||
|
assert!(Name::from_str("abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn names_ignore_case() {
|
||||||
|
assert_eq!(
|
||||||
|
Name::from_str("UHSURE.COM").unwrap(),
|
||||||
|
Name::from_str("uhsure.com").unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
9
src/network/resolver/protocol.rs
Normal file
9
src/network/resolver/protocol.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
mod client;
|
||||||
|
mod header;
|
||||||
|
pub mod question;
|
||||||
|
mod request;
|
||||||
|
mod resource_record;
|
||||||
|
mod response;
|
||||||
|
mod server;
|
||||||
|
|
||||||
|
pub use client::Client;
|
||||||
273
src/network/resolver/protocol/client.rs
Normal file
273
src/network/resolver/protocol/client.rs
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
use crate::network::resolver::protocol::header::Header;
|
||||||
|
use crate::network::resolver::protocol::question::Question;
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncReadExt, WriteHalf};
|
||||||
|
use tokio::net::{TcpStream, UdpSocket, UnixDatagram, UnixStream};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::task::JoinSet;
|
||||||
|
|
||||||
|
type Callback = fn() -> ();
|
||||||
|
|
||||||
|
pub struct Client {
|
||||||
|
callback: Arc<RwLock<Callback>>,
|
||||||
|
channel: DnsChannel,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum DnsChannel {
|
||||||
|
Tcp(WriteHalf<TcpStream>),
|
||||||
|
Udp(Arc<UdpSocket>),
|
||||||
|
UnixData(Arc<UnixDatagram>),
|
||||||
|
Unix(WriteHalf<UnixStream>),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn empty_callback() {}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum SendError {}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum GeneralAddr {
|
||||||
|
Network(SocketAddr),
|
||||||
|
Unix(std::os::unix::net::SocketAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SocketAddr> for GeneralAddr {
|
||||||
|
fn from(value: SocketAddr) -> Self {
|
||||||
|
GeneralAddr::Network(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<tokio::net::unix::SocketAddr> for GeneralAddr {
|
||||||
|
fn from(value: tokio::net::unix::SocketAddr) -> Self {
|
||||||
|
GeneralAddr::Unix(value.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum ServerProcessorError {
|
||||||
|
#[error("Could not read message header from server.")]
|
||||||
|
CouldNotReadHeader,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn process_server_response(
|
||||||
|
callback: &Arc<RwLock<Callback>>,
|
||||||
|
mut bytes: Bytes,
|
||||||
|
source: GeneralAddr,
|
||||||
|
) -> error_stack::Result<(), ServerProcessorError> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_response_processing_loop(
|
||||||
|
server: String,
|
||||||
|
maximum_consecurive_errors: u64,
|
||||||
|
callback: Arc<RwLock<Callback>>,
|
||||||
|
mut fetcher: impl AsyncFnMut() -> Result<(Bytes, GeneralAddr), std::io::Error>,
|
||||||
|
) {
|
||||||
|
let mut consecutive_errors = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match fetcher().await {
|
||||||
|
Ok((bytes, source)) => {
|
||||||
|
if let Err(e) = process_server_response(&callback, bytes, source).await {
|
||||||
|
consecutive_errors += 1;
|
||||||
|
tracing::warn!(
|
||||||
|
server,
|
||||||
|
error = %e,
|
||||||
|
maximum_consecurive_errors,
|
||||||
|
consecutive_errors,
|
||||||
|
"error processing DNS server response"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
consecutive_errors = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(e) => {
|
||||||
|
consecutive_errors += 1;
|
||||||
|
tracing::warn!(
|
||||||
|
server,
|
||||||
|
error = %e,
|
||||||
|
maximum_consecurive_errors,
|
||||||
|
consecutive_errors,
|
||||||
|
"failed to read response from DNS server"
|
||||||
|
);
|
||||||
|
|
||||||
|
if consecutive_errors >= maximum_consecurive_errors {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::error!(
|
||||||
|
server,
|
||||||
|
"quitting DNS response processing loop due to too many consecutive errors"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Create a new DNS client from the given, targeted UnixDatagram socket.
|
||||||
|
///
|
||||||
|
/// By "targeted", we mean that this socket should've had `connect` called on
|
||||||
|
/// it, so that the client can just send datagrams to the server without having
|
||||||
|
/// to know where to send them. It is unspecified what will happen to DNS clients
|
||||||
|
/// if you do not call `connect` beforehand.
|
||||||
|
pub async fn from_unix_datagram(socket: UnixDatagram, group: &mut JoinSet<()>) -> Self {
|
||||||
|
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||||
|
let socket = Arc::new(socket);
|
||||||
|
|
||||||
|
let reader_socket = socket.clone();
|
||||||
|
let reader_callback = callback.clone();
|
||||||
|
let server_addr = socket.local_addr()
|
||||||
|
.ok()
|
||||||
|
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
|
||||||
|
.flatten()
|
||||||
|
.unwrap_or_else(|| "<unknown>".into());
|
||||||
|
let server = format!("unixd://{}", server_addr);
|
||||||
|
|
||||||
|
group.spawn(async move {
|
||||||
|
let fetcher = async move || {
|
||||||
|
let mut buffer = BytesMut::with_capacity(16384);
|
||||||
|
|
||||||
|
match reader_socket.recv_buf_from(&mut buffer).await {
|
||||||
|
Err(x) => Err(x),
|
||||||
|
Ok((size, from)) => {
|
||||||
|
unsafe {
|
||||||
|
buffer.set_len(size);
|
||||||
|
};
|
||||||
|
Ok((buffer.freeze(), from.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Client {
|
||||||
|
callback,
|
||||||
|
channel: DnsChannel::UnixData(socket),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn from_unix_stream(
|
||||||
|
socket: UnixStream,
|
||||||
|
group: &mut JoinSet<()>,
|
||||||
|
) -> std::io::Result<Self> {
|
||||||
|
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||||
|
let server_addr = socket.local_addr()
|
||||||
|
.ok()
|
||||||
|
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
|
||||||
|
.flatten()
|
||||||
|
.unwrap_or_else(|| "<unknown>".into());
|
||||||
|
let server = format!("unix://{}", server_addr);
|
||||||
|
let other_side = GeneralAddr::Unix(socket.peer_addr()?.into());
|
||||||
|
let (mut reader, writer) = tokio::io::split(socket);
|
||||||
|
|
||||||
|
let reader_callback = callback.clone();
|
||||||
|
group.spawn(async move {
|
||||||
|
let fetcher = async move || {
|
||||||
|
let size = reader.read_u16().await?;
|
||||||
|
|
||||||
|
let mut buffer = vec![0u8; size as usize];
|
||||||
|
reader.read_exact(&mut buffer).await?;
|
||||||
|
|
||||||
|
Ok((Bytes::from(buffer), other_side.clone()))
|
||||||
|
};
|
||||||
|
|
||||||
|
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Client {
|
||||||
|
callback,
|
||||||
|
channel: DnsChannel::Unix(writer),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn from_udp(socket: UdpSocket, group: &mut JoinSet<()>) -> Self {
|
||||||
|
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||||
|
let socket = Arc::new(socket);
|
||||||
|
|
||||||
|
let reader_callback = callback.clone();
|
||||||
|
let reader_socket = socket.clone();
|
||||||
|
let server_addr = socket.local_addr()
|
||||||
|
.ok()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
.unwrap_or_else(|| "<unknown>".into());
|
||||||
|
let server = format!("udp://{}", server_addr);
|
||||||
|
|
||||||
|
group.spawn(async move {
|
||||||
|
let fetcher = async move || {
|
||||||
|
let mut buffer = BytesMut::with_capacity(16384);
|
||||||
|
|
||||||
|
match reader_socket.recv_buf_from(&mut buffer).await {
|
||||||
|
Err(x) => Err(x),
|
||||||
|
Ok((size, from)) => {
|
||||||
|
unsafe {
|
||||||
|
buffer.set_len(size);
|
||||||
|
};
|
||||||
|
Ok((buffer.freeze(), from.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Client {
|
||||||
|
callback,
|
||||||
|
channel: DnsChannel::Udp(socket),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn from_tcp(socket: TcpStream, group: &mut JoinSet<()>) -> std::io::Result<Self> {
|
||||||
|
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||||
|
let other_side = GeneralAddr::Network(socket.peer_addr()?);
|
||||||
|
let server_addr = socket.local_addr()
|
||||||
|
.ok()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
.unwrap_or_else(|| "<unknown>".into());
|
||||||
|
let server = format!("tcp://{}", server_addr);
|
||||||
|
let (mut reader, writer) = tokio::io::split(socket);
|
||||||
|
|
||||||
|
let reader_callback = callback.clone();
|
||||||
|
group.spawn(async move {
|
||||||
|
let fetcher = async move || {
|
||||||
|
let size = reader.read_u16().await?;
|
||||||
|
|
||||||
|
let mut buffer = vec![0u8; size as usize];
|
||||||
|
reader.read_exact(&mut buffer).await?;
|
||||||
|
|
||||||
|
Ok((Bytes::from(buffer), other_side.clone()))
|
||||||
|
};
|
||||||
|
|
||||||
|
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Client {
|
||||||
|
callback,
|
||||||
|
channel: DnsChannel::Tcp(writer),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a set of questions to the upstream server.
|
||||||
|
///
|
||||||
|
/// Any response(s) that is/are sent will be handled as part of the callback
|
||||||
|
/// scheme. This function will thus only return an error if there's a problem
|
||||||
|
/// sending the question to the server.
|
||||||
|
pub async fn send_questions(_questions: Vec<Question>) -> error_stack::Result<(), SendError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the callback handler for when we receive responses from the server.
|
||||||
|
///
|
||||||
|
/// This function may take awhile to execute, depending on how busy we are
|
||||||
|
/// taking responses, as it takes write ownership of a read-write lock that's
|
||||||
|
/// almost always written.
|
||||||
|
pub async fn set_callback(&self, callback: Callback) {
|
||||||
|
*self.callback.write().await = callback;
|
||||||
|
}
|
||||||
|
}
|
||||||
225
src/network/resolver/protocol/header.rs
Normal file
225
src/network/resolver/protocol/header.rs
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
use bytes::{Buf, BufMut};
|
||||||
|
use error_stack::{report, ResultExt};
|
||||||
|
use num_enum::{FromPrimitive, IntoPrimitive};
|
||||||
|
use proptest::arbitrary::Arbitrary;
|
||||||
|
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
||||||
|
use std::fmt;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||||
|
pub struct Header {
|
||||||
|
pub message_id: u16,
|
||||||
|
pub is_response: bool,
|
||||||
|
pub opcode: OpCode,
|
||||||
|
pub authoritative_answer: bool,
|
||||||
|
pub message_truncated: bool,
|
||||||
|
pub recursion_desired: bool,
|
||||||
|
pub recursion_available: bool,
|
||||||
|
pub response_code: ResponseCode,
|
||||||
|
pub question_count: u16,
|
||||||
|
pub answer_count: u16,
|
||||||
|
pub name_server_count: u16,
|
||||||
|
pub additional_record_count: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum OpCode {
|
||||||
|
StandardQuery = 0,
|
||||||
|
InverseQuery = 1,
|
||||||
|
ServiceStatusRequest = 2,
|
||||||
|
#[num_enum(catch_all)]
|
||||||
|
Other(u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Arbitrary for OpCode {
|
||||||
|
type Parameters = ();
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
|
||||||
|
// while the type is 8 bits in rust, it's 4 bits in the protocol,
|
||||||
|
// and dealing with the run-off is messy. so this only generates
|
||||||
|
// valid-sized values here. it also biases toward the legit values.
|
||||||
|
// both things might want to be reconsidered in the future.
|
||||||
|
proptest::prop_oneof![
|
||||||
|
Just(OpCode::StandardQuery),
|
||||||
|
Just(OpCode::InverseQuery),
|
||||||
|
Just(OpCode::ServiceStatusRequest),
|
||||||
|
(3u8..=15).prop_map(OpCode::Other),
|
||||||
|
]
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum ResponseCode {
|
||||||
|
NoErrorConditions = 0,
|
||||||
|
FormatError = 1,
|
||||||
|
ServerFailure = 2,
|
||||||
|
NameError = 3,
|
||||||
|
NotImplemented = 4,
|
||||||
|
Refused = 5,
|
||||||
|
#[num_enum(catch_all)]
|
||||||
|
Other(u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ResponseCode {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
ResponseCode::NoErrorConditions => write!(f, "No errors"),
|
||||||
|
ResponseCode::FormatError => write!(f, "Illegal format"),
|
||||||
|
ResponseCode::ServerFailure => write!(f, "Server failure"),
|
||||||
|
ResponseCode::NameError => write!(f, "Name error"),
|
||||||
|
ResponseCode::NotImplemented => write!(f, "Not implemented"),
|
||||||
|
ResponseCode::Refused => write!(f, "Refused"),
|
||||||
|
ResponseCode::Other(x) => write!(f, "unknown error {x}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Arbitrary for ResponseCode {
|
||||||
|
type Parameters = ();
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
|
||||||
|
// while the type is 8 bits in rust, it's 4 bits in the protocol,
|
||||||
|
// and dealing with the run-off is messy. so this only generates
|
||||||
|
// valid-sized values here. it also biases toward the legit values.
|
||||||
|
// both things might want to be reconsidered in the future.
|
||||||
|
proptest::prop_oneof![
|
||||||
|
Just(ResponseCode::NoErrorConditions),
|
||||||
|
Just(ResponseCode::FormatError),
|
||||||
|
Just(ResponseCode::ServerFailure),
|
||||||
|
Just(ResponseCode::NameError),
|
||||||
|
Just(ResponseCode::NotImplemented),
|
||||||
|
Just(ResponseCode::Refused),
|
||||||
|
(6u8..=15).prop_map(ResponseCode::Other),
|
||||||
|
]
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum HeaderReadError {
|
||||||
|
#[error("Buffer not large enough to have a DNS message header in it.")]
|
||||||
|
BufferTooSmall,
|
||||||
|
|
||||||
|
#[error("Invalid data in zero-filled space.")]
|
||||||
|
NonZeroInZeroes,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum HeaderWriteError {
|
||||||
|
#[error("Buffer not large enough to write header.")]
|
||||||
|
BufferTooSmall,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Header {
|
||||||
|
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, HeaderReadError> {
|
||||||
|
if buffer.remaining() < 12
|
||||||
|
/* 6 16-bit fields = 6 * 2 = 1 */
|
||||||
|
{
|
||||||
|
return Err(report!(HeaderReadError::BufferTooSmall)).attach_printable_lazy(|| {
|
||||||
|
format!(
|
||||||
|
"Need at least {} bytes, but only had {}",
|
||||||
|
6 * 12,
|
||||||
|
buffer.remaining()
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let message_id = buffer.get_u16();
|
||||||
|
let flags = buffer.get_u16();
|
||||||
|
let question_count = buffer.get_u16();
|
||||||
|
let answer_count = buffer.get_u16();
|
||||||
|
let name_server_count = buffer.get_u16();
|
||||||
|
let additional_record_count = buffer.get_u16();
|
||||||
|
|
||||||
|
let is_response = (0x8000 & flags) != 0;
|
||||||
|
let opcode = OpCode::from(((flags >> 11) & 0xF) as u8);
|
||||||
|
let authoritative_answer = (0x0400 & flags) != 0;
|
||||||
|
let message_truncated = (0x0200 & flags) != 0;
|
||||||
|
let recursion_desired = (0x0100 & flags) != 0;
|
||||||
|
let recursion_available = (0x0080 & flags) != 0;
|
||||||
|
let zeroes = 0x0070 & flags;
|
||||||
|
let response_code = ResponseCode::from((flags & 0x000F) as u8);
|
||||||
|
|
||||||
|
if zeroes != 0 {
|
||||||
|
return Err(report!(HeaderReadError::NonZeroInZeroes))
|
||||||
|
.attach_printable_lazy(|| format!("Saw {:#x} instead.", zeroes >> 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Header {
|
||||||
|
message_id,
|
||||||
|
is_response,
|
||||||
|
opcode,
|
||||||
|
authoritative_answer,
|
||||||
|
message_truncated,
|
||||||
|
recursion_desired,
|
||||||
|
recursion_available,
|
||||||
|
response_code,
|
||||||
|
question_count,
|
||||||
|
answer_count,
|
||||||
|
name_server_count,
|
||||||
|
additional_record_count,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), HeaderWriteError> {
|
||||||
|
if buffer.remaining_mut() < 12
|
||||||
|
/* 6 16-bit fields = 6 * 2 = 1 */
|
||||||
|
{
|
||||||
|
return Err(report!(HeaderWriteError::BufferTooSmall)).attach_printable_lazy(|| {
|
||||||
|
format!(
|
||||||
|
"Need at least {} to write DNS header, only have {}",
|
||||||
|
6 * 12,
|
||||||
|
buffer.remaining_mut()
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut flags: u16 = 0;
|
||||||
|
|
||||||
|
if self.is_response {
|
||||||
|
flags |= 0x8000;
|
||||||
|
}
|
||||||
|
let opcode: u8 = self.opcode.into();
|
||||||
|
flags |= (opcode as u16) << 11;
|
||||||
|
if self.authoritative_answer {
|
||||||
|
flags |= 0x0400;
|
||||||
|
}
|
||||||
|
if self.message_truncated {
|
||||||
|
flags |= 0x0200;
|
||||||
|
}
|
||||||
|
if self.recursion_desired {
|
||||||
|
flags |= 0x0100;
|
||||||
|
}
|
||||||
|
if self.recursion_available {
|
||||||
|
flags |= 0x0080;
|
||||||
|
}
|
||||||
|
let response_code: u8 = self.response_code.into();
|
||||||
|
flags |= response_code as u16;
|
||||||
|
|
||||||
|
buffer.put_u16(self.message_id);
|
||||||
|
buffer.put_u16(flags);
|
||||||
|
buffer.put_u16(self.question_count);
|
||||||
|
buffer.put_u16(self.answer_count);
|
||||||
|
buffer.put_u16(self.name_server_count);
|
||||||
|
buffer.put_u16(self.additional_record_count);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest::proptest! {
|
||||||
|
#[test]
|
||||||
|
fn headers_roundtrip(header: Header) {
|
||||||
|
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||||
|
let safe_header = header.clone();
|
||||||
|
header.write(&mut write_buffer).expect("can write name");
|
||||||
|
let mut read_buffer = write_buffer.freeze();
|
||||||
|
let new_header = Header::read(&mut read_buffer).expect("can read name");
|
||||||
|
assert_eq!(safe_header, new_header);
|
||||||
|
}
|
||||||
|
}
|
||||||
91
src/network/resolver/protocol/question.rs
Normal file
91
src/network/resolver/protocol/question.rs
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
use crate::network::resolver::name::Name;
|
||||||
|
use crate::network::resolver::protocol::resource_record::raw::{RecordClass, RecordType};
|
||||||
|
use bytes::{Buf, BufMut, TryGetError};
|
||||||
|
use error_stack::{report, ResultExt};
|
||||||
|
use std::fmt;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||||
|
pub struct Question {
|
||||||
|
name: Name,
|
||||||
|
record_type: RecordType,
|
||||||
|
record_class: RecordClass,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Question {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "<Q:{}@{}.{}>", self.name, self.record_type, self.record_class)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum QuestionReadError {
|
||||||
|
#[error("Could not read name for question.")]
|
||||||
|
CouldNotReadName,
|
||||||
|
|
||||||
|
#[error("Could not read the record type for the question: {0}")]
|
||||||
|
CouldNotReadType(TryGetError),
|
||||||
|
|
||||||
|
#[error("Could not read the record class for the question.")]
|
||||||
|
CouldNotReadClass(TryGetError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum QuestionWriteError {
|
||||||
|
#[error("Could not write name for the question.")]
|
||||||
|
CouldNotWriteName,
|
||||||
|
|
||||||
|
#[error("Buffer not large enough to write question type and class.")]
|
||||||
|
BufferTooSmall,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Question {
|
||||||
|
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, QuestionReadError> {
|
||||||
|
let name = Name::read(buffer).change_context(QuestionReadError::CouldNotReadName)?;
|
||||||
|
let record_type_u16 = buffer
|
||||||
|
.try_get_u16()
|
||||||
|
.map_err(|e| report!(QuestionReadError::CouldNotReadType(e)))
|
||||||
|
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
|
||||||
|
let record_class_u16 = buffer
|
||||||
|
.try_get_u16()
|
||||||
|
.map_err(|e| report!(QuestionReadError::CouldNotReadClass(e)))
|
||||||
|
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
|
||||||
|
|
||||||
|
let record_type = RecordType::from(record_type_u16);
|
||||||
|
let record_class = RecordClass::from(record_class_u16);
|
||||||
|
|
||||||
|
Ok(Question {
|
||||||
|
name,
|
||||||
|
record_type,
|
||||||
|
record_class,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), QuestionWriteError> {
|
||||||
|
self.name
|
||||||
|
.write(buffer)
|
||||||
|
.change_context(QuestionWriteError::CouldNotWriteName)
|
||||||
|
.attach_printable_lazy(|| format!("Question was about '{}'", self.name))?;
|
||||||
|
|
||||||
|
if buffer.remaining_mut() < 4 {
|
||||||
|
return Err(report!(QuestionWriteError::BufferTooSmall))
|
||||||
|
.attach_printable_lazy(|| format!("Question was about '{}'", self.name));
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.put_u16(self.record_type.into());
|
||||||
|
buffer.put_u16(self.record_class.into());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest::proptest! {
|
||||||
|
#[test]
|
||||||
|
fn questions_roundtrip(question: Question) {
|
||||||
|
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||||
|
question.write(&mut write_buffer).expect("can write name");
|
||||||
|
let mut read_buffer = write_buffer.freeze();
|
||||||
|
let new_question = Question::read(&mut read_buffer).expect("can read name");
|
||||||
|
assert_eq!(question, new_question);
|
||||||
|
}
|
||||||
|
}
|
||||||
156
src/network/resolver/protocol/request.rs
Normal file
156
src/network/resolver/protocol/request.rs
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
use bytes::{Buf, BufMut};
|
||||||
|
use error_stack::{report, ResultExt};
|
||||||
|
use crate::network::resolver::protocol::header::{Header, OpCode, ResponseCode};
|
||||||
|
use crate::network::resolver::protocol::question::Question;
|
||||||
|
use crate::network::resolver::protocol::resource_record::ResourceRecord;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||||
|
pub struct Request {
|
||||||
|
source_message_id: u16,
|
||||||
|
opcode: OpCode,
|
||||||
|
recursion_desired: bool,
|
||||||
|
questions: Vec<Question>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum RequestReadError {
|
||||||
|
#[error("Error reading request header.")]
|
||||||
|
Header,
|
||||||
|
#[error("Message isn't a request.")]
|
||||||
|
NotRequest,
|
||||||
|
#[error("Illegal message.")]
|
||||||
|
IllegalMessage,
|
||||||
|
#[error("Error reading request question.")]
|
||||||
|
Question,
|
||||||
|
#[error("Error reading request answers (?!)")]
|
||||||
|
Answer,
|
||||||
|
#[error("Error reading requst name servers.")]
|
||||||
|
NameServers,
|
||||||
|
#[error("Reading request additional records.")]
|
||||||
|
AdditionalRecords,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum RequestWriteError {
|
||||||
|
#[error("Error writing request header.")]
|
||||||
|
Header,
|
||||||
|
#[error("Error writing request question.")]
|
||||||
|
Question,
|
||||||
|
#[error("Request had too many questions.")]
|
||||||
|
TooManyQuestions,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Request {
|
||||||
|
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, RequestReadError> {
|
||||||
|
let header = Header::read(buffer)
|
||||||
|
.change_context(RequestReadError::Header)?;
|
||||||
|
|
||||||
|
if header.is_response {
|
||||||
|
return Err(report!(RequestReadError::NotRequest))
|
||||||
|
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.authoritative_answer {
|
||||||
|
return Err(report!(RequestReadError::IllegalMessage))
|
||||||
|
.attach_printable("Request messages are not allowed to set the 'authoritative answer' bit.")
|
||||||
|
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.response_code != ResponseCode::NoErrorConditions {
|
||||||
|
return Err(report!(RequestReadError::IllegalMessage))
|
||||||
|
.attach_printable("Request messages are not allowed to set the response code.")
|
||||||
|
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.answer_count != 0 {
|
||||||
|
return Err(report!(RequestReadError::IllegalMessage))
|
||||||
|
.attach_printable("Request messages are not allowed to include answers.")
|
||||||
|
.attach_printable_lazy(|| format!("{} answers declared", header.answer_count))
|
||||||
|
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.name_server_count != 0 {
|
||||||
|
return Err(report!(RequestReadError::IllegalMessage))
|
||||||
|
.attach_printable("Request messages are not allowed to include name servers.")
|
||||||
|
.attach_printable_lazy(|| format!("{} name servers declared", header.name_server_count))
|
||||||
|
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.additional_record_count != 0 {
|
||||||
|
return Err(report!(RequestReadError::IllegalMessage))
|
||||||
|
.attach_printable("Request messages are not allowed to include additional records.")
|
||||||
|
.attach_printable_lazy(|| format!("{} aditional records declared", header.additional_record_count))
|
||||||
|
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
let mut questions = vec![];
|
||||||
|
|
||||||
|
for i in 0..header.question_count {
|
||||||
|
let question = Question::read(buffer)
|
||||||
|
.change_context(RequestReadError::Question)
|
||||||
|
.attach_printable_lazy(|| format!("for question #{} of {}", i+1, header.question_count))
|
||||||
|
.attach_printable_lazy(|| format!("message id is {}", header.message_id))?;
|
||||||
|
|
||||||
|
questions.push(question);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Ok(Request {
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
opcode: header.opcode,
|
||||||
|
recursion_desired: header.recursion_desired,
|
||||||
|
questions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), RequestWriteError> {
|
||||||
|
let question_count = self.questions.len();
|
||||||
|
|
||||||
|
if question_count > (u16::MAX as usize) {
|
||||||
|
return Err(report!(RequestWriteError::TooManyQuestions))
|
||||||
|
.attach_printable(format!("message_id is {}", self.source_message_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
let header = Header {
|
||||||
|
message_id: self.source_message_id,
|
||||||
|
is_response: false,
|
||||||
|
opcode: self.opcode,
|
||||||
|
authoritative_answer: false,
|
||||||
|
message_truncated: false,
|
||||||
|
recursion_desired: self.recursion_desired,
|
||||||
|
recursion_available: false,
|
||||||
|
response_code: ResponseCode::NoErrorConditions,
|
||||||
|
question_count: question_count as u16,
|
||||||
|
answer_count: 0,
|
||||||
|
name_server_count: 0,
|
||||||
|
additional_record_count: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
header.write(buffer)
|
||||||
|
.change_context(RequestWriteError::Header)
|
||||||
|
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))?;
|
||||||
|
|
||||||
|
for (index, question) in self.questions.into_iter().enumerate() {
|
||||||
|
question.write(buffer)
|
||||||
|
.change_context(RequestWriteError::Question)
|
||||||
|
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))
|
||||||
|
.attach_printable_lazy(|| format!("question #{} of {}", index+1, question_count))
|
||||||
|
.attach_printable_lazy(|| format!("{}", question))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest::proptest! {
|
||||||
|
#[test]
|
||||||
|
fn request_roundtrip(request: Request) {
|
||||||
|
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||||
|
let safe_request = request.clone();
|
||||||
|
request.write(&mut write_buffer).expect("can write name");
|
||||||
|
let mut read_buffer = write_buffer.freeze();
|
||||||
|
let new_request = Request::read(&mut read_buffer).expect("can read name");
|
||||||
|
assert_eq!(safe_request, new_request);
|
||||||
|
}
|
||||||
|
}
|
||||||
1273
src/network/resolver/protocol/resource_record.rs
Normal file
1273
src/network/resolver/protocol/resource_record.rs
Normal file
File diff suppressed because it is too large
Load Diff
304
src/network/resolver/protocol/resource_record/raw.rs
Normal file
304
src/network/resolver/protocol/resource_record/raw.rs
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
use crate::network::resolver::name::Name;
|
||||||
|
use bytes::{Buf, BufMut, Bytes};
|
||||||
|
use error_stack::{report, ResultExt};
|
||||||
|
use num_enum::{FromPrimitive, IntoPrimitive};
|
||||||
|
use proptest::arbitrary::Arbitrary;
|
||||||
|
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
||||||
|
use std::fmt;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub struct RawResourceRecord {
|
||||||
|
pub name: Name,
|
||||||
|
pub record_type: RecordType,
|
||||||
|
pub record_class: RecordClass,
|
||||||
|
pub ttl: u32,
|
||||||
|
pub data: Bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Arbitrary for RawResourceRecord {
|
||||||
|
type Parameters = ();
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
|
||||||
|
(
|
||||||
|
Name::arbitrary(),
|
||||||
|
RecordType::arbitrary(),
|
||||||
|
RecordClass::arbitrary(),
|
||||||
|
u32::arbitrary(),
|
||||||
|
proptest::collection::vec(u8::arbitrary(), 0..65535),
|
||||||
|
)
|
||||||
|
.prop_map(
|
||||||
|
|(name, record_type, record_class, ttl, data)| RawResourceRecord {
|
||||||
|
name,
|
||||||
|
record_type,
|
||||||
|
record_class,
|
||||||
|
ttl,
|
||||||
|
data: data.into(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||||
|
#[repr(u16)]
|
||||||
|
pub enum RecordType {
|
||||||
|
A = 1,
|
||||||
|
AAAA = 28,
|
||||||
|
NS = 2,
|
||||||
|
MD = 3,
|
||||||
|
MF = 4,
|
||||||
|
CNAME = 5,
|
||||||
|
SOA = 6,
|
||||||
|
MB = 7,
|
||||||
|
MG = 8,
|
||||||
|
MR = 9,
|
||||||
|
NULL = 10,
|
||||||
|
WKS = 11,
|
||||||
|
PTR = 12,
|
||||||
|
HINFO = 13,
|
||||||
|
MINFO = 14,
|
||||||
|
MX = 15,
|
||||||
|
TXT = 16,
|
||||||
|
URI = 256,
|
||||||
|
#[num_enum(catch_all)]
|
||||||
|
Other(u16),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for RecordType {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
RecordType::A => write!(f, "A"),
|
||||||
|
RecordType::AAAA => write!(f, "AAAA"),
|
||||||
|
RecordType::NS => write!(f, "NS"),
|
||||||
|
RecordType::MD => write!(f, "MD"),
|
||||||
|
RecordType::MF => write!(f, "MF"),
|
||||||
|
RecordType::CNAME => write!(f, "CNAME"),
|
||||||
|
RecordType::SOA => write!(f, "SOA"),
|
||||||
|
RecordType::MB => write!(f, "MB"),
|
||||||
|
RecordType::MG => write!(f, "MG"),
|
||||||
|
RecordType::MR => write!(f, "MR"),
|
||||||
|
RecordType::NULL => write!(f, "NULL"),
|
||||||
|
RecordType::WKS => write!(f, "WKS"),
|
||||||
|
RecordType::PTR => write!(f, "PTR"),
|
||||||
|
RecordType::HINFO => write!(f, "HINFO"),
|
||||||
|
RecordType::MINFO => write!(f, "MINFO"),
|
||||||
|
RecordType::MX => write!(f, "MX"),
|
||||||
|
RecordType::TXT => write!(f, "TXT"),
|
||||||
|
RecordType::URI => write!(f, "URI"),
|
||||||
|
RecordType::Other(x) => write!(f, "UNKNOWN<{x}>"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Arbitrary for RecordType {
|
||||||
|
type Parameters = ();
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
|
||||||
|
// this is intentionally biased towards the legit values
|
||||||
|
proptest::prop_oneof![
|
||||||
|
Just(RecordType::A),
|
||||||
|
Just(RecordType::AAAA),
|
||||||
|
Just(RecordType::NS),
|
||||||
|
Just(RecordType::MD),
|
||||||
|
Just(RecordType::MF),
|
||||||
|
Just(RecordType::CNAME),
|
||||||
|
Just(RecordType::SOA),
|
||||||
|
Just(RecordType::MB),
|
||||||
|
Just(RecordType::MG),
|
||||||
|
Just(RecordType::MR),
|
||||||
|
Just(RecordType::NULL),
|
||||||
|
Just(RecordType::WKS),
|
||||||
|
Just(RecordType::PTR),
|
||||||
|
Just(RecordType::HINFO),
|
||||||
|
Just(RecordType::MINFO),
|
||||||
|
Just(RecordType::MX),
|
||||||
|
Just(RecordType::TXT),
|
||||||
|
proptest::prop_oneof![
|
||||||
|
(17u16..28).prop_map(|x| RecordType::Other(x)),
|
||||||
|
(29u16..256).prop_map(|x| RecordType::Other(x)),
|
||||||
|
(257u16..=65535).prop_map(|x| RecordType::Other(x)),
|
||||||
|
Just(RecordType::Other(0)),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||||
|
#[repr(u16)]
|
||||||
|
pub enum RecordClass {
|
||||||
|
IN = 1,
|
||||||
|
CS = 2,
|
||||||
|
CH = 3,
|
||||||
|
HS = 4,
|
||||||
|
#[num_enum(catch_all)]
|
||||||
|
Other(u16),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for RecordClass {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
RecordClass::IN => write!(f, "IN"),
|
||||||
|
RecordClass::CS => write!(f, "CS"),
|
||||||
|
RecordClass::CH => write!(f, "CH"),
|
||||||
|
RecordClass::HS => write!(f, "HS"),
|
||||||
|
RecordClass::Other(x) => write!(f, "UNKNOWN<{x}>"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Arbitrary for RecordClass {
|
||||||
|
type Parameters = ();
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
|
||||||
|
// this is intentionally biased towards the legit values
|
||||||
|
proptest::prop_oneof![
|
||||||
|
Just(RecordClass::IN),
|
||||||
|
Just(RecordClass::CS),
|
||||||
|
Just(RecordClass::CH),
|
||||||
|
Just(RecordClass::HS),
|
||||||
|
(5u16..=65535).prop_map(|x| RecordClass::Other(x)),
|
||||||
|
Just(RecordClass::Other(0)),
|
||||||
|
]
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ResourceRecordReadError {
|
||||||
|
#[error("Failed to read initial record name.")]
|
||||||
|
InitialRecord,
|
||||||
|
|
||||||
|
#[error("Resource record truncated; couldn't find its type field.")]
|
||||||
|
NoTypeField,
|
||||||
|
|
||||||
|
#[error("Resource record truncated; couldn't find its class field.")]
|
||||||
|
NoClassField,
|
||||||
|
|
||||||
|
#[error("Resource record truncated; couldn't find its TTL field.")]
|
||||||
|
NoTtl,
|
||||||
|
|
||||||
|
#[error("Resource record truncated; couldn't find its data length.")]
|
||||||
|
NoDataLength,
|
||||||
|
|
||||||
|
#[error("Resource record truncated; couldn't read its entire data field.")]
|
||||||
|
DataTruncated,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ResourceRecordWriteError {
|
||||||
|
#[error("Could not write name to the output record.")]
|
||||||
|
CouldNotWriteName,
|
||||||
|
|
||||||
|
#[error("Could not write resource record type and class to output record.")]
|
||||||
|
CouldNotWriteTypeClass,
|
||||||
|
|
||||||
|
#[error("Could not write TTL to output record.")]
|
||||||
|
CountNotWriteTtl,
|
||||||
|
|
||||||
|
#[error("Could not write resource record data length to output record.")]
|
||||||
|
CountNotWriteDataLength,
|
||||||
|
|
||||||
|
#[error("Could not write resource record data to output record.")]
|
||||||
|
CountNotWriteData,
|
||||||
|
|
||||||
|
#[error("Input data was too large to write to output stream.")]
|
||||||
|
InputDataTooLarge,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RawResourceRecord {
|
||||||
|
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, ResourceRecordReadError> {
|
||||||
|
let name = Name::read(buffer).change_context(ResourceRecordReadError::InitialRecord)?;
|
||||||
|
|
||||||
|
let record_type = buffer
|
||||||
|
.try_get_u16()
|
||||||
|
.map_err(|_| report!(ResourceRecordReadError::NoTypeField))?
|
||||||
|
.into();
|
||||||
|
let record_class = buffer
|
||||||
|
.try_get_u16()
|
||||||
|
.map_err(|_| report!(ResourceRecordReadError::NoClassField))?
|
||||||
|
.into();
|
||||||
|
let ttl = buffer
|
||||||
|
.try_get_u32()
|
||||||
|
.map_err(|_| report!(ResourceRecordReadError::NoTtl))?;
|
||||||
|
let rdata_length = buffer
|
||||||
|
.try_get_u16()
|
||||||
|
.map_err(|_| report!(ResourceRecordReadError::NoDataLength))?;
|
||||||
|
|
||||||
|
if buffer.remaining() < (rdata_length as usize) {
|
||||||
|
return Err(report!(ResourceRecordReadError::DataTruncated)).attach_printable_lazy(
|
||||||
|
|| {
|
||||||
|
format!(
|
||||||
|
"Expected {rdata_length} bytes, but only saw {}",
|
||||||
|
buffer.remaining()
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let data = buffer.copy_to_bytes(rdata_length as usize);
|
||||||
|
|
||||||
|
Ok(RawResourceRecord {
|
||||||
|
name,
|
||||||
|
record_type,
|
||||||
|
record_class,
|
||||||
|
ttl,
|
||||||
|
data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write<B: BufMut>(
|
||||||
|
&self,
|
||||||
|
buffer: &mut B,
|
||||||
|
) -> error_stack::Result<(), ResourceRecordWriteError> {
|
||||||
|
self.name
|
||||||
|
.write(buffer)
|
||||||
|
.change_context(ResourceRecordWriteError::CouldNotWriteName)?;
|
||||||
|
|
||||||
|
if buffer.remaining_mut() < 4 {
|
||||||
|
return Err(report!(ResourceRecordWriteError::CouldNotWriteTypeClass));
|
||||||
|
}
|
||||||
|
buffer.put_u16(self.record_type.into());
|
||||||
|
buffer.put_u16(self.record_class.into());
|
||||||
|
|
||||||
|
if buffer.remaining_mut() < 4 {
|
||||||
|
return Err(report!(ResourceRecordWriteError::CountNotWriteTtl));
|
||||||
|
}
|
||||||
|
buffer.put_u32(self.ttl);
|
||||||
|
|
||||||
|
if buffer.remaining_mut() < 2 {
|
||||||
|
return Err(report!(ResourceRecordWriteError::CountNotWriteDataLength));
|
||||||
|
}
|
||||||
|
if self.data.len() > (u16::MAX as usize) {
|
||||||
|
return Err(report!(ResourceRecordWriteError::InputDataTooLarge))
|
||||||
|
.attach_printable_lazy(|| {
|
||||||
|
format!(
|
||||||
|
"Incoming data was {} bytes, needs to be < 2^16",
|
||||||
|
self.data.len()
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
buffer.put_u16(self.data.len() as u16);
|
||||||
|
|
||||||
|
if buffer.remaining_mut() < self.data.len() {
|
||||||
|
return Err(report!(ResourceRecordWriteError::CountNotWriteData));
|
||||||
|
}
|
||||||
|
buffer.put_slice(&self.data);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest::proptest! {
|
||||||
|
#[test]
|
||||||
|
fn any_random_record_roundtrips(record: RawResourceRecord) {
|
||||||
|
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||||
|
record.write(&mut write_buffer).expect("can write name");
|
||||||
|
let mut read_buffer = write_buffer.freeze();
|
||||||
|
let new_record = RawResourceRecord::read(&mut read_buffer).expect("can read name");
|
||||||
|
assert_eq!(record, new_record);
|
||||||
|
}
|
||||||
|
}
|
||||||
258
src/network/resolver/protocol/response.rs
Normal file
258
src/network/resolver/protocol/response.rs
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
use bytes::{Buf, BufMut};
|
||||||
|
use error_stack::ResultExt;
|
||||||
|
use crate::network::resolver::protocol::header::{Header, ResponseCode};
|
||||||
|
use crate::network::resolver::protocol::question::Question;
|
||||||
|
use crate::network::resolver::protocol::resource_record::ResourceRecord;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||||
|
pub enum Response {
|
||||||
|
Valid {
|
||||||
|
source_message_id: u16,
|
||||||
|
authoritative: bool,
|
||||||
|
truncated: bool,
|
||||||
|
answers: Vec<ResourceRecord>,
|
||||||
|
name_servers: Vec<ResourceRecord>,
|
||||||
|
additional_records: Vec<ResourceRecord>,
|
||||||
|
},
|
||||||
|
FormatError {
|
||||||
|
source_message_id: u16,
|
||||||
|
},
|
||||||
|
ServerFailure {
|
||||||
|
source_message_id: u16,
|
||||||
|
},
|
||||||
|
NameError {
|
||||||
|
source_message_id: u16,
|
||||||
|
},
|
||||||
|
NotImplemented {
|
||||||
|
source_message_id: u16,
|
||||||
|
},
|
||||||
|
Refused {
|
||||||
|
source_message_id: u16,
|
||||||
|
},
|
||||||
|
UnknownError {
|
||||||
|
#[proptest(strategy="6u8..=15")]
|
||||||
|
error_code: u8,
|
||||||
|
source_message_id: u16,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ResponseReadError {
|
||||||
|
#[error("Could not read response header.")]
|
||||||
|
HeaderReadError,
|
||||||
|
#[error("Could not read question included in response.")]
|
||||||
|
QuestionReadError,
|
||||||
|
#[error("Could not read resource record included as an answer.")]
|
||||||
|
AnswerReadError,
|
||||||
|
#[error("Could not read name server record included in response.")]
|
||||||
|
NameServerReadError,
|
||||||
|
#[error("Could not read supplemental record included in response.")]
|
||||||
|
AdditionalInfoReadError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ResponseWriteError {
|
||||||
|
#[error("Could not write response header.")]
|
||||||
|
Header,
|
||||||
|
#[error("Could not write answer.")]
|
||||||
|
Answer,
|
||||||
|
#[error("Could not write name server.")]
|
||||||
|
NameServer,
|
||||||
|
#[error("Could not write additional record.")]
|
||||||
|
AdditionalRecord,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Response {
|
||||||
|
pub fn message_id(&self) -> u16 {
|
||||||
|
match self {
|
||||||
|
Response::Valid { source_message_id, .. } => *source_message_id,
|
||||||
|
Response::FormatError { source_message_id } => *source_message_id,
|
||||||
|
Response::ServerFailure { source_message_id } => *source_message_id,
|
||||||
|
Response::NameError { source_message_id } => *source_message_id,
|
||||||
|
Response::NotImplemented { source_message_id } => *source_message_id,
|
||||||
|
Response::Refused { source_message_id } => *source_message_id,
|
||||||
|
Response::UnknownError { source_message_id, .. } => *source_message_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Response, ResponseReadError> {
|
||||||
|
let header = Header::read(buffer)
|
||||||
|
.change_context(ResponseReadError::HeaderReadError)?;
|
||||||
|
|
||||||
|
// check for errors, and short-cut out if we find any
|
||||||
|
match header.response_code {
|
||||||
|
ResponseCode::NoErrorConditions => {}
|
||||||
|
|
||||||
|
ResponseCode::FormatError => return Ok(Response::FormatError {
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
}),
|
||||||
|
|
||||||
|
ResponseCode::ServerFailure => return Ok(Response::ServerFailure {
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
}),
|
||||||
|
|
||||||
|
ResponseCode::NameError => return Ok(Response::NameError {
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
}),
|
||||||
|
|
||||||
|
ResponseCode::NotImplemented => return Ok(Response::NotImplemented {
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
}),
|
||||||
|
|
||||||
|
ResponseCode::Refused => return Ok(Response::Refused {
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
}),
|
||||||
|
|
||||||
|
ResponseCode::Other(error_code) => return Ok(Response::UnknownError {
|
||||||
|
error_code,
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// it seems weird to get questions in a response, but we need to parse
|
||||||
|
// them out if they exist.
|
||||||
|
for _ in 0..header.question_count {
|
||||||
|
let question = Question::read(buffer)
|
||||||
|
.change_context(ResponseReadError::QuestionReadError)?;
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
%question,
|
||||||
|
"got question during server response."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut answers = vec![];
|
||||||
|
for idx in 0..header.answer_count {
|
||||||
|
let answer = ResourceRecord::read(buffer)
|
||||||
|
.change_context(ResponseReadError::AnswerReadError)
|
||||||
|
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.answer_count))?;
|
||||||
|
|
||||||
|
answers.push(answer);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut name_servers = vec![];
|
||||||
|
for idx in 0..header.name_server_count {
|
||||||
|
let name_server = ResourceRecord::read(buffer)
|
||||||
|
.change_context(ResponseReadError::NameServerReadError)
|
||||||
|
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.name_server_count))?;
|
||||||
|
|
||||||
|
name_servers.push(name_server);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut additional_records = vec![];
|
||||||
|
for idx in 0..header.additional_record_count {
|
||||||
|
let extra = ResourceRecord::read(buffer)
|
||||||
|
.change_context(ResponseReadError::AnswerReadError)
|
||||||
|
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.additional_record_count))?;
|
||||||
|
|
||||||
|
additional_records.push(extra);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Response::Valid {
|
||||||
|
source_message_id: header.message_id,
|
||||||
|
authoritative: header.authoritative_answer,
|
||||||
|
truncated: header.message_truncated,
|
||||||
|
answers,
|
||||||
|
name_servers,
|
||||||
|
additional_records,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_error<B: BufMut>(source_message_id: u16, error_code: ResponseCode, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
|
||||||
|
let header = Header {
|
||||||
|
message_id: source_message_id,
|
||||||
|
is_response: true,
|
||||||
|
opcode: super::header::OpCode::StandardQuery,
|
||||||
|
authoritative_answer: false,
|
||||||
|
message_truncated: false,
|
||||||
|
recursion_desired: false,
|
||||||
|
recursion_available: false,
|
||||||
|
response_code: error_code,
|
||||||
|
question_count: 0,
|
||||||
|
answer_count: 0,
|
||||||
|
name_server_count: 0,
|
||||||
|
additional_record_count: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
header
|
||||||
|
.write(buffer)
|
||||||
|
.change_context(ResponseWriteError::Header)
|
||||||
|
.attach_printable_lazy(|| format!("responding to message {source_message_id} with error code {error_code}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
|
||||||
|
let (source_message_id, authoritative, truncated, answers, name_servers, additional_records) = match self {
|
||||||
|
Response::FormatError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::FormatError, buffer),
|
||||||
|
|
||||||
|
Response::ServerFailure { source_message_id } => return Self::write_error(source_message_id, ResponseCode::ServerFailure, buffer),
|
||||||
|
|
||||||
|
Response::NameError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NameError, buffer),
|
||||||
|
|
||||||
|
Response::NotImplemented { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NotImplemented, buffer),
|
||||||
|
|
||||||
|
Response::Refused { source_message_id } => return Self::write_error(source_message_id, ResponseCode::Refused, buffer),
|
||||||
|
|
||||||
|
Response::UnknownError { error_code, source_message_id } => return Self::write_error(source_message_id, ResponseCode::Other(error_code), buffer),
|
||||||
|
|
||||||
|
Response::Valid { source_message_id, authoritative, truncated, answers, name_servers, additional_records } => {
|
||||||
|
(source_message_id, authoritative, truncated, answers, name_servers, additional_records)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let header = Header {
|
||||||
|
message_id: source_message_id,
|
||||||
|
is_response: true,
|
||||||
|
opcode: super::header::OpCode::StandardQuery,
|
||||||
|
authoritative_answer: authoritative,
|
||||||
|
message_truncated: truncated,
|
||||||
|
recursion_desired: false,
|
||||||
|
recursion_available: false,
|
||||||
|
response_code: ResponseCode::NoErrorConditions,
|
||||||
|
question_count: 0,
|
||||||
|
answer_count: answers.len() as u16,
|
||||||
|
name_server_count: name_servers.len() as u16,
|
||||||
|
additional_record_count: additional_records.len() as u16,
|
||||||
|
};
|
||||||
|
|
||||||
|
header.write(buffer)
|
||||||
|
.change_context(ResponseWriteError::Header)
|
||||||
|
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))?;
|
||||||
|
|
||||||
|
let answer_count = answers.len();
|
||||||
|
for (item, answer) in answers.into_iter().enumerate() {
|
||||||
|
answer.write(buffer)
|
||||||
|
.change_context(ResponseWriteError::Answer)
|
||||||
|
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
|
||||||
|
.attach_printable_lazy(|| format!("Writing answer {} of {}", item+1, answer_count))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ns_count = name_servers.len();
|
||||||
|
for (item, answer) in name_servers.into_iter().enumerate() {
|
||||||
|
answer.write(buffer)
|
||||||
|
.change_context(ResponseWriteError::NameServer)
|
||||||
|
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
|
||||||
|
.attach_printable_lazy(|| format!("Writing name server {} of {}", item+1, ns_count))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ar_count = additional_records.len();
|
||||||
|
for (item, answer) in additional_records.into_iter().enumerate() {
|
||||||
|
answer.write(buffer)
|
||||||
|
.change_context(ResponseWriteError::AdditionalRecord)
|
||||||
|
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
|
||||||
|
.attach_printable_lazy(|| format!("Writing additional record {} of {}", item+1, ar_count))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest::proptest! {
|
||||||
|
#[test]
|
||||||
|
fn any_random_response_roundtrips(record: Response) {
|
||||||
|
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||||
|
record.clone().write(&mut write_buffer).expect("can write name");
|
||||||
|
let mut read_buffer = write_buffer.freeze();
|
||||||
|
let new_record = Response::read(&mut read_buffer).expect("can read name");
|
||||||
|
assert_eq!(record, new_record);
|
||||||
|
}
|
||||||
|
}
|
||||||
1
src/network/resolver/protocol/server.rs
Normal file
1
src/network/resolver/protocol/server.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
176
src/network/resolver/resolution_table.rs
Normal file
176
src/network/resolver/resolution_table.rs
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
use crate::network::resolver::name::Name;
|
||||||
|
use std::collections::hash_map::Entry;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use tokio::time::{Duration, Instant};
|
||||||
|
|
||||||
|
pub struct ResolutionTable {
|
||||||
|
inner: HashMap<Name, Vec<Resolution>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Resolution {
|
||||||
|
result: IpAddr,
|
||||||
|
expiration: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ResolutionTable {
|
||||||
|
fn default() -> Self {
|
||||||
|
ResolutionTable::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResolutionTable {
|
||||||
|
/// Generate a new, empty resolution table to use in a new DNS implementation,
|
||||||
|
/// or shared by a bunch of them.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
ResolutionTable {
|
||||||
|
inner: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean the table of expired entries.
|
||||||
|
pub fn garbage_collect(&mut self) {
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
self.inner.retain(|_, items| {
|
||||||
|
items.retain(|x| x.expiration > now);
|
||||||
|
!items.is_empty()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new entry to the resolution table, with a TTL on it.
|
||||||
|
pub fn add_entry(&mut self, name: Name, maps_to: IpAddr, ttl: Duration) {
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let new_entry = Resolution {
|
||||||
|
result: maps_to,
|
||||||
|
expiration: now + ttl,
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.inner.entry(name) {
|
||||||
|
Entry::Vacant(vac) => {
|
||||||
|
vac.insert(vec![new_entry]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Entry::Occupied(mut occ) => {
|
||||||
|
occ.get_mut().push(new_entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up an entry in the resolution map. This will only return
|
||||||
|
/// unexpired items.
|
||||||
|
pub fn lookup(&mut self, name: &Name) -> HashSet<IpAddr> {
|
||||||
|
let mut result = HashSet::new();
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
if let Some(entry) = self.inner.get_mut(name) {
|
||||||
|
entry.retain(|x| {
|
||||||
|
let retain = x.expiration > now;
|
||||||
|
|
||||||
|
if retain {
|
||||||
|
result.insert(x.result);
|
||||||
|
}
|
||||||
|
|
||||||
|
retain
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
|
#[cfg(test)]
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_set_gets_fail() {
|
||||||
|
let mut empty = ResolutionTable::default();
|
||||||
|
assert!(empty.lookup(&Name::from_str("foo").unwrap()).is_empty());
|
||||||
|
assert!(empty.lookup(&Name::from_str("bar").unwrap()).is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn basic_lookups() {
|
||||||
|
let mut table = ResolutionTable::new();
|
||||||
|
let foo = Name::from_str("foo").unwrap();
|
||||||
|
let bar = Name::from_str("bar").unwrap();
|
||||||
|
let baz = Name::from_str("baz").unwrap();
|
||||||
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
||||||
|
let long_time = Duration::from_secs(10000000000);
|
||||||
|
|
||||||
|
table.add_entry(foo.clone(), localhost, long_time);
|
||||||
|
table.add_entry(bar.clone(), localhost, long_time);
|
||||||
|
table.add_entry(bar.clone(), other, long_time);
|
||||||
|
|
||||||
|
assert_eq!(1, table.lookup(&foo).len());
|
||||||
|
assert_eq!(2, table.lookup(&bar).len());
|
||||||
|
assert!(table.lookup(&baz).is_empty());
|
||||||
|
assert!(table.lookup(&foo).contains(&localhost));
|
||||||
|
assert!(!table.lookup(&foo).contains(&other));
|
||||||
|
assert!(table.lookup(&bar).contains(&localhost));
|
||||||
|
assert!(table.lookup(&bar).contains(&other));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lookup_cleans_up() {
|
||||||
|
let mut table = ResolutionTable::new();
|
||||||
|
let foo = Name::from_str("foo").unwrap();
|
||||||
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
||||||
|
let short_time = Duration::from_millis(100);
|
||||||
|
let long_time = Duration::from_secs(10000000000);
|
||||||
|
|
||||||
|
table.add_entry(foo.clone(), localhost, long_time);
|
||||||
|
table.add_entry(foo.clone(), other, short_time);
|
||||||
|
let wait_until = Instant::now() + (2 * short_time);
|
||||||
|
while Instant::now() < wait_until {
|
||||||
|
std::thread::sleep(short_time);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(1, table.lookup(&foo).len());
|
||||||
|
assert!(table.lookup(&foo).contains(&localhost));
|
||||||
|
assert!(!table.lookup(&foo).contains(&other));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn garbage_collection_works() {
|
||||||
|
let mut table = ResolutionTable::new();
|
||||||
|
let foo = Name::from_str("foo").unwrap();
|
||||||
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
||||||
|
let short_time = Duration::from_millis(100);
|
||||||
|
let long_time = Duration::from_secs(10000000000);
|
||||||
|
|
||||||
|
table.add_entry(foo.clone(), localhost, long_time);
|
||||||
|
table.add_entry(foo.clone(), other, short_time);
|
||||||
|
let wait_until = Instant::now() + (2 * short_time);
|
||||||
|
while Instant::now() < wait_until {
|
||||||
|
std::thread::sleep(short_time);
|
||||||
|
}
|
||||||
|
table.garbage_collect();
|
||||||
|
|
||||||
|
assert_eq!(1, table.inner.get(&foo).unwrap().len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn garbage_collection_clears_empties() {
|
||||||
|
let mut table = ResolutionTable::new();
|
||||||
|
let foo = Name::from_str("foo").unwrap();
|
||||||
|
table.inner.insert(foo.clone(), vec![]);
|
||||||
|
table.garbage_collect();
|
||||||
|
assert!(table.inner.is_empty());
|
||||||
|
|
||||||
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let short_time = Duration::from_millis(100);
|
||||||
|
table.add_entry(foo.clone(), localhost, short_time);
|
||||||
|
let wait_until = Instant::now() + (2 * short_time);
|
||||||
|
while Instant::now() < wait_until {
|
||||||
|
std::thread::sleep(short_time);
|
||||||
|
}
|
||||||
|
table.garbage_collect();
|
||||||
|
assert!(table.inner.is_empty());
|
||||||
|
}
|
||||||
@@ -37,8 +37,6 @@ pub enum OperationalError {
|
|||||||
ChannelFailure,
|
ChannelFailure,
|
||||||
#[error("Error in initial handshake: {0}")]
|
#[error("Error in initial handshake: {0}")]
|
||||||
KeyxProcessingError(#[from] SshKeyExchangeProcessingError),
|
KeyxProcessingError(#[from] SshKeyExchangeProcessingError),
|
||||||
#[error("Call into random number generator failed: {0}")]
|
|
||||||
RngFailure(#[from] rand::Error),
|
|
||||||
#[error("Invalid port number '{port_string}': {error}")]
|
#[error("Invalid port number '{port_string}': {error}")]
|
||||||
InvalidPort {
|
InvalidPort {
|
||||||
port_string: String,
|
port_string: String,
|
||||||
|
|||||||
@@ -44,10 +44,10 @@ where
|
|||||||
/// write operations are cancel- and concurrency-safe. They take ownership
|
/// write operations are cancel- and concurrency-safe. They take ownership
|
||||||
/// of the underlying stream once established, where "established" means
|
/// of the underlying stream once established, where "established" means
|
||||||
/// that we have read and written the initial SSH banners.
|
/// that we have read and written the initial SSH banners.
|
||||||
pub fn new(stream: Stream) -> SshChannel<Stream> {
|
pub fn new(stream: Stream) -> Result<SshChannel<Stream>, getrandom::Error> {
|
||||||
let (read_half, write_half) = tokio::io::split(stream);
|
let (read_half, write_half) = tokio::io::split(stream);
|
||||||
|
|
||||||
SshChannel {
|
Ok(SshChannel {
|
||||||
read_side: Mutex::new(ReadSide {
|
read_side: Mutex::new(ReadSide {
|
||||||
stream: read_half,
|
stream: read_half,
|
||||||
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
|
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
|
||||||
@@ -55,12 +55,12 @@ where
|
|||||||
write_side: Mutex::new(WriteSide {
|
write_side: Mutex::new(WriteSide {
|
||||||
stream: write_half,
|
stream: write_half,
|
||||||
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
|
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
|
||||||
rng: ChaCha20Rng::from_entropy(),
|
rng: ChaCha20Rng::try_from_os_rng()?,
|
||||||
}),
|
}),
|
||||||
mac_length: 0,
|
mac_length: 0,
|
||||||
cipher_block_size: 8,
|
cipher_block_size: 8,
|
||||||
channel_is_closed: false,
|
channel_is_closed: false,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read an SshPacket from the wire.
|
/// Read an SshPacket from the wire.
|
||||||
@@ -155,7 +155,7 @@ where
|
|||||||
encoded_packet.put(packet.buffer);
|
encoded_packet.put(packet.buffer);
|
||||||
|
|
||||||
for _ in 0..rounded_padding {
|
for _ in 0..rounded_padding {
|
||||||
encoded_packet.put_u8(rng.gen());
|
encoded_packet.put_u8(rng.random());
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(encoded_packet.freeze())
|
Some(encoded_packet.freeze())
|
||||||
@@ -247,8 +247,8 @@ proptest::proptest! {
|
|||||||
fn can_read_back_anything(packet in SshPacket::arbitrary()) {
|
fn can_read_back_anything(packet in SshPacket::arbitrary()) {
|
||||||
let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||||
let (left, right) = tokio::io::duplex(8192);
|
let (left, right) = tokio::io::duplex(8192);
|
||||||
let leftssh = SshChannel::new(left);
|
let leftssh = SshChannel::new(left).unwrap();
|
||||||
let rightssh = SshChannel::new(right);
|
let rightssh = SshChannel::new(right).unwrap();
|
||||||
let packet_copy = packet.clone();
|
let packet_copy = packet.clone();
|
||||||
|
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
@@ -264,8 +264,8 @@ proptest::proptest! {
|
|||||||
fn sequences_send_correctly_serial(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) {
|
fn sequences_send_correctly_serial(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) {
|
||||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||||
let (left, right) = tokio::io::duplex(8192);
|
let (left, right) = tokio::io::duplex(8192);
|
||||||
let leftssh = SshChannel::new(left);
|
let leftssh = SshChannel::new(left).unwrap();
|
||||||
let rightssh = SshChannel::new(right);
|
let rightssh = SshChannel::new(right).unwrap();
|
||||||
|
|
||||||
let sequence_left = sequence.clone();
|
let sequence_left = sequence.clone();
|
||||||
let sequence_right = sequence;
|
let sequence_right = sequence;
|
||||||
@@ -335,9 +335,9 @@ proptest::proptest! {
|
|||||||
|
|
||||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||||
let (left, right) = tokio::io::duplex(8192);
|
let (left, right) = tokio::io::duplex(8192);
|
||||||
let leftsshw = Arc::new(SshChannel::new(left));
|
let leftsshw = Arc::new(SshChannel::new(left).unwrap());
|
||||||
let leftsshr = leftsshw.clone();
|
let leftsshr = leftsshw.clone();
|
||||||
let rightsshw = Arc::new(SshChannel::new(right));
|
let rightsshw = Arc::new(SshChannel::new(right).unwrap());
|
||||||
let rightsshr = rightsshw.clone();
|
let rightsshr = rightsshw.clone();
|
||||||
|
|
||||||
let sequence_left_write = sequence.clone();
|
let sequence_left_write = sequence.clone();
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ impl SshKeyExchange {
|
|||||||
/// seed the message with a random cookie, but is otherwise deterministic.
|
/// seed the message with a random cookie, but is otherwise deterministic.
|
||||||
/// It will fail only in the case that the underlying random number
|
/// It will fail only in the case that the underlying random number
|
||||||
/// generator fails, and return exactly that error.
|
/// generator fails, and return exactly that error.
|
||||||
pub fn new<R>(rng: &mut R, value: ClientConnectionOpts) -> Result<Self, rand::Error>
|
pub fn new<R>(rng: &mut R, value: ClientConnectionOpts) -> Result<Self, ()>
|
||||||
where
|
where
|
||||||
R: CryptoRng + Rng,
|
R: CryptoRng + Rng,
|
||||||
{
|
{
|
||||||
@@ -185,7 +185,7 @@ impl SshKeyExchange {
|
|||||||
first_kex_packet_follows: value.predict.is_some(),
|
first_kex_packet_follows: value.predict.is_some(),
|
||||||
};
|
};
|
||||||
|
|
||||||
rng.try_fill(&mut result.cookie)?;
|
rng.fill(&mut result.cookie);
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user