Checkpoint with resolver tests including request/response.
This commit is contained in:
505
src/network/resolver/name.rs
Normal file
505
src/network/resolver/name.rs
Normal file
@@ -0,0 +1,505 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use error_stack::{report, ResultExt};
|
||||
use internment::ArcIntern;
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::char::CharStrategy;
|
||||
use proptest::strategy::{BoxedStrategy, Strategy};
|
||||
use std::borrow::Cow;
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::str::FromStr;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
pub struct Name {
|
||||
labels: Vec<Label>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Name {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let mut first_section = true;
|
||||
|
||||
for &Label(ref label) in self.labels.iter() {
|
||||
if first_section {
|
||||
first_section = false;
|
||||
write!(f, "{}", label.as_str())?;
|
||||
} else {
|
||||
write!(f, ".{}", label.as_str())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Name {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
<Self as fmt::Debug>::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Hash, Eq)]
|
||||
pub struct Label(ArcIntern<String>);
|
||||
|
||||
impl fmt::Debug for Label {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.as_str().fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Label {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.as_str().fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Label {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0.eq_ignore_ascii_case(other.0.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NameParseError {
|
||||
#[error("Provided name '{name}' is too long ({observed_length} bytes); maximum length of a DNS name is 255 octets.")]
|
||||
NameTooLong {
|
||||
name: String,
|
||||
observed_length: usize,
|
||||
},
|
||||
|
||||
#[error("Provided name '{name}' contains illegal character '{illegal_character}'.")]
|
||||
NonAsciiCharacter {
|
||||
name: String,
|
||||
illegal_character: char,
|
||||
},
|
||||
|
||||
#[error("Provided name '{name}' contains an empty section/label, which isn't allowed.")]
|
||||
EmptyLabel { name: String },
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that is too long ({observed_length} letters, which is more than 63).")]
|
||||
LabelTooLong {
|
||||
observed_length: usize,
|
||||
name: String,
|
||||
label: String,
|
||||
},
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that begins with an illegal character; it must be a letter.")]
|
||||
LabelStartsWrong { name: String, label: String },
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that ends with an illegal character; it must be a letter or number.")]
|
||||
LabelEndsWrong { name: String, label: String },
|
||||
|
||||
#[error("Provided name '{name}' contains a label ('{label}') that contains a non-letter, non-number, and non-dash.")]
|
||||
IllegalInternalCharacter { name: String, label: String },
|
||||
}
|
||||
|
||||
impl FromStr for Name {
|
||||
type Err = NameParseError;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
let observed_length = value.as_bytes().len();
|
||||
|
||||
if observed_length > 255 {
|
||||
return Err(NameParseError::NameTooLong {
|
||||
name: value.to_string(),
|
||||
observed_length,
|
||||
});
|
||||
}
|
||||
|
||||
for char in value.chars() {
|
||||
if !(char.is_ascii_alphanumeric() || char == '.' || char == '-') {
|
||||
return Err(NameParseError::NonAsciiCharacter {
|
||||
name: value.to_string(),
|
||||
illegal_character: char,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut labels = vec![];
|
||||
|
||||
for label_str in value.split('.') {
|
||||
if label_str.is_empty() {
|
||||
return Err(NameParseError::EmptyLabel {
|
||||
name: value.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if label_str.len() > 63 {
|
||||
return Err(NameParseError::LabelTooLong {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
observed_length: label_str.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let letter = |x| ('a'..='z').contains(&x) || ('A'..='Z').contains(&x);
|
||||
let letter_or_num = |x| letter(x) || ('0'..='9').contains(&x);
|
||||
let letter_num_dash = |x| letter_or_num(x) || (x == '-');
|
||||
|
||||
if !label_str.starts_with(letter) {
|
||||
return Err(NameParseError::LabelStartsWrong {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if !label_str.ends_with(letter_or_num) {
|
||||
return Err(NameParseError::LabelEndsWrong {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if label_str.contains(|x| !letter_num_dash(x)) {
|
||||
return Err(NameParseError::IllegalInternalCharacter {
|
||||
name: value.to_string(),
|
||||
label: label_str.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// RFC 1035 says that all domain names are case-insensitive. We
|
||||
// arbitrarily normalize to lowercase here, because it shouldn't
|
||||
// matter to anyone.
|
||||
labels.push(Label(ArcIntern::new(label_str.to_string())));
|
||||
}
|
||||
|
||||
Ok(Name { labels })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NameReadError {
|
||||
#[error("Could not read name field out of an empty buffer.")]
|
||||
EmptyBuffer,
|
||||
|
||||
#[error("Buffer truncated before we could read the last label")]
|
||||
TruncatedBuffer,
|
||||
|
||||
#[error("Provided label value too long; must be 63 octets or less.")]
|
||||
LabelTooLong,
|
||||
|
||||
#[error("Label truncated while reading. Broken stream?")]
|
||||
LabelTruncated,
|
||||
|
||||
#[error("Label starts with an illegal character (must be [A-Za-z])")]
|
||||
WrongFirstByte,
|
||||
|
||||
#[error("Label ends with an illegal character (must be [A-Za-z0-9]")]
|
||||
WrongLastByte,
|
||||
|
||||
#[error("Label contains an illegal character (must be [A-Za-z0-9] or a dash)")]
|
||||
WrongInnerByte,
|
||||
}
|
||||
|
||||
impl Name {
|
||||
/// Read a name frm a record, or return an error describing what went wrong.
|
||||
///
|
||||
/// This function will advance the read pointer on the record. If the result is
|
||||
/// an error, the read pointer is not guaranteed to be in a good place, so you
|
||||
/// may need to manage that externally if you want to implement some sort of
|
||||
/// "try" functionality.
|
||||
pub fn read<B: Buf>(buffer: &mut B) -> error_stack::Result<Name, NameReadError> {
|
||||
let mut labels = Vec::new();
|
||||
let name_so_far = |labels: Vec<Label>| {
|
||||
let mut result = String::new();
|
||||
|
||||
for Label(label) in labels.into_iter() {
|
||||
if result.is_empty() {
|
||||
result.push_str(label.as_str());
|
||||
} else {
|
||||
result.push('.');
|
||||
result.push_str(label.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
loop {
|
||||
if !buffer.has_remaining() && labels.is_empty() {
|
||||
return Err(report!(NameReadError::EmptyBuffer));
|
||||
}
|
||||
|
||||
if !buffer.has_remaining() {
|
||||
return Err(report!(NameReadError::TruncatedBuffer))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)));
|
||||
}
|
||||
|
||||
let label_octet_length = buffer.get_u8() as usize;
|
||||
if label_octet_length == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if label_octet_length > 63 {
|
||||
return Err(report!(NameReadError::LabelTooLong)).attach_printable_lazy(|| {
|
||||
format!(
|
||||
"label length too big; max is supposed to be 63, saw {label_octet_length}"
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
if !buffer.remaining() < label_octet_length {
|
||||
let remaining = buffer.copy_to_bytes(buffer.remaining());
|
||||
let partial = String::from_utf8_lossy(&remaining).to_string();
|
||||
|
||||
return Err(report!(NameReadError::LabelTruncated))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!(
|
||||
"expected {} octets, but only found {}",
|
||||
label_octet_length,
|
||||
buffer.remaining()
|
||||
)
|
||||
})
|
||||
.attach_printable_lazy(|| format!("partial read: '{partial}'"));
|
||||
}
|
||||
|
||||
let label_bytes = buffer.copy_to_bytes(label_octet_length);
|
||||
let Some(first_byte) = label_bytes.first() else {
|
||||
panic!(
|
||||
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
|
||||
);
|
||||
};
|
||||
let Some(last_byte) = label_bytes.last() else {
|
||||
panic!(
|
||||
"INTERNAL ERROR: Should have at least one byte, we checked this previously."
|
||||
);
|
||||
};
|
||||
|
||||
let letter = |x| (b'a'..=b'z').contains(&x) || (b'A'..=b'Z').contains(&x);
|
||||
let letter_or_num = |x| letter(x) || (b'0'..=b'9').contains(&x);
|
||||
let letter_num_dash = |x| letter_or_num(x) || (x == b'-');
|
||||
|
||||
if !letter(*first_byte) {
|
||||
return Err(report!(NameReadError::WrongFirstByte))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||
});
|
||||
}
|
||||
|
||||
if !letter_or_num(*last_byte) {
|
||||
return Err(report!(NameReadError::WrongLastByte))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||
});
|
||||
}
|
||||
|
||||
if label_bytes.iter().any(|x| !letter_num_dash(*x)) {
|
||||
return Err(report!(NameReadError::WrongInnerByte))
|
||||
.attach_printable_lazy(|| format!("name thus far: '{}'", name_so_far(labels)))
|
||||
.attach_printable_lazy(|| {
|
||||
format!("bad label: '{}'", String::from_utf8_lossy(&label_bytes))
|
||||
});
|
||||
}
|
||||
|
||||
let label = label_bytes.into_iter().map(|x| x as char).collect();
|
||||
labels.push(Label(ArcIntern::new(label)));
|
||||
}
|
||||
|
||||
Ok(Name { labels })
|
||||
}
|
||||
|
||||
/// Write a name out to the given buffer.
|
||||
///
|
||||
/// This will try as hard as it can to write the value to the given buffer. If an
|
||||
/// error occurs, you may end up with partially-written data, so if you're worried
|
||||
/// about that you should be careful to mark where you started in the output buffer.
|
||||
pub fn write<B: BufMut>(&self, buffer: &mut B) -> error_stack::Result<(), NameWriteError> {
|
||||
for &Label(ref label) in self.labels.iter() {
|
||||
let bytes = label.as_bytes();
|
||||
|
||||
if buffer.remaining_mut() < (bytes.len() + 1) {
|
||||
return Err(report!(NameWriteError::NoRoomForLabel))
|
||||
.attach_printable_lazy(|| format!("Writing name {self}"))
|
||||
.attach_printable_lazy(|| format!("For label {label}"));
|
||||
}
|
||||
|
||||
if bytes.is_empty() || bytes.len() > 63 {
|
||||
return Err(report!(NameWriteError::IllegalLabel))
|
||||
.attach_printable_lazy(|| format!("Writing name {self}"))
|
||||
.attach_printable_lazy(|| format!("For label {label}"));
|
||||
}
|
||||
|
||||
buffer.put_u8(bytes.len() as u8);
|
||||
buffer.put_slice(bytes);
|
||||
}
|
||||
|
||||
if !buffer.has_remaining_mut() {
|
||||
return Err(report!(NameWriteError::NoRoomForNull))
|
||||
.attach_printable_lazy(|| format!("Writing name {self}"));
|
||||
}
|
||||
buffer.put_u8(0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NameWriteError {
|
||||
#[error("Not enough room to write label in name")]
|
||||
NoRoomForLabel,
|
||||
|
||||
#[error("Ran out of room writing the terminating NULL for the name")]
|
||||
NoRoomForNull,
|
||||
|
||||
#[error("Internal error: Illegal label (this shouldn't happen)")]
|
||||
IllegalLabel,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ArbitraryDomainNameSpecifications {
|
||||
total_length_range: Range<usize>,
|
||||
label_length_range: Range<usize>,
|
||||
number_labels: Range<usize>,
|
||||
}
|
||||
|
||||
impl Default for ArbitraryDomainNameSpecifications {
|
||||
fn default() -> Self {
|
||||
ArbitraryDomainNameSpecifications {
|
||||
total_length_range: 5..256,
|
||||
label_length_range: 1..64,
|
||||
number_labels: 2..5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for Name {
|
||||
type Parameters = ArbitraryDomainNameSpecifications;
|
||||
type Strategy = BoxedStrategy<Name>;
|
||||
|
||||
fn arbitrary_with(spec: Self::Parameters) -> Self::Strategy {
|
||||
(spec.number_labels.clone(), spec.total_length_range.clone())
|
||||
.prop_flat_map(move |(mut labels, mut total_length)| {
|
||||
// we need to make sure that our total length and our label count are,
|
||||
// at the very minimum, compatible. If they're not, we'll need to adjust
|
||||
// them.
|
||||
//
|
||||
// in general, we prefer to update our number of labels, if we can, but
|
||||
// only to the extent that the labels count is at least the minimum of
|
||||
// our input specification. if we try to go below that, we increase total
|
||||
// length as required. If we're forced into a place where we need to take
|
||||
// the label length below it's minimum and/or the total_length over its
|
||||
// maximum, we just give up and panic.
|
||||
//
|
||||
// Note that the minimum length of n labels is (n * label_minimum) + n - 1.
|
||||
// Consider, for example, a label_minimum of 1 and a label length of 3.
|
||||
// A minimum string is "a.b.c", which is (3 * 1) + 3 - 1 = 5 characters
|
||||
// long.
|
||||
//
|
||||
// This loop does the first part, lowering the number of labels until we
|
||||
// either get below the total length or reach the minimum number of labels.
|
||||
while (labels * spec.label_length_range.start) + (labels - 1) > total_length {
|
||||
if labels == spec.number_labels.start {
|
||||
break;
|
||||
} else {
|
||||
labels -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, if it's not right, we just set it to be right.
|
||||
if (labels * spec.total_length_range.start) + (labels - 1) > total_length {
|
||||
total_length = (labels * spec.total_length_range.start) + (labels - 1);
|
||||
}
|
||||
|
||||
// And if this takes us over our limit, just panic.
|
||||
if total_length >= spec.total_length_range.end {
|
||||
panic!("Unresolvable generation condition; couldn't resolve label count {} with total_length {}, with specification {:?}", labels, total_length, spec);
|
||||
}
|
||||
|
||||
proptest::collection::vec(Label::arbitrary_with(spec.label_length_range.clone()), labels)
|
||||
.prop_map(|labels| Name{ labels })
|
||||
}).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for Label {
|
||||
type Parameters = Range<usize>;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary() -> Self::Strategy {
|
||||
Self::arbitrary_with(1..64)
|
||||
}
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
args.prop_flat_map(|length| {
|
||||
let first = char_selector(&['a'..='z', 'A'..='Z']);
|
||||
let middle = char_selector(&['a'..='z', 'A'..='Z', '-'..='-', '0'..='9']);
|
||||
let last = char_selector(&['a'..='z', 'A'..='Z', '0'..='9']);
|
||||
|
||||
match length {
|
||||
0 => panic!("Should not be able to generate a label of length 0"),
|
||||
1 => first.prop_map(|x| x.into()).boxed(),
|
||||
2 => (first, last).prop_map(|(a, b)| format!("{a}{b}")).boxed(),
|
||||
_ => (first, proptest::collection::vec(middle, length - 2), last)
|
||||
.prop_map(move |(first, middle, last)| {
|
||||
let mut result = String::with_capacity(length);
|
||||
result.push(first);
|
||||
for c in middle.iter() {
|
||||
result.push(*c);
|
||||
}
|
||||
result.push(last);
|
||||
result
|
||||
})
|
||||
.boxed(),
|
||||
}
|
||||
})
|
||||
.prop_map(|x| Label(ArcIntern::new(x)))
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn char_selector<'a>(ranges: &'a [RangeInclusive<char>]) -> CharStrategy<'a> {
|
||||
CharStrategy::new(
|
||||
Cow::Borrowed(&[]),
|
||||
Cow::Borrowed(&[]),
|
||||
Cow::Borrowed(ranges),
|
||||
)
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn any_random_names_parses(name: Name) {
|
||||
let name_str = name.to_string();
|
||||
let new_name = Name::from_str(&name_str).expect("can re-parse name");
|
||||
assert_eq!(name, new_name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn any_random_name_roundtrips(name: Name) {
|
||||
let mut write_buffer = bytes::BytesMut::with_capacity(512);
|
||||
name.write(&mut write_buffer).expect("can write name");
|
||||
let mut read_buffer = write_buffer.freeze();
|
||||
let new_name = Name::read(&mut read_buffer).expect("can read name");
|
||||
assert_eq!(name, new_name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn illegal_names_generate_errors() {
|
||||
assert!(Name::from_str("").is_err());
|
||||
assert!(Name::from_str(".").is_err());
|
||||
assert!(Name::from_str(".com").is_err());
|
||||
assert!(Name::from_str("com.").is_err());
|
||||
assert!(Name::from_str("9.com").is_err());
|
||||
assert!(Name::from_str("foo-.com").is_err());
|
||||
assert!(Name::from_str("-foo.com").is_err());
|
||||
assert!(Name::from_str(".foo.com").is_err());
|
||||
assert!(Name::from_str("fo*o.com").is_err());
|
||||
assert!(Name::from_str(
|
||||
"foo.abcdefghiabcdefghiabcdefghiabcdefghiabcdefghiabcdefghijjjjjjabcdefghij.com"
|
||||
)
|
||||
.is_err());
|
||||
assert!(Name::from_str("abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij.abcdefghij").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn names_ignore_case() {
|
||||
assert_eq!(
|
||||
Name::from_str("UHSURE.COM").unwrap(),
|
||||
Name::from_str("uhsure.com").unwrap()
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user