Switch to basic tokio; will expand later to arbitrary backends.

This commit is contained in:
2022-05-14 17:59:28 -07:00
parent d284f60d67
commit c8279cfc5f
29 changed files with 1472 additions and 2671 deletions

View File

@@ -8,13 +8,11 @@ edition = "2018"
name = "async_socks5" name = "async_socks5"
[dependencies] [dependencies]
async-std = { version = "1.9.0", features = ["attributes"] } anyhow = "^1.0.57"
async-trait = "0.1.50" proptest = "^1.0.0"
futures = "0.3.15" thiserror = "^1.0.31"
log = "0.4.8" tokio = { version = "^1", features = ["full"] }
proptest = "1.0.0" tracing = "^0.1.34"
simplelog = "0.10.0"
thiserror = "1.0.24"
[dev-dependencies] [dev-dependencies]
proptest = "1.0.0" proptest = "1.0.0"

174
src/address.rs Normal file
View File

@@ -0,0 +1,174 @@
use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum SOCKSv5Address {
IP4(Ipv4Addr),
IP6(Ipv6Addr),
Hostname(String),
}
impl From<IpAddr> for SOCKSv5Address {
fn from(x: IpAddr) -> SOCKSv5Address {
match x {
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
}
}
}
impl From<Ipv4Addr> for SOCKSv5Address {
fn from(x: Ipv4Addr) -> SOCKSv5Address {
SOCKSv5Address::IP4(x)
}
}
impl From<Ipv6Addr> for SOCKSv5Address {
fn from(x: Ipv6Addr) -> SOCKSv5Address {
SOCKSv5Address::IP6(x)
}
}
impl From<SOCKSv5String> for SOCKSv5Address {
fn from(x: SOCKSv5String) -> SOCKSv5Address {
SOCKSv5Address::Hostname(x.into())
}
}
impl<'a> From<&'a str> for SOCKSv5Address {
fn from(x: &str) -> SOCKSv5Address {
SOCKSv5Address::Hostname(x.to_string())
}
}
impl From<String> for SOCKSv5Address {
fn from(x: String) -> SOCKSv5Address {
SOCKSv5Address::Hostname(x)
}
}
impl fmt::Display for SOCKSv5Address {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
SOCKSv5Address::Hostname(a) => write!(f, "{}", a),
}
}
}
#[cfg(test)]
const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+";
#[cfg(test)]
use proptest::prelude::{any, prop_oneof, Arbitrary, BoxedStrategy, Strategy};
#[cfg(test)]
impl Arbitrary for SOCKSv5Address {
type Parameters = Option<u16>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let max_len = args.unwrap_or(32) as usize;
prop_oneof![
any::<Ipv4Addr>().prop_map(SOCKSv5Address::IP4),
any::<Ipv6Addr>().prop_map(SOCKSv5Address::IP6),
HOSTNAME_REGEX.prop_map(move |mut hostname| {
hostname.shrink_to(max_len);
SOCKSv5Address::Hostname(hostname)
}),
]
.boxed()
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5AddressReadError {
#[error("Bad address type {0} (expected 1, 3, or 4)")]
BadAddressType(u8),
#[error("Read buffer error: {0}")]
ReadError(String),
#[error(transparent)]
SOCKSv5StringError(#[from] SOCKSv5StringReadError),
}
impl From<std::io::Error> for SOCKSv5AddressReadError {
fn from(x: std::io::Error) -> SOCKSv5AddressReadError {
SOCKSv5AddressReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5AddressWriteError {
#[error(transparent)]
SOCKSv5StringError(#[from] SOCKSv5StringWriteError),
#[error("Write buffer error: {0}")]
WriteError(String),
}
impl From<std::io::Error> for SOCKSv5AddressWriteError {
fn from(x: std::io::Error) -> SOCKSv5AddressWriteError {
SOCKSv5AddressWriteError::WriteError(format!("{}", x))
}
}
impl SOCKSv5Address {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<Self, SOCKSv5AddressReadError> {
match r.read_u8().await? {
1 => {
let mut addr_buffer = [0; 4];
r.read_exact(&mut addr_buffer).await?;
let ip4 = Ipv4Addr::from(addr_buffer);
Ok(SOCKSv5Address::IP4(ip4))
}
3 => {
let string = SOCKSv5String::read(r).await?;
Ok(SOCKSv5Address::from(string))
}
4 => {
let mut addr_buffer = [0; 16];
r.read_exact(&mut addr_buffer).await?;
let ip6 = Ipv6Addr::from(addr_buffer);
Ok(SOCKSv5Address::IP6(ip6))
}
x => Err(SOCKSv5AddressReadError::BadAddressType(x)),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SOCKSv5AddressWriteError> {
match self {
SOCKSv5Address::IP4(x) => {
w.write_u8(1).await?;
w.write_all(&x.octets()).await?;
Ok(())
}
SOCKSv5Address::IP6(x) => {
w.write_u8(4).await?;
w.write_all(&x.octets()).await?;
Ok(())
}
SOCKSv5Address::Hostname(x) => {
w.write_u8(3).await?;
let string = SOCKSv5String::from(x.clone());
string.write(w).await?;
Ok(())
}
}
}
}
crate::standard_roundtrip!(socks_address_roundtrips, SOCKSv5Address);

View File

@@ -1,36 +1,3 @@
use async_socks5::network::Builtin; fn main() -> Result<(), ()> {
use async_socks5::server::{SOCKSv5Server, SecurityParameters};
use async_std::io;
use futures::stream::StreamExt;
use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TermLogger, TerminalMode};
#[async_std::main]
async fn main() -> Result<(), io::Error> {
CombinedLogger::init(vec![TermLogger::new(
LevelFilter::Debug,
Config::default(),
TerminalMode::Mixed,
ColorChoice::Auto,
)])
.expect("Couldn't initialize logger");
let params = SecurityParameters {
allow_unauthenticated: true,
allow_connection: None,
check_password: None,
connect_tls: None,
};
let mut server = SOCKSv5Server::new(Builtin::new(), params);
server.start("127.0.0.1", 9999).await?;
let mut responses = Box::pin(server.subserver_results());
while let Some(response) = responses.next().await {
if let Err(e) = response {
println!("Server failed with: {}", e);
}
}
Ok(()) Ok(())
} }

View File

@@ -1,57 +1,47 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::address::SOCKSv5Address;
use crate::messages::{ use crate::messages::{
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting, AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandWriteError,
ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus, ClientConnectionRequest, ClientGreeting, ClientGreetingWriteError, ClientUsernamePassword,
ClientUsernamePasswordWriteError, ServerAuthResponse, ServerAuthResponseReadError,
ServerChoice, ServerChoiceReadError, ServerResponse, ServerResponseReadError,
ServerResponseStatus,
}; };
use crate::network::datagram::GenericDatagramSocket; use std::future::Future;
use crate::network::generic::{IntoErrorResponse, Networklike};
use crate::network::listener::GenericListener;
use crate::network::stream::GenericStream;
use crate::network::SOCKSv5Address;
use async_std::io;
use async_std::sync::{Arc, Mutex};
use async_trait::async_trait;
use futures::Future;
use log::{info, trace, warn};
use std::fmt::{Debug, Display};
use thiserror::Error; use thiserror::Error;
use tokio::net::TcpStream;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum SOCKSv5Error<E: Debug + Display> { pub enum SOCKSv5ClientError {
#[error("SOCKSv5 serialization error: {0}")] #[error("Underlying networking error: {0}")]
SerializationError(#[from] SerializationError), NetworkingError(String),
#[error("SOCKSv5 deserialization error: {0}")] #[error("Client greeting write error: {0}")]
DeserializationError(#[from] DeserializationError), ClientWriteError(#[from] ClientGreetingWriteError),
#[error("No acceptable authentication methods available")] #[error("Server choice error: {0}")]
NoAuthMethodsAllowed, ServerChoiceError(#[from] ServerChoiceReadError),
#[error("Error writing credentials: {0}")]
CredentialWriteError(#[from] ClientUsernamePasswordWriteError),
#[error("Server auth read error: {0}")]
AuthResponseError(#[from] ServerAuthResponseReadError),
#[error("Authentication failed")] #[error("Authentication failed")]
AuthenticationFailed, AuthenticationFailed,
#[error("Server chose an unsupported authentication method ({0}")] #[error("No authentication methods allowed")]
NoAuthMethodsAllowed,
#[error("Unsupported authentication method chosen ({0})")]
UnsupportedAuthMethodChosen(AuthenticationMethod), UnsupportedAuthMethodChosen(AuthenticationMethod),
#[error("Client connection command write error: {0}")]
ClientCommandWriteError(#[from] ClientConnectionCommandWriteError),
#[error("Server said no: {0}")] #[error("Server said no: {0}")]
ServerFailure(#[from] ServerResponseStatus), ServerRejected(#[from] ServerResponseStatus),
#[error("Connection error: {0}")] #[error("Server response read failure: {0}")]
ConnectionError(#[from] io::Error), ServerResponseError(#[from] ServerResponseReadError),
#[error("Underlying network error: {0}")]
UnderlyingNetwork(E),
} }
impl<E: Debug + Display> IntoErrorResponse for SOCKSv5Error<E> { impl From<std::io::Error> for SOCKSv5ClientError {
fn into_response(&self) -> ServerResponseStatus { fn from(x: std::io::Error) -> SOCKSv5ClientError {
match self { SOCKSv5ClientError::NetworkingError(format!("{}", x))
SOCKSv5Error::ServerFailure(v) => v.clone(),
_ => ServerResponseStatus::GeneralFailure,
}
} }
} }
pub struct SOCKSv5Client<N: Networklike + Sync> {
network: Arc<Mutex<N>>,
login_info: LoginInfo,
address: SOCKSv5Address,
port: u16,
}
pub struct LoginInfo { pub struct LoginInfo {
pub username_password: Option<UsernamePassword>, pub username_password: Option<UsernamePassword>,
} }
@@ -64,7 +54,7 @@ impl Default for LoginInfo {
impl LoginInfo { impl LoginInfo {
/// Generate an empty bit of login information. /// Generate an empty bit of login information.
fn new() -> LoginInfo { pub fn new() -> LoginInfo {
LoginInfo { LoginInfo {
username_password: None, username_password: None,
} }
@@ -89,22 +79,24 @@ pub struct UsernamePassword {
pub password: String, pub password: String,
} }
impl<N> SOCKSv5Client<N> pub struct SOCKSv5Client {
where login_info: LoginInfo,
N: Networklike + Sync, address: SOCKSv5Address,
{ port: u16,
}
impl SOCKSv5Client {
/// Create a new SOCKSv5 client connection over the given steam, using the given /// Create a new SOCKSv5 client connection over the given steam, using the given
/// authentication information. As part of the process of building this object, we /// authentication information. As part of the process of building this object, we
/// do a little test run to make sure that we can login effectively; this should save /// do a little test run to make sure that we can login effectively; this should save
/// from *some* surprises later on. If you'd rather *not* do that, though, you can /// from *some* surprises later on. If you'd rather *not* do that, though, you can
/// try `unchecked_new`. /// try `unchecked_new`.
pub async fn new<A: Into<SOCKSv5Address>>( pub async fn new<A: Into<SOCKSv5Address>>(
network: N,
login: LoginInfo, login: LoginInfo,
server_addr: A, server_addr: A,
server_port: u16, server_port: u16,
) -> Result<Self, SOCKSv5Error<N::Error>> { ) -> Result<Self, SOCKSv5ClientError> {
let base_version = SOCKSv5Client::unchecked_new(network, login, server_addr, server_port); let base_version = SOCKSv5Client::unchecked_new(login, server_addr, server_port);
let _ = base_version.start_session().await?; let _ = base_version.start_session().await?;
Ok(base_version) Ok(base_version)
} }
@@ -113,13 +105,11 @@ where
/// connection sequence at the expense of an increased possibility of an error /// connection sequence at the expense of an increased possibility of an error
/// later on down the road. /// later on down the road.
pub fn unchecked_new<A: Into<SOCKSv5Address>>( pub fn unchecked_new<A: Into<SOCKSv5Address>>(
network: N,
login_info: LoginInfo, login_info: LoginInfo,
address: A, address: A,
port: u16, port: u16,
) -> Self { ) -> Self {
SOCKSv5Client { SOCKSv5Client {
network: Arc::new(Mutex::new(network)),
login_info, login_info,
address: address.into(), address: address.into(),
port, port,
@@ -128,17 +118,17 @@ where
/// This runs the connection and negotiates login, as required, and then returns /// This runs the connection and negotiates login, as required, and then returns
/// the stream the caller should use to do ... whatever it wants to do. /// the stream the caller should use to do ... whatever it wants to do.
async fn start_session(&self) -> Result<GenericStream, SOCKSv5Error<N::Error>> { async fn start_session(&self) -> Result<TcpStream, SOCKSv5ClientError> {
// create the initial stream // create the initial stream
let mut stream = { let mut stream = match &self.address {
let mut network = self.network.lock().await; SOCKSv5Address::IP4(x) => TcpStream::connect((*x, self.port)).await?,
network.connect(self.address.clone(), self.port).await SOCKSv5Address::IP6(x) => TcpStream::connect((*x, self.port)).await?,
} SOCKSv5Address::Hostname(x) => TcpStream::connect((x.as_ref(), self.port)).await?,
.map_err(SOCKSv5Error::UnderlyingNetwork)?; };
// compute how we can log in // compute how we can log in
let acceptable_methods = self.login_info.acceptable_methods(); let acceptable_methods = self.login_info.acceptable_methods();
trace!( tracing::trace!(
"Computed acceptable methods -- {:?} -- sending client greeting.", "Computed acceptable methods -- {:?} -- sending client greeting.",
acceptable_methods acceptable_methods
); );
@@ -146,9 +136,9 @@ where
// Negotiate with the server. Well. "Negotiate." // Negotiate with the server. Well. "Negotiate."
let client_greeting = ClientGreeting { acceptable_methods }; let client_greeting = ClientGreeting { acceptable_methods };
client_greeting.write(&mut stream).await?; client_greeting.write(&mut stream).await?;
trace!("Write client greeting, waiting for server's choice."); tracing::trace!("Write client greeting, waiting for server's choice.");
let server_choice = ServerChoice::read(&mut stream).await?; let server_choice = ServerChoice::read(&mut stream).await?;
trace!("Received server's choice: {}", server_choice.chosen_method); tracing::trace!("Received server's choice: {}", server_choice.chosen_method);
// Let's do it! // Let's do it!
match server_choice.chosen_method { match server_choice.chosen_method {
@@ -158,30 +148,32 @@ where
let (username, password) = if let Some(ref linfo) = let (username, password) = if let Some(ref linfo) =
self.login_info.username_password self.login_info.username_password
{ {
trace!("Server requested username/password, getting data from login info."); tracing::trace!(
"Server requested username/password, getting data from login info."
);
(linfo.username.clone(), linfo.password.clone()) (linfo.username.clone(), linfo.password.clone())
} else { } else {
warn!("Server requested username/password, but we weren't provided one. Very weird."); tracing::warn!("Server requested username/password, but we weren't provided one. Very weird.");
("".to_string(), "".to_string()) ("".to_string(), "".to_string())
}; };
let auth_request = ClientUsernamePassword { username, password }; let auth_request = ClientUsernamePassword { username, password };
trace!("Writing password information."); tracing::trace!("Writing password information.");
auth_request.write(&mut stream).await?; auth_request.write(&mut stream).await?;
let server_response = ServerAuthResponse::read(&mut stream).await?; let server_response = ServerAuthResponse::read(&mut stream).await?;
trace!("Got server response: {}", server_response.success); tracing::trace!("Got server response: {}", server_response.success);
if !server_response.success { if !server_response.success {
return Err(SOCKSv5Error::AuthenticationFailed); return Err(SOCKSv5ClientError::AuthenticationFailed);
} }
} }
AuthenticationMethod::NoAcceptableMethods => { AuthenticationMethod::NoAcceptableMethods => {
return Err(SOCKSv5Error::NoAuthMethodsAllowed) return Err(SOCKSv5ClientError::NoAuthMethodsAllowed)
} }
x => return Err(SOCKSv5Error::UnsupportedAuthMethodChosen(x)), x => return Err(SOCKSv5ClientError::UnsupportedAuthMethodChosen(x)),
} }
Ok(stream) Ok(stream)
@@ -193,12 +185,12 @@ where
/// person to listen on. So this function takes an async function, which it /// person to listen on. So this function takes an async function, which it
/// will pass this information to once it has it. It's up to that function, /// will pass this information to once it has it. It's up to that function,
/// then, to communicate this to its peer. /// then, to communicate this to its peer.
pub async fn remote_listen<A, Fut: Future<Output = Result<(), SOCKSv5Error<N::Error>>>>( pub async fn remote_listen<A, Fut: Future<Output = Result<(), SOCKSv5ClientError>>>(
self, self,
addr: A, addr: A,
port: u16, port: u16,
callback: impl FnOnce(SOCKSv5Address, u16) -> Fut, callback: impl FnOnce(SOCKSv5Address, u16) -> Fut,
) -> Result<(SOCKSv5Address, u16, GenericStream), SOCKSv5Error<N::Error>> ) -> Result<(SOCKSv5Address, u16, TcpStream), SOCKSv5ClientError>
where where
A: Into<SOCKSv5Address>, A: Into<SOCKSv5Address>,
{ {
@@ -217,9 +209,12 @@ where
return Err(initial_response.status.into()); return Err(initial_response.status.into());
} }
info!( tracing::info!(
"Proxy port binding of {}:{} established; server listening on {}:{}", "Proxy port binding of {}:{} established; server listening on {}:{}",
target, port, initial_response.bound_address, initial_response.bound_port target,
port,
initial_response.bound_address,
initial_response.bound_port
); );
callback(initial_response.bound_address, initial_response.bound_port).await?; callback(initial_response.bound_address, initial_response.bound_port).await?;
@@ -229,9 +224,10 @@ where
return Err(secondary_response.status.into()); return Err(secondary_response.status.into());
} }
info!( tracing::info!(
"Proxy bind got a connection from {}:{}", "Proxy bind got a connection from {}:{}",
secondary_response.bound_address, secondary_response.bound_port secondary_response.bound_address,
secondary_response.bound_port
); );
Ok(( Ok((
@@ -240,20 +236,12 @@ where
stream, stream,
)) ))
} }
}
#[async_trait] pub async fn connect<A: Send + Into<SOCKSv5Address>>(
impl<N> Networklike for SOCKSv5Client<N>
where
N: Networklike + Sync + Send,
{
type Error = SOCKSv5Error<N::Error>;
async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self, &mut self,
addr: A, addr: A,
port: u16, port: u16,
) -> Result<GenericStream, Self::Error> { ) -> Result<TcpStream, SOCKSv5ClientError> {
let mut stream = self.start_session().await?; let mut stream = self.start_session().await?;
let target = addr.into(); let target = addr.into();
@@ -266,29 +254,16 @@ where
let response = ServerResponse::read(&mut stream).await?; let response = ServerResponse::read(&mut stream).await?;
if response.status == ServerResponseStatus::RequestGranted { if response.status == ServerResponseStatus::RequestGranted {
info!( tracing::info!(
"Proxy connection to {}:{} established; server is using {}:{}", "Proxy connection to {}:{} established; server is using {}:{}",
target, port, response.bound_address, response.bound_port target,
port,
response.bound_address,
response.bound_port
); );
Ok(stream) Ok(stream)
} else { } else {
Err(response.status.into()) Err(response.status.into())
} }
} }
async fn listen<A: Send + Into<SOCKSv5Address>>(
&mut self,
_addr: A,
_port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error> {
unimplemented!()
}
async fn bind<A: Send + Into<SOCKSv5Address>>(
&mut self,
_addr: A,
_port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
unimplemented!()
}
} }

View File

@@ -1,217 +0,0 @@
use std::io;
use std::string::FromUtf8Error;
use thiserror::Error;
use crate::network::SOCKSv5Address;
/// All the errors that can pop up when trying to turn raw bytes into SOCKSv5
/// messages.
#[derive(Error, Debug)]
pub enum DeserializationError {
#[error("Invalid protocol version for packet ({1} is not {0}!)")]
InvalidVersion(u8, u8),
#[error("Not enough data found")]
NotEnoughData,
#[error("Ooops! Found an empty string where I shouldn't")]
InvalidEmptyString,
#[error("IO error: {0}")]
IOError(#[from] io::Error),
#[error("SOCKS authentication format parse error: {0}")]
AuthenticationMethodError(#[from] AuthenticationDeserializationError),
#[error("Error converting from UTF-8: {0}")]
UTF8Error(#[from] FromUtf8Error),
#[error("Invalid address type; wanted 1, 3, or 4, got {0}")]
InvalidAddressType(u8),
#[error("Invalid client command {0}; expected 1, 2, or 3")]
InvalidClientCommand(u8),
#[error("Invalid server status {0}; expected 0-8")]
InvalidServerResponse(u8),
#[error("Invalid byte in reserved byte ({0})")]
InvalidReservedByte(u8),
}
#[test]
fn des_error_reasonable_equals() {
let invalid_version1 = DeserializationError::InvalidVersion(1, 2);
let invalid_version2 = DeserializationError::InvalidVersion(1, 2);
assert_eq!(invalid_version1, invalid_version2);
let not_enough1 = DeserializationError::NotEnoughData;
let not_enough2 = DeserializationError::NotEnoughData;
assert_eq!(not_enough1, not_enough2);
let invalid_empty1 = DeserializationError::InvalidEmptyString;
let invalid_empty2 = DeserializationError::InvalidEmptyString;
assert_eq!(invalid_empty1, invalid_empty2);
let auth_method1 = DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound,
);
let auth_method2 = DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound,
);
assert_eq!(auth_method1, auth_method2);
let utf8a = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
let utf8b = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
assert_eq!(utf8a, utf8b);
let invalid_address1 = DeserializationError::InvalidAddressType(3);
let invalid_address2 = DeserializationError::InvalidAddressType(3);
assert_eq!(invalid_address1, invalid_address2);
let invalid_client_cmd1 = DeserializationError::InvalidClientCommand(32);
let invalid_client_cmd2 = DeserializationError::InvalidClientCommand(32);
assert_eq!(invalid_client_cmd1, invalid_client_cmd2);
let invalid_server_resp1 = DeserializationError::InvalidServerResponse(42);
let invalid_server_resp2 = DeserializationError::InvalidServerResponse(42);
assert_eq!(invalid_server_resp1, invalid_server_resp2);
assert_ne!(invalid_version1, invalid_address1);
assert_ne!(not_enough1, invalid_empty1);
assert_ne!(auth_method1, invalid_client_cmd1);
assert_ne!(utf8a, invalid_server_resp1);
}
impl PartialEq for DeserializationError {
fn eq(&self, other: &DeserializationError) -> bool {
match (self, other) {
(
&DeserializationError::InvalidVersion(a, b),
&DeserializationError::InvalidVersion(x, y),
) => (a == x) && (b == y),
(&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true,
(
&DeserializationError::InvalidEmptyString,
&DeserializationError::InvalidEmptyString,
) => true,
(
&DeserializationError::AuthenticationMethodError(ref a),
&DeserializationError::AuthenticationMethodError(ref b),
) => a == b,
(&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => {
a == b
}
(
&DeserializationError::InvalidAddressType(a),
&DeserializationError::InvalidAddressType(b),
) => a == b,
(
&DeserializationError::InvalidClientCommand(a),
&DeserializationError::InvalidClientCommand(b),
) => a == b,
(
&DeserializationError::InvalidServerResponse(a),
&DeserializationError::InvalidServerResponse(b),
) => a == b,
(
&DeserializationError::InvalidReservedByte(a),
&DeserializationError::InvalidReservedByte(b),
) => a == b,
(_, _) => false,
}
}
}
/// All the errors that can occur trying to turn SOCKSv5 message structures
/// into raw bytes. There's a few places that the message structures allow
/// for information that can't be serialized; often, you have to be careful
/// about how long your strings are ...
#[derive(Error, Debug)]
pub enum SerializationError {
#[error("Too many authentication methods for serialization ({0} > 255)")]
TooManyAuthMethods(usize),
#[error("Invalid length for string: {0}")]
InvalidStringLength(String),
#[error("IO error: {0}")]
IOError(#[from] io::Error),
}
#[test]
fn ser_err_reasonable_equals() {
let too_many1 = SerializationError::TooManyAuthMethods(512);
let too_many2 = SerializationError::TooManyAuthMethods(512);
assert_eq!(too_many1, too_many2);
let invalid_str1 = SerializationError::InvalidStringLength("Whoopsy!".to_string());
let invalid_str2 = SerializationError::InvalidStringLength("Whoopsy!".to_string());
assert_eq!(invalid_str1, invalid_str2);
assert_ne!(too_many1, invalid_str1);
}
impl PartialEq for SerializationError {
fn eq(&self, other: &SerializationError) -> bool {
match (self, other) {
(
&SerializationError::TooManyAuthMethods(a),
&SerializationError::TooManyAuthMethods(b),
) => a == b,
(
&SerializationError::InvalidStringLength(ref a),
&SerializationError::InvalidStringLength(ref b),
) => a == b,
(_, _) => false,
}
}
}
#[derive(Error, Debug)]
pub enum AuthenticationDeserializationError {
#[error("No data found deserializing SOCKS authentication type")]
NoDataFound,
#[error("Invalid authentication type value: {0}")]
InvalidAuthenticationByte(u8),
#[error("IO error reading SOCKS authentication type: {0}")]
IOError(#[from] io::Error),
}
#[test]
fn auth_des_err_reasonable_equals() {
let no_data1 = AuthenticationDeserializationError::NoDataFound;
let no_data2 = AuthenticationDeserializationError::NoDataFound;
assert_eq!(no_data1, no_data2);
let invalid_auth1 = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
let invalid_auth2 = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
assert_eq!(invalid_auth1, invalid_auth2);
assert_ne!(no_data1, invalid_auth1);
}
impl PartialEq for AuthenticationDeserializationError {
fn eq(&self, other: &AuthenticationDeserializationError) -> bool {
match (self, other) {
(
&AuthenticationDeserializationError::NoDataFound,
&AuthenticationDeserializationError::NoDataFound,
) => true,
(
&AuthenticationDeserializationError::InvalidAuthenticationByte(x),
&AuthenticationDeserializationError::InvalidAuthenticationByte(y),
) => x == y,
(_, _) => false,
}
}
}
/// The errors that can happen, as a server, when we're negotiating the start
/// of a SOCKS session.
#[derive(Debug, Error)]
pub enum AuthenticationError {
#[error("Firewall disallowed connection from {0}:{1}")]
FirewallRejected(SOCKSv5Address, u16),
#[error("Could not agree on an authentication method with the client")]
ItsNotUsItsYou,
#[error("Failure in serializing response message: {0}")]
SerializationError(#[from] SerializationError),
#[error("Failed TLS handshake")]
FailedTLSHandshake,
#[error("IO error writing response message: {0}")]
IOError(#[from] io::Error),
#[error("Failure in reading client message: {0}")]
DeserializationError(#[from] DeserializationError),
#[error("Username/password check failed (username was {0})")]
FailedUsernamePassword(String),
}

View File

@@ -1,192 +1,172 @@
pub mod client; pub mod client;
pub mod errors;
pub mod messages;
pub mod network;
mod serialize;
pub mod server; pub mod server;
mod address;
mod messages;
mod security_parameters;
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::address::SOCKSv5Address;
use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword}; use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword};
use crate::network::generic::Networklike; use crate::security_parameters::SecurityParameters;
use crate::network::listener::Listenerlike; use crate::server::SOCKSv5Server;
use crate::network::testing::TestingStack; use std::io;
use crate::server::{SOCKSv5Server, SecurityParameters}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use async_std::channel::bounded; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use async_std::io::prelude::WriteExt; use tokio::net::{TcpSocket, TcpStream};
use async_std::task; use tokio::sync::oneshot;
use futures::AsyncReadExt; use tokio::task;
#[test] #[tokio::test]
fn unrestricted_login() { async fn unrestricted_login() {
task::block_on(async { // generate the server
let network_stack = TestingStack::default(); let security_parameters = SecurityParameters::unrestricted();
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9999).await.unwrap();
// generate the server let login_info = LoginInfo {
let security_parameters = SecurityParameters::unrestricted(); username_password: None,
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); };
server.start("localhost", 9999).await.unwrap(); let client = SOCKSv5Client::new(login_info, "localhost", 9999).await;
let login_info = LoginInfo { assert!(client.is_ok());
username_password: None,
};
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9999).await;
assert!(client.is_ok());
})
} }
#[test] #[tokio::test]
fn disallow_unrestricted() { async fn disallow_unrestricted() {
task::block_on(async { // generate the server
let network_stack = TestingStack::default(); let mut security_parameters = SecurityParameters::unrestricted();
security_parameters.allow_unauthenticated = false;
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9998).await.unwrap();
// generate the server let login_info = LoginInfo::default();
let mut security_parameters = SecurityParameters::unrestricted(); let client = SOCKSv5Client::new(login_info, "localhost", 9998).await;
security_parameters.allow_unauthenticated = false;
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
server.start("localhost", 9998).await.unwrap();
let login_info = LoginInfo { assert!(client.is_err());
username_password: None,
};
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9998).await;
assert!(client.is_err());
})
} }
#[test] #[tokio::test]
fn password_checks() { async fn password_checks() {
task::block_on(async { // generate the server
let network_stack = TestingStack::default(); let security_parameters = SecurityParameters {
allow_unauthenticated: false,
allow_connection: None,
connect_tls: None,
check_password: Some(|username, password| {
username == "awick" && password == "password"
}),
};
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9997).await.unwrap();
// generate the server // try the positive side
let security_parameters = SecurityParameters { let login_info = LoginInfo {
allow_unauthenticated: false, username_password: Some(UsernamePassword {
allow_connection: None, username: "awick".to_string(),
connect_tls: None, password: "password".to_string(),
check_password: Some(|username, password| { }),
username == "awick" && password == "password" };
}), let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
}; assert!(client.is_ok());
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
server.start("localhost", 9997).await.unwrap();
// try the positive side // try the negative side
let login_info = LoginInfo { let login_info = LoginInfo {
username_password: Some(UsernamePassword { username_password: Some(UsernamePassword {
username: "awick".to_string(), username: "adamw".to_string(),
password: "password".to_string(), password: "password".to_string(),
}), }),
}; };
let client = let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
SOCKSv5Client::new(network_stack.clone(), login_info, "localhost", 9997).await; assert!(client.is_err());
assert!(client.is_ok());
// try the negative side
let login_info = LoginInfo {
username_password: Some(UsernamePassword {
username: "adamw".to_string(),
password: "password".to_string(),
}),
};
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9997).await;
assert!(client.is_err());
})
} }
#[test] #[tokio::test]
fn firewall_blocks() { async fn firewall_blocks() {
task::block_on(async { // generate the server
let network_stack = TestingStack::default(); let mut security_parameters = SecurityParameters::unrestricted();
security_parameters.allow_connection = Some(|_| false);
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9996).await.unwrap();
// generate the server let login_info = LoginInfo::new();
let mut security_parameters = SecurityParameters::unrestricted(); let client = SOCKSv5Client::new(login_info, "localhost", 9996).await;
security_parameters.allow_connection = Some(|_, _| false);
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
server.start("localhost", 9996).await.unwrap();
let login_info = LoginInfo { assert!(client.is_err());
username_password: None,
};
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9996).await;
assert!(client.is_err());
})
} }
#[test] #[tokio::test]
fn establish_stream() { async fn establish_stream() -> io::Result<()> {
task::block_on(async { let target_socket = TcpSocket::new_v4()?;
let mut network_stack = TestingStack::default(); target_socket.bind(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1337,
))?;
let target_port = target_socket.listen(1)?;
let target_port = network_stack.listen("localhost", 1337).await.unwrap(); // generate the server
let security_parameters = SecurityParameters::unrestricted();
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9995).await.unwrap();
// generate the server let login_info = LoginInfo {
let security_parameters = SecurityParameters::unrestricted(); username_password: None,
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); };
server.start("localhost", 9995).await.unwrap();
let login_info = LoginInfo { let mut client = SOCKSv5Client::new(login_info, "localhost", 9995)
username_password: None, .await
}; .unwrap();
let mut client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9995) task::spawn(async move {
let mut conn = client.connect("localhost", 1337).await.unwrap();
conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap();
});
let (mut target_connection, _) = target_port.accept().await.unwrap();
let mut read_buffer = [0; 4];
target_connection
.read_exact(&mut read_buffer)
.await
.unwrap();
assert_eq!(read_buffer, [1, 3, 3, 7]);
Ok(())
}
#[tokio::test]
async fn bind_test() -> io::Result<()> {
let security_parameters = SecurityParameters::unrestricted();
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9994).await.unwrap();
let login_info = LoginInfo::default();
let client = SOCKSv5Client::new(login_info, "localhost", 9994)
.await
.unwrap();
let (target_sender, target_receiver) = oneshot::channel();
task::spawn(async move {
let (_, _, mut conn) = client
.remote_listen("localhost", 9993, |addr, port| async move {
target_sender.send((addr, port)).unwrap();
Ok(())
})
.await .await
.unwrap(); .unwrap();
task::spawn(async move { conn.write_all(&[2, 3, 5, 7]).await.unwrap();
let mut conn = client.connect("localhost", 1337).await.unwrap(); });
conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap();
});
let (mut target_connection, _, _) = target_port.accept().await.unwrap(); let (target_addr, target_port) = target_receiver.await.unwrap();
let mut read_buffer = [0; 4]; let mut stream = match target_addr {
target_connection SOCKSv5Address::IP4(x) => TcpStream::connect((x, target_port)).await?,
.read_exact(&mut read_buffer) SOCKSv5Address::IP6(x) => TcpStream::connect((x, target_port)).await?,
.await SOCKSv5Address::Hostname(x) => TcpStream::connect((x, target_port)).await?,
.unwrap(); };
assert_eq!(read_buffer, [1, 3, 3, 7]); let mut read_buffer = [0; 4];
}) stream.read_exact(&mut read_buffer).await.unwrap();
} assert_eq!(read_buffer, [2, 3, 5, 7]);
Ok(())
#[test]
fn bind_test() {
task::block_on(async {
let mut network_stack = TestingStack::default();
let security_parameters = SecurityParameters::unrestricted();
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
server.start("localhost", 9994).await.unwrap();
let login_info = LoginInfo::default();
let client = SOCKSv5Client::new(network_stack.clone(), login_info, "localhost", 9994)
.await
.unwrap();
let (target_sender, target_receiver) = bounded(1);
task::spawn(async move {
let (_, _, mut conn) = client
.remote_listen("localhost", 9993, |addr, port| async move {
target_sender.send((addr, port)).await.unwrap();
Ok(())
})
.await
.unwrap();
conn.write_all(&[2, 3, 5, 7]).await.unwrap();
});
let (target_addr, target_port) = target_receiver.recv().await.unwrap();
let mut stream = network_stack
.connect(target_addr, target_port)
.await
.unwrap();
let mut read_buffer = [0; 4];
stream.read_exact(&mut read_buffer).await.unwrap();
assert_eq!(read_buffer, [2, 3, 5, 7]);
})
} }
} }

View File

@@ -5,12 +5,51 @@ mod client_username_password;
mod server_auth_response; mod server_auth_response;
mod server_choice; mod server_choice;
mod server_response; mod server_response;
pub(crate) mod utils;
pub use crate::messages::authentication_method::AuthenticationMethod; pub(crate) mod string;
pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest};
pub use crate::messages::client_greeting::ClientGreeting; pub use crate::messages::authentication_method::{
pub use crate::messages::client_username_password::ClientUsernamePassword; AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
pub use crate::messages::server_auth_response::ServerAuthResponse; };
pub use crate::messages::server_choice::ServerChoice; pub use crate::messages::client_command::{
pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus}; ClientConnectionCommand, ClientConnectionCommandReadError, ClientConnectionCommandWriteError,
ClientConnectionRequest, ClientConnectionRequestReadError,
};
pub use crate::messages::client_greeting::{
ClientGreeting, ClientGreetingReadError, ClientGreetingWriteError,
};
pub use crate::messages::client_username_password::{
ClientUsernamePassword, ClientUsernamePasswordReadError, ClientUsernamePasswordWriteError,
};
pub use crate::messages::server_auth_response::{
ServerAuthResponse, ServerAuthResponseReadError, ServerAuthResponseWriteError,
};
pub use crate::messages::server_choice::{
ServerChoice, ServerChoiceReadError, ServerChoiceWriteError,
};
pub use crate::messages::server_response::{
ServerResponse, ServerResponseReadError, ServerResponseStatus, ServerResponseWriteError,
};
#[doc(hidden)]
#[macro_export]
macro_rules! standard_roundtrip {
($name: ident, $t: ty) => {
proptest::proptest! {
#[test]
fn $name(xs: $t) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
use std::io::Cursor;
let buffer = vec![];
let mut write_cursor = Cursor::new(buffer);
xs.write(&mut write_cursor).await.unwrap();
let serialized_form = write_cursor.into_inner();
let mut read_cursor = Cursor::new(serialized_form);
let ys = <$t>::read(&mut read_cursor);
assert_eq!(xs, ys.await.unwrap());
})
}
}
};
}

