177 lines
5.2 KiB
Rust
177 lines
5.2 KiB
Rust
use crate::network::resolver::name::Name;
|
|
use std::collections::hash_map::Entry;
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::net::IpAddr;
|
|
use tokio::time::{Duration, Instant};
|
|
|
|
pub struct ResolutionTable {
|
|
inner: HashMap<Name, Vec<Resolution>>,
|
|
}
|
|
|
|
struct Resolution {
|
|
result: IpAddr,
|
|
expiration: Instant,
|
|
}
|
|
|
|
impl Default for ResolutionTable {
|
|
fn default() -> Self {
|
|
ResolutionTable::new()
|
|
}
|
|
}
|
|
|
|
impl ResolutionTable {
|
|
/// Generate a new, empty resolution table to use in a new DNS implementation,
|
|
/// or shared by a bunch of them.
|
|
pub fn new() -> Self {
|
|
ResolutionTable {
|
|
inner: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Clean the table of expired entries.
|
|
pub fn garbage_collect(&mut self) {
|
|
let now = Instant::now();
|
|
|
|
self.inner.retain(|_, items| {
|
|
items.retain(|x| x.expiration > now);
|
|
!items.is_empty()
|
|
});
|
|
}
|
|
|
|
/// Add a new entry to the resolution table, with a TTL on it.
|
|
pub fn add_entry(&mut self, name: Name, maps_to: IpAddr, ttl: Duration) {
|
|
let now = Instant::now();
|
|
|
|
let new_entry = Resolution {
|
|
result: maps_to,
|
|
expiration: now + ttl,
|
|
};
|
|
|
|
match self.inner.entry(name) {
|
|
Entry::Vacant(vac) => {
|
|
vac.insert(vec![new_entry]);
|
|
}
|
|
|
|
Entry::Occupied(mut occ) => {
|
|
occ.get_mut().push(new_entry);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Look up an entry in the resolution map. This will only return
|
|
/// unexpired items.
|
|
pub fn lookup(&mut self, name: &Name) -> HashSet<IpAddr> {
|
|
let mut result = HashSet::new();
|
|
let now = Instant::now();
|
|
|
|
if let Some(entry) = self.inner.get_mut(name) {
|
|
entry.retain(|x| {
|
|
let retain = x.expiration > now;
|
|
|
|
if retain {
|
|
result.insert(x.result);
|
|
}
|
|
|
|
retain
|
|
});
|
|
}
|
|
|
|
result
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
|
#[cfg(test)]
|
|
use std::str::FromStr;
|
|
|
|
#[test]
|
|
fn empty_set_gets_fail() {
|
|
let mut empty = ResolutionTable::default();
|
|
assert!(empty.lookup(&Name::from_str("foo").unwrap()).is_empty());
|
|
assert!(empty.lookup(&Name::from_str("bar").unwrap()).is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn basic_lookups() {
|
|
let mut table = ResolutionTable::new();
|
|
let foo = Name::from_str("foo").unwrap();
|
|
let bar = Name::from_str("bar").unwrap();
|
|
let baz = Name::from_str("baz").unwrap();
|
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
|
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
|
let long_time = Duration::from_secs(10000000000);
|
|
|
|
table.add_entry(foo.clone(), localhost, long_time);
|
|
table.add_entry(bar.clone(), localhost, long_time);
|
|
table.add_entry(bar.clone(), other, long_time);
|
|
|
|
assert_eq!(1, table.lookup(&foo).len());
|
|
assert_eq!(2, table.lookup(&bar).len());
|
|
assert!(table.lookup(&baz).is_empty());
|
|
assert!(table.lookup(&foo).contains(&localhost));
|
|
assert!(!table.lookup(&foo).contains(&other));
|
|
assert!(table.lookup(&bar).contains(&localhost));
|
|
assert!(table.lookup(&bar).contains(&other));
|
|
}
|
|
|
|
#[test]
|
|
fn lookup_cleans_up() {
|
|
let mut table = ResolutionTable::new();
|
|
let foo = Name::from_str("foo").unwrap();
|
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
|
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
|
let short_time = Duration::from_millis(100);
|
|
let long_time = Duration::from_secs(10000000000);
|
|
|
|
table.add_entry(foo.clone(), localhost, long_time);
|
|
table.add_entry(foo.clone(), other, short_time);
|
|
let wait_until = Instant::now() + (2 * short_time);
|
|
while Instant::now() < wait_until {
|
|
std::thread::sleep(short_time);
|
|
}
|
|
|
|
assert_eq!(1, table.lookup(&foo).len());
|
|
assert!(table.lookup(&foo).contains(&localhost));
|
|
assert!(!table.lookup(&foo).contains(&other));
|
|
}
|
|
|
|
#[test]
|
|
fn garbage_collection_works() {
|
|
let mut table = ResolutionTable::new();
|
|
let foo = Name::from_str("foo").unwrap();
|
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
|
let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap());
|
|
let short_time = Duration::from_millis(100);
|
|
let long_time = Duration::from_secs(10000000000);
|
|
|
|
table.add_entry(foo.clone(), localhost, long_time);
|
|
table.add_entry(foo.clone(), other, short_time);
|
|
let wait_until = Instant::now() + (2 * short_time);
|
|
while Instant::now() < wait_until {
|
|
std::thread::sleep(short_time);
|
|
}
|
|
table.garbage_collect();
|
|
|
|
assert_eq!(1, table.inner.get(&foo).unwrap().len());
|
|
}
|
|
|
|
#[test]
|
|
fn garbage_collection_clears_empties() {
|
|
let mut table = ResolutionTable::new();
|
|
let foo = Name::from_str("foo").unwrap();
|
|
table.inner.insert(foo.clone(), vec![]);
|
|
table.garbage_collect();
|
|
assert!(table.inner.is_empty());
|
|
|
|
let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
|
let short_time = Duration::from_millis(100);
|
|
table.add_entry(foo.clone(), localhost, short_time);
|
|
let wait_until = Instant::now() + (2 * short_time);
|
|
while Instant::now() < wait_until {
|
|
std::thread::sleep(short_time);
|
|
}
|
|
table.garbage_collect();
|
|
assert!(table.inner.is_empty());
|
|
}
|