use error_stack::{report, ResultExt}; use futures::stream::{FuturesUnordered, StreamExt}; use hickory_resolver::error::ResolveError; use hickory_resolver::name_server::ConnectionProvider; use hickory_resolver::AsyncResolver; use std::collections::HashSet; use std::fmt; use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::str::FromStr; use thiserror::Error; use tokio::net::TcpStream; pub enum Host { IPv4(Ipv4Addr), IPv6(Ipv6Addr), Hostname(String), } #[derive(Debug, Error)] pub enum HostParseError { #[error("Could not parse IPv6 address {address:?}: {error}")] CouldNotParseIPv6 { address: String, error: AddrParseError, }, #[error("Invalid hostname {hostname:?}")] InvalidHostname { hostname: String }, } impl fmt::Display for Host { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Host::IPv4(x) => x.fmt(f), Host::IPv6(x) => x.fmt(f), Host::Hostname(x) => x.fmt(f), } } } impl FromStr for Host { type Err = HostParseError; fn from_str(s: &str) -> Result { if let Ok(addr) = Ipv4Addr::from_str(s) { return Ok(Host::IPv4(addr)); } if let Ok(addr) = Ipv6Addr::from_str(s) { return Ok(Host::IPv6(addr)); } if s.starts_with('[') && s.ends_with(']') { match s.trim_start_matches('[').trim_end_matches(']').parse() { Ok(x) => return Ok(Host::IPv6(x)), Err(e) => { return Err(HostParseError::CouldNotParseIPv6 { address: s.to_string(), error: e, }) } } } if hostname_validator::is_valid(s) { return Ok(Host::Hostname(s.to_string())); } Err(HostParseError::InvalidHostname { hostname: s.to_string(), }) } } #[derive(Debug, Error)] pub enum ConnectionError { #[error("Failed to resolve host: {error}")] ResolveError { #[from] error: ResolveError, }, #[error("No valid IP addresses found")] NoAddresses, #[error("Error connecting to host: {error}")] ConnectionError { #[from] error: std::io::Error, }, } impl Host { /// Resolve this host address to a set of underlying IP addresses. /// /// It is possible that the set of addresses provided may be empty, if the /// address properly resolves (as in, we get a good DNS response) but there /// are no relevant records for us to use for IPv4 or IPv6 connections. There /// is also no guarantee that the host will have both IPv4 and IPv6 addresses, /// so you may only see one or the other. pub async fn resolve( &self, resolver: &AsyncResolver

, ) -> Result, ResolveError> { match self { Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])), Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])), Host::Hostname(name) => { let resolve_result = resolver.lookup_ip(name).await?; let possibilities = resolve_result.iter().collect(); Ok(possibilities) } } } /// Connect to this host and port. /// /// This routine will attempt to connect to every address provided by the /// resolver, and return the first successful connection. If all of the /// connections fail, it will return the first error it receives. This routine /// will also return an error if there are no addresses to connect to (which /// can happen in cases in which [`Host::resolve`] would return an empty set. pub async fn connect( &self, resolver: &AsyncResolver

, port: u16, ) -> error_stack::Result { let addresses = self .resolve(resolver) .await .map_err(ConnectionError::from) .attach_printable_lazy(|| format!("target address {}", self))?; let mut connectors = FuturesUnordered::new(); for address in addresses.into_iter() { let connect_future = TcpStream::connect(SocketAddr::new(address, port)); connectors.push(connect_future); } let mut error = None; while let Some(result) = connectors.next().await { match result { Err(e) if error.is_none() => error = Some(e), Err(_) => {} Ok(v) => return Ok(v), } } let final_error = if let Some(e) = error { ConnectionError::ConnectionError { error: e } } else { ConnectionError::NoAddresses }; Err(report!(final_error)).attach_printable_lazy(|| format!("target address {}", self)) } } #[test] fn ip4_hosts_work() { assert!( matches!(Host::from_str("127.0.0.1"), Ok(Host::IPv4(addr)) if addr == Ipv4Addr::new(127, 0, 0, 1)) ); } #[test] fn bare_ip6_hosts_work() { assert!(matches!( Host::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), Ok(Host::IPv6(_)) )); assert!(matches!(Host::from_str("2001:db8::1"), Ok(Host::IPv6(_)))); assert!(matches!(Host::from_str("2001:DB8::1"), Ok(Host::IPv6(_)))); assert!(matches!(Host::from_str("::1"), Ok(Host::IPv6(_)))); assert!(matches!(Host::from_str("::"), Ok(Host::IPv6(_)))); } #[test] fn wrapped_ip6_hosts_work() { assert!(matches!( Host::from_str("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"), Ok(Host::IPv6(_)) )); assert!(matches!(Host::from_str("[2001:db8::1]"), Ok(Host::IPv6(_)))); assert!(matches!(Host::from_str("[2001:DB8::1]"), Ok(Host::IPv6(_)))); assert!(matches!(Host::from_str("[::1]"), Ok(Host::IPv6(_)))); assert!(matches!(Host::from_str("[::]"), Ok(Host::IPv6(_)))); } #[test] fn valid_domains_work() { assert!(matches!( Host::from_str("uhsure.com"), Ok(Host::Hostname(_)) )); assert!(matches!( Host::from_str("www.cs.indiana.edu"), Ok(Host::Hostname(_)) )); } #[test] fn invalid_inputs_fail() { assert!(matches!( Host::from_str("[uhsure.com]"), Err(HostParseError::CouldNotParseIPv6 { .. }) )); assert!(matches!( Host::from_str("-uhsure.com"), Err(HostParseError::InvalidHostname { .. }) )); }