Workspacify

This commit is contained in:
2025-05-03 17:30:01 -07:00
parent 9fe5b78962
commit d036997de3
60 changed files with 450 additions and 212 deletions

20
resolver/Cargo.toml Normal file
View File

@@ -0,0 +1,20 @@
[package]
name = "resolver"
edition = "2024"
[dependencies]
bytes = { workspace = true }
configuration = { workspace = true }
crypto = { workspace = true }
error-stack = { workspace = true }
futures = { workspace = true }
internment = { workspace = true }
num_enum = { workspace = true }
proptest = { workspace = true }
proptest-derive = { workspace = true }
serde = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
toml = { workspace = true }
tracing = { workspace = true }
url = { workspace = true }

247
resolver/src/lib.rs Normal file
View File

@@ -0,0 +1,247 @@
pub mod name;
mod protocol;
mod resolution_table;
use configuration::resolver::{DnsConfig, NameServerConfig};
use crate::name::Name;
use crate::resolution_table::ResolutionTable;
use error_stack::{report, ResultExt};
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;
use tokio::net::{TcpSocket, UdpSocket};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio::time::{Duration, Instant};
#[derive(Debug, Error)]
pub enum ResolverConfigError {
#[error("Bad local domain name provided")]
BadDomainName,
#[error("Couldn't create a DNS client for the given address, port, and protocol.")]
FailedToCreateDnsClient,
#[error("No DNS servers found to search, and mDNS not enabled")]
NoHosts,
}
#[derive(Debug, Error)]
pub enum ResolveError {
#[error("No servers available for query")]
NoServersAvailable,
#[error("No responses found for query")]
NoResponses,
#[error("Error reading response from server")]
ResponseError,
}
pub struct Resolver {
search_domains: Vec<Name>,
max_time_to_wait_for_initial: Duration,
time_to_wait_after_first: Duration,
time_to_wait_for_lingering: Duration,
connections: Arc<Mutex<Vec<(NameServerConfig, protocol::Client)>>>,
table: Arc<Mutex<ResolutionTable>>,
tasks: JoinSet<()>,
}
pub struct ResolverState {
client_connections: Vec<(NameServerConfig, protocol::Client)>,
cache: HashMap<Name, Vec<DnsResolution>>,
}
pub struct DnsResolution {
address: IpAddr,
expires_at: Instant,
}
impl Resolver {
/// Create a new DNS resolution engine for use by some part of the system.
pub async fn new(config: &DnsConfig) -> error_stack::Result<Self, ResolverConfigError> {
let mut search_domains = Vec::new();
let mut tasks = JoinSet::new();
if let Some(local) = config.local_domain.as_ref() {
let name = Name::from_str(local)
.change_context(ResolverConfigError::BadDomainName)
.attach_printable("Trying to add local domain.")
.attach_printable_lazy(|| "Offending non-name: '{local}'")?;
search_domains.push(name);
}
for search_domain in config.search_domains.iter() {
let name = Name::from_str(search_domain)
.change_context(ResolverConfigError::BadDomainName)
.attach_printable("Trying to add search domain.")
.attach_printable_lazy(|| "Offending non-name: '{search_domain}'")?;
search_domains.push(name);
}
let mut client_connections = Vec::new();
for target in config.name_servers.iter() {
let port = target.address.port().unwrap_or(53);
let Some(address) = target.address.host() else {
return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
.attach_printable("No address to connect to?")
.attach_printable_lazy(|| format!("Target address was {}", target.address));
};
let address = match address {
url::Host::Ipv4(addr) => IpAddr::V4(addr),
url::Host::Ipv6(addr) => IpAddr::V6(addr),
url::Host::Domain(name) => {
return Err(report!(ResolverConfigError::FailedToCreateDnsClient))
.attach_printable("Cannot use domain names to identify domain servers")
.attach_printable_lazy(|| format!("Target address was {name}"));
}
};
let target_addr = SocketAddr::new(address, port);
match target.address.scheme() {
"tcp" => {
let socket = if target_addr.is_ipv4() {
TcpSocket::new_v4()
} else {
TcpSocket::new_v6()
};
let socket = socket
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Could not create a socket")
.attach_printable_lazy(|| {
format!("For target DNS server {}", target.address)
})?;
if let Some(bind_address) = target.bind_address {
socket
.bind(bind_address)
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Could not bind local address for socket.")
.attach_printable_lazy(|| {
format!("Binding to TCP address {}", bind_address)
})
.attach_printable_lazy(|| {
format!("For target DNS server {}", target.address)
})?;
}
let stream = socket
.connect(target_addr)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable_lazy(|| {
format!("Connecting to target {}", target_addr)
})?;
let client = protocol::Client::from_tcp(stream, &mut tasks)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable_lazy(|| {
format!("Connecting to target {}", target_addr)
})?;
client_connections.push((target.clone(), client));
}
"udp" => {
let port = target.address.port().unwrap_or(53);
let Some(address) = target.address.host() else {
tracing::warn!(address = %target.address, "proposed domain server has no host");
continue;
};
let address = match address {
url::Host::Ipv4(addr) => IpAddr::V4(addr),
url::Host::Ipv6(addr) => IpAddr::V6(addr),
url::Host::Domain(name) => {
tracing::warn!(
address = %target.address,
hostname = name,
"currently, we can't use hostnames for domain servers"
);
continue;
}
};
let sock_addr = SocketAddr::new(address, port);
let bind_address = target.bind_address.unwrap_or_else(|| {
if sock_addr.is_ipv4() {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
} else {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
}
});
let udp_socket = UdpSocket::bind(bind_address)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Generating UDP socket")
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
udp_socket
.connect(target_addr)
.await
.change_context(ResolverConfigError::FailedToCreateDnsClient)
.attach_printable("Connecting UDP socket")
.attach_printable_lazy(|| format!("binding to address {bind_address}"))
.attach_printable_lazy(|| format!("targeting {}", target.address))?;
let client = protocol::Client::from_udp(udp_socket, &mut tasks).await;
client_connections.push((target.clone(), client));
}
"unix" => unimplemented!(),
"unixd" => unimplemented!(),
"http" => unimplemented!(),
"https" => unimplemented!(),
_ => {
tracing::warn!(address = %target.address, "Unknown scheme for building DNS connections");
continue;
}
}
}
Ok(Resolver {
search_domains,
// FIXME: All of these should be configurable
max_time_to_wait_for_initial: Duration::from_millis(150),
time_to_wait_after_first: Duration::from_millis(50),
time_to_wait_for_lingering: Duration::from_secs(2),
connections: Arc::new(Mutex::new(client_connections)),
table: Arc::new(Mutex::new(ResolutionTable::new())),
tasks,
})
}
pub async fn lookup(&self, _name: &Name) -> error_stack::Result<HashSet<IpAddr>, ResolveError> {
unimplemented!()
}
}
//#[tokio::test]
//async fn fetch_cached() {
// let resolver = Resolver::new(&DnsConfig::empty()).await.unwrap();
// let name = Name::from_utf8("name.foo").unwrap();
// let addr = IpAddr::from_str("1.2.4.5").unwrap();
//
// resolver
// .inject_resolution(name.clone(), addr.clone(), Duration::from_secs(100000))
// .await;
// let read = resolver.lookup(&name).await.unwrap();
// assert!(read.contains(&addr));
//}
//#[tokio::test]
//async fn uhsure() {
// let resolver = Resolver::new(&DnsConfig::default()).await.unwrap();
// let name = Name::from_ascii("uhsure.com").unwrap();
// let result = resolver.lookup(&name).await.unwrap();
// println!("result = {:?}", result);
// assert!(!result.is_empty());
//}

