diff --git a/Cargo.toml b/Cargo.toml index f4ff11b..309235b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,14 +8,12 @@ edition = "2018" name = "async_socks5" [dependencies] -async-std = { version = "1.9.0", features = ["attributes"] } -async-trait = "0.1.50" -futures = "0.3.15" -log = "0.4.8" -proptest = "1.0.0" -simplelog = "0.10.0" -thiserror = "1.0.24" +anyhow = "^1.0.57" +proptest = "^1.0.0" +thiserror = "^1.0.31" +tokio = { version = "^1", features = ["full"] } +tracing = "^0.1.34" [dev-dependencies] proptest = "1.0.0" -proptest-derive = "0.3.0" \ No newline at end of file +proptest-derive = "0.3.0" diff --git a/src/address.rs b/src/address.rs new file mode 100644 index 0000000..28d2588 --- /dev/null +++ b/src/address.rs @@ -0,0 +1,174 @@ +use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError}; +use std::fmt; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum SOCKSv5Address { + IP4(Ipv4Addr), + IP6(Ipv6Addr), + Hostname(String), +} + +impl From for SOCKSv5Address { + fn from(x: IpAddr) -> SOCKSv5Address { + match x { + IpAddr::V4(a) => SOCKSv5Address::IP4(a), + IpAddr::V6(a) => SOCKSv5Address::IP6(a), + } + } +} + +impl From for SOCKSv5Address { + fn from(x: Ipv4Addr) -> SOCKSv5Address { + SOCKSv5Address::IP4(x) + } +} + +impl From for SOCKSv5Address { + fn from(x: Ipv6Addr) -> SOCKSv5Address { + SOCKSv5Address::IP6(x) + } +} + +impl From for SOCKSv5Address { + fn from(x: SOCKSv5String) -> SOCKSv5Address { + SOCKSv5Address::Hostname(x.into()) + } +} + +impl<'a> From<&'a str> for SOCKSv5Address { + fn from(x: &str) -> SOCKSv5Address { + SOCKSv5Address::Hostname(x.to_string()) + } +} + +impl From for SOCKSv5Address { + fn from(x: String) -> SOCKSv5Address { + SOCKSv5Address::Hostname(x) + } +} + +impl fmt::Display for SOCKSv5Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SOCKSv5Address::IP4(a) => write!(f, "{}", a), + SOCKSv5Address::IP6(a) => write!(f, "{}", a), + SOCKSv5Address::Hostname(a) => write!(f, "{}", a), + } + } +} + +#[cfg(test)] +const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+"; + +#[cfg(test)] +use proptest::prelude::{any, prop_oneof, Arbitrary, BoxedStrategy, Strategy}; + +#[cfg(test)] +impl Arbitrary for SOCKSv5Address { + type Parameters = Option; + type Strategy = BoxedStrategy; + + fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { + let max_len = args.unwrap_or(32) as usize; + + prop_oneof![ + any::().prop_map(SOCKSv5Address::IP4), + any::().prop_map(SOCKSv5Address::IP6), + HOSTNAME_REGEX.prop_map(move |mut hostname| { + hostname.shrink_to(max_len); + SOCKSv5Address::Hostname(hostname) + }), + ] + .boxed() + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum SOCKSv5AddressReadError { + #[error("Bad address type {0} (expected 1, 3, or 4)")] + BadAddressType(u8), + #[error("Read buffer error: {0}")] + ReadError(String), + #[error(transparent)] + SOCKSv5StringError(#[from] SOCKSv5StringReadError), +} + +impl From for SOCKSv5AddressReadError { + fn from(x: std::io::Error) -> SOCKSv5AddressReadError { + SOCKSv5AddressReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum SOCKSv5AddressWriteError { + #[error(transparent)] + SOCKSv5StringError(#[from] SOCKSv5StringWriteError), + #[error("Write buffer error: {0}")] + WriteError(String), +} + +impl From for SOCKSv5AddressWriteError { + fn from(x: std::io::Error) -> SOCKSv5AddressWriteError { + SOCKSv5AddressWriteError::WriteError(format!("{}", x)) + } +} + +impl SOCKSv5Address { + pub async fn read( + r: &mut R, + ) -> Result { + match r.read_u8().await? { + 1 => { + let mut addr_buffer = [0; 4]; + r.read_exact(&mut addr_buffer).await?; + let ip4 = Ipv4Addr::from(addr_buffer); + Ok(SOCKSv5Address::IP4(ip4)) + } + + 3 => { + let string = SOCKSv5String::read(r).await?; + Ok(SOCKSv5Address::from(string)) + } + + 4 => { + let mut addr_buffer = [0; 16]; + r.read_exact(&mut addr_buffer).await?; + let ip6 = Ipv6Addr::from(addr_buffer); + Ok(SOCKSv5Address::IP6(ip6)) + } + + x => Err(SOCKSv5AddressReadError::BadAddressType(x)), + } + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SOCKSv5AddressWriteError> { + match self { + SOCKSv5Address::IP4(x) => { + w.write_u8(1).await?; + w.write_all(&x.octets()).await?; + Ok(()) + } + + SOCKSv5Address::IP6(x) => { + w.write_u8(4).await?; + w.write_all(&x.octets()).await?; + Ok(()) + } + + SOCKSv5Address::Hostname(x) => { + w.write_u8(3).await?; + let string = SOCKSv5String::from(x.clone()); + string.write(w).await?; + Ok(()) + } + } + } +} + +crate::standard_roundtrip!(socks_address_roundtrips, SOCKSv5Address); diff --git a/src/bin/socks-server.rs b/src/bin/socks-server.rs index 63d47f2..c1e5c88 100644 --- a/src/bin/socks-server.rs +++ b/src/bin/socks-server.rs @@ -1,36 +1,3 @@ -use async_socks5::network::Builtin; -use async_socks5::server::{SOCKSv5Server, SecurityParameters}; -use async_std::io; -use futures::stream::StreamExt; -use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TermLogger, TerminalMode}; - -#[async_std::main] -async fn main() -> Result<(), io::Error> { - CombinedLogger::init(vec![TermLogger::new( - LevelFilter::Debug, - Config::default(), - TerminalMode::Mixed, - ColorChoice::Auto, - )]) - .expect("Couldn't initialize logger"); - - let params = SecurityParameters { - allow_unauthenticated: true, - allow_connection: None, - check_password: None, - connect_tls: None, - }; - - let mut server = SOCKSv5Server::new(Builtin::new(), params); - server.start("127.0.0.1", 9999).await?; - - let mut responses = Box::pin(server.subserver_results()); - - while let Some(response) = responses.next().await { - if let Err(e) = response { - println!("Server failed with: {}", e); - } - } - +fn main() -> Result<(), ()> { Ok(()) } diff --git a/src/client.rs b/src/client.rs index 3836295..be6f296 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,57 +1,47 @@ -use crate::errors::{DeserializationError, SerializationError}; +use crate::address::SOCKSv5Address; use crate::messages::{ - AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting, - ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus, + AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandWriteError, + ClientConnectionRequest, ClientGreeting, ClientGreetingWriteError, ClientUsernamePassword, + ClientUsernamePasswordWriteError, ServerAuthResponse, ServerAuthResponseReadError, + ServerChoice, ServerChoiceReadError, ServerResponse, ServerResponseReadError, + ServerResponseStatus, }; -use crate::network::datagram::GenericDatagramSocket; -use crate::network::generic::{IntoErrorResponse, Networklike}; -use crate::network::listener::GenericListener; -use crate::network::stream::GenericStream; -use crate::network::SOCKSv5Address; -use async_std::io; -use async_std::sync::{Arc, Mutex}; -use async_trait::async_trait; -use futures::Future; -use log::{info, trace, warn}; -use std::fmt::{Debug, Display}; +use std::future::Future; use thiserror::Error; +use tokio::net::TcpStream; #[derive(Debug, Error)] -pub enum SOCKSv5Error { - #[error("SOCKSv5 serialization error: {0}")] - SerializationError(#[from] SerializationError), - #[error("SOCKSv5 deserialization error: {0}")] - DeserializationError(#[from] DeserializationError), - #[error("No acceptable authentication methods available")] - NoAuthMethodsAllowed, +pub enum SOCKSv5ClientError { + #[error("Underlying networking error: {0}")] + NetworkingError(String), + #[error("Client greeting write error: {0}")] + ClientWriteError(#[from] ClientGreetingWriteError), + #[error("Server choice error: {0}")] + ServerChoiceError(#[from] ServerChoiceReadError), + #[error("Error writing credentials: {0}")] + CredentialWriteError(#[from] ClientUsernamePasswordWriteError), + #[error("Server auth read error: {0}")] + AuthResponseError(#[from] ServerAuthResponseReadError), #[error("Authentication failed")] AuthenticationFailed, - #[error("Server chose an unsupported authentication method ({0}")] + #[error("No authentication methods allowed")] + NoAuthMethodsAllowed, + #[error("Unsupported authentication method chosen ({0})")] UnsupportedAuthMethodChosen(AuthenticationMethod), + #[error("Client connection command write error: {0}")] + ClientCommandWriteError(#[from] ClientConnectionCommandWriteError), #[error("Server said no: {0}")] - ServerFailure(#[from] ServerResponseStatus), - #[error("Connection error: {0}")] - ConnectionError(#[from] io::Error), - #[error("Underlying network error: {0}")] - UnderlyingNetwork(E), + ServerRejected(#[from] ServerResponseStatus), + #[error("Server response read failure: {0}")] + ServerResponseError(#[from] ServerResponseReadError), } -impl IntoErrorResponse for SOCKSv5Error { - fn into_response(&self) -> ServerResponseStatus { - match self { - SOCKSv5Error::ServerFailure(v) => v.clone(), - _ => ServerResponseStatus::GeneralFailure, - } +impl From for SOCKSv5ClientError { + fn from(x: std::io::Error) -> SOCKSv5ClientError { + SOCKSv5ClientError::NetworkingError(format!("{}", x)) } } -pub struct SOCKSv5Client { - network: Arc>, - login_info: LoginInfo, - address: SOCKSv5Address, - port: u16, -} - pub struct LoginInfo { pub username_password: Option, } @@ -64,7 +54,7 @@ impl Default for LoginInfo { impl LoginInfo { /// Generate an empty bit of login information. - fn new() -> LoginInfo { + pub fn new() -> LoginInfo { LoginInfo { username_password: None, } @@ -89,22 +79,24 @@ pub struct UsernamePassword { pub password: String, } -impl SOCKSv5Client -where - N: Networklike + Sync, -{ +pub struct SOCKSv5Client { + login_info: LoginInfo, + address: SOCKSv5Address, + port: u16, +} + +impl SOCKSv5Client { /// Create a new SOCKSv5 client connection over the given steam, using the given /// authentication information. As part of the process of building this object, we /// do a little test run to make sure that we can login effectively; this should save /// from *some* surprises later on. If you'd rather *not* do that, though, you can /// try `unchecked_new`. pub async fn new>( - network: N, login: LoginInfo, server_addr: A, server_port: u16, - ) -> Result> { - let base_version = SOCKSv5Client::unchecked_new(network, login, server_addr, server_port); + ) -> Result { + let base_version = SOCKSv5Client::unchecked_new(login, server_addr, server_port); let _ = base_version.start_session().await?; Ok(base_version) } @@ -113,13 +105,11 @@ where /// connection sequence at the expense of an increased possibility of an error /// later on down the road. pub fn unchecked_new>( - network: N, login_info: LoginInfo, address: A, port: u16, ) -> Self { SOCKSv5Client { - network: Arc::new(Mutex::new(network)), login_info, address: address.into(), port, @@ -128,17 +118,17 @@ where /// This runs the connection and negotiates login, as required, and then returns /// the stream the caller should use to do ... whatever it wants to do. - async fn start_session(&self) -> Result> { + async fn start_session(&self) -> Result { // create the initial stream - let mut stream = { - let mut network = self.network.lock().await; - network.connect(self.address.clone(), self.port).await - } - .map_err(SOCKSv5Error::UnderlyingNetwork)?; + let mut stream = match &self.address { + SOCKSv5Address::IP4(x) => TcpStream::connect((*x, self.port)).await?, + SOCKSv5Address::IP6(x) => TcpStream::connect((*x, self.port)).await?, + SOCKSv5Address::Hostname(x) => TcpStream::connect((x.as_ref(), self.port)).await?, + }; // compute how we can log in let acceptable_methods = self.login_info.acceptable_methods(); - trace!( + tracing::trace!( "Computed acceptable methods -- {:?} -- sending client greeting.", acceptable_methods ); @@ -146,9 +136,9 @@ where // Negotiate with the server. Well. "Negotiate." let client_greeting = ClientGreeting { acceptable_methods }; client_greeting.write(&mut stream).await?; - trace!("Write client greeting, waiting for server's choice."); + tracing::trace!("Write client greeting, waiting for server's choice."); let server_choice = ServerChoice::read(&mut stream).await?; - trace!("Received server's choice: {}", server_choice.chosen_method); + tracing::trace!("Received server's choice: {}", server_choice.chosen_method); // Let's do it! match server_choice.chosen_method { @@ -158,30 +148,32 @@ where let (username, password) = if let Some(ref linfo) = self.login_info.username_password { - trace!("Server requested username/password, getting data from login info."); + tracing::trace!( + "Server requested username/password, getting data from login info." + ); (linfo.username.clone(), linfo.password.clone()) } else { - warn!("Server requested username/password, but we weren't provided one. Very weird."); + tracing::warn!("Server requested username/password, but we weren't provided one. Very weird."); ("".to_string(), "".to_string()) }; let auth_request = ClientUsernamePassword { username, password }; - trace!("Writing password information."); + tracing::trace!("Writing password information."); auth_request.write(&mut stream).await?; let server_response = ServerAuthResponse::read(&mut stream).await?; - trace!("Got server response: {}", server_response.success); + tracing::trace!("Got server response: {}", server_response.success); if !server_response.success { - return Err(SOCKSv5Error::AuthenticationFailed); + return Err(SOCKSv5ClientError::AuthenticationFailed); } } AuthenticationMethod::NoAcceptableMethods => { - return Err(SOCKSv5Error::NoAuthMethodsAllowed) + return Err(SOCKSv5ClientError::NoAuthMethodsAllowed) } - x => return Err(SOCKSv5Error::UnsupportedAuthMethodChosen(x)), + x => return Err(SOCKSv5ClientError::UnsupportedAuthMethodChosen(x)), } Ok(stream) @@ -193,12 +185,12 @@ where /// person to listen on. So this function takes an async function, which it /// will pass this information to once it has it. It's up to that function, /// then, to communicate this to its peer. - pub async fn remote_listen>>>( + pub async fn remote_listen>>( self, addr: A, port: u16, callback: impl FnOnce(SOCKSv5Address, u16) -> Fut, - ) -> Result<(SOCKSv5Address, u16, GenericStream), SOCKSv5Error> + ) -> Result<(SOCKSv5Address, u16, TcpStream), SOCKSv5ClientError> where A: Into, { @@ -217,9 +209,12 @@ where return Err(initial_response.status.into()); } - info!( + tracing::info!( "Proxy port binding of {}:{} established; server listening on {}:{}", - target, port, initial_response.bound_address, initial_response.bound_port + target, + port, + initial_response.bound_address, + initial_response.bound_port ); callback(initial_response.bound_address, initial_response.bound_port).await?; @@ -229,9 +224,10 @@ where return Err(secondary_response.status.into()); } - info!( + tracing::info!( "Proxy bind got a connection from {}:{}", - secondary_response.bound_address, secondary_response.bound_port + secondary_response.bound_address, + secondary_response.bound_port ); Ok(( @@ -240,20 +236,12 @@ where stream, )) } -} -#[async_trait] -impl Networklike for SOCKSv5Client -where - N: Networklike + Sync + Send, -{ - type Error = SOCKSv5Error; - - async fn connect>( + pub async fn connect>( &mut self, addr: A, port: u16, - ) -> Result { + ) -> Result { let mut stream = self.start_session().await?; let target = addr.into(); @@ -266,29 +254,16 @@ where let response = ServerResponse::read(&mut stream).await?; if response.status == ServerResponseStatus::RequestGranted { - info!( + tracing::info!( "Proxy connection to {}:{} established; server is using {}:{}", - target, port, response.bound_address, response.bound_port + target, + port, + response.bound_address, + response.bound_port ); Ok(stream) } else { Err(response.status.into()) } } - - async fn listen>( - &mut self, - _addr: A, - _port: u16, - ) -> Result, Self::Error> { - unimplemented!() - } - - async fn bind>( - &mut self, - _addr: A, - _port: u16, - ) -> Result, Self::Error> { - unimplemented!() - } } diff --git a/src/errors.rs b/src/errors.rs deleted file mode 100644 index 9733529..0000000 --- a/src/errors.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::io; -use std::string::FromUtf8Error; -use thiserror::Error; - -use crate::network::SOCKSv5Address; - -/// All the errors that can pop up when trying to turn raw bytes into SOCKSv5 -/// messages. -#[derive(Error, Debug)] -pub enum DeserializationError { - #[error("Invalid protocol version for packet ({1} is not {0}!)")] - InvalidVersion(u8, u8), - #[error("Not enough data found")] - NotEnoughData, - #[error("Ooops! Found an empty string where I shouldn't")] - InvalidEmptyString, - #[error("IO error: {0}")] - IOError(#[from] io::Error), - #[error("SOCKS authentication format parse error: {0}")] - AuthenticationMethodError(#[from] AuthenticationDeserializationError), - #[error("Error converting from UTF-8: {0}")] - UTF8Error(#[from] FromUtf8Error), - #[error("Invalid address type; wanted 1, 3, or 4, got {0}")] - InvalidAddressType(u8), - #[error("Invalid client command {0}; expected 1, 2, or 3")] - InvalidClientCommand(u8), - #[error("Invalid server status {0}; expected 0-8")] - InvalidServerResponse(u8), - #[error("Invalid byte in reserved byte ({0})")] - InvalidReservedByte(u8), -} - -#[test] -fn des_error_reasonable_equals() { - let invalid_version1 = DeserializationError::InvalidVersion(1, 2); - let invalid_version2 = DeserializationError::InvalidVersion(1, 2); - assert_eq!(invalid_version1, invalid_version2); - - let not_enough1 = DeserializationError::NotEnoughData; - let not_enough2 = DeserializationError::NotEnoughData; - assert_eq!(not_enough1, not_enough2); - - let invalid_empty1 = DeserializationError::InvalidEmptyString; - let invalid_empty2 = DeserializationError::InvalidEmptyString; - assert_eq!(invalid_empty1, invalid_empty2); - - let auth_method1 = DeserializationError::AuthenticationMethodError( - AuthenticationDeserializationError::NoDataFound, - ); - let auth_method2 = DeserializationError::AuthenticationMethodError( - AuthenticationDeserializationError::NoDataFound, - ); - assert_eq!(auth_method1, auth_method2); - - let utf8a = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err()); - let utf8b = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err()); - assert_eq!(utf8a, utf8b); - - let invalid_address1 = DeserializationError::InvalidAddressType(3); - let invalid_address2 = DeserializationError::InvalidAddressType(3); - assert_eq!(invalid_address1, invalid_address2); - - let invalid_client_cmd1 = DeserializationError::InvalidClientCommand(32); - let invalid_client_cmd2 = DeserializationError::InvalidClientCommand(32); - assert_eq!(invalid_client_cmd1, invalid_client_cmd2); - - let invalid_server_resp1 = DeserializationError::InvalidServerResponse(42); - let invalid_server_resp2 = DeserializationError::InvalidServerResponse(42); - assert_eq!(invalid_server_resp1, invalid_server_resp2); - - assert_ne!(invalid_version1, invalid_address1); - assert_ne!(not_enough1, invalid_empty1); - assert_ne!(auth_method1, invalid_client_cmd1); - assert_ne!(utf8a, invalid_server_resp1); -} - -impl PartialEq for DeserializationError { - fn eq(&self, other: &DeserializationError) -> bool { - match (self, other) { - ( - &DeserializationError::InvalidVersion(a, b), - &DeserializationError::InvalidVersion(x, y), - ) => (a == x) && (b == y), - (&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true, - ( - &DeserializationError::InvalidEmptyString, - &DeserializationError::InvalidEmptyString, - ) => true, - ( - &DeserializationError::AuthenticationMethodError(ref a), - &DeserializationError::AuthenticationMethodError(ref b), - ) => a == b, - (&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => { - a == b - } - ( - &DeserializationError::InvalidAddressType(a), - &DeserializationError::InvalidAddressType(b), - ) => a == b, - ( - &DeserializationError::InvalidClientCommand(a), - &DeserializationError::InvalidClientCommand(b), - ) => a == b, - ( - &DeserializationError::InvalidServerResponse(a), - &DeserializationError::InvalidServerResponse(b), - ) => a == b, - ( - &DeserializationError::InvalidReservedByte(a), - &DeserializationError::InvalidReservedByte(b), - ) => a == b, - (_, _) => false, - } - } -} - -/// All the errors that can occur trying to turn SOCKSv5 message structures -/// into raw bytes. There's a few places that the message structures allow -/// for information that can't be serialized; often, you have to be careful -/// about how long your strings are ... -#[derive(Error, Debug)] -pub enum SerializationError { - #[error("Too many authentication methods for serialization ({0} > 255)")] - TooManyAuthMethods(usize), - #[error("Invalid length for string: {0}")] - InvalidStringLength(String), - #[error("IO error: {0}")] - IOError(#[from] io::Error), -} - -#[test] -fn ser_err_reasonable_equals() { - let too_many1 = SerializationError::TooManyAuthMethods(512); - let too_many2 = SerializationError::TooManyAuthMethods(512); - assert_eq!(too_many1, too_many2); - - let invalid_str1 = SerializationError::InvalidStringLength("Whoopsy!".to_string()); - let invalid_str2 = SerializationError::InvalidStringLength("Whoopsy!".to_string()); - assert_eq!(invalid_str1, invalid_str2); - - assert_ne!(too_many1, invalid_str1); -} - -impl PartialEq for SerializationError { - fn eq(&self, other: &SerializationError) -> bool { - match (self, other) { - ( - &SerializationError::TooManyAuthMethods(a), - &SerializationError::TooManyAuthMethods(b), - ) => a == b, - ( - &SerializationError::InvalidStringLength(ref a), - &SerializationError::InvalidStringLength(ref b), - ) => a == b, - (_, _) => false, - } - } -} - -#[derive(Error, Debug)] -pub enum AuthenticationDeserializationError { - #[error("No data found deserializing SOCKS authentication type")] - NoDataFound, - #[error("Invalid authentication type value: {0}")] - InvalidAuthenticationByte(u8), - #[error("IO error reading SOCKS authentication type: {0}")] - IOError(#[from] io::Error), -} - -#[test] -fn auth_des_err_reasonable_equals() { - let no_data1 = AuthenticationDeserializationError::NoDataFound; - let no_data2 = AuthenticationDeserializationError::NoDataFound; - assert_eq!(no_data1, no_data2); - - let invalid_auth1 = AuthenticationDeserializationError::InvalidAuthenticationByte(39); - let invalid_auth2 = AuthenticationDeserializationError::InvalidAuthenticationByte(39); - assert_eq!(invalid_auth1, invalid_auth2); - - assert_ne!(no_data1, invalid_auth1); -} - -impl PartialEq for AuthenticationDeserializationError { - fn eq(&self, other: &AuthenticationDeserializationError) -> bool { - match (self, other) { - ( - &AuthenticationDeserializationError::NoDataFound, - &AuthenticationDeserializationError::NoDataFound, - ) => true, - ( - &AuthenticationDeserializationError::InvalidAuthenticationByte(x), - &AuthenticationDeserializationError::InvalidAuthenticationByte(y), - ) => x == y, - (_, _) => false, - } - } -} - -/// The errors that can happen, as a server, when we're negotiating the start -/// of a SOCKS session. -#[derive(Debug, Error)] -pub enum AuthenticationError { - #[error("Firewall disallowed connection from {0}:{1}")] - FirewallRejected(SOCKSv5Address, u16), - #[error("Could not agree on an authentication method with the client")] - ItsNotUsItsYou, - #[error("Failure in serializing response message: {0}")] - SerializationError(#[from] SerializationError), - #[error("Failed TLS handshake")] - FailedTLSHandshake, - #[error("IO error writing response message: {0}")] - IOError(#[from] io::Error), - #[error("Failure in reading client message: {0}")] - DeserializationError(#[from] DeserializationError), - #[error("Username/password check failed (username was {0})")] - FailedUsernamePassword(String), -} diff --git a/src/lib.rs b/src/lib.rs index 66bdf88..9a8260d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,192 +1,172 @@ pub mod client; -pub mod errors; -pub mod messages; -pub mod network; -mod serialize; pub mod server; +mod address; +mod messages; +mod security_parameters; + #[cfg(test)] mod test { + use crate::address::SOCKSv5Address; use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword}; - use crate::network::generic::Networklike; - use crate::network::listener::Listenerlike; - use crate::network::testing::TestingStack; - use crate::server::{SOCKSv5Server, SecurityParameters}; - use async_std::channel::bounded; - use async_std::io::prelude::WriteExt; - use async_std::task; - use futures::AsyncReadExt; + use crate::security_parameters::SecurityParameters; + use crate::server::SOCKSv5Server; + use std::io; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpSocket, TcpStream}; + use tokio::sync::oneshot; + use tokio::task; - #[test] - fn unrestricted_login() { - task::block_on(async { - let network_stack = TestingStack::default(); + #[tokio::test] + async fn unrestricted_login() { + // generate the server + let security_parameters = SecurityParameters::unrestricted(); + let server = SOCKSv5Server::new(security_parameters); + server.start("localhost", 9999).await.unwrap(); - // generate the server - let security_parameters = SecurityParameters::unrestricted(); - let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); - server.start("localhost", 9999).await.unwrap(); + let login_info = LoginInfo { + username_password: None, + }; + let client = SOCKSv5Client::new(login_info, "localhost", 9999).await; - let login_info = LoginInfo { - username_password: None, - }; - let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9999).await; - - assert!(client.is_ok()); - }) + assert!(client.is_ok()); } - #[test] - fn disallow_unrestricted() { - task::block_on(async { - let network_stack = TestingStack::default(); + #[tokio::test] + async fn disallow_unrestricted() { + // generate the server + let mut security_parameters = SecurityParameters::unrestricted(); + security_parameters.allow_unauthenticated = false; + let server = SOCKSv5Server::new(security_parameters); + server.start("localhost", 9998).await.unwrap(); - // generate the server - let mut security_parameters = SecurityParameters::unrestricted(); - security_parameters.allow_unauthenticated = false; - let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); - server.start("localhost", 9998).await.unwrap(); + let login_info = LoginInfo::default(); + let client = SOCKSv5Client::new(login_info, "localhost", 9998).await; - let login_info = LoginInfo { - username_password: None, - }; - let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9998).await; - - assert!(client.is_err()); - }) + assert!(client.is_err()); } - #[test] - fn password_checks() { - task::block_on(async { - let network_stack = TestingStack::default(); + #[tokio::test] + async fn password_checks() { + // generate the server + let security_parameters = SecurityParameters { + allow_unauthenticated: false, + allow_connection: None, + connect_tls: None, + check_password: Some(|username, password| { + username == "awick" && password == "password" + }), + }; + let server = SOCKSv5Server::new(security_parameters); + server.start("localhost", 9997).await.unwrap(); - // generate the server - let security_parameters = SecurityParameters { - allow_unauthenticated: false, - allow_connection: None, - connect_tls: None, - check_password: Some(|username, password| { - username == "awick" && password == "password" - }), - }; - let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); - server.start("localhost", 9997).await.unwrap(); + // try the positive side + let login_info = LoginInfo { + username_password: Some(UsernamePassword { + username: "awick".to_string(), + password: "password".to_string(), + }), + }; + let client = SOCKSv5Client::new(login_info, "localhost", 9997).await; + assert!(client.is_ok()); - // try the positive side - let login_info = LoginInfo { - username_password: Some(UsernamePassword { - username: "awick".to_string(), - password: "password".to_string(), - }), - }; - let client = - SOCKSv5Client::new(network_stack.clone(), login_info, "localhost", 9997).await; - assert!(client.is_ok()); - - // try the negative side - let login_info = LoginInfo { - username_password: Some(UsernamePassword { - username: "adamw".to_string(), - password: "password".to_string(), - }), - }; - let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9997).await; - assert!(client.is_err()); - }) + // try the negative side + let login_info = LoginInfo { + username_password: Some(UsernamePassword { + username: "adamw".to_string(), + password: "password".to_string(), + }), + }; + let client = SOCKSv5Client::new(login_info, "localhost", 9997).await; + assert!(client.is_err()); } - #[test] - fn firewall_blocks() { - task::block_on(async { - let network_stack = TestingStack::default(); + #[tokio::test] + async fn firewall_blocks() { + // generate the server + let mut security_parameters = SecurityParameters::unrestricted(); + security_parameters.allow_connection = Some(|_| false); + let server = SOCKSv5Server::new(security_parameters); + server.start("localhost", 9996).await.unwrap(); - // generate the server - let mut security_parameters = SecurityParameters::unrestricted(); - security_parameters.allow_connection = Some(|_, _| false); - let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); - server.start("localhost", 9996).await.unwrap(); + let login_info = LoginInfo::new(); + let client = SOCKSv5Client::new(login_info, "localhost", 9996).await; - let login_info = LoginInfo { - username_password: None, - }; - let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9996).await; - - assert!(client.is_err()); - }) + assert!(client.is_err()); } - #[test] - fn establish_stream() { - task::block_on(async { - let mut network_stack = TestingStack::default(); + #[tokio::test] + async fn establish_stream() -> io::Result<()> { + let target_socket = TcpSocket::new_v4()?; + target_socket.bind(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 1337, + ))?; + let target_port = target_socket.listen(1)?; - let target_port = network_stack.listen("localhost", 1337).await.unwrap(); + // generate the server + let security_parameters = SecurityParameters::unrestricted(); + let server = SOCKSv5Server::new(security_parameters); + server.start("localhost", 9995).await.unwrap(); - // generate the server - let security_parameters = SecurityParameters::unrestricted(); - let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); - server.start("localhost", 9995).await.unwrap(); + let login_info = LoginInfo { + username_password: None, + }; - let login_info = LoginInfo { - username_password: None, - }; + let mut client = SOCKSv5Client::new(login_info, "localhost", 9995) + .await + .unwrap(); - let mut client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9995) + task::spawn(async move { + let mut conn = client.connect("localhost", 1337).await.unwrap(); + conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap(); + }); + + let (mut target_connection, _) = target_port.accept().await.unwrap(); + let mut read_buffer = [0; 4]; + target_connection + .read_exact(&mut read_buffer) + .await + .unwrap(); + assert_eq!(read_buffer, [1, 3, 3, 7]); + Ok(()) + } + + #[tokio::test] + async fn bind_test() -> io::Result<()> { + let security_parameters = SecurityParameters::unrestricted(); + let server = SOCKSv5Server::new(security_parameters); + server.start("localhost", 9994).await.unwrap(); + + let login_info = LoginInfo::default(); + let client = SOCKSv5Client::new(login_info, "localhost", 9994) + .await + .unwrap(); + + let (target_sender, target_receiver) = oneshot::channel(); + + task::spawn(async move { + let (_, _, mut conn) = client + .remote_listen("localhost", 9993, |addr, port| async move { + target_sender.send((addr, port)).unwrap(); + Ok(()) + }) .await .unwrap(); - task::spawn(async move { - let mut conn = client.connect("localhost", 1337).await.unwrap(); - conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap(); - }); + conn.write_all(&[2, 3, 5, 7]).await.unwrap(); + }); - let (mut target_connection, _, _) = target_port.accept().await.unwrap(); - let mut read_buffer = [0; 4]; - target_connection - .read_exact(&mut read_buffer) - .await - .unwrap(); - assert_eq!(read_buffer, [1, 3, 3, 7]); - }) - } - - #[test] - fn bind_test() { - task::block_on(async { - let mut network_stack = TestingStack::default(); - - let security_parameters = SecurityParameters::unrestricted(); - let server = SOCKSv5Server::new(network_stack.clone(), security_parameters); - server.start("localhost", 9994).await.unwrap(); - - let login_info = LoginInfo::default(); - let client = SOCKSv5Client::new(network_stack.clone(), login_info, "localhost", 9994) - .await - .unwrap(); - - let (target_sender, target_receiver) = bounded(1); - - task::spawn(async move { - let (_, _, mut conn) = client - .remote_listen("localhost", 9993, |addr, port| async move { - target_sender.send((addr, port)).await.unwrap(); - Ok(()) - }) - .await - .unwrap(); - - conn.write_all(&[2, 3, 5, 7]).await.unwrap(); - }); - - let (target_addr, target_port) = target_receiver.recv().await.unwrap(); - let mut stream = network_stack - .connect(target_addr, target_port) - .await - .unwrap(); - let mut read_buffer = [0; 4]; - stream.read_exact(&mut read_buffer).await.unwrap(); - assert_eq!(read_buffer, [2, 3, 5, 7]); - }) + let (target_addr, target_port) = target_receiver.await.unwrap(); + let mut stream = match target_addr { + SOCKSv5Address::IP4(x) => TcpStream::connect((x, target_port)).await?, + SOCKSv5Address::IP6(x) => TcpStream::connect((x, target_port)).await?, + SOCKSv5Address::Hostname(x) => TcpStream::connect((x, target_port)).await?, + }; + let mut read_buffer = [0; 4]; + stream.read_exact(&mut read_buffer).await.unwrap(); + assert_eq!(read_buffer, [2, 3, 5, 7]); + Ok(()) } } diff --git a/src/messages.rs b/src/messages.rs index b85209a..8856a18 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -5,12 +5,51 @@ mod client_username_password; mod server_auth_response; mod server_choice; mod server_response; -pub(crate) mod utils; -pub use crate::messages::authentication_method::AuthenticationMethod; -pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest}; -pub use crate::messages::client_greeting::ClientGreeting; -pub use crate::messages::client_username_password::ClientUsernamePassword; -pub use crate::messages::server_auth_response::ServerAuthResponse; -pub use crate::messages::server_choice::ServerChoice; -pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus}; +pub(crate) mod string; + +pub use crate::messages::authentication_method::{ + AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError, +}; +pub use crate::messages::client_command::{ + ClientConnectionCommand, ClientConnectionCommandReadError, ClientConnectionCommandWriteError, + ClientConnectionRequest, ClientConnectionRequestReadError, +}; +pub use crate::messages::client_greeting::{ + ClientGreeting, ClientGreetingReadError, ClientGreetingWriteError, +}; +pub use crate::messages::client_username_password::{ + ClientUsernamePassword, ClientUsernamePasswordReadError, ClientUsernamePasswordWriteError, +}; +pub use crate::messages::server_auth_response::{ + ServerAuthResponse, ServerAuthResponseReadError, ServerAuthResponseWriteError, +}; +pub use crate::messages::server_choice::{ + ServerChoice, ServerChoiceReadError, ServerChoiceWriteError, +}; +pub use crate::messages::server_response::{ + ServerResponse, ServerResponseReadError, ServerResponseStatus, ServerResponseWriteError, +}; + +#[doc(hidden)] +#[macro_export] +macro_rules! standard_roundtrip { + ($name: ident, $t: ty) => { + proptest::proptest! { + #[test] + fn $name(xs: $t) { + tokio::runtime::Runtime::new().unwrap().block_on(async { + use std::io::Cursor; + + let buffer = vec![]; + let mut write_cursor = Cursor::new(buffer); + xs.write(&mut write_cursor).await.unwrap(); + let serialized_form = write_cursor.into_inner(); + let mut read_cursor = Cursor::new(serialized_form); + let ys = <$t>::read(&mut read_cursor); + assert_eq!(xs, ys.await.unwrap()); + }) + } + } + }; +} diff --git a/src/messages/authentication_method.rs b/src/messages/authentication_method.rs index eedaf8d..292bf0b 100644 --- a/src/messages/authentication_method.rs +++ b/src/messages/authentication_method.rs @@ -1,16 +1,12 @@ -use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError}; -use crate::standard_roundtrip; #[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use proptest::proptest; -#[cfg(test)] -use proptest::prelude::{Arbitrary, Just, Strategy, prop_oneof}; +use proptest::prelude::{prop_oneof, Arbitrary, Just, Strategy}; #[cfg(test)] use proptest::strategy::BoxedStrategy; use std::fmt; +#[cfg(test)] +use std::io::Cursor; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[allow(clippy::upper_case_acronyms)] #[derive(Clone, Debug, Eq, PartialEq)] @@ -28,6 +24,34 @@ pub enum AuthenticationMethod { NoAcceptableMethods, } +#[derive(Clone, Debug, Error, PartialEq)] +pub enum AuthenticationMethodReadError { + #[error("Invalid authentication method #{0}")] + UnknownAuthenticationMethod(u8), + #[error("Error in underlying buffer: {0}")] + ReadError(String), +} + +impl From for AuthenticationMethodReadError { + fn from(x: std::io::Error) -> AuthenticationMethodReadError { + AuthenticationMethodReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum AuthenticationMethodWriteError { + #[error("Trying to write invalid authentication method #{0}")] + InvalidAuthMethod(u8), + #[error("Error in underlying buffer: {0}")] + WriteError(String), +} + +impl From for AuthenticationMethodWriteError { + fn from(x: std::io::Error) -> AuthenticationMethodWriteError { + AuthenticationMethodWriteError::WriteError(format!("{}", x)) + } +} + impl fmt::Display for AuthenticationMethod { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -65,26 +89,17 @@ impl Arbitrary for AuthenticationMethod { Just(AuthenticationMethod::MultiAuthenticationFramework), Just(AuthenticationMethod::JSONPropertyBlock), Just(AuthenticationMethod::NoAcceptableMethods), - (0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod), - ].boxed() + ] + .boxed() } } - - impl AuthenticationMethod { pub async fn read( r: &mut R, - ) -> Result { - let mut byte_buffer = [0u8; 1]; - let amount_read = r.read(&mut byte_buffer).await?; - - if amount_read == 0 { - return Err(AuthenticationDeserializationError::NoDataFound.into()); - } - - match byte_buffer[0] { + ) -> Result { + match r.read_u8().await? { 0 => Ok(AuthenticationMethod::None), 1 => Ok(AuthenticationMethod::GSSAPI), 2 => Ok(AuthenticationMethod::UsernameAndPassword), @@ -96,14 +111,16 @@ impl AuthenticationMethod { 9 => Ok(AuthenticationMethod::JSONPropertyBlock), x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)), 0xff => Ok(AuthenticationMethod::NoAcceptableMethods), - e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()), + e => Err(AuthenticationMethodReadError::UnknownAuthenticationMethod( + e, + )), } } pub async fn write( &self, w: &mut W, - ) -> Result<(), SerializationError> { + ) -> Result<(), AuthenticationMethodWriteError> { let value = match self { AuthenticationMethod::None => 0, AuthenticationMethod::GSSAPI => 1, @@ -114,31 +131,32 @@ impl AuthenticationMethod { AuthenticationMethod::NDS => 7, AuthenticationMethod::MultiAuthenticationFramework => 8, AuthenticationMethod::JSONPropertyBlock => 9, - AuthenticationMethod::PrivateMethod(pm) => *pm, + AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(pm) => *pm, + AuthenticationMethod::PrivateMethod(pm) => { + return Err(AuthenticationMethodWriteError::InvalidAuthMethod(*pm)) + } AuthenticationMethod::NoAcceptableMethods => 0xff, }; - Ok(w.write_all(&[value]).await?) + Ok(w.write_u8(value).await?) } } -standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod); +crate::standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod); -#[test] -fn bad_byte() { +#[tokio::test] +async fn bad_byte() { let no_len = vec![42]; let mut cursor = Cursor::new(no_len); - let ys = AuthenticationMethod::read(&mut cursor); + let ys = AuthenticationMethod::read(&mut cursor).await.unwrap_err(); assert_eq!( - Err(DeserializationError::AuthenticationMethodError( - AuthenticationDeserializationError::InvalidAuthenticationByte(42) - )), - task::block_on(ys) + AuthenticationMethodReadError::UnknownAuthenticationMethod(42), + ys ); } -#[test] -fn display_isnt_empty() { +#[tokio::test] +async fn display_isnt_empty() { let vals = vec![ AuthenticationMethod::None, AuthenticationMethod::GSSAPI, diff --git a/src/messages/client_command.rs b/src/messages/client_command.rs index 141dcee..021d0f9 100644 --- a/src/messages/client_command.rs +++ b/src/messages/client_command.rs @@ -1,20 +1,10 @@ -use crate::errors::{DeserializationError, SerializationError}; -use crate::network::SOCKSv5Address; -use crate::serialize::read_amt; -use crate::standard_roundtrip; -#[cfg(test)] -use async_std::io::ErrorKind; -#[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use log::debug; -use proptest::proptest; +use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError}; #[cfg(test)] use proptest_derive::Arbitrary; #[cfg(test)] -use std::net::Ipv4Addr; +use std::io::Cursor; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Clone, Copy, Debug, Eq, PartialEq)] #[cfg_attr(test, derive(Arbitrary))] @@ -24,6 +14,60 @@ pub enum ClientConnectionCommand { AssociateUDPPort, } +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ClientConnectionCommandReadError { + #[error("Invalid client connection command code: {0}")] + InvalidClientConnectionCommand(u8), + #[error("Underlying buffer read error: {0}")] + ReadError(String), +} + +impl From for ClientConnectionCommandReadError { + fn from(x: std::io::Error) -> ClientConnectionCommandReadError { + ClientConnectionCommandReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ClientConnectionCommandWriteError { + #[error("Underlying buffer write error: {0}")] + WriteError(String), + #[error(transparent)] + SOCKSAddressWriteError(#[from] SOCKSv5AddressWriteError), +} + +impl From for ClientConnectionCommandWriteError { + fn from(x: std::io::Error) -> ClientConnectionCommandWriteError { + ClientConnectionCommandWriteError::WriteError(format!("{}", x)) + } +} + +impl ClientConnectionCommand { + pub async fn read( + r: &mut R, + ) -> Result { + match r.read_u8().await? { + 0x01 => Ok(ClientConnectionCommand::EstablishTCPStream), + 0x02 => Ok(ClientConnectionCommand::EstablishTCPPortBinding), + 0x03 => Ok(ClientConnectionCommand::AssociateUDPPort), + x => Err(ClientConnectionCommandReadError::InvalidClientConnectionCommand(x)), + } + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), std::io::Error> { + match self { + ClientConnectionCommand::EstablishTCPStream => w.write_u8(0x01).await, + ClientConnectionCommand::EstablishTCPPortBinding => w.write_u8(0x02).await, + ClientConnectionCommand::AssociateUDPPort => w.write_u8(0x03).await, + } + } +} + +crate::standard_roundtrip!(client_command_roundtrips, ClientConnectionCommand); + #[derive(Clone, Debug, Eq, PartialEq)] #[cfg_attr(test, derive(Arbitrary))] pub struct ClientConnectionRequest { @@ -32,37 +76,46 @@ pub struct ClientConnectionRequest { pub destination_port: u16, } +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ClientConnectionRequestReadError { + #[error("Invalid version in client request: {0} (expected 5)")] + InvalidVersion(u8), + #[error("Invalid command for client request: {0}")] + InvalidCommand(#[from] ClientConnectionCommandReadError), + #[error("Invalid reserved byte: {0} (expected 0)")] + InvalidReservedByte(u8), + #[error("Underlying read error: {0}")] + ReadError(String), + #[error(transparent)] + AddressReadError(#[from] SOCKSv5AddressReadError), +} + +impl From for ClientConnectionRequestReadError { + fn from(x: std::io::Error) -> ClientConnectionRequestReadError { + ClientConnectionRequestReadError::ReadError(format!("{}", x)) + } +} + impl ClientConnectionRequest { pub async fn read( r: &mut R, - ) -> Result { - let mut buffer = [0; 3]; - - debug!("Starting to read request."); - read_amt(r, 3, &mut buffer).await?; - debug!("Read three opening bytes: {:?}", buffer); - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); + ) -> Result { + let version = r.read_u8().await?; + if version != 5 { + return Err(ClientConnectionRequestReadError::InvalidVersion(version)); } - let command_code = match buffer[1] { - 0x01 => ClientConnectionCommand::EstablishTCPStream, - 0x02 => ClientConnectionCommand::EstablishTCPPortBinding, - 0x03 => ClientConnectionCommand::AssociateUDPPort, - x => return Err(DeserializationError::InvalidClientCommand(x)), - }; - debug!("Command code: {:?}", command_code); + let command_code = ClientConnectionCommand::read(r).await?; - if buffer[2] != 0 { - return Err(DeserializationError::InvalidReservedByte(buffer[2])); + let reserved = r.read_u8().await?; + if reserved != 0 { + return Err(ClientConnectionRequestReadError::InvalidReservedByte( + reserved, + )); } let destination_address = SOCKSv5Address::read(r).await?; - debug!("Destination address: {}", destination_address); - - read_amt(r, 2, &mut buffer).await?; - let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); - debug!("Destination port: {}", destination_port); + let destination_port = r.read_u16().await?; Ok(ClientConnectionRequest { command_code, @@ -74,63 +127,62 @@ impl ClientConnectionRequest { pub async fn write( &self, w: &mut W, - ) -> Result<(), SerializationError> { - let command = match self.command_code { - ClientConnectionCommand::EstablishTCPStream => 1, - ClientConnectionCommand::EstablishTCPPortBinding => 2, - ClientConnectionCommand::AssociateUDPPort => 3, - }; - - w.write_all(&[5, command, 0]).await?; + ) -> Result<(), ClientConnectionCommandWriteError> { + w.write_u8(5).await?; + self.command_code.write(w).await?; + w.write_u8(0).await?; self.destination_address.write(w).await?; - w.write_all(&[ - (self.destination_port >> 8) as u8, - (self.destination_port & 0xffu16) as u8, - ]) - .await - .map_err(SerializationError::IOError) + w.write_u16(self.destination_port).await?; + Ok(()) } } -standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest); +crate::standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest); -#[test] -fn check_short_reads() { +#[tokio::test] +async fn check_short_reads() { let empty = vec![]; let mut cursor = Cursor::new(empty); - let ys = ClientConnectionRequest::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ClientConnectionRequest::read(&mut cursor).await; + assert!(matches!( + ys, + Err(ClientConnectionRequestReadError::ReadError(_)) + )); let no_len = vec![5, 1]; let mut cursor = Cursor::new(no_len); - let ys = ClientConnectionRequest::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ClientConnectionRequest::read(&mut cursor).await; + assert!(matches!( + ys, + Err(ClientConnectionRequestReadError::ReadError(_)) + )); } -#[test] -fn check_bad_version() { +#[tokio::test] +async fn check_bad_version() { let bad_ver = vec![6, 1, 1]; let mut cursor = Cursor::new(bad_ver); - let ys = ClientConnectionRequest::read(&mut cursor); - assert_eq!( - Err(DeserializationError::InvalidVersion(5, 6)), - task::block_on(ys) - ); + let ys = ClientConnectionRequest::read(&mut cursor).await; + assert_eq!(Err(ClientConnectionRequestReadError::InvalidVersion(6)), ys); } -#[test] -fn check_bad_command() { +#[tokio::test] +async fn check_bad_command() { let bad_cmd = vec![5, 32, 1]; let mut cursor = Cursor::new(bad_cmd); - let ys = ClientConnectionRequest::read(&mut cursor); + let ys = ClientConnectionRequest::read(&mut cursor).await; assert_eq!( - Err(DeserializationError::InvalidClientCommand(32)), - task::block_on(ys) + Err(ClientConnectionRequestReadError::InvalidCommand( + ClientConnectionCommandReadError::InvalidClientConnectionCommand(32) + )), + ys ); } -#[test] -fn short_write_fails_right() { +#[tokio::test] +async fn short_write_fails_right() { + use std::net::Ipv4Addr; + let mut buffer = [0u8; 2]; let cmd = ClientConnectionRequest { command_code: ClientConnectionCommand::AssociateUDPPort, @@ -138,10 +190,12 @@ fn short_write_fails_right() { destination_port: 22, }; let mut cursor = Cursor::new(&mut buffer as &mut [u8]); - let result = task::block_on(cmd.write(&mut cursor)); + let result = cmd.write(&mut cursor).await; match result { Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."), - Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), + Err(ClientConnectionCommandWriteError::WriteError(x)) => { + assert!(x.contains("write zero")); + } Err(e) => panic!("Got the wrong error writing too much data: {}", e), } } diff --git a/src/messages/client_greeting.rs b/src/messages/client_greeting.rs index 63787a7..af8faec 100644 --- a/src/messages/client_greeting.rs +++ b/src/messages/client_greeting.rs @@ -1,16 +1,12 @@ -#[cfg(test)] -use crate::errors::AuthenticationDeserializationError; -use crate::errors::{DeserializationError, SerializationError}; -use crate::messages::AuthenticationMethod; -use crate::standard_roundtrip; -#[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use proptest::proptest; +use crate::messages::authentication_method::{ + AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError, +}; #[cfg(test)] use proptest_derive::Arbitrary; +#[cfg(test)] +use std::io::Cursor; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; /// Client greetings are the first message sent in a SOCKSv5 session. They /// identify that there's a client that wants to talk to a server, and that @@ -23,26 +19,52 @@ pub struct ClientGreeting { pub acceptable_methods: Vec, } +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ClientGreetingReadError { + #[error("Invalid version in client request: {0} (expected 5)")] + InvalidVersion(u8), + #[error(transparent)] + AuthMethodReadError(#[from] AuthenticationMethodReadError), + #[error("Underlying read error: {0}")] + ReadError(String), +} + +impl From for ClientGreetingReadError { + fn from(x: std::io::Error) -> ClientGreetingReadError { + ClientGreetingReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ClientGreetingWriteError { + #[error("Too many methods provided; need <256, saw {0}")] + TooManyMethods(usize), + #[error(transparent)] + AuthMethodWriteError(#[from] AuthenticationMethodWriteError), + #[error("Underlying write error: {0}")] + WriteError(String), +} + +impl From for ClientGreetingWriteError { + fn from(x: std::io::Error) -> ClientGreetingWriteError { + ClientGreetingWriteError::WriteError(format!("{}", x)) + } +} + impl ClientGreeting { pub async fn read( r: &mut R, - ) -> Result { - let mut buffer = [0; 1]; + ) -> Result { + let version = r.read_u8().await?; - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); + if version != 5 { + return Err(ClientGreetingReadError::InvalidVersion(version)); } - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); - } + let num_methods = r.read_u8().await? as usize; - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize); - for _ in 0..buffer[0] { + let mut acceptable_methods = Vec::with_capacity(num_methods); + for _ in 0..num_methods { acceptable_methods.push(AuthenticationMethod::read(r).await?); } @@ -52,9 +74,9 @@ impl ClientGreeting { pub async fn write( &self, w: &mut W, - ) -> Result<(), SerializationError> { + ) -> Result<(), ClientGreetingWriteError> { if self.acceptable_methods.len() > 255 { - return Err(SerializationError::TooManyAuthMethods( + return Err(ClientGreetingWriteError::TooManyMethods( self.acceptable_methods.len(), )); } @@ -70,44 +92,41 @@ impl ClientGreeting { } } -standard_roundtrip!(client_greeting_roundtrips, ClientGreeting); +crate::standard_roundtrip!(client_greeting_roundtrips, ClientGreeting); -#[test] -fn check_short_reads() { +#[tokio::test] +async fn check_short_reads() { let empty = vec![]; let mut cursor = Cursor::new(empty); - let ys = ClientGreeting::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ClientGreeting::read(&mut cursor).await; + assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_)))); let no_len = vec![5]; let mut cursor = Cursor::new(no_len); - let ys = ClientGreeting::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ClientGreeting::read(&mut cursor).await; + assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_)))); let bad_len = vec![5, 9]; let mut cursor = Cursor::new(bad_len); - let ys = ClientGreeting::read(&mut cursor); - assert_eq!( - Err(DeserializationError::AuthenticationMethodError( - AuthenticationDeserializationError::NoDataFound - )), - task::block_on(ys) - ); + let ys = ClientGreeting::read(&mut cursor).await; + assert!(matches!( + ys, + Err(ClientGreetingReadError::AuthMethodReadError( + AuthenticationMethodReadError::ReadError(_) + )) + )); } -#[test] -fn check_bad_version() { +#[tokio::test] +async fn check_bad_version() { let no_len = vec![6, 1, 1]; let mut cursor = Cursor::new(no_len); - let ys = ClientGreeting::read(&mut cursor); - assert_eq!( - Err(DeserializationError::InvalidVersion(5, 6)), - task::block_on(ys) - ); + let ys = ClientGreeting::read(&mut cursor).await; + assert_eq!(Err(ClientGreetingReadError::InvalidVersion(6)), ys); } -#[test] -fn check_too_many() { +#[tokio::test] +async fn check_too_many() { let mut auth_methods = Vec::with_capacity(512); auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake); let greet = ClientGreeting { @@ -115,7 +134,7 @@ fn check_too_many() { }; let mut output = vec![0; 1024]; assert_eq!( - Err(SerializationError::TooManyAuthMethods(512)), - task::block_on(greet.write(&mut output)) + Err(ClientGreetingWriteError::TooManyMethods(512)), + greet.write(&mut output).await ); } diff --git a/src/messages/client_username_password.rs b/src/messages/client_username_password.rs index d071a82..c2fe5b3 100644 --- a/src/messages/client_username_password.rs +++ b/src/messages/client_username_password.rs @@ -1,16 +1,10 @@ -use crate::errors::{DeserializationError, SerializationError}; -use crate::serialize::{read_string, write_string}; -use crate::standard_roundtrip; +use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError}; #[cfg(test)] -use async_std::task; +use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy}; #[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -#[cfg(test)] -use proptest::prelude::{Arbitrary, BoxedStrategy}; -use proptest::proptest; -#[cfg(test)] -use proptest::strategy::Strategy; +use std::io::Cursor; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Clone, Debug, Eq, PartialEq)] pub struct ClientUsernamePassword { @@ -30,30 +24,58 @@ impl Arbitrary for ClientUsernamePassword { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { let max_len = args.unwrap_or(12) as usize; - (USERNAME_REGEX, PASSWORD_REGEX).prop_map(move |(mut username, mut password)| { - username.shrink_to(max_len); - password.shrink_to(max_len); - ClientUsernamePassword { username, password } - }).boxed() + (USERNAME_REGEX, PASSWORD_REGEX) + .prop_map(move |(mut username, mut password)| { + username.shrink_to(max_len); + password.shrink_to(max_len); + ClientUsernamePassword { username, password } + }) + .boxed() + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ClientUsernamePasswordReadError { + #[error("Underlying buffer read error: {0}")] + ReadError(String), + #[error("Invalid username/password version; expected 1, saw {0}")] + InvalidVersion(u8), + #[error(transparent)] + StringError(#[from] SOCKSv5StringReadError), +} + +impl From for ClientUsernamePasswordReadError { + fn from(x: std::io::Error) -> ClientUsernamePasswordReadError { + ClientUsernamePasswordReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ClientUsernamePasswordWriteError { + #[error("Underlying buffer read error: {0}")] + WriteError(String), + #[error(transparent)] + StringError(#[from] SOCKSv5StringWriteError), +} + +impl From for ClientUsernamePasswordWriteError { + fn from(x: std::io::Error) -> ClientUsernamePasswordWriteError { + ClientUsernamePasswordWriteError::WriteError(format!("{}", x)) } } impl ClientUsernamePassword { pub async fn read( r: &mut R, - ) -> Result { - let mut buffer = [0; 1]; + ) -> Result { + let version = r.read_u8().await?; - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); + if version != 1 { + return Err(ClientUsernamePasswordReadError::InvalidVersion(version)); } - if buffer[0] != 1 { - return Err(DeserializationError::InvalidVersion(1, buffer[0])); - } - - let username = read_string(r).await?; - let password = read_string(r).await?; + let username = SOCKSv5String::read(r).await?.into(); + let password = SOCKSv5String::read(r).await?.into(); Ok(ClientUsernamePassword { username, password }) } @@ -61,35 +83,40 @@ impl ClientUsernamePassword { pub async fn write( &self, w: &mut W, - ) -> Result<(), SerializationError> { - w.write_all(&[1]).await?; - write_string(&self.username, w).await?; - write_string(&self.password, w).await + ) -> Result<(), ClientUsernamePasswordWriteError> { + w.write_u8(1).await?; + SOCKSv5String::from(self.username.as_str()).write(w).await?; + SOCKSv5String::from(self.password.as_str()).write(w).await?; + Ok(()) } } -standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword); +crate::standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword); -#[test] -fn check_short_reads() { +#[tokio::test] +async fn heck_short_reads() { let empty = vec![]; let mut cursor = Cursor::new(empty); - let ys = ClientUsernamePassword::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ClientUsernamePassword::read(&mut cursor).await; + assert!(matches!( + ys, + Err(ClientUsernamePasswordReadError::ReadError(_)) + )); let user_only = vec![1, 3, 102, 111, 111]; let mut cursor = Cursor::new(user_only); - let ys = ClientUsernamePassword::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ClientUsernamePassword::read(&mut cursor).await; + println!("ys: {:?}", ys); + assert!(matches!( + ys, + Err(ClientUsernamePasswordReadError::StringError(_)) + )); } -#[test] -fn check_bad_version() { +#[tokio::test] +async fn check_bad_version() { let bad_len = vec![5]; let mut cursor = Cursor::new(bad_len); - let ys = ClientUsernamePassword::read(&mut cursor); - assert_eq!( - Err(DeserializationError::InvalidVersion(1, 5)), - task::block_on(ys) - ); + let ys = ClientUsernamePassword::read(&mut cursor).await; + assert_eq!(Err(ClientUsernamePasswordReadError::InvalidVersion(5)), ys); } diff --git a/src/messages/server_auth_response.rs b/src/messages/server_auth_response.rs index 0963265..a527ef5 100644 --- a/src/messages/server_auth_response.rs +++ b/src/messages/server_auth_response.rs @@ -1,13 +1,7 @@ -use crate::errors::{DeserializationError, SerializationError}; -use crate::standard_roundtrip; -#[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use proptest::proptest; #[cfg(test)] use proptest_derive::Arbitrary; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Clone, Debug, Eq, PartialEq)] #[cfg_attr(test, derive(Arbitrary))] @@ -15,6 +9,32 @@ pub struct ServerAuthResponse { pub success: bool, } +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ServerAuthResponseReadError { + #[error("Underlying buffer read error: {0}")] + ReadError(String), + #[error("Invalid username/password version; expected 1, saw {0}")] + InvalidVersion(u8), +} + +impl From for ServerAuthResponseReadError { + fn from(x: std::io::Error) -> ServerAuthResponseReadError { + ServerAuthResponseReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ServerAuthResponseWriteError { + #[error("Underlying buffer read error: {0}")] + WriteError(String), +} + +impl From for ServerAuthResponseWriteError { + fn from(x: std::io::Error) -> ServerAuthResponseWriteError { + ServerAuthResponseWriteError::WriteError(format!("{}", x)) + } +} + impl ServerAuthResponse { pub fn success() -> ServerAuthResponse { ServerAuthResponse { success: true } @@ -26,30 +46,22 @@ impl ServerAuthResponse { pub async fn read( r: &mut R, - ) -> Result { - let mut buffer = [0; 1]; + ) -> Result { + let version = r.read_u8().await?; - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - if buffer[0] != 1 { - return Err(DeserializationError::InvalidVersion(1, buffer[0])); - } - - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); + if version != 1 { + return Err(ServerAuthResponseReadError::InvalidVersion(version)); } Ok(ServerAuthResponse { - success: buffer[0] == 0, + success: r.read_u8().await? == 0, }) } pub async fn write( &self, w: &mut W, - ) -> Result<(), SerializationError> { + ) -> Result<(), ServerAuthResponseWriteError> { w.write_all(&[1]).await?; w.write_all(&[if self.success { 0x00 } else { 0xde }]) .await?; @@ -57,28 +69,29 @@ impl ServerAuthResponse { } } -standard_roundtrip!(server_auth_response, ServerAuthResponse); +crate::standard_roundtrip!(server_auth_response, ServerAuthResponse); + +#[tokio::test] +async fn check_short_reads() { + use std::io::Cursor; -#[test] -fn check_short_reads() { let empty = vec![]; let mut cursor = Cursor::new(empty); - let ys = ServerAuthResponse::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ServerAuthResponse::read(&mut cursor).await; + assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_)))); let no_len = vec![1]; let mut cursor = Cursor::new(no_len); - let ys = ServerAuthResponse::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ServerAuthResponse::read(&mut cursor).await; + assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_)))); } -#[test] -fn check_bad_version() { +#[tokio::test] +async fn check_bad_version() { + use std::io::Cursor; + let no_len = vec![6, 1]; let mut cursor = Cursor::new(no_len); - let ys = ServerAuthResponse::read(&mut cursor); - assert_eq!( - Err(DeserializationError::InvalidVersion(1, 6)), - task::block_on(ys) - ); + let ys = ServerAuthResponse::read(&mut cursor).await; + assert_eq!(Err(ServerAuthResponseReadError::InvalidVersion(6)), ys); } diff --git a/src/messages/server_choice.rs b/src/messages/server_choice.rs index 2b452e4..e6d089c 100644 --- a/src/messages/server_choice.rs +++ b/src/messages/server_choice.rs @@ -1,16 +1,12 @@ -#[cfg(test)] -use crate::errors::AuthenticationDeserializationError; -use crate::errors::{DeserializationError, SerializationError}; -use crate::messages::AuthenticationMethod; -use crate::standard_roundtrip; -#[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use proptest::proptest; +use crate::messages::authentication_method::{ + AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError, +}; #[cfg(test)] use proptest_derive::Arbitrary; +#[cfg(test)] +use std::io::Cursor; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Clone, Debug, Eq, PartialEq)] #[cfg_attr(test, derive(Arbitrary))] @@ -18,6 +14,36 @@ pub struct ServerChoice { pub chosen_method: AuthenticationMethod, } +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ServerChoiceReadError { + #[error(transparent)] + AuthMethodError(#[from] AuthenticationMethodReadError), + #[error("Error in underlying buffer: {0}")] + ReadError(String), + #[error("Invalid version; expected 5, got {0}")] + InvalidVersion(u8), +} + +impl From for ServerChoiceReadError { + fn from(x: std::io::Error) -> ServerChoiceReadError { + ServerChoiceReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ServerChoiceWriteError { + #[error(transparent)] + AuthMethodError(#[from] AuthenticationMethodWriteError), + #[error("Error in underlying buffer: {0}")] + WriteError(String), +} + +impl From for ServerChoiceWriteError { + fn from(x: std::io::Error) -> ServerChoiceWriteError { + ServerChoiceWriteError::WriteError(format!("{}", x)) + } +} + impl ServerChoice { pub fn rejection() -> ServerChoice { ServerChoice { @@ -33,15 +59,11 @@ impl ServerChoice { pub async fn read( r: &mut R, - ) -> Result { - let mut buffer = [0; 1]; + ) -> Result { + let version = r.read_u8().await?; - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); + if version != 5 { + return Err(ServerChoiceReadError::InvalidVersion(version)); } let chosen_method = AuthenticationMethod::read(r).await?; @@ -52,39 +74,32 @@ impl ServerChoice { pub async fn write( &self, w: &mut W, - ) -> Result<(), SerializationError> { - w.write_all(&[5]).await?; - self.chosen_method.write(w).await + ) -> Result<(), ServerChoiceWriteError> { + w.write_u8(5).await?; + self.chosen_method.write(w).await?; + Ok(()) } } -standard_roundtrip!(server_choice_roundtrips, ServerChoice); +crate::standard_roundtrip!(server_choice_roundtrips, ServerChoice); -#[test] -fn check_short_reads() { +#[tokio::test] +async fn check_short_reads() { let empty = vec![]; let mut cursor = Cursor::new(empty); - let ys = ServerChoice::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ServerChoice::read(&mut cursor).await; + assert!(matches!(ys, Err(ServerChoiceReadError::ReadError(_)))); let bad_len = vec![5]; let mut cursor = Cursor::new(bad_len); - let ys = ServerChoice::read(&mut cursor); - assert_eq!( - Err(DeserializationError::AuthenticationMethodError( - AuthenticationDeserializationError::NoDataFound - )), - task::block_on(ys) - ); + let ys = ServerChoice::read(&mut cursor).await; + assert!(matches!(ys, Err(ServerChoiceReadError::AuthMethodError(_)))); } -#[test] -fn check_bad_version() { +#[tokio::test] +async fn check_bad_version() { let no_len = vec![9, 1]; let mut cursor = Cursor::new(no_len); - let ys = ServerChoice::read(&mut cursor); - assert_eq!( - Err(DeserializationError::InvalidVersion(5, 9)), - task::block_on(ys) - ); + let ys = ServerChoice::read(&mut cursor).await; + assert_eq!(Err(ServerChoiceReadError::InvalidVersion(9)), ys); } diff --git a/src/messages/server_response.rs b/src/messages/server_response.rs index bf35d8f..5de5d58 100644 --- a/src/messages/server_response.rs +++ b/src/messages/server_response.rs @@ -1,21 +1,10 @@ -use crate::errors::{DeserializationError, SerializationError}; -use crate::network::generic::IntoErrorResponse; -use crate::network::SOCKSv5Address; -use crate::serialize::read_amt; -use crate::standard_roundtrip; -#[cfg(test)] -use async_std::io::ErrorKind; -#[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use log::warn; -use proptest::proptest; +use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError}; #[cfg(test)] use proptest_derive::Arbitrary; -use std::net::Ipv4Addr; +#[cfg(test)] +use std::io::Cursor; use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Clone, Debug, Eq, Error, PartialEq)] #[cfg_attr(test, derive(Arbitrary))] @@ -40,12 +29,6 @@ pub enum ServerResponseStatus { AddressTypeNotSupported, } -impl IntoErrorResponse for ServerResponseStatus { - fn into_response(&self) -> ServerResponseStatus { - self.clone() - } -} - #[derive(Clone, Debug, Eq, PartialEq)] #[cfg_attr(test, derive(Arbitrary))] pub struct ServerResponse { @@ -54,33 +37,57 @@ pub struct ServerResponse { pub bound_port: u16, } -impl ServerResponse { - pub fn error(resp: &E) -> ServerResponse { - ServerResponse { - status: resp.into_response(), - bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)), - bound_port: 0, - } +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ServerResponseReadError { + #[error("Error reading from underlying buffer: {0}")] + ReadError(String), + #[error(transparent)] + AddressReadError(#[from] SOCKSv5AddressReadError), + #[error("Invalid version; expected 5, got {0}")] + InvalidVersion(u8), + #[error("Invalid reserved byte; saw {0}, should be 0")] + InvalidReservedByte(u8), + #[error("Invalid (or just unknown) server response value {0}")] + InvalidServerResponse(u8), +} + +impl From for ServerResponseReadError { + fn from(x: std::io::Error) -> ServerResponseReadError { + ServerResponseReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ServerResponseWriteError { + #[error("Error reading from underlying buffer: {0}")] + WriteError(String), + #[error(transparent)] + AddressWriteError(#[from] SOCKSv5AddressWriteError), +} + +impl From for ServerResponseWriteError { + fn from(x: std::io::Error) -> ServerResponseWriteError { + ServerResponseWriteError::WriteError(format!("{}", x)) } } impl ServerResponse { pub async fn read( r: &mut R, - ) -> Result { - let mut buffer = [0; 3]; - - read_amt(r, 3, &mut buffer).await?; - - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); + ) -> Result { + let version = r.read_u8().await?; + if version != 5 { + return Err(ServerResponseReadError::InvalidVersion(version)); } - if buffer[2] != 0 { - warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes."); + let status_byte = r.read_u8().await?; + + let reserved_byte = r.read_u8().await?; + if reserved_byte != 0 { + return Err(ServerResponseReadError::InvalidReservedByte(reserved_byte)); } - let status = match buffer[1] { + let status = match status_byte { 0x00 => ServerResponseStatus::RequestGranted, 0x01 => ServerResponseStatus::GeneralFailure, 0x02 => ServerResponseStatus::ConnectionNotAllowedByRule, @@ -90,12 +97,11 @@ impl ServerResponse { 0x06 => ServerResponseStatus::TTLExpired, 0x07 => ServerResponseStatus::CommandNotSupported, 0x08 => ServerResponseStatus::AddressTypeNotSupported, - x => return Err(DeserializationError::InvalidServerResponse(x)), + x => return Err(ServerResponseReadError::InvalidServerResponse(x)), }; let bound_address = SOCKSv5Address::read(r).await?; - read_amt(r, 2, &mut buffer).await?; - let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); + let bound_port = r.read_u16().await?; Ok(ServerResponse { status, @@ -107,7 +113,9 @@ impl ServerResponse { pub async fn write( &self, w: &mut W, - ) -> Result<(), SerializationError> { + ) -> Result<(), ServerResponseWriteError> { + w.write_u8(5).await?; + let status_code = match self.status { ServerResponseStatus::RequestGranted => 0x00, ServerResponseStatus::GeneralFailure => 0x01, @@ -119,59 +127,61 @@ impl ServerResponse { ServerResponseStatus::CommandNotSupported => 0x07, ServerResponseStatus::AddressTypeNotSupported => 0x08, }; - - w.write_all(&[5, status_code, 0]).await?; + w.write_u8(status_code).await?; + w.write_u8(0).await?; self.bound_address.write(w).await?; - w.write_all(&[ - (self.bound_port >> 8) as u8, - (self.bound_port & 0xffu16) as u8, - ]) - .await - .map_err(SerializationError::IOError) + w.write_u16(self.bound_port).await?; + + Ok(()) } } -standard_roundtrip!(server_response_roundtrips, ServerResponse); +crate::standard_roundtrip!(server_response_roundtrips, ServerResponse); -#[test] -fn check_short_reads() { +#[tokio::test] +async fn check_short_reads() { let empty = vec![]; let mut cursor = Cursor::new(empty); - let ys = ServerResponse::read(&mut cursor); - assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + let ys = ServerResponse::read(&mut cursor).await; + assert!(matches!(ys, Err(ServerResponseReadError::ReadError(_)))); } -#[test] -fn check_bad_version() { +#[tokio::test] +async fn check_bad_version() { let bad_ver = vec![6, 1, 1]; let mut cursor = Cursor::new(bad_ver); - let ys = ServerResponse::read(&mut cursor); - assert_eq!( - Err(DeserializationError::InvalidVersion(5, 6)), - task::block_on(ys) - ); + let ys = ServerResponse::read(&mut cursor).await; + assert_eq!(Err(ServerResponseReadError::InvalidVersion(6)), ys); } -#[test] -fn check_bad_command() { +#[tokio::test] +async fn check_bad_reserved() { let bad_cmd = vec![5, 32, 0x42]; let mut cursor = Cursor::new(bad_cmd); - let ys = ServerResponse::read(&mut cursor); - assert_eq!( - Err(DeserializationError::InvalidServerResponse(32)), - task::block_on(ys) - ); + let ys = ServerResponse::read(&mut cursor).await; + assert_eq!(Err(ServerResponseReadError::InvalidReservedByte(0x42)), ys); } -#[test] -fn short_write_fails_right() { - let mut buffer = [0u8; 2]; - let cmd = ServerResponse::error(&ServerResponseStatus::AddressTypeNotSupported); - let mut cursor = Cursor::new(&mut buffer as &mut [u8]); - let result = task::block_on(cmd.write(&mut cursor)); - match result { - Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."), - Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), - Err(e) => panic!("Got the wrong error writing too much data: {}", e), - } +#[tokio::test] +async fn check_bad_command() { + let bad_cmd = vec![5, 32, 0]; + let mut cursor = Cursor::new(bad_cmd); + let ys = ServerResponse::read(&mut cursor).await; + assert_eq!(Err(ServerResponseReadError::InvalidServerResponse(32)), ys); +} + +#[tokio::test] +async fn short_write_fails_right() { + let mut buffer = [0u8; 2]; + let cmd = ServerResponse { + status: ServerResponseStatus::AddressTypeNotSupported, + bound_address: SOCKSv5Address::Hostname("tester.com".to_string()), + bound_port: 99, + }; + let mut cursor = Cursor::new(&mut buffer as &mut [u8]); + let result = cmd.write(&mut cursor).await; + assert!(matches!( + result, + Err(ServerResponseWriteError::WriteError(_)) + )); } diff --git a/src/messages/string.rs b/src/messages/string.rs new file mode 100644 index 0000000..7af0fe7 --- /dev/null +++ b/src/messages/string.rs @@ -0,0 +1,117 @@ +#[cfg(test)] +use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy}; +use std::convert::TryFrom; +use std::string::FromUtf8Error; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +#[derive(Debug, PartialEq)] +pub struct SOCKSv5String(String); + +#[cfg(test)] +const STRING_REGEX: &str = "[a-zA-Z0-9_.|!@#$%^]+"; + +#[cfg(test)] +impl Arbitrary for SOCKSv5String { + type Parameters = Option; + type Strategy = BoxedStrategy; + + fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { + let max_len = args.unwrap_or(32) as usize; + + STRING_REGEX + .prop_map(move |mut str| { + str.shrink_to(max_len); + SOCKSv5String(str) + }) + .boxed() + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum SOCKSv5StringReadError { + #[error("Underlying buffer read error: {0}")] + ReadError(String), + #[error("SOCKSv5 string encoding error; encountered empty string (?)")] + ZeroStringLength, + #[error("Invalid UTF-8 string: {0}")] + InvalidUtf8Error(#[from] FromUtf8Error), +} + +impl From for SOCKSv5StringReadError { + fn from(x: std::io::Error) -> SOCKSv5StringReadError { + SOCKSv5StringReadError::ReadError(format!("{}", x)) + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum SOCKSv5StringWriteError { + #[error("Underlying buffer write error: {0}")] + WriteError(String), + #[error("String too large to encode according to SOCKSv5 reuls ({0} bytes long)")] + TooBig(usize), + #[error("Cannot serialize the empty string in SOCKSv5")] + ZeroStringLength, +} + +impl From for SOCKSv5StringWriteError { + fn from(x: std::io::Error) -> SOCKSv5StringWriteError { + SOCKSv5StringWriteError::WriteError(format!("{}", x)) + } +} + +impl SOCKSv5String { + pub async fn read(r: &mut R) -> Result { + let length = r.read_u8().await? as usize; + + if length == 0 { + return Err(SOCKSv5StringReadError::ZeroStringLength); + } + + let mut bytestring = vec![0; length]; + r.read_exact(&mut bytestring).await?; + + Ok(SOCKSv5String(String::from_utf8(bytestring)?)) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SOCKSv5StringWriteError> { + let bytestring = self.0.as_bytes(); + + if bytestring.is_empty() { + return Err(SOCKSv5StringWriteError::ZeroStringLength); + } + + let length = match u8::try_from(bytestring.len()) { + Err(_) => return Err(SOCKSv5StringWriteError::TooBig(bytestring.len())), + Ok(x) => x, + }; + + w.write_u8(length).await?; + w.write_all(bytestring).await?; + + Ok(()) + } +} + +impl From for SOCKSv5String { + fn from(x: String) -> Self { + SOCKSv5String(x) + } +} + +impl<'a> From<&'a str> for SOCKSv5String { + fn from(x: &str) -> Self { + SOCKSv5String(x.to_string()) + } +} + +impl From for String { + fn from(x: SOCKSv5String) -> Self { + x.0 + } +} + +crate::standard_roundtrip!(socks_string_roundtrips, SOCKSv5String); diff --git a/src/messages/utils.rs b/src/messages/utils.rs deleted file mode 100644 index 8c44441..0000000 --- a/src/messages/utils.rs +++ /dev/null @@ -1,16 +0,0 @@ -#[doc(hidden)] -#[macro_export] -macro_rules! standard_roundtrip { - ($name: ident, $t: ty) => { - proptest! { - #[test] - fn $name(xs: $t) { - let mut buffer = vec![]; - task::block_on(xs.write(&mut buffer)).unwrap(); - let mut cursor = Cursor::new(buffer); - let ys = <$t>::read(&mut cursor); - assert_eq!(xs, task::block_on(ys).unwrap()); - } - } - }; -} diff --git a/src/network.rs b/src/network.rs deleted file mode 100644 index 37d9917..0000000 --- a/src/network.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub mod address; -pub mod datagram; -pub mod generic; -pub mod listener; -pub mod standard; -pub mod stream; -pub mod testing; - -pub use crate::network::address::SOCKSv5Address; -pub use crate::network::standard::Builtin; diff --git a/src/network/address.rs b/src/network/address.rs deleted file mode 100644 index 7c40855..0000000 --- a/src/network/address.rs +++ /dev/null @@ -1,264 +0,0 @@ -use crate::errors::{DeserializationError, SerializationError}; -use crate::serialize::{read_amt, read_string, write_string}; -use crate::standard_roundtrip; -#[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use proptest::prelude::proptest; -#[cfg(test)] -use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any, prop_oneof}; -use std::convert::TryFrom; -use std::fmt; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use thiserror::Error; - -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub enum SOCKSv5Address { - IP4(Ipv4Addr), - IP6(Ipv6Addr), - Name(String), -} - -#[cfg(test)] -const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+"; - -#[cfg(test)] -impl Arbitrary for SOCKSv5Address { - type Parameters = Option; - type Strategy = BoxedStrategy; - - fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { - let max_len = args.unwrap_or(32) as usize; - - prop_oneof![ - any::().prop_map(SOCKSv5Address::IP4), - any::().prop_map(SOCKSv5Address::IP6), - HOSTNAME_REGEX.prop_map(move |mut hostname| { - hostname.shrink_to(max_len); - SOCKSv5Address::Name(hostname) - }), - ].boxed() - } -} - -#[derive(Error, Debug, PartialEq)] -pub enum AddressConversionError { - #[error("Couldn't convert IPv4 address into destination type")] - CouldntConvertIP4, - #[error("Couldn't convert IPv6 address into destination type")] - CouldntConvertIP6, - #[error("Couldn't convert name into destination type")] - CouldntConvertName, -} - -impl From for SOCKSv5Address { - fn from(x: IpAddr) -> SOCKSv5Address { - match x { - IpAddr::V4(a) => SOCKSv5Address::IP4(a), - IpAddr::V6(a) => SOCKSv5Address::IP6(a), - } - } -} - -impl TryFrom for IpAddr { - type Error = AddressConversionError; - - fn try_from(value: SOCKSv5Address) -> Result { - match value { - SOCKSv5Address::IP4(a) => Ok(IpAddr::V4(a)), - SOCKSv5Address::IP6(a) => Ok(IpAddr::V6(a)), - SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName), - } - } -} - -impl From for SOCKSv5Address { - fn from(x: Ipv4Addr) -> Self { - SOCKSv5Address::IP4(x) - } -} - -impl TryFrom for Ipv4Addr { - type Error = AddressConversionError; - - fn try_from(value: SOCKSv5Address) -> Result { - match value { - SOCKSv5Address::IP4(a) => Ok(a), - SOCKSv5Address::IP6(_) => Err(AddressConversionError::CouldntConvertIP6), - SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName), - } - } -} - -impl From for SOCKSv5Address { - fn from(x: Ipv6Addr) -> Self { - SOCKSv5Address::IP6(x) - } -} - -impl TryFrom for Ipv6Addr { - type Error = AddressConversionError; - - fn try_from(value: SOCKSv5Address) -> Result { - match value { - SOCKSv5Address::IP4(_) => Err(AddressConversionError::CouldntConvertIP4), - SOCKSv5Address::IP6(a) => Ok(a), - SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName), - } - } -} - -impl From for SOCKSv5Address { - fn from(x: String) -> Self { - SOCKSv5Address::Name(x) - } -} - -impl<'a> From<&'a str> for SOCKSv5Address { - fn from(x: &str) -> SOCKSv5Address { - SOCKSv5Address::Name(x.to_string()) - } -} - -impl fmt::Display for SOCKSv5Address { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - SOCKSv5Address::IP4(a) => write!(f, "{}", a), - SOCKSv5Address::IP6(a) => write!(f, "{}", a), - SOCKSv5Address::Name(a) => write!(f, "{}", a), - } - } -} - -impl SOCKSv5Address { - pub async fn read( - r: &mut R, - ) -> Result { - let mut byte_buffer = [0u8; 1]; - let amount_read = r.read(&mut byte_buffer).await?; - - if amount_read == 0 { - return Err(DeserializationError::NotEnoughData); - } - - match byte_buffer[0] { - 1 => { - let mut addr_buffer = [0; 4]; - read_amt(r, 4, &mut addr_buffer).await?; - Ok(SOCKSv5Address::IP4(Ipv4Addr::from(addr_buffer))) - } - 3 => { - let name = read_string(r).await?; - Ok(SOCKSv5Address::Name(name)) - } - 4 => { - let mut addr_buffer = [0; 16]; - read_amt(r, 16, &mut addr_buffer).await?; - Ok(SOCKSv5Address::IP6(Ipv6Addr::from(addr_buffer))) - } - x => Err(DeserializationError::InvalidAddressType(x)), - } - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - match self { - SOCKSv5Address::IP4(x) => { - w.write_all(&[1]).await?; - w.write_all(&x.octets()) - .await - .map_err(SerializationError::IOError) - } - SOCKSv5Address::IP6(x) => { - w.write_all(&[4]).await?; - w.write_all(&x.octets()) - .await - .map_err(SerializationError::IOError) - } - SOCKSv5Address::Name(x) => { - w.write_all(&[3]).await?; - write_string(x, w).await - } - } - } -} - -pub trait HasLocalAddress { - fn local_addr(&self) -> (SOCKSv5Address, u16); -} - -standard_roundtrip!(address_roundtrips, SOCKSv5Address); - -proptest! { - #[test] - fn ip_conversion(x: IpAddr) { - match x { - IpAddr::V4(ref a) => - assert_eq!(Err(AddressConversionError::CouldntConvertIP4), - Ipv6Addr::try_from(SOCKSv5Address::from(*a))), - IpAddr::V6(ref a) => - assert_eq!(Err(AddressConversionError::CouldntConvertIP6), - Ipv4Addr::try_from(SOCKSv5Address::from(*a))), - } - assert_eq!(x, IpAddr::try_from(SOCKSv5Address::from(x)).unwrap()); - } - - #[test] - fn ip4_conversion(x: Ipv4Addr) { - assert_eq!(x, Ipv4Addr::try_from(SOCKSv5Address::from(x)).unwrap()); - } - - #[test] - fn ip6_conversion(x: Ipv6Addr) { - assert_eq!(x, Ipv6Addr::try_from(SOCKSv5Address::from(x)).unwrap()); - } - - #[test] - fn display_matches(x: SOCKSv5Address) { - match x { - SOCKSv5Address::IP4(a) => assert_eq!(format!("{}", a), format!("{}", x)), - SOCKSv5Address::IP6(a) => assert_eq!(format!("{}", a), format!("{}", x)), - SOCKSv5Address::Name(ref a) => assert_eq!(*a, x.to_string()), - } - } - - #[test] - fn bad_read_key(x: u8) { - match x { - 1 | 3 | 4 => {} - _ => { - let buffer = [x, 0, 1, 2, 9, 10]; - let mut cursor = Cursor::new(buffer); - let meh = SOCKSv5Address::read(&mut cursor); - assert_eq!(Err(DeserializationError::InvalidAddressType(x)), task::block_on(meh)); - } - } - } -} - -#[test] -fn domain_name_sanity() { - let name = "uhsure.com"; - let strname = name.to_string(); - - let addr1 = SOCKSv5Address::from(name); - let addr2 = SOCKSv5Address::from(strname); - - assert_eq!(addr1, addr2); - assert_eq!( - Err(AddressConversionError::CouldntConvertName), - IpAddr::try_from(addr1.clone()) - ); - assert_eq!( - Err(AddressConversionError::CouldntConvertName), - Ipv4Addr::try_from(addr1.clone()) - ); - assert_eq!( - Err(AddressConversionError::CouldntConvertName), - Ipv6Addr::try_from(addr1) - ); -} diff --git a/src/network/datagram.rs b/src/network/datagram.rs deleted file mode 100644 index 6f029d3..0000000 --- a/src/network/datagram.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::network::address::{HasLocalAddress, SOCKSv5Address}; -use async_trait::async_trait; - -#[async_trait] -pub trait Datagramlike: Send + Sync + HasLocalAddress { - type Error; - - async fn send_to( - &self, - buf: &[u8], - addr: SOCKSv5Address, - port: u16, - ) -> Result; - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error>; -} - -pub struct GenericDatagramSocket { - pub internal: Box>, -} - -#[async_trait] -impl Datagramlike for GenericDatagramSocket { - type Error = E; - - async fn send_to( - &self, - buf: &[u8], - addr: SOCKSv5Address, - port: u16, - ) -> Result { - Ok(self.internal.send_to(buf, addr, port).await?) - } - - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> { - Ok(self.internal.recv_from(buf).await?) - } -} - -impl HasLocalAddress for GenericDatagramSocket { - fn local_addr(&self) -> (SOCKSv5Address, u16) { - self.internal.local_addr() - } -} diff --git a/src/network/generic.rs b/src/network/generic.rs deleted file mode 100644 index 0a82759..0000000 --- a/src/network/generic.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::messages::ServerResponseStatus; -use crate::network::address::SOCKSv5Address; -use crate::network::datagram::GenericDatagramSocket; -use crate::network::listener::GenericListener; -use crate::network::stream::GenericStream; -use async_trait::async_trait; -use std::fmt::{Debug, Display}; - -#[async_trait] -pub trait Networklike { - /// The error type for things that fail on this network. Apologies in advance - /// for using only one; if you have a use case for separating your errors, - /// please shoot the author(s) and email to split this into multiple types, one - /// for each trait function. - type Error: Debug + Display + IntoErrorResponse + Send; - - /// Connect to the given address and port, over this kind of network. The - /// underlying stream should behave somewhat like a TCP stream ... which - /// may be exactly what you're using. However, in order to support tunnelling - /// scenarios (i.e., using another proxy, going through Tor or SSH, etc.) we - /// work generically over any stream-like object. - async fn connect>( - &mut self, - addr: A, - port: u16, - ) -> Result; - - /// Listen for connections on the given address and port, returning a generic - /// listener socket to use in the future. - async fn listen>( - &mut self, - addr: A, - port: u16, - ) -> Result, Self::Error>; - - /// Bind a socket for the purposes of doing some datagram communication. NOTE! - /// this is only for UDP-like communication, not for generic connecting or - /// listening! Maybe obvious from the types, but POSIX has overtrained many - /// of us. - /// - /// Recall when using these functions that datagram protocols allow for packet - /// loss and out-of-order delivery. So ... be warned. - async fn bind>( - &mut self, - addr: A, - port: u16, - ) -> Result, Self::Error>; -} - -/// This trait is a hack; sorry about that. The thing is, we want to be able to -/// convert Errors from the `Networklike` trait into a `ServerResponseStatus`, -/// but want to do so on references to the error object rather than the actual -/// object. This is for the paired reason that (a) we want to be able to use -/// the errors in multiple places -- for example, to return a value to the client -/// and then also to whoever called the function -- and (b) some common errors -/// (I'm looking at you, `io::Error`) aren't `Clone`. So ... hence this overly- -/// specific trait. -pub trait IntoErrorResponse { - #[allow(clippy::wrong_self_convention)] - fn into_response(&self) -> ServerResponseStatus; -} diff --git a/src/network/listener.rs b/src/network/listener.rs deleted file mode 100644 index 231e43a..0000000 --- a/src/network/listener.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::network::address::{HasLocalAddress, SOCKSv5Address}; -use crate::network::stream::GenericStream; -use async_trait::async_trait; - -#[async_trait] -pub trait Listenerlike: Send + Sync + HasLocalAddress { - type Error; - - async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>; -} - -pub struct GenericListener { - pub internal: Box>, -} - -#[async_trait] -impl Listenerlike for GenericListener { - type Error = E; - - async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> { - Ok(self.internal.accept().await?) - } -} - -impl HasLocalAddress for GenericListener { - fn local_addr(&self) -> (SOCKSv5Address, u16) { - self.internal.local_addr() - } -} diff --git a/src/network/standard.rs b/src/network/standard.rs deleted file mode 100644 index e55f8e1..0000000 --- a/src/network/standard.rs +++ /dev/null @@ -1,240 +0,0 @@ -use crate::messages::ServerResponseStatus; -use crate::network::address::{HasLocalAddress, SOCKSv5Address}; -use crate::network::datagram::{Datagramlike, GenericDatagramSocket}; -use crate::network::generic::Networklike; -use crate::network::listener::{GenericListener, Listenerlike}; -use crate::network::stream::{GenericStream, Streamlike}; -use async_std::io; -#[cfg(test)] -use async_std::io::ReadExt; -use async_std::net::{TcpListener, TcpStream, UdpSocket}; -use async_trait::async_trait; -#[cfg(test)] -use futures::AsyncWriteExt; -use log::error; - -use super::generic::IntoErrorResponse; - -#[derive(Clone)] -pub struct Builtin {} - -impl Builtin { - pub fn new() -> Builtin { - Builtin {} - } -} - -impl Default for Builtin { - fn default() -> Self { - Self::new() - } -} - -macro_rules! local_address_impl { - ($t: ty) => { - impl HasLocalAddress for $t { - fn local_addr(&self) -> (SOCKSv5Address, u16) { - match self.local_addr() { - Ok(a) => - (SOCKSv5Address::from(a.ip()), a.port()), - Err(e) => { - error!("Couldn't translate (Streamlike) local address to SOCKS local address: {}", e); - (SOCKSv5Address::from("localhost"), 0) - } - } - } - } - }; -} - -local_address_impl!(TcpStream); -local_address_impl!(TcpListener); -local_address_impl!(UdpSocket); - -impl Streamlike for TcpStream {} - -#[async_trait] -impl Listenerlike for TcpListener { - type Error = io::Error; - - async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> { - let (base, addrport) = self.accept().await?; - let addr = addrport.ip(); - let port = addrport.port(); - Ok((GenericStream::new(base), SOCKSv5Address::from(addr), port)) - } -} - -#[async_trait] -impl Datagramlike for UdpSocket { - type Error = io::Error; - - async fn send_to( - &self, - buf: &[u8], - addr: SOCKSv5Address, - port: u16, - ) -> Result { - match addr { - SOCKSv5Address::IP4(a) => self.send_to(buf, (a, port)).await, - SOCKSv5Address::IP6(a) => self.send_to(buf, (a, port)).await, - SOCKSv5Address::Name(n) => self.send_to(buf, (n.as_str(), port)).await, - } - } - - async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> { - let (amt, addrport) = self.recv_from(buf).await?; - let addr = addrport.ip(); - let port = addrport.port(); - Ok((amt, SOCKSv5Address::from(addr), port)) - } -} - -#[async_trait] -impl Networklike for Builtin { - type Error = io::Error; - - async fn connect>( - &mut self, - addr: A, - port: u16, - ) -> Result { - let target = addr.into(); - - let base_stream = match target { - SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await?, - SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await?, - SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?, - }; - - Ok(GenericStream::from(base_stream)) - } - - async fn listen>( - &mut self, - addr: A, - port: u16, - ) -> Result, Self::Error> { - let target = addr.into(); - - let base_stream = match target { - SOCKSv5Address::IP4(a) => TcpListener::bind((a, port)).await?, - SOCKSv5Address::IP6(a) => TcpListener::bind((a, port)).await?, - SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), port)).await?, - }; - - Ok(GenericListener { - internal: Box::new(base_stream), - }) - } - - async fn bind>( - &mut self, - addr: A, - port: u16, - ) -> Result, Self::Error> { - let target = addr.into(); - - let base_socket = match target { - SOCKSv5Address::IP4(a) => UdpSocket::bind((a, port)).await?, - SOCKSv5Address::IP6(a) => UdpSocket::bind((a, port)).await?, - SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), port)).await?, - }; - - Ok(GenericDatagramSocket { - internal: Box::new(base_socket), - }) - } -} - -#[test] -fn check_sanity() { - async_std::task::block_on(async { - // Technically, this is UDP, and UDP is lossy. We're going to assume we're not - // going to get any dropped data along here ... which is a very questionable - // assumption, morally speaking, but probably fine for most purposes. - let mut network = Builtin::new(); - let receiver = network - .bind("localhost", 0) - .await - .expect("Failed to bind receiver socket."); - let sender = network - .bind("localhost", 0) - .await - .expect("Failed to bind sender socket."); - let buffer = [0xde, 0xea, 0xbe, 0xef]; - let (receiver_addr, receiver_port) = receiver.local_addr(); - sender - .send_to(&buffer, receiver_addr, receiver_port) - .await - .expect("Failure sending datagram!"); - let mut recvbuffer = [0; 4]; - let (s, f, p) = receiver - .recv_from(&mut recvbuffer) - .await - .expect("Didn't receive UDP message?"); - let (sender_addr, sender_port) = sender.local_addr(); - assert_eq!(s, 4); - assert_eq!(f, sender_addr); - assert_eq!(p, sender_port); - assert_eq!(recvbuffer, buffer); - }); - - // This whole block should be pretty solid, though, unless the system we're - // on is in a pretty weird place. - let mut network = Builtin::new(); - - let listener = async_std::task::block_on(network.listen("localhost", 0)) - .expect("Couldn't set up listener on localhost"); - let (listener_address, listener_port) = listener.local_addr(); - - let listener_task_handle = async_std::task::spawn(async move { - let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection"); - let mut result_buffer = [0u8; 4]; - stream - .read_exact(&mut result_buffer) - .await - .expect("Read failure in TCP test"); - (result_buffer, addr, port) - }); - - let sender_task_handle = async_std::task::spawn(async move { - let mut sender = network - .connect(listener_address, listener_port) - .await - .expect("Coudln't connect to listener?"); - let (sender_address, sender_port) = sender.local_addr(); - let send_buffer = [0xa, 0xff, 0xab, 0x1e]; - sender - .write_all(&send_buffer) - .await - .expect("Couldn't send the write buffer"); - sender - .flush() - .await - .expect("Couldn't flush the write buffer"); - sender - .close() - .await - .expect("Couldn't close the write buffer"); - (sender_address, sender_port) - }); - - async_std::task::block_on(async { - let (result, result_from, result_from_port) = listener_task_handle.await; - assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]); - let (sender_address, sender_port) = sender_task_handle.await; - assert_eq!(result_from, sender_address); - assert_eq!(result_from_port, sender_port); - }); -} - -impl IntoErrorResponse for io::Error { - fn into_response(&self) -> ServerResponseStatus { - match self.kind() { - io::ErrorKind::ConnectionRefused => ServerResponseStatus::ConnectionRefused, - io::ErrorKind::NotFound => ServerResponseStatus::HostUnreachable, - _ => ServerResponseStatus::GeneralFailure, - } - } -} diff --git a/src/network/stream.rs b/src/network/stream.rs deleted file mode 100644 index 8bddf57..0000000 --- a/src/network/stream.rs +++ /dev/null @@ -1,74 +0,0 @@ -use crate::network::SOCKSv5Address; -use async_std::task::{Context, Poll}; -use futures::io; -use futures::io::{AsyncRead, AsyncWrite}; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; - -use super::address::HasLocalAddress; - -pub trait Streamlike: AsyncRead + AsyncWrite + HasLocalAddress + Send + Sync + Unpin {} - -#[derive(Clone)] -pub struct GenericStream { - internal: Arc>, -} - -impl GenericStream { - pub fn new(x: T) -> GenericStream { - GenericStream { - internal: Arc::new(Mutex::new(x)), - } - } -} - -impl HasLocalAddress for GenericStream { - fn local_addr(&self) -> (SOCKSv5Address, u16) { - let item = self.internal.lock().unwrap(); - item.local_addr() - } -} - -impl AsyncRead for GenericStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let mut item = self.internal.lock().unwrap(); - let pinned = Pin::new(&mut *item); - pinned.poll_read(cx, buf) - } -} - -impl AsyncWrite for GenericStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let mut item = self.internal.lock().unwrap(); - let pinned = Pin::new(&mut *item); - pinned.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut item = self.internal.lock().unwrap(); - let pinned = Pin::new(&mut *item); - pinned.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut item = self.internal.lock().unwrap(); - let pinned = Pin::new(&mut *item); - pinned.poll_close(cx) - } -} - -impl From for GenericStream { - fn from(x: T) -> GenericStream { - GenericStream { - internal: Arc::new(Mutex::new(x)), - } - } -} diff --git a/src/network/testing.rs b/src/network/testing.rs deleted file mode 100644 index 44310fd..0000000 --- a/src/network/testing.rs +++ /dev/null @@ -1,276 +0,0 @@ -mod datagram; -mod stream; - -use crate::messages::ServerResponseStatus; -use crate::network::address::{HasLocalAddress, SOCKSv5Address}; -#[cfg(test)] -use crate::network::datagram::Datagramlike; -use crate::network::datagram::GenericDatagramSocket; -use crate::network::generic::{IntoErrorResponse, Networklike}; -use crate::network::listener::{GenericListener, Listenerlike}; -use crate::network::stream::GenericStream; -use crate::network::testing::datagram::TestDatagram; -use crate::network::testing::stream::TestingStream; -use async_std::channel::{bounded, Receiver, Sender}; -use async_std::sync::{Arc, Mutex}; -#[cfg(test)] -use async_std::task; -use async_trait::async_trait; -#[cfg(test)] -use futures::{AsyncReadExt, AsyncWriteExt}; -use std::collections::HashMap; -use std::fmt; - -/// A "network", based purely on internal Rust datatypes, for testing -/// networking code. This stack operates purely in memory, so shouldn't -/// suffer from any weird networking effects ... which makes it a good -/// functional test, but not great at actually testing real-world failure -/// modes. -#[allow(clippy::type_complexity)] -#[derive(Clone)] -pub struct TestingStack { - tcp_listeners: Arc>>>, - udp_sockets: Arc)>>>>, - next_random_socket: u16, -} - -impl TestingStack { - pub fn new() -> TestingStack { - TestingStack { - tcp_listeners: Arc::new(Mutex::new(HashMap::new())), - udp_sockets: Arc::new(Mutex::new(HashMap::new())), - next_random_socket: 23, - } - } -} - -impl Default for TestingStack { - fn default() -> Self { - Self::new() - } -} - -#[derive(Debug)] -pub enum TestStackError { - AcceptFailed, - AddressBusy(SOCKSv5Address, u16), - ConnectionFailed, - FailureToSend, - NoTCPHostFound(SOCKSv5Address, u16), - ReceiveFailure, -} - -impl fmt::Display for TestStackError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TestStackError::AcceptFailed => write!(f, "Accept failed; the other side died (?)"), - TestStackError::AddressBusy(ref addr, port) => { - write!(f, "Address {}:{} already in use", addr, port) - } - TestStackError::ConnectionFailed => write!(f, "Couldn't connect to host."), - TestStackError::FailureToSend => write!( - f, - "Weird internal error in testing infrastructure; channel send failed" - ), - TestStackError::NoTCPHostFound(ref addr, port) => { - write!(f, "No host found at {} for TCP port {}", addr, port) - } - TestStackError::ReceiveFailure => { - write!(f, "Failed to process a UDP receive (this is weird)") - } - } - } -} - -impl IntoErrorResponse for TestStackError { - fn into_response(&self) -> ServerResponseStatus { - ServerResponseStatus::GeneralFailure - } -} - -#[async_trait] -impl Networklike for TestingStack { - type Error = TestStackError; - - async fn connect>( - &mut self, - addr: A, - port: u16, - ) -> Result { - let table = self.tcp_listeners.lock().await; - let target = addr.into(); - - match table.get(&(target.clone(), port)) { - None => Err(TestStackError::NoTCPHostFound(target, port)), - Some(result) => { - let stream = TestingStream::new(target, port); - let retval = stream.invert(); - match result.send(stream).await { - Ok(()) => Ok(GenericStream::new(retval)), - Err(_) => Err(TestStackError::FailureToSend), - } - } - } - } - - async fn listen>( - &mut self, - addr: A, - mut port: u16, - ) -> Result, Self::Error> { - let mut table = self.tcp_listeners.lock().await; - let target = addr.into(); - let (sender, receiver) = bounded(5); - - if port == 0 { - port = self.next_random_socket; - self.next_random_socket += 1; - } - - table.insert((target.clone(), port), sender); - Ok(GenericListener { - internal: Box::new(TestListener::new(target, port, receiver)), - }) - } - - async fn bind>( - &mut self, - addr: A, - mut port: u16, - ) -> Result, Self::Error> { - let mut table = self.udp_sockets.lock().await; - let target = addr.into(); - let (sender, receiver) = bounded(5); - - if port == 0 { - port = self.next_random_socket; - self.next_random_socket += 1; - } - - table.insert((target.clone(), port), sender); - Ok(GenericDatagramSocket { - internal: Box::new(TestDatagram::new(self.clone(), target, port, receiver)), - }) - } -} - -struct TestListener { - address: SOCKSv5Address, - port: u16, - receiver: Receiver, -} - -impl TestListener { - fn new(address: SOCKSv5Address, port: u16, receiver: Receiver) -> Self { - TestListener { - address, - port, - receiver, - } - } -} - -impl HasLocalAddress for TestListener { - fn local_addr(&self) -> (SOCKSv5Address, u16) { - (self.address.clone(), self.port) - } -} - -#[async_trait] -impl Listenerlike for TestListener { - type Error = TestStackError; - - async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> { - match self.receiver.recv().await { - Ok(next) => { - let (addr, port) = next.local_addr(); - Ok((GenericStream::new(next), addr, port)) - } - Err(_) => Err(TestStackError::AcceptFailed), - } - } -} - -#[test] -fn check_udp_sanity() { - task::block_on(async { - let mut network = TestingStack::new(); - let receiver = network - .bind("localhost", 0) - .await - .expect("Failed to bind receiver socket."); - let sender = network - .bind("localhost", 0) - .await - .expect("Failed to bind sender socket."); - let buffer = [0xde, 0xea, 0xbe, 0xef]; - let (receiver_addr, receiver_port) = receiver.local_addr(); - sender - .send_to(&buffer, receiver_addr, receiver_port) - .await - .expect("Failure sending datagram!"); - let mut recvbuffer = [0; 4]; - let (s, f, p) = receiver - .recv_from(&mut recvbuffer) - .await - .expect("Didn't receive UDP message?"); - let (sender_addr, sender_port) = sender.local_addr(); - assert_eq!(s, 4); - assert_eq!(f, sender_addr); - assert_eq!(p, sender_port); - assert_eq!(recvbuffer, buffer); - }); -} - -#[test] -fn check_basic_tcp() { - task::block_on(async { - let mut network = TestingStack::new(); - - let listener = network - .listen("localhost", 0) - .await - .expect("Couldn't set up listener on localhost"); - let (listener_address, listener_port) = listener.local_addr(); - - let listener_task_handle = task::spawn(async move { - dbg!("Starting listener task!!"); - let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection"); - let mut result_buffer = [0u8; 4]; - if let Err(e) = stream.read_exact(&mut result_buffer).await { - dbg!("Error reading buffer from stream: {}", e); - } else { - dbg!("made it through read_exact"); - } - (result_buffer, addr, port) - }); - - let sender_task_handle = task::spawn(async move { - let mut sender = network - .connect(listener_address, listener_port) - .await - .expect("Coudln't connect to listener?"); - let (sender_address, sender_port) = sender.local_addr(); - let send_buffer = [0xa, 0xff, 0xab, 0x1e]; - sender - .write_all(&send_buffer) - .await - .expect("Couldn't send the write buffer"); - sender - .flush() - .await - .expect("Couldn't flush the write buffer"); - sender - .close() - .await - .expect("Couldn't close the write buffer"); - (sender_address, sender_port) - }); - - let (result, result_from, result_from_port) = listener_task_handle.await; - assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]); - let (sender_address, sender_port) = sender_task_handle.await; - assert_eq!(result_from, sender_address); - assert_eq!(result_from_port, sender_port); - }); -} diff --git a/src/network/testing/datagram.rs b/src/network/testing/datagram.rs deleted file mode 100644 index e274797..0000000 --- a/src/network/testing/datagram.rs +++ /dev/null @@ -1,88 +0,0 @@ -use crate::network::address::HasLocalAddress; -use crate::network::datagram::Datagramlike; -use crate::network::testing::{TestStackError, TestingStack}; -use crate::network::SOCKSv5Address; -use async_std::channel::Receiver; -use async_trait::async_trait; -use std::cmp::Ordering; - -pub struct TestDatagram { - context: TestingStack, - my_address: SOCKSv5Address, - my_port: u16, - input_stream: Receiver<(SOCKSv5Address, u16, Vec)>, -} - -impl TestDatagram { - pub fn new( - context: TestingStack, - my_address: SOCKSv5Address, - my_port: u16, - input_stream: Receiver<(SOCKSv5Address, u16, Vec)>, - ) -> Self { - TestDatagram { - context, - my_address, - my_port, - input_stream, - } - } -} - -impl HasLocalAddress for TestDatagram { - fn local_addr(&self) -> (SOCKSv5Address, u16) { - (self.my_address.clone(), self.my_port) - } -} - -#[async_trait] -impl Datagramlike for TestDatagram { - type Error = TestStackError; - - async fn send_to( - &self, - buf: &[u8], - target: SOCKSv5Address, - port: u16, - ) -> Result { - let table = self.context.udp_sockets.lock().await; - match table.get(&(target, port)) { - None => Ok(buf.len()), - Some(sender) => { - sender - .send((self.my_address.clone(), self.my_port, buf.to_vec())) - .await - .map_err(|_| TestStackError::FailureToSend)?; - Ok(buf.len()) - } - } - } - - async fn recv_from( - &self, - buffer: &mut [u8], - ) -> Result<(usize, SOCKSv5Address, u16), Self::Error> { - let (from_addr, from_port, message) = self - .input_stream - .recv() - .await - .map_err(|_| TestStackError::ReceiveFailure)?; - - match message.len().cmp(&buffer.len()) { - Ordering::Greater => { - buffer.copy_from_slice(&message[..buffer.len()]); - Ok((message.len(), from_addr, from_port)) - } - - Ordering::Less => { - (&mut buffer[..message.len()]).copy_from_slice(&message); - Ok((message.len(), from_addr, from_port)) - } - - Ordering::Equal => { - buffer.copy_from_slice(message.as_ref()); - Ok((message.len(), from_addr, from_port)) - } - } - } -} diff --git a/src/network/testing/stream.rs b/src/network/testing/stream.rs deleted file mode 100644 index 2e66423..0000000 --- a/src/network/testing/stream.rs +++ /dev/null @@ -1,192 +0,0 @@ -use crate::network::address::HasLocalAddress; -use crate::network::stream::Streamlike; -use crate::network::SOCKSv5Address; -use async_std::io; -use async_std::io::{Read, Write}; -use async_std::task::{Context, Poll, Waker}; -use std::cell::UnsafeCell; -use std::pin::Pin; -use std::ptr::NonNull; -use std::sync::atomic::{AtomicBool, Ordering}; - -#[derive(Clone)] -pub struct TestingStream { - address: SOCKSv5Address, - port: u16, - read_side: NonNull, - write_side: NonNull, -} - -unsafe impl Send for TestingStream {} -unsafe impl Sync for TestingStream {} - -struct TestingStreamData { - lock: AtomicBool, - writer_dead: AtomicBool, - waiters: UnsafeCell>, - buffer: UnsafeCell>, -} - -unsafe impl Send for TestingStreamData {} -unsafe impl Sync for TestingStreamData {} - -impl TestingStream { - /// Generate a testing stream. Note that this is directional. So, if you want to - /// talk to this stream, you should also generate an `invert()` and pass that to - /// the other thread(s). - pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream { - let read_side_data = TestingStreamData { - lock: AtomicBool::new(false), - writer_dead: AtomicBool::new(false), - waiters: UnsafeCell::new(Vec::new()), - buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), - }; - - let write_side_data = TestingStreamData { - lock: AtomicBool::new(false), - writer_dead: AtomicBool::new(false), - waiters: UnsafeCell::new(Vec::new()), - buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), - }; - - let boxed_rsd = Box::new(read_side_data); - let boxed_wsd = Box::new(write_side_data); - let raw_read_ptr = Box::leak(boxed_rsd); - let raw_write_ptr = Box::leak(boxed_wsd); - - TestingStream { - address, - port, - read_side: NonNull::new(raw_read_ptr).unwrap(), - write_side: NonNull::new(raw_write_ptr).unwrap(), - } - } - - /// Get the flip side of this stream; reads from the inverted side will catch the writes - /// of the original, etc. - pub fn invert(&self) -> TestingStream { - TestingStream { - address: self.address.clone(), - port: self.port, - read_side: self.write_side, - write_side: self.read_side, - } - } -} - -impl TestingStreamData { - fn acquire(&mut self) { - loop { - match self - .lock - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) - { - Err(_) => continue, - Ok(_) => return, - } - } - } - - fn release(&mut self) { - self.lock.store(false, Ordering::SeqCst); - } -} - -impl HasLocalAddress for TestingStream { - fn local_addr(&self) -> (SOCKSv5Address, u16) { - (self.address.clone(), self.port) - } -} - -impl Read for TestingStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - // so, we're going to spin here, which is less than ideal but should work fine - // in practice. we'll obviously need to be very careful to ensure that we keep - // the stuff internal to this spin really short. - let internals = unsafe { self.read_side.as_mut() }; - - internals.acquire(); - - let stream_buffer = internals.buffer.get_mut(); - let amount_available = stream_buffer.len(); - - if amount_available == 0 { - // we wait to do this check until we've determined the buffer is empty, - // so that we make sure to drain any residual stuff in there. - if internals.writer_dead.load(Ordering::SeqCst) { - internals.release(); - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Writer closed the socket.", - ))); - } else { - let waker = cx.waker().clone(); - internals.waiters.get_mut().push(waker); - internals.release(); - return Poll::Pending; - } - } - - let amt_written = if buf.len() >= amount_available { - (&mut buf[0..amount_available]).copy_from_slice(stream_buffer); - stream_buffer.clear(); - amount_available - } else { - let amt_to_copy = buf.len(); - buf.copy_from_slice(&stream_buffer[0..amt_to_copy]); - stream_buffer.copy_within(amt_to_copy.., 0); - let amt_left = amount_available - amt_to_copy; - stream_buffer.resize(amt_left, 0); - amt_to_copy - }; - - internals.release(); - - Poll::Ready(Ok(amt_written)) - } -} - -impl Write for TestingStream { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let internals = unsafe { self.write_side.as_mut() }; - internals.acquire(); - let stream_buffer = internals.buffer.get_mut(); - stream_buffer.extend_from_slice(buf); - for waiter in internals.waiters.get_mut().drain(0..) { - waiter.wake(); - } - internals.release(); - - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) // FIXME: Might consider having this wait until the buffer is empty - } - - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) // FIXME: Might consider putting in some open/closed logic here - } -} - -impl Streamlike for TestingStream {} - -impl Drop for TestingStream { - fn drop(&mut self) { - let internals = unsafe { self.write_side.as_mut() }; - internals.writer_dead.store(true, Ordering::SeqCst); - internals.acquire(); - for waiter in internals.waiters.get_mut().drain(0..) { - waiter.wake(); - } - internals.release(); - } -} diff --git a/src/security_parameters.rs b/src/security_parameters.rs new file mode 100644 index 0000000..24c3365 --- /dev/null +++ b/src/security_parameters.rs @@ -0,0 +1,84 @@ +use std::net::SocketAddr; + +/// The security parameters that you can assign to the server, to make decisions +/// about the weirdos it accepts as users. It is recommended that you only use +/// wide open connections when you're 100% sure that the server will only be +/// accessible locally. +#[derive(Clone)] +pub struct SecurityParameters { + /// Allow completely unauthenticated connections. You should be very, very + /// careful about setting this to true, especially if you don't provide a + /// guard to ensure that you're getting connections from reasonable places. + pub allow_unauthenticated: bool, + /// An optional function that can serve as a firewall for new connections. + /// Return true if the connection should be allowed to continue, false if + /// it shouldn't. This check happens before any data is read from or written + /// to the connecting party. + pub allow_connection: Option bool>, + /// An optional function to check a user name (first argument) and password + /// (second argument). Return true if the username / password is good, false + /// if not. + pub check_password: Option bool>, + /// An optional function to transition the stream from an unencrypted one to + /// an encrypted on. The assumption is you're using something like `rustls` + /// to make this happen; the exact mechanism is outside the scope of this + /// particular crate. If the connection shouldn't be allowed for some reason + /// (a bad certificate or handshake, for example), then return None; otherwise, + /// return the new stream. + pub connect_tls: Option Option<()>>, +} + +impl SecurityParameters { + /// Generates a `SecurityParameters` object that's empty. It won't accept + /// anything, because it has no mechanisms it can use to actually authenticate + /// a user and yet won't allow unauthenticated connections. + pub fn new() -> SecurityParameters { + SecurityParameters { + allow_unauthenticated: false, + allow_connection: None, + check_password: None, + connect_tls: None, + } + } + + /// Generates a `SecurityParameters` object that does not, in any way, + /// restrict who can log in. It also will not induce any transition into + /// TLS. Use this at your own risk ... or, really, just don't use this, + /// ever, and certainly not in production. + pub fn unrestricted() -> SecurityParameters { + SecurityParameters { + allow_unauthenticated: true, + allow_connection: None, + check_password: None, + connect_tls: None, + } + } + + /// Use the provided function to check incoming connections before proceeding + /// with the rest of the handshake. + pub fn check_connections(mut self, checker: fn(&SocketAddr) -> bool) -> SecurityParameters { + self.allow_connection = Some(checker); + self + } + + /// Use the provided function to check usernames and passwords provided + /// to the server. + pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters { + self.check_password = Some(checker); + self + } + + /// Use the provide function to validate a TLS connection, and transition it + /// to the new stream type. If the handshake fails, return `None` instead of + /// `Some`. (And maybe log it somewhere, you know.) + pub fn tls_converter(mut self, converter: fn() -> Option<()>) -> SecurityParameters { + self.connect_tls = Some(converter); + self + } +} + +impl Default for SecurityParameters { + fn default() -> Self { + Self::new() + } +} diff --git a/src/serialize.rs b/src/serialize.rs deleted file mode 100644 index 0ec7fec..0000000 --- a/src/serialize.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::errors::{DeserializationError, SerializationError}; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -pub async fn read_string( - r: &mut R, -) -> Result { - let mut length_buffer = [0; 1]; - - if r.read(&mut length_buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - let target = length_buffer[0] as usize; - - if target == 0 { - return Err(DeserializationError::InvalidEmptyString); - } - - let mut bytestring = vec![0; target]; - read_amt(r, target, &mut bytestring).await?; - - Ok(String::from_utf8(bytestring)?) -} - -pub async fn write_string( - s: &str, - w: &mut W, -) -> Result<(), SerializationError> { - let bytestring = s.as_bytes(); - - if bytestring.is_empty() || bytestring.len() > 255 { - return Err(SerializationError::InvalidStringLength(s.to_string())); - } - - w.write_all(&[bytestring.len() as u8]).await?; - w.write_all(bytestring) - .await - .map_err(SerializationError::IOError) -} - -pub async fn read_amt( - r: &mut R, - amt: usize, - buffer: &mut [u8], -) -> Result<(), DeserializationError> { - let mut amt_read = 0; - - while amt_read < amt { - let chunk_amt = r.read(&mut buffer[amt_read..]).await?; - - if chunk_amt == 0 { - return Err(DeserializationError::NotEnoughData); - } - - amt_read += chunk_amt; - } - - Ok(()) -} diff --git a/src/server.rs b/src/server.rs index 0976a6f..ed7a521 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,157 +1,59 @@ -//! An implementation of a SOCKSv5 server, parameterizable by the security parameters -//! and network stack you want to use. You should implement the server by first -//! setting up the `SecurityParameters`, then initializing the server object, and -//! then running it, as follows: -//! -//! ``` -//! use async_socks5::network::Builtin; -//! use async_socks5::server::{SecurityParameters, SOCKSv5Server}; -//! use std::io; -//! -//! async { -//! let parameters = SecurityParameters::new() -//! .password_check(|u,p| { u == "adam" && p == "evil" }); -//! let network = Builtin::new(); -//! let server = SOCKSv5Server::new(network, parameters); -//! server.start("localhost", 9999).await; -//! // ... do other stuff ... -//! }; -//! -//! ``` -use crate::errors::{AuthenticationError, DeserializationError, SerializationError}; +use std::net::SocketAddr; + +use crate::address::SOCKSv5Address; use crate::messages::{ - AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting, - ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus, + AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandReadError, + ClientConnectionRequest, ClientConnectionRequestReadError, ClientGreeting, + ClientGreetingReadError, ClientUsernamePassword, ClientUsernamePasswordReadError, + ServerAuthResponse, ServerAuthResponseWriteError, ServerChoice, ServerChoiceWriteError, + ServerResponse, ServerResponseStatus, ServerResponseWriteError, }; -use crate::network::address::HasLocalAddress; -use crate::network::generic::Networklike; -use crate::network::listener::{GenericListener, Listenerlike}; -use crate::network::stream::GenericStream; -use crate::network::SOCKSv5Address; -use async_std::io; -use async_std::io::prelude::WriteExt; -use async_std::sync::{Arc, Mutex}; -use async_std::task; -use futures::Stream; -use log::{error, info, trace, warn}; -use std::collections::HashMap; -use std::default::Default; -use std::fmt::{Debug, Display}; +use crate::security_parameters::SecurityParameters; use thiserror::Error; +use tokio::io::{copy_bidirectional, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; -/// A convenient bit of shorthand for an address and port -pub type AddressAndPort = (SOCKSv5Address, u16); - -// Just some shorthand for us. -type ResultHandle = task::JoinHandle>; - -/// A handle representing a SOCKSv5 server, parameterized by the underlying network -/// stack it runs over. #[derive(Clone)] -pub struct SOCKSv5Server { - network: Arc>, - running_servers: Arc>>, +pub struct SOCKSv5Server { security_parameters: SecurityParameters, } -/// The security parameters that you can assign to the server, to make decisions -/// about the weirdos it accepts as users. It is recommended that you only use -/// wide open connections when you're 100% sure that the server will only be -/// accessible locally. -#[derive(Clone)] -pub struct SecurityParameters { - /// Allow completely unauthenticated connections. You should be very, very - /// careful about setting this to true, especially if you don't provide a - /// guard to ensure that you're getting connections from reasonable places. - pub allow_unauthenticated: bool, - /// An optional function that can serve as a firewall for new connections. - /// Return true if the connection should be allowed to continue, false if - /// it shouldn't. This check happens before any data is read from or written - /// to the connecting party. - pub allow_connection: Option bool>, - /// An optional function to check a user name (first argument) and password - /// (second argument). Return true if the username / password is good, false - /// if not. - pub check_password: Option bool>, - /// An optional function to transition the stream from an unencrypted one to - /// an encrypted on. The assumption is you're using something like `rustls` - /// to make this happen; the exact mechanism is outside the scope of this - /// particular crate. If the connection shouldn't be allowed for some reason - /// (a bad certificate or handshake, for example), then return None; otherwise, - /// return the new stream. - pub connect_tls: Option Option>, +#[derive(Clone, Debug, Error, PartialEq)] +pub enum SOCKSv5ServerError { + #[error("Underlying networking error: {0}")] + NetworkingError(String), + #[error("Couldn't negotiate authentication with client.")] + ItsNotUsItsYou, + #[error("Client greeting read problem: {0}")] + GreetingReadProblem(#[from] ClientGreetingReadError), + #[error("Server choice write problem: {0}")] + ChoiceWriteProblem(#[from] ServerChoiceWriteError), + #[error("Failed username/password authentication for user {0}")] + FailedUsernamePassword(String), + #[error("Server authentication response problem: {0}")] + ServerAuthWriteProblem(#[from] ServerAuthResponseWriteError), + #[error("Error reading client username/password: {0}")] + UserPassReadProblem(#[from] ClientUsernamePasswordReadError), + #[error("Error reading client connection command: {0}")] + ClientConnReadProblem(#[from] ClientConnectionCommandReadError), + #[error("Error reading client connection request: {0}")] + ClientRequestReadProblem(#[from] ClientConnectionRequestReadError), + #[error("Error writing server response: {0}")] + ServerResponseWriteProblem(#[from] ServerResponseWriteError), } -impl SecurityParameters { - /// Generates a `SecurityParameters` object that's empty. It won't accept - /// anything, because it has no mechanisms it can use to actually authenticate - /// a user and yet won't allow unauthenticated connections. - pub fn new() -> SecurityParameters { - SecurityParameters { - allow_unauthenticated: false, - allow_connection: None, - check_password: None, - connect_tls: None, - } - } - - /// Generates a `SecurityParameters` object that does not, in any way, - /// restrict who can log in. It also will not induce any transition into - /// TLS. Use this at your own risk ... or, really, just don't use this, - /// ever, and certainly not in production. - pub fn unrestricted() -> SecurityParameters { - SecurityParameters { - allow_unauthenticated: true, - allow_connection: None, - check_password: None, - connect_tls: None, - } - } - - /// Use the provided function to check incoming connections before proceeding - /// with the rest of the handshake. - pub fn check_connections( - mut self, - checker: fn(&SOCKSv5Address, u16) -> bool, - ) -> SecurityParameters { - self.allow_connection = Some(checker); - self - } - - /// Use the provided function to check usernames and passwords provided - /// to the server. - pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters { - self.check_password = Some(checker); - self - } - - /// Use the provide function to validate a TLS connection, and transition it - /// to the new stream type. If the handshake fails, return `None` instead of - /// `Some`. (And maybe log it somewhere, you know.) - pub fn tls_converter( - mut self, - converter: fn(GenericStream) -> Option, - ) -> SecurityParameters { - self.connect_tls = Some(converter); - self +impl From for SOCKSv5ServerError { + fn from(x: std::io::Error) -> SOCKSv5ServerError { + SOCKSv5ServerError::NetworkingError(format!("{}", x)) } } -impl Default for SecurityParameters { - fn default() -> Self { - Self::new() - } -} - -impl SOCKSv5Server { - /// Initialize a SOCKSv5 server for use later on. Once initialize, you can listen on - /// as many addresses and ports as you like; the metadata about the server will be - /// sync'd across all of the instances, should you want to gather that data for some - /// reason. - pub fn new(network: N, security_parameters: SecurityParameters) -> SOCKSv5Server { +impl SOCKSv5Server { + /// Initialize a SOCKSv5 server for use later on. Once initialized, you can listen + /// on as many addresses and ports as you like; the metadata about the server will + /// be synced across all the instances. + pub fn new(security_parameters: SecurityParameters) -> Self { SOCKSv5Server { - network: Arc::new(Mutex::new(network)), - running_servers: Arc::new(Mutex::new(HashMap::new())), security_parameters, } } @@ -159,191 +61,241 @@ impl SOCKSv5Server { /// Start a server on the given address and port. This function returns when it has /// set up its listening socket, but spawns a separate task to actually wait for /// connections. You can query which ones are still active, or see which ones have - /// failed, using some of the other items in this structure. + /// failed, using some of the other functions for this structure. + /// + /// If you don't care what port is assigned to this server, pass 0 in as the port + /// number and one will be chosen for you by the OS. + /// pub async fn start>( &self, addr: A, port: u16, - ) -> Result<(), N::Error> { - // This might seem a little weird, but we do this in a separate block to make it - // as clear as possible to the borrow checker (and the reader) that we only want - // to hold the lock while we're actually calling listen. - let listener = { - let mut network = self.network.lock().await; - network.listen(addr, port).await - }?; + ) -> Result<(), std::io::Error> { + let listener = match addr.into() { + SOCKSv5Address::IP4(x) => TcpListener::bind((x, port)).await?, + SOCKSv5Address::IP6(x) => TcpListener::bind((x, port)).await?, + SOCKSv5Address::Hostname(x) => TcpListener::bind((x, port)).await?, + }; - // this should really be the same as the input, but technically they could've - // thrown some zeros in there and let the underlying network stack decide. So - // we'll just pull this information post-initialization, and maybe get something - // a bit more detailed. - let (my_addr, my_port) = listener.local_addr(); - info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port); + let sockaddr = listener.local_addr()?; + tracing::info!( + "Starting SOCKSv5 server on {}:{}", + sockaddr.ip(), + sockaddr.port() + ); - // OK, spawn off the server loop, and then we'll register this in our list of - // things running. - let new_self = self.clone(); - let task_id = task::spawn(async move { - new_self - .server_loop(listener) - .await - .map_err(|x| format!("Server network error: {}", x)) + let second_life = self.clone(); + + tokio::task::spawn(async move { + if let Err(e) = second_life.server_loop(listener).await { + tracing::error!( + "{}:{}: server network error: {}", + sockaddr.ip(), + sockaddr.port(), + e + ); + } }); - let mut server_map = self.running_servers.lock().await; - server_map.insert((my_addr, my_port), task_id); - Ok(()) } - /// Provide a list of open sockets on the server. - pub async fn open_sockets(&self) -> Vec { - let server_map = self.running_servers.lock().await; - server_map.keys().cloned().collect() - } - - pub fn subserver_results(&mut self) -> impl Stream> { - futures::stream::unfold(self.running_servers.clone(), |locked_map| async move { - let first_server = { - let mut server_map = locked_map.lock().await; - let first_key = server_map.keys().next().cloned()?; - - server_map.remove(&first_key) - }?; - - let result = first_server.await; - Some((result, locked_map)) - }) - } - - async fn server_loop(self, listener: GenericListener) -> Result<(), N::Error> { + /// Run the server loop for a particular listener. This routine will never actually + /// return except in error conditions. + async fn server_loop(self, listener: TcpListener) -> Result<(), std::io::Error> { loop { - let (stream, their_addr, their_port) = listener.accept().await?; - trace!( - "Initial accept of connection from {}:{}", - their_addr, - their_port - ); + let (socket, their_addr) = listener.accept().await?; - // before we do anything, make sure this connection is cool. we don't want to - // waste resources (or parse any data) if this isn't someone we actually care - // about it. - if let Some(checker) = &self.security_parameters.allow_connection { - if !checker(&their_addr, their_port) { - info!( - "Rejecting attempted connection from {}:{}", - their_addr, their_port - ); - continue; + // before we do anything of note, make sure this connection is cool. we don't want + // to waste any resources (and certainly don't want to handle any data!) if this + // isn't someone we want to accept connections from. + tracing::trace!("Initial accept of connection from {}", their_addr); + if let Some(checker) = self.security_parameters.allow_connection { + if !checker(&their_addr) { + tracing::info!("Rejecting attempted connection from {}", their_addr,); } + continue; } - // throw this off into another task to take from here. We could to the rest - // of this handshake here, but there's a chance that an adversarial connection - // could just stall us out, and keep us from doing the next connection. So ... - // we'll potentially spin off the task early. + // continue this work in another task. we could absolutely do this work here, + // but just in case someone starts doing slow responses (or other nasty things), + // we want to make sure that that doesn't slow down our ability to accept other + // requests. let me_again = self.clone(); - task::spawn(async move { - me_again - .authenticate_step(their_addr, their_port, stream) - .await; + tokio::task::spawn(async move { + if let Err(e) = me_again.start_authentication(their_addr, socket).await { + tracing::error!("{}: server handler failure: {}", their_addr, e); + } }); } } - async fn authenticate_step( + /// Start the authentication phase of the SOCKS handshake. This may be very short, and + /// is the first stage of handling a request. This will only really return on errors. + async fn start_authentication( self, - their_addr: SOCKSv5Address, - their_port: u16, - base_stream: GenericStream, - ) { - // Turn this stream into one where we've authenticated the other side. Or, you - // know, don't, and just restart this loop. - let mut authenticated_stream = - match run_authentication(&self.security_parameters, base_stream).await { - Ok(authed_stream) => authed_stream, - Err(e) => { - warn!( - "Failure running authentication from {}:{}: {}", - their_addr, their_port, e - ); - return; - } - }; + their_addr: SocketAddr, + mut socket: TcpStream, + ) -> Result<(), SOCKSv5ServerError> { + let greeting = ClientGreeting::read(&mut socket).await?; - // Figure out what the client actually wants from this connection, and - // then dispatch a task to deal with that. - let mccr = ClientConnectionRequest::read(&mut authenticated_stream).await; - match mccr { - Err(e) => warn!("Failure figuring out what the client wanted: {}", e), - Ok(ccr) => match ccr.command_code { - ClientConnectionCommand::AssociateUDPPort => self - .handle_udp_request(authenticated_stream, ccr, their_addr, their_port) - .await - .unwrap_or_else(|e| warn!("Internal server error in UDP association: {}", e)), - ClientConnectionCommand::EstablishTCPPortBinding => self - .handle_tcp_bind(authenticated_stream, ccr, their_addr, their_port) - .await - .unwrap_or_else(|e| warn!("Internal server error in TCP bind: {}", e)), - ClientConnectionCommand::EstablishTCPStream => self - .handle_tcp_forward(authenticated_stream, ccr, their_addr, their_port) - .await - .unwrap_or_else(|e| warn!("Internal server error in TCP forward: {}", e)), - }, + match choose_authentication_method(&self.security_parameters, &greeting.acceptable_methods) + { + // it's not us, it's you. (we're just going to say no.) + None => { + tracing::trace!( + "{}: Failed to find acceptable authentication method.", + their_addr, + ); + let rejection_letter = ServerChoice::rejection(); + + rejection_letter.write(&mut socket).await?; + socket.flush().await?; + + Err(SOCKSv5ServerError::ItsNotUsItsYou) + } + + // the gold standard. great choice. + Some(ChosenMethod::TLS(_converter)) => { + unimplemented!() + } + + // well, I guess this is something? + Some(ChosenMethod::Password(checker)) => { + tracing::trace!( + "{}: Choosing username/password for authentication.", + their_addr, + ); + let ok_lets_do_password = + ServerChoice::option(AuthenticationMethod::UsernameAndPassword); + ok_lets_do_password.write(&mut socket).await?; + socket.flush().await?; + + let their_info = ClientUsernamePassword::read(&mut socket).await?; + if checker(&their_info.username, &their_info.password) { + let its_all_good = ServerAuthResponse::success(); + its_all_good.write(&mut socket).await?; + socket.flush().await?; + self.choose_mode(socket, their_addr).await + } else { + let yeah_no = ServerAuthResponse::failure(); + yeah_no.write(&mut socket).await?; + socket.flush().await?; + Err(SOCKSv5ServerError::FailedUsernamePassword( + their_info.username, + )) + } + } + + // Um. I guess we're doing this unchecked. Yay? + Some(ChosenMethod::None) => { + tracing::trace!( + "{}: Just skipping the whole authentication thing.", + their_addr, + ); + let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None); + nothin_i_guess.write(&mut socket).await?; + socket.flush().await?; + self.choose_mode(socket, their_addr).await + } } } - async fn handle_udp_request( + /// Determine which of the modes we might want this particular connection to run + /// in. + async fn choose_mode( self, - stream: GenericStream, - ccr: ClientConnectionRequest, - their_addr: SOCKSv5Address, - their_port: u16, - ) -> Result<(), ServerError> { - // Let the user know that we're maybe making progress - let (my_addr, my_port) = stream.local_addr(); - info!( - "[{}:{}] Handling UDP bind request from {}:{}, seeking to bind {}:{}", - my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port - ); - - unimplemented!() + mut socket: TcpStream, + their_addr: SocketAddr, + ) -> Result<(), SOCKSv5ServerError> { + let ccr = ClientConnectionRequest::read(&mut socket).await?; + match ccr.command_code { + ClientConnectionCommand::AssociateUDPPort => { + self.handle_udp_request(socket, their_addr, ccr).await? + } + ClientConnectionCommand::EstablishTCPStream => { + self.handle_tcp_request(socket, their_addr, ccr).await? + } + ClientConnectionCommand::EstablishTCPPortBinding => { + self.handle_tcp_binding_request(socket, their_addr, ccr) + .await? + } + } + Ok(()) } - async fn handle_tcp_forward( + /// Handle UDP forwarding requests + #[allow(unreachable_code)] + async fn handle_udp_request( self, - mut stream: GenericStream, + stream: TcpStream, + their_addr: SocketAddr, ccr: ClientConnectionRequest, - their_addr: SOCKSv5Address, - their_port: u16, - ) -> Result<(), ServerError> { + ) -> Result<(), SOCKSv5ServerError> { + let my_addr = stream.local_addr()?; + tracing::info!( + "[{}:{}] Handling UDP bind request from {}:{}, seeking to bind towards {}:{}", + my_addr.ip(), + my_addr.port(), + their_addr.ip(), + their_addr.port(), + ccr.destination_address, + ccr.destination_port + ); + + let _socket = match ccr.destination_address.clone() { + SOCKSv5Address::IP4(x) => UdpSocket::bind((x, ccr.destination_port)).await?, + SOCKSv5Address::IP6(x) => UdpSocket::bind((x, ccr.destination_port)).await?, + SOCKSv5Address::Hostname(x) => UdpSocket::bind((x, ccr.destination_port)).await?, + }; + + // OK, it worked. In order to mitigate an infinitesimal chance of a race condition, we're + // going to set up our forwarding tasks first, and then return the result to the user. (Note, + // we'd have to be slightly more precious in order to ensure a lack of race conditions, as + // the runtime could take forever to actually start these tasks, but I'm not ready to be + // bothered by this, yet. FIXME.) + unimplemented!(); + + // Cool; now we can get the result out to the user. + let bound_address = _socket.local_addr()?; + let response = ServerResponse { + status: ServerResponseStatus::RequestGranted, + bound_address: bound_address.ip().into(), + bound_port: bound_address.port(), + }; + + response.write(&mut stream).await?; + Ok(()) + } + + /// Handle TCP forwarding requests + async fn handle_tcp_request( + self, + mut stream: TcpStream, + their_addr: SocketAddr, + ccr: ClientConnectionRequest, + ) -> Result<(), SOCKSv5ServerError> { // Let the user know that we're maybe making progress - let (my_addr, my_port) = stream.local_addr(); - info!( - "[{}:{}] Handling TCP forward request from {}:{}, seeking to connect to {}:{}", - my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port + let my_addr = stream.local_addr()?; + tracing::info!( + "[{}] Handling TCP forward request from {}, seeking to connect to {}:{}", + my_addr, + their_addr, + ccr.destination_address, + ccr.destination_port ); // OK, first thing's first: We need to actually connect to the server that the user // wants us to connect to. - let connection_res = { - let mut network = self.network.lock().await; - network - .connect(ccr.destination_address.clone(), ccr.destination_port) - .await - }; - - let outgoing_stream = match connection_res { - Ok(x) => x, - Err(e) => { - error!("Failed to connect to {}: {}", ccr.destination_address, e); - let response = ServerResponse::error(&e); - response.write(&mut stream).await?; - return Err(ServerError::NetworkError(e)); + let outgoing_stream = match &ccr.destination_address { + SOCKSv5Address::IP4(x) => TcpStream::connect((*x, ccr.destination_port)).await?, + SOCKSv5Address::IP6(x) => TcpStream::connect((*x, ccr.destination_port)).await?, + SOCKSv5Address::Hostname(x) => { + TcpStream::connect((x.as_ref(), ccr.destination_port)).await? } }; - trace!( + tracing::trace!( "Connection established to {}:{}", ccr.destination_address, ccr.destination_port @@ -352,117 +304,117 @@ impl SOCKSv5Server { // Now, for whatever reason -- and this whole thing sent me down a garden path // in understanding how this whole protocol works -- we tell the user what address // and port we bound for that connection. - let (bound_address, bound_port) = outgoing_stream.local_addr(); + let bound_address = outgoing_stream.local_addr()?; let response = ServerResponse { status: ServerResponseStatus::RequestGranted, - bound_address, - bound_port, + bound_address: bound_address.ip().into(), + bound_port: bound_address.port(), }; response.write(&mut stream).await?; // so now tie our streams together, and we're good to go - tie_streams( - format!("{}:{}", their_addr, their_port), - stream, - format!("{}:{}", ccr.destination_address, ccr.destination_port), - outgoing_stream, - ) - .await; + tie_streams(stream, outgoing_stream).await; + Ok(()) } - async fn handle_tcp_bind( + /// Handle TCP binding requests + async fn handle_tcp_binding_request( self, - mut stream: GenericStream, + mut stream: TcpStream, + their_addr: SocketAddr, ccr: ClientConnectionRequest, - their_addr: SOCKSv5Address, - their_port: u16, - ) -> Result<(), ServerError> { + ) -> Result<(), SOCKSv5ServerError> { // Let the user know that we're maybe making progress - let (my_addr, my_port) = stream.local_addr(); - info!( - "[{}:{}] Handling TCP bind request from {}:{}, seeking to bind {}:{}", - my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port + let my_addr = stream.local_addr()?; + tracing::info!( + "[{}] Handling TCP bind request from {}, seeking to bind {}:{}", + my_addr, + their_addr, + ccr.destination_address, + ccr.destination_port ); // OK, we have to bind the darn socket first. - let port_binding = { - let mut network = self.network.lock().await; - network.listen(their_addr.clone(), their_port).await - } - .map_err(ServerError::NetworkError)?; + let listener_port = match &their_addr { + SocketAddr::V4(_) => TcpSocket::new_v4(), + SocketAddr::V6(_) => TcpSocket::new_v6(), + }?; + // FIXME: Might want to bind on a particular interface, based on a + // config flag, at some point. + let listener = listener_port.listen(1)?; // Tell them what we bound, just in case they want to inform anyone. - let (bound_address, bound_port) = port_binding.local_addr(); + let bound_address = listener.local_addr()?; let response = ServerResponse { status: ServerResponseStatus::RequestGranted, - bound_address, - bound_port, + bound_address: bound_address.ip().into(), + bound_port: bound_address.port(), }; response.write(&mut stream).await?; // Wait politely for someone to talk to us. - let (other, other_addr, other_port) = port_binding - .accept() - .await - .map_err(ServerError::NetworkError)?; + let (other, other_addr) = listener.accept().await?; let info = ServerResponse { status: ServerResponseStatus::RequestGranted, - bound_address: other_addr.clone(), - bound_port: other_port, + bound_address: other_addr.ip().into(), + bound_port: other_addr.port(), }; info.write(&mut stream).await?; - tie_streams( - format!("{}:{}", their_addr, their_port), - stream, - format!("{}:{}", other_addr, other_port), - other, - ) - .await; + tie_streams(stream, other).await; + Ok(()) } } -async fn tie_streams( - left_name: String, - left: GenericStream, - right_name: String, - right: GenericStream, -) { - // Now that we've informed them of that, we set up one task to transfer information - // from the current stream (`stream`) to the connection (`outgoing_stream`), and - // another task that goes in the reverse direction. - // - // I've chosen to start two fresh tasks and let this one die; I'm not sure that - // this is the right approach. My only rationale is that this might let some - // memory we might have accumulated along the way drop more easily, but that - // might not actually matter. - let mut from_left = left.clone(); - let mut from_right = right.clone(); - let mut to_left = left; - let mut to_right = right; - let left_right_name = format!("{} >--> {}", left_name, right_name); - let right_left_name = format!("{} <--< {}", left_name, right_name); +async fn tie_streams(mut left: TcpStream, mut right: TcpStream) { + let left_local_addr = left + .local_addr() + .expect("couldn't get left local address in tie_streams"); + let left_peer_addr = left + .peer_addr() + .expect("couldn't get left peer address in tie_streams"); + let right_local_addr = right + .local_addr() + .expect("couldn't get right local address in tie_streams"); + let right_peer_addr = right + .peer_addr() + .expect("couldn't get right peer address in tie_streams"); - task::spawn(async move { - info!("Spawned {} task", left_right_name); - if let Err(e) = io::copy(&mut from_left, &mut to_right).await { - warn!("{} connection failed with: {}", left_right_name, e); - } - }); - - task::spawn(async move { - info!("Spawned {} task", right_left_name); - if let Err(e) = io::copy(&mut from_right, &mut to_left).await { - warn!("{} connection failed with: {}", right_left_name, e); + tokio::task::spawn(async move { + tracing::info!( + "Setting up linkage {}/{} <-> {}/{}", + left_peer_addr, + left_local_addr, + right_local_addr, + right_peer_addr + ); + match copy_bidirectional(&mut left, &mut right).await { + Ok((l2r, r2l)) => tracing::info!( + "Shutting down linkage {}/{} <-> {}/{} (sent {} and {} bytes, respectively)", + left_peer_addr, + left_local_addr, + right_local_addr, + right_peer_addr, + l2r, + r2l + ), + Err(e) => tracing::warn!( + "Shutting down linkage {}/{} <-> {}/{} with error: {}", + left_peer_addr, + left_local_addr, + right_local_addr, + right_peer_addr, + e + ), } }); } #[allow(clippy::upper_case_acronyms)] enum ChosenMethod { - TLS(fn(GenericStream) -> Option), + TLS(fn() -> Option<()>), Password(fn(&str, &str) -> bool), None, } @@ -567,7 +519,7 @@ fn reasonable_auth_method_choices() { ); // OK, cool. If we have a TLS handler, that shouldn't actually make a difference. - params.connect_tls = Some(|_| unimplemented!()); + params.connect_tls = Some(|| unimplemented!()); assert_eq!( choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from), None @@ -580,7 +532,7 @@ fn reasonable_auth_method_choices() { None ); // but if we have a handler, and they go for it, we use it. - params.connect_tls = Some(|_| unimplemented!()); + params.connect_tls = Some(|| unimplemented!()); assert_eq!( choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from), Some(AuthenticationMethod::SSL) @@ -600,75 +552,3 @@ fn reasonable_auth_method_choices() { Some(AuthenticationMethod::SSL) ); } - -async fn run_authentication( - params: &SecurityParameters, - mut stream: GenericStream, -) -> Result { - let greeting = ClientGreeting::read(&mut stream).await?; - - match choose_authentication_method(params, &greeting.acceptable_methods) { - // it's not us, it's you - None => { - trace!("Failed to find acceptable authentication method."); - let rejection_letter = ServerChoice::rejection(); - - rejection_letter.write(&mut stream).await?; - stream.flush().await?; - - Err(AuthenticationError::ItsNotUsItsYou) - } - - // the gold standard. great choice. - Some(ChosenMethod::TLS(converter)) => { - trace!("Choosing TLS for authentication."); - let lets_do_this = ServerChoice::option(AuthenticationMethod::SSL); - lets_do_this.write(&mut stream).await?; - stream.flush().await?; - - converter(stream).ok_or(AuthenticationError::FailedTLSHandshake) - } - - // well, I guess this is something? - Some(ChosenMethod::Password(checker)) => { - trace!("Choosing Username/Password for authentication."); - let ok_lets_do_password = - ServerChoice::option(AuthenticationMethod::UsernameAndPassword); - ok_lets_do_password.write(&mut stream).await?; - stream.flush().await?; - - let their_info = ClientUsernamePassword::read(&mut stream).await?; - if checker(&their_info.username, &their_info.password) { - let its_all_good = ServerAuthResponse::success(); - its_all_good.write(&mut stream).await?; - stream.flush().await?; - Ok(stream) - } else { - let yeah_no = ServerAuthResponse::failure(); - yeah_no.write(&mut stream).await?; - stream.flush().await?; - Err(AuthenticationError::FailedUsernamePassword( - their_info.username, - )) - } - } - - Some(ChosenMethod::None) => { - trace!("Just skipping the whole authentication thing."); - let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None); - nothin_i_guess.write(&mut stream).await?; - stream.flush().await?; - Ok(stream) - } - } -} - -#[derive(Error, Debug)] -pub enum ServerError { - #[error("Error in deserialization: {0}")] - DeserializationError(#[from] DeserializationError), - #[error("Error in serialization: {0}")] - SerializationError(#[from] SerializationError), - #[error("Underlying network error: {0}")] - NetworkError(E), -}