Workspacify
This commit is contained in:
20
resolver/Cargo.toml
Normal file
20
resolver/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "resolver"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
bytes = { workspace = true }
|
||||
configuration = { workspace = true }
|
||||
crypto = { workspace = true }
|
||||
error-stack = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
internment = { workspace = true }
|
||||
num_enum = { workspace = true }
|
||||
proptest = { workspace = true }
|
||||
proptest-derive = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
url = { workspace = true }
|
||||
247
resolver/src/lib.rs
Normal file
247
resolver/src/lib.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
pub mod name;
|
||||
mod protocol;
|
||||
mod resolution_table;
|
||||
|
||||
use configuration::resolver::{DnsConfig, NameServerConfig};
|
||||
use crate::name::Name;
|
||||
use crate::resolution_table::ResolutionTable;
|
||||
use error_stack::{report, ResultExt};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpSocket, UdpSocket};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ResolverConfigError {
|
||||
#[error("Bad local domain name provided")]
|
||||
BadDomainName,
|
||||
#[error("Couldn't create a DNS client for the given address, port, and protocol.")]
|
||||
FailedToCreateDnsClient,
|
||||
#[error("No DNS servers found to search, and mDNS not enabled")]
|
||||
NoHosts,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ResolveError {
|
||||
#[error("No servers available for query")]
|
||||
NoServersAvailable,
|
||||
#[error("No responses found for query")]
|
||||
NoResponses,
|
||||
#[error("Error reading response from server")]
|
||||
ResponseError,
|
||||
}
|
||||
|
||||
pub struct Resolver {
|
||||
search_domains: Vec<Name>,
|
||||
max_time_to_wait_for_initial: Duration,
|
||||
time_to_wait_after_first: Duration,
|
||||
time_to_wait_for_lingering: Duration,
|
||||
connections: Arc<Mutex<Vec<(NameServerConfig, protocol::Client)>>>,
|
||||
table: Arc<Mutex<ResolutionTable>>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
pub struct ResolverState {
|
||||
client_connections: Vec<(NameServerConfig, protocol::Client)>,
|
||||
cache: HashMap<Name, Vec<DnsResolution>>,
|
||||
}
|
||||
|
||||
pub struct DnsResolution {
|
||||
address: IpAddr,
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
/// Create a new DNS resolution engine for use by some part of the system.
|
||||
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
|
||||
let mut search_domains = Vec::new();
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
if let Some(local) = config.local_domain.as_ref() {
|
||||
let name = Name::from_str(local)
|
||||
.change_context(ResolverConfigError::BadDomainName)
|
||||
.attach_printable("Trying to add local domain.")
|
||||
.attach_printable_lazy(|| "Offending non-name: '{local}'")?;
|
||||
|
||||
search_domains.push(name);
|
||||
}
|
||||
|
||||
for search_domain in config.search_domains.iter() {
|
||||
let name = Name::from_str(search_domain)
|
||||
.change_context(ResolverConfigError::BadDomainName)
|
||||
.attach_printable("Trying to add search domain.")
|
||||
.attach_printable_lazy(|| "Offending non-name: '{search_domain}'")?;
|
||||
|
||||
search_domains.push(name);
|
||||
}
|
||||
|
||||
let mut client_connections = Vec::new();
|
||||
|
||||
for target in config.name_servers.iter() {
|
||||
let port = target.address.port().unwrap_or(53);
|
||||
let Some(address) = target.address.host() else {
|
||||
return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
|
||||
.attach_printable("No address to connect to?")
|
||||
.attach_printable_lazy(|| format!("Target address was {}", target.address));
|
||||
};
|
||||
|
||||
let address = match address {
|
||||
url::Host::Ipv4(addr) => IpAddr::V4(addr),
|
||||
url::Host::Ipv6(addr) => IpAddr::V6(addr),
|
||||
url::Host::Domain(name) => {
|
||||
return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
|
||||
.attach_printable("Cannot use domain names to identify domain servers")
|
||||
.attach_printable_lazy(|| format!("Target address was {name}"));
|
||||
}
|
||||
};
|
||||
let target_addr = SocketAddr::new(address, port);
|
||||
|
||||
match target.address.scheme() {
|
||||
"tcp" => {
|
||||
let socket = if target_addr.is_ipv4() {
|
||||
TcpSocket::new_v4()
|
||||
} else {
|
||||
TcpSocket::new_v6()
|
||||
};
|
||||
|
||||
let socket = socket
|
||||
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||
.attach_printable("Could not create a socket")
|
||||
.attach_printable_lazy(|| {
|
||||
format!("For target DNS server {}", target.address)
|
||||
})?;
|
||||
|
||||
if let Some(bind_address) = target.bind_address {
|
||||
socket
|
||||
.bind(bind_address)
|
||||
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||
.attach_printable("Could not bind local address for socket.")
|
||||
.attach_printable_lazy(|| {
|
||||
format!("Binding to TCP address {}", bind_address)
|
||||
})
|
||||
.attach_printable_lazy(|| {
|
||||
format!("For target DNS server {}", target.address)
|
||||
})?;
|
||||
}
|
||||
|
||||
let stream = socket
|
||||
.connect(target_addr)
|
||||
.await
|
||||
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||
.attach_printable_lazy(|| {
|
||||
format!("Connecting to target {}", target_addr)
|
||||
})?;
|
||||
let client = protocol::Client::from_tcp(stream, &mut tasks)
|
||||
.await
|
||||
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||
.attach_printable_lazy(|| {
|
||||
format!("Connecting to target {}", target_addr)
|
||||
})?;
|
||||
|
||||
client_connections.push((target.clone(), client));
|
||||
}
|
||||
|
||||
"udp" => {
|
||||
let port = target.address.port().unwrap_or(53);
|
||||
let Some(address) = target.address.host() else {
|
||||
tracing::warn!(address = %target.address, "proposed domain server has no host");
|
||||
continue;
|
||||
};
|
||||
let address = match address {
|
||||
url::Host::Ipv4(addr) => IpAddr::V4(addr),
|
||||
url::Host::Ipv6(addr) => IpAddr::V6(addr),
|
||||
url::Host::Domain(name) => {
|
||||
tracing::warn!(
|
||||
address = %target.address,
|
||||
hostname = name,
|
||||
"currently, we can't use hostnames for domain servers"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let sock_addr = SocketAddr::new(address, port);
|
||||
let bind_address = target.bind_address.unwrap_or_else(|| {
|
||||
if sock_addr.is_ipv4() {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
|
||||
} else {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
|
||||
}
|
||||
});
|
||||
|
||||
let udp_socket = UdpSocket::bind(bind_address)
|
||||
.await
|
||||
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||
.attach_printable("Generating UDP socket")
|
||||
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
|
||||
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
|
||||
|
||||
udp_socket
|
||||
.connect(target_addr)
|
||||
.await
|
||||
.change_context(ResolverConfigError::FailedToCreateDnsClient)
|
||||
.attach_printable("Connecting UDP socket")
|
||||
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
|
||||
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
|
||||
|
||||
let client = protocol::Client::from_udp(udp_socket, &mut tasks).await;
|
||||
|
||||
client_connections.push((target.clone(), client));
|
||||
}
|
||||
|
||||
"unix" => unimplemented!(),
|
||||
"unixd" => unimplemented!(),
|
||||
|
||||
"http" => unimplemented!(),
|
||||
"https" => unimplemented!(),
|
||||
|
||||
_ => {
|
||||
tracing::warn!(address = %target.address, "Unknown scheme for building DNS connections");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Resolver {
|
||||
search_domains,
|
||||
// FIXME: All of these should be configurable
|
||||
max_time_to_wait_for_initial: Duration::from_millis(150),
|
||||
time_to_wait_after_first: Duration::from_millis(50),
|
||||
time_to_wait_for_lingering: Duration::from_secs(2),
|
||||
connections: Arc::new(Mutex::new(client_connections)),
|
||||
table: Arc::new(Mutex::new(ResolutionTable::new())),
|
||||
tasks,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn lookup(&self, _name: &Name) -> error_stack::Result<HashSet<IpAddr>, ResolveError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
//#[tokio::test]
|
||||
//async fn fetch_cached() {
|
||||
// let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
|
||||
// let name = Name::from_utf8("name.foo").unwrap();
|
||||
// let addr = IpAddr::from_str("1.2.4.5").unwrap();
|
||||
//
|
||||
// resolver
|
||||
// .inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
|
||||
// .await;
|
||||
// let read = resolver.lookup(&name).await.unwrap();
|
||||
// assert!(read.contains(&addr));
|
||||
//}
|
||||
|
||||
//#[tokio::test]
|
||||
//async fn uhsure() {
|
||||
// let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
|
||||
// let name = Name::from_ascii("uhsure.com").unwrap();
|
||||
// let result = resolver.lookup(&name).await.unwrap();
|
||||
// println!("result = {:?}", result);
|
||||
// assert!(!result.is_empty());
|
||||
//}
|
||||
505
resolver/src/name.rs
Normal file
505
resolver/src/name.rs
Normal file
@@ -0,0 +1,505 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use error_stack::{report, ResultExt};
|
||||
use internment::ArcIntern;
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::char::CharStrategy;
|
||||
use proptest::strategy::{BoxedStrategy, Strategy};
|
||||
use std::borrow::Cow;
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::str::FromStr;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
pub struct Name {
|
||||
labels: Vec<Label>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Name {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let mut first_section = true;
|
||||
|
||||
for &Label(ref label) in self.labels.iter() {
|
||||
if first_section {
|
||||
first_section = false;
|
||||
write!(f, "{}", label.as_str())?;
|
||||
} else {
|
||||
write!(f, ".{}", label.as_str())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Name {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
<Self as fmt::Debug>::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Hash, Eq)]
|
||||
pub struct Label(ArcIntern<String>);
|
||||
|
||||
impl fmt::Debug for Label {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.as_str().fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Label {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.as_str().fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Label {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0.eq_ignore_ascii_case(other.0.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NameParseError {
|
||||
#[error("Provided name '{name}' is too long ({observed_length} bytes); maximum length of a DNS name is 255 octets.")]
|
||||
NameTooLong {
|
||||
name: String,
|
||||
observed_length: usize,
|
||||
},
|
||||
|
||||
#[error("Provided name '{name}' contains illegal character '{illegal_character}'.")]
|
||||
NonAsciiCharacter {
|
||||
name: String,
|
||||
illegal_character: char,
|
||||
},
|
||||
|
||||
#[error("Provided name '{name}' contains an empty section/label, which isn't allowed.")]
|
||||
EmptyLabel { name: String },
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that is too long ({observed_length} letters, which is more than 63).")]
|
||||
LabelTooLong {
|
||||
observed_length: usize,
|
||||
name: String,
|
||||
label: String,
|
||||
},
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that begins with an illegal character; it must be a letter.")]
|
||||
LabelStartsWrong { name: String, label: String },
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that ends with an illegal character; it must be a letter or number.")]
|
||||
LabelEndsWrong { name: String, label: String },
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that contains a non-letter, non-number, and non-dash.")]
|
||||
IllegalInternalCharacter { name: String, label: String },
|
||||
}
|
||||
|
||||
impl FromStr for Name {
|
||||
type Err = NameParseError;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
let observed_length = value.as_bytes().len();
|
||||
|
||||
if observed_length > 255 {
|
||||
return Err(NameParseError::NameTooLong {
|
||||
name: value.to_string(),
|
||||
observed_length,
|
||||
});
|
||||
}
|
||||
|
||||
for char in value.chars() {
|
||||
if !(char.is_ascii_alphanumeric() || char == '.' || char == '-') {
|
||||
return Err(NameParseError::NonAsciiCharacter {
|
||||
name: value.to_string(),
|
||||
illegal_character: char,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut labels = vec![];
|
||||
|
||||
for label_str in value.split('.') {
|
||||
if label_str.is_empty() {
|
||||
return Err(NameParseError::EmptyLabel {
|
||||
name: value.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if label_str.len() > 63 {
|
||||
return Err(NameParseError::LabelTooLong {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
observed_length: label_str.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let letter = |x| ('a'..='z').contains(&x) || ('A'..='Z').contains(&x);
|
||||
let letter_or_num = |x| letter(x) || ('0'..='9').contains(&x);
|
||||
let letter_num_dash = |x| letter_or_num(x) || (x == '-');
|
||||
|
||||
if !label_str.starts_with(letter) {
|
||||
return Err(NameParseError::LabelStartsWrong {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if !label_str.ends_with(letter_or_num) {
|
||||
return Err(NameParseError::LabelEndsWrong {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if label_str.contains(|x| !letter_num_dash(x)) {
|
||||
return Err(NameParseError::IllegalInternalCharacter {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// RFC 1035 says that all domain names are case-insensitive. We
|
||||
// arbitrarily normalize to lowercase here, because it shouldn't
|
||||
// matter to anyone.
|
||||
labels.push(Label(ArcIntern::new(label_str.to_string())));
|
||||
}
|
||||
|
||||
Ok(Name { labels })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NameReadError {
|
||||
#[error("Could not read name field out of an empty buffer.")]
|
||||
EmptyBuffer,
|
||||
|
||||
#[error("Buffer truncated before we could read the last label")]
|
||||
TruncatedBuffer,
|
||||
|
||||
#[error("Provided label value too long; must be 63 octets or less.")]
|
||||
LabelTooLong,
|
||||
|
||||
#[error("Label truncated while reading. Broken stream?")]
|
||||
LabelTruncated,
|
||||
|
||||
#[error("Label starts with an illegal character (must be [A-Za-z])")]
|
||||
WrongFirstByte,
|
||||
|
||||
#[error("Label ends with an illegal character (must be [A-Za-z0-9]")]
|
||||
WrongLastByte,
|
||||
|
||||
#[error("Label contains an illegal character (must be [A-Za-z0-9] or a dash)")]
|
||||
WrongInnerByte,
|
||||
}
|
||||
|
||||
impl Name {
|
||||
/// Read a name frm a record, or return an error describing what went wrong.
|
||||
///
|
||||
/// This function will advance the read pointer on the record. If the result is
|
||||
/// an error, the read pointer is not guaranteed to be in a good place, so you
|
||||
/// may need to manage that externally if you want to implement some sort of
|
||||
/// "try" functionality.
|
||||
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Name, NameReadError> {
|
||||
let mut labels = Vec::new();
|
||||
let name_so_far = |labels: Vec<Label>| {
|
||||
let mut result = String::new();
|
||||
|
||||
for Label(label) in labels.into_iter() {
|
||||
if result.is_empty() {
|
||||
result.push_str(label.as_str());
|
||||
} else {
|
||||
result.push('.');
|
||||
result.push_str(label.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
loop {
|
||||
if !buffer.has_remaining() && labels.is_empty() {
|
||||
return Err(report!(NameReadError::EmptyBuffer));
|
||||
}
|
||||
|
||||
if !buffer.has_remaining() {
|
||||
return Err(report!(NameReadError::TruncatedBuffer))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)));
|
||||
}
|
||||
|
||||
let label_octet_length = buffer.get_u8() as usize;
|
||||
if label_octet_length == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if label_octet_length > 63 {
|
||||
return Err(report!(NameReadError::LabelTooLong)).attach_printable_lazy(|| {
|
||||
format!(
|
||||
"label length too big; max is supposed to be 63, saw {label_octet_length}"
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
if !buffer.remaining() < label_octet_length {
|
||||
let remaining = buffer.copy_to_bytes(buffer.remaining());
|
||||
let partial = String::from_utf8_lossy(&remaining).to_string();
|
||||
|
||||
return Err(report!(NameReadError::LabelTruncated))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!(
|
||||
"expected {} octets, but only found {}",
|
||||
label_octet_length,
|
||||
buffer.remaining()
|
||||
)
|
||||
})
|
||||
.attach_printable_lazy(|| format!("partial read: '{partial}'"));
|
||||
}
|
||||
|
||||
let label_bytes = buffer.copy_to_bytes(label_octet_length);
|
||||
let Some(first_byte) = label_bytes.first() else {
|
||||
panic!(
|
||||
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
|
||||
);
|
||||
};
|
||||
let Some(last_byte) = label_bytes.last() else {
|
||||
panic!(
|
||||
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
|
||||
);
|
||||
};
|
||||
|
||||
let letter = |x| (b'a'..=b'z').contains(&x) || (b'A'..=b'Z').contains(&x);
|
||||
let letter_or_num = |x| letter(x) || (b'0'..=b'9').contains(&x);
|
||||
let letter_num_dash = |x| letter_or_num(x) || (x == b'-');
|
||||
|
||||
if !letter(*first_byte) {
|
||||
return Err(report!(NameReadError::WrongFirstByte))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||
});
|
||||
}
|
||||
|
||||
if !letter_or_num(*last_byte) {
|
||||
return Err(report!(NameReadError::WrongLastByte))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||
});
|
||||
}
|
||||
|
||||
if label_bytes.iter().any(|x| !letter_num_dash(*x)) {
|
||||
return Err(report!(NameReadError::WrongInnerByte))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||
});
|
||||
}
|
||||
|
||||
let label = label_bytes.into_iter().map(|x| x as char).collect();
|
||||
labels.push(Label(ArcIntern::new(label)));
|
||||
}
|
||||
|
||||
Ok(Name { labels })
|
||||
}
|
||||
|
||||
/// Write a name out to the given buffer.
|
||||
///
|
||||
/// This will try as hard as it can to write the value to the given buffer. If an
|
||||
/// error occurs, you may end up with partially-written data, so if you're worried
|
||||
/// about that you should be careful to mark where you started in the output buffer.
|
||||
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), NameWriteError> {
|
||||
for &Label(ref label) in self.labels.iter() {
|
||||
let bytes = label.as_bytes();
|
||||
|
||||
if buffer.remaining_mut() < (bytes.len() + 1) {
|
||||
return Err(report!(NameWriteError::NoRoomForLabel))
|
||||
.attach_printable_lazy(|| format!("Writing name {self}"))
|
||||
.attach_printable_lazy(|| format!("For label {label}"));
|
||||
}
|
||||
|
||||
if bytes.is_empty() || bytes.len() > 63 {
|
||||
return Err(report!(NameWriteError::IllegalLabel))
|
||||
.attach_printable_lazy(|| format!("Writing name {self}"))
|
||||
.attach_printable_lazy(|| format!("For label {label}"));
|
||||
}
|
||||
|
||||
buffer.put_u8(bytes.len() as u8);
|
||||
buffer.put_slice(bytes);
|
||||
}
|
||||
|
||||
if !buffer.has_remaining_mut() {
|
||||
return Err(report!(NameWriteError::NoRoomForNull))
|
||||
.attach_printable_lazy(|| format!("Writing name {self}"));
|
||||
}
|
||||
buffer.put_u8(0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NameWriteError {
|
||||
#[error("Not enough room to write label in name")]
|
||||
NoRoomForLabel,
|
||||
|
||||
#[error("Ran out of room writing the terminating NULL for the name")]
|
||||
NoRoomForNull,
|
||||
|
||||
#[error("Internal error: Illegal label (this shouldn't happen)")]
|
||||
IllegalLabel,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ArbitraryDomainNameSpecifications {
|
||||
total_length_range: Range<usize>,
|
||||
label_length_range: Range<usize>,
|
||||
number_labels: Range<usize>,
|
||||
}
|
||||
|
||||
impl Default for ArbitraryDomainNameSpecifications {
|
||||
fn default() -> Self {
|
||||
ArbitraryDomainNameSpecifications {
|
||||
total_length_range: 5..256,
|
||||
label_length_range: 1..64,
|
||||
number_labels: 2..5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for Name {
|
||||
type Parameters = ArbitraryDomainNameSpecifications;
|
||||
type Strategy = BoxedStrategy<Name>;
|
||||
|
||||
fn arbitrary_with(spec: Self::Parameters) -> Self::Strategy {
|
||||
(spec.number_labels.clone(), spec.total_length_range.clone())
|
||||
.prop_flat_map(move |(mut labels, mut total_length)| {
|
||||
// we need to make sure that our total length and our label count are,
|
||||
// at the very minimum, compatible. If they're not, we'll need to adjust
|
||||
// them.
|
||||
//
|
||||
// in general, we prefer to update our number of labels, if we can, but
|
||||
// only to the extent that the labels count is at least the minimum of
|
||||
// our input specification. if we try to go below that, we increase total
|
||||
// length as required. If we're forced into a place where we need to take
|
||||
// the label length below it's minimum and/or the total_length over its
|
||||
// maximum, we just give up and panic.
|
||||
//
|
||||
// Note that the minimum length of n labels is (n * label_minimum) + n - 1.
|
||||
// Consider, for example, a label_minimum of 1 and a label length of 3.
|
||||
// A minimum string is "a.b.c", which is (3 * 1) + 3 - 1 = 5 characters
|
||||
// long.
|
||||
//
|
||||
// This loop does the first part, lowering the number of labels until we
|
||||
// either get below the total length or reach the minimum number of labels.
|
||||
while (labels * spec.label_length_range.start) + (labels - 1) > total_length {
|
||||
if labels == spec.number_labels.start {
|
||||
break;
|
||||
} else {
|
||||
labels -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, if it's not right, we just set it to be right.
|
||||
if (labels * spec.total_length_range.start) + (labels - 1) > total_length {
|
||||
total_length = (labels * spec.total_length_range.start) + (labels - 1);
|
||||
}
|
||||
|
||||
// And if this takes us over our limit, just panic.
|
||||
if total_length >= spec.total_length_range.end {
|
||||
panic!("Unresolvable generation condition; couldn't resolve label count {} with total_length {}, with specification {:?}", labels, total_length, spec);
|
||||
}
|
||||
|
||||
proptest::collection::vec(Label::arbitrary_with(spec.label_length_range.clone()), labels)
|
||||
.prop_map(|labels| Name{ labels })
|
||||
}).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for Label {
|
||||
type Parameters = Range<usize>;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary() -> Self::Strategy {
|
||||
Self::arbitrary_with(1..64)
|
||||
}
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
args.prop_flat_map(|length| {
|
||||
let first = char_selector(&['a'..='z', 'A'..='Z']);
|
||||
let middle = char_selector(&['a'..='z', 'A'..='Z', '-'..='-', '0'..='9']);
|
||||
let last = char_selector(&['a'..='z', 'A'..='Z', '0'..='9']);
|
||||
|
||||
match length {
|
||||
0 => panic!("Should not be able to generate a label of length 0"),
|
||||
1 => first.prop_map(|x| x.into()).boxed(),
|
||||
2 => (first, last).prop_map(|(a, b)| format!("{a}{b}")).boxed(),
|
||||
_ => (first, proptest::collection::vec(middle, length - 2), last)
|
||||
.prop_map(move |(first, middle, last)| {
|
||||
let mut result = String::with_capacity(length);
|
||||
result.push(first);
|
||||
for c in middle.iter() {
|
||||
result.push(*c);
|
||||
}
|
||||
result.push(last);
|
||||
result
|
||||
})
|
||||
.boxed(),
|
||||
}
|
||||
})
|
||||
.prop_map(|x| Label(ArcIntern::new(x)))
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn char_selector<'a>(ranges: &'a [RangeInclusive<char>]) -> CharStrategy<'a> {
|
||||
CharStrategy::new(
|
||||
Cow::Borrowed(&[]),
|
||||
Cow::Borrowed(&[]),
|
||||
Cow::Borrowed(ranges),
|
||||
)
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn any_random_names_parses(name: Name) {
|
||||
let name_str = name.to_string();
|
||||
let new_name = Name::from_str(&name_str).expect("can re-parse name");
|
||||
assert_eq!(name, new_name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn any_random_name_roundtrips(name: Name) {
|
||||
let mut write_buffer = bytes::BytesMut::with_capacity(512);
|
||||
name.write(&mut write_buffer).expect("can write name");
|
||||
let mut read_buffer = write_buffer.freeze();
|
||||
let new_name = Name::read(&mut read_buffer).expect("can read name");
|
||||
assert_eq!(name, new_name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn illegal_names_generate_errors() {
|
||||
assert!(Name::from_str("").is_err());
|
||||
assert!(Name::from_str(".").is_err());
|
||||
assert!(Name::from_str(".com").is_err());
|
||||
assert!(Name::from_str("com.").is_err());
|
||||
assert!(Name::from_str("9.com").is_err());
|
||||
assert!(Name::from_str("foo-.com").is_err());
|
||||
assert!(Name::from_str("-foo.com").is_err());
|
||||
assert!(Name::from_str(".foo.com").is_err());
|
||||
assert!(Name::from_str("fo*o.com").is_err());
|
||||
assert!(Name::from_str(
|
||||
"foo.abcdefghiabcdefghiabcdefghiabcdefghiabcdefghiabcdefghijjjjjjabcdefghij.com"
|
||||
)
|
||||
.is_err());
|
||||
assert!(Name::from_str("abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn names_ignore_case() {
|
||||
assert_eq!(
|
||||
Name::from_str("UHSURE.COM").unwrap(),
|
||||
Name::from_str("uhsure.com").unwrap()
|
||||
);
|
||||
}
|
||||
9
resolver/src/protocol.rs
Normal file
9
resolver/src/protocol.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod client;
|
||||
mod header;
|
||||
pub mod question;
|
||||
mod request;
|
||||
mod resource_record;
|
||||
mod response;
|
||||
mod server;
|
||||
|
||||
pub use client::Client;
|
||||
273
resolver/src/protocol/client.rs
Normal file
273
resolver/src/protocol/client.rs
Normal file
@@ -0,0 +1,273 @@
|
||||
use crate::protocol::header::Header;
|
||||
use crate::protocol::question::Question;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use error_stack::ResultExt;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncReadExt, WriteHalf};
|
||||
use tokio::net::{TcpStream, UdpSocket, UnixDatagram, UnixStream};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
type Callback = fn() -> ();
|
||||
|
||||
pub struct Client {
|
||||
callback: Arc<RwLock<Callback>>,
|
||||
channel: DnsChannel,
|
||||
}
|
||||
|
||||
pub enum DnsChannel {
|
||||
Tcp(WriteHalf<TcpStream>),
|
||||
Udp(Arc<UdpSocket>),
|
||||
UnixData(Arc<UnixDatagram>),
|
||||
Unix(WriteHalf<UnixStream>),
|
||||
}
|
||||
|
||||
fn empty_callback() {}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SendError {}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum GeneralAddr {
|
||||
Network(SocketAddr),
|
||||
Unix(std::os::unix::net::SocketAddr),
|
||||
}
|
||||
|
||||
impl From<SocketAddr> for GeneralAddr {
|
||||
fn from(value: SocketAddr) -> Self {
|
||||
GeneralAddr::Network(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio::net::unix::SocketAddr> for GeneralAddr {
|
||||
fn from(value: tokio::net::unix::SocketAddr) -> Self {
|
||||
GeneralAddr::Unix(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum ServerProcessorError {
|
||||
#[error("Could not read message header from server.")]
|
||||
CouldNotReadHeader,
|
||||
}
|
||||
|
||||
async fn process_server_response(
|
||||
callback: &Arc<RwLock<Callback>>,
|
||||
mut bytes: Bytes,
|
||||
source: GeneralAddr,
|
||||
) -> error_stack::Result<(), ServerProcessorError> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn run_response_processing_loop(
|
||||
server: String,
|
||||
maximum_consecurive_errors: u64,
|
||||
callback: Arc<RwLock<Callback>>,
|
||||
mut fetcher: impl AsyncFnMut() -> Result<(Bytes, GeneralAddr), std::io::Error>,
|
||||
) {
|
||||
let mut consecutive_errors = 0;
|
||||
|
||||
loop {
|
||||
match fetcher().await {
|
||||
Ok((bytes, source)) => {
|
||||
if let Err(e) = process_server_response(&callback, bytes, source).await {
|
||||
consecutive_errors += 1;
|
||||
tracing::warn!(
|
||||
server,
|
||||
error = %e,
|
||||
maximum_consecurive_errors,
|
||||
consecutive_errors,
|
||||
"error processing DNS server response"
|
||||
);
|
||||
} else {
|
||||
consecutive_errors = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
consecutive_errors += 1;
|
||||
tracing::warn!(
|
||||
server,
|
||||
error = %e,
|
||||
maximum_consecurive_errors,
|
||||
consecutive_errors,
|
||||
"failed to read response from DNS server"
|
||||
);
|
||||
|
||||
if consecutive_errors >= maximum_consecurive_errors {
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
tracing::error!(
|
||||
server,
|
||||
"quitting DNS response processing loop due to too many consecutive errors"
|
||||
);
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new DNS client from the given, targeted UnixDatagram socket.
|
||||
///
|
||||
/// By "targeted", we mean that this socket should've had `connect` called on
|
||||
/// it, so that the client can just send datagrams to the server without having
|
||||
/// to know where to send them. It is unspecified what will happen to DNS clients
|
||||
/// if you do not call `connect` beforehand.
|
||||
pub async fn from_unix_datagram(socket: UnixDatagram, group: &mut JoinSet<()>) -> Self {
|
||||
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||
let socket = Arc::new(socket);
|
||||
|
||||
let reader_socket = socket.clone();
|
||||
let reader_callback = callback.clone();
|
||||
let server_addr = socket.local_addr()
|
||||
.ok()
|
||||
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
|
||||
.flatten()
|
||||
.unwrap_or_else(|| "<unknown>".into());
|
||||
let server = format!("unixd://{}", server_addr);
|
||||
|
||||
group.spawn(async move {
|
||||
let fetcher = async move || {
|
||||
let mut buffer = BytesMut::with_capacity(16384);
|
||||
|
||||
match reader_socket.recv_buf_from(&mut buffer).await {
|
||||
Err(x) => Err(x),
|
||||
Ok((size, from)) => {
|
||||
unsafe {
|
||||
buffer.set_len(size);
|
||||
};
|
||||
Ok((buffer.freeze(), from.into()))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||
});
|
||||
|
||||
Client {
|
||||
callback,
|
||||
channel: DnsChannel::UnixData(socket),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_unix_stream(
|
||||
socket: UnixStream,
|
||||
group: &mut JoinSet<()>,
|
||||
) -> std::io::Result<Self> {
|
||||
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||
let server_addr = socket.local_addr()
|
||||
.ok()
|
||||
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
|
||||
.flatten()
|
||||
.unwrap_or_else(|| "<unknown>".into());
|
||||
let server = format!("unix://{}", server_addr);
|
||||
let other_side = GeneralAddr::Unix(socket.peer_addr()?.into());
|
||||
let (mut reader, writer) = tokio::io::split(socket);
|
||||
|
||||
let reader_callback = callback.clone();
|
||||
group.spawn(async move {
|
||||
let fetcher = async move || {
|
||||
let size = reader.read_u16().await?;
|
||||
|
||||
let mut buffer = vec![0u8; size as usize];
|
||||
reader.read_exact(&mut buffer).await?;
|
||||
|
||||
Ok((Bytes::from(buffer), other_side.clone()))
|
||||
};
|
||||
|
||||
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||
});
|
||||
|
||||
Ok(Client {
|
||||
callback,
|
||||
channel: DnsChannel::Unix(writer),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn from_udp(socket: UdpSocket, group: &mut JoinSet<()>) -> Self {
|
||||
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||
let socket = Arc::new(socket);
|
||||
|
||||
let reader_callback = callback.clone();
|
||||
let reader_socket = socket.clone();
|
||||
let server_addr = socket.local_addr()
|
||||
.ok()
|
||||
.map(|x| x.to_string())
|
||||
.unwrap_or_else(|| "<unknown>".into());
|
||||
let server = format!("udp://{}", server_addr);
|
||||
|
||||
group.spawn(async move {
|
||||
let fetcher = async move || {
|
||||
let mut buffer = BytesMut::with_capacity(16384);
|
||||
|
||||
match reader_socket.recv_buf_from(&mut buffer).await {
|
||||
Err(x) => Err(x),
|
||||
Ok((size, from)) => {
|
||||
unsafe {
|
||||
buffer.set_len(size);
|
||||
};
|
||||
Ok((buffer.freeze(), from.into()))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||
});
|
||||
|
||||
Client {
|
||||
callback,
|
||||
channel: DnsChannel::Udp(socket),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_tcp(socket: TcpStream, group: &mut JoinSet<()>) -> std::io::Result<Self> {
|
||||
let callback = Arc::new(RwLock::new(empty_callback as Callback));
|
||||
let other_side = GeneralAddr::Network(socket.peer_addr()?);
|
||||
let server_addr = socket.local_addr()
|
||||
.ok()
|
||||
.map(|x| x.to_string())
|
||||
.unwrap_or_else(|| "<unknown>".into());
|
||||
let server = format!("tcp://{}", server_addr);
|
||||
let (mut reader, writer) = tokio::io::split(socket);
|
||||
|
||||
let reader_callback = callback.clone();
|
||||
group.spawn(async move {
|
||||
let fetcher = async move || {
|
||||
let size = reader.read_u16().await?;
|
||||
|
||||
let mut buffer = vec![0u8; size as usize];
|
||||
reader.read_exact(&mut buffer).await?;
|
||||
|
||||
Ok((Bytes::from(buffer), other_side.clone()))
|
||||
};
|
||||
|
||||
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
|
||||
});
|
||||
|
||||
Ok(Client {
|
||||
callback,
|
||||
channel: DnsChannel::Tcp(writer),
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a set of questions to the upstream server.
|
||||
///
|
||||
/// Any response(s) that is/are sent will be handled as part of the callback
|
||||
/// scheme. This function will thus only return an error if there's a problem
|
||||
/// sending the question to the server.
|
||||
pub async fn send_questions(_questions: Vec<Question>) -> error_stack::Result<(), SendError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set the callback handler for when we receive responses from the server.
|
||||
///
|
||||
/// This function may take awhile to execute, depending on how busy we are
|
||||
/// taking responses, as it takes write ownership of a read-write lock that's
|
||||
/// almost always written.
|
||||
pub async fn set_callback(&self, callback: Callback) {
|
||||
*self.callback.write().await = callback;
|
||||
}
|
||||
}
|
||||
225
resolver/src/protocol/header.rs
Normal file
225
resolver/src/protocol/header.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use error_stack::{report, ResultExt};
|
||||
use num_enum::{FromPrimitive, IntoPrimitive};
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
||||
use std::fmt;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||
pub struct Header {
|
||||
pub message_id: u16,
|
||||
pub is_response: bool,
|
||||
pub opcode: OpCode,
|
||||
pub authoritative_answer: bool,
|
||||
pub message_truncated: bool,
|
||||
pub recursion_desired: bool,
|
||||
pub recursion_available: bool,
|
||||
pub response_code: ResponseCode,
|
||||
pub question_count: u16,
|
||||
pub answer_count: u16,
|
||||
pub name_server_count: u16,
|
||||
pub additional_record_count: u16,
|
||||
}
|
||||
|
||||
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
pub enum OpCode {
|
||||
StandardQuery = 0,
|
||||
InverseQuery = 1,
|
||||
ServiceStatusRequest = 2,
|
||||
#[num_enum(catch_all)]
|
||||
Other(u8),
|
||||
}
|
||||
|
||||
impl Arbitrary for OpCode {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
|
||||
// while the type is 8 bits in rust, it's 4 bits in the protocol,
|
||||
// and dealing with the run-off is messy. so this only generates
|
||||
// valid-sized values here. it also biases toward the legit values.
|
||||
// both things might want to be reconsidered in the future.
|
||||
proptest::prop_oneof![
|
||||
Just(OpCode::StandardQuery),
|
||||
Just(OpCode::InverseQuery),
|
||||
Just(OpCode::ServiceStatusRequest),
|
||||
(3u8..=15).prop_map(OpCode::Other),
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
pub enum ResponseCode {
|
||||
NoErrorConditions = 0,
|
||||
FormatError = 1,
|
||||
ServerFailure = 2,
|
||||
NameError = 3,
|
||||
NotImplemented = 4,
|
||||
Refused = 5,
|
||||
#[num_enum(catch_all)]
|
||||
Other(u8),
|
||||
}
|
||||
|
||||
impl fmt::Display for ResponseCode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ResponseCode::NoErrorConditions => write!(f, "No errors"),
|
||||
ResponseCode::FormatError => write!(f, "Illegal format"),
|
||||
ResponseCode::ServerFailure => write!(f, "Server failure"),
|
||||
ResponseCode::NameError => write!(f, "Name error"),
|
||||
ResponseCode::NotImplemented => write!(f, "Not implemented"),
|
||||
ResponseCode::Refused => write!(f, "Refused"),
|
||||
ResponseCode::Other(x) => write!(f, "unknown error {x}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for ResponseCode {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
|
||||
// while the type is 8 bits in rust, it's 4 bits in the protocol,
|
||||
// and dealing with the run-off is messy. so this only generates
|
||||
// valid-sized values here. it also biases toward the legit values.
|
||||
// both things might want to be reconsidered in the future.
|
||||
proptest::prop_oneof![
|
||||
Just(ResponseCode::NoErrorConditions),
|
||||
Just(ResponseCode::FormatError),
|
||||
Just(ResponseCode::ServerFailure),
|
||||
Just(ResponseCode::NameError),
|
||||
Just(ResponseCode::NotImplemented),
|
||||
Just(ResponseCode::Refused),
|
||||
(6u8..=15).prop_map(ResponseCode::Other),
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HeaderReadError {
|
||||
#[error("Buffer not large enough to have a DNS message header in it.")]
|
||||
BufferTooSmall,
|
||||
|
||||
#[error("Invalid data in zero-filled space.")]
|
||||
NonZeroInZeroes,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HeaderWriteError {
|
||||
#[error("Buffer not large enough to write header.")]
|
||||
BufferTooSmall,
|
||||
}
|
||||
|
||||
impl Header {
|
||||
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, HeaderReadError> {
|
||||
if buffer.remaining() < 12
|
||||
/* 6 16-bit fields = 6 * 2 = 1 */
|
||||
{
|
||||
return Err(report!(HeaderReadError::BufferTooSmall)).attach_printable_lazy(|| {
|
||||
format!(
|
||||
"Need at least {} bytes, but only had {}",
|
||||
6 * 12,
|
||||
buffer.remaining()
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
let message_id = buffer.get_u16();
|
||||
let flags = buffer.get_u16();
|
||||
let question_count = buffer.get_u16();
|
||||
let answer_count = buffer.get_u16();
|
||||
let name_server_count = buffer.get_u16();
|
||||
let additional_record_count = buffer.get_u16();
|
||||
|
||||
let is_response = (0x8000 & flags) != 0;
|
||||
let opcode = OpCode::from(((flags >> 11) & 0xF) as u8);
|
||||
let authoritative_answer = (0x0400 & flags) != 0;
|
||||
let message_truncated = (0x0200 & flags) != 0;
|
||||
let recursion_desired = (0x0100 & flags) != 0;
|
||||
let recursion_available = (0x0080 & flags) != 0;
|
||||
let zeroes = 0x0070 & flags;
|
||||
let response_code = ResponseCode::from((flags & 0x000F) as u8);
|
||||
|
||||
if zeroes != 0 {
|
||||
return Err(report!(HeaderReadError::NonZeroInZeroes))
|
||||
.attach_printable_lazy(|| format!("Saw {:#x} instead.", zeroes >> 4));
|
||||
}
|
||||
|
||||
Ok(Header {
|
||||
message_id,
|
||||
is_response,
|
||||
opcode,
|
||||
authoritative_answer,
|
||||
message_truncated,
|
||||
recursion_desired,
|
||||
recursion_available,
|
||||
response_code,
|
||||
question_count,
|
||||
answer_count,
|
||||
name_server_count,
|
||||
additional_record_count,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), HeaderWriteError> {
|
||||
if buffer.remaining_mut() < 12
|
||||
/* 6 16-bit fields = 6 * 2 = 1 */
|
||||
{
|
||||
return Err(report!(HeaderWriteError::BufferTooSmall)).attach_printable_lazy(|| {
|
||||
format!(
|
||||
"Need at least {} to write DNS header, only have {}",
|
||||
6 * 12,
|
||||
buffer.remaining_mut()
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
let mut flags: u16 = 0;
|
||||
|
||||
if self.is_response {
|
||||
flags |= 0x8000;
|
||||
}
|
||||
let opcode: u8 = self.opcode.into();
|
||||
flags |= (opcode as u16) << 11;
|
||||
if self.authoritative_answer {
|
||||
flags |= 0x0400;
|
||||
}
|
||||
if self.message_truncated {
|
||||
flags |= 0x0200;
|
||||
}
|
||||
if self.recursion_desired {
|
||||
flags |= 0x0100;
|
||||
}
|
||||
if self.recursion_available {
|
||||
flags |= 0x0080;
|
||||
}
|
||||
let response_code: u8 = self.response_code.into();
|
||||
flags |= response_code as u16;
|
||||
|
||||
buffer.put_u16(self.message_id);
|
||||
buffer.put_u16(flags);
|
||||
buffer.put_u16(self.question_count);
|
||||
buffer.put_u16(self.answer_count);
|
||||
buffer.put_u16(self.name_server_count);
|
||||
buffer.put_u16(self.additional_record_count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn headers_roundtrip(header: Header) {
|
||||
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||
let safe_header = header.clone();
|
||||
header.write(&mut write_buffer).expect("can write name");
|
||||
let mut read_buffer = write_buffer.freeze();
|
||||
let new_header = Header::read(&mut read_buffer).expect("can read name");
|
||||
assert_eq!(safe_header, new_header);
|
||||
}
|
||||
}
|
||||
91
resolver/src/protocol/question.rs
Normal file
91
resolver/src/protocol/question.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use crate::name::Name;
|
||||
use crate::protocol::resource_record::raw::{RecordClass, RecordType};
|
||||
use bytes::{Buf, BufMut, TryGetError};
|
||||
use error_stack::{report, ResultExt};
|
||||
use std::fmt;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||
pub struct Question {
|
||||
name: Name,
|
||||
record_type: RecordType,
|
||||
record_class: RecordClass,
|
||||
}
|
||||
|
||||
impl fmt::Display for Question {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "<Q:{}@{}.{}>", self.name, self.record_type, self.record_class)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum QuestionReadError {
|
||||
#[error("Could not read name for question.")]
|
||||
CouldNotReadName,
|
||||
|
||||
#[error("Could not read the record type for the question: {0}")]
|
||||
CouldNotReadType(TryGetError),
|
||||
|
||||
#[error("Could not read the record class for the question.")]
|
||||
CouldNotReadClass(TryGetError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum QuestionWriteError {
|
||||
#[error("Could not write name for the question.")]
|
||||
CouldNotWriteName,
|
||||
|
||||
#[error("Buffer not large enough to write question type and class.")]
|
||||
BufferTooSmall,
|
||||
}
|
||||
|
||||
impl Question {
|
||||
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, QuestionReadError> {
|
||||
let name = Name::read(buffer).change_context(QuestionReadError::CouldNotReadName)?;
|
||||
let record_type_u16 = buffer
|
||||
.try_get_u16()
|
||||
.map_err(|e| report!(QuestionReadError::CouldNotReadType(e)))
|
||||
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
|
||||
let record_class_u16 = buffer
|
||||
.try_get_u16()
|
||||
.map_err(|e| report!(QuestionReadError::CouldNotReadClass(e)))
|
||||
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
|
||||
|
||||
let record_type = RecordType::from(record_type_u16);
|
||||
let record_class = RecordClass::from(record_class_u16);
|
||||
|
||||
Ok(Question {
|
||||
name,
|
||||
record_type,
|
||||
record_class,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), QuestionWriteError> {
|
||||
self.name
|
||||
.write(buffer)
|
||||
.change_context(QuestionWriteError::CouldNotWriteName)
|
||||
.attach_printable_lazy(|| format!("Question was about '{}'", self.name))?;
|
||||
|
||||
if buffer.remaining_mut() < 4 {
|
||||
return Err(report!(QuestionWriteError::BufferTooSmall))
|
||||
.attach_printable_lazy(|| format!("Question was about '{}'", self.name));
|
||||
}
|
||||
|
||||
buffer.put_u16(self.record_type.into());
|
||||
buffer.put_u16(self.record_class.into());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn questions_roundtrip(question: Question) {
|
||||
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||
question.write(&mut write_buffer).expect("can write name");
|
||||
let mut read_buffer = write_buffer.freeze();
|
||||
let new_question = Question::read(&mut read_buffer).expect("can read name");
|
||||
assert_eq!(question, new_question);
|
||||
}
|
||||
}
|
||||
156
resolver/src/protocol/request.rs
Normal file
156
resolver/src/protocol/request.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use error_stack::{report, ResultExt};
|
||||
use crate::protocol::header::{Header, OpCode, ResponseCode};
|
||||
use crate::protocol::question::Question;
|
||||
use crate::protocol::resource_record::ResourceRecord;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||
pub struct Request {
|
||||
source_message_id: u16,
|
||||
opcode: OpCode,
|
||||
recursion_desired: bool,
|
||||
questions: Vec<Question>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RequestReadError {
|
||||
#[error("Error reading request header.")]
|
||||
Header,
|
||||
#[error("Message isn't a request.")]
|
||||
NotRequest,
|
||||
#[error("Illegal message.")]
|
||||
IllegalMessage,
|
||||
#[error("Error reading request question.")]
|
||||
Question,
|
||||
#[error("Error reading request answers (?!)")]
|
||||
Answer,
|
||||
#[error("Error reading requst name servers.")]
|
||||
NameServers,
|
||||
#[error("Reading request additional records.")]
|
||||
AdditionalRecords,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RequestWriteError {
|
||||
#[error("Error writing request header.")]
|
||||
Header,
|
||||
#[error("Error writing request question.")]
|
||||
Question,
|
||||
#[error("Request had too many questions.")]
|
||||
TooManyQuestions,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, RequestReadError> {
|
||||
let header = Header::read(buffer)
|
||||
.change_context(RequestReadError::Header)?;
|
||||
|
||||
if header.is_response {
|
||||
return Err(report!(RequestReadError::NotRequest))
|
||||
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||
}
|
||||
|
||||
if header.authoritative_answer {
|
||||
return Err(report!(RequestReadError::IllegalMessage))
|
||||
.attach_printable("Request messages are not allowed to set the 'authoritative answer' bit.")
|
||||
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||
}
|
||||
|
||||
if header.response_code != ResponseCode::NoErrorConditions {
|
||||
return Err(report!(RequestReadError::IllegalMessage))
|
||||
.attach_printable("Request messages are not allowed to set the response code.")
|
||||
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||
}
|
||||
|
||||
if header.answer_count != 0 {
|
||||
return Err(report!(RequestReadError::IllegalMessage))
|
||||
.attach_printable("Request messages are not allowed to include answers.")
|
||||
.attach_printable_lazy(|| format!("{} answers declared", header.answer_count))
|
||||
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||
}
|
||||
|
||||
if header.name_server_count != 0 {
|
||||
return Err(report!(RequestReadError::IllegalMessage))
|
||||
.attach_printable("Request messages are not allowed to include name servers.")
|
||||
.attach_printable_lazy(|| format!("{} name servers declared", header.name_server_count))
|
||||
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||
}
|
||||
|
||||
if header.additional_record_count != 0 {
|
||||
return Err(report!(RequestReadError::IllegalMessage))
|
||||
.attach_printable("Request messages are not allowed to include additional records.")
|
||||
.attach_printable_lazy(|| format!("{} aditional records declared", header.additional_record_count))
|
||||
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
|
||||
}
|
||||
|
||||
|
||||
let mut questions = vec![];
|
||||
|
||||
for i in 0..header.question_count {
|
||||
let question = Question::read(buffer)
|
||||
.change_context(RequestReadError::Question)
|
||||
.attach_printable_lazy(|| format!("for question #{} of {}", i+1, header.question_count))
|
||||
.attach_printable_lazy(|| format!("message id is {}", header.message_id))?;
|
||||
|
||||
questions.push(question);
|
||||
}
|
||||
|
||||
|
||||
Ok(Request {
|
||||
source_message_id: header.message_id,
|
||||
opcode: header.opcode,
|
||||
recursion_desired: header.recursion_desired,
|
||||
questions,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), RequestWriteError> {
|
||||
let question_count = self.questions.len();
|
||||
|
||||
if question_count > (u16::MAX as usize) {
|
||||
return Err(report!(RequestWriteError::TooManyQuestions))
|
||||
.attach_printable(format!("message_id is {}", self.source_message_id));
|
||||
}
|
||||
|
||||
let header = Header {
|
||||
message_id: self.source_message_id,
|
||||
is_response: false,
|
||||
opcode: self.opcode,
|
||||
authoritative_answer: false,
|
||||
message_truncated: false,
|
||||
recursion_desired: self.recursion_desired,
|
||||
recursion_available: false,
|
||||
response_code: ResponseCode::NoErrorConditions,
|
||||
question_count: question_count as u16,
|
||||
answer_count: 0,
|
||||
name_server_count: 0,
|
||||
additional_record_count: 0,
|
||||
};
|
||||
|
||||
header.write(buffer)
|
||||
.change_context(RequestWriteError::Header)
|
||||
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))?;
|
||||
|
||||
for (index, question) in self.questions.into_iter().enumerate() {
|
||||
question.write(buffer)
|
||||
.change_context(RequestWriteError::Question)
|
||||
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))
|
||||
.attach_printable_lazy(|| format!("question #{} of {}", index+1, question_count))
|
||||
.attach_printable_lazy(|| format!("{}", question))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn request_roundtrip(request: Request) {
|
||||
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||
let safe_request = request.clone();
|
||||
request.write(&mut write_buffer).expect("can write name");
|
||||
let mut read_buffer = write_buffer.freeze();
|
||||
let new_request = Request::read(&mut read_buffer).expect("can read name");
|
||||
assert_eq!(safe_request, new_request);
|
||||
}
|
||||
}
|
||||
1273
resolver/src/protocol/resource_record.rs
Normal file
1273
resolver/src/protocol/resource_record.rs
Normal file
File diff suppressed because it is too large
Load Diff
304
resolver/src/protocol/resource_record/raw.rs
Normal file
304
resolver/src/protocol/resource_record/raw.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
use crate::name::Name;
|
||||
use bytes::{Buf, BufMut, Bytes};
|
||||
use error_stack::{report, ResultExt};
|
||||
use num_enum::{FromPrimitive, IntoPrimitive};
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
||||
use std::fmt;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct RawResourceRecord {
|
||||
pub name: Name,
|
||||
pub record_type: RecordType,
|
||||
pub record_class: RecordClass,
|
||||
pub ttl: u32,
|
||||
pub data: Bytes,
|
||||
}
|
||||
|
||||
impl Arbitrary for RawResourceRecord {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
|
||||
(
|
||||
Name::arbitrary(),
|
||||
RecordType::arbitrary(),
|
||||
RecordClass::arbitrary(),
|
||||
u32::arbitrary(),
|
||||
proptest::collection::vec(u8::arbitrary(), 0..65535),
|
||||
)
|
||||
.prop_map(
|
||||
|(name, record_type, record_class, ttl, data)| RawResourceRecord {
|
||||
name,
|
||||
record_type,
|
||||
record_class,
|
||||
ttl,
|
||||
data: data.into(),
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||
#[repr(u16)]
|
||||
pub enum RecordType {
|
||||
A = 1,
|
||||
AAAA = 28,
|
||||
NS = 2,
|
||||
MD = 3,
|
||||
MF = 4,
|
||||
CNAME = 5,
|
||||
SOA = 6,
|
||||
MB = 7,
|
||||
MG = 8,
|
||||
MR = 9,
|
||||
NULL = 10,
|
||||
WKS = 11,
|
||||
PTR = 12,
|
||||
HINFO = 13,
|
||||
MINFO = 14,
|
||||
MX = 15,
|
||||
TXT = 16,
|
||||
URI = 256,
|
||||
#[num_enum(catch_all)]
|
||||
Other(u16),
|
||||
}
|
||||
|
||||
impl fmt::Display for RecordType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
RecordType::A => write!(f, "A"),
|
||||
RecordType::AAAA => write!(f, "AAAA"),
|
||||
RecordType::NS => write!(f, "NS"),
|
||||
RecordType::MD => write!(f, "MD"),
|
||||
RecordType::MF => write!(f, "MF"),
|
||||
RecordType::CNAME => write!(f, "CNAME"),
|
||||
RecordType::SOA => write!(f, "SOA"),
|
||||
RecordType::MB => write!(f, "MB"),
|
||||
RecordType::MG => write!(f, "MG"),
|
||||
RecordType::MR => write!(f, "MR"),
|
||||
RecordType::NULL => write!(f, "NULL"),
|
||||
RecordType::WKS => write!(f, "WKS"),
|
||||
RecordType::PTR => write!(f, "PTR"),
|
||||
RecordType::HINFO => write!(f, "HINFO"),
|
||||
RecordType::MINFO => write!(f, "MINFO"),
|
||||
RecordType::MX => write!(f, "MX"),
|
||||
RecordType::TXT => write!(f, "TXT"),
|
||||
RecordType::URI => write!(f, "URI"),
|
||||
RecordType::Other(x) => write!(f, "UNKNOWN<{x}>"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for RecordType {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
|
||||
// this is intentionally biased towards the legit values
|
||||
proptest::prop_oneof![
|
||||
Just(RecordType::A),
|
||||
Just(RecordType::AAAA),
|
||||
Just(RecordType::NS),
|
||||
Just(RecordType::MD),
|
||||
Just(RecordType::MF),
|
||||
Just(RecordType::CNAME),
|
||||
Just(RecordType::SOA),
|
||||
Just(RecordType::MB),
|
||||
Just(RecordType::MG),
|
||||
Just(RecordType::MR),
|
||||
Just(RecordType::NULL),
|
||||
Just(RecordType::WKS),
|
||||
Just(RecordType::PTR),
|
||||
Just(RecordType::HINFO),
|
||||
Just(RecordType::MINFO),
|
||||
Just(RecordType::MX),
|
||||
Just(RecordType::TXT),
|
||||
proptest::prop_oneof![
|
||||
(17u16..28).prop_map(|x| RecordType::Other(x)),
|
||||
(29u16..256).prop_map(|x| RecordType::Other(x)),
|
||||
(257u16..=65535).prop_map(|x| RecordType::Other(x)),
|
||||
Just(RecordType::Other(0)),
|
||||
],
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
|
||||
#[repr(u16)]
|
||||
pub enum RecordClass {
|
||||
IN = 1,
|
||||
CS = 2,
|
||||
CH = 3,
|
||||
HS = 4,
|
||||
#[num_enum(catch_all)]
|
||||
Other(u16),
|
||||
}
|
||||
|
||||
impl fmt::Display for RecordClass {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
RecordClass::IN => write!(f, "IN"),
|
||||
RecordClass::CS => write!(f, "CS"),
|
||||
RecordClass::CH => write!(f, "CH"),
|
||||
RecordClass::HS => write!(f, "HS"),
|
||||
RecordClass::Other(x) => write!(f, "UNKNOWN<{x}>"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for RecordClass {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
|
||||
// this is intentionally biased towards the legit values
|
||||
proptest::prop_oneof![
|
||||
Just(RecordClass::IN),
|
||||
Just(RecordClass::CS),
|
||||
Just(RecordClass::CH),
|
||||
Just(RecordClass::HS),
|
||||
(5u16..=65535).prop_map(|x| RecordClass::Other(x)),
|
||||
Just(RecordClass::Other(0)),
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ResourceRecordReadError {
|
||||
#[error("Failed to read initial record name.")]
|
||||
InitialRecord,
|
||||
|
||||
#[error("Resource record truncated; couldn't find its type field.")]
|
||||
NoTypeField,
|
||||
|
||||
#[error("Resource record truncated; couldn't find its class field.")]
|
||||
NoClassField,
|
||||
|
||||
#[error("Resource record truncated; couldn't find its TTL field.")]
|
||||
NoTtl,
|
||||
|
||||
#[error("Resource record truncated; couldn't find its data length.")]
|
||||
NoDataLength,
|
||||
|
||||
#[error("Resource record truncated; couldn't read its entire data field.")]
|
||||
DataTruncated,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ResourceRecordWriteError {
|
||||
#[error("Could not write name to the output record.")]
|
||||
CouldNotWriteName,
|
||||
|
||||
#[error("Could not write resource record type and class to output record.")]
|
||||
CouldNotWriteTypeClass,
|
||||
|
||||
#[error("Could not write TTL to output record.")]
|
||||
CountNotWriteTtl,
|
||||
|
||||
#[error("Could not write resource record data length to output record.")]
|
||||
CountNotWriteDataLength,
|
||||
|
||||
#[error("Could not write resource record data to output record.")]
|
||||
CountNotWriteData,
|
||||
|
||||
#[error("Input data was too large to write to output stream.")]
|
||||
InputDataTooLarge,
|
||||
}
|
||||
|
||||
impl RawResourceRecord {
|
||||
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, ResourceRecordReadError> {
|
||||
let name = Name::read(buffer).change_context(ResourceRecordReadError::InitialRecord)?;
|
||||
|
||||
let record_type = buffer
|
||||
.try_get_u16()
|
||||
.map_err(|_| report!(ResourceRecordReadError::NoTypeField))?
|
||||
.into();
|
||||
let record_class = buffer
|
||||
.try_get_u16()
|
||||
.map_err(|_| report!(ResourceRecordReadError::NoClassField))?
|
||||
.into();
|
||||
let ttl = buffer
|
||||
.try_get_u32()
|
||||
.map_err(|_| report!(ResourceRecordReadError::NoTtl))?;
|
||||
let rdata_length = buffer
|
||||
.try_get_u16()
|
||||
.map_err(|_| report!(ResourceRecordReadError::NoDataLength))?;
|
||||
|
||||
if buffer.remaining() < (rdata_length as usize) {
|
||||
return Err(report!(ResourceRecordReadError::DataTruncated)).attach_printable_lazy(
|
||||
|| {
|
||||
format!(
|
||||
"Expected {rdata_length} bytes, but only saw {}",
|
||||
buffer.remaining()
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
let data = buffer.copy_to_bytes(rdata_length as usize);
|
||||
|
||||
Ok(RawResourceRecord {
|
||||
name,
|
||||
record_type,
|
||||
record_class,
|
||||
ttl,
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn write<B: BufMut>(
|
||||
&self,
|
||||
buffer: &mut B,
|
||||
) -> error_stack::Result<(), ResourceRecordWriteError> {
|
||||
self.name
|
||||
.write(buffer)
|
||||
.change_context(ResourceRecordWriteError::CouldNotWriteName)?;
|
||||
|
||||
if buffer.remaining_mut() < 4 {
|
||||
return Err(report!(ResourceRecordWriteError::CouldNotWriteTypeClass));
|
||||
}
|
||||
buffer.put_u16(self.record_type.into());
|
||||
buffer.put_u16(self.record_class.into());
|
||||
|
||||
if buffer.remaining_mut() < 4 {
|
||||
return Err(report!(ResourceRecordWriteError::CountNotWriteTtl));
|
||||
}
|
||||
buffer.put_u32(self.ttl);
|
||||
|
||||
if buffer.remaining_mut() < 2 {
|
||||
return Err(report!(ResourceRecordWriteError::CountNotWriteDataLength));
|
||||
}
|
||||
if self.data.len() > (u16::MAX as usize) {
|
||||
return Err(report!(ResourceRecordWriteError::InputDataTooLarge))
|
||||
.attach_printable_lazy(|| {
|
||||
format!(
|
||||
"Incoming data was {} bytes, needs to be < 2^16",
|
||||
self.data.len()
|
||||
)
|
||||
});
|
||||
}
|
||||
buffer.put_u16(self.data.len() as u16);
|
||||
|
||||
if buffer.remaining_mut() < self.data.len() {
|
||||
return Err(report!(ResourceRecordWriteError::CountNotWriteData));
|
||||
}
|
||||
buffer.put_slice(&self.data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn any_random_record_roundtrips(record: RawResourceRecord) {
|
||||
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||
record.write(&mut write_buffer).expect("can write name");
|
||||
let mut read_buffer = write_buffer.freeze();
|
||||
let new_record = RawResourceRecord::read(&mut read_buffer).expect("can read name");
|
||||
assert_eq!(record, new_record);
|
||||
}
|
||||
}
|
||||
258
resolver/src/protocol/response.rs
Normal file
258
resolver/src/protocol/response.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use error_stack::ResultExt;
|
||||
use crate::protocol::header::{Header, ResponseCode};
|
||||
use crate::protocol::question::Question;
|
||||
use crate::protocol::resource_record::ResourceRecord;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
|
||||
pub enum Response {
|
||||
Valid {
|
||||
source_message_id: u16,
|
||||
authoritative: bool,
|
||||
truncated: bool,
|
||||
answers: Vec<ResourceRecord>,
|
||||
name_servers: Vec<ResourceRecord>,
|
||||
additional_records: Vec<ResourceRecord>,
|
||||
},
|
||||
FormatError {
|
||||
source_message_id: u16,
|
||||
},
|
||||
ServerFailure {
|
||||
source_message_id: u16,
|
||||
},
|
||||
NameError {
|
||||
source_message_id: u16,
|
||||
},
|
||||
NotImplemented {
|
||||
source_message_id: u16,
|
||||
},
|
||||
Refused {
|
||||
source_message_id: u16,
|
||||
},
|
||||
UnknownError {
|
||||
#[proptest(strategy="6u8..=15")]
|
||||
error_code: u8,
|
||||
source_message_id: u16,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ResponseReadError {
|
||||
#[error("Could not read response header.")]
|
||||
HeaderReadError,
|
||||
#[error("Could not read question included in response.")]
|
||||
QuestionReadError,
|
||||
#[error("Could not read resource record included as an answer.")]
|
||||
AnswerReadError,
|
||||
#[error("Could not read name server record included in response.")]
|
||||
NameServerReadError,
|
||||
#[error("Could not read supplemental record included in response.")]
|
||||
AdditionalInfoReadError,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ResponseWriteError {
|
||||
#[error("Could not write response header.")]
|
||||
Header,
|
||||
#[error("Could not write answer.")]
|
||||
Answer,
|
||||
#[error("Could not write name server.")]
|
||||
NameServer,
|
||||
#[error("Could not write additional record.")]
|
||||
AdditionalRecord,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn message_id(&self) -> u16 {
|
||||
match self {
|
||||
Response::Valid { source_message_id, .. } => *source_message_id,
|
||||
Response::FormatError { source_message_id } => *source_message_id,
|
||||
Response::ServerFailure { source_message_id } => *source_message_id,
|
||||
Response::NameError { source_message_id } => *source_message_id,
|
||||
Response::NotImplemented { source_message_id } => *source_message_id,
|
||||
Response::Refused { source_message_id } => *source_message_id,
|
||||
Response::UnknownError { source_message_id, .. } => *source_message_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Response, ResponseReadError> {
|
||||
let header = Header::read(buffer)
|
||||
.change_context(ResponseReadError::HeaderReadError)?;
|
||||
|
||||
// check for errors, and short-cut out if we find any
|
||||
match header.response_code {
|
||||
ResponseCode::NoErrorConditions => {}
|
||||
|
||||
ResponseCode::FormatError => return Ok(Response::FormatError {
|
||||
source_message_id: header.message_id,
|
||||
}),
|
||||
|
||||
ResponseCode::ServerFailure => return Ok(Response::ServerFailure {
|
||||
source_message_id: header.message_id,
|
||||
}),
|
||||
|
||||
ResponseCode::NameError => return Ok(Response::NameError {
|
||||
source_message_id: header.message_id,
|
||||
}),
|
||||
|
||||
ResponseCode::NotImplemented => return Ok(Response::NotImplemented {
|
||||
source_message_id: header.message_id,
|
||||
}),
|
||||
|
||||
ResponseCode::Refused => return Ok(Response::Refused {
|
||||
source_message_id: header.message_id,
|
||||
}),
|
||||
|
||||
ResponseCode::Other(error_code) => return Ok(Response::UnknownError {
|
||||
error_code,
|
||||
source_message_id: header.message_id,
|
||||
}),
|
||||
}
|
||||
|
||||
// it seems weird to get questions in a response, but we need to parse
|
||||
// them out if they exist.
|
||||
for _ in 0..header.question_count {
|
||||
let question = Question::read(buffer)
|
||||
.change_context(ResponseReadError::QuestionReadError)?;
|
||||
|
||||
tracing::warn!(
|
||||
%question,
|
||||
"got question during server response."
|
||||
);
|
||||
}
|
||||
|
||||
let mut answers = vec![];
|
||||
for idx in 0..header.answer_count {
|
||||
let answer = ResourceRecord::read(buffer)
|
||||
.change_context(ResponseReadError::AnswerReadError)
|
||||
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.answer_count))?;
|
||||
|
||||
answers.push(answer);
|
||||
}
|
||||
|
||||
let mut name_servers = vec![];
|
||||
for idx in 0..header.name_server_count {
|
||||
let name_server = ResourceRecord::read(buffer)
|
||||
.change_context(ResponseReadError::NameServerReadError)
|
||||
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.name_server_count))?;
|
||||
|
||||
name_servers.push(name_server);
|
||||
}
|
||||
|
||||
let mut additional_records = vec![];
|
||||
for idx in 0..header.additional_record_count {
|
||||
let extra = ResourceRecord::read(buffer)
|
||||
.change_context(ResponseReadError::AnswerReadError)
|
||||
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.additional_record_count))?;
|
||||
|
||||
additional_records.push(extra);
|
||||
}
|
||||
|
||||
Ok(Response::Valid {
|
||||
source_message_id: header.message_id,
|
||||
authoritative: header.authoritative_answer,
|
||||
truncated: header.message_truncated,
|
||||
answers,
|
||||
name_servers,
|
||||
additional_records,
|
||||
})
|
||||
}
|
||||
|
||||
fn write_error<B: BufMut>(source_message_id: u16, error_code: ResponseCode, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
|
||||
let header = Header {
|
||||
message_id: source_message_id,
|
||||
is_response: true,
|
||||
opcode: super::header::OpCode::StandardQuery,
|
||||
authoritative_answer: false,
|
||||
message_truncated: false,
|
||||
recursion_desired: false,
|
||||
recursion_available: false,
|
||||
response_code: error_code,
|
||||
question_count: 0,
|
||||
answer_count: 0,
|
||||
name_server_count: 0,
|
||||
additional_record_count: 0,
|
||||
};
|
||||
|
||||
header
|
||||
.write(buffer)
|
||||
.change_context(ResponseWriteError::Header)
|
||||
.attach_printable_lazy(|| format!("responding to message {source_message_id} with error code {error_code}"))
|
||||
}
|
||||
|
||||
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
|
||||
let (source_message_id, authoritative, truncated, answers, name_servers, additional_records) = match self {
|
||||
Response::FormatError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::FormatError, buffer),
|
||||
|
||||
Response::ServerFailure { source_message_id } => return Self::write_error(source_message_id, ResponseCode::ServerFailure, buffer),
|
||||
|
||||
Response::NameError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NameError, buffer),
|
||||
|
||||
Response::NotImplemented { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NotImplemented, buffer),
|
||||
|
||||
Response::Refused { source_message_id } => return Self::write_error(source_message_id, ResponseCode::Refused, buffer),
|
||||
|
||||
Response::UnknownError { error_code, source_message_id } => return Self::write_error(source_message_id, ResponseCode::Other(error_code), buffer),
|
||||
|
||||
Response::Valid { source_message_id, authoritative, truncated, answers, name_servers, additional_records } => {
|
||||
(source_message_id, authoritative, truncated, answers, name_servers, additional_records)
|
||||
}
|
||||
};
|
||||
|
||||
let header = Header {
|
||||
message_id: source_message_id,
|
||||
is_response: true,
|
||||
opcode: super::header::OpCode::StandardQuery,
|
||||
authoritative_answer: authoritative,
|
||||
message_truncated: truncated,
|
||||
recursion_desired: false,
|
||||
recursion_available: false,
|
||||
response_code: ResponseCode::NoErrorConditions,
|
||||
question_count: 0,
|
||||
answer_count: answers.len() as u16,
|
||||
name_server_count: name_servers.len() as u16,
|
||||
additional_record_count: additional_records.len() as u16,
|
||||
};
|
||||
|
||||
header.write(buffer)
|
||||
.change_context(ResponseWriteError::Header)
|
||||
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))?;
|
||||
|
||||
let answer_count = answers.len();
|
||||
for (item, answer) in answers.into_iter().enumerate() {
|
||||
answer.write(buffer)
|
||||
.change_context(ResponseWriteError::Answer)
|
||||
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
|
||||
.attach_printable_lazy(|| format!("Writing answer {} of {}", item+1, answer_count))?;
|
||||
}
|
||||
|
||||
let ns_count = name_servers.len();
|
||||
for (item, answer) in name_servers.into_iter().enumerate() {
|
||||
answer.write(buffer)
|
||||
.change_context(ResponseWriteError::NameServer)
|
||||
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
|
||||
.attach_printable_lazy(|| format!("Writing name server {} of {}", item+1, ns_count))?;
|
||||
}
|
||||
|
||||
let ar_count = additional_records.len();
|
||||
for (item, answer) in additional_records.into_iter().enumerate() {
|
||||
answer.write(buffer)
|
||||
.change_context(ResponseWriteError::AdditionalRecord)
|
||||
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
|
||||
.attach_printable_lazy(|| format!("Writing additional record {} of {}", item+1, ar_count))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn any_random_response_roundtrips(record: Response) {
|
||||
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
|
||||
record.clone().write(&mut write_buffer).expect("can write name");
|
||||
let mut read_buffer = write_buffer.freeze();
|
||||
let new_record = Response::read(&mut read_buffer).expect("can read name");
|
||||
assert_eq!(record, new_record);
|
||||
}
|
||||
}
|
||||
1
resolver/src/protocol/server.rs
Normal file
1
resolver/src/protocol/server.rs
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
176
resolver/src/resolution_table.rs
Normal file
176
resolver/src/resolution_table.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use crate::name::Name;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::IpAddr;
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
pub struct ResolutionTable {
|
||||
inner: HashMap<Name, Vec<Resolution>>,
|
||||
}
|
||||
|
||||
struct Resolution {
|
||||
result: IpAddr,
|
||||
expiration: Instant,
|
||||
}
|
||||
|
||||
impl Default for ResolutionTable {
|
||||
fn default() -> Self {
|
||||
ResolutionTable::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ResolutionTable {
|
||||
/// Generate a new, empty resolution table to use in a new DNS implementation,
|
||||
/// or shared by a bunch of them.
|
||||
pub fn new() -> Self {
|
||||
ResolutionTable {
|
||||
inner: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean the table of expired entries.
|
||||
pub fn garbage_collect(&mut self) {
|
||||
let now = Instant::now();
|
||||
|
||||
self.inner.retain(|_, items| {
|
||||
items.retain(|x| x.expiration > now);
|
||||
!items.is_empty()
|
||||
});
|
||||
}
|
||||
|
||||
/// Add a new entry to the resolution table, with a TTL on it.
|
||||
pub fn add_entry(&mut self, name: Name, maps_to: IpAddr, ttl: Duration) {
|
||||
let now = Instant::now();
|
||||
|
||||
let new_entry = Resolution {
|
||||
result: maps_to,
|
||||
expiration: now + ttl,
|
||||
};
|
||||
|
||||
match self.inner.entry(name) {
|
||||
Entry::Vacant(vac) => {
|
||||
vac.insert(vec![new_entry]);
|
||||
}
|
||||
|
||||
Entry::Occupied(mut occ) => {
|
||||
occ.get_mut().push(new_entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up an entry in the resolution map. This will only return
|
||||
/// unexpired items.
|
||||
pub fn lookup(&mut self, name: &Name) -> HashSet<IpAddr> {
|
||||
let mut result = HashSet::new();
|
||||
let now = Instant::now();
|
||||
|
||||
if let Some(entry) = self.inner.get_mut(name) {
|
||||
entry.retain(|x| {
|
||||
let retain = x.expiration > now;
|
||||
|
||||
if retain {
|
||||
result.insert(x.result);
|
||||
}
|
||||
|
||||
retain
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
#[cfg(test)]
|
||||
use std::str::FromStr;
|
||||
|
||||
#[test]
|
||||
fn empty_set_gets_fail() {
|
||||
let mut empty = ResolutionTable::default();
|
||||
assert!(empty.lookup(&Name::from_str("foo").unwrap()).is_empty());
|
||||
assert!(empty.lookup(&Name::from_str("bar").unwrap()).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn basic_lookups() {
|
||||
let mut table = ResolutionTable::new();
|
||||
let foo = Name::from_str("foo").unwrap();
|
||||
let bar = Name::from_str("bar").unwrap();
|
||||
let baz = Name::from_str("baz").unwrap();
|
||||
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
||||
let long_time = Duration::from_secs(10000000000);
|
||||
|
||||
table.add_entry(foo.clone(), localhost, long_time);
|
||||
table.add_entry(bar.clone(), localhost, long_time);
|
||||
table.add_entry(bar.clone(), other, long_time);
|
||||
|
||||
assert_eq!(1, table.lookup(&foo).len());
|
||||
assert_eq!(2, table.lookup(&bar).len());
|
||||
assert!(table.lookup(&baz).is_empty());
|
||||
assert!(table.lookup(&foo).contains(&localhost));
|
||||
assert!(!table.lookup(&foo).contains(&other));
|
||||
assert!(table.lookup(&bar).contains(&localhost));
|
||||
assert!(table.lookup(&bar).contains(&other));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lookup_cleans_up() {
|
||||
let mut table = ResolutionTable::new();
|
||||
let foo = Name::from_str("foo").unwrap();
|
||||
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
||||
let short_time = Duration::from_millis(100);
|
||||
let long_time = Duration::from_secs(10000000000);
|
||||
|
||||
table.add_entry(foo.clone(), localhost, long_time);
|
||||
table.add_entry(foo.clone(), other, short_time);
|
||||
let wait_until = Instant::now() + (2 * short_time);
|
||||
while Instant::now() < wait_until {
|
||||
std::thread::sleep(short_time);
|
||||
}
|
||||
|
||||
assert_eq!(1, table.lookup(&foo).len());
|
||||
assert!(table.lookup(&foo).contains(&localhost));
|
||||
assert!(!table.lookup(&foo).contains(&other));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn garbage_collection_works() {
|
||||
let mut table = ResolutionTable::new();
|
||||
let foo = Name::from_str("foo").unwrap();
|
||||
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
||||
let short_time = Duration::from_millis(100);
|
||||
let long_time = Duration::from_secs(10000000000);
|
||||
|
||||
table.add_entry(foo.clone(), localhost, long_time);
|
||||
table.add_entry(foo.clone(), other, short_time);
|
||||
let wait_until = Instant::now() + (2 * short_time);
|
||||
while Instant::now() < wait_until {
|
||||
std::thread::sleep(short_time);
|
||||
}
|
||||
table.garbage_collect();
|
||||
|
||||
assert_eq!(1, table.inner.get(&foo).unwrap().len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn garbage_collection_clears_empties() {
|
||||
let mut table = ResolutionTable::new();
|
||||
let foo = Name::from_str("foo").unwrap();
|
||||
table.inner.insert(foo.clone(), vec![]);
|
||||
table.garbage_collect();
|
||||
assert!(table.inner.is_empty());
|
||||
|
||||
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
let short_time = Duration::from_millis(100);
|
||||
table.add_entry(foo.clone(), localhost, short_time);
|
||||
let wait_until = Instant::now() + (2 * short_time);
|
||||
while Instant::now() < wait_until {
|
||||
std::thread::sleep(short_time);
|
||||
}
|
||||
table.garbage_collect();
|
||||
assert!(table.inner.is_empty());
|
||||
}
|
||||
Reference in New Issue
Block a user