505
resolver/src/name.rs Normal file
View File

@@ -0,0 +1,505 @@
use bytes::{Buf, BufMut};
use error_stack::{report, ResultExt};
use internment::ArcIntern;
use proptest::arbitrary::Arbitrary;
use proptest::char::CharStrategy;
use proptest::strategy::{BoxedStrategy, Strategy};
use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
use std::ops::{Range, RangeInclusive};
use std::str::FromStr;
use thiserror::Error;
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct Name {
labels: Vec<Label>,
}
impl fmt::Debug for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut first_section = true;
for &Label(ref label) in self.labels.iter() {
if first_section {
first_section = false;
write!(f, "{}", label.as_str())?;
} else {
write!(f, ".{}", label.as_str())?;
}
}
Ok(())
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
#[derive(Clone, Hash, Eq)]
pub struct Label(ArcIntern<String>);
impl fmt::Debug for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.as_str().fmt(f)
}
}
impl fmt::Display for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.as_str().fmt(f)
}
}
impl PartialEq for Label {
fn eq(&self, other: &Self) -> bool {
self.0.eq_ignore_ascii_case(other.0.as_str())
}
}
#[derive(Debug, Error)]
pub enum NameParseError {
#[error("Provided name '{name}' is too long ({observed_length} bytes); maximum length of a DNS name is 255 octets.")]
NameTooLong {
name: String,
observed_length: usize,
},
#[error("Provided name '{name}' contains illegal character '{illegal_character}'.")]
NonAsciiCharacter {
name: String,
illegal_character: char,
},
#[error("Provided name '{name}' contains an empty section/label, which isn't allowed.")]
EmptyLabel { name: String },
#[error("Provided name '{name}' contains a label ('{label}') that is too long ({observed_length} letters, which is more than 63).")]
LabelTooLong {
observed_length: usize,
name: String,
label: String,
},
#[error("Provided name '{name}' contains a label ('{label}') that begins with an illegal character; it must be a letter.")]
LabelStartsWrong { name: String, label: String },
#[error("Provided name '{name}' contains a label ('{label}') that ends with an illegal character; it must be a letter or number.")]
LabelEndsWrong { name: String, label: String },
#[error("Provided name '{name}' contains a label ('{label}') that contains a non-letter, non-number, and non-dash.")]
IllegalInternalCharacter { name: String, label: String },
}
impl FromStr for Name {
type Err = NameParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let observed_length = value.as_bytes().len();
if observed_length > 255 {
return Err(NameParseError::NameTooLong {
name: value.to_string(),
observed_length,
});
}
for char in value.chars() {
if !(char.is_ascii_alphanumeric() || char == '.' || char == '-') {
return Err(NameParseError::NonAsciiCharacter {
name: value.to_string(),
illegal_character: char,
});
}
}
let mut labels = vec![];
for label_str in value.split('.') {
if label_str.is_empty() {
return Err(NameParseError::EmptyLabel {
name: value.to_string(),
});
}
if label_str.len() > 63 {
return Err(NameParseError::LabelTooLong {
name: value.to_string(),
label: label_str.to_string(),
observed_length: label_str.len(),
});
}
let letter = |x| ('a'..='z').contains(&x) || ('A'..='Z').contains(&x);
let letter_or_num = |x| letter(x) || ('0'..='9').contains(&x);
let letter_num_dash = |x| letter_or_num(x) || (x == '-');
if !label_str.starts_with(letter) {
return Err(NameParseError::LabelStartsWrong {
name: value.to_string(),
label: label_str.to_string(),
});
}
if !label_str.ends_with(letter_or_num) {
return Err(NameParseError::LabelEndsWrong {
name: value.to_string(),
label: label_str.to_string(),
});
}
if label_str.contains(|x| !letter_num_dash(x)) {
return Err(NameParseError::IllegalInternalCharacter {
name: value.to_string(),
label: label_str.to_string(),
});
}
// RFC 1035 says that all domain names are case-insensitive. We
// arbitrarily normalize to lowercase here, because it shouldn't
// matter to anyone.
labels.push(Label(ArcIntern::new(label_str.to_string())));
}
Ok(Name { labels })
}
}
#[derive(Debug, Error)]
pub enum NameReadError {
#[error("Could not read name field out of an empty buffer.")]
EmptyBuffer,
#[error("Buffer truncated before we could read the last label")]
TruncatedBuffer,
#[error("Provided label value too long; must be 63 octets or less.")]
LabelTooLong,
#[error("Label truncated while reading. Broken stream?")]
LabelTruncated,
#[error("Label starts with an illegal character (must be [A-Za-z])")]
WrongFirstByte,
#[error("Label ends with an illegal character (must be [A-Za-z0-9]")]
WrongLastByte,
#[error("Label contains an illegal character (must be [A-Za-z0-9] or a dash)")]
WrongInnerByte,
}
impl Name {
/// Read a name frm a record, or return an error describing what went wrong.
///
/// This function will advance the read pointer on the record. If the result is
/// an error, the read pointer is not guaranteed to be in a good place, so you
/// may need to manage that externally if you want to implement some sort of
/// "try" functionality.
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Name, NameReadError> {
let mut labels = Vec::new();
let name_so_far = |labels: Vec<Label>| {
let mut result = String::new();
for Label(label) in labels.into_iter() {
if result.is_empty() {
result.push_str(label.as_str());
} else {
result.push('.');
result.push_str(label.as_str());
}
}
result
};
loop {
if !buffer.has_remaining() && labels.is_empty() {
return Err(report!(NameReadError::EmptyBuffer));
}
if !buffer.has_remaining() {
return Err(report!(NameReadError::TruncatedBuffer))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)));
}
let label_octet_length = buffer.get_u8() as usize;
if label_octet_length == 0 {
break;
}
if label_octet_length > 63 {
return Err(report!(NameReadError::LabelTooLong)).attach_printable_lazy(|| {
format!(
"label length too big; max is supposed to be 63, saw {label_octet_length}"
)
});
}
if !buffer.remaining() < label_octet_length {
let remaining = buffer.copy_to_bytes(buffer.remaining());
let partial = String::from_utf8_lossy(&remaining).to_string();
return Err(report!(NameReadError::LabelTruncated))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!(
"expected {} octets, but only found {}",
label_octet_length,
buffer.remaining()
)
})
.attach_printable_lazy(|| format!("partial read: '{partial}'"));
}
let label_bytes = buffer.copy_to_bytes(label_octet_length);
let Some(first_byte) = label_bytes.first() else {
panic!(
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
);
};
let Some(last_byte) = label_bytes.last() else {
panic!(
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
);
};
let letter = |x| (b'a'..=b'z').contains(&x) || (b'A'..=b'Z').contains(&x);
let letter_or_num = |x| letter(x) || (b'0'..=b'9').contains(&x);
let letter_num_dash = |x| letter_or_num(x) || (x == b'-');
if !letter(*first_byte) {
return Err(report!(NameReadError::WrongFirstByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
if !letter_or_num(*last_byte) {
return Err(report!(NameReadError::WrongLastByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
if label_bytes.iter().any(|x| !letter_num_dash(*x)) {
return Err(report!(NameReadError::WrongInnerByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
let label = label_bytes.into_iter().map(|x| x as char).collect();
labels.push(Label(ArcIntern::new(label)));
}
Ok(Name { labels })
}
/// Write a name out to the given buffer.
///
/// This will try as hard as it can to write the value to the given buffer. If an
/// error occurs, you may end up with partially-written data, so if you're worried
/// about that you should be careful to mark where you started in the output buffer.
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), NameWriteError> {
for &Label(ref label) in self.labels.iter() {
let bytes = label.as_bytes();
if buffer.remaining_mut() < (bytes.len() + 1) {
return Err(report!(NameWriteError::NoRoomForLabel))
.attach_printable_lazy(|| format!("Writing name {self}"))
.attach_printable_lazy(|| format!("For label {label}"));
}
if bytes.is_empty() || bytes.len() > 63 {
return Err(report!(NameWriteError::IllegalLabel))
.attach_printable_lazy(|| format!("Writing name {self}"))
.attach_printable_lazy(|| format!("For label {label}"));
}
buffer.put_u8(bytes.len() as u8);
buffer.put_slice(bytes);
}
if !buffer.has_remaining_mut() {
return Err(report!(NameWriteError::NoRoomForNull))
.attach_printable_lazy(|| format!("Writing name {self}"));
}
buffer.put_u8(0);
Ok(())
}
}
#[derive(Debug, Error)]
pub enum NameWriteError {
#[error("Not enough room to write label in name")]
NoRoomForLabel,
#[error("Ran out of room writing the terminating NULL for the name")]
NoRoomForNull,
#[error("Internal error: Illegal label (this shouldn't happen)")]
IllegalLabel,
}
#[derive(Debug)]
pub struct ArbitraryDomainNameSpecifications {
total_length_range: Range<usize>,
label_length_range: Range<usize>,
number_labels: Range<usize>,
}
impl Default for ArbitraryDomainNameSpecifications {
fn default() -> Self {
ArbitraryDomainNameSpecifications {
total_length_range: 5..256,
label_length_range: 1..64,
number_labels: 2..5,
}
}
}
impl Arbitrary for Name {
type Parameters = ArbitraryDomainNameSpecifications;
type Strategy = BoxedStrategy<Name>;
fn arbitrary_with(spec: Self::Parameters) -> Self::Strategy {
(spec.number_labels.clone(), spec.total_length_range.clone())
.prop_flat_map(move |(mut labels, mut total_length)| {
// we need to make sure that our total length and our label count are,
// at the very minimum, compatible. If they're not, we'll need to adjust
// them.
//
// in general, we prefer to update our number of labels, if we can, but
// only to the extent that the labels count is at least the minimum of
// our input specification. if we try to go below that, we increase total
// length as required. If we're forced into a place where we need to take
// the label length below it's minimum and/or the total_length over its
// maximum, we just give up and panic.
//
// Note that the minimum length of n labels is (n * label_minimum) + n - 1.
// Consider, for example, a label_minimum of 1 and a label length of 3.
// A minimum string is "a.b.c", which is (3 * 1) + 3 - 1 = 5 characters
// long.
//
// This loop does the first part, lowering the number of labels until we
// either get below the total length or reach the minimum number of labels.
while (labels * spec.label_length_range.start) + (labels - 1) > total_length {
if labels == spec.number_labels.start {
break;
} else {
labels -= 1;
}
}
// At this point, if it's not right, we just set it to be right.
if (labels * spec.total_length_range.start) + (labels - 1) > total_length {
total_length = (labels * spec.total_length_range.start) + (labels - 1);
}
// And if this takes us over our limit, just panic.
if total_length >= spec.total_length_range.end {
panic!("Unresolvable generation condition; couldn't resolve label count {} with total_length {}, with specification {:?}", labels, total_length, spec);
}
proptest::collection::vec(Label::arbitrary_with(spec.label_length_range.clone()), labels)
.prop_map(|labels| Name{ labels })
}).boxed()
}
}
impl Arbitrary for Label {
type Parameters = Range<usize>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary() -> Self::Strategy {
Self::arbitrary_with(1..64)
}
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
args.prop_flat_map(|length| {
let first = char_selector(&['a'..='z', 'A'..='Z']);
let middle = char_selector(&['a'..='z', 'A'..='Z', '-'..='-', '0'..='9']);
let last = char_selector(&['a'..='z', 'A'..='Z', '0'..='9']);
match length {
0 => panic!("Should not be able to generate a label of length 0"),
1 => first.prop_map(|x| x.into()).boxed(),
2 => (first, last).prop_map(|(a, b)| format!("{a}{b}")).boxed(),
_ => (first, proptest::collection::vec(middle, length - 2), last)
.prop_map(move |(first, middle, last)| {
let mut result = String::with_capacity(length);
result.push(first);
for c in middle.iter() {
result.push(*c);
}
result.push(last);
result
})
.boxed(),
}
})
.prop_map(|x| Label(ArcIntern::new(x)))
.boxed()
}
}
fn char_selector<'a>(ranges: &'a [RangeInclusive<char>]) -> CharStrategy<'a> {
CharStrategy::new(
Cow::Borrowed(&[]),
Cow::Borrowed(&[]),
Cow::Borrowed(ranges),
)
}
proptest::proptest! {
#[test]
fn any_random_names_parses(name: Name) {
let name_str = name.to_string();
let new_name = Name::from_str(&name_str).expect("can re-parse name");
assert_eq!(name, new_name);
}
#[test]
fn any_random_name_roundtrips(name: Name) {
let mut write_buffer = bytes::BytesMut::with_capacity(512);
name.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_name = Name::read(&mut read_buffer).expect("can read name");
assert_eq!(name, new_name);
}
}
#[test]
fn illegal_names_generate_errors() {
assert!(Name::from_str("").is_err());
assert!(Name::from_str(".").is_err());
assert!(Name::from_str(".com").is_err());
assert!(Name::from_str("com.").is_err());
assert!(Name::from_str("9.com").is_err());
assert!(Name::from_str("foo-.com").is_err());
assert!(Name::from_str("-foo.com").is_err());
assert!(Name::from_str(".foo.com").is_err());
assert!(Name::from_str("fo*o.com").is_err());
assert!(Name::from_str(
"foo.abcdefghiabcdefghiabcdefghiabcdefghiabcdefghiabcdefghijjjjjjabcdefghij.com"
)
.is_err());
assert!(Name::from_str("abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij").is_err());
}
#[test]
fn names_ignore_case() {
assert_eq!(
Name::from_str("UHSURE.COM").unwrap(),
Name::from_str("uhsure.com").unwrap()
);
}

9
resolver/src/protocol.rs Normal file
View File

@@ -0,0 +1,9 @@
mod client;
mod header;
pub mod question;
mod request;
mod resource_record;
mod response;
mod server;
pub use client::Client;

View File

@@ -0,0 +1,273 @@
use crate::protocol::header::Header;
use crate::protocol::question::Question;
use bytes::{Bytes, BytesMut};
use error_stack::ResultExt;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncReadExt, WriteHalf};
use tokio::net::{TcpStream, UdpSocket, UnixDatagram, UnixStream};
use tokio::sync::RwLock;
use tokio::task::JoinSet;
type Callback = fn() -> ();
pub struct Client {
callback: Arc<RwLock<Callback>>,
channel: DnsChannel,
}
pub enum DnsChannel {
Tcp(WriteHalf<TcpStream>),
Udp(Arc<UdpSocket>),
UnixData(Arc<UnixDatagram>),
Unix(WriteHalf<UnixStream>),
}
fn empty_callback() {}
#[derive(Debug, thiserror::Error)]
pub enum SendError {}
#[derive(Clone, Debug)]
enum GeneralAddr {
Network(SocketAddr),
Unix(std::os::unix::net::SocketAddr),
}
impl From<SocketAddr> for GeneralAddr {
fn from(value: SocketAddr) -> Self {
GeneralAddr::Network(value)
}
}
impl From<tokio::net::unix::SocketAddr> for GeneralAddr {
fn from(value: tokio::net::unix::SocketAddr) -> Self {
GeneralAddr::Unix(value.into())
}
}
#[derive(Debug, Error)]
enum ServerProcessorError {
#[error("Could not read message header from server.")]
CouldNotReadHeader,
}
async fn process_server_response(
callback: &Arc<RwLock<Callback>>,
mut bytes: Bytes,
source: GeneralAddr,
) -> error_stack::Result<(), ServerProcessorError> {
unimplemented!()
}
async fn run_response_processing_loop(
server: String,
maximum_consecurive_errors: u64,
callback: Arc<RwLock<Callback>>,
mut fetcher: impl AsyncFnMut() -> Result<(Bytes, GeneralAddr), std::io::Error>,
) {
let mut consecutive_errors = 0;
loop {
match fetcher().await {
Ok((bytes, source)) => {
if let Err(e) = process_server_response(&callback, bytes, source).await {
consecutive_errors += 1;
tracing::warn!(
server,
error = %e,
maximum_consecurive_errors,
consecutive_errors,
"error processing DNS server response"
);
} else {
consecutive_errors = 0;
}
}
Err(e) => {
consecutive_errors += 1;
tracing::warn!(
server,
error = %e,
maximum_consecurive_errors,
consecutive_errors,
"failed to read response from DNS server"
);
if consecutive_errors >= maximum_consecurive_errors {
break;
}
}
};
}
tracing::error!(
server,
"quitting DNS response processing loop due to too many consecutive errors"
);
}
impl Client {
/// Create a new DNS client from the given, targeted UnixDatagram socket.
///
/// By "targeted", we mean that this socket should've had `connect` called on
/// it, so that the client can just send datagrams to the server without having
/// to know where to send them. It is unspecified what will happen to DNS clients
/// if you do not call `connect` beforehand.
pub async fn from_unix_datagram(socket: UnixDatagram, group: &mut JoinSet<()>) -> Self {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let socket = Arc::new(socket);
let reader_socket = socket.clone();
let reader_callback = callback.clone();
let server_addr = socket.local_addr()
.ok()
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
.flatten()
.unwrap_or_else(|| "<unknown>".into());
let server = format!("unixd://{}", server_addr);
group.spawn(async move {
let fetcher = async move || {
let mut buffer = BytesMut::with_capacity(16384);
match reader_socket.recv_buf_from(&mut buffer).await {
Err(x) => Err(x),
Ok((size, from)) => {
unsafe {
buffer.set_len(size);
};
Ok((buffer.freeze(), from.into()))
}
}
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Client {
callback,
channel: DnsChannel::UnixData(socket),
}
}
pub async fn from_unix_stream(
socket: UnixStream,
group: &mut JoinSet<()>,
) -> std::io::Result<Self> {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let server_addr = socket.local_addr()
.ok()
.map(|x| x.as_pathname().map(|x| x.display().to_string()))
.flatten()
.unwrap_or_else(|| "<unknown>".into());
let server = format!("unix://{}", server_addr);
let other_side = GeneralAddr::Unix(socket.peer_addr()?.into());
let (mut reader, writer) = tokio::io::split(socket);
let reader_callback = callback.clone();
group.spawn(async move {
let fetcher = async move || {
let size = reader.read_u16().await?;
let mut buffer = vec![0u8; size as usize];
reader.read_exact(&mut buffer).await?;
Ok((Bytes::from(buffer), other_side.clone()))
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Ok(Client {
callback,
channel: DnsChannel::Unix(writer),
})
}
pub async fn from_udp(socket: UdpSocket, group: &mut JoinSet<()>) -> Self {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let socket = Arc::new(socket);
let reader_callback = callback.clone();
let reader_socket = socket.clone();
let server_addr = socket.local_addr()
.ok()
.map(|x| x.to_string())
.unwrap_or_else(|| "<unknown>".into());
let server = format!("udp://{}", server_addr);
group.spawn(async move {
let fetcher = async move || {
let mut buffer = BytesMut::with_capacity(16384);
match reader_socket.recv_buf_from(&mut buffer).await {
Err(x) => Err(x),
Ok((size, from)) => {
unsafe {
buffer.set_len(size);
};
Ok((buffer.freeze(), from.into()))
}
}
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Client {
callback,
channel: DnsChannel::Udp(socket),
}
}
pub async fn from_tcp(socket: TcpStream, group: &mut JoinSet<()>) -> std::io::Result<Self> {
let callback = Arc::new(RwLock::new(empty_callback as Callback));
let other_side = GeneralAddr::Network(socket.peer_addr()?);
let server_addr = socket.local_addr()
.ok()
.map(|x| x.to_string())
.unwrap_or_else(|| "<unknown>".into());
let server = format!("tcp://{}", server_addr);
let (mut reader, writer) = tokio::io::split(socket);
let reader_callback = callback.clone();
group.spawn(async move {
let fetcher = async move || {
let size = reader.read_u16().await?;
let mut buffer = vec![0u8; size as usize];
reader.read_exact(&mut buffer).await?;
Ok((Bytes::from(buffer), other_side.clone()))
};
run_response_processing_loop(server, 5, reader_callback, fetcher).await;
});
Ok(Client {
callback,
channel: DnsChannel::Tcp(writer),
})
}
/// Send a set of questions to the upstream server.
///
/// Any response(s) that is/are sent will be handled as part of the callback
/// scheme. This function will thus only return an error if there's a problem
/// sending the question to the server.
pub async fn send_questions(_questions: Vec<Question>) -> error_stack::Result<(), SendError> {
Ok(())
}
/// Set the callback handler for when we receive responses from the server.
///
/// This function may take awhile to execute, depending on how busy we are
/// taking responses, as it takes write ownership of a read-write lock that's
/// almost always written.
pub async fn set_callback(&self, callback: Callback) {
*self.callback.write().await = callback;
}
}

View File

@@ -0,0 +1,225 @@
use bytes::{Buf, BufMut};
use error_stack::{report, ResultExt};
use num_enum::{FromPrimitive, IntoPrimitive};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use std::fmt;
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub struct Header {
pub message_id: u16,
pub is_response: bool,
pub opcode: OpCode,
pub authoritative_answer: bool,
pub message_truncated: bool,
pub recursion_desired: bool,
pub recursion_available: bool,
pub response_code: ResponseCode,
pub question_count: u16,
pub answer_count: u16,
pub name_server_count: u16,
pub additional_record_count: u16,
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u8)]
pub enum OpCode {
StandardQuery = 0,
InverseQuery = 1,
ServiceStatusRequest = 2,
#[num_enum(catch_all)]
Other(u8),
}
impl Arbitrary for OpCode {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
// while the type is 8 bits in rust, it's 4 bits in the protocol,
// and dealing with the run-off is messy. so this only generates
// valid-sized values here. it also biases toward the legit values.
// both things might want to be reconsidered in the future.
proptest::prop_oneof![
Just(OpCode::StandardQuery),
Just(OpCode::InverseQuery),
Just(OpCode::ServiceStatusRequest),
(3u8..=15).prop_map(OpCode::Other),
]
.boxed()
}
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u8)]
pub enum ResponseCode {
NoErrorConditions = 0,
FormatError = 1,
ServerFailure = 2,
NameError = 3,
NotImplemented = 4,
Refused = 5,
#[num_enum(catch_all)]
Other(u8),
}
impl fmt::Display for ResponseCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ResponseCode::NoErrorConditions => write!(f, "No errors"),
ResponseCode::FormatError => write!(f, "Illegal format"),
ResponseCode::ServerFailure => write!(f, "Server failure"),
ResponseCode::NameError => write!(f, "Name error"),
ResponseCode::NotImplemented => write!(f, "Not implemented"),
ResponseCode::Refused => write!(f, "Refused"),
ResponseCode::Other(x) => write!(f, "unknown error {x}"),
}
}
}
impl Arbitrary for ResponseCode {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
// while the type is 8 bits in rust, it's 4 bits in the protocol,
// and dealing with the run-off is messy. so this only generates
// valid-sized values here. it also biases toward the legit values.
// both things might want to be reconsidered in the future.
proptest::prop_oneof![
Just(ResponseCode::NoErrorConditions),
Just(ResponseCode::FormatError),
Just(ResponseCode::ServerFailure),
Just(ResponseCode::NameError),
Just(ResponseCode::NotImplemented),
Just(ResponseCode::Refused),
(6u8..=15).prop_map(ResponseCode::Other),
]
.boxed()
}
}
#[derive(Debug, Error)]
pub enum HeaderReadError {
#[error("Buffer not large enough to have a DNS message header in it.")]
BufferTooSmall,
#[error("Invalid data in zero-filled space.")]
NonZeroInZeroes,
}
#[derive(Debug, Error)]
pub enum HeaderWriteError {
#[error("Buffer not large enough to write header.")]
BufferTooSmall,
}
impl Header {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, HeaderReadError> {
if buffer.remaining() < 12
/* 6 16-bit fields = 6 * 2 = 1 */
{
return Err(report!(HeaderReadError::BufferTooSmall)).attach_printable_lazy(|| {
format!(
"Need at least {} bytes, but only had {}",
6 * 12,
buffer.remaining()
)
});
}
let message_id = buffer.get_u16();
let flags = buffer.get_u16();
let question_count = buffer.get_u16();
let answer_count = buffer.get_u16();
let name_server_count = buffer.get_u16();
let additional_record_count = buffer.get_u16();
let is_response = (0x8000 & flags) != 0;
let opcode = OpCode::from(((flags >> 11) & 0xF) as u8);
let authoritative_answer = (0x0400 & flags) != 0;
let message_truncated = (0x0200 & flags) != 0;
let recursion_desired = (0x0100 & flags) != 0;
let recursion_available = (0x0080 & flags) != 0;
let zeroes = 0x0070 & flags;
let response_code = ResponseCode::from((flags & 0x000F) as u8);
if zeroes != 0 {
return Err(report!(HeaderReadError::NonZeroInZeroes))
.attach_printable_lazy(|| format!("Saw {:#x} instead.", zeroes >> 4));
}
Ok(Header {
message_id,
is_response,
opcode,
authoritative_answer,
message_truncated,
recursion_desired,
recursion_available,
response_code,
question_count,
answer_count,
name_server_count,
additional_record_count,
})
}
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), HeaderWriteError> {
if buffer.remaining_mut() < 12
/* 6 16-bit fields = 6 * 2 = 1 */
{
return Err(report!(HeaderWriteError::BufferTooSmall)).attach_printable_lazy(|| {
format!(
"Need at least {} to write DNS header, only have {}",
6 * 12,
buffer.remaining_mut()
)
});
}
let mut flags: u16 = 0;
if self.is_response {
flags |= 0x8000;
}
let opcode: u8 = self.opcode.into();
flags |= (opcode as u16) << 11;
if self.authoritative_answer {
flags |= 0x0400;
}
if self.message_truncated {
flags |= 0x0200;
}
if self.recursion_desired {
flags |= 0x0100;
}
if self.recursion_available {
flags |= 0x0080;
}
let response_code: u8 = self.response_code.into();
flags |= response_code as u16;
buffer.put_u16(self.message_id);
buffer.put_u16(flags);
buffer.put_u16(self.question_count);
buffer.put_u16(self.answer_count);
buffer.put_u16(self.name_server_count);
buffer.put_u16(self.additional_record_count);
Ok(())
}
}
proptest::proptest! {
#[test]
fn headers_roundtrip(header: Header) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
let safe_header = header.clone();
header.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_header = Header::read(&mut read_buffer).expect("can read name");
assert_eq!(safe_header, new_header);
}
}

View File

@@ -0,0 +1,91 @@
use crate::name::Name;
use crate::protocol::resource_record::raw::{RecordClass, RecordType};
use bytes::{Buf, BufMut, TryGetError};
use error_stack::{report, ResultExt};
use std::fmt;
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub struct Question {
name: Name,
record_type: RecordType,
record_class: RecordClass,
}
impl fmt::Display for Question {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "<Q:{}@{}.{}>", self.name, self.record_type, self.record_class)
}
}
#[derive(Debug, Error)]
pub enum QuestionReadError {
#[error("Could not read name for question.")]
CouldNotReadName,
#[error("Could not read the record type for the question: {0}")]
CouldNotReadType(TryGetError),
#[error("Could not read the record class for the question.")]
CouldNotReadClass(TryGetError),
}
#[derive(Debug, Error)]
pub enum QuestionWriteError {
#[error("Could not write name for the question.")]
CouldNotWriteName,
#[error("Buffer not large enough to write question type and class.")]
BufferTooSmall,
}
impl Question {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, QuestionReadError> {
let name = Name::read(buffer).change_context(QuestionReadError::CouldNotReadName)?;
let record_type_u16 = buffer
.try_get_u16()
.map_err(|e| report!(QuestionReadError::CouldNotReadType(e)))
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
let record_class_u16 = buffer
.try_get_u16()
.map_err(|e| report!(QuestionReadError::CouldNotReadClass(e)))
.attach_printable_lazy(|| format!("question was about '{name}'"))?;
let record_type = RecordType::from(record_type_u16);
let record_class = RecordClass::from(record_class_u16);
Ok(Question {
name,
record_type,
record_class,
})
}
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), QuestionWriteError> {
self.name
.write(buffer)
.change_context(QuestionWriteError::CouldNotWriteName)
.attach_printable_lazy(|| format!("Question was about '{}'", self.name))?;
if buffer.remaining_mut() < 4 {
return Err(report!(QuestionWriteError::BufferTooSmall))
.attach_printable_lazy(|| format!("Question was about '{}'", self.name));
}
buffer.put_u16(self.record_type.into());
buffer.put_u16(self.record_class.into());
Ok(())
}
}
proptest::proptest! {
#[test]
fn questions_roundtrip(question: Question) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
question.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_question = Question::read(&mut read_buffer).expect("can read name");
assert_eq!(question, new_question);
}
}

