Checkpoint with resolver tests including request/response.

This commit is contained in:
2025-05-03 13:50:16 -07:00
parent 31cd34d280
commit 9fe5b78962
20 changed files with 4012 additions and 1093 deletions

View File

@@ -0,0 +1,505 @@
use bytes::{Buf, BufMut};
use error_stack::{report, ResultExt};
use internment::ArcIntern;
use proptest::arbitrary::Arbitrary;
use proptest::char::CharStrategy;
use proptest::strategy::{BoxedStrategy, Strategy};
use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
use std::ops::{Range, RangeInclusive};
use std::str::FromStr;
use thiserror::Error;
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct Name {
labels: Vec<Label>,
}
impl fmt::Debug for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut first_section = true;
for &Label(ref label) in self.labels.iter() {
if first_section {
first_section = false;
write!(f, "{}", label.as_str())?;
} else {
write!(f, ".{}", label.as_str())?;
}
}
Ok(())
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
#[derive(Clone, Hash, Eq)]
pub struct Label(ArcIntern<String>);
impl fmt::Debug for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.as_str().fmt(f)
}
}
impl fmt::Display for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.as_str().fmt(f)
}
}
impl PartialEq for Label {
fn eq(&self, other: &Self) -> bool {
self.0.eq_ignore_ascii_case(other.0.as_str())
}
}
#[derive(Debug, Error)]
pub enum NameParseError {
#[error("Provided name '{name}' is too long ({observed_length} bytes); maximum length of a DNS name is 255 octets.")]
NameTooLong {
name: String,
observed_length: usize,
},
#[error("Provided name '{name}' contains illegal character '{illegal_character}'.")]
NonAsciiCharacter {
name: String,
illegal_character: char,
},
#[error("Provided name '{name}' contains an empty section/label, which isn't allowed.")]
EmptyLabel { name: String },
#[error("Provided name '{name}' contains a label ('{label}') that is too long ({observed_length} letters, which is more than 63).")]
LabelTooLong {
observed_length: usize,
name: String,
label: String,
},
#[error("Provided name '{name}' contains a label ('{label}') that begins with an illegal character; it must be a letter.")]
LabelStartsWrong { name: String, label: String },
#[error("Provided name '{name}' contains a label ('{label}') that ends with an illegal character; it must be a letter or number.")]
LabelEndsWrong { name: String, label: String },
#[error("Provided name '{name}' contains a label ('{label}') that contains a non-letter, non-number, and non-dash.")]
IllegalInternalCharacter { name: String, label: String },
}
impl FromStr for Name {
type Err = NameParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let observed_length = value.as_bytes().len();
if observed_length > 255 {
return Err(NameParseError::NameTooLong {
name: value.to_string(),
observed_length,
});
}
for char in value.chars() {
if !(char.is_ascii_alphanumeric() || char == '.' || char == '-') {
return Err(NameParseError::NonAsciiCharacter {
name: value.to_string(),
illegal_character: char,
});
}
}
let mut labels = vec![];
for label_str in value.split('.') {
if label_str.is_empty() {
return Err(NameParseError::EmptyLabel {
name: value.to_string(),
});
}
if label_str.len() > 63 {
return Err(NameParseError::LabelTooLong {
name: value.to_string(),
label: label_str.to_string(),
observed_length: label_str.len(),
});
}
let letter = |x| ('a'..='z').contains(&x) || ('A'..='Z').contains(&x);
let letter_or_num = |x| letter(x) || ('0'..='9').contains(&x);
let letter_num_dash = |x| letter_or_num(x) || (x == '-');
if !label_str.starts_with(letter) {
return Err(NameParseError::LabelStartsWrong {
name: value.to_string(),
label: label_str.to_string(),
});
}
if !label_str.ends_with(letter_or_num) {
return Err(NameParseError::LabelEndsWrong {
name: value.to_string(),
label: label_str.to_string(),
});
}
if label_str.contains(|x| !letter_num_dash(x)) {
return Err(NameParseError::IllegalInternalCharacter {
name: value.to_string(),
label: label_str.to_string(),
});
}
// RFC 1035 says that all domain names are case-insensitive. We
// arbitrarily normalize to lowercase here, because it shouldn't
// matter to anyone.
labels.push(Label(ArcIntern::new(label_str.to_string())));
}
Ok(Name { labels })
}
}
#[derive(Debug, Error)]
pub enum NameReadError {
#[error("Could not read name field out of an empty buffer.")]
EmptyBuffer,
#[error("Buffer truncated before we could read the last label")]
TruncatedBuffer,
#[error("Provided label value too long; must be 63 octets or less.")]
LabelTooLong,
#[error("Label truncated while reading. Broken stream?")]
LabelTruncated,
#[error("Label starts with an illegal character (must be [A-Za-z])")]
WrongFirstByte,
#[error("Label ends with an illegal character (must be [A-Za-z0-9]")]
WrongLastByte,
#[error("Label contains an illegal character (must be [A-Za-z0-9] or a dash)")]
WrongInnerByte,
}
impl Name {
/// Read a name frm a record, or return an error describing what went wrong.
///
/// This function will advance the read pointer on the record. If the result is
/// an error, the read pointer is not guaranteed to be in a good place, so you
/// may need to manage that externally if you want to implement some sort of
/// "try" functionality.
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Name, NameReadError> {
let mut labels = Vec::new();
let name_so_far = |labels: Vec<Label>| {
let mut result = String::new();
for Label(label) in labels.into_iter() {
if result.is_empty() {
result.push_str(label.as_str());
} else {
result.push('.');
result.push_str(label.as_str());
}
}
result
};
loop {
if !buffer.has_remaining() && labels.is_empty() {
return Err(report!(NameReadError::EmptyBuffer));
}
if !buffer.has_remaining() {
return Err(report!(NameReadError::TruncatedBuffer))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)));
}
let label_octet_length = buffer.get_u8() as usize;
if label_octet_length == 0 {
break;
}
if label_octet_length > 63 {
return Err(report!(NameReadError::LabelTooLong)).attach_printable_lazy(|| {
format!(
"label length too big; max is supposed to be 63, saw {label_octet_length}"
)
});
}
if !buffer.remaining() < label_octet_length {
let remaining = buffer.copy_to_bytes(buffer.remaining());
let partial = String::from_utf8_lossy(&remaining).to_string();
return Err(report!(NameReadError::LabelTruncated))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!(
"expected {} octets, but only found {}",
label_octet_length,
buffer.remaining()
)
})
.attach_printable_lazy(|| format!("partial read: '{partial}'"));
}
let label_bytes = buffer.copy_to_bytes(label_octet_length);
let Some(first_byte) = label_bytes.first() else {
panic!(
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
);
};
let Some(last_byte) = label_bytes.last() else {
panic!(
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
);
};
let letter = |x| (b'a'..=b'z').contains(&x) || (b'A'..=b'Z').contains(&x);
let letter_or_num = |x| letter(x) || (b'0'..=b'9').contains(&x);
let letter_num_dash = |x| letter_or_num(x) || (x == b'-');
if !letter(*first_byte) {
return Err(report!(NameReadError::WrongFirstByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
if !letter_or_num(*last_byte) {
return Err(report!(NameReadError::WrongLastByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
if label_bytes.iter().any(|x| !letter_num_dash(*x)) {
return Err(report!(NameReadError::WrongInnerByte))
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
.attach_printable_lazy(|| {
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
});
}
let label = label_bytes.into_iter().map(|x| x as char).collect();
labels.push(Label(ArcIntern::new(label)));
}
Ok(Name { labels })
}
/// Write a name out to the given buffer.
///
/// This will try as hard as it can to write the value to the given buffer. If an
/// error occurs, you may end up with partially-written data, so if you're worried
/// about that you should be careful to mark where you started in the output buffer.
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), NameWriteError> {
for &Label(ref label) in self.labels.iter() {
let bytes = label.as_bytes();
if buffer.remaining_mut() < (bytes.len() + 1) {
return Err(report!(NameWriteError::NoRoomForLabel))
.attach_printable_lazy(|| format!("Writing name {self}"))
.attach_printable_lazy(|| format!("For label {label}"));
}
if bytes.is_empty() || bytes.len() > 63 {
return Err(report!(NameWriteError::IllegalLabel))
.attach_printable_lazy(|| format!("Writing name {self}"))
.attach_printable_lazy(|| format!("For label {label}"));
}
buffer.put_u8(bytes.len() as u8);
buffer.put_slice(bytes);
}
if !buffer.has_remaining_mut() {
return Err(report!(NameWriteError::NoRoomForNull))
.attach_printable_lazy(|| format!("Writing name {self}"));
}
buffer.put_u8(0);
Ok(())
}
}
#[derive(Debug, Error)]
pub enum NameWriteError {
#[error("Not enough room to write label in name")]
NoRoomForLabel,
#[error("Ran out of room writing the terminating NULL for the name")]
NoRoomForNull,
#[error("Internal error: Illegal label (this shouldn't happen)")]
IllegalLabel,
}
#[derive(Debug)]
pub struct ArbitraryDomainNameSpecifications {
total_length_range: Range<usize>,
label_length_range: Range<usize>,
number_labels: Range<usize>,
}
impl Default for ArbitraryDomainNameSpecifications {
fn default() -> Self {
ArbitraryDomainNameSpecifications {
total_length_range: 5..256,
label_length_range: 1..64,
number_labels: 2..5,
}
}
}
impl Arbitrary for Name {
type Parameters = ArbitraryDomainNameSpecifications;
type Strategy = BoxedStrategy<Name>;
fn arbitrary_with(spec: Self::Parameters) -> Self::Strategy {
(spec.number_labels.clone(), spec.total_length_range.clone())
.prop_flat_map(move |(mut labels, mut total_length)| {
// we need to make sure that our total length and our label count are,
// at the very minimum, compatible. If they're not, we'll need to adjust
// them.
//
// in general, we prefer to update our number of labels, if we can, but
// only to the extent that the labels count is at least the minimum of
// our input specification. if we try to go below that, we increase total
// length as required. If we're forced into a place where we need to take
// the label length below it's minimum and/or the total_length over its
// maximum, we just give up and panic.
//
// Note that the minimum length of n labels is (n * label_minimum) + n - 1.
// Consider, for example, a label_minimum of 1 and a label length of 3.
// A minimum string is "a.b.c", which is (3 * 1) + 3 - 1 = 5 characters
// long.
//
// This loop does the first part, lowering the number of labels until we
// either get below the total length or reach the minimum number of labels.
while (labels * spec.label_length_range.start) + (labels - 1) > total_length {
if labels == spec.number_labels.start {
break;
} else {
labels -= 1;
}
}
// At this point, if it's not right, we just set it to be right.
if (labels * spec.total_length_range.start) + (labels - 1) > total_length {
total_length = (labels * spec.total_length_range.start) + (labels - 1);
}
// And if this takes us over our limit, just panic.
if total_length >= spec.total_length_range.end {
panic!("Unresolvable generation condition; couldn't resolve label count {} with total_length {}, with specification {:?}", labels, total_length, spec);
}
proptest::collection::vec(Label::arbitrary_with(spec.label_length_range.clone()), labels)
.prop_map(|labels| Name{ labels })
}).boxed()
}
}
impl Arbitrary for Label {
type Parameters = Range<usize>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary() -> Self::Strategy {
Self::arbitrary_with(1..64)
}
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
args.prop_flat_map(|length| {
let first = char_selector(&['a'..='z', 'A'..='Z']);
let middle = char_selector(&['a'..='z', 'A'..='Z', '-'..='-', '0'..='9']);
let last = char_selector(&['a'..='z', 'A'..='Z', '0'..='9']);
match length {
0 => panic!("Should not be able to generate a label of length 0"),
1 => first.prop_map(|x| x.into()).boxed(),
2 => (first, last).prop_map(|(a, b)| format!("{a}{b}")).boxed(),
_ => (first, proptest::collection::vec(middle, length - 2), last)
.prop_map(move |(first, middle, last)| {
let mut result = String::with_capacity(length);
result.push(first);
for c in middle.iter() {
result.push(*c);
}
result.push(last);
result
})
.boxed(),
}
})
.prop_map(|x| Label(ArcIntern::new(x)))
.boxed()
}
}
fn char_selector<'a>(ranges: &'a [RangeInclusive<char>]) -> CharStrategy<'a> {
CharStrategy::new(
Cow::Borrowed(&[]),
Cow::Borrowed(&[]),
Cow::Borrowed(ranges),
)
}
proptest::proptest! {
#[test]
fn any_random_names_parses(name: Name) {
let name_str = name.to_string();
let new_name = Name::from_str(&name_str).expect("can re-parse name");
assert_eq!(name, new_name);
}
#[test]
fn any_random_name_roundtrips(name: Name) {
let mut write_buffer = bytes::BytesMut::with_capacity(512);
name.write(&mut write_buffer).expect("can write name");
let mut read_buffer = write_buffer.freeze();
let new_name = Name::read(&mut read_buffer).expect("can read name");
assert_eq!(name, new_name);
}
}
#[test]
fn illegal_names_generate_errors() {
assert!(Name::from_str("").is_err());
assert!(Name::from_str(".").is_err());
assert!(Name::from_str(".com").is_err());
assert!(Name::from_str("com.").is_err());
assert!(Name::from_str("9.com").is_err());
assert!(Name::from_str("foo-.com").is_err());
assert!(Name::from_str("-foo.com").is_err());
assert!(Name::from_str(".foo.com").is_err());
assert!(Name::from_str("fo*o.com").is_err());
assert!(Name::from_str(
"foo.abcdefghiabcdefghiabcdefghiabcdefghiabcdefghiabcdefghijjjjjjabcdefghij.com"
)
.is_err());
assert!(Name::from_str("abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij").is_err());
}
#[test]
fn names_ignore_case() {
assert_eq!(
Name::from_str("UHSURE.COM").unwrap(),
Name::from_str("uhsure.com").unwrap()
);
}