212 lines
6.4 KiB
Rust
212 lines
6.4 KiB
Rust
use error_stack::{report, ResultExt};
|
|
use futures::stream::{FuturesUnordered, StreamExt};
|
|
use hickory_resolver::error::ResolveError;
|
|
use hickory_resolver::name_server::ConnectionProvider;
|
|
use hickory_resolver::AsyncResolver;
|
|
use std::collections::HashSet;
|
|
use std::fmt;
|
|
use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
|
use std::str::FromStr;
|
|
use thiserror::Error;
|
|
use tokio::net::TcpStream;
|
|
|
|
pub enum Host {
|
|
IPv4(Ipv4Addr),
|
|
IPv6(Ipv6Addr),
|
|
Hostname(String),
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum HostParseError {
|
|
#[error("Could not parse IPv6 address {address:?}: {error}")]
|
|
CouldNotParseIPv6 {
|
|
address: String,
|
|
error: AddrParseError,
|
|
},
|
|
#[error("Invalid hostname {hostname:?}")]
|
|
InvalidHostname { hostname: String },
|
|
}
|
|
|
|
impl fmt::Display for Host {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Host::IPv4(x) => x.fmt(f),
|
|
Host::IPv6(x) => x.fmt(f),
|
|
Host::Hostname(x) => x.fmt(f),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FromStr for Host {
|
|
type Err = HostParseError;
|
|
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
if let Ok(addr) = Ipv4Addr::from_str(s) {
|
|
return Ok(Host::IPv4(addr));
|
|
}
|
|
|
|
if let Ok(addr) = Ipv6Addr::from_str(s) {
|
|
return Ok(Host::IPv6(addr));
|
|
}
|
|
|
|
if s.starts_with('[') && s.ends_with(']') {
|
|
match s.trim_start_matches('[').trim_end_matches(']').parse() {
|
|
Ok(x) => return Ok(Host::IPv6(x)),
|
|
Err(e) => {
|
|
return Err(HostParseError::CouldNotParseIPv6 {
|
|
address: s.to_string(),
|
|
error: e,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
if hostname_validator::is_valid(s) {
|
|
return Ok(Host::Hostname(s.to_string()));
|
|
}
|
|
|
|
Err(HostParseError::InvalidHostname {
|
|
hostname: s.to_string(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum ConnectionError {
|
|
#[error("Failed to resolve host: {error}")]
|
|
ResolveError {
|
|
#[from]
|
|
error: ResolveError,
|
|
},
|
|
#[error("No valid IP addresses found")]
|
|
NoAddresses,
|
|
#[error("Error connecting to host: {error}")]
|
|
ConnectionError {
|
|
#[from]
|
|
error: std::io::Error,
|
|
},
|
|
}
|
|
|
|
impl Host {
|
|
/// Resolve this host address to a set of underlying IP addresses.
|
|
///
|
|
/// It is possible that the set of addresses provided may be empty, if the
|
|
/// address properly resolves (as in, we get a good DNS response) but there
|
|
/// are no relevant records for us to use for IPv4 or IPv6 connections. There
|
|
/// is also no guarantee that the host will have both IPv4 and IPv6 addresses,
|
|
/// so you may only see one or the other.
|
|
pub async fn resolve<P: ConnectionProvider>(
|
|
&self,
|
|
resolver: &AsyncResolver<P>,
|
|
) -> Result<HashSet<IpAddr>, ResolveError> {
|
|
match self {
|
|
Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])),
|
|
Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])),
|
|
Host::Hostname(name) => {
|
|
let resolve_result = resolver.lookup_ip(name).await?;
|
|
let possibilities = resolve_result.iter().collect();
|
|
Ok(possibilities)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Connect to this host and port.
|
|
///
|
|
/// This routine will attempt to connect to every address provided by the
|
|
/// resolver, and return the first successful connection. If all of the
|
|
/// connections fail, it will return the first error it receives. This routine
|
|
/// will also return an error if there are no addresses to connect to (which
|
|
/// can happen in cases in which [`Host::resolve`] would return an empty set.
|
|
pub async fn connect<P: ConnectionProvider>(
|
|
&self,
|
|
resolver: &AsyncResolver<P>,
|
|
port: u16,
|
|
) -> error_stack::Result<TcpStream, ConnectionError> {
|
|
let addresses = self
|
|
.resolve(resolver)
|
|
.await
|
|
.map_err(ConnectionError::from)
|
|
.attach_printable_lazy(|| format!("target address {}", self))?;
|
|
|
|
let mut connectors = FuturesUnordered::new();
|
|
|
|
for address in addresses.into_iter() {
|
|
let connect_future = TcpStream::connect(SocketAddr::new(address, port));
|
|
connectors.push(connect_future);
|
|
}
|
|
|
|
let mut error = None;
|
|
|
|
while let Some(result) = connectors.next().await {
|
|
match result {
|
|
Err(e) if error.is_none() => error = Some(e),
|
|
Err(_) => {}
|
|
Ok(v) => return Ok(v),
|
|
}
|
|
}
|
|
|
|
let final_error = if let Some(e) = error {
|
|
ConnectionError::ConnectionError { error: e }
|
|
} else {
|
|
ConnectionError::NoAddresses
|
|
};
|
|
|
|
Err(report!(final_error)).attach_printable_lazy(|| format!("target address {}", self))
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn ip4_hosts_work() {
|
|
assert!(
|
|
matches!(Host::from_str("127.0.0.1"), Ok(Host::IPv4(addr)) if addr == Ipv4Addr::new(127, 0, 0, 1))
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn bare_ip6_hosts_work() {
|
|
assert!(matches!(
|
|
Host::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
|
Ok(Host::IPv6(_))
|
|
));
|
|
assert!(matches!(Host::from_str("2001:db8::1"), Ok(Host::IPv6(_))));
|
|
assert!(matches!(Host::from_str("2001:DB8::1"), Ok(Host::IPv6(_))));
|
|
assert!(matches!(Host::from_str("::1"), Ok(Host::IPv6(_))));
|
|
assert!(matches!(Host::from_str("::"), Ok(Host::IPv6(_))));
|
|
}
|
|
|
|
#[test]
|
|
fn wrapped_ip6_hosts_work() {
|
|
assert!(matches!(
|
|
Host::from_str("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"),
|
|
Ok(Host::IPv6(_))
|
|
));
|
|
assert!(matches!(Host::from_str("[2001:db8::1]"), Ok(Host::IPv6(_))));
|
|
assert!(matches!(Host::from_str("[2001:DB8::1]"), Ok(Host::IPv6(_))));
|
|
assert!(matches!(Host::from_str("[::1]"), Ok(Host::IPv6(_))));
|
|
assert!(matches!(Host::from_str("[::]"), Ok(Host::IPv6(_))));
|
|
}
|
|
|
|
#[test]
|
|
fn valid_domains_work() {
|
|
assert!(matches!(
|
|
Host::from_str("uhsure.com"),
|
|
Ok(Host::Hostname(_))
|
|
));
|
|
assert!(matches!(
|
|
Host::from_str("www.cs.indiana.edu"),
|
|
Ok(Host::Hostname(_))
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn invalid_inputs_fail() {
|
|
assert!(matches!(
|
|
Host::from_str("[uhsure.com]"),
|
|
Err(HostParseError::CouldNotParseIPv6 { .. })
|
|
));
|
|
assert!(matches!(
|
|
Host::from_str("-uhsure.com"),
|
|
Err(HostParseError::InvalidHostname { .. })
|
|
));
|
|
}
|