View File

@@ -1,16 +1,12 @@
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::task; use proptest::prelude::{prop_oneof, Arbitrary, Just, Strategy};
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proptest::proptest;
#[cfg(test)]
use proptest::prelude::{Arbitrary, Just, Strategy, prop_oneof};
#[cfg(test)] #[cfg(test)]
use proptest::strategy::BoxedStrategy; use proptest::strategy::BoxedStrategy;
use std::fmt; use std::fmt;
#[cfg(test)]
use std::io::Cursor;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
@@ -28,6 +24,34 @@ pub enum AuthenticationMethod {
NoAcceptableMethods, NoAcceptableMethods,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum AuthenticationMethodReadError {
#[error("Invalid authentication method #{0}")]
UnknownAuthenticationMethod(u8),
#[error("Error in underlying buffer: {0}")]
ReadError(String),
}
impl From<std::io::Error> for AuthenticationMethodReadError {
fn from(x: std::io::Error) -> AuthenticationMethodReadError {
AuthenticationMethodReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum AuthenticationMethodWriteError {
#[error("Trying to write invalid authentication method #{0}")]
InvalidAuthMethod(u8),
#[error("Error in underlying buffer: {0}")]
WriteError(String),
}
impl From<std::io::Error> for AuthenticationMethodWriteError {
fn from(x: std::io::Error) -> AuthenticationMethodWriteError {
AuthenticationMethodWriteError::WriteError(format!("{}", x))
}
}
impl fmt::Display for AuthenticationMethod { impl fmt::Display for AuthenticationMethod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
@@ -65,26 +89,17 @@ impl Arbitrary for AuthenticationMethod {
Just(AuthenticationMethod::MultiAuthenticationFramework), Just(AuthenticationMethod::MultiAuthenticationFramework),
Just(AuthenticationMethod::JSONPropertyBlock), Just(AuthenticationMethod::JSONPropertyBlock),
Just(AuthenticationMethod::NoAcceptableMethods), Just(AuthenticationMethod::NoAcceptableMethods),
(0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod), (0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod),
].boxed() ]
.boxed()
} }
} }
impl AuthenticationMethod { impl AuthenticationMethod {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<AuthenticationMethod, DeserializationError> { ) -> Result<AuthenticationMethod, AuthenticationMethodReadError> {
let mut byte_buffer = [0u8; 1]; match r.read_u8().await? {
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(AuthenticationDeserializationError::NoDataFound.into());
}
match byte_buffer[0] {
0 => Ok(AuthenticationMethod::None), 0 => Ok(AuthenticationMethod::None),
1 => Ok(AuthenticationMethod::GSSAPI), 1 => Ok(AuthenticationMethod::GSSAPI),
2 => Ok(AuthenticationMethod::UsernameAndPassword), 2 => Ok(AuthenticationMethod::UsernameAndPassword),
@@ -96,14 +111,16 @@ impl AuthenticationMethod {
9 => Ok(AuthenticationMethod::JSONPropertyBlock), 9 => Ok(AuthenticationMethod::JSONPropertyBlock),
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)), x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
0xff => Ok(AuthenticationMethod::NoAcceptableMethods), 0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()), e => Err(AuthenticationMethodReadError::UnknownAuthenticationMethod(
e,
)),
} }
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, &self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), AuthenticationMethodWriteError> {
let value = match self { let value = match self {
AuthenticationMethod::None => 0, AuthenticationMethod::None => 0,
AuthenticationMethod::GSSAPI => 1, AuthenticationMethod::GSSAPI => 1,
@@ -114,31 +131,32 @@ impl AuthenticationMethod {
AuthenticationMethod::NDS => 7, AuthenticationMethod::NDS => 7,
AuthenticationMethod::MultiAuthenticationFramework => 8, AuthenticationMethod::MultiAuthenticationFramework => 8,
AuthenticationMethod::JSONPropertyBlock => 9, AuthenticationMethod::JSONPropertyBlock => 9,
AuthenticationMethod::PrivateMethod(pm) => *pm, AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(pm) => *pm,
AuthenticationMethod::PrivateMethod(pm) => {
return Err(AuthenticationMethodWriteError::InvalidAuthMethod(*pm))
}
AuthenticationMethod::NoAcceptableMethods => 0xff, AuthenticationMethod::NoAcceptableMethods => 0xff,
}; };
Ok(w.write_all(&[value]).await?) Ok(w.write_u8(value).await?)
} }
} }
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod); crate::standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
#[test] #[tokio::test]
fn bad_byte() { async fn bad_byte() {
let no_len = vec![42]; let no_len = vec![42];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = AuthenticationMethod::read(&mut cursor); let ys = AuthenticationMethod::read(&mut cursor).await.unwrap_err();
assert_eq!( assert_eq!(
Err(DeserializationError::AuthenticationMethodError( AuthenticationMethodReadError::UnknownAuthenticationMethod(42),
AuthenticationDeserializationError::InvalidAuthenticationByte(42) ys
)),
task::block_on(ys)
); );
} }
#[test] #[tokio::test]
fn display_isnt_empty() { async fn display_isnt_empty() {
let vals = vec![ let vals = vec![
AuthenticationMethod::None, AuthenticationMethod::None,
AuthenticationMethod::GSSAPI, AuthenticationMethod::GSSAPI,

View File

@@ -1,20 +1,10 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
use crate::network::SOCKSv5Address;
use crate::serialize::read_amt;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::io::ErrorKind;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use log::debug;
use proptest::proptest;
#[cfg(test)] #[cfg(test)]
use proptest_derive::Arbitrary; use proptest_derive::Arbitrary;
#[cfg(test)] #[cfg(test)]
use std::net::Ipv4Addr; use std::io::Cursor;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))] #[cfg_attr(test, derive(Arbitrary))]
@@ -24,6 +14,60 @@ pub enum ClientConnectionCommand {
AssociateUDPPort, AssociateUDPPort,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientConnectionCommandReadError {
#[error("Invalid client connection command code: {0}")]
InvalidClientConnectionCommand(u8),
#[error("Underlying buffer read error: {0}")]
ReadError(String),
}
impl From<std::io::Error> for ClientConnectionCommandReadError {
fn from(x: std::io::Error) -> ClientConnectionCommandReadError {
ClientConnectionCommandReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientConnectionCommandWriteError {
#[error("Underlying buffer write error: {0}")]
WriteError(String),
#[error(transparent)]
SOCKSAddressWriteError(#[from] SOCKSv5AddressWriteError),
}
impl From<std::io::Error> for ClientConnectionCommandWriteError {
fn from(x: std::io::Error) -> ClientConnectionCommandWriteError {
ClientConnectionCommandWriteError::WriteError(format!("{}", x))
}
}
impl ClientConnectionCommand {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<ClientConnectionCommand, ClientConnectionCommandReadError> {
match r.read_u8().await? {
0x01 => Ok(ClientConnectionCommand::EstablishTCPStream),
0x02 => Ok(ClientConnectionCommand::EstablishTCPPortBinding),
0x03 => Ok(ClientConnectionCommand::AssociateUDPPort),
x => Err(ClientConnectionCommandReadError::InvalidClientConnectionCommand(x)),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), std::io::Error> {
match self {
ClientConnectionCommand::EstablishTCPStream => w.write_u8(0x01).await,
ClientConnectionCommand::EstablishTCPPortBinding => w.write_u8(0x02).await,
ClientConnectionCommand::AssociateUDPPort => w.write_u8(0x03).await,
}
}
}
crate::standard_roundtrip!(client_command_roundtrips, ClientConnectionCommand);
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))] #[cfg_attr(test, derive(Arbitrary))]
pub struct ClientConnectionRequest { pub struct ClientConnectionRequest {
@@ -32,37 +76,46 @@ pub struct ClientConnectionRequest {
pub destination_port: u16, pub destination_port: u16,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientConnectionRequestReadError {
#[error("Invalid version in client request: {0} (expected 5)")]
InvalidVersion(u8),
#[error("Invalid command for client request: {0}")]
InvalidCommand(#[from] ClientConnectionCommandReadError),
#[error("Invalid reserved byte: {0} (expected 0)")]
InvalidReservedByte(u8),
#[error("Underlying read error: {0}")]
ReadError(String),
#[error(transparent)]
AddressReadError(#[from] SOCKSv5AddressReadError),
}
impl From<std::io::Error> for ClientConnectionRequestReadError {
fn from(x: std::io::Error) -> ClientConnectionRequestReadError {
ClientConnectionRequestReadError::ReadError(format!("{}", x))
}
}
impl ClientConnectionRequest { impl ClientConnectionRequest {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ClientConnectionRequestReadError> {
let mut buffer = [0; 3]; let version = r.read_u8().await?;
if version != 5 {
debug!("Starting to read request."); return Err(ClientConnectionRequestReadError::InvalidVersion(version));
read_amt(r, 3, &mut buffer).await?;
debug!("Read three opening bytes: {:?}", buffer);
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
} }
let command_code = match buffer[1] { let command_code = ClientConnectionCommand::read(r).await?;
0x01 => ClientConnectionCommand::EstablishTCPStream,
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
0x03 => ClientConnectionCommand::AssociateUDPPort,
x => return Err(DeserializationError::InvalidClientCommand(x)),
};
debug!("Command code: {:?}", command_code);
if buffer[2] != 0 { let reserved = r.read_u8().await?;
return Err(DeserializationError::InvalidReservedByte(buffer[2])); if reserved != 0 {
return Err(ClientConnectionRequestReadError::InvalidReservedByte(
reserved,
));
} }
let destination_address = SOCKSv5Address::read(r).await?; let destination_address = SOCKSv5Address::read(r).await?;
debug!("Destination address: {}", destination_address); let destination_port = r.read_u16().await?;
read_amt(r, 2, &mut buffer).await?;
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
debug!("Destination port: {}", destination_port);
Ok(ClientConnectionRequest { Ok(ClientConnectionRequest {
command_code, command_code,
@@ -74,63 +127,62 @@ impl ClientConnectionRequest {
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, &self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ClientConnectionCommandWriteError> {
let command = match self.command_code { w.write_u8(5).await?;
ClientConnectionCommand::EstablishTCPStream => 1, self.command_code.write(w).await?;
ClientConnectionCommand::EstablishTCPPortBinding => 2, w.write_u8(0).await?;
ClientConnectionCommand::AssociateUDPPort => 3,
};
w.write_all(&[5, command, 0]).await?;
self.destination_address.write(w).await?; self.destination_address.write(w).await?;
w.write_all(&[ w.write_u16(self.destination_port).await?;
(self.destination_port >> 8) as u8, Ok(())
(self.destination_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
} }
} }
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest); crate::standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
#[test] #[tokio::test]
fn check_short_reads() { async fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(
ys,
Err(ClientConnectionRequestReadError::ReadError(_))
));
let no_len = vec![5, 1]; let no_len = vec![5, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(
ys,
Err(ClientConnectionRequestReadError::ReadError(_))
));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let bad_ver = vec![6, 1, 1]; let bad_ver = vec![6, 1, 1];
let mut cursor = Cursor::new(bad_ver); let mut cursor = Cursor::new(bad_ver);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ClientConnectionRequestReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_bad_command() { async fn check_bad_command() {
let bad_cmd = vec![5, 32, 1]; let bad_cmd = vec![5, 32, 1];
let mut cursor = Cursor::new(bad_cmd); let mut cursor = Cursor::new(bad_cmd);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!( assert_eq!(
Err(DeserializationError::InvalidClientCommand(32)), Err(ClientConnectionRequestReadError::InvalidCommand(
task::block_on(ys) ClientConnectionCommandReadError::InvalidClientConnectionCommand(32)
)),
ys
); );
} }
#[test] #[tokio::test]
fn short_write_fails_right() { async fn short_write_fails_right() {
use std::net::Ipv4Addr;
let mut buffer = [0u8; 2]; let mut buffer = [0u8; 2];
let cmd = ClientConnectionRequest { let cmd = ClientConnectionRequest {
command_code: ClientConnectionCommand::AssociateUDPPort, command_code: ClientConnectionCommand::AssociateUDPPort,
@@ -138,10 +190,12 @@ fn short_write_fails_right() {
destination_port: 22, destination_port: 22,
}; };
let mut cursor = Cursor::new(&mut buffer as &mut [u8]); let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
let result = task::block_on(cmd.write(&mut cursor)); let result = cmd.write(&mut cursor).await;
match result { match result {
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."), Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), Err(ClientConnectionCommandWriteError::WriteError(x)) => {
assert!(x.contains("write zero"));
}
Err(e) => panic!("Got the wrong error writing too much data: {}", e), Err(e) => panic!("Got the wrong error writing too much data: {}", e),
} }
} }

View File

@@ -1,16 +1,12 @@
#[cfg(test)] use crate::messages::authentication_method::{
use crate::errors::AuthenticationDeserializationError; AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
use crate::errors::{DeserializationError, SerializationError}; };
use crate::messages::AuthenticationMethod;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proptest::proptest;
#[cfg(test)] #[cfg(test)]
use proptest_derive::Arbitrary; use proptest_derive::Arbitrary;
#[cfg(test)]
use std::io::Cursor;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
/// Client greetings are the first message sent in a SOCKSv5 session. They /// Client greetings are the first message sent in a SOCKSv5 session. They
/// identify that there's a client that wants to talk to a server, and that /// identify that there's a client that wants to talk to a server, and that
@@ -23,26 +19,52 @@ pub struct ClientGreeting {
pub acceptable_methods: Vec<AuthenticationMethod>, pub acceptable_methods: Vec<AuthenticationMethod>,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientGreetingReadError {
#[error("Invalid version in client request: {0} (expected 5)")]
InvalidVersion(u8),
#[error(transparent)]
AuthMethodReadError(#[from] AuthenticationMethodReadError),
#[error("Underlying read error: {0}")]
ReadError(String),
}
impl From<std::io::Error> for ClientGreetingReadError {
fn from(x: std::io::Error) -> ClientGreetingReadError {
ClientGreetingReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientGreetingWriteError {
#[error("Too many methods provided; need <256, saw {0}")]
TooManyMethods(usize),
#[error(transparent)]
AuthMethodWriteError(#[from] AuthenticationMethodWriteError),
#[error("Underlying write error: {0}")]
WriteError(String),
}
impl From<std::io::Error> for ClientGreetingWriteError {
fn from(x: std::io::Error) -> ClientGreetingWriteError {
ClientGreetingWriteError::WriteError(format!("{}", x))
}
}
impl ClientGreeting { impl ClientGreeting {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<ClientGreeting, DeserializationError> { ) -> Result<ClientGreeting, ClientGreetingReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 5 {
return Err(DeserializationError::NotEnoughData); return Err(ClientGreetingReadError::InvalidVersion(version));
} }
if buffer[0] != 5 { let num_methods = r.read_u8().await? as usize;
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if r.read(&mut buffer).await? == 0 { let mut acceptable_methods = Vec::with_capacity(num_methods);
return Err(DeserializationError::NotEnoughData); for _ in 0..num_methods {
}
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
for _ in 0..buffer[0] {
acceptable_methods.push(AuthenticationMethod::read(r).await?); acceptable_methods.push(AuthenticationMethod::read(r).await?);
} }
@@ -52,9 +74,9 @@ impl ClientGreeting {
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, &self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ClientGreetingWriteError> {
if self.acceptable_methods.len() > 255 { if self.acceptable_methods.len() > 255 {
return Err(SerializationError::TooManyAuthMethods( return Err(ClientGreetingWriteError::TooManyMethods(
self.acceptable_methods.len(), self.acceptable_methods.len(),
)); ));
} }
@@ -70,44 +92,41 @@ impl ClientGreeting {
} }
} }
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting); crate::standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
#[test] #[tokio::test]
fn check_short_reads() { async fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
let no_len = vec![5]; let no_len = vec![5];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
let bad_len = vec![5, 9]; let bad_len = vec![5, 9];
let mut cursor = Cursor::new(bad_len); let mut cursor = Cursor::new(bad_len);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!( assert!(matches!(
Err(DeserializationError::AuthenticationMethodError( ys,
AuthenticationDeserializationError::NoDataFound Err(ClientGreetingReadError::AuthMethodReadError(
)), AuthenticationMethodReadError::ReadError(_)
task::block_on(ys) ))
); ));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let no_len = vec![6, 1, 1]; let no_len = vec![6, 1, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ClientGreetingReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_too_many() { async fn check_too_many() {
let mut auth_methods = Vec::with_capacity(512); let mut auth_methods = Vec::with_capacity(512);
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake); auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
let greet = ClientGreeting { let greet = ClientGreeting {
@@ -115,7 +134,7 @@ fn check_too_many() {
}; };
let mut output = vec![0; 1024]; let mut output = vec![0; 1024];
assert_eq!( assert_eq!(
Err(SerializationError::TooManyAuthMethods(512)), Err(ClientGreetingWriteError::TooManyMethods(512)),
task::block_on(greet.write(&mut output)) greet.write(&mut output).await
); );
} }

View File

@@ -1,16 +1,10 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
use crate::serialize::{read_string, write_string};
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::task; use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
#[cfg(test)] #[cfg(test)]
use futures::io::Cursor; use std::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use thiserror::Error;
#[cfg(test)] use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proptest::prelude::{Arbitrary, BoxedStrategy};
use proptest::proptest;
#[cfg(test)]
use proptest::strategy::Strategy;
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientUsernamePassword { pub struct ClientUsernamePassword {
@@ -30,30 +24,58 @@ impl Arbitrary for ClientUsernamePassword {
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let max_len = args.unwrap_or(12) as usize; let max_len = args.unwrap_or(12) as usize;
(USERNAME_REGEX, PASSWORD_REGEX).prop_map(move |(mut username, mut password)| { (USERNAME_REGEX, PASSWORD_REGEX)
username.shrink_to(max_len); .prop_map(move |(mut username, mut password)| {
password.shrink_to(max_len); username.shrink_to(max_len);
ClientUsernamePassword { username, password } password.shrink_to(max_len);
}).boxed() ClientUsernamePassword { username, password }
})
.boxed()
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientUsernamePasswordReadError {
#[error("Underlying buffer read error: {0}")]
ReadError(String),
#[error("Invalid username/password version; expected 1, saw {0}")]
InvalidVersion(u8),
#[error(transparent)]
StringError(#[from] SOCKSv5StringReadError),
}
impl From<std::io::Error> for ClientUsernamePasswordReadError {
fn from(x: std::io::Error) -> ClientUsernamePasswordReadError {
ClientUsernamePasswordReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientUsernamePasswordWriteError {
#[error("Underlying buffer read error: {0}")]
WriteError(String),
#[error(transparent)]
StringError(#[from] SOCKSv5StringWriteError),
}
impl From<std::io::Error> for ClientUsernamePasswordWriteError {
fn from(x: std::io::Error) -> ClientUsernamePasswordWriteError {
ClientUsernamePasswordWriteError::WriteError(format!("{}", x))
} }
} }
impl ClientUsernamePassword { impl ClientUsernamePassword {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ClientUsernamePasswordReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 1 {
return Err(DeserializationError::NotEnoughData); return Err(ClientUsernamePasswordReadError::InvalidVersion(version));
} }
if buffer[0] != 1 { let username = SOCKSv5String::read(r).await?.into();
return Err(DeserializationError::InvalidVersion(1, buffer[0])); let password = SOCKSv5String::read(r).await?.into();
}
let username = read_string(r).await?;
let password = read_string(r).await?;
Ok(ClientUsernamePassword { username, password }) Ok(ClientUsernamePassword { username, password })
} }
@@ -61,35 +83,40 @@ impl ClientUsernamePassword {
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, &self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ClientUsernamePasswordWriteError> {
w.write_all(&[1]).await?; w.write_u8(1).await?;
write_string(&self.username, w).await?; SOCKSv5String::from(self.username.as_str()).write(w).await?;
write_string(&self.password, w).await SOCKSv5String::from(self.password.as_str()).write(w).await?;
Ok(())
} }
} }
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword); crate::standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
#[test] #[tokio::test]
fn check_short_reads() { async fn heck_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ClientUsernamePassword::read(&mut cursor); let ys = ClientUsernamePassword::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(
ys,
Err(ClientUsernamePasswordReadError::ReadError(_))
));
let user_only = vec![1, 3, 102, 111, 111]; let user_only = vec![1, 3, 102, 111, 111];
let mut cursor = Cursor::new(user_only); let mut cursor = Cursor::new(user_only);
let ys = ClientUsernamePassword::read(&mut cursor); let ys = ClientUsernamePassword::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); println!("ys: {:?}", ys);
assert!(matches!(
ys,
Err(ClientUsernamePasswordReadError::StringError(_))
));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let bad_len = vec![5]; let bad_len = vec![5];
let mut cursor = Cursor::new(bad_len); let mut cursor = Cursor::new(bad_len);
let ys = ClientUsernamePassword::read(&mut cursor); let ys = ClientUsernamePassword::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ClientUsernamePasswordReadError::InvalidVersion(5)), ys);
Err(DeserializationError::InvalidVersion(1, 5)),
task::block_on(ys)
);
} }

View File

@@ -1,13 +1,7 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proptest::proptest;
#[cfg(test)] #[cfg(test)]
use proptest_derive::Arbitrary; use proptest_derive::Arbitrary;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))] #[cfg_attr(test, derive(Arbitrary))]
@@ -15,6 +9,32 @@ pub struct ServerAuthResponse {
pub success: bool, pub success: bool,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerAuthResponseReadError {
#[error("Underlying buffer read error: {0}")]
ReadError(String),
#[error("Invalid username/password version; expected 1, saw {0}")]
InvalidVersion(u8),
}
impl From<std::io::Error> for ServerAuthResponseReadError {
fn from(x: std::io::Error) -> ServerAuthResponseReadError {
ServerAuthResponseReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerAuthResponseWriteError {
#[error("Underlying buffer read error: {0}")]
WriteError(String),
}
impl From<std::io::Error> for ServerAuthResponseWriteError {
fn from(x: std::io::Error) -> ServerAuthResponseWriteError {
ServerAuthResponseWriteError::WriteError(format!("{}", x))
}
}
impl ServerAuthResponse { impl ServerAuthResponse {
pub fn success() -> ServerAuthResponse { pub fn success() -> ServerAuthResponse {
ServerAuthResponse { success: true } ServerAuthResponse { success: true }
@@ -26,30 +46,22 @@ impl ServerAuthResponse {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ServerAuthResponseReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 1 {
return Err(DeserializationError::NotEnoughData); return Err(ServerAuthResponseReadError::InvalidVersion(version));
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
} }
Ok(ServerAuthResponse { Ok(ServerAuthResponse {
success: buffer[0] == 0, success: r.read_u8().await? == 0,
}) })
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, &self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ServerAuthResponseWriteError> {
w.write_all(&[1]).await?; w.write_all(&[1]).await?;
w.write_all(&[if self.success { 0x00 } else { 0xde }]) w.write_all(&[if self.success { 0x00 } else { 0xde }])
.await?; .await?;
@@ -57,28 +69,29 @@ impl ServerAuthResponse {
} }
} }
standard_roundtrip!(server_auth_response, ServerAuthResponse); crate::standard_roundtrip!(server_auth_response, ServerAuthResponse);
#[tokio::test]
async fn check_short_reads() {
use std::io::Cursor;
#[test]
fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ServerAuthResponse::read(&mut cursor); let ys = ServerAuthResponse::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
let no_len = vec![1]; let no_len = vec![1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ServerAuthResponse::read(&mut cursor); let ys = ServerAuthResponse::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
use std::io::Cursor;
let no_len = vec![6, 1]; let no_len = vec![6, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ServerAuthResponse::read(&mut cursor); let ys = ServerAuthResponse::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerAuthResponseReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(1, 6)),
task::block_on(ys)
);
} }

View File

@@ -1,16 +1,12 @@
#[cfg(test)] use crate::messages::authentication_method::{
use crate::errors::AuthenticationDeserializationError; AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
use crate::errors::{DeserializationError, SerializationError}; };
use crate::messages::AuthenticationMethod;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proptest::proptest;
#[cfg(test)] #[cfg(test)]
use proptest_derive::Arbitrary; use proptest_derive::Arbitrary;
#[cfg(test)]
use std::io::Cursor;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))] #[cfg_attr(test, derive(Arbitrary))]
@@ -18,6 +14,36 @@ pub struct ServerChoice {
pub chosen_method: AuthenticationMethod, pub chosen_method: AuthenticationMethod,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerChoiceReadError {
#[error(transparent)]
AuthMethodError(#[from] AuthenticationMethodReadError),
#[error("Error in underlying buffer: {0}")]
ReadError(String),
#[error("Invalid version; expected 5, got {0}")]
InvalidVersion(u8),
}
impl From<std::io::Error> for ServerChoiceReadError {
fn from(x: std::io::Error) -> ServerChoiceReadError {
ServerChoiceReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerChoiceWriteError {
#[error(transparent)]
AuthMethodError(#[from] AuthenticationMethodWriteError),
#[error("Error in underlying buffer: {0}")]
WriteError(String),
}
impl From<std::io::Error> for ServerChoiceWriteError {
fn from(x: std::io::Error) -> ServerChoiceWriteError {
ServerChoiceWriteError::WriteError(format!("{}", x))
}
}
impl ServerChoice { impl ServerChoice {
pub fn rejection() -> ServerChoice { pub fn rejection() -> ServerChoice {
ServerChoice { ServerChoice {
@@ -33,15 +59,11 @@ impl ServerChoice {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ServerChoiceReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 5 {
return Err(DeserializationError::NotEnoughData); return Err(ServerChoiceReadError::InvalidVersion(version));
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
} }
let chosen_method = AuthenticationMethod::read(r).await?; let chosen_method = AuthenticationMethod::read(r).await?;
@@ -52,39 +74,32 @@ impl ServerChoice {
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, &self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ServerChoiceWriteError> {
w.write_all(&[5]).await?; w.write_u8(5).await?;
self.chosen_method.write(w).await self.chosen_method.write(w).await?;
Ok(())
} }
} }
standard_roundtrip!(server_choice_roundtrips, ServerChoice); crate::standard_roundtrip!(server_choice_roundtrips, ServerChoice);
#[test] #[tokio::test]
fn check_short_reads() { async fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ServerChoice::read(&mut cursor); let ys = ServerChoice::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerChoiceReadError::ReadError(_))));
let bad_len = vec![5]; let bad_len = vec![5];
let mut cursor = Cursor::new(bad_len); let mut cursor = Cursor::new(bad_len);
let ys = ServerChoice::read(&mut cursor); let ys = ServerChoice::read(&mut cursor).await;
assert_eq!( assert!(matches!(ys, Err(ServerChoiceReadError::AuthMethodError(_))));
Err(DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound
)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let no_len = vec![9, 1]; let no_len = vec![9, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ServerChoice::read(&mut cursor); let ys = ServerChoice::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerChoiceReadError::InvalidVersion(9)), ys);
Err(DeserializationError::InvalidVersion(5, 9)),
task::block_on(ys)
);
} }

View File

@@ -1,21 +1,10 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
use crate::network::generic::IntoErrorResponse;
use crate::network::SOCKSv5Address;
use crate::serialize::read_amt;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::io::ErrorKind;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use log::warn;
use proptest::proptest;
#[cfg(test)] #[cfg(test)]
use proptest_derive::Arbitrary; use proptest_derive::Arbitrary;
use std::net::Ipv4Addr; #[cfg(test)]
use std::io::Cursor;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, Eq, Error, PartialEq)] #[derive(Clone, Debug, Eq, Error, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))] #[cfg_attr(test, derive(Arbitrary))]
@@ -40,12 +29,6 @@ pub enum ServerResponseStatus {
AddressTypeNotSupported, AddressTypeNotSupported,
} }
impl IntoErrorResponse for ServerResponseStatus {
fn into_response(&self) -> ServerResponseStatus {
self.clone()
}
}
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))] #[cfg_attr(test, derive(Arbitrary))]
pub struct ServerResponse { pub struct ServerResponse {
@@ -54,33 +37,57 @@ pub struct ServerResponse {
pub bound_port: u16, pub bound_port: u16,
} }
impl ServerResponse { #[derive(Clone, Debug, Error, PartialEq)]
pub fn error<E: IntoErrorResponse>(resp: &E) -> ServerResponse { pub enum ServerResponseReadError {
ServerResponse { #[error("Error reading from underlying buffer: {0}")]
status: resp.into_response(), ReadError(String),
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)), #[error(transparent)]
bound_port: 0, AddressReadError(#[from] SOCKSv5AddressReadError),
} #[error("Invalid version; expected 5, got {0}")]
InvalidVersion(u8),
#[error("Invalid reserved byte; saw {0}, should be 0")]
InvalidReservedByte(u8),
#[error("Invalid (or just unknown) server response value {0}")]
InvalidServerResponse(u8),
}
impl From<std::io::Error> for ServerResponseReadError {
fn from(x: std::io::Error) -> ServerResponseReadError {
ServerResponseReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerResponseWriteError {
#[error("Error reading from underlying buffer: {0}")]
WriteError(String),
#[error(transparent)]
AddressWriteError(#[from] SOCKSv5AddressWriteError),
}
impl From<std::io::Error> for ServerResponseWriteError {
fn from(x: std::io::Error) -> ServerResponseWriteError {
ServerResponseWriteError::WriteError(format!("{}", x))
} }
} }
impl ServerResponse { impl ServerResponse {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ServerResponseReadError> {
let mut buffer = [0; 3]; let version = r.read_u8().await?;
if version != 5 {
read_amt(r, 3, &mut buffer).await?; return Err(ServerResponseReadError::InvalidVersion(version));
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
} }
if buffer[2] != 0 { let status_byte = r.read_u8().await?;
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
let reserved_byte = r.read_u8().await?;
if reserved_byte != 0 {
return Err(ServerResponseReadError::InvalidReservedByte(reserved_byte));
} }
let status = match buffer[1] { let status = match status_byte {
0x00 => ServerResponseStatus::RequestGranted, 0x00 => ServerResponseStatus::RequestGranted,
0x01 => ServerResponseStatus::GeneralFailure, 0x01 => ServerResponseStatus::GeneralFailure,
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule, 0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
@@ -90,12 +97,11 @@ impl ServerResponse {
0x06 => ServerResponseStatus::TTLExpired, 0x06 => ServerResponseStatus::TTLExpired,
0x07 => ServerResponseStatus::CommandNotSupported, 0x07 => ServerResponseStatus::CommandNotSupported,
0x08 => ServerResponseStatus::AddressTypeNotSupported, 0x08 => ServerResponseStatus::AddressTypeNotSupported,
x => return Err(DeserializationError::InvalidServerResponse(x)), x => return Err(ServerResponseReadError::InvalidServerResponse(x)),
}; };
let bound_address = SOCKSv5Address::read(r).await?; let bound_address = SOCKSv5Address::read(r).await?;
read_amt(r, 2, &mut buffer).await?; let bound_port = r.read_u16().await?;
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ServerResponse { Ok(ServerResponse {
status, status,
@@ -107,7 +113,9 @@ impl ServerResponse {
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, &self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ServerResponseWriteError> {
w.write_u8(5).await?;
let status_code = match self.status { let status_code = match self.status {
ServerResponseStatus::RequestGranted => 0x00, ServerResponseStatus::RequestGranted => 0x00,
ServerResponseStatus::GeneralFailure => 0x01, ServerResponseStatus::GeneralFailure => 0x01,
@@ -119,59 +127,61 @@ impl ServerResponse {
ServerResponseStatus::CommandNotSupported => 0x07, ServerResponseStatus::CommandNotSupported => 0x07,
ServerResponseStatus::AddressTypeNotSupported => 0x08, ServerResponseStatus::AddressTypeNotSupported => 0x08,
}; };
w.write_u8(status_code).await?;
w.write_all(&[5, status_code, 0]).await?; w.write_u8(0).await?;
self.bound_address.write(w).await?; self.bound_address.write(w).await?;
w.write_all(&[ w.write_u16(self.bound_port).await?;
(self.bound_port >> 8) as u8,
(self.bound_port & 0xffu16) as u8, Ok(())
])
.await
.map_err(SerializationError::IOError)
} }
} }
standard_roundtrip!(server_response_roundtrips, ServerResponse); crate::standard_roundtrip!(server_response_roundtrips, ServerResponse);
#[test] #[tokio::test]
fn check_short_reads() { async fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ServerResponse::read(&mut cursor); let ys = ServerResponse::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerResponseReadError::ReadError(_))));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let bad_ver = vec![6, 1, 1]; let bad_ver = vec![6, 1, 1];
let mut cursor = Cursor::new(bad_ver); let mut cursor = Cursor::new(bad_ver);
let ys = ServerResponse::read(&mut cursor); let ys = ServerResponse::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerResponseReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_bad_command() { async fn check_bad_reserved() {
let bad_cmd = vec![5, 32, 0x42]; let bad_cmd = vec![5, 32, 0x42];
let mut cursor = Cursor::new(bad_cmd); let mut cursor = Cursor::new(bad_cmd);
let ys = ServerResponse::read(&mut cursor); let ys = ServerResponse::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerResponseReadError::InvalidReservedByte(0x42)), ys);
Err(DeserializationError::InvalidServerResponse(32)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn short_write_fails_right() { async fn check_bad_command() {
let mut buffer = [0u8; 2]; let bad_cmd = vec![5, 32, 0];
let cmd = ServerResponse::error(&ServerResponseStatus::AddressTypeNotSupported); let mut cursor = Cursor::new(bad_cmd);
let mut cursor = Cursor::new(&mut buffer as &mut [u8]); let ys = ServerResponse::read(&mut cursor).await;
let result = task::block_on(cmd.write(&mut cursor)); assert_eq!(Err(ServerResponseReadError::InvalidServerResponse(32)), ys);
match result { }
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), #[tokio::test]
Err(e) => panic!("Got the wrong error writing too much data: {}", e), async fn short_write_fails_right() {
} let mut buffer = [0u8; 2];
let cmd = ServerResponse {
status: ServerResponseStatus::AddressTypeNotSupported,
bound_address: SOCKSv5Address::Hostname("tester.com".to_string()),
bound_port: 99,
};
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
let result = cmd.write(&mut cursor).await;
assert!(matches!(
result,
Err(ServerResponseWriteError::WriteError(_))
));
} }

117
src/messages/string.rs Normal file
View File

@@ -0,0 +1,117 @@
#[cfg(test)]
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
use std::convert::TryFrom;
use std::string::FromUtf8Error;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Debug, PartialEq)]
pub struct SOCKSv5String(String);
#[cfg(test)]
const STRING_REGEX: &str = "[a-zA-Z0-9_.|!@#$%^]+";
#[cfg(test)]
impl Arbitrary for SOCKSv5String {
type Parameters = Option<u16>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let max_len = args.unwrap_or(32) as usize;
STRING_REGEX
.prop_map(move |mut str| {
str.shrink_to(max_len);
SOCKSv5String(str)
})
.boxed()
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5StringReadError {
#[error("Underlying buffer read error: {0}")]
ReadError(String),
#[error("SOCKSv5 string encoding error; encountered empty string (?)")]
ZeroStringLength,
#[error("Invalid UTF-8 string: {0}")]
InvalidUtf8Error(#[from] FromUtf8Error),
}
impl From<std::io::Error> for SOCKSv5StringReadError {
fn from(x: std::io::Error) -> SOCKSv5StringReadError {
SOCKSv5StringReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5StringWriteError {
#[error("Underlying buffer write error: {0}")]
WriteError(String),
#[error("String too large to encode according to SOCKSv5 reuls ({0} bytes long)")]
TooBig(usize),
#[error("Cannot serialize the empty string in SOCKSv5")]
ZeroStringLength,
}
impl From<std::io::Error> for SOCKSv5StringWriteError {
fn from(x: std::io::Error) -> SOCKSv5StringWriteError {
SOCKSv5StringWriteError::WriteError(format!("{}", x))
}
}
impl SOCKSv5String {
pub async fn read<R: AsyncRead + Unpin>(r: &mut R) -> Result<Self, SOCKSv5StringReadError> {
let length = r.read_u8().await? as usize;
if length == 0 {
return Err(SOCKSv5StringReadError::ZeroStringLength);
}
let mut bytestring = vec![0; length];
r.read_exact(&mut bytestring).await?;
Ok(SOCKSv5String(String::from_utf8(bytestring)?))
}
pub async fn write<W: AsyncWrite + Unpin>(
&self,
w: &mut W,
) -> Result<(), SOCKSv5StringWriteError> {
let bytestring = self.0.as_bytes();
if bytestring.is_empty() {
return Err(SOCKSv5StringWriteError::ZeroStringLength);
}
let length = match u8::try_from(bytestring.len()) {
Err(_) => return Err(SOCKSv5StringWriteError::TooBig(bytestring.len())),
Ok(x) => x,
};
w.write_u8(length).await?;
w.write_all(bytestring).await?;
Ok(())
}
}
impl From<String> for SOCKSv5String {
fn from(x: String) -> Self {
SOCKSv5String(x)
}
}
impl<'a> From<&'a str> for SOCKSv5String {
fn from(x: &str) -> Self {
SOCKSv5String(x.to_string())
}
}
impl From<SOCKSv5String> for String {
fn from(x: SOCKSv5String) -> Self {
x.0
}
}
crate::standard_roundtrip!(socks_string_roundtrips, SOCKSv5String);

View File

@@ -1,16 +0,0 @@
#[doc(hidden)]
#[macro_export]
macro_rules! standard_roundtrip {
($name: ident, $t: ty) => {
proptest! {
#[test]
fn $name(xs: $t) {
let mut buffer = vec![];
task::block_on(xs.write(&mut buffer)).unwrap();
let mut cursor = Cursor::new(buffer);
let ys = <$t>::read(&mut cursor);
assert_eq!(xs, task::block_on(ys).unwrap());
}
}
};
}

View File

@@ -1,10 +0,0 @@
pub mod address;
pub mod datagram;
pub mod generic;
pub mod listener;
pub mod standard;
pub mod stream;
pub mod testing;
pub use crate::network::address::SOCKSv5Address;
pub use crate::network::standard::Builtin;

View File

@@ -1,264 +0,0 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::serialize::{read_amt, read_string, write_string};
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proptest::prelude::proptest;
#[cfg(test)]
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any, prop_oneof};
use std::convert::TryFrom;
use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use thiserror::Error;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum SOCKSv5Address {
IP4(Ipv4Addr),
IP6(Ipv6Addr),
Name(String),
}
#[cfg(test)]
const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+";
#[cfg(test)]
impl Arbitrary for SOCKSv5Address {
type Parameters = Option<u16>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let max_len = args.unwrap_or(32) as usize;
prop_oneof![
any::<Ipv4Addr>().prop_map(SOCKSv5Address::IP4),
any::<Ipv6Addr>().prop_map(SOCKSv5Address::IP6),
HOSTNAME_REGEX.prop_map(move |mut hostname| {
hostname.shrink_to(max_len);
SOCKSv5Address::Name(hostname)
}),
].boxed()
}
}
#[derive(Error, Debug, PartialEq)]
pub enum AddressConversionError {
#[error("Couldn't convert IPv4 address into destination type")]
CouldntConvertIP4,
#[error("Couldn't convert IPv6 address into destination type")]
CouldntConvertIP6,
#[error("Couldn't convert name into destination type")]
CouldntConvertName,
}
impl From<IpAddr> for SOCKSv5Address {
fn from(x: IpAddr) -> SOCKSv5Address {
match x {
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
}
}
}
impl TryFrom<SOCKSv5Address> for IpAddr {
type Error = AddressConversionError;
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
match value {
SOCKSv5Address::IP4(a) => Ok(IpAddr::V4(a)),
SOCKSv5Address::IP6(a) => Ok(IpAddr::V6(a)),
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
}
}
}
impl From<Ipv4Addr> for SOCKSv5Address {
fn from(x: Ipv4Addr) -> Self {
SOCKSv5Address::IP4(x)
}
}
impl TryFrom<SOCKSv5Address> for Ipv4Addr {
type Error = AddressConversionError;
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
match value {
SOCKSv5Address::IP4(a) => Ok(a),
SOCKSv5Address::IP6(_) => Err(AddressConversionError::CouldntConvertIP6),
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
}
}
}
impl From<Ipv6Addr> for SOCKSv5Address {
fn from(x: Ipv6Addr) -> Self {
SOCKSv5Address::IP6(x)
}
}
impl TryFrom<SOCKSv5Address> for Ipv6Addr {
type Error = AddressConversionError;
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
match value {
SOCKSv5Address::IP4(_) => Err(AddressConversionError::CouldntConvertIP4),
SOCKSv5Address::IP6(a) => Ok(a),
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
}
}
}
impl From<String> for SOCKSv5Address {
fn from(x: String) -> Self {
SOCKSv5Address::Name(x)
}
}
impl<'a> From<&'a str> for SOCKSv5Address {
fn from(x: &str) -> SOCKSv5Address {
SOCKSv5Address::Name(x.to_string())
}
}
impl fmt::Display for SOCKSv5Address {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
SOCKSv5Address::Name(a) => write!(f, "{}", a),
}
}
}
impl SOCKSv5Address {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<Self, DeserializationError> {
let mut byte_buffer = [0u8; 1];
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(DeserializationError::NotEnoughData);
}
match byte_buffer[0] {
1 => {
let mut addr_buffer = [0; 4];
read_amt(r, 4, &mut addr_buffer).await?;
Ok(SOCKSv5Address::IP4(Ipv4Addr::from(addr_buffer)))
}
3 => {
let name = read_string(r).await?;
Ok(SOCKSv5Address::Name(name))
}
4 => {
let mut addr_buffer = [0; 16];
read_amt(r, 16, &mut addr_buffer).await?;
Ok(SOCKSv5Address::IP6(Ipv6Addr::from(addr_buffer)))
}
x => Err(DeserializationError::InvalidAddressType(x)),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
match self {
SOCKSv5Address::IP4(x) => {
w.write_all(&[1]).await?;
w.write_all(&x.octets())
.await
.map_err(SerializationError::IOError)
}
SOCKSv5Address::IP6(x) => {
w.write_all(&[4]).await?;
w.write_all(&x.octets())
.await
.map_err(SerializationError::IOError)
}
SOCKSv5Address::Name(x) => {
w.write_all(&[3]).await?;
write_string(x, w).await
}
}
}
}
pub trait HasLocalAddress {
fn local_addr(&self) -> (SOCKSv5Address, u16);
}
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
proptest! {
#[test]
fn ip_conversion(x: IpAddr) {
match x {
IpAddr::V4(ref a) =>
assert_eq!(Err(AddressConversionError::CouldntConvertIP4),
Ipv6Addr::try_from(SOCKSv5Address::from(*a))),
IpAddr::V6(ref a) =>
assert_eq!(Err(AddressConversionError::CouldntConvertIP6),
Ipv4Addr::try_from(SOCKSv5Address::from(*a))),
}
assert_eq!(x, IpAddr::try_from(SOCKSv5Address::from(x)).unwrap());
}
#[test]
fn ip4_conversion(x: Ipv4Addr) {
assert_eq!(x, Ipv4Addr::try_from(SOCKSv5Address::from(x)).unwrap());
}
#[test]
fn ip6_conversion(x: Ipv6Addr) {
assert_eq!(x, Ipv6Addr::try_from(SOCKSv5Address::from(x)).unwrap());
}
#[test]
fn display_matches(x: SOCKSv5Address) {
match x {
SOCKSv5Address::IP4(a) => assert_eq!(format!("{}", a), format!("{}", x)),
SOCKSv5Address::IP6(a) => assert_eq!(format!("{}", a), format!("{}", x)),
SOCKSv5Address::Name(ref a) => assert_eq!(*a, x.to_string()),
}
}
#[test]
fn bad_read_key(x: u8) {
match x {
1 | 3 | 4 => {}
_ => {
let buffer = [x, 0, 1, 2, 9, 10];
let mut cursor = Cursor::new(buffer);
let meh = SOCKSv5Address::read(&mut cursor);
assert_eq!(Err(DeserializationError::InvalidAddressType(x)), task::block_on(meh));
}
}
}
}
#[test]
fn domain_name_sanity() {
let name = "uhsure.com";
let strname = name.to_string();
let addr1 = SOCKSv5Address::from(name);
let addr2 = SOCKSv5Address::from(strname);
assert_eq!(addr1, addr2);
assert_eq!(
Err(AddressConversionError::CouldntConvertName),
IpAddr::try_from(addr1.clone())
);
assert_eq!(
Err(AddressConversionError::CouldntConvertName),
Ipv4Addr::try_from(addr1.clone())
);
assert_eq!(
Err(AddressConversionError::CouldntConvertName),
Ipv6Addr::try_from(addr1)
);
}

View File

@@ -1,43 +0,0 @@
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use async_trait::async_trait;
#[async_trait]
pub trait Datagramlike: Send + Sync + HasLocalAddress {
type Error;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error>;
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error>;
}
pub struct GenericDatagramSocket<E> {
pub internal: Box<dyn Datagramlike<Error = E>>,
}
#[async_trait]
impl<E> Datagramlike for GenericDatagramSocket<E> {
type Error = E;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
Ok(self.internal.send_to(buf, addr, port).await?)
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
Ok(self.internal.recv_from(buf).await?)
}
}
impl<E> HasLocalAddress for GenericDatagramSocket<E> {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
self.internal.local_addr()
}
}

View File

@@ -1,61 +0,0 @@
use crate::messages::ServerResponseStatus;
use crate::network::address::SOCKSv5Address;
use crate::network::datagram::GenericDatagramSocket;
use crate::network::listener::GenericListener;
use crate::network::stream::GenericStream;
use async_trait::async_trait;
use std::fmt::{Debug, Display};
#[async_trait]
pub trait Networklike {
/// The error type for things that fail on this network. Apologies in advance
/// for using only one; if you have a use case for separating your errors,
/// please shoot the author(s) and email to split this into multiple types, one
/// for each trait function.
type Error: Debug + Display + IntoErrorResponse + Send;
/// Connect to the given address and port, over this kind of network. The
/// underlying stream should behave somewhat like a TCP stream ... which
/// may be exactly what you're using. However, in order to support tunnelling
/// scenarios (i.e., using another proxy, going through Tor or SSH, etc.) we
/// work generically over any stream-like object.
async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error>;
/// Listen for connections on the given address and port, returning a generic
/// listener socket to use in the future.
async fn listen<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error>;
/// Bind a socket for the purposes of doing some datagram communication. NOTE!
/// this is only for UDP-like communication, not for generic connecting or
/// listening! Maybe obvious from the types, but POSIX has overtrained many
/// of us.
///
/// Recall when using these functions that datagram protocols allow for packet
/// loss and out-of-order delivery. So ... be warned.
async fn bind<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error>;
}
/// This trait is a hack; sorry about that. The thing is, we want to be able to
/// convert Errors from the `Networklike` trait into a `ServerResponseStatus`,
/// but want to do so on references to the error object rather than the actual
/// object. This is for the paired reason that (a) we want to be able to use
/// the errors in multiple places -- for example, to return a value to the client
/// and then also to whoever called the function -- and (b) some common errors
/// (I'm looking at you, `io::Error`) aren't `Clone`. So ... hence this overly-
/// specific trait.
pub trait IntoErrorResponse {
#[allow(clippy::wrong_self_convention)]
fn into_response(&self) -> ServerResponseStatus;
}

View File

@@ -1,29 +0,0 @@
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use crate::network::stream::GenericStream;
use async_trait::async_trait;
#[async_trait]
pub trait Listenerlike: Send + Sync + HasLocalAddress {
type Error;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>;
}
pub struct GenericListener<E> {
pub internal: Box<dyn Listenerlike<Error = E>>,
}
#[async_trait]
impl<E> Listenerlike for GenericListener<E> {
type Error = E;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
Ok(self.internal.accept().await?)
}
}
impl<E> HasLocalAddress for GenericListener<E> {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
self.internal.local_addr()
}
}

View File

@@ -1,240 +0,0 @@
use crate::messages::ServerResponseStatus;
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use crate::network::datagram::{Datagramlike, GenericDatagramSocket};
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::{GenericStream, Streamlike};
use async_std::io;
#[cfg(test)]
use async_std::io::ReadExt;
use async_std::net::{TcpListener, TcpStream, UdpSocket};
use async_trait::async_trait;
#[cfg(test)]
use futures::AsyncWriteExt;
use log::error;
use super::generic::IntoErrorResponse;
#[derive(Clone)]
pub struct Builtin {}
impl Builtin {
pub fn new() -> Builtin {
Builtin {}
}
}
impl Default for Builtin {
fn default() -> Self {
Self::new()
}
}
macro_rules! local_address_impl {
($t: ty) => {
impl HasLocalAddress for $t {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
match self.local_addr() {
Ok(a) =>
(SOCKSv5Address::from(a.ip()), a.port()),
Err(e) => {
error!("Couldn't translate (Streamlike) local address to SOCKS local address: {}", e);
(SOCKSv5Address::from("localhost"), 0)
}
}
}
}
};
}
local_address_impl!(TcpStream);
local_address_impl!(TcpListener);
local_address_impl!(UdpSocket);
impl Streamlike for TcpStream {}
#[async_trait]
impl Listenerlike for TcpListener {
type Error = io::Error;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
let (base, addrport) = self.accept().await?;
let addr = addrport.ip();
let port = addrport.port();
Ok((GenericStream::new(base), SOCKSv5Address::from(addr), port))
}
}
#[async_trait]
impl Datagramlike for UdpSocket {
type Error = io::Error;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
match addr {
SOCKSv5Address::IP4(a) => self.send_to(buf, (a, port)).await,
SOCKSv5Address::IP6(a) => self.send_to(buf, (a, port)).await,
SOCKSv5Address::Name(n) => self.send_to(buf, (n.as_str(), port)).await,
}
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
let (amt, addrport) = self.recv_from(buf).await?;
let addr = addrport.ip();
let port = addrport.port();
Ok((amt, SOCKSv5Address::from(addr), port))
}
}
#[async_trait]
impl Networklike for Builtin {
type Error = io::Error;
async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error> {
let target = addr.into();
let base_stream = match target {
SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await?,
SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await?,
SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?,
};
Ok(GenericStream::from(base_stream))
}
async fn listen<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error> {
let target = addr.into();
let base_stream = match target {
SOCKSv5Address::IP4(a) => TcpListener::bind((a, port)).await?,
SOCKSv5Address::IP6(a) => TcpListener::bind((a, port)).await?,
SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), port)).await?,
};
Ok(GenericListener {
internal: Box::new(base_stream),
})
}
async fn bind<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
let target = addr.into();
let base_socket = match target {
SOCKSv5Address::IP4(a) => UdpSocket::bind((a, port)).await?,
SOCKSv5Address::IP6(a) => UdpSocket::bind((a, port)).await?,
SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), port)).await?,
};
Ok(GenericDatagramSocket {
internal: Box::new(base_socket),
})
}
}
#[test]
fn check_sanity() {
async_std::task::block_on(async {
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
// going to get any dropped data along here ... which is a very questionable
// assumption, morally speaking, but probably fine for most purposes.
let mut network = Builtin::new();
let receiver = network
.bind("localhost", 0)
.await
.expect("Failed to bind receiver socket.");
let sender = network
.bind("localhost", 0)
.await
.expect("Failed to bind sender socket.");
let buffer = [0xde, 0xea, 0xbe, 0xef];
let (receiver_addr, receiver_port) = receiver.local_addr();
sender
.send_to(&buffer, receiver_addr, receiver_port)
.await
.expect("Failure sending datagram!");
let mut recvbuffer = [0; 4];
let (s, f, p) = receiver
.recv_from(&mut recvbuffer)
.await
.expect("Didn't receive UDP message?");
let (sender_addr, sender_port) = sender.local_addr();
assert_eq!(s, 4);
assert_eq!(f, sender_addr);
assert_eq!(p, sender_port);
assert_eq!(recvbuffer, buffer);
});
// This whole block should be pretty solid, though, unless the system we're
// on is in a pretty weird place.
let mut network = Builtin::new();
let listener = async_std::task::block_on(network.listen("localhost", 0))
.expect("Couldn't set up listener on localhost");
let (listener_address, listener_port) = listener.local_addr();
let listener_task_handle = async_std::task::spawn(async move {
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
let mut result_buffer = [0u8; 4];
stream
.read_exact(&mut result_buffer)
.await
.expect("Read failure in TCP test");
(result_buffer, addr, port)
});
let sender_task_handle = async_std::task::spawn(async move {
let mut sender = network
.connect(listener_address, listener_port)
.await
.expect("Coudln't connect to listener?");
let (sender_address, sender_port) = sender.local_addr();
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
sender
.write_all(&send_buffer)
.await
.expect("Couldn't send the write buffer");
sender
.flush()
.await
.expect("Couldn't flush the write buffer");
sender
.close()
.await
.expect("Couldn't close the write buffer");
(sender_address, sender_port)
});
async_std::task::block_on(async {
let (result, result_from, result_from_port) = listener_task_handle.await;
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
let (sender_address, sender_port) = sender_task_handle.await;
assert_eq!(result_from, sender_address);
assert_eq!(result_from_port, sender_port);
});
}
impl IntoErrorResponse for io::Error {
fn into_response(&self) -> ServerResponseStatus {
match self.kind() {
io::ErrorKind::ConnectionRefused => ServerResponseStatus::ConnectionRefused,
io::ErrorKind::NotFound => ServerResponseStatus::HostUnreachable,
_ => ServerResponseStatus::GeneralFailure,
}
}
}