View File

@@ -0,0 +1,156 @@
use bytes::{Buf, BufMut};
use error_stack::{report, ResultExt};
use crate::protocol::header::{Header, OpCode, ResponseCode};
use crate::protocol::question::Question;
use crate::protocol::resource_record::ResourceRecord;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub struct Request {
source_message_id: u16,
opcode: OpCode,
recursion_desired: bool,
questions: Vec<Question>,
}
#[derive(Debug, thiserror::Error)]
pub enum RequestReadError {
#[error("Error reading request header.")]
Header,
#[error("Message isn't a request.")]
NotRequest,
#[error("Illegal message.")]
IllegalMessage,
#[error("Error reading request question.")]
Question,
#[error("Error reading request answers (?!)")]
Answer,
#[error("Error reading requst name servers.")]
NameServers,
#[error("Reading request additional records.")]
AdditionalRecords,
}
#[derive(Debug, thiserror::Error)]
pub enum RequestWriteError {
#[error("Error writing request header.")]
Header,
#[error("Error writing request question.")]
Question,
#[error("Request had too many questions.")]
TooManyQuestions,
}
impl Request {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, RequestReadError> {
let header = Header::read(buffer)
.change_context(RequestReadError::Header)?;
if header.is_response {
return Err(report!(RequestReadError::NotRequest))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.authoritative_answer {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to set the 'authoritative answer' bit.")
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.response_code != ResponseCode::NoErrorConditions {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to set the response code.")
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.answer_count != 0 {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to include answers.")
.attach_printable_lazy(|| format!("{} answers declared", header.answer_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.name_server_count != 0 {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to include name servers.")
.attach_printable_lazy(|| format!("{} name servers declared", header.name_server_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
if header.additional_record_count != 0 {
return Err(report!(RequestReadError::IllegalMessage))
.attach_printable("Request messages are not allowed to include additional records.")
.attach_printable_lazy(|| format!("{} aditional records declared", header.additional_record_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id));
}
let mut questions = vec![];
for i in 0..header.question_count {
let question = Question::read(buffer)
.change_context(RequestReadError::Question)
.attach_printable_lazy(|| format!("for question #{} of {}", i+1, header.question_count))
.attach_printable_lazy(|| format!("message id is {}", header.message_id))?;
questions.push(question);
}
Ok(Request {
source_message_id: header.message_id,
opcode: header.opcode,
recursion_desired: header.recursion_desired,
questions,
})
}
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), RequestWriteError> {
let question_count = self.questions.len();
if question_count > (u16::MAX as usize) {
return Err(report!(RequestWriteError::TooManyQuestions))
.attach_printable(format!("message_id is {}", self.source_message_id));
}
let header = Header {
message_id: self.source_message_id,
is_response: false,
opcode: self.opcode,
authoritative_answer: false,
message_truncated: false,
recursion_desired: self.recursion_desired,
recursion_available: false,
response_code: ResponseCode::NoErrorConditions,
question_count: question_count as u16,
answer_count: 0,
name_server_count: 0,
additional_record_count: 0,
};
header.write(buffer)
.change_context(RequestWriteError::Header)
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))?;
for (index, question) in self.questions.into_iter().enumerate() {
question.write(buffer)
.change_context(RequestWriteError::Question)
.attach_printable_lazy(|| format!("message ID is {}", self.source_message_id))
.attach_printable_lazy(|| format!("question #{} of {}", index+1, question_count))
.attach_printable_lazy(|| format!("{}", question))?;
}
Ok(())
}
}
proptest::proptest! {
#[test]
fn request_roundtrip(request: Request) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
let safe_request = request.clone();
request.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_request = Request::read(&mut read_buffer).expect("can read name");
assert_eq!(safe_request, new_request);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,304 @@
use crate::name::Name;
use bytes::{Buf, BufMut, Bytes};
use error_stack::{report, ResultExt};
use num_enum::{FromPrimitive, IntoPrimitive};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use std::fmt;
use thiserror::Error;
#[derive(Debug, PartialEq)]
pub struct RawResourceRecord {
pub name: Name,
pub record_type: RecordType,
pub record_class: RecordClass,
pub ttl: u32,
pub data: Bytes,
}
impl Arbitrary for RawResourceRecord {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(
Name::arbitrary(),
RecordType::arbitrary(),
RecordClass::arbitrary(),
u32::arbitrary(),
proptest::collection::vec(u8::arbitrary(), 0..65535),
)
.prop_map(
|(name, record_type, record_class, ttl, data)| RawResourceRecord {
name,
record_type,
record_class,
ttl,
data: data.into(),
},
)
.boxed()
}
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u16)]
pub enum RecordType {
A = 1,
AAAA = 28,
NS = 2,
MD = 3,
MF = 4,
CNAME = 5,
SOA = 6,
MB = 7,
MG = 8,
MR = 9,
NULL = 10,
WKS = 11,
PTR = 12,
HINFO = 13,
MINFO = 14,
MX = 15,
TXT = 16,
URI = 256,
#[num_enum(catch_all)]
Other(u16),
}
impl fmt::Display for RecordType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecordType::A => write!(f, "A"),
RecordType::AAAA => write!(f, "AAAA"),
RecordType::NS => write!(f, "NS"),
RecordType::MD => write!(f, "MD"),
RecordType::MF => write!(f, "MF"),
RecordType::CNAME => write!(f, "CNAME"),
RecordType::SOA => write!(f, "SOA"),
RecordType::MB => write!(f, "MB"),
RecordType::MG => write!(f, "MG"),
RecordType::MR => write!(f, "MR"),
RecordType::NULL => write!(f, "NULL"),
RecordType::WKS => write!(f, "WKS"),
RecordType::PTR => write!(f, "PTR"),
RecordType::HINFO => write!(f, "HINFO"),
RecordType::MINFO => write!(f, "MINFO"),
RecordType::MX => write!(f, "MX"),
RecordType::TXT => write!(f, "TXT"),
RecordType::URI => write!(f, "URI"),
RecordType::Other(x) => write!(f, "UNKNOWN<{x}>"),
}
}
}
impl Arbitrary for RecordType {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// this is intentionally biased towards the legit values
proptest::prop_oneof![
Just(RecordType::A),
Just(RecordType::AAAA),
Just(RecordType::NS),
Just(RecordType::MD),
Just(RecordType::MF),
Just(RecordType::CNAME),
Just(RecordType::SOA),
Just(RecordType::MB),
Just(RecordType::MG),
Just(RecordType::MR),
Just(RecordType::NULL),
Just(RecordType::WKS),
Just(RecordType::PTR),
Just(RecordType::HINFO),
Just(RecordType::MINFO),
Just(RecordType::MX),
Just(RecordType::TXT),
proptest::prop_oneof![
(17u16..28).prop_map(|x| RecordType::Other(x)),
(29u16..256).prop_map(|x| RecordType::Other(x)),
(257u16..=65535).prop_map(|x| RecordType::Other(x)),
Just(RecordType::Other(0)),
],
]
.boxed()
}
}
#[derive(FromPrimitive, IntoPrimitive, PartialEq, PartialOrd, Eq, Ord, Debug, Copy, Clone)]
#[repr(u16)]
pub enum RecordClass {
IN = 1,
CS = 2,
CH = 3,
HS = 4,
#[num_enum(catch_all)]
Other(u16),
}
impl fmt::Display for RecordClass {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecordClass::IN => write!(f, "IN"),
RecordClass::CS => write!(f, "CS"),
RecordClass::CH => write!(f, "CH"),
RecordClass::HS => write!(f, "HS"),
RecordClass::Other(x) => write!(f, "UNKNOWN<{x}>"),
}
}
}
impl Arbitrary for RecordClass {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// this is intentionally biased towards the legit values
proptest::prop_oneof![
Just(RecordClass::IN),
Just(RecordClass::CS),
Just(RecordClass::CH),
Just(RecordClass::HS),
(5u16..=65535).prop_map(|x| RecordClass::Other(x)),
Just(RecordClass::Other(0)),
]
.boxed()
}
}
#[derive(Debug, Error)]
pub enum ResourceRecordReadError {
#[error("Failed to read initial record name.")]
InitialRecord,
#[error("Resource record truncated; couldn't find its type field.")]
NoTypeField,
#[error("Resource record truncated; couldn't find its class field.")]
NoClassField,
#[error("Resource record truncated; couldn't find its TTL field.")]
NoTtl,
#[error("Resource record truncated; couldn't find its data length.")]
NoDataLength,
#[error("Resource record truncated; couldn't read its entire data field.")]
DataTruncated,
}
#[derive(Debug, Error)]
pub enum ResourceRecordWriteError {
#[error("Could not write name to the output record.")]
CouldNotWriteName,
#[error("Could not write resource record type and class to output record.")]
CouldNotWriteTypeClass,
#[error("Could not write TTL to output record.")]
CountNotWriteTtl,
#[error("Could not write resource record data length to output record.")]
CountNotWriteDataLength,
#[error("Could not write resource record data to output record.")]
CountNotWriteData,
#[error("Input data was too large to write to output stream.")]
InputDataTooLarge,
}
impl RawResourceRecord {
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Self, ResourceRecordReadError> {
let name = Name::read(buffer).change_context(ResourceRecordReadError::InitialRecord)?;
let record_type = buffer
.try_get_u16()
.map_err(|_| report!(ResourceRecordReadError::NoTypeField))?
.into();
let record_class = buffer
.try_get_u16()
.map_err(|_| report!(ResourceRecordReadError::NoClassField))?
.into();
let ttl = buffer
.try_get_u32()
.map_err(|_| report!(ResourceRecordReadError::NoTtl))?;
let rdata_length = buffer
.try_get_u16()
.map_err(|_| report!(ResourceRecordReadError::NoDataLength))?;
if buffer.remaining() < (rdata_length as usize) {
return Err(report!(ResourceRecordReadError::DataTruncated)).attach_printable_lazy(
|| {
format!(
"Expected {rdata_length} bytes, but only saw {}",
buffer.remaining()
)
},
);
}
let data = buffer.copy_to_bytes(rdata_length as usize);
Ok(RawResourceRecord {
name,
record_type,
record_class,
ttl,
data,
})
}
pub fn write<B: BufMut>(
&self,
buffer: &mut B,
) -> error_stack::Result<(), ResourceRecordWriteError> {
self.name
.write(buffer)
.change_context(ResourceRecordWriteError::CouldNotWriteName)?;
if buffer.remaining_mut() < 4 {
return Err(report!(ResourceRecordWriteError::CouldNotWriteTypeClass));
}
buffer.put_u16(self.record_type.into());
buffer.put_u16(self.record_class.into());
if buffer.remaining_mut() < 4 {
return Err(report!(ResourceRecordWriteError::CountNotWriteTtl));
}
buffer.put_u32(self.ttl);
if buffer.remaining_mut() < 2 {
return Err(report!(ResourceRecordWriteError::CountNotWriteDataLength));
}
if self.data.len() > (u16::MAX as usize) {
return Err(report!(ResourceRecordWriteError::InputDataTooLarge))
.attach_printable_lazy(|| {
format!(
"Incoming data was {} bytes, needs to be < 2^16",
self.data.len()
)
});
}
buffer.put_u16(self.data.len() as u16);
if buffer.remaining_mut() < self.data.len() {
return Err(report!(ResourceRecordWriteError::CountNotWriteData));
}
buffer.put_slice(&self.data);
Ok(())
}
}
proptest::proptest! {
#[test]
fn any_random_record_roundtrips(record: RawResourceRecord) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
record.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_record = RawResourceRecord::read(&mut read_buffer).expect("can read name");
assert_eq!(record, new_record);
}
}

