Files
hushd/src/network/host.rs
2024-10-14 17:40:18 -07:00

212 lines
6.4 KiB
Rust

use error_stack::{report, ResultExt};
use futures::stream::{FuturesUnordered, StreamExt};
use hickory_resolver::error::ResolveError;
use hickory_resolver::name_server::ConnectionProvider;
use hickory_resolver::AsyncResolver;
use std::collections::HashSet;
use std::fmt;
use std::net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use thiserror::Error;
use tokio::net::TcpStream;
pub enum Host {
IPv4(Ipv4Addr),
IPv6(Ipv6Addr),
Hostname(String),
}
#[derive(Debug, Error)]
pub enum HostParseError {
#[error("Could not parse IPv6 address {address:?}: {error}")]
CouldNotParseIPv6 {
address: String,
error: AddrParseError,
},
#[error("Invalid hostname {hostname:?}")]
InvalidHostname { hostname: String },
}
impl fmt::Display for Host {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Host::IPv4(x) => x.fmt(f),
Host::IPv6(x) => x.fmt(f),
Host::Hostname(x) => x.fmt(f),
}
}
}
impl FromStr for Host {
type Err = HostParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(addr) = Ipv4Addr::from_str(s) {
return Ok(Host::IPv4(addr));
}
if let Ok(addr) = Ipv6Addr::from_str(s) {
return Ok(Host::IPv6(addr));
}
if s.starts_with('[') && s.ends_with(']') {
match s.trim_start_matches('[').trim_end_matches(']').parse() {
Ok(x) => return Ok(Host::IPv6(x)),
Err(e) => {
return Err(HostParseError::CouldNotParseIPv6 {
address: s.to_string(),
error: e,
})
}
}
}
if hostname_validator::is_valid(s) {
return Ok(Host::Hostname(s.to_string()));
}
Err(HostParseError::InvalidHostname {
hostname: s.to_string(),
})
}
}
#[derive(Debug, Error)]
pub enum ConnectionError {
#[error("Failed to resolve host: {error}")]
ResolveError {
#[from]
error: ResolveError,
},
#[error("No valid IP addresses found")]
NoAddresses,
#[error("Error connecting to host: {error}")]
ConnectionError {
#[from]
error: std::io::Error,
},
}
impl Host {
/// Resolve this host address to a set of underlying IP addresses.
///
/// It is possible that the set of addresses provided may be empty, if the
/// address properly resolves (as in, we get a good DNS response) but there
/// are no relevant records for us to use for IPv4 or IPv6 connections. There
/// is also no guarantee that the host will have both IPv4 and IPv6 addresses,
/// so you may only see one or the other.
pub async fn resolve<P: ConnectionProvider>(
&self,
resolver: &AsyncResolver<P>,
) -> Result<HashSet<IpAddr>, ResolveError> {
match self {
Host::IPv4(addr) => Ok(HashSet::from([IpAddr::V4(*addr)])),
Host::IPv6(addr) => Ok(HashSet::from([IpAddr::V6(*addr)])),
Host::Hostname(name) => {
let resolve_result = resolver.lookup_ip(name).await?;
let possibilities = resolve_result.iter().collect();
Ok(possibilities)
}
}
}
/// Connect to this host and port.
///
/// This routine will attempt to connect to every address provided by the
/// resolver, and return the first successful connection. If all of the
/// connections fail, it will return the first error it receives. This routine
/// will also return an error if there are no addresses to connect to (which
/// can happen in cases in which [`Host::resolve`] would return an empty set.
pub async fn connect<P: ConnectionProvider>(
&self,
resolver: &AsyncResolver<P>,
port: u16,
) -> error_stack::Result<TcpStream, ConnectionError> {
let addresses = self
.resolve(resolver)
.await
.map_err(ConnectionError::from)
.attach_printable_lazy(|| format!("target address {}", self))?;
let mut connectors = FuturesUnordered::new();
for address in addresses.into_iter() {
let connect_future = TcpStream::connect(SocketAddr::new(address, port));
connectors.push(connect_future);
}
let mut error = None;
while let Some(result) = connectors.next().await {
match result {
Err(e) if error.is_none() => error = Some(e),
Err(_) => {}
Ok(v) => return Ok(v),
}
}
let final_error = if let Some(e) = error {
ConnectionError::ConnectionError { error: e }
} else {
ConnectionError::NoAddresses
};
Err(report!(final_error)).attach_printable_lazy(|| format!("target address {}", self))
}
}
#[test]
fn ip4_hosts_work() {
assert!(
matches!(Host::from_str("127.0.0.1"), Ok(Host::IPv4(addr)) if addr == Ipv4Addr::new(127, 0, 0, 1))
);
}
#[test]
fn bare_ip6_hosts_work() {
assert!(matches!(
Host::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
Ok(Host::IPv6(_))
));
assert!(matches!(Host::from_str("2001:db8::1"), Ok(Host::IPv6(_))));
assert!(matches!(Host::from_str("2001:DB8::1"), Ok(Host::IPv6(_))));
assert!(matches!(Host::from_str("::1"), Ok(Host::IPv6(_))));
assert!(matches!(Host::from_str("::"), Ok(Host::IPv6(_))));
}
#[test]
fn wrapped_ip6_hosts_work() {
assert!(matches!(
Host::from_str("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"),
Ok(Host::IPv6(_))
));
assert!(matches!(Host::from_str("[2001:db8::1]"), Ok(Host::IPv6(_))));
assert!(matches!(Host::from_str("[2001:DB8::1]"), Ok(Host::IPv6(_))));
assert!(matches!(Host::from_str("[::1]"), Ok(Host::IPv6(_))));
assert!(matches!(Host::from_str("[::]"), Ok(Host::IPv6(_))));
}
#[test]
fn valid_domains_work() {
assert!(matches!(
Host::from_str("uhsure.com"),
Ok(Host::Hostname(_))
));
assert!(matches!(
Host::from_str("www.cs.indiana.edu"),
Ok(Host::Hostname(_))
));
}
#[test]
fn invalid_inputs_fail() {
assert!(matches!(
Host::from_str("[uhsure.com]"),
Err(HostParseError::CouldNotParseIPv6 { .. })
));
assert!(matches!(
Host::from_str("-uhsure.com"),
Err(HostParseError::InvalidHostname { .. })
));
}