View File

@@ -1,74 +0,0 @@
use crate::network::SOCKSv5Address;
use async_std::task::{Context, Poll};
use futures::io;
use futures::io::{AsyncRead, AsyncWrite};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use super::address::HasLocalAddress;
pub trait Streamlike: AsyncRead + AsyncWrite + HasLocalAddress + Send + Sync + Unpin {}
#[derive(Clone)]
pub struct GenericStream {
internal: Arc<Mutex<dyn Streamlike>>,
}
impl GenericStream {
pub fn new<T: Streamlike + 'static>(x: T) -> GenericStream {
GenericStream {
internal: Arc::new(Mutex::new(x)),
}
}
}
impl HasLocalAddress for GenericStream {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
let item = self.internal.lock().unwrap();
item.local_addr()
}
}
impl AsyncRead for GenericStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_read(cx, buf)
}
}
impl AsyncWrite for GenericStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_close(cx)
}
}
impl<T: Streamlike + 'static> From<T> for GenericStream {
fn from(x: T) -> GenericStream {
GenericStream {
internal: Arc::new(Mutex::new(x)),
}
}
}

View File

@@ -1,276 +0,0 @@
mod datagram;
mod stream;
use crate::messages::ServerResponseStatus;
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
#[cfg(test)]
use crate::network::datagram::Datagramlike;
use crate::network::datagram::GenericDatagramSocket;
use crate::network::generic::{IntoErrorResponse, Networklike};
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::GenericStream;
use crate::network::testing::datagram::TestDatagram;
use crate::network::testing::stream::TestingStream;
use async_std::channel::{bounded, Receiver, Sender};
use async_std::sync::{Arc, Mutex};
#[cfg(test)]
use async_std::task;
use async_trait::async_trait;
#[cfg(test)]
use futures::{AsyncReadExt, AsyncWriteExt};
use std::collections::HashMap;
use std::fmt;
/// A "network", based purely on internal Rust datatypes, for testing
/// networking code. This stack operates purely in memory, so shouldn't
/// suffer from any weird networking effects ... which makes it a good
/// functional test, but not great at actually testing real-world failure
/// modes.
#[allow(clippy::type_complexity)]
#[derive(Clone)]
pub struct TestingStack {
tcp_listeners: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<TestingStream>>>>,
udp_sockets: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<(SOCKSv5Address, u16, Vec<u8>)>>>>,
next_random_socket: u16,
}
impl TestingStack {
pub fn new() -> TestingStack {
TestingStack {
tcp_listeners: Arc::new(Mutex::new(HashMap::new())),
udp_sockets: Arc::new(Mutex::new(HashMap::new())),
next_random_socket: 23,
}
}
}
impl Default for TestingStack {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum TestStackError {
AcceptFailed,
AddressBusy(SOCKSv5Address, u16),
ConnectionFailed,
FailureToSend,
NoTCPHostFound(SOCKSv5Address, u16),
ReceiveFailure,
}
impl fmt::Display for TestStackError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TestStackError::AcceptFailed => write!(f, "Accept failed; the other side died (?)"),
TestStackError::AddressBusy(ref addr, port) => {
write!(f, "Address {}:{} already in use", addr, port)
}
TestStackError::ConnectionFailed => write!(f, "Couldn't connect to host."),
TestStackError::FailureToSend => write!(
f,
"Weird internal error in testing infrastructure; channel send failed"
),
TestStackError::NoTCPHostFound(ref addr, port) => {
write!(f, "No host found at {} for TCP port {}", addr, port)
}
TestStackError::ReceiveFailure => {
write!(f, "Failed to process a UDP receive (this is weird)")
}
}
}
}
impl IntoErrorResponse for TestStackError {
fn into_response(&self) -> ServerResponseStatus {
ServerResponseStatus::GeneralFailure
}
}
#[async_trait]
impl Networklike for TestingStack {
type Error = TestStackError;
async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error> {
let table = self.tcp_listeners.lock().await;
let target = addr.into();
match table.get(&(target.clone(), port)) {
None => Err(TestStackError::NoTCPHostFound(target, port)),
Some(result) => {
let stream = TestingStream::new(target, port);
let retval = stream.invert();
match result.send(stream).await {
Ok(()) => Ok(GenericStream::new(retval)),
Err(_) => Err(TestStackError::FailureToSend),
}
}
}
}
async fn listen<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
mut port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error> {
let mut table = self.tcp_listeners.lock().await;
let target = addr.into();
let (sender, receiver) = bounded(5);
if port == 0 {
port = self.next_random_socket;
self.next_random_socket += 1;
}
table.insert((target.clone(), port), sender);
Ok(GenericListener {
internal: Box::new(TestListener::new(target, port, receiver)),
})
}
async fn bind<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
mut port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
let mut table = self.udp_sockets.lock().await;
let target = addr.into();
let (sender, receiver) = bounded(5);
if port == 0 {
port = self.next_random_socket;
self.next_random_socket += 1;
}
table.insert((target.clone(), port), sender);
Ok(GenericDatagramSocket {
internal: Box::new(TestDatagram::new(self.clone(), target, port, receiver)),
})
}
}
struct TestListener {
address: SOCKSv5Address,
port: u16,
receiver: Receiver<TestingStream>,
}
impl TestListener {
fn new(address: SOCKSv5Address, port: u16, receiver: Receiver<TestingStream>) -> Self {
TestListener {
address,
port,
receiver,
}
}
}
impl HasLocalAddress for TestListener {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
(self.address.clone(), self.port)
}
}
#[async_trait]
impl Listenerlike for TestListener {
type Error = TestStackError;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
match self.receiver.recv().await {
Ok(next) => {
let (addr, port) = next.local_addr();
Ok((GenericStream::new(next), addr, port))
}
Err(_) => Err(TestStackError::AcceptFailed),
}
}
}
#[test]
fn check_udp_sanity() {
task::block_on(async {
let mut network = TestingStack::new();
let receiver = network
.bind("localhost", 0)
.await
.expect("Failed to bind receiver socket.");
let sender = network
.bind("localhost", 0)
.await
.expect("Failed to bind sender socket.");
let buffer = [0xde, 0xea, 0xbe, 0xef];
let (receiver_addr, receiver_port) = receiver.local_addr();
sender
.send_to(&buffer, receiver_addr, receiver_port)
.await
.expect("Failure sending datagram!");
let mut recvbuffer = [0; 4];
let (s, f, p) = receiver
.recv_from(&mut recvbuffer)
.await
.expect("Didn't receive UDP message?");
let (sender_addr, sender_port) = sender.local_addr();
assert_eq!(s, 4);
assert_eq!(f, sender_addr);
assert_eq!(p, sender_port);
assert_eq!(recvbuffer, buffer);
});
}
#[test]
fn check_basic_tcp() {
task::block_on(async {
let mut network = TestingStack::new();
let listener = network
.listen("localhost", 0)
.await
.expect("Couldn't set up listener on localhost");
let (listener_address, listener_port) = listener.local_addr();
let listener_task_handle = task::spawn(async move {
dbg!("Starting listener task!!");
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
let mut result_buffer = [0u8; 4];
if let Err(e) = stream.read_exact(&mut result_buffer).await {
dbg!("Error reading buffer from stream: {}", e);
} else {
dbg!("made it through read_exact");
}
(result_buffer, addr, port)
});
let sender_task_handle = task::spawn(async move {
let mut sender = network
.connect(listener_address, listener_port)
.await
.expect("Coudln't connect to listener?");
let (sender_address, sender_port) = sender.local_addr();
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
sender
.write_all(&send_buffer)
.await
.expect("Couldn't send the write buffer");
sender
.flush()
.await
.expect("Couldn't flush the write buffer");
sender
.close()
.await
.expect("Couldn't close the write buffer");
(sender_address, sender_port)
});
let (result, result_from, result_from_port) = listener_task_handle.await;
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
let (sender_address, sender_port) = sender_task_handle.await;
assert_eq!(result_from, sender_address);
assert_eq!(result_from_port, sender_port);
});
}

