Split out the messages into individual files,, and add negative tests, so we can aspire towards good coverage.

This commit is contained in:
2021-06-27 16:53:57 -07:00
parent 1bf6f62d4e
commit d1143a414c
13 changed files with 1087 additions and 607 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
/target /target
Cargo.lock Cargo.lock
tarpaulin-report.html

View File

@@ -26,6 +26,69 @@ pub enum DeserializationError {
InvalidServerResponse(u8), InvalidServerResponse(u8),
} }
#[test]
fn des_error_reasonable_equals() {
let invalid_version = DeserializationError::InvalidVersion(1, 2);
assert_eq!(invalid_version, invalid_version);
let not_enough = DeserializationError::NotEnoughData;
assert_eq!(not_enough, not_enough);
let invalid_empty = DeserializationError::InvalidEmptyString;
assert_eq!(invalid_empty, invalid_empty);
let auth_method = DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound,
);
assert_eq!(auth_method, auth_method);
let utf8 = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
assert_eq!(utf8, utf8);
let invalid_address = DeserializationError::InvalidAddressType(3);
assert_eq!(invalid_address, invalid_address);
let invalid_client_cmd = DeserializationError::InvalidClientCommand(32);
assert_eq!(invalid_client_cmd, invalid_client_cmd);
let invalid_server_resp = DeserializationError::InvalidServerResponse(42);
assert_eq!(invalid_server_resp, invalid_server_resp);
assert_ne!(invalid_version, invalid_address);
assert_ne!(not_enough, invalid_empty);
assert_ne!(auth_method, invalid_client_cmd);
assert_ne!(utf8, invalid_server_resp);
}
impl PartialEq for DeserializationError {
fn eq(&self, other: &DeserializationError) -> bool {
match (self, other) {
(
&DeserializationError::InvalidVersion(a, b),
&DeserializationError::InvalidVersion(x, y),
) => (a == x) && (b == y),
(&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true,
(
&DeserializationError::InvalidEmptyString,
&DeserializationError::InvalidEmptyString,
) => true,
(
&DeserializationError::AuthenticationMethodError(ref a),
&DeserializationError::AuthenticationMethodError(ref b),
) => a == b,
(&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => {
a == b
}
(
&DeserializationError::InvalidAddressType(a),
&DeserializationError::InvalidAddressType(b),
) => a == b,
(
&DeserializationError::InvalidClientCommand(a),
&DeserializationError::InvalidClientCommand(b),
) => a == b,
(
&DeserializationError::InvalidServerResponse(a),
&DeserializationError::InvalidServerResponse(b),
) => a == b,
(_, _) => false,
}
}
}
/// All the errors that can occur trying to turn SOCKSv5 message structures /// All the errors that can occur trying to turn SOCKSv5 message structures
/// into raw bytes. There's a few places that the message structures allow /// into raw bytes. There's a few places that the message structures allow
/// for information that can't be serialized; often, you have to be careful /// for information that can't be serialized; often, you have to be careful
@@ -40,6 +103,32 @@ pub enum SerializationError {
IOError(#[from] io::Error), IOError(#[from] io::Error),
} }
#[test]
fn ser_err_reasonable_equals() {
let too_many = SerializationError::TooManyAuthMethods(512);
assert_eq!(too_many, too_many);
let invalid_str = SerializationError::InvalidStringLength("Whoopsy!".to_string());
assert_eq!(invalid_str, invalid_str);
assert_ne!(too_many, invalid_str);
}
impl PartialEq for SerializationError {
fn eq(&self, other: &SerializationError) -> bool {
match (self, other) {
(
&SerializationError::TooManyAuthMethods(a),
&SerializationError::TooManyAuthMethods(b),
) => a == b,
(
&SerializationError::InvalidStringLength(ref a),
&SerializationError::InvalidStringLength(ref b),
) => a == b,
(_, _) => false,
}
}
}
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum AuthenticationDeserializationError { pub enum AuthenticationDeserializationError {
#[error("No data found deserializing SOCKS authentication type")] #[error("No data found deserializing SOCKS authentication type")]
@@ -49,3 +138,29 @@ pub enum AuthenticationDeserializationError {
#[error("IO error reading SOCKS authentication type: {0}")] #[error("IO error reading SOCKS authentication type: {0}")]
IOError(#[from] io::Error), IOError(#[from] io::Error),
} }
#[test]
fn auth_des_err_reasonable_equals() {
let no_data = AuthenticationDeserializationError::NoDataFound;
assert_eq!(no_data, no_data);
let invalid_auth = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
assert_eq!(invalid_auth, invalid_auth);
assert_ne!(no_data, invalid_auth);
}
impl PartialEq for AuthenticationDeserializationError {
fn eq(&self, other: &AuthenticationDeserializationError) -> bool {
match (self, other) {
(
&AuthenticationDeserializationError::NoDataFound,
&AuthenticationDeserializationError::NoDataFound,
) => true,
(
&AuthenticationDeserializationError::InvalidAuthenticationByte(x),
&AuthenticationDeserializationError::InvalidAuthenticationByte(y),
) => x == y,
(_, _) => false,
}
}
}

View File

@@ -1,604 +1,16 @@
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError}; mod authentication_method;
use crate::network::SOCKSv5Address; mod client_command;
use crate::serialize::{read_amt, read_string, write_string}; mod client_greeting;
#[cfg(test)] mod client_username_password;
use async_std::task; mod server_auth_response;
#[cfg(test)] mod server_choice;
use futures::io::Cursor; mod server_response;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub(crate) mod utils;
use log::warn;
#[cfg(test)] pub use crate::messages::authentication_method::AuthenticationMethod;
use quickcheck::{quickcheck, Arbitrary, Gen}; pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest};
use std::fmt; pub use crate::messages::client_greeting::ClientGreeting;
use std::net::Ipv4Addr; pub use crate::messages::client_username_password::ClientUsernamePassword;
use std::pin::Pin; pub use crate::messages::server_auth_response::ServerAuthResponse;
use thiserror::Error; pub use crate::messages::server_choice::ServerChoice;
pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus};
/// Client greetings are the first message sent in a SOCKSv5 session. They
/// identify that there's a client that wants to talk to a server, and that
/// they can support any of the provided mechanisms for authenticating to
/// said server. (It feels weird that the offer/choice goes this way instead
/// of the reverse, but whatever.)
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientGreeting {
pub acceptable_methods: Vec<AuthenticationMethod>,
}
impl ClientGreeting {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<ClientGreeting, DeserializationError> {
let mut buffer = [0; 1];
let raw_r = Pin::into_inner(r);
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
for _ in 0..buffer[0] {
acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?);
}
Ok(ClientGreeting { acceptable_methods })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
if self.acceptable_methods.len() > 255 {
return Err(SerializationError::TooManyAuthMethods(
self.acceptable_methods.len(),
));
}
let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2);
buffer.push(5);
buffer.push(self.acceptable_methods.len() as u8);
w.write_all(&buffer).await?;
for authmeth in self.acceptable_methods.iter() {
authmeth.write(w).await?;
}
Ok(())
}
}
#[cfg(test)]
impl Arbitrary for ClientGreeting {
fn arbitrary(g: &mut Gen) -> ClientGreeting {
let amt = u8::arbitrary(g);
let mut acceptable_methods = Vec::with_capacity(amt as usize);
for _ in 0..amt {
acceptable_methods.push(AuthenticationMethod::arbitrary(g));
}
ClientGreeting { acceptable_methods }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerChoice {
pub chosen_method: AuthenticationMethod,
}
impl ServerChoice {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
let chosen_method = AuthenticationMethod::read(r).await?;
Ok(ServerChoice { chosen_method })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[5]).await?;
self.chosen_method.write(w).await
}
}
#[cfg(test)]
impl Arbitrary for ServerChoice {
fn arbitrary(g: &mut Gen) -> ServerChoice {
ServerChoice {
chosen_method: AuthenticationMethod::arbitrary(g),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientUsernamePassword {
pub username: String,
pub password: String,
}
impl ClientUsernamePassword {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
let raw_r = Pin::into_inner(r);
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
let username = read_string(Pin::new(raw_r)).await?;
let password = read_string(Pin::new(raw_r)).await?;
Ok(ClientUsernamePassword { username, password })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[1]).await?;
write_string(&self.username, w).await?;
write_string(&self.password, w).await
}
}
#[cfg(test)]
impl Arbitrary for ClientUsernamePassword {
fn arbitrary(g: &mut Gen) -> Self {
let username = arbitrary_socks_string(g);
let password = arbitrary_socks_string(g);
ClientUsernamePassword { username, password }
}
}
#[cfg(test)]
pub fn arbitrary_socks_string(g: &mut Gen) -> String {
loop {
let mut potential = String::arbitrary(g);
potential.truncate(255);
let bytestring = potential.as_bytes();
if bytestring.len() > 0 && bytestring.len() < 256 {
return potential;
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerAuthResponse {
pub success: bool,
}
impl ServerAuthResponse {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
Ok(ServerAuthResponse {
success: buffer[0] == 0,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[1]).await?;
w.write_all(&[if self.success { 0x00 } else { 0xde }])
.await?;
Ok(())
}
}
#[cfg(test)]
impl Arbitrary for ServerAuthResponse {
fn arbitrary(g: &mut Gen) -> ServerAuthResponse {
let success = bool::arbitrary(g);
ServerAuthResponse { success }
}
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum AuthenticationMethod {
None,
GSSAPI,
UsernameAndPassword,
ChallengeHandshake,
ChallengeResponse,
SSL,
NDS,
MultiAuthenticationFramework,
JSONPropertyBlock,
PrivateMethod(u8),
NoAcceptableMethods,
}
impl fmt::Display for AuthenticationMethod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
AuthenticationMethod::None => write!(f, "No authentication"),
AuthenticationMethod::GSSAPI => write!(f, "GSS-API"),
AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"),
AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"),
AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"),
AuthenticationMethod::SSL => write!(f, "SSL"),
AuthenticationMethod::NDS => write!(f, "NDS Authentication"),
AuthenticationMethod::MultiAuthenticationFramework => {
write!(f, "Multi-Authentication Framework")
}
AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"),
AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m),
AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"),
}
}
}
impl AuthenticationMethod {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<AuthenticationMethod, DeserializationError> {
let mut byte_buffer = [0u8; 1];
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(AuthenticationDeserializationError::NoDataFound.into());
}
match byte_buffer[0] {
0 => Ok(AuthenticationMethod::None),
1 => Ok(AuthenticationMethod::GSSAPI),
2 => Ok(AuthenticationMethod::UsernameAndPassword),
3 => Ok(AuthenticationMethod::ChallengeHandshake),
5 => Ok(AuthenticationMethod::ChallengeResponse),
6 => Ok(AuthenticationMethod::SSL),
7 => Ok(AuthenticationMethod::NDS),
8 => Ok(AuthenticationMethod::MultiAuthenticationFramework),
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let value = match self {
AuthenticationMethod::None => 0,
AuthenticationMethod::GSSAPI => 1,
AuthenticationMethod::UsernameAndPassword => 2,
AuthenticationMethod::ChallengeHandshake => 3,
AuthenticationMethod::ChallengeResponse => 5,
AuthenticationMethod::SSL => 6,
AuthenticationMethod::NDS => 7,
AuthenticationMethod::MultiAuthenticationFramework => 8,
AuthenticationMethod::JSONPropertyBlock => 9,
AuthenticationMethod::PrivateMethod(pm) => *pm,
AuthenticationMethod::NoAcceptableMethods => 0xff,
};
Ok(w.write_all(&[value]).await?)
}
}
#[cfg(test)]
impl Arbitrary for AuthenticationMethod {
fn arbitrary(g: &mut Gen) -> AuthenticationMethod {
let mut vals = vec![
AuthenticationMethod::None,
AuthenticationMethod::GSSAPI,
AuthenticationMethod::UsernameAndPassword,
AuthenticationMethod::ChallengeHandshake,
AuthenticationMethod::ChallengeResponse,
AuthenticationMethod::SSL,
AuthenticationMethod::NDS,
AuthenticationMethod::MultiAuthenticationFramework,
AuthenticationMethod::JSONPropertyBlock,
AuthenticationMethod::NoAcceptableMethods,
];
for x in 0x80..0xffu8 {
vals.push(AuthenticationMethod::PrivateMethod(x));
}
g.choose(&vals).unwrap().clone()
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ClientConnectionCommand {
EstablishTCPStream,
EstablishTCPPortBinding,
AssociateUDPPort,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientConnectionRequest {
pub command_code: ClientConnectionCommand,
pub destination_address: SOCKSv5Address,
pub destination_port: u16,
}
impl ClientConnectionRequest {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 2];
let raw_r = Pin::into_inner(r);
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
let command_code = match buffer[1] {
0x01 => ClientConnectionCommand::EstablishTCPStream,
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
0x03 => ClientConnectionCommand::AssociateUDPPort,
x => return Err(DeserializationError::InvalidClientCommand(x)),
};
let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ClientConnectionRequest {
command_code,
destination_address,
destination_port,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let command = match self.command_code {
ClientConnectionCommand::EstablishTCPStream => 1,
ClientConnectionCommand::EstablishTCPPortBinding => 2,
ClientConnectionCommand::AssociateUDPPort => 3,
};
w.write_all(&[5, command]).await?;
self.destination_address.write(w).await?;
w.write_all(&[
(self.destination_port >> 8) as u8,
(self.destination_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
}
}
#[cfg(test)]
impl Arbitrary for ClientConnectionCommand {
fn arbitrary(g: &mut Gen) -> ClientConnectionCommand {
g.choose(&[
ClientConnectionCommand::EstablishTCPStream,
ClientConnectionCommand::EstablishTCPPortBinding,
ClientConnectionCommand::AssociateUDPPort,
])
.unwrap()
.clone()
}
}
#[cfg(test)]
impl Arbitrary for ClientConnectionRequest {
fn arbitrary(g: &mut Gen) -> Self {
let command_code = ClientConnectionCommand::arbitrary(g);
let destination_address = SOCKSv5Address::arbitrary(g);
let destination_port = u16::arbitrary(g);
ClientConnectionRequest {
command_code,
destination_address,
destination_port,
}
}
}
#[derive(Clone, Debug, Eq, Error, PartialEq)]
pub enum ServerResponseStatus {
#[error("Actually, everything's fine (weird to see this in an error)")]
RequestGranted,
#[error("General server failure")]
GeneralFailure,
#[error("Connection not allowed by policy rule")]
ConnectionNotAllowedByRule,
#[error("Network unreachable")]
NetworkUnreachable,
#[error("Host unreachable")]
HostUnreachable,
#[error("Connection refused")]
ConnectionRefused,
#[error("TTL expired")]
TTLExpired,
#[error("Command not supported")]
CommandNotSupported,
#[error("Address type not supported")]
AddressTypeNotSupported,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerResponse {
pub status: ServerResponseStatus,
pub bound_address: SOCKSv5Address,
pub bound_port: u16,
}
impl ServerResponse {
pub fn error<E: Into<ServerResponseStatus>>(resp: E) -> ServerResponse {
ServerResponse {
status: resp.into(),
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
bound_port: 0,
}
}
}
impl ServerResponse {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 3];
let raw_r = Pin::into_inner(r);
read_amt(Pin::new(raw_r), 3, &mut buffer).await?;
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if buffer[2] != 0 {
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
}
let status = match buffer[1] {
0x00 => ServerResponseStatus::RequestGranted,
0x01 => ServerResponseStatus::GeneralFailure,
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
0x03 => ServerResponseStatus::NetworkUnreachable,
0x04 => ServerResponseStatus::HostUnreachable,
0x05 => ServerResponseStatus::ConnectionRefused,
0x06 => ServerResponseStatus::TTLExpired,
0x07 => ServerResponseStatus::CommandNotSupported,
0x08 => ServerResponseStatus::AddressTypeNotSupported,
x => return Err(DeserializationError::InvalidServerResponse(x)),
};
let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ServerResponse {
status,
bound_address,
bound_port,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let status_code = match self.status {
ServerResponseStatus::RequestGranted => 0x00,
ServerResponseStatus::GeneralFailure => 0x01,
ServerResponseStatus::ConnectionNotAllowedByRule => 0x02,
ServerResponseStatus::NetworkUnreachable => 0x03,
ServerResponseStatus::HostUnreachable => 0x04,
ServerResponseStatus::ConnectionRefused => 0x05,
ServerResponseStatus::TTLExpired => 0x06,
ServerResponseStatus::CommandNotSupported => 0x07,
ServerResponseStatus::AddressTypeNotSupported => 0x08,
};
w.write_all(&[5, status_code, 0]).await?;
self.bound_address.write(w).await?;
w.write_all(&[
(self.bound_port >> 8) as u8,
(self.bound_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
}
}
#[cfg(test)]
impl Arbitrary for ServerResponseStatus {
fn arbitrary(g: &mut Gen) -> ServerResponseStatus {
g.choose(&[
ServerResponseStatus::RequestGranted,
ServerResponseStatus::GeneralFailure,
ServerResponseStatus::ConnectionNotAllowedByRule,
ServerResponseStatus::NetworkUnreachable,
ServerResponseStatus::HostUnreachable,
ServerResponseStatus::ConnectionRefused,
ServerResponseStatus::TTLExpired,
ServerResponseStatus::CommandNotSupported,
ServerResponseStatus::AddressTypeNotSupported,
])
.unwrap()
.clone()
}
}
#[cfg(test)]
impl Arbitrary for ServerResponse {
fn arbitrary(g: &mut Gen) -> Self {
let status = ServerResponseStatus::arbitrary(g);
let bound_address = SOCKSv5Address::arbitrary(g);
let bound_port = u16::arbitrary(g);
ServerResponse {
status,
bound_address,
bound_port,
}
}
}
macro_rules! standard_roundtrip {
($name: ident, $t: ty) => {
#[cfg(test)]
quickcheck! {
fn $name(xs: $t) -> bool {
let mut buffer = vec![];
task::block_on(xs.write(&mut buffer)).unwrap();
let mut cursor = Cursor::new(buffer);
let ys = <$t>::read(Pin::new(&mut cursor));
xs == task::block_on(ys).unwrap()
}
}
};
}
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
standard_roundtrip!(server_auth_response, ServerAuthResponse);
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
standard_roundtrip!(server_response_roundtrips, ServerResponse);

View File

@@ -0,0 +1,156 @@
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::fmt;
use std::pin::Pin;
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum AuthenticationMethod {
None,
GSSAPI,
UsernameAndPassword,
ChallengeHandshake,
ChallengeResponse,
SSL,
NDS,
MultiAuthenticationFramework,
JSONPropertyBlock,
PrivateMethod(u8),
NoAcceptableMethods,
}
impl fmt::Display for AuthenticationMethod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
AuthenticationMethod::None => write!(f, "No authentication"),
AuthenticationMethod::GSSAPI => write!(f, "GSS-API"),
AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"),
AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"),
AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"),
AuthenticationMethod::SSL => write!(f, "SSL"),
AuthenticationMethod::NDS => write!(f, "NDS Authentication"),
AuthenticationMethod::MultiAuthenticationFramework => {
write!(f, "Multi-Authentication Framework")
}
AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"),
AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m),
AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"),
}
}
}
impl AuthenticationMethod {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<AuthenticationMethod, DeserializationError> {
let mut byte_buffer = [0u8; 1];
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(AuthenticationDeserializationError::NoDataFound.into());
}
match byte_buffer[0] {
0 => Ok(AuthenticationMethod::None),
1 => Ok(AuthenticationMethod::GSSAPI),
2 => Ok(AuthenticationMethod::UsernameAndPassword),
3 => Ok(AuthenticationMethod::ChallengeHandshake),
5 => Ok(AuthenticationMethod::ChallengeResponse),
6 => Ok(AuthenticationMethod::SSL),
7 => Ok(AuthenticationMethod::NDS),
8 => Ok(AuthenticationMethod::MultiAuthenticationFramework),
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let value = match self {
AuthenticationMethod::None => 0,
AuthenticationMethod::GSSAPI => 1,
AuthenticationMethod::UsernameAndPassword => 2,
AuthenticationMethod::ChallengeHandshake => 3,
AuthenticationMethod::ChallengeResponse => 5,
AuthenticationMethod::SSL => 6,
AuthenticationMethod::NDS => 7,
AuthenticationMethod::MultiAuthenticationFramework => 8,
AuthenticationMethod::JSONPropertyBlock => 9,
AuthenticationMethod::PrivateMethod(pm) => *pm,
AuthenticationMethod::NoAcceptableMethods => 0xff,
};
Ok(w.write_all(&[value]).await?)
}
}
#[cfg(test)]
impl Arbitrary for AuthenticationMethod {
fn arbitrary(g: &mut Gen) -> AuthenticationMethod {
let mut vals = vec![
AuthenticationMethod::None,
AuthenticationMethod::GSSAPI,
AuthenticationMethod::UsernameAndPassword,
AuthenticationMethod::ChallengeHandshake,
AuthenticationMethod::ChallengeResponse,
AuthenticationMethod::SSL,
AuthenticationMethod::NDS,
AuthenticationMethod::MultiAuthenticationFramework,
AuthenticationMethod::JSONPropertyBlock,
AuthenticationMethod::NoAcceptableMethods,
];
for x in 0x80..0xffu8 {
vals.push(AuthenticationMethod::PrivateMethod(x));
}
g.choose(&vals).unwrap().clone()
}
}
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
#[test]
fn bad_byte() {
let no_len = vec![42];
let mut cursor = Cursor::new(no_len);
let ys = AuthenticationMethod::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::InvalidAuthenticationByte(42)
)),
task::block_on(ys)
);
}
#[test]
fn display_isnt_empty() {
let vals = vec![
AuthenticationMethod::None,
AuthenticationMethod::GSSAPI,
AuthenticationMethod::UsernameAndPassword,
AuthenticationMethod::ChallengeHandshake,
AuthenticationMethod::ChallengeResponse,
AuthenticationMethod::SSL,
AuthenticationMethod::NDS,
AuthenticationMethod::MultiAuthenticationFramework,
AuthenticationMethod::JSONPropertyBlock,
AuthenticationMethod::NoAcceptableMethods,
AuthenticationMethod::PrivateMethod(42),
];
for method in vals.iter() {
let str = format!("{}", method);
assert!(str.is_ascii());
assert!(!str.is_empty());
}
}

View File

@@ -0,0 +1,164 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::network::SOCKSv5Address;
use crate::serialize::read_amt;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::io::ErrorKind;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
#[cfg(test)]
use std::net::Ipv4Addr;
use std::pin::Pin;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ClientConnectionCommand {
EstablishTCPStream,
EstablishTCPPortBinding,
AssociateUDPPort,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientConnectionRequest {
pub command_code: ClientConnectionCommand,
pub destination_address: SOCKSv5Address,
pub destination_port: u16,
}
impl ClientConnectionRequest {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 2];
let raw_r = Pin::into_inner(r);
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
let command_code = match buffer[1] {
0x01 => ClientConnectionCommand::EstablishTCPStream,
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
0x03 => ClientConnectionCommand::AssociateUDPPort,
x => return Err(DeserializationError::InvalidClientCommand(x)),
};
let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ClientConnectionRequest {
command_code,
destination_address,
destination_port,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let command = match self.command_code {
ClientConnectionCommand::EstablishTCPStream => 1,
ClientConnectionCommand::EstablishTCPPortBinding => 2,
ClientConnectionCommand::AssociateUDPPort => 3,
};
w.write_all(&[5, command]).await?;
self.destination_address.write(w).await?;
w.write_all(&[
(self.destination_port >> 8) as u8,
(self.destination_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
}
}
#[cfg(test)]
impl Arbitrary for ClientConnectionCommand {
fn arbitrary(g: &mut Gen) -> ClientConnectionCommand {
let options = [
ClientConnectionCommand::EstablishTCPStream,
ClientConnectionCommand::EstablishTCPPortBinding,
ClientConnectionCommand::AssociateUDPPort,
];
g.choose(&options).unwrap().clone()
}
}
#[cfg(test)]
impl Arbitrary for ClientConnectionRequest {
fn arbitrary(g: &mut Gen) -> Self {
let command_code = ClientConnectionCommand::arbitrary(g);
let destination_address = SOCKSv5Address::arbitrary(g);
let destination_port = u16::arbitrary(g);
ClientConnectionRequest {
command_code,
destination_address,
destination_port,
}
}
}
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
#[test]
fn check_short_reads() {
let empty = vec![];
let mut cursor = Cursor::new(empty);
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
let no_len = vec![5, 1];
let mut cursor = Cursor::new(no_len);
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
}
#[test]
fn check_bad_version() {
let bad_ver = vec![6, 1, 1];
let mut cursor = Cursor::new(bad_ver);
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
}
#[test]
fn check_bad_command() {
let bad_cmd = vec![5, 32, 1];
let mut cursor = Cursor::new(bad_cmd);
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidClientCommand(32)),
task::block_on(ys)
);
}
#[test]
fn short_write_fails_right() {
let mut buffer = [0u8; 2];
let cmd = ClientConnectionRequest {
command_code: ClientConnectionCommand::AssociateUDPPort,
destination_address: SOCKSv5Address::IP4(Ipv4Addr::from(0)),
destination_port: 22,
};
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
let result = task::block_on(cmd.write(&mut cursor));
match result {
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."),
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e),
}
}

View File

@@ -0,0 +1,135 @@
#[cfg(test)]
use crate::errors::AuthenticationDeserializationError;
use crate::errors::{DeserializationError, SerializationError};
use crate::messages::AuthenticationMethod;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::pin::Pin;
/// Client greetings are the first message sent in a SOCKSv5 session. They
/// identify that there's a client that wants to talk to a server, and that
/// they can support any of the provided mechanisms for authenticating to
/// said server. (It feels weird that the offer/choice goes this way instead
/// of the reverse, but whatever.)
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientGreeting {
pub acceptable_methods: Vec<AuthenticationMethod>,
}
impl ClientGreeting {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<ClientGreeting, DeserializationError> {
let mut buffer = [0; 1];
let raw_r = Pin::into_inner(r);
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
for _ in 0..buffer[0] {
acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?);
}
Ok(ClientGreeting { acceptable_methods })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
if self.acceptable_methods.len() > 255 {
return Err(SerializationError::TooManyAuthMethods(
self.acceptable_methods.len(),
));
}
let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2);
buffer.push(5);
buffer.push(self.acceptable_methods.len() as u8);
w.write_all(&buffer).await?;
for authmeth in self.acceptable_methods.iter() {
authmeth.write(w).await?;
}
Ok(())
}
}
#[cfg(test)]
impl Arbitrary for ClientGreeting {
fn arbitrary(g: &mut Gen) -> ClientGreeting {
let amt = u8::arbitrary(g);
let mut acceptable_methods = Vec::with_capacity(amt as usize);
for _ in 0..amt {
acceptable_methods.push(AuthenticationMethod::arbitrary(g));
}
ClientGreeting { acceptable_methods }
}
}
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
#[test]
fn check_short_reads() {
let empty = vec![];
let mut cursor = Cursor::new(empty);
let ys = ClientGreeting::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
let no_len = vec![5];
let mut cursor = Cursor::new(no_len);
let ys = ClientGreeting::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
let bad_len = vec![5, 9];
let mut cursor = Cursor::new(bad_len);
let ys = ClientGreeting::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound
)),
task::block_on(ys)
);
}
#[test]
fn check_bad_version() {
let no_len = vec![6, 1, 1];
let mut cursor = Cursor::new(no_len);
let ys = ClientGreeting::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
}
#[test]
fn check_too_many() {
let mut auth_methods = Vec::with_capacity(512);
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
let greet = ClientGreeting {
acceptable_methods: auth_methods,
};
let mut output = vec![0; 1024];
assert_eq!(
Err(SerializationError::TooManyAuthMethods(512)),
task::block_on(greet.write(&mut output))
);
}

