use crate::network::resolver::name::Name; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use tokio::time::{Duration, Instant}; pub struct ResolutionTable { inner: HashMap>, } struct Resolution { result: IpAddr, expiration: Instant, } impl Default for ResolutionTable { fn default() -> Self { ResolutionTable::new() } } impl ResolutionTable { /// Generate a new, empty resolution table to use in a new DNS implementation, /// or shared by a bunch of them. pub fn new() -> Self { ResolutionTable { inner: HashMap::new(), } } /// Clean the table of expired entries. pub fn garbage_collect(&mut self) { let now = Instant::now(); self.inner.retain(|_, items| { items.retain(|x| x.expiration > now); !items.is_empty() }); } /// Add a new entry to the resolution table, with a TTL on it. pub fn add_entry(&mut self, name: Name, maps_to: IpAddr, ttl: Duration) { let now = Instant::now(); let new_entry = Resolution { result: maps_to, expiration: now + ttl, }; match self.inner.entry(name) { Entry::Vacant(vac) => { vac.insert(vec![new_entry]); } Entry::Occupied(mut occ) => { occ.get_mut().push(new_entry); } } } /// Look up an entry in the resolution map. This will only return /// unexpired items. pub fn lookup(&mut self, name: &Name) -> HashSet { let mut result = HashSet::new(); let now = Instant::now(); if let Some(entry) = self.inner.get_mut(name) { entry.retain(|x| { let retain = x.expiration > now; if retain { result.insert(x.result); } retain }); } result } } #[cfg(test)] use std::net::{Ipv4Addr, Ipv6Addr}; #[cfg(test)] use std::str::FromStr; #[test] fn empty_set_gets_fail() { let mut empty = ResolutionTable::default(); assert!(empty.lookup(&Name::from_str("foo").unwrap()).is_empty()); assert!(empty.lookup(&Name::from_str("bar").unwrap()).is_empty()); } #[test] fn basic_lookups() { let mut table = ResolutionTable::new(); let foo = Name::from_str("foo").unwrap(); let bar = Name::from_str("bar").unwrap(); let baz = Name::from_str("baz").unwrap(); let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap()); let long_time = Duration::from_secs(10000000000); table.add_entry(foo.clone(), localhost, long_time); table.add_entry(bar.clone(), localhost, long_time); table.add_entry(bar.clone(), other, long_time); assert_eq!(1, table.lookup(&foo).len()); assert_eq!(2, table.lookup(&bar).len()); assert!(table.lookup(&baz).is_empty()); assert!(table.lookup(&foo).contains(&localhost)); assert!(!table.lookup(&foo).contains(&other)); assert!(table.lookup(&bar).contains(&localhost)); assert!(table.lookup(&bar).contains(&other)); } #[test] fn lookup_cleans_up() { let mut table = ResolutionTable::new(); let foo = Name::from_str("foo").unwrap(); let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap()); let short_time = Duration::from_millis(100); let long_time = Duration::from_secs(10000000000); table.add_entry(foo.clone(), localhost, long_time); table.add_entry(foo.clone(), other, short_time); let wait_until = Instant::now() + (2 * short_time); while Instant::now() < wait_until { std::thread::sleep(short_time); } assert_eq!(1, table.lookup(&foo).len()); assert!(table.lookup(&foo).contains(&localhost)); assert!(!table.lookup(&foo).contains(&other)); } #[test] fn garbage_collection_works() { let mut table = ResolutionTable::new(); let foo = Name::from_str("foo").unwrap(); let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); let other = IpAddr::V6(Ipv6Addr::from_str("2001:18e8:2:e::11e").unwrap()); let short_time = Duration::from_millis(100); let long_time = Duration::from_secs(10000000000); table.add_entry(foo.clone(), localhost, long_time); table.add_entry(foo.clone(), other, short_time); let wait_until = Instant::now() + (2 * short_time); while Instant::now() < wait_until { std::thread::sleep(short_time); } table.garbage_collect(); assert_eq!(1, table.inner.get(&foo).unwrap().len()); } #[test] fn garbage_collection_clears_empties() { let mut table = ResolutionTable::new(); let foo = Name::from_str("foo").unwrap(); table.inner.insert(foo.clone(), vec![]); table.garbage_collect(); assert!(table.inner.is_empty()); let localhost = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); let short_time = Duration::from_millis(100); table.add_entry(foo.clone(), localhost, short_time); let wait_until = Instant::now() + (2 * short_time); while Instant::now() < wait_until { std::thread::sleep(short_time); } table.garbage_collect(); assert!(table.inner.is_empty()); }