View File

@@ -1,88 +0,0 @@
use crate::network::address::HasLocalAddress;
use crate::network::datagram::Datagramlike;
use crate::network::testing::{TestStackError, TestingStack};
use crate::network::SOCKSv5Address;
use async_std::channel::Receiver;
use async_trait::async_trait;
use std::cmp::Ordering;
pub struct TestDatagram {
context: TestingStack,
my_address: SOCKSv5Address,
my_port: u16,
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
}
impl TestDatagram {
pub fn new(
context: TestingStack,
my_address: SOCKSv5Address,
my_port: u16,
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
) -> Self {
TestDatagram {
context,
my_address,
my_port,
input_stream,
}
}
}
impl HasLocalAddress for TestDatagram {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
(self.my_address.clone(), self.my_port)
}
}
#[async_trait]
impl Datagramlike for TestDatagram {
type Error = TestStackError;
async fn send_to(
&self,
buf: &[u8],
target: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
let table = self.context.udp_sockets.lock().await;
match table.get(&(target, port)) {
None => Ok(buf.len()),
Some(sender) => {
sender
.send((self.my_address.clone(), self.my_port, buf.to_vec()))
.await
.map_err(|_| TestStackError::FailureToSend)?;
Ok(buf.len())
}
}
}
async fn recv_from(
&self,
buffer: &mut [u8],
) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
let (from_addr, from_port, message) = self
.input_stream
.recv()
.await
.map_err(|_| TestStackError::ReceiveFailure)?;
match message.len().cmp(&buffer.len()) {
Ordering::Greater => {
buffer.copy_from_slice(&message[..buffer.len()]);
Ok((message.len(), from_addr, from_port))
}
Ordering::Less => {
(&mut buffer[..message.len()]).copy_from_slice(&message);
Ok((message.len(), from_addr, from_port))
}
Ordering::Equal => {
buffer.copy_from_slice(message.as_ref());
Ok((message.len(), from_addr, from_port))
}
}
}
}