View File

@@ -0,0 +1,258 @@
use bytes::{Buf, BufMut};
use error_stack::ResultExt;
use crate::protocol::header::{Header, ResponseCode};
use crate::protocol::question::Question;
use crate::protocol::resource_record::ResourceRecord;
#[derive(Clone, Debug, PartialEq, proptest_derive::Arbitrary)]
pub enum Response {
Valid {
source_message_id: u16,
authoritative: bool,
truncated: bool,
answers: Vec<ResourceRecord>,
name_servers: Vec<ResourceRecord>,
additional_records: Vec<ResourceRecord>,
},
FormatError {
source_message_id: u16,
},
ServerFailure {
source_message_id: u16,
},
NameError {
source_message_id: u16,
},
NotImplemented {
source_message_id: u16,
},
Refused {
source_message_id: u16,
},
UnknownError {
#[proptest(strategy="6u8..=15")]
error_code: u8,
source_message_id: u16,
}
}
#[derive(Debug, thiserror::Error)]
pub enum ResponseReadError {
#[error("Could not read response header.")]
HeaderReadError,
#[error("Could not read question included in response.")]
QuestionReadError,
#[error("Could not read resource record included as an answer.")]
AnswerReadError,
#[error("Could not read name server record included in response.")]
NameServerReadError,
#[error("Could not read supplemental record included in response.")]
AdditionalInfoReadError,
}
#[derive(Debug, thiserror::Error)]
pub enum ResponseWriteError {
#[error("Could not write response header.")]
Header,
#[error("Could not write answer.")]
Answer,
#[error("Could not write name server.")]
NameServer,
#[error("Could not write additional record.")]
AdditionalRecord,
}
impl Response {
pub fn message_id(&self) -> u16 {
match self {
Response::Valid { source_message_id, .. } => *source_message_id,
Response::FormatError { source_message_id } => *source_message_id,
Response::ServerFailure { source_message_id } => *source_message_id,
Response::NameError { source_message_id } => *source_message_id,
Response::NotImplemented { source_message_id } => *source_message_id,
Response::Refused { source_message_id } => *source_message_id,
Response::UnknownError { source_message_id, .. } => *source_message_id,
}
}
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Response, ResponseReadError> {
let header = Header::read(buffer)
.change_context(ResponseReadError::HeaderReadError)?;
// check for errors, and short-cut out if we find any
match header.response_code {
ResponseCode::NoErrorConditions => {}
ResponseCode::FormatError => return Ok(Response::FormatError {
source_message_id: header.message_id,
}),
ResponseCode::ServerFailure => return Ok(Response::ServerFailure {
source_message_id: header.message_id,
}),
ResponseCode::NameError => return Ok(Response::NameError {
source_message_id: header.message_id,
}),
ResponseCode::NotImplemented => return Ok(Response::NotImplemented {
source_message_id: header.message_id,
}),
ResponseCode::Refused => return Ok(Response::Refused {
source_message_id: header.message_id,
}),
ResponseCode::Other(error_code) => return Ok(Response::UnknownError {
error_code,
source_message_id: header.message_id,
}),
}
// it seems weird to get questions in a response, but we need to parse
// them out if they exist.
for _ in 0..header.question_count {
let question = Question::read(buffer)
.change_context(ResponseReadError::QuestionReadError)?;
tracing::warn!(
%question,
"got question during server response."
);
}
let mut answers = vec![];
for idx in 0..header.answer_count {
let answer = ResourceRecord::read(buffer)
.change_context(ResponseReadError::AnswerReadError)
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.answer_count))?;
answers.push(answer);
}
let mut name_servers = vec![];
for idx in 0..header.name_server_count {
let name_server = ResourceRecord::read(buffer)
.change_context(ResponseReadError::NameServerReadError)
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.name_server_count))?;
name_servers.push(name_server);
}
let mut additional_records = vec![];
for idx in 0..header.additional_record_count {
let extra = ResourceRecord::read(buffer)
.change_context(ResponseReadError::AnswerReadError)
.attach_printable_lazy(|| format!("In answer {} of {}", idx + 1, header.additional_record_count))?;
additional_records.push(extra);
}
Ok(Response::Valid {
source_message_id: header.message_id,
authoritative: header.authoritative_answer,
truncated: header.message_truncated,
answers,
name_servers,
additional_records,
})
}
fn write_error<B: BufMut>(source_message_id: u16, error_code: ResponseCode, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
let header = Header {
message_id: source_message_id,
is_response: true,
opcode: super::header::OpCode::StandardQuery,
authoritative_answer: false,
message_truncated: false,
recursion_desired: false,
recursion_available: false,
response_code: error_code,
question_count: 0,
answer_count: 0,
name_server_count: 0,
additional_record_count: 0,
};
header
.write(buffer)
.change_context(ResponseWriteError::Header)
.attach_printable_lazy(|| format!("responding to message {source_message_id} with error code {error_code}"))
}
pub fn write<B: BufMut>(self, buffer: &mut B) -> error_stack::Result<(), ResponseWriteError> {
let (source_message_id, authoritative, truncated, answers, name_servers, additional_records) = match self {
Response::FormatError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::FormatError, buffer),
Response::ServerFailure { source_message_id } => return Self::write_error(source_message_id, ResponseCode::ServerFailure, buffer),
Response::NameError { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NameError, buffer),
Response::NotImplemented { source_message_id } => return Self::write_error(source_message_id, ResponseCode::NotImplemented, buffer),
Response::Refused { source_message_id } => return Self::write_error(source_message_id, ResponseCode::Refused, buffer),
Response::UnknownError { error_code, source_message_id } => return Self::write_error(source_message_id, ResponseCode::Other(error_code), buffer),
Response::Valid { source_message_id, authoritative, truncated, answers, name_servers, additional_records } => {
(source_message_id, authoritative, truncated, answers, name_servers, additional_records)
}
};
let header = Header {
message_id: source_message_id,
is_response: true,
opcode: super::header::OpCode::StandardQuery,
authoritative_answer: authoritative,
message_truncated: truncated,
recursion_desired: false,
recursion_available: false,
response_code: ResponseCode::NoErrorConditions,
question_count: 0,
answer_count: answers.len() as u16,
name_server_count: name_servers.len() as u16,
additional_record_count: additional_records.len() as u16,
};
header.write(buffer)
.change_context(ResponseWriteError::Header)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))?;
let answer_count = answers.len();
for (item, answer) in answers.into_iter().enumerate() {
answer.write(buffer)
.change_context(ResponseWriteError::Answer)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
.attach_printable_lazy(|| format!("Writing answer {} of {}", item+1, answer_count))?;
}
let ns_count = name_servers.len();
for (item, answer) in name_servers.into_iter().enumerate() {
answer.write(buffer)
.change_context(ResponseWriteError::NameServer)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
.attach_printable_lazy(|| format!("Writing name server {} of {}", item+1, ns_count))?;
}
let ar_count = additional_records.len();
for (item, answer) in additional_records.into_iter().enumerate() {
answer.write(buffer)
.change_context(ResponseWriteError::AdditionalRecord)
.attach_printable_lazy(|| format!("Writing clean response to {source_message_id}"))
.attach_printable_lazy(|| format!("Writing additional record {} of {}", item+1, ar_count))?;
}
Ok(())
}
}
proptest::proptest! {
#[test]
fn any_random_response_roundtrips(record: Response) {
let mut write_buffer = bytes::BytesMut::with_capacity(128 * 1024);
record.clone().write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_record = Response::read(&mut read_buffer).expect("can read name");
assert_eq!(record, new_record);
}
}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,176 @@
use crate::name::Name;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use tokio::time::{Duration, Instant};
pub struct ResolutionTable {
inner: HashMap<Name, Vec<Resolution>>,
}
struct Resolution {
result: IpAddr,
expiration: Instant,
}
impl Default for ResolutionTable {
fn default() -> Self {
ResolutionTable::new()
}
}
impl ResolutionTable {
/// Generate a new, empty resolution table to use in a new DNS implementation,
/// or shared by a bunch of them.
pub fn new() -> Self {
ResolutionTable {
inner: HashMap::new(),
}
}
/// Clean the table of expired entries.
pub fn garbage_collect(&mut self) {
let now = Instant::now();
self.inner.retain(|_, items| {
items.retain(|x| x.expiration > now);
!items.is_empty()
});
}
/// Add a new entry to the resolution table, with a TTL on it.
pub fn add_entry(&mut self, name: Name, maps_to: IpAddr, ttl: Duration) {
let now = Instant::now();
let new_entry = Resolution {
result: maps_to,
expiration: now + ttl,
};
match self.inner.entry(name) {
Entry::Vacant(vac) => {
vac.insert(vec![new_entry]);
}
Entry::Occupied(mut occ) => {
occ.get_mut().push(new_entry);
}
}
}
/// Look up an entry in the resolution map. This will only return
/// unexpired items.
pub fn lookup(&mut self, name: &Name) -> HashSet<IpAddr> {
let mut result = HashSet::new();
let now = Instant::now();
if let Some(entry) = self.inner.get_mut(name) {
entry.retain(|x| {
let retain = x.expiration > now;
if retain {
result.insert(x.result);
}
retain
});
}
result
}
}
#[cfg(test)]
use std::net::{Ipv4Addr, Ipv6Addr};
#[cfg(test)]
use std::str::FromStr;
#[test]
fn empty_set_gets_fail() {
let mut empty = ResolutionTable::default();
assert!(empty.lookup(&Name::from_str("foo").unwrap()).is_empty());
assert!(empty.lookup(&Name::from_str("bar").unwrap()).is_empty());
}
#[test]
fn basic_lookups() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let bar = Name::from_str("bar").unwrap();
let baz = Name::from_str("baz").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(bar.clone(), localhost, long_time);
table.add_entry(bar.clone(), other, long_time);
assert_eq!(1, table.lookup(&foo).len());
assert_eq!(2, table.lookup(&bar).len());
assert!(table.lookup(&baz).is_empty());
assert!(table.lookup(&foo).contains(&localhost));
assert!(!table.lookup(&foo).contains(&other));
assert!(table.lookup(&bar).contains(&localhost));
assert!(table.lookup(&bar).contains(&other));
}
#[test]
fn lookup_cleans_up() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let short_time = Duration::from_millis(100);
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(foo.clone(), other, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
assert_eq!(1, table.lookup(&foo).len());
assert!(table.lookup(&foo).contains(&localhost));
assert!(!table.lookup(&foo).contains(&other));
}
#[test]
fn garbage_collection_works() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let short_time = Duration::from_millis(100);
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(foo.clone(), other, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
table.garbage_collect();
assert_eq!(1, table.inner.get(&foo).unwrap().len());
}
#[test]
fn garbage_collection_clears_empties() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
table.inner.insert(foo.clone(), vec![]);
table.garbage_collect();
assert!(table.inner.is_empty());
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let short_time = Duration::from_millis(100);
table.add_entry(foo.clone(), localhost, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
table.garbage_collect();
assert!(table.inner.is_empty());
}