From d1143a414c575a9c7fbde6ad4b33e7160f5b7221 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sun, 27 Jun 2021 16:53:57 -0700 Subject: [PATCH] Split out the messages into individual files,, and add negative tests, so we can aspire towards good coverage. --- .gitignore | 1 + src/errors.rs | 115 +++++ src/messages.rs | 620 +---------------------- src/messages/authentication_method.rs | 156 ++++++ src/messages/client_command.rs | 164 ++++++ src/messages/client_greeting.rs | 135 +++++ src/messages/client_username_password.rs | 86 ++++ src/messages/server_auth_response.rs | 83 +++ src/messages/server_choice.rs | 86 ++++ src/messages/server_response.rs | 202 ++++++++ src/messages/utils.rs | 33 ++ src/network/address.rs | 11 +- src/server.rs | 2 +- 13 files changed, 1087 insertions(+), 607 deletions(-) create mode 100644 src/messages/authentication_method.rs create mode 100644 src/messages/client_command.rs create mode 100644 src/messages/client_greeting.rs create mode 100644 src/messages/client_username_password.rs create mode 100644 src/messages/server_auth_response.rs create mode 100644 src/messages/server_choice.rs create mode 100644 src/messages/server_response.rs create mode 100644 src/messages/utils.rs diff --git a/.gitignore b/.gitignore index 96ef6c0..53715f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +tarpaulin-report.html \ No newline at end of file diff --git a/src/errors.rs b/src/errors.rs index aa1a73d..974a45e 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -26,6 +26,69 @@ pub enum DeserializationError { InvalidServerResponse(u8), } +#[test] +fn des_error_reasonable_equals() { + let invalid_version = DeserializationError::InvalidVersion(1, 2); + assert_eq!(invalid_version, invalid_version); + let not_enough = DeserializationError::NotEnoughData; + assert_eq!(not_enough, not_enough); + let invalid_empty = DeserializationError::InvalidEmptyString; + assert_eq!(invalid_empty, invalid_empty); + let auth_method = DeserializationError::AuthenticationMethodError( + AuthenticationDeserializationError::NoDataFound, + ); + assert_eq!(auth_method, auth_method); + let utf8 = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err()); + assert_eq!(utf8, utf8); + let invalid_address = DeserializationError::InvalidAddressType(3); + assert_eq!(invalid_address, invalid_address); + let invalid_client_cmd = DeserializationError::InvalidClientCommand(32); + assert_eq!(invalid_client_cmd, invalid_client_cmd); + let invalid_server_resp = DeserializationError::InvalidServerResponse(42); + assert_eq!(invalid_server_resp, invalid_server_resp); + + assert_ne!(invalid_version, invalid_address); + assert_ne!(not_enough, invalid_empty); + assert_ne!(auth_method, invalid_client_cmd); + assert_ne!(utf8, invalid_server_resp); +} + +impl PartialEq for DeserializationError { + fn eq(&self, other: &DeserializationError) -> bool { + match (self, other) { + ( + &DeserializationError::InvalidVersion(a, b), + &DeserializationError::InvalidVersion(x, y), + ) => (a == x) && (b == y), + (&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true, + ( + &DeserializationError::InvalidEmptyString, + &DeserializationError::InvalidEmptyString, + ) => true, + ( + &DeserializationError::AuthenticationMethodError(ref a), + &DeserializationError::AuthenticationMethodError(ref b), + ) => a == b, + (&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => { + a == b + } + ( + &DeserializationError::InvalidAddressType(a), + &DeserializationError::InvalidAddressType(b), + ) => a == b, + ( + &DeserializationError::InvalidClientCommand(a), + &DeserializationError::InvalidClientCommand(b), + ) => a == b, + ( + &DeserializationError::InvalidServerResponse(a), + &DeserializationError::InvalidServerResponse(b), + ) => a == b, + (_, _) => false, + } + } +} + /// All the errors that can occur trying to turn SOCKSv5 message structures /// into raw bytes. There's a few places that the message structures allow /// for information that can't be serialized; often, you have to be careful @@ -40,6 +103,32 @@ pub enum SerializationError { IOError(#[from] io::Error), } +#[test] +fn ser_err_reasonable_equals() { + let too_many = SerializationError::TooManyAuthMethods(512); + assert_eq!(too_many, too_many); + let invalid_str = SerializationError::InvalidStringLength("Whoopsy!".to_string()); + assert_eq!(invalid_str, invalid_str); + + assert_ne!(too_many, invalid_str); +} + +impl PartialEq for SerializationError { + fn eq(&self, other: &SerializationError) -> bool { + match (self, other) { + ( + &SerializationError::TooManyAuthMethods(a), + &SerializationError::TooManyAuthMethods(b), + ) => a == b, + ( + &SerializationError::InvalidStringLength(ref a), + &SerializationError::InvalidStringLength(ref b), + ) => a == b, + (_, _) => false, + } + } +} + #[derive(Error, Debug)] pub enum AuthenticationDeserializationError { #[error("No data found deserializing SOCKS authentication type")] @@ -49,3 +138,29 @@ pub enum AuthenticationDeserializationError { #[error("IO error reading SOCKS authentication type: {0}")] IOError(#[from] io::Error), } + +#[test] +fn auth_des_err_reasonable_equals() { + let no_data = AuthenticationDeserializationError::NoDataFound; + assert_eq!(no_data, no_data); + let invalid_auth = AuthenticationDeserializationError::InvalidAuthenticationByte(39); + assert_eq!(invalid_auth, invalid_auth); + + assert_ne!(no_data, invalid_auth); +} + +impl PartialEq for AuthenticationDeserializationError { + fn eq(&self, other: &AuthenticationDeserializationError) -> bool { + match (self, other) { + ( + &AuthenticationDeserializationError::NoDataFound, + &AuthenticationDeserializationError::NoDataFound, + ) => true, + ( + &AuthenticationDeserializationError::InvalidAuthenticationByte(x), + &AuthenticationDeserializationError::InvalidAuthenticationByte(y), + ) => x == y, + (_, _) => false, + } + } +} diff --git a/src/messages.rs b/src/messages.rs index f3632e5..b85209a 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,604 +1,16 @@ -use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError}; -use crate::network::SOCKSv5Address; -use crate::serialize::{read_amt, read_string, write_string}; -#[cfg(test)] -use async_std::task; -#[cfg(test)] -use futures::io::Cursor; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use log::warn; -#[cfg(test)] -use quickcheck::{quickcheck, Arbitrary, Gen}; -use std::fmt; -use std::net::Ipv4Addr; -use std::pin::Pin; -use thiserror::Error; - -/// Client greetings are the first message sent in a SOCKSv5 session. They -/// identify that there's a client that wants to talk to a server, and that -/// they can support any of the provided mechanisms for authenticating to -/// said server. (It feels weird that the offer/choice goes this way instead -/// of the reverse, but whatever.) -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct ClientGreeting { - pub acceptable_methods: Vec, -} - -impl ClientGreeting { - pub async fn read( - r: Pin<&mut R>, - ) -> Result { - let mut buffer = [0; 1]; - let raw_r = Pin::into_inner(r); - - if raw_r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); - } - - if raw_r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize); - for _ in 0..buffer[0] { - acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?); - } - - Ok(ClientGreeting { acceptable_methods }) - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - if self.acceptable_methods.len() > 255 { - return Err(SerializationError::TooManyAuthMethods( - self.acceptable_methods.len(), - )); - } - - let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2); - buffer.push(5); - buffer.push(self.acceptable_methods.len() as u8); - w.write_all(&buffer).await?; - for authmeth in self.acceptable_methods.iter() { - authmeth.write(w).await?; - } - Ok(()) - } -} - -#[cfg(test)] -impl Arbitrary for ClientGreeting { - fn arbitrary(g: &mut Gen) -> ClientGreeting { - let amt = u8::arbitrary(g); - let mut acceptable_methods = Vec::with_capacity(amt as usize); - - for _ in 0..amt { - acceptable_methods.push(AuthenticationMethod::arbitrary(g)); - } - - ClientGreeting { acceptable_methods } - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct ServerChoice { - pub chosen_method: AuthenticationMethod, -} - -impl ServerChoice { - pub async fn read( - mut r: Pin<&mut R>, - ) -> Result { - let mut buffer = [0; 1]; - - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); - } - - let chosen_method = AuthenticationMethod::read(r).await?; - - Ok(ServerChoice { chosen_method }) - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - w.write_all(&[5]).await?; - self.chosen_method.write(w).await - } -} - -#[cfg(test)] -impl Arbitrary for ServerChoice { - fn arbitrary(g: &mut Gen) -> ServerChoice { - ServerChoice { - chosen_method: AuthenticationMethod::arbitrary(g), - } - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct ClientUsernamePassword { - pub username: String, - pub password: String, -} - -impl ClientUsernamePassword { - pub async fn read( - r: Pin<&mut R>, - ) -> Result { - let mut buffer = [0; 1]; - let raw_r = Pin::into_inner(r); - - if raw_r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - if buffer[0] != 1 { - return Err(DeserializationError::InvalidVersion(1, buffer[0])); - } - - let username = read_string(Pin::new(raw_r)).await?; - let password = read_string(Pin::new(raw_r)).await?; - - Ok(ClientUsernamePassword { username, password }) - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - w.write_all(&[1]).await?; - write_string(&self.username, w).await?; - write_string(&self.password, w).await - } -} - -#[cfg(test)] -impl Arbitrary for ClientUsernamePassword { - fn arbitrary(g: &mut Gen) -> Self { - let username = arbitrary_socks_string(g); - let password = arbitrary_socks_string(g); - - ClientUsernamePassword { username, password } - } -} - -#[cfg(test)] -pub fn arbitrary_socks_string(g: &mut Gen) -> String { - loop { - let mut potential = String::arbitrary(g); - - potential.truncate(255); - let bytestring = potential.as_bytes(); - - if bytestring.len() > 0 && bytestring.len() < 256 { - return potential; - } - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct ServerAuthResponse { - pub success: bool, -} - -impl ServerAuthResponse { - pub async fn read( - mut r: Pin<&mut R>, - ) -> Result { - let mut buffer = [0; 1]; - - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - if buffer[0] != 1 { - return Err(DeserializationError::InvalidVersion(1, buffer[0])); - } - - if r.read(&mut buffer).await? == 0 { - return Err(DeserializationError::NotEnoughData); - } - - Ok(ServerAuthResponse { - success: buffer[0] == 0, - }) - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - w.write_all(&[1]).await?; - w.write_all(&[if self.success { 0x00 } else { 0xde }]) - .await?; - Ok(()) - } -} - -#[cfg(test)] -impl Arbitrary for ServerAuthResponse { - fn arbitrary(g: &mut Gen) -> ServerAuthResponse { - let success = bool::arbitrary(g); - ServerAuthResponse { success } - } -} - -#[allow(clippy::upper_case_acronyms)] -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AuthenticationMethod { - None, - GSSAPI, - UsernameAndPassword, - ChallengeHandshake, - ChallengeResponse, - SSL, - NDS, - MultiAuthenticationFramework, - JSONPropertyBlock, - PrivateMethod(u8), - NoAcceptableMethods, -} - -impl fmt::Display for AuthenticationMethod { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - AuthenticationMethod::None => write!(f, "No authentication"), - AuthenticationMethod::GSSAPI => write!(f, "GSS-API"), - AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"), - AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"), - AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"), - AuthenticationMethod::SSL => write!(f, "SSL"), - AuthenticationMethod::NDS => write!(f, "NDS Authentication"), - AuthenticationMethod::MultiAuthenticationFramework => { - write!(f, "Multi-Authentication Framework") - } - AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"), - AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m), - AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"), - } - } -} - -impl AuthenticationMethod { - pub async fn read( - mut r: Pin<&mut R>, - ) -> Result { - let mut byte_buffer = [0u8; 1]; - let amount_read = r.read(&mut byte_buffer).await?; - - if amount_read == 0 { - return Err(AuthenticationDeserializationError::NoDataFound.into()); - } - - match byte_buffer[0] { - 0 => Ok(AuthenticationMethod::None), - 1 => Ok(AuthenticationMethod::GSSAPI), - 2 => Ok(AuthenticationMethod::UsernameAndPassword), - 3 => Ok(AuthenticationMethod::ChallengeHandshake), - 5 => Ok(AuthenticationMethod::ChallengeResponse), - 6 => Ok(AuthenticationMethod::SSL), - 7 => Ok(AuthenticationMethod::NDS), - 8 => Ok(AuthenticationMethod::MultiAuthenticationFramework), - 9 => Ok(AuthenticationMethod::JSONPropertyBlock), - x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)), - 0xff => Ok(AuthenticationMethod::NoAcceptableMethods), - e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()), - } - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - let value = match self { - AuthenticationMethod::None => 0, - AuthenticationMethod::GSSAPI => 1, - AuthenticationMethod::UsernameAndPassword => 2, - AuthenticationMethod::ChallengeHandshake => 3, - AuthenticationMethod::ChallengeResponse => 5, - AuthenticationMethod::SSL => 6, - AuthenticationMethod::NDS => 7, - AuthenticationMethod::MultiAuthenticationFramework => 8, - AuthenticationMethod::JSONPropertyBlock => 9, - AuthenticationMethod::PrivateMethod(pm) => *pm, - AuthenticationMethod::NoAcceptableMethods => 0xff, - }; - - Ok(w.write_all(&[value]).await?) - } -} - -#[cfg(test)] -impl Arbitrary for AuthenticationMethod { - fn arbitrary(g: &mut Gen) -> AuthenticationMethod { - let mut vals = vec![ - AuthenticationMethod::None, - AuthenticationMethod::GSSAPI, - AuthenticationMethod::UsernameAndPassword, - AuthenticationMethod::ChallengeHandshake, - AuthenticationMethod::ChallengeResponse, - AuthenticationMethod::SSL, - AuthenticationMethod::NDS, - AuthenticationMethod::MultiAuthenticationFramework, - AuthenticationMethod::JSONPropertyBlock, - AuthenticationMethod::NoAcceptableMethods, - ]; - for x in 0x80..0xffu8 { - vals.push(AuthenticationMethod::PrivateMethod(x)); - } - g.choose(&vals).unwrap().clone() - } -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum ClientConnectionCommand { - EstablishTCPStream, - EstablishTCPPortBinding, - AssociateUDPPort, -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct ClientConnectionRequest { - pub command_code: ClientConnectionCommand, - pub destination_address: SOCKSv5Address, - pub destination_port: u16, -} - -impl ClientConnectionRequest { - pub async fn read( - r: Pin<&mut R>, - ) -> Result { - let mut buffer = [0; 2]; - let raw_r = Pin::into_inner(r); - - read_amt(Pin::new(raw_r), 2, &mut buffer).await?; - - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); - } - - let command_code = match buffer[1] { - 0x01 => ClientConnectionCommand::EstablishTCPStream, - 0x02 => ClientConnectionCommand::EstablishTCPPortBinding, - 0x03 => ClientConnectionCommand::AssociateUDPPort, - x => return Err(DeserializationError::InvalidClientCommand(x)), - }; - - let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?; - - read_amt(Pin::new(raw_r), 2, &mut buffer).await?; - let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); - - Ok(ClientConnectionRequest { - command_code, - destination_address, - destination_port, - }) - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - let command = match self.command_code { - ClientConnectionCommand::EstablishTCPStream => 1, - ClientConnectionCommand::EstablishTCPPortBinding => 2, - ClientConnectionCommand::AssociateUDPPort => 3, - }; - - w.write_all(&[5, command]).await?; - self.destination_address.write(w).await?; - w.write_all(&[ - (self.destination_port >> 8) as u8, - (self.destination_port & 0xffu16) as u8, - ]) - .await - .map_err(SerializationError::IOError) - } -} - -#[cfg(test)] -impl Arbitrary for ClientConnectionCommand { - fn arbitrary(g: &mut Gen) -> ClientConnectionCommand { - g.choose(&[ - ClientConnectionCommand::EstablishTCPStream, - ClientConnectionCommand::EstablishTCPPortBinding, - ClientConnectionCommand::AssociateUDPPort, - ]) - .unwrap() - .clone() - } -} - -#[cfg(test)] -impl Arbitrary for ClientConnectionRequest { - fn arbitrary(g: &mut Gen) -> Self { - let command_code = ClientConnectionCommand::arbitrary(g); - let destination_address = SOCKSv5Address::arbitrary(g); - let destination_port = u16::arbitrary(g); - - ClientConnectionRequest { - command_code, - destination_address, - destination_port, - } - } -} - -#[derive(Clone, Debug, Eq, Error, PartialEq)] -pub enum ServerResponseStatus { - #[error("Actually, everything's fine (weird to see this in an error)")] - RequestGranted, - #[error("General server failure")] - GeneralFailure, - #[error("Connection not allowed by policy rule")] - ConnectionNotAllowedByRule, - #[error("Network unreachable")] - NetworkUnreachable, - #[error("Host unreachable")] - HostUnreachable, - #[error("Connection refused")] - ConnectionRefused, - #[error("TTL expired")] - TTLExpired, - #[error("Command not supported")] - CommandNotSupported, - #[error("Address type not supported")] - AddressTypeNotSupported, -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct ServerResponse { - pub status: ServerResponseStatus, - pub bound_address: SOCKSv5Address, - pub bound_port: u16, -} - -impl ServerResponse { - pub fn error>(resp: E) -> ServerResponse { - ServerResponse { - status: resp.into(), - bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)), - bound_port: 0, - } - } -} - -impl ServerResponse { - pub async fn read( - r: Pin<&mut R>, - ) -> Result { - let mut buffer = [0; 3]; - let raw_r = Pin::into_inner(r); - - read_amt(Pin::new(raw_r), 3, &mut buffer).await?; - - if buffer[0] != 5 { - return Err(DeserializationError::InvalidVersion(5, buffer[0])); - } - - if buffer[2] != 0 { - warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes."); - } - - let status = match buffer[1] { - 0x00 => ServerResponseStatus::RequestGranted, - 0x01 => ServerResponseStatus::GeneralFailure, - 0x02 => ServerResponseStatus::ConnectionNotAllowedByRule, - 0x03 => ServerResponseStatus::NetworkUnreachable, - 0x04 => ServerResponseStatus::HostUnreachable, - 0x05 => ServerResponseStatus::ConnectionRefused, - 0x06 => ServerResponseStatus::TTLExpired, - 0x07 => ServerResponseStatus::CommandNotSupported, - 0x08 => ServerResponseStatus::AddressTypeNotSupported, - x => return Err(DeserializationError::InvalidServerResponse(x)), - }; - - let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?; - read_amt(Pin::new(raw_r), 2, &mut buffer).await?; - let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); - - Ok(ServerResponse { - status, - bound_address, - bound_port, - }) - } - - pub async fn write( - &self, - w: &mut W, - ) -> Result<(), SerializationError> { - let status_code = match self.status { - ServerResponseStatus::RequestGranted => 0x00, - ServerResponseStatus::GeneralFailure => 0x01, - ServerResponseStatus::ConnectionNotAllowedByRule => 0x02, - ServerResponseStatus::NetworkUnreachable => 0x03, - ServerResponseStatus::HostUnreachable => 0x04, - ServerResponseStatus::ConnectionRefused => 0x05, - ServerResponseStatus::TTLExpired => 0x06, - ServerResponseStatus::CommandNotSupported => 0x07, - ServerResponseStatus::AddressTypeNotSupported => 0x08, - }; - - w.write_all(&[5, status_code, 0]).await?; - self.bound_address.write(w).await?; - w.write_all(&[ - (self.bound_port >> 8) as u8, - (self.bound_port & 0xffu16) as u8, - ]) - .await - .map_err(SerializationError::IOError) - } -} - -#[cfg(test)] -impl Arbitrary for ServerResponseStatus { - fn arbitrary(g: &mut Gen) -> ServerResponseStatus { - g.choose(&[ - ServerResponseStatus::RequestGranted, - ServerResponseStatus::GeneralFailure, - ServerResponseStatus::ConnectionNotAllowedByRule, - ServerResponseStatus::NetworkUnreachable, - ServerResponseStatus::HostUnreachable, - ServerResponseStatus::ConnectionRefused, - ServerResponseStatus::TTLExpired, - ServerResponseStatus::CommandNotSupported, - ServerResponseStatus::AddressTypeNotSupported, - ]) - .unwrap() - .clone() - } -} - -#[cfg(test)] -impl Arbitrary for ServerResponse { - fn arbitrary(g: &mut Gen) -> Self { - let status = ServerResponseStatus::arbitrary(g); - let bound_address = SOCKSv5Address::arbitrary(g); - let bound_port = u16::arbitrary(g); - - ServerResponse { - status, - bound_address, - bound_port, - } - } -} - -macro_rules! standard_roundtrip { - ($name: ident, $t: ty) => { - #[cfg(test)] - quickcheck! { - fn $name(xs: $t) -> bool { - let mut buffer = vec![]; - task::block_on(xs.write(&mut buffer)).unwrap(); - let mut cursor = Cursor::new(buffer); - let ys = <$t>::read(Pin::new(&mut cursor)); - xs == task::block_on(ys).unwrap() - } - } - }; -} - -standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod); -standard_roundtrip!(client_greeting_roundtrips, ClientGreeting); -standard_roundtrip!(server_choice_roundtrips, ServerChoice); -standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword); -standard_roundtrip!(server_auth_response, ServerAuthResponse); -standard_roundtrip!(address_roundtrips, SOCKSv5Address); -standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest); -standard_roundtrip!(server_response_roundtrips, ServerResponse); +mod authentication_method; +mod client_command; +mod client_greeting; +mod client_username_password; +mod server_auth_response; +mod server_choice; +mod server_response; +pub(crate) mod utils; + +pub use crate::messages::authentication_method::AuthenticationMethod; +pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest}; +pub use crate::messages::client_greeting::ClientGreeting; +pub use crate::messages::client_username_password::ClientUsernamePassword; +pub use crate::messages::server_auth_response::ServerAuthResponse; +pub use crate::messages::server_choice::ServerChoice; +pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus}; diff --git a/src/messages/authentication_method.rs b/src/messages/authentication_method.rs new file mode 100644 index 0000000..050d876 --- /dev/null +++ b/src/messages/authentication_method.rs @@ -0,0 +1,156 @@ +use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError}; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +use std::fmt; +use std::pin::Pin; + +#[allow(clippy::upper_case_acronyms)] +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AuthenticationMethod { + None, + GSSAPI, + UsernameAndPassword, + ChallengeHandshake, + ChallengeResponse, + SSL, + NDS, + MultiAuthenticationFramework, + JSONPropertyBlock, + PrivateMethod(u8), + NoAcceptableMethods, +} + +impl fmt::Display for AuthenticationMethod { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AuthenticationMethod::None => write!(f, "No authentication"), + AuthenticationMethod::GSSAPI => write!(f, "GSS-API"), + AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"), + AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"), + AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"), + AuthenticationMethod::SSL => write!(f, "SSL"), + AuthenticationMethod::NDS => write!(f, "NDS Authentication"), + AuthenticationMethod::MultiAuthenticationFramework => { + write!(f, "Multi-Authentication Framework") + } + AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"), + AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m), + AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"), + } + } +} + +impl AuthenticationMethod { + pub async fn read( + mut r: Pin<&mut R>, + ) -> Result { + let mut byte_buffer = [0u8; 1]; + let amount_read = r.read(&mut byte_buffer).await?; + + if amount_read == 0 { + return Err(AuthenticationDeserializationError::NoDataFound.into()); + } + + match byte_buffer[0] { + 0 => Ok(AuthenticationMethod::None), + 1 => Ok(AuthenticationMethod::GSSAPI), + 2 => Ok(AuthenticationMethod::UsernameAndPassword), + 3 => Ok(AuthenticationMethod::ChallengeHandshake), + 5 => Ok(AuthenticationMethod::ChallengeResponse), + 6 => Ok(AuthenticationMethod::SSL), + 7 => Ok(AuthenticationMethod::NDS), + 8 => Ok(AuthenticationMethod::MultiAuthenticationFramework), + 9 => Ok(AuthenticationMethod::JSONPropertyBlock), + x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)), + 0xff => Ok(AuthenticationMethod::NoAcceptableMethods), + e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()), + } + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + let value = match self { + AuthenticationMethod::None => 0, + AuthenticationMethod::GSSAPI => 1, + AuthenticationMethod::UsernameAndPassword => 2, + AuthenticationMethod::ChallengeHandshake => 3, + AuthenticationMethod::ChallengeResponse => 5, + AuthenticationMethod::SSL => 6, + AuthenticationMethod::NDS => 7, + AuthenticationMethod::MultiAuthenticationFramework => 8, + AuthenticationMethod::JSONPropertyBlock => 9, + AuthenticationMethod::PrivateMethod(pm) => *pm, + AuthenticationMethod::NoAcceptableMethods => 0xff, + }; + + Ok(w.write_all(&[value]).await?) + } +} + +#[cfg(test)] +impl Arbitrary for AuthenticationMethod { + fn arbitrary(g: &mut Gen) -> AuthenticationMethod { + let mut vals = vec![ + AuthenticationMethod::None, + AuthenticationMethod::GSSAPI, + AuthenticationMethod::UsernameAndPassword, + AuthenticationMethod::ChallengeHandshake, + AuthenticationMethod::ChallengeResponse, + AuthenticationMethod::SSL, + AuthenticationMethod::NDS, + AuthenticationMethod::MultiAuthenticationFramework, + AuthenticationMethod::JSONPropertyBlock, + AuthenticationMethod::NoAcceptableMethods, + ]; + for x in 0x80..0xffu8 { + vals.push(AuthenticationMethod::PrivateMethod(x)); + } + g.choose(&vals).unwrap().clone() + } +} + +standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod); + +#[test] +fn bad_byte() { + let no_len = vec![42]; + let mut cursor = Cursor::new(no_len); + let ys = AuthenticationMethod::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::AuthenticationMethodError( + AuthenticationDeserializationError::InvalidAuthenticationByte(42) + )), + task::block_on(ys) + ); +} + +#[test] +fn display_isnt_empty() { + let vals = vec![ + AuthenticationMethod::None, + AuthenticationMethod::GSSAPI, + AuthenticationMethod::UsernameAndPassword, + AuthenticationMethod::ChallengeHandshake, + AuthenticationMethod::ChallengeResponse, + AuthenticationMethod::SSL, + AuthenticationMethod::NDS, + AuthenticationMethod::MultiAuthenticationFramework, + AuthenticationMethod::JSONPropertyBlock, + AuthenticationMethod::NoAcceptableMethods, + AuthenticationMethod::PrivateMethod(42), + ]; + + for method in vals.iter() { + let str = format!("{}", method); + assert!(str.is_ascii()); + assert!(!str.is_empty()); + } +} diff --git a/src/messages/client_command.rs b/src/messages/client_command.rs new file mode 100644 index 0000000..2d5456a --- /dev/null +++ b/src/messages/client_command.rs @@ -0,0 +1,164 @@ +use crate::errors::{DeserializationError, SerializationError}; +use crate::network::SOCKSv5Address; +use crate::serialize::read_amt; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::io::ErrorKind; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +#[cfg(test)] +use std::net::Ipv4Addr; +use std::pin::Pin; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ClientConnectionCommand { + EstablishTCPStream, + EstablishTCPPortBinding, + AssociateUDPPort, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ClientConnectionRequest { + pub command_code: ClientConnectionCommand, + pub destination_address: SOCKSv5Address, + pub destination_port: u16, +} + +impl ClientConnectionRequest { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 2]; + let raw_r = Pin::into_inner(r); + + read_amt(Pin::new(raw_r), 2, &mut buffer).await?; + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + let command_code = match buffer[1] { + 0x01 => ClientConnectionCommand::EstablishTCPStream, + 0x02 => ClientConnectionCommand::EstablishTCPPortBinding, + 0x03 => ClientConnectionCommand::AssociateUDPPort, + x => return Err(DeserializationError::InvalidClientCommand(x)), + }; + + let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?; + + read_amt(Pin::new(raw_r), 2, &mut buffer).await?; + let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); + + Ok(ClientConnectionRequest { + command_code, + destination_address, + destination_port, + }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + let command = match self.command_code { + ClientConnectionCommand::EstablishTCPStream => 1, + ClientConnectionCommand::EstablishTCPPortBinding => 2, + ClientConnectionCommand::AssociateUDPPort => 3, + }; + + w.write_all(&[5, command]).await?; + self.destination_address.write(w).await?; + w.write_all(&[ + (self.destination_port >> 8) as u8, + (self.destination_port & 0xffu16) as u8, + ]) + .await + .map_err(SerializationError::IOError) + } +} + +#[cfg(test)] +impl Arbitrary for ClientConnectionCommand { + fn arbitrary(g: &mut Gen) -> ClientConnectionCommand { + let options = [ + ClientConnectionCommand::EstablishTCPStream, + ClientConnectionCommand::EstablishTCPPortBinding, + ClientConnectionCommand::AssociateUDPPort, + ]; + g.choose(&options).unwrap().clone() + } +} + +#[cfg(test)] +impl Arbitrary for ClientConnectionRequest { + fn arbitrary(g: &mut Gen) -> Self { + let command_code = ClientConnectionCommand::arbitrary(g); + let destination_address = SOCKSv5Address::arbitrary(g); + let destination_port = u16::arbitrary(g); + + ClientConnectionRequest { + command_code, + destination_address, + destination_port, + } + } +} + +standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest); + +#[test] +fn check_short_reads() { + let empty = vec![]; + let mut cursor = Cursor::new(empty); + let ys = ClientConnectionRequest::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + + let no_len = vec![5, 1]; + let mut cursor = Cursor::new(no_len); + let ys = ClientConnectionRequest::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); +} + +#[test] +fn check_bad_version() { + let bad_ver = vec![6, 1, 1]; + let mut cursor = Cursor::new(bad_ver); + let ys = ClientConnectionRequest::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidVersion(5, 6)), + task::block_on(ys) + ); +} + +#[test] +fn check_bad_command() { + let bad_cmd = vec![5, 32, 1]; + let mut cursor = Cursor::new(bad_cmd); + let ys = ClientConnectionRequest::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidClientCommand(32)), + task::block_on(ys) + ); +} + +#[test] +fn short_write_fails_right() { + let mut buffer = [0u8; 2]; + let cmd = ClientConnectionRequest { + command_code: ClientConnectionCommand::AssociateUDPPort, + destination_address: SOCKSv5Address::IP4(Ipv4Addr::from(0)), + destination_port: 22, + }; + let mut cursor = Cursor::new(&mut buffer as &mut [u8]); + let result = task::block_on(cmd.write(&mut cursor)); + match result { + Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."), + Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), + Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e), + } +} diff --git a/src/messages/client_greeting.rs b/src/messages/client_greeting.rs new file mode 100644 index 0000000..be474e4 --- /dev/null +++ b/src/messages/client_greeting.rs @@ -0,0 +1,135 @@ +#[cfg(test)] +use crate::errors::AuthenticationDeserializationError; +use crate::errors::{DeserializationError, SerializationError}; +use crate::messages::AuthenticationMethod; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +use std::pin::Pin; + +/// Client greetings are the first message sent in a SOCKSv5 session. They +/// identify that there's a client that wants to talk to a server, and that +/// they can support any of the provided mechanisms for authenticating to +/// said server. (It feels weird that the offer/choice goes this way instead +/// of the reverse, but whatever.) +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ClientGreeting { + pub acceptable_methods: Vec, +} + +impl ClientGreeting { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + let raw_r = Pin::into_inner(r); + + if raw_r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + if raw_r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize); + for _ in 0..buffer[0] { + acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?); + } + + Ok(ClientGreeting { acceptable_methods }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + if self.acceptable_methods.len() > 255 { + return Err(SerializationError::TooManyAuthMethods( + self.acceptable_methods.len(), + )); + } + + let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2); + buffer.push(5); + buffer.push(self.acceptable_methods.len() as u8); + w.write_all(&buffer).await?; + for authmeth in self.acceptable_methods.iter() { + authmeth.write(w).await?; + } + Ok(()) + } +} + +#[cfg(test)] +impl Arbitrary for ClientGreeting { + fn arbitrary(g: &mut Gen) -> ClientGreeting { + let amt = u8::arbitrary(g); + let mut acceptable_methods = Vec::with_capacity(amt as usize); + + for _ in 0..amt { + acceptable_methods.push(AuthenticationMethod::arbitrary(g)); + } + + ClientGreeting { acceptable_methods } + } +} + +standard_roundtrip!(client_greeting_roundtrips, ClientGreeting); + +#[test] +fn check_short_reads() { + let empty = vec![]; + let mut cursor = Cursor::new(empty); + let ys = ClientGreeting::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + + let no_len = vec![5]; + let mut cursor = Cursor::new(no_len); + let ys = ClientGreeting::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + + let bad_len = vec![5, 9]; + let mut cursor = Cursor::new(bad_len); + let ys = ClientGreeting::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::AuthenticationMethodError( + AuthenticationDeserializationError::NoDataFound + )), + task::block_on(ys) + ); +} + +#[test] +fn check_bad_version() { + let no_len = vec![6, 1, 1]; + let mut cursor = Cursor::new(no_len); + let ys = ClientGreeting::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidVersion(5, 6)), + task::block_on(ys) + ); +} + +#[test] +fn check_too_many() { + let mut auth_methods = Vec::with_capacity(512); + auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake); + let greet = ClientGreeting { + acceptable_methods: auth_methods, + }; + let mut output = vec![0; 1024]; + assert_eq!( + Err(SerializationError::TooManyAuthMethods(512)), + task::block_on(greet.write(&mut output)) + ); +} diff --git a/src/messages/client_username_password.rs b/src/messages/client_username_password.rs new file mode 100644 index 0000000..d5c7dae --- /dev/null +++ b/src/messages/client_username_password.rs @@ -0,0 +1,86 @@ +use crate::errors::{DeserializationError, SerializationError}; +#[cfg(test)] +use crate::messages::utils::arbitrary_socks_string; +use crate::serialize::{read_string, write_string}; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +use std::pin::Pin; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ClientUsernamePassword { + pub username: String, + pub password: String, +} + +impl ClientUsernamePassword { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + let raw_r = Pin::into_inner(r); + + if raw_r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 1 { + return Err(DeserializationError::InvalidVersion(1, buffer[0])); + } + + let username = read_string(Pin::new(raw_r)).await?; + let password = read_string(Pin::new(raw_r)).await?; + + Ok(ClientUsernamePassword { username, password }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + w.write_all(&[1]).await?; + write_string(&self.username, w).await?; + write_string(&self.password, w).await + } +} + +#[cfg(test)] +impl Arbitrary for ClientUsernamePassword { + fn arbitrary(g: &mut Gen) -> Self { + let username = arbitrary_socks_string(g); + let password = arbitrary_socks_string(g); + + ClientUsernamePassword { username, password } + } +} + +standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword); + +#[test] +fn check_short_reads() { + let empty = vec![]; + let mut cursor = Cursor::new(empty); + let ys = ClientUsernamePassword::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + + let user_only = vec![1, 3, 102, 111, 111]; + let mut cursor = Cursor::new(user_only); + let ys = ClientUsernamePassword::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); +} + +#[test] +fn check_bad_version() { + let bad_len = vec![5]; + let mut cursor = Cursor::new(bad_len); + let ys = ClientUsernamePassword::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidVersion(1, 5)), + task::block_on(ys) + ); +} diff --git a/src/messages/server_auth_response.rs b/src/messages/server_auth_response.rs new file mode 100644 index 0000000..ec70af1 --- /dev/null +++ b/src/messages/server_auth_response.rs @@ -0,0 +1,83 @@ +use crate::errors::{DeserializationError, SerializationError}; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +use std::pin::Pin; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ServerAuthResponse { + pub success: bool, +} + +impl ServerAuthResponse { + pub async fn read( + mut r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + + if r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 1 { + return Err(DeserializationError::InvalidVersion(1, buffer[0])); + } + + if r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + Ok(ServerAuthResponse { + success: buffer[0] == 0, + }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + w.write_all(&[1]).await?; + w.write_all(&[if self.success { 0x00 } else { 0xde }]) + .await?; + Ok(()) + } +} + +#[cfg(test)] +impl Arbitrary for ServerAuthResponse { + fn arbitrary(g: &mut Gen) -> ServerAuthResponse { + let success = bool::arbitrary(g); + ServerAuthResponse { success } + } +} + +standard_roundtrip!(server_auth_response, ServerAuthResponse); + +#[test] +fn check_short_reads() { + let empty = vec![]; + let mut cursor = Cursor::new(empty); + let ys = ServerAuthResponse::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + + let no_len = vec![1]; + let mut cursor = Cursor::new(no_len); + let ys = ServerAuthResponse::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); +} + +#[test] +fn check_bad_version() { + let no_len = vec![6, 1]; + let mut cursor = Cursor::new(no_len); + let ys = ServerAuthResponse::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidVersion(1, 6)), + task::block_on(ys) + ); +} diff --git a/src/messages/server_choice.rs b/src/messages/server_choice.rs new file mode 100644 index 0000000..d3fe169 --- /dev/null +++ b/src/messages/server_choice.rs @@ -0,0 +1,86 @@ +#[cfg(test)] +use crate::errors::AuthenticationDeserializationError; +use crate::errors::{DeserializationError, SerializationError}; +use crate::messages::AuthenticationMethod; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +use std::pin::Pin; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ServerChoice { + pub chosen_method: AuthenticationMethod, +} + +impl ServerChoice { + pub async fn read( + mut r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + + if r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + let chosen_method = AuthenticationMethod::read(r).await?; + + Ok(ServerChoice { chosen_method }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + w.write_all(&[5]).await?; + self.chosen_method.write(w).await + } +} + +#[cfg(test)] +impl Arbitrary for ServerChoice { + fn arbitrary(g: &mut Gen) -> ServerChoice { + ServerChoice { + chosen_method: AuthenticationMethod::arbitrary(g), + } + } +} + +standard_roundtrip!(server_choice_roundtrips, ServerChoice); + +#[test] +fn check_short_reads() { + let empty = vec![]; + let mut cursor = Cursor::new(empty); + let ys = ServerChoice::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); + + let bad_len = vec![5]; + let mut cursor = Cursor::new(bad_len); + let ys = ServerChoice::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::AuthenticationMethodError( + AuthenticationDeserializationError::NoDataFound + )), + task::block_on(ys) + ); +} + +#[test] +fn check_bad_version() { + let no_len = vec![9, 1]; + let mut cursor = Cursor::new(no_len); + let ys = ServerChoice::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidVersion(5, 9)), + task::block_on(ys) + ); +} diff --git a/src/messages/server_response.rs b/src/messages/server_response.rs new file mode 100644 index 0000000..6127116 --- /dev/null +++ b/src/messages/server_response.rs @@ -0,0 +1,202 @@ +use crate::errors::{DeserializationError, SerializationError}; +use crate::network::SOCKSv5Address; +use crate::serialize::read_amt; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::io::ErrorKind; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use log::warn; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +use std::net::Ipv4Addr; +use std::pin::Pin; +use thiserror::Error; + +#[derive(Clone, Debug, Eq, Error, PartialEq)] +pub enum ServerResponseStatus { + #[error("Actually, everything's fine (weird to see this in an error)")] + RequestGranted, + #[error("General server failure")] + GeneralFailure, + #[error("Connection not allowed by policy rule")] + ConnectionNotAllowedByRule, + #[error("Network unreachable")] + NetworkUnreachable, + #[error("Host unreachable")] + HostUnreachable, + #[error("Connection refused")] + ConnectionRefused, + #[error("TTL expired")] + TTLExpired, + #[error("Command not supported")] + CommandNotSupported, + #[error("Address type not supported")] + AddressTypeNotSupported, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ServerResponse { + pub status: ServerResponseStatus, + pub bound_address: SOCKSv5Address, + pub bound_port: u16, +} + +impl ServerResponse { + pub fn error>(resp: E) -> ServerResponse { + ServerResponse { + status: resp.into(), + bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)), + bound_port: 0, + } + } +} + +impl ServerResponse { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 3]; + let raw_r = Pin::into_inner(r); + + read_amt(Pin::new(raw_r), 3, &mut buffer).await?; + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + if buffer[2] != 0 { + warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes."); + } + + let status = match buffer[1] { + 0x00 => ServerResponseStatus::RequestGranted, + 0x01 => ServerResponseStatus::GeneralFailure, + 0x02 => ServerResponseStatus::ConnectionNotAllowedByRule, + 0x03 => ServerResponseStatus::NetworkUnreachable, + 0x04 => ServerResponseStatus::HostUnreachable, + 0x05 => ServerResponseStatus::ConnectionRefused, + 0x06 => ServerResponseStatus::TTLExpired, + 0x07 => ServerResponseStatus::CommandNotSupported, + 0x08 => ServerResponseStatus::AddressTypeNotSupported, + x => return Err(DeserializationError::InvalidServerResponse(x)), + }; + + let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?; + read_amt(Pin::new(raw_r), 2, &mut buffer).await?; + let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); + + Ok(ServerResponse { + status, + bound_address, + bound_port, + }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + let status_code = match self.status { + ServerResponseStatus::RequestGranted => 0x00, + ServerResponseStatus::GeneralFailure => 0x01, + ServerResponseStatus::ConnectionNotAllowedByRule => 0x02, + ServerResponseStatus::NetworkUnreachable => 0x03, + ServerResponseStatus::HostUnreachable => 0x04, + ServerResponseStatus::ConnectionRefused => 0x05, + ServerResponseStatus::TTLExpired => 0x06, + ServerResponseStatus::CommandNotSupported => 0x07, + ServerResponseStatus::AddressTypeNotSupported => 0x08, + }; + + w.write_all(&[5, status_code, 0]).await?; + self.bound_address.write(w).await?; + w.write_all(&[ + (self.bound_port >> 8) as u8, + (self.bound_port & 0xffu16) as u8, + ]) + .await + .map_err(SerializationError::IOError) + } +} + +#[cfg(test)] +impl Arbitrary for ServerResponseStatus { + fn arbitrary(g: &mut Gen) -> ServerResponseStatus { + let options = [ + ServerResponseStatus::RequestGranted, + ServerResponseStatus::GeneralFailure, + ServerResponseStatus::ConnectionNotAllowedByRule, + ServerResponseStatus::NetworkUnreachable, + ServerResponseStatus::HostUnreachable, + ServerResponseStatus::ConnectionRefused, + ServerResponseStatus::TTLExpired, + ServerResponseStatus::CommandNotSupported, + ServerResponseStatus::AddressTypeNotSupported, + ]; + g.choose(&options).unwrap().clone() + } +} + +#[cfg(test)] +impl Arbitrary for ServerResponse { + fn arbitrary(g: &mut Gen) -> Self { + let status = ServerResponseStatus::arbitrary(g); + let bound_address = SOCKSv5Address::arbitrary(g); + let bound_port = u16::arbitrary(g); + + ServerResponse { + status, + bound_address, + bound_port, + } + } +} + +standard_roundtrip!(server_response_roundtrips, ServerResponse); + +#[test] +fn check_short_reads() { + let empty = vec![]; + let mut cursor = Cursor::new(empty); + let ys = ServerResponse::read(Pin::new(&mut cursor)); + assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); +} + +#[test] +fn check_bad_version() { + let bad_ver = vec![6, 1, 1]; + let mut cursor = Cursor::new(bad_ver); + let ys = ServerResponse::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidVersion(5, 6)), + task::block_on(ys) + ); +} + +#[test] +fn check_bad_command() { + let bad_cmd = vec![5, 32, 0x42]; + let mut cursor = Cursor::new(bad_cmd); + let ys = ServerResponse::read(Pin::new(&mut cursor)); + assert_eq!( + Err(DeserializationError::InvalidServerResponse(32)), + task::block_on(ys) + ); +} + +#[test] +fn short_write_fails_right() { + let mut buffer = [0u8; 2]; + let cmd = ServerResponse::error(ServerResponseStatus::AddressTypeNotSupported); + let mut cursor = Cursor::new(&mut buffer as &mut [u8]); + let result = task::block_on(cmd.write(&mut cursor)); + match result { + Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."), + Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), + Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e), + } +} diff --git a/src/messages/utils.rs b/src/messages/utils.rs new file mode 100644 index 0000000..52a7c61 --- /dev/null +++ b/src/messages/utils.rs @@ -0,0 +1,33 @@ +#[cfg(test)] +use quickcheck::{Arbitrary, Gen}; + +#[cfg(test)] +pub fn arbitrary_socks_string(g: &mut Gen) -> String { + loop { + let mut potential = String::arbitrary(g); + + potential.truncate(255); + let bytestring = potential.as_bytes(); + + if bytestring.len() > 0 && bytestring.len() < 256 { + return potential; + } + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! standard_roundtrip { + ($name: ident, $t: ty) => { + #[cfg(test)] + quickcheck! { + fn $name(xs: $t) -> bool { + let mut buffer = vec![]; + task::block_on(xs.write(&mut buffer)).unwrap(); + let mut cursor = Cursor::new(buffer); + let ys = <$t>::read(Pin::new(&mut cursor)); + xs == task::block_on(ys).unwrap() + } + } + }; +} diff --git a/src/network/address.rs b/src/network/address.rs index fa1c584..f30db52 100644 --- a/src/network/address.rs +++ b/src/network/address.rs @@ -1,10 +1,15 @@ use crate::errors::{DeserializationError, SerializationError}; #[cfg(test)] -use crate::messages::arbitrary_socks_string; +use crate::messages::utils::arbitrary_socks_string; use crate::serialize::{read_amt, read_string, write_string}; +use crate::standard_roundtrip; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[cfg(test)] -use quickcheck::{Arbitrary, Gen}; +use quickcheck::{quickcheck, Arbitrary, Gen}; use std::fmt; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::pin::Pin; @@ -149,3 +154,5 @@ impl Arbitrary for SOCKSv5Address { .clone() } } + +standard_roundtrip!(address_roundtrips, SOCKSv5Address); diff --git a/src/server.rs b/src/server.rs index de8611d..1d75603 100644 --- a/src/server.rs +++ b/src/server.rs @@ -310,4 +310,4 @@ where } } } -} \ No newline at end of file +}