View File

@@ -1,192 +0,0 @@
use crate::network::address::HasLocalAddress;
use crate::network::stream::Streamlike;
use crate::network::SOCKSv5Address;
use async_std::io;
use async_std::io::{Read, Write};
use async_std::task::{Context, Poll, Waker};
use std::cell::UnsafeCell;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Clone)]
pub struct TestingStream {
address: SOCKSv5Address,
port: u16,
read_side: NonNull<TestingStreamData>,
write_side: NonNull<TestingStreamData>,
}
unsafe impl Send for TestingStream {}
unsafe impl Sync for TestingStream {}
struct TestingStreamData {
lock: AtomicBool,
writer_dead: AtomicBool,
waiters: UnsafeCell<Vec<Waker>>,
buffer: UnsafeCell<Vec<u8>>,
}
unsafe impl Send for TestingStreamData {}
unsafe impl Sync for TestingStreamData {}
impl TestingStream {
/// Generate a testing stream. Note that this is directional. So, if you want to
/// talk to this stream, you should also generate an `invert()` and pass that to
/// the other thread(s).
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
let read_side_data = TestingStreamData {
lock: AtomicBool::new(false),
writer_dead: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
};
let write_side_data = TestingStreamData {
lock: AtomicBool::new(false),
writer_dead: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
};
let boxed_rsd = Box::new(read_side_data);
let boxed_wsd = Box::new(write_side_data);
let raw_read_ptr = Box::leak(boxed_rsd);
let raw_write_ptr = Box::leak(boxed_wsd);
TestingStream {
address,
port,
read_side: NonNull::new(raw_read_ptr).unwrap(),
write_side: NonNull::new(raw_write_ptr).unwrap(),
}
}
/// Get the flip side of this stream; reads from the inverted side will catch the writes
/// of the original, etc.
pub fn invert(&self) -> TestingStream {
TestingStream {
address: self.address.clone(),
port: self.port,
read_side: self.write_side,
write_side: self.read_side,
}
}
}
impl TestingStreamData {
fn acquire(&mut self) {
loop {
match self
.lock
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
{
Err(_) => continue,
Ok(_) => return,
}
}
}
fn release(&mut self) {
self.lock.store(false, Ordering::SeqCst);
}
}
impl HasLocalAddress for TestingStream {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
(self.address.clone(), self.port)
}
}
impl Read for TestingStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
// so, we're going to spin here, which is less than ideal but should work fine
// in practice. we'll obviously need to be very careful to ensure that we keep
// the stuff internal to this spin really short.
let internals = unsafe { self.read_side.as_mut() };
internals.acquire();
let stream_buffer = internals.buffer.get_mut();
let amount_available = stream_buffer.len();
if amount_available == 0 {
// we wait to do this check until we've determined the buffer is empty,
// so that we make sure to drain any residual stuff in there.
if internals.writer_dead.load(Ordering::SeqCst) {
internals.release();
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Writer closed the socket.",
)));
} else {
let waker = cx.waker().clone();
internals.waiters.get_mut().push(waker);
internals.release();
return Poll::Pending;
}
}
let amt_written = if buf.len() >= amount_available {
(&mut buf[0..amount_available]).copy_from_slice(stream_buffer);
stream_buffer.clear();
amount_available
} else {
let amt_to_copy = buf.len();
buf.copy_from_slice(&stream_buffer[0..amt_to_copy]);
stream_buffer.copy_within(amt_to_copy.., 0);
let amt_left = amount_available - amt_to_copy;
stream_buffer.resize(amt_left, 0);
amt_to_copy
};
internals.release();
Poll::Ready(Ok(amt_written))
}
}
impl Write for TestingStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let internals = unsafe { self.write_side.as_mut() };
internals.acquire();
let stream_buffer = internals.buffer.get_mut();
stream_buffer.extend_from_slice(buf);
for waiter in internals.waiters.get_mut().drain(0..) {
waiter.wake();
}
internals.release();
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(())) // FIXME: Might consider having this wait until the buffer is empty
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(())) // FIXME: Might consider putting in some open/closed logic here
}
}
impl Streamlike for TestingStream {}
impl Drop for TestingStream {
fn drop(&mut self) {
let internals = unsafe { self.write_side.as_mut() };
internals.writer_dead.store(true, Ordering::SeqCst);
internals.acquire();
for waiter in internals.waiters.get_mut().drain(0..) {
waiter.wake();
}
internals.release();
}
}

View File

@@ -0,0 +1,84 @@
use std::net::SocketAddr;
/// The security parameters that you can assign to the server, to make decisions
/// about the weirdos it accepts as users. It is recommended that you only use
/// wide open connections when you're 100% sure that the server will only be
/// accessible locally.
#[derive(Clone)]
pub struct SecurityParameters {
/// Allow completely unauthenticated connections. You should be very, very
/// careful about setting this to true, especially if you don't provide a
/// guard to ensure that you're getting connections from reasonable places.
pub allow_unauthenticated: bool,
/// An optional function that can serve as a firewall for new connections.
/// Return true if the connection should be allowed to continue, false if
/// it shouldn't. This check happens before any data is read from or written
/// to the connecting party.
pub allow_connection: Option<fn(&SocketAddr) -> bool>,
/// An optional function to check a user name (first argument) and password
/// (second argument). Return true if the username / password is good, false
/// if not.
pub check_password: Option<fn(&str, &str) -> bool>,
/// An optional function to transition the stream from an unencrypted one to
/// an encrypted on. The assumption is you're using something like `rustls`
/// to make this happen; the exact mechanism is outside the scope of this
/// particular crate. If the connection shouldn't be allowed for some reason
/// (a bad certificate or handshake, for example), then return None; otherwise,
/// return the new stream.
pub connect_tls: Option<fn() -> Option<()>>,
}
impl SecurityParameters {
/// Generates a `SecurityParameters` object that's empty. It won't accept
/// anything, because it has no mechanisms it can use to actually authenticate
/// a user and yet won't allow unauthenticated connections.
pub fn new() -> SecurityParameters {
SecurityParameters {
allow_unauthenticated: false,
allow_connection: None,
check_password: None,
connect_tls: None,
}
}
/// Generates a `SecurityParameters` object that does not, in any way,
/// restrict who can log in. It also will not induce any transition into
/// TLS. Use this at your own risk ... or, really, just don't use this,
/// ever, and certainly not in production.
pub fn unrestricted() -> SecurityParameters {
SecurityParameters {
allow_unauthenticated: true,
allow_connection: None,
check_password: None,
connect_tls: None,
}
}
/// Use the provided function to check incoming connections before proceeding
/// with the rest of the handshake.
pub fn check_connections(mut self, checker: fn(&SocketAddr) -> bool) -> SecurityParameters {
self.allow_connection = Some(checker);
self
}
/// Use the provided function to check usernames and passwords provided
/// to the server.
pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters {
self.check_password = Some(checker);
self
}
/// Use the provide function to validate a TLS connection, and transition it
/// to the new stream type. If the handshake fails, return `None` instead of
/// `Some`. (And maybe log it somewhere, you know.)
pub fn tls_converter(mut self, converter: fn() -> Option<()>) -> SecurityParameters {
self.connect_tls = Some(converter);
self
}
}
impl Default for SecurityParameters {
fn default() -> Self {
Self::new()
}
}

View File

@@ -1,59 +0,0 @@
use crate::errors::{DeserializationError, SerializationError};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub async fn read_string<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<String, DeserializationError> {
let mut length_buffer = [0; 1];
if r.read(&mut length_buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
let target = length_buffer[0] as usize;
if target == 0 {
return Err(DeserializationError::InvalidEmptyString);
}
let mut bytestring = vec![0; target];
read_amt(r, target, &mut bytestring).await?;
Ok(String::from_utf8(bytestring)?)
}
pub async fn write_string<W: AsyncWrite + Send + Unpin>(
s: &str,
w: &mut W,
) -> Result<(), SerializationError> {
let bytestring = s.as_bytes();
if bytestring.is_empty() || bytestring.len() > 255 {
return Err(SerializationError::InvalidStringLength(s.to_string()));
}
w.write_all(&[bytestring.len() as u8]).await?;
w.write_all(bytestring)
.await
.map_err(SerializationError::IOError)
}
pub async fn read_amt<R: AsyncRead + Send + Unpin>(
r: &mut R,
amt: usize,
buffer: &mut [u8],
) -> Result<(), DeserializationError> {
let mut amt_read = 0;
while amt_read < amt {
let chunk_amt = r.read(&mut buffer[amt_read..]).await?;
if chunk_amt == 0 {
return Err(DeserializationError::NotEnoughData);
}
amt_read += chunk_amt;
}
Ok(())
}

View File

@@ -1,157 +1,59 @@
//! An implementation of a SOCKSv5 server, parameterizable by the security parameters use std::net::SocketAddr;
//! and network stack you want to use. You should implement the server by first
//! setting up the `SecurityParameters`, then initializing the server object, and use crate::address::SOCKSv5Address;
//! then running it, as follows:
//!
//! ```
//! use async_socks5::network::Builtin;
//! use async_socks5::server::{SecurityParameters, SOCKSv5Server};
//! use std::io;
//!
//! async {
//! let parameters = SecurityParameters::new()
//! .password_check(|u,p| { u == "adam" && p == "evil" });
//! let network = Builtin::new();
//! let server = SOCKSv5Server::new(network, parameters);
//! server.start("localhost", 9999).await;
//! // ... do other stuff ...
//! };
//!
//! ```
use crate::errors::{AuthenticationError, DeserializationError, SerializationError};
use crate::messages::{ use crate::messages::{
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting, AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandReadError,
ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus, ClientConnectionRequest, ClientConnectionRequestReadError, ClientGreeting,
ClientGreetingReadError, ClientUsernamePassword, ClientUsernamePasswordReadError,
ServerAuthResponse, ServerAuthResponseWriteError, ServerChoice, ServerChoiceWriteError,
ServerResponse, ServerResponseStatus, ServerResponseWriteError,
}; };
use crate::network::address::HasLocalAddress; use crate::security_parameters::SecurityParameters;
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::GenericStream;
use crate::network::SOCKSv5Address;
use async_std::io;
use async_std::io::prelude::WriteExt;
use async_std::sync::{Arc, Mutex};
use async_std::task;
use futures::Stream;
use log::{error, info, trace, warn};
use std::collections::HashMap;
use std::default::Default;
use std::fmt::{Debug, Display};
use thiserror::Error; use thiserror::Error;
use tokio::io::{copy_bidirectional, AsyncWriteExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket};
/// A convenient bit of shorthand for an address and port
pub type AddressAndPort = (SOCKSv5Address, u16);
// Just some shorthand for us.
type ResultHandle = task::JoinHandle<Result<(), String>>;
/// A handle representing a SOCKSv5 server, parameterized by the underlying network
/// stack it runs over.
#[derive(Clone)] #[derive(Clone)]
pub struct SOCKSv5Server<N: Networklike> { pub struct SOCKSv5Server {
network: Arc<Mutex<N>>,
running_servers: Arc<Mutex<HashMap<AddressAndPort, ResultHandle>>>,
security_parameters: SecurityParameters, security_parameters: SecurityParameters,
} }
/// The security parameters that you can assign to the server, to make decisions #[derive(Clone, Debug, Error, PartialEq)]
/// about the weirdos it accepts as users. It is recommended that you only use pub enum SOCKSv5ServerError {
/// wide open connections when you're 100% sure that the server will only be #[error("Underlying networking error: {0}")]
/// accessible locally. NetworkingError(String),
#[derive(Clone)] #[error("Couldn't negotiate authentication with client.")]
pub struct SecurityParameters { ItsNotUsItsYou,
/// Allow completely unauthenticated connections. You should be very, very #[error("Client greeting read problem: {0}")]
/// careful about setting this to true, especially if you don't provide a GreetingReadProblem(#[from] ClientGreetingReadError),
/// guard to ensure that you're getting connections from reasonable places. #[error("Server choice write problem: {0}")]
pub allow_unauthenticated: bool, ChoiceWriteProblem(#[from] ServerChoiceWriteError),
/// An optional function that can serve as a firewall for new connections. #[error("Failed username/password authentication for user {0}")]
/// Return true if the connection should be allowed to continue, false if FailedUsernamePassword(String),
/// it shouldn't. This check happens before any data is read from or written #[error("Server authentication response problem: {0}")]
/// to the connecting party. ServerAuthWriteProblem(#[from] ServerAuthResponseWriteError),
pub allow_connection: Option<fn(&SOCKSv5Address, u16) -> bool>, #[error("Error reading client username/password: {0}")]
/// An optional function to check a user name (first argument) and password UserPassReadProblem(#[from] ClientUsernamePasswordReadError),
/// (second argument). Return true if the username / password is good, false #[error("Error reading client connection command: {0}")]
/// if not. ClientConnReadProblem(#[from] ClientConnectionCommandReadError),
pub check_password: Option<fn(&str, &str) -> bool>, #[error("Error reading client connection request: {0}")]
/// An optional function to transition the stream from an unencrypted one to ClientRequestReadProblem(#[from] ClientConnectionRequestReadError),
/// an encrypted on. The assumption is you're using something like `rustls` #[error("Error writing server response: {0}")]
/// to make this happen; the exact mechanism is outside the scope of this ServerResponseWriteProblem(#[from] ServerResponseWriteError),
/// particular crate. If the connection shouldn't be allowed for some reason
/// (a bad certificate or handshake, for example), then return None; otherwise,
/// return the new stream.
pub connect_tls: Option<fn(GenericStream) -> Option<GenericStream>>,
} }
impl SecurityParameters { impl From<std::io::Error> for SOCKSv5ServerError {
/// Generates a `SecurityParameters` object that's empty. It won't accept fn from(x: std::io::Error) -> SOCKSv5ServerError {
/// anything, because it has no mechanisms it can use to actually authenticate SOCKSv5ServerError::NetworkingError(format!("{}", x))
/// a user and yet won't allow unauthenticated connections.
pub fn new() -> SecurityParameters {
SecurityParameters {
allow_unauthenticated: false,
allow_connection: None,
check_password: None,
connect_tls: None,
}
}
/// Generates a `SecurityParameters` object that does not, in any way,
/// restrict who can log in. It also will not induce any transition into
/// TLS. Use this at your own risk ... or, really, just don't use this,
/// ever, and certainly not in production.
pub fn unrestricted() -> SecurityParameters {
SecurityParameters {
allow_unauthenticated: true,
allow_connection: None,
check_password: None,
connect_tls: None,
}
}
/// Use the provided function to check incoming connections before proceeding
/// with the rest of the handshake.
pub fn check_connections(
mut self,
checker: fn(&SOCKSv5Address, u16) -> bool,
) -> SecurityParameters {
self.allow_connection = Some(checker);
self
}
/// Use the provided function to check usernames and passwords provided
/// to the server.
pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters {
self.check_password = Some(checker);
self
}
/// Use the provide function to validate a TLS connection, and transition it
/// to the new stream type. If the handshake fails, return `None` instead of
/// `Some`. (And maybe log it somewhere, you know.)
pub fn tls_converter(
mut self,
converter: fn(GenericStream) -> Option<GenericStream>,
) -> SecurityParameters {
self.connect_tls = Some(converter);
self
} }
} }
impl Default for SecurityParameters { impl SOCKSv5Server {
fn default() -> Self { /// Initialize a SOCKSv5 server for use later on. Once initialized, you can listen
Self::new() /// on as many addresses and ports as you like; the metadata about the server will
} /// be synced across all the instances.
} pub fn new(security_parameters: SecurityParameters) -> Self {
impl<N: Networklike + Clone + Send + 'static> SOCKSv5Server<N> {
/// Initialize a SOCKSv5 server for use later on. Once initialize, you can listen on
/// as many addresses and ports as you like; the metadata about the server will be
/// sync'd across all of the instances, should you want to gather that data for some
/// reason.
pub fn new(network: N, security_parameters: SecurityParameters) -> SOCKSv5Server<N> {
SOCKSv5Server { SOCKSv5Server {
network: Arc::new(Mutex::new(network)),
running_servers: Arc::new(Mutex::new(HashMap::new())),
security_parameters, security_parameters,
} }
} }
@@ -159,191 +61,241 @@ impl<N: Networklike + Clone + Send + 'static> SOCKSv5Server<N> {
/// Start a server on the given address and port. This function returns when it has /// Start a server on the given address and port. This function returns when it has
/// set up its listening socket, but spawns a separate task to actually wait for /// set up its listening socket, but spawns a separate task to actually wait for
/// connections. You can query which ones are still active, or see which ones have /// connections. You can query which ones are still active, or see which ones have
/// failed, using some of the other items in this structure. /// failed, using some of the other functions for this structure.
///
/// If you don't care what port is assigned to this server, pass 0 in as the port
/// number and one will be chosen for you by the OS.
///
pub async fn start<A: Send + Into<SOCKSv5Address>>( pub async fn start<A: Send + Into<SOCKSv5Address>>(
&self, &self,
addr: A, addr: A,
port: u16, port: u16,
) -> Result<(), N::Error> { ) -> Result<(), std::io::Error> {
// This might seem a little weird, but we do this in a separate block to make it let listener = match addr.into() {
// as clear as possible to the borrow checker (and the reader) that we only want SOCKSv5Address::IP4(x) => TcpListener::bind((x, port)).await?,
// to hold the lock while we're actually calling listen. SOCKSv5Address::IP6(x) => TcpListener::bind((x, port)).await?,
let listener = { SOCKSv5Address::Hostname(x) => TcpListener::bind((x, port)).await?,
let mut network = self.network.lock().await; };
network.listen(addr, port).await
}?;
// this should really be the same as the input, but technically they could've let sockaddr = listener.local_addr()?;
// thrown some zeros in there and let the underlying network stack decide. So tracing::info!(
// we'll just pull this information post-initialization, and maybe get something "Starting SOCKSv5 server on {}:{}",
// a bit more detailed. sockaddr.ip(),
let (my_addr, my_port) = listener.local_addr(); sockaddr.port()
info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port); );
// OK, spawn off the server loop, and then we'll register this in our list of let second_life = self.clone();
// things running.
let new_self = self.clone(); tokio::task::spawn(async move {
let task_id = task::spawn(async move { if let Err(e) = second_life.server_loop(listener).await {
new_self tracing::error!(
.server_loop(listener) "{}:{}: server network error: {}",
.await sockaddr.ip(),
.map_err(|x| format!("Server network error: {}", x)) sockaddr.port(),
e
);
}
}); });
let mut server_map = self.running_servers.lock().await;
server_map.insert((my_addr, my_port), task_id);
Ok(()) Ok(())
} }
/// Provide a list of open sockets on the server. /// Run the server loop for a particular listener. This routine will never actually
pub async fn open_sockets(&self) -> Vec<AddressAndPort> { /// return except in error conditions.
let server_map = self.running_servers.lock().await; async fn server_loop(self, listener: TcpListener) -> Result<(), std::io::Error> {
server_map.keys().cloned().collect()
}
pub fn subserver_results(&mut self) -> impl Stream<Item = Result<(), String>> {
futures::stream::unfold(self.running_servers.clone(), |locked_map| async move {
let first_server = {
let mut server_map = locked_map.lock().await;
let first_key = server_map.keys().next().cloned()?;
server_map.remove(&first_key)
}?;
let result = first_server.await;
Some((result, locked_map))
})
}
async fn server_loop(self, listener: GenericListener<N::Error>) -> Result<(), N::Error> {
loop { loop {
let (stream, their_addr, their_port) = listener.accept().await?; let (socket, their_addr) = listener.accept().await?;
trace!(
"Initial accept of connection from {}:{}",
their_addr,
their_port
);
// before we do anything, make sure this connection is cool. we don't want to // before we do anything of note, make sure this connection is cool. we don't want
// waste resources (or parse any data) if this isn't someone we actually care // to waste any resources (and certainly don't want to handle any data!) if this
// about it. // isn't someone we want to accept connections from.
if let Some(checker) = &self.security_parameters.allow_connection { tracing::trace!("Initial accept of connection from {}", their_addr);
if !checker(&their_addr, their_port) { if let Some(checker) = self.security_parameters.allow_connection {
info!( if !checker(&their_addr) {
"Rejecting attempted connection from {}:{}", tracing::info!("Rejecting attempted connection from {}", their_addr,);
their_addr, their_port
);
continue;
} }
continue;
} }
// throw this off into another task to take from here. We could to the rest // continue this work in another task. we could absolutely do this work here,
// of this handshake here, but there's a chance that an adversarial connection // but just in case someone starts doing slow responses (or other nasty things),
// could just stall us out, and keep us from doing the next connection. So ... // we want to make sure that that doesn't slow down our ability to accept other
// we'll potentially spin off the task early. // requests.
let me_again = self.clone(); let me_again = self.clone();
task::spawn(async move { tokio::task::spawn(async move {
me_again if let Err(e) = me_again.start_authentication(their_addr, socket).await {
.authenticate_step(their_addr, their_port, stream) tracing::error!("{}: server handler failure: {}", their_addr, e);
.await; }
}); });
} }
} }
async fn authenticate_step( /// Start the authentication phase of the SOCKS handshake. This may be very short, and
/// is the first stage of handling a request. This will only really return on errors.
async fn start_authentication(
self, self,
their_addr: SOCKSv5Address, their_addr: SocketAddr,
their_port: u16, mut socket: TcpStream,
base_stream: GenericStream, ) -> Result<(), SOCKSv5ServerError> {
) { let greeting = ClientGreeting::read(&mut socket).await?;
// Turn this stream into one where we've authenticated the other side. Or, you
// know, don't, and just restart this loop.
let mut authenticated_stream =
match run_authentication(&self.security_parameters, base_stream).await {
Ok(authed_stream) => authed_stream,
Err(e) => {
warn!(
"Failure running authentication from {}:{}: {}",
their_addr, their_port, e
);
return;
}
};
// Figure out what the client actually wants from this connection, and match choose_authentication_method(&self.security_parameters, &greeting.acceptable_methods)
// then dispatch a task to deal with that. {
let mccr = ClientConnectionRequest::read(&mut authenticated_stream).await; // it's not us, it's you. (we're just going to say no.)
match mccr { None => {
Err(e) => warn!("Failure figuring out what the client wanted: {}", e), tracing::trace!(
Ok(ccr) => match ccr.command_code { "{}: Failed to find acceptable authentication method.",
ClientConnectionCommand::AssociateUDPPort => self their_addr,
.handle_udp_request(authenticated_stream, ccr, their_addr, their_port) );
.await let rejection_letter = ServerChoice::rejection();
.unwrap_or_else(|e| warn!("Internal server error in UDP association: {}", e)),
ClientConnectionCommand::EstablishTCPPortBinding => self rejection_letter.write(&mut socket).await?;
.handle_tcp_bind(authenticated_stream, ccr, their_addr, their_port) socket.flush().await?;
.await
.unwrap_or_else(|e| warn!("Internal server error in TCP bind: {}", e)), Err(SOCKSv5ServerError::ItsNotUsItsYou)
ClientConnectionCommand::EstablishTCPStream => self }
.handle_tcp_forward(authenticated_stream, ccr, their_addr, their_port)
.await // the gold standard. great choice.
.unwrap_or_else(|e| warn!("Internal server error in TCP forward: {}", e)), Some(ChosenMethod::TLS(_converter)) => {
}, unimplemented!()
}
// well, I guess this is something?
Some(ChosenMethod::Password(checker)) => {
tracing::trace!(
"{}: Choosing username/password for authentication.",
their_addr,
);
let ok_lets_do_password =
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
ok_lets_do_password.write(&mut socket).await?;
socket.flush().await?;
let their_info = ClientUsernamePassword::read(&mut socket).await?;
if checker(&their_info.username, &their_info.password) {
let its_all_good = ServerAuthResponse::success();
its_all_good.write(&mut socket).await?;
socket.flush().await?;
self.choose_mode(socket, their_addr).await
} else {
let yeah_no = ServerAuthResponse::failure();
yeah_no.write(&mut socket).await?;
socket.flush().await?;
Err(SOCKSv5ServerError::FailedUsernamePassword(
their_info.username,
))
}
}
// Um. I guess we're doing this unchecked. Yay?
Some(ChosenMethod::None) => {
tracing::trace!(
"{}: Just skipping the whole authentication thing.",
their_addr,
);
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
nothin_i_guess.write(&mut socket).await?;
socket.flush().await?;
self.choose_mode(socket, their_addr).await
}
} }
} }
async fn handle_udp_request( /// Determine which of the modes we might want this particular connection to run
/// in.
async fn choose_mode(
self, self,
stream: GenericStream, mut socket: TcpStream,
ccr: ClientConnectionRequest, their_addr: SocketAddr,
their_addr: SOCKSv5Address, ) -> Result<(), SOCKSv5ServerError> {
their_port: u16, let ccr = ClientConnectionRequest::read(&mut socket).await?;
) -> Result<(), ServerError<N::Error>> { match ccr.command_code {
// Let the user know that we're maybe making progress ClientConnectionCommand::AssociateUDPPort => {
let (my_addr, my_port) = stream.local_addr(); self.handle_udp_request(socket, their_addr, ccr).await?
info!( }
"[{}:{}] Handling UDP bind request from {}:{}, seeking to bind {}:{}", ClientConnectionCommand::EstablishTCPStream => {
my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port self.handle_tcp_request(socket, their_addr, ccr).await?
); }
ClientConnectionCommand::EstablishTCPPortBinding => {
unimplemented!() self.handle_tcp_binding_request(socket, their_addr, ccr)
.await?
}
}
Ok(())
} }
async fn handle_tcp_forward( /// Handle UDP forwarding requests
#[allow(unreachable_code)]
async fn handle_udp_request(
self, self,
mut stream: GenericStream, stream: TcpStream,
their_addr: SocketAddr,
ccr: ClientConnectionRequest, ccr: ClientConnectionRequest,
their_addr: SOCKSv5Address, ) -> Result<(), SOCKSv5ServerError> {
their_port: u16, let my_addr = stream.local_addr()?;
) -> Result<(), ServerError<N::Error>> { tracing::info!(
"[{}:{}] Handling UDP bind request from {}:{}, seeking to bind towards {}:{}",
my_addr.ip(),
my_addr.port(),
their_addr.ip(),
their_addr.port(),
ccr.destination_address,
ccr.destination_port
);
let _socket = match ccr.destination_address.clone() {
SOCKSv5Address::IP4(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
SOCKSv5Address::IP6(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
SOCKSv5Address::Hostname(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
};
// OK, it worked. In order to mitigate an infinitesimal chance of a race condition, we're
// going to set up our forwarding tasks first, and then return the result to the user. (Note,
// we'd have to be slightly more precious in order to ensure a lack of race conditions, as
// the runtime could take forever to actually start these tasks, but I'm not ready to be
// bothered by this, yet. FIXME.)
unimplemented!();
// Cool; now we can get the result out to the user.
let bound_address = _socket.local_addr()?;
let response = ServerResponse {
status: ServerResponseStatus::RequestGranted,
bound_address: bound_address.ip().into(),
bound_port: bound_address.port(),
};
response.write(&mut stream).await?;
Ok(())
}
/// Handle TCP forwarding requests
async fn handle_tcp_request(
self,
mut stream: TcpStream,
their_addr: SocketAddr,
ccr: ClientConnectionRequest,
) -> Result<(), SOCKSv5ServerError> {
// Let the user know that we're maybe making progress // Let the user know that we're maybe making progress
let (my_addr, my_port) = stream.local_addr(); let my_addr = stream.local_addr()?;
info!( tracing::info!(
"[{}:{}] Handling TCP forward request from {}:{}, seeking to connect to {}:{}", "[{}] Handling TCP forward request from {}, seeking to connect to {}:{}",
my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port my_addr,
their_addr,
ccr.destination_address,
ccr.destination_port
); );
// OK, first thing's first: We need to actually connect to the server that the user // OK, first thing's first: We need to actually connect to the server that the user
// wants us to connect to. // wants us to connect to.
let connection_res = { let outgoing_stream = match &ccr.destination_address {
let mut network = self.network.lock().await; SOCKSv5Address::IP4(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
network SOCKSv5Address::IP6(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
.connect(ccr.destination_address.clone(), ccr.destination_port) SOCKSv5Address::Hostname(x) => {
.await TcpStream::connect((x.as_ref(), ccr.destination_port)).await?
};
let outgoing_stream = match connection_res {
Ok(x) => x,
Err(e) => {
error!("Failed to connect to {}: {}", ccr.destination_address, e);
let response = ServerResponse::error(&e);
response.write(&mut stream).await?;
return Err(ServerError::NetworkError(e));
} }
}; };
trace!( tracing::trace!(
"Connection established to {}:{}", "Connection established to {}:{}",
ccr.destination_address, ccr.destination_address,
ccr.destination_port ccr.destination_port
@@ -352,117 +304,117 @@ impl<N: Networklike + Clone + Send + 'static> SOCKSv5Server<N> {
// Now, for whatever reason -- and this whole thing sent me down a garden path // Now, for whatever reason -- and this whole thing sent me down a garden path
// in understanding how this whole protocol works -- we tell the user what address // in understanding how this whole protocol works -- we tell the user what address
// and port we bound for that connection. // and port we bound for that connection.
let (bound_address, bound_port) = outgoing_stream.local_addr(); let bound_address = outgoing_stream.local_addr()?;
let response = ServerResponse { let response = ServerResponse {
status: ServerResponseStatus::RequestGranted, status: ServerResponseStatus::RequestGranted,
bound_address, bound_address: bound_address.ip().into(),
bound_port, bound_port: bound_address.port(),
}; };
response.write(&mut stream).await?; response.write(&mut stream).await?;
// so now tie our streams together, and we're good to go // so now tie our streams together, and we're good to go
tie_streams( tie_streams(stream, outgoing_stream).await;
format!("{}:{}", their_addr, their_port),
stream,
format!("{}:{}", ccr.destination_address, ccr.destination_port),
outgoing_stream,
)
.await;
Ok(()) Ok(())
} }
async fn handle_tcp_bind( /// Handle TCP binding requests
async fn handle_tcp_binding_request(
self, self,
mut stream: GenericStream, mut stream: TcpStream,
their_addr: SocketAddr,
ccr: ClientConnectionRequest, ccr: ClientConnectionRequest,
their_addr: SOCKSv5Address, ) -> Result<(), SOCKSv5ServerError> {
their_port: u16,
) -> Result<(), ServerError<N::Error>> {
// Let the user know that we're maybe making progress // Let the user know that we're maybe making progress
let (my_addr, my_port) = stream.local_addr(); let my_addr = stream.local_addr()?;
info!( tracing::info!(
"[{}:{}] Handling TCP bind request from {}:{}, seeking to bind {}:{}", "[{}] Handling TCP bind request from {}, seeking to bind {}:{}",
my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port my_addr,
their_addr,
ccr.destination_address,
ccr.destination_port
); );
// OK, we have to bind the darn socket first. // OK, we have to bind the darn socket first.
let port_binding = { let listener_port = match &their_addr {
let mut network = self.network.lock().await; SocketAddr::V4(_) => TcpSocket::new_v4(),
network.listen(their_addr.clone(), their_port).await SocketAddr::V6(_) => TcpSocket::new_v6(),
} }?;
.map_err(ServerError::NetworkError)?; // FIXME: Might want to bind on a particular interface, based on a
// config flag, at some point.
let listener = listener_port.listen(1)?;
// Tell them what we bound, just in case they want to inform anyone. // Tell them what we bound, just in case they want to inform anyone.
let (bound_address, bound_port) = port_binding.local_addr(); let bound_address = listener.local_addr()?;
let response = ServerResponse { let response = ServerResponse {
status: ServerResponseStatus::RequestGranted, status: ServerResponseStatus::RequestGranted,
bound_address, bound_address: bound_address.ip().into(),
bound_port, bound_port: bound_address.port(),
}; };
response.write(&mut stream).await?; response.write(&mut stream).await?;
// Wait politely for someone to talk to us. // Wait politely for someone to talk to us.
let (other, other_addr, other_port) = port_binding let (other, other_addr) = listener.accept().await?;
.accept()
.await
.map_err(ServerError::NetworkError)?;
let info = ServerResponse { let info = ServerResponse {
status: ServerResponseStatus::RequestGranted, status: ServerResponseStatus::RequestGranted,
bound_address: other_addr.clone(), bound_address: other_addr.ip().into(),
bound_port: other_port, bound_port: other_addr.port(),
}; };
info.write(&mut stream).await?; info.write(&mut stream).await?;
tie_streams( tie_streams(stream, other).await;
format!("{}:{}", their_addr, their_port),
stream,
format!("{}:{}", other_addr, other_port),
other,
)
.await;
Ok(()) Ok(())
} }
} }
async fn tie_streams( async fn tie_streams(mut left: TcpStream, mut right: TcpStream) {
left_name: String, let left_local_addr = left
left: GenericStream, .local_addr()
right_name: String, .expect("couldn't get left local address in tie_streams");
right: GenericStream, let left_peer_addr = left
) { .peer_addr()
// Now that we've informed them of that, we set up one task to transfer information .expect("couldn't get left peer address in tie_streams");
// from the current stream (`stream`) to the connection (`outgoing_stream`), and let right_local_addr = right
// another task that goes in the reverse direction. .local_addr()
// .expect("couldn't get right local address in tie_streams");
// I've chosen to start two fresh tasks and let this one die; I'm not sure that let right_peer_addr = right
// this is the right approach. My only rationale is that this might let some .peer_addr()
// memory we might have accumulated along the way drop more easily, but that .expect("couldn't get right peer address in tie_streams");
// might not actually matter.
let mut from_left = left.clone();
let mut from_right = right.clone();
let mut to_left = left;
let mut to_right = right;
let left_right_name = format!("{} >--> {}", left_name, right_name);
let right_left_name = format!("{} <--< {}", left_name, right_name);
task::spawn(async move { tokio::task::spawn(async move {
info!("Spawned {} task", left_right_name); tracing::info!(
if let Err(e) = io::copy(&mut from_left, &mut to_right).await { "Setting up linkage {}/{} <-> {}/{}",
warn!("{} connection failed with: {}", left_right_name, e); left_peer_addr,
} left_local_addr,
}); right_local_addr,
right_peer_addr
task::spawn(async move { );
info!("Spawned {} task", right_left_name); match copy_bidirectional(&mut left, &mut right).await {
if let Err(e) = io::copy(&mut from_right, &mut to_left).await { Ok((l2r, r2l)) => tracing::info!(
warn!("{} connection failed with: {}", right_left_name, e); "Shutting down linkage {}/{} <-> {}/{} (sent {} and {} bytes, respectively)",
left_peer_addr,
left_local_addr,
right_local_addr,
right_peer_addr,
l2r,
r2l
),
Err(e) => tracing::warn!(
"Shutting down linkage {}/{} <-> {}/{} with error: {}",
left_peer_addr,
left_local_addr,
right_local_addr,
right_peer_addr,
e
),
} }
}); });
} }
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
enum ChosenMethod { enum ChosenMethod {
TLS(fn(GenericStream) -> Option<GenericStream>), TLS(fn() -> Option<()>),
Password(fn(&str, &str) -> bool), Password(fn(&str, &str) -> bool),
None, None,
} }
@@ -567,7 +519,7 @@ fn reasonable_auth_method_choices() {
); );
// OK, cool. If we have a TLS handler, that shouldn't actually make a difference. // OK, cool. If we have a TLS handler, that shouldn't actually make a difference.
params.connect_tls = Some(|_| unimplemented!()); params.connect_tls = Some(|| unimplemented!());
assert_eq!( assert_eq!(
choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from), choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from),
None None
@@ -580,7 +532,7 @@ fn reasonable_auth_method_choices() {
None None
); );
// but if we have a handler, and they go for it, we use it. // but if we have a handler, and they go for it, we use it.
params.connect_tls = Some(|_| unimplemented!()); params.connect_tls = Some(|| unimplemented!());
assert_eq!( assert_eq!(
choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from), choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from),
Some(AuthenticationMethod::SSL) Some(AuthenticationMethod::SSL)
@@ -600,75 +552,3 @@ fn reasonable_auth_method_choices() {
Some(AuthenticationMethod::SSL) Some(AuthenticationMethod::SSL)
); );
} }
async fn run_authentication(
params: &SecurityParameters,
mut stream: GenericStream,
) -> Result<GenericStream, AuthenticationError> {
let greeting = ClientGreeting::read(&mut stream).await?;
match choose_authentication_method(params, &greeting.acceptable_methods) {
// it's not us, it's you
None => {
trace!("Failed to find acceptable authentication method.");
let rejection_letter = ServerChoice::rejection();
rejection_letter.write(&mut stream).await?;
stream.flush().await?;
Err(AuthenticationError::ItsNotUsItsYou)
}
// the gold standard. great choice.
Some(ChosenMethod::TLS(converter)) => {
trace!("Choosing TLS for authentication.");
let lets_do_this = ServerChoice::option(AuthenticationMethod::SSL);
lets_do_this.write(&mut stream).await?;
stream.flush().await?;
converter(stream).ok_or(AuthenticationError::FailedTLSHandshake)
}
// well, I guess this is something?
Some(ChosenMethod::Password(checker)) => {
trace!("Choosing Username/Password for authentication.");
let ok_lets_do_password =
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
ok_lets_do_password.write(&mut stream).await?;
stream.flush().await?;
let their_info = ClientUsernamePassword::read(&mut stream).await?;
if checker(&their_info.username, &their_info.password) {
let its_all_good = ServerAuthResponse::success();
its_all_good.write(&mut stream).await?;
stream.flush().await?;
Ok(stream)
} else {
let yeah_no = ServerAuthResponse::failure();
yeah_no.write(&mut stream).await?;
stream.flush().await?;
Err(AuthenticationError::FailedUsernamePassword(
their_info.username,
))
}
}
Some(ChosenMethod::None) => {
trace!("Just skipping the whole authentication thing.");
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
nothin_i_guess.write(&mut stream).await?;
stream.flush().await?;
Ok(stream)
}
}
}
#[derive(Error, Debug)]
pub enum ServerError<E: Debug + Display> {
#[error("Error in deserialization: {0}")]
DeserializationError(#[from] DeserializationError),
#[error("Error in serialization: {0}")]
SerializationError(#[from] SerializationError),
#[error("Underlying network error: {0}")]
NetworkError(E),
}