Checkpoint with resolver tests including request/response.

This commit is contained in:
2025-05-03 13:50:16 -07:00
parent 31cd34d280
commit 9fe5b78962
20 changed files with 4012 additions and 1093 deletions

View File

@@ -0,0 +1,176 @@
use crate::network::resolver::name::Name;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use tokio::time::{Duration, Instant};
pub struct ResolutionTable {
inner: HashMap<Name, Vec<Resolution>>,
}
struct Resolution {
result: IpAddr,
expiration: Instant,
}
impl Default for ResolutionTable {
fn default() -> Self {
ResolutionTable::new()
}
}
impl ResolutionTable {
/// Generate a new, empty resolution table to use in a new DNS implementation,
/// or shared by a bunch of them.
pub fn new() -> Self {
ResolutionTable {
inner: HashMap::new(),
}
}
/// Clean the table of expired entries.
pub fn garbage_collect(&mut self) {
let now = Instant::now();
self.inner.retain(|_, items| {
items.retain(|x| x.expiration > now);
!items.is_empty()
});
}
/// Add a new entry to the resolution table, with a TTL on it.
pub fn add_entry(&mut self, name: Name, maps_to: IpAddr, ttl: Duration) {
let now = Instant::now();
let new_entry = Resolution {
result: maps_to,
expiration: now + ttl,
};
match self.inner.entry(name) {
Entry::Vacant(vac) => {
vac.insert(vec![new_entry]);
}
Entry::Occupied(mut occ) => {
occ.get_mut().push(new_entry);
}
}
}
/// Look up an entry in the resolution map. This will only return
/// unexpired items.
pub fn lookup(&mut self, name: &Name) -> HashSet<IpAddr> {
let mut result = HashSet::new();
let now = Instant::now();
if let Some(entry) = self.inner.get_mut(name) {
entry.retain(|x| {
let retain = x.expiration > now;
if retain {
result.insert(x.result);
}
retain
});
}
result
}
}
#[cfg(test)]
use std::net::{Ipv4Addr, Ipv6Addr};
#[cfg(test)]
use std::str::FromStr;
#[test]
fn empty_set_gets_fail() {
let mut empty = ResolutionTable::default();
assert!(empty.lookup(&Name::from_str("foo").unwrap()).is_empty());
assert!(empty.lookup(&Name::from_str("bar").unwrap()).is_empty());
}
#[test]
fn basic_lookups() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let bar = Name::from_str("bar").unwrap();
let baz = Name::from_str("baz").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(bar.clone(), localhost, long_time);
table.add_entry(bar.clone(), other, long_time);
assert_eq!(1, table.lookup(&foo).len());
assert_eq!(2, table.lookup(&bar).len());
assert!(table.lookup(&baz).is_empty());
assert!(table.lookup(&foo).contains(&localhost));
assert!(!table.lookup(&foo).contains(&other));
assert!(table.lookup(&bar).contains(&localhost));
assert!(table.lookup(&bar).contains(&other));
}
#[test]
fn lookup_cleans_up() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let short_time = Duration::from_millis(100);
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(foo.clone(), other, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
assert_eq!(1, table.lookup(&foo).len());
assert!(table.lookup(&foo).contains(&localhost));
assert!(!table.lookup(&foo).contains(&other));
}
#[test]
fn garbage_collection_works() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
let short_time = Duration::from_millis(100);
let long_time = Duration::from_secs(10000000000);
table.add_entry(foo.clone(), localhost, long_time);
table.add_entry(foo.clone(), other, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
table.garbage_collect();
assert_eq!(1, table.inner.get(&foo).unwrap().len());
}
#[test]
fn garbage_collection_clears_empties() {
let mut table = ResolutionTable::new();
let foo = Name::from_str("foo").unwrap();
table.inner.insert(foo.clone(), vec![]);
table.garbage_collect();
assert!(table.inner.is_empty());
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let short_time = Duration::from_millis(100);
table.add_entry(foo.clone(), localhost, short_time);
let wait_until = Instant::now() + (2 * short_time);
while Instant::now() < wait_until {
std::thread::sleep(short_time);
}
table.garbage_collect();
assert!(table.inner.is_empty());
}