View File

@@ -0,0 +1,86 @@
use crate::errors::{DeserializationError, SerializationError};
#[cfg(test)]
use crate::messages::utils::arbitrary_socks_string;
use crate::serialize::{read_string, write_string};
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::pin::Pin;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientUsernamePassword {
pub username: String,
pub password: String,
}
impl ClientUsernamePassword {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
let raw_r = Pin::into_inner(r);
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
let username = read_string(Pin::new(raw_r)).await?;
let password = read_string(Pin::new(raw_r)).await?;
Ok(ClientUsernamePassword { username, password })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[1]).await?;
write_string(&self.username, w).await?;
write_string(&self.password, w).await
}
}
#[cfg(test)]
impl Arbitrary for ClientUsernamePassword {
fn arbitrary(g: &mut Gen) -> Self {
let username = arbitrary_socks_string(g);
let password = arbitrary_socks_string(g);
ClientUsernamePassword { username, password }
}
}
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
#[test]
fn check_short_reads() {
let empty = vec![];
let mut cursor = Cursor::new(empty);
let ys = ClientUsernamePassword::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
let user_only = vec![1, 3, 102, 111, 111];
let mut cursor = Cursor::new(user_only);
let ys = ClientUsernamePassword::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
}
#[test]
fn check_bad_version() {
let bad_len = vec![5];
let mut cursor = Cursor::new(bad_len);
let ys = ClientUsernamePassword::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidVersion(1, 5)),
task::block_on(ys)
);
}

View File

@@ -0,0 +1,83 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::pin::Pin;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerAuthResponse {
pub success: bool,
}
impl ServerAuthResponse {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
Ok(ServerAuthResponse {
success: buffer[0] == 0,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[1]).await?;
w.write_all(&[if self.success { 0x00 } else { 0xde }])
.await?;
Ok(())
}
}
#[cfg(test)]
impl Arbitrary for ServerAuthResponse {
fn arbitrary(g: &mut Gen) -> ServerAuthResponse {
let success = bool::arbitrary(g);
ServerAuthResponse { success }
}
}
standard_roundtrip!(server_auth_response, ServerAuthResponse);
#[test]
fn check_short_reads() {
let empty = vec![];
let mut cursor = Cursor::new(empty);
let ys = ServerAuthResponse::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
let no_len = vec![1];
let mut cursor = Cursor::new(no_len);
let ys = ServerAuthResponse::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
}
#[test]
fn check_bad_version() {
let no_len = vec![6, 1];
let mut cursor = Cursor::new(no_len);
let ys = ServerAuthResponse::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidVersion(1, 6)),
task::block_on(ys)
);
}

View File

@@ -0,0 +1,86 @@
#[cfg(test)]
use crate::errors::AuthenticationDeserializationError;
use crate::errors::{DeserializationError, SerializationError};
use crate::messages::AuthenticationMethod;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::pin::Pin;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerChoice {
pub chosen_method: AuthenticationMethod,
}
impl ServerChoice {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
let chosen_method = AuthenticationMethod::read(r).await?;
Ok(ServerChoice { chosen_method })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[5]).await?;
self.chosen_method.write(w).await
}
}
#[cfg(test)]
impl Arbitrary for ServerChoice {
fn arbitrary(g: &mut Gen) -> ServerChoice {
ServerChoice {
chosen_method: AuthenticationMethod::arbitrary(g),
}
}
}
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
#[test]
fn check_short_reads() {
let empty = vec![];
let mut cursor = Cursor::new(empty);
let ys = ServerChoice::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
let bad_len = vec![5];
let mut cursor = Cursor::new(bad_len);
let ys = ServerChoice::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound
)),
task::block_on(ys)
);
}
#[test]
fn check_bad_version() {
let no_len = vec![9, 1];
let mut cursor = Cursor::new(no_len);
let ys = ServerChoice::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidVersion(5, 9)),
task::block_on(ys)
);
}

View File

@@ -0,0 +1,202 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::network::SOCKSv5Address;
use crate::serialize::read_amt;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::io::ErrorKind;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use log::warn;
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::net::Ipv4Addr;
use std::pin::Pin;
use thiserror::Error;
#[derive(Clone, Debug, Eq, Error, PartialEq)]
pub enum ServerResponseStatus {
#[error("Actually, everything's fine (weird to see this in an error)")]
RequestGranted,
#[error("General server failure")]
GeneralFailure,
#[error("Connection not allowed by policy rule")]
ConnectionNotAllowedByRule,
#[error("Network unreachable")]
NetworkUnreachable,
#[error("Host unreachable")]
HostUnreachable,
#[error("Connection refused")]
ConnectionRefused,
#[error("TTL expired")]
TTLExpired,
#[error("Command not supported")]
CommandNotSupported,
#[error("Address type not supported")]
AddressTypeNotSupported,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerResponse {
pub status: ServerResponseStatus,
pub bound_address: SOCKSv5Address,
pub bound_port: u16,
}
impl ServerResponse {
pub fn error<E: Into<ServerResponseStatus>>(resp: E) -> ServerResponse {
ServerResponse {
status: resp.into(),
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
bound_port: 0,
}
}
}
impl ServerResponse {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 3];
let raw_r = Pin::into_inner(r);
read_amt(Pin::new(raw_r), 3, &mut buffer).await?;
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if buffer[2] != 0 {
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
}
let status = match buffer[1] {
0x00 => ServerResponseStatus::RequestGranted,
0x01 => ServerResponseStatus::GeneralFailure,
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
0x03 => ServerResponseStatus::NetworkUnreachable,
0x04 => ServerResponseStatus::HostUnreachable,
0x05 => ServerResponseStatus::ConnectionRefused,
0x06 => ServerResponseStatus::TTLExpired,
0x07 => ServerResponseStatus::CommandNotSupported,
0x08 => ServerResponseStatus::AddressTypeNotSupported,
x => return Err(DeserializationError::InvalidServerResponse(x)),
};
let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ServerResponse {
status,
bound_address,
bound_port,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let status_code = match self.status {
ServerResponseStatus::RequestGranted => 0x00,
ServerResponseStatus::GeneralFailure => 0x01,
ServerResponseStatus::ConnectionNotAllowedByRule => 0x02,
ServerResponseStatus::NetworkUnreachable => 0x03,
ServerResponseStatus::HostUnreachable => 0x04,
ServerResponseStatus::ConnectionRefused => 0x05,
ServerResponseStatus::TTLExpired => 0x06,
ServerResponseStatus::CommandNotSupported => 0x07,
ServerResponseStatus::AddressTypeNotSupported => 0x08,
};
w.write_all(&[5, status_code, 0]).await?;
self.bound_address.write(w).await?;
w.write_all(&[
(self.bound_port >> 8) as u8,
(self.bound_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
}
}
#[cfg(test)]
impl Arbitrary for ServerResponseStatus {
fn arbitrary(g: &mut Gen) -> ServerResponseStatus {
let options = [
ServerResponseStatus::RequestGranted,
ServerResponseStatus::GeneralFailure,
ServerResponseStatus::ConnectionNotAllowedByRule,
ServerResponseStatus::NetworkUnreachable,
ServerResponseStatus::HostUnreachable,
ServerResponseStatus::ConnectionRefused,
ServerResponseStatus::TTLExpired,
ServerResponseStatus::CommandNotSupported,
ServerResponseStatus::AddressTypeNotSupported,
];
g.choose(&options).unwrap().clone()
}
}
#[cfg(test)]
impl Arbitrary for ServerResponse {
fn arbitrary(g: &mut Gen) -> Self {
let status = ServerResponseStatus::arbitrary(g);
let bound_address = SOCKSv5Address::arbitrary(g);
let bound_port = u16::arbitrary(g);
ServerResponse {
status,
bound_address,
bound_port,
}
}
}
standard_roundtrip!(server_response_roundtrips, ServerResponse);
#[test]
fn check_short_reads() {
let empty = vec![];
let mut cursor = Cursor::new(empty);
let ys = ServerResponse::read(Pin::new(&mut cursor));
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
}
#[test]
fn check_bad_version() {
let bad_ver = vec![6, 1, 1];
let mut cursor = Cursor::new(bad_ver);
let ys = ServerResponse::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
}
#[test]
fn check_bad_command() {
let bad_cmd = vec![5, 32, 0x42];
let mut cursor = Cursor::new(bad_cmd);
let ys = ServerResponse::read(Pin::new(&mut cursor));
assert_eq!(
Err(DeserializationError::InvalidServerResponse(32)),
task::block_on(ys)
);
}
#[test]
fn short_write_fails_right() {
let mut buffer = [0u8; 2];
let cmd = ServerResponse::error(ServerResponseStatus::AddressTypeNotSupported);
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
let result = task::block_on(cmd.write(&mut cursor));
match result {
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."),
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e),
}
}

33
src/messages/utils.rs Normal file
View File

@@ -0,0 +1,33 @@
#[cfg(test)]
use quickcheck::{Arbitrary, Gen};
#[cfg(test)]
pub fn arbitrary_socks_string(g: &mut Gen) -> String {
loop {
let mut potential = String::arbitrary(g);
potential.truncate(255);
let bytestring = potential.as_bytes();
if bytestring.len() > 0 && bytestring.len() < 256 {
return potential;
}
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! standard_roundtrip {
($name: ident, $t: ty) => {
#[cfg(test)]
quickcheck! {
fn $name(xs: $t) -> bool {
let mut buffer = vec![];
task::block_on(xs.write(&mut buffer)).unwrap();
let mut cursor = Cursor::new(buffer);
let ys = <$t>::read(Pin::new(&mut cursor));
xs == task::block_on(ys).unwrap()
}
}
};
}

View File

@@ -1,10 +1,15 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::errors::{DeserializationError, SerializationError};
#[cfg(test)] #[cfg(test)]
use crate::messages::arbitrary_socks_string; use crate::messages::utils::arbitrary_socks_string;
use crate::serialize::{read_amt, read_string, write_string}; use crate::serialize::{read_amt, read_string, write_string};
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)] #[cfg(test)]
use quickcheck::{Arbitrary, Gen}; use quickcheck::{quickcheck, Arbitrary, Gen};
use std::fmt; use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::pin::Pin; use std::pin::Pin;
@@ -149,3 +154,5 @@ impl Arbitrary for SOCKSv5Address {
.clone() .clone()
} }
} }
standard_roundtrip!(address_roundtrips, SOCKSv5Address);

View File

@@ -310,4 +310,4 @@ where
} }
} }
} }
} }