Split out the messages into individual files,, and add negative tests, so we can aspire towards good coverage.
This commit is contained in:
115
src/errors.rs
115
src/errors.rs
@@ -26,6 +26,69 @@ pub enum DeserializationError {
|
||||
InvalidServerResponse(u8),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn des_error_reasonable_equals() {
|
||||
let invalid_version = DeserializationError::InvalidVersion(1, 2);
|
||||
assert_eq!(invalid_version, invalid_version);
|
||||
let not_enough = DeserializationError::NotEnoughData;
|
||||
assert_eq!(not_enough, not_enough);
|
||||
let invalid_empty = DeserializationError::InvalidEmptyString;
|
||||
assert_eq!(invalid_empty, invalid_empty);
|
||||
let auth_method = DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound,
|
||||
);
|
||||
assert_eq!(auth_method, auth_method);
|
||||
let utf8 = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
|
||||
assert_eq!(utf8, utf8);
|
||||
let invalid_address = DeserializationError::InvalidAddressType(3);
|
||||
assert_eq!(invalid_address, invalid_address);
|
||||
let invalid_client_cmd = DeserializationError::InvalidClientCommand(32);
|
||||
assert_eq!(invalid_client_cmd, invalid_client_cmd);
|
||||
let invalid_server_resp = DeserializationError::InvalidServerResponse(42);
|
||||
assert_eq!(invalid_server_resp, invalid_server_resp);
|
||||
|
||||
assert_ne!(invalid_version, invalid_address);
|
||||
assert_ne!(not_enough, invalid_empty);
|
||||
assert_ne!(auth_method, invalid_client_cmd);
|
||||
assert_ne!(utf8, invalid_server_resp);
|
||||
}
|
||||
|
||||
impl PartialEq for DeserializationError {
|
||||
fn eq(&self, other: &DeserializationError) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
&DeserializationError::InvalidVersion(a, b),
|
||||
&DeserializationError::InvalidVersion(x, y),
|
||||
) => (a == x) && (b == y),
|
||||
(&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true,
|
||||
(
|
||||
&DeserializationError::InvalidEmptyString,
|
||||
&DeserializationError::InvalidEmptyString,
|
||||
) => true,
|
||||
(
|
||||
&DeserializationError::AuthenticationMethodError(ref a),
|
||||
&DeserializationError::AuthenticationMethodError(ref b),
|
||||
) => a == b,
|
||||
(&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => {
|
||||
a == b
|
||||
}
|
||||
(
|
||||
&DeserializationError::InvalidAddressType(a),
|
||||
&DeserializationError::InvalidAddressType(b),
|
||||
) => a == b,
|
||||
(
|
||||
&DeserializationError::InvalidClientCommand(a),
|
||||
&DeserializationError::InvalidClientCommand(b),
|
||||
) => a == b,
|
||||
(
|
||||
&DeserializationError::InvalidServerResponse(a),
|
||||
&DeserializationError::InvalidServerResponse(b),
|
||||
) => a == b,
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All the errors that can occur trying to turn SOCKSv5 message structures
|
||||
/// into raw bytes. There's a few places that the message structures allow
|
||||
/// for information that can't be serialized; often, you have to be careful
|
||||
@@ -40,6 +103,32 @@ pub enum SerializationError {
|
||||
IOError(#[from] io::Error),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ser_err_reasonable_equals() {
|
||||
let too_many = SerializationError::TooManyAuthMethods(512);
|
||||
assert_eq!(too_many, too_many);
|
||||
let invalid_str = SerializationError::InvalidStringLength("Whoopsy!".to_string());
|
||||
assert_eq!(invalid_str, invalid_str);
|
||||
|
||||
assert_ne!(too_many, invalid_str);
|
||||
}
|
||||
|
||||
impl PartialEq for SerializationError {
|
||||
fn eq(&self, other: &SerializationError) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
&SerializationError::TooManyAuthMethods(a),
|
||||
&SerializationError::TooManyAuthMethods(b),
|
||||
) => a == b,
|
||||
(
|
||||
&SerializationError::InvalidStringLength(ref a),
|
||||
&SerializationError::InvalidStringLength(ref b),
|
||||
) => a == b,
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AuthenticationDeserializationError {
|
||||
#[error("No data found deserializing SOCKS authentication type")]
|
||||
@@ -49,3 +138,29 @@ pub enum AuthenticationDeserializationError {
|
||||
#[error("IO error reading SOCKS authentication type: {0}")]
|
||||
IOError(#[from] io::Error),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_des_err_reasonable_equals() {
|
||||
let no_data = AuthenticationDeserializationError::NoDataFound;
|
||||
assert_eq!(no_data, no_data);
|
||||
let invalid_auth = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
|
||||
assert_eq!(invalid_auth, invalid_auth);
|
||||
|
||||
assert_ne!(no_data, invalid_auth);
|
||||
}
|
||||
|
||||
impl PartialEq for AuthenticationDeserializationError {
|
||||
fn eq(&self, other: &AuthenticationDeserializationError) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
&AuthenticationDeserializationError::NoDataFound,
|
||||
&AuthenticationDeserializationError::NoDataFound,
|
||||
) => true,
|
||||
(
|
||||
&AuthenticationDeserializationError::InvalidAuthenticationByte(x),
|
||||
&AuthenticationDeserializationError::InvalidAuthenticationByte(y),
|
||||
) => x == y,
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
620
src/messages.rs
620
src/messages.rs
@@ -1,604 +1,16 @@
|
||||
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
|
||||
use crate::network::SOCKSv5Address;
|
||||
use crate::serialize::{read_amt, read_string, write_string};
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use log::warn;
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::fmt;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::pin::Pin;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Client greetings are the first message sent in a SOCKSv5 session. They
|
||||
/// identify that there's a client that wants to talk to a server, and that
|
||||
/// they can support any of the provided mechanisms for authenticating to
|
||||
/// said server. (It feels weird that the offer/choice goes this way instead
|
||||
/// of the reverse, but whatever.)
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientGreeting {
|
||||
pub acceptable_methods: Vec<AuthenticationMethod>,
|
||||
}
|
||||
|
||||
impl ClientGreeting {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<ClientGreeting, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
if raw_r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
if raw_r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
|
||||
for _ in 0..buffer[0] {
|
||||
acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?);
|
||||
}
|
||||
|
||||
Ok(ClientGreeting { acceptable_methods })
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
if self.acceptable_methods.len() > 255 {
|
||||
return Err(SerializationError::TooManyAuthMethods(
|
||||
self.acceptable_methods.len(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2);
|
||||
buffer.push(5);
|
||||
buffer.push(self.acceptable_methods.len() as u8);
|
||||
w.write_all(&buffer).await?;
|
||||
for authmeth in self.acceptable_methods.iter() {
|
||||
authmeth.write(w).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientGreeting {
|
||||
fn arbitrary(g: &mut Gen) -> ClientGreeting {
|
||||
let amt = u8::arbitrary(g);
|
||||
let mut acceptable_methods = Vec::with_capacity(amt as usize);
|
||||
|
||||
for _ in 0..amt {
|
||||
acceptable_methods.push(AuthenticationMethod::arbitrary(g));
|
||||
}
|
||||
|
||||
ClientGreeting { acceptable_methods }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ServerChoice {
|
||||
pub chosen_method: AuthenticationMethod,
|
||||
}
|
||||
|
||||
impl ServerChoice {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
mut r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
let chosen_method = AuthenticationMethod::read(r).await?;
|
||||
|
||||
Ok(ServerChoice { chosen_method })
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[5]).await?;
|
||||
self.chosen_method.write(w).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerChoice {
|
||||
fn arbitrary(g: &mut Gen) -> ServerChoice {
|
||||
ServerChoice {
|
||||
chosen_method: AuthenticationMethod::arbitrary(g),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientUsernamePassword {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl ClientUsernamePassword {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
if raw_r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
let username = read_string(Pin::new(raw_r)).await?;
|
||||
let password = read_string(Pin::new(raw_r)).await?;
|
||||
|
||||
Ok(ClientUsernamePassword { username, password })
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[1]).await?;
|
||||
write_string(&self.username, w).await?;
|
||||
write_string(&self.password, w).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientUsernamePassword {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let username = arbitrary_socks_string(g);
|
||||
let password = arbitrary_socks_string(g);
|
||||
|
||||
ClientUsernamePassword { username, password }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn arbitrary_socks_string(g: &mut Gen) -> String {
|
||||
loop {
|
||||
let mut potential = String::arbitrary(g);
|
||||
|
||||
potential.truncate(255);
|
||||
let bytestring = potential.as_bytes();
|
||||
|
||||
if bytestring.len() > 0 && bytestring.len() < 256 {
|
||||
return potential;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ServerAuthResponse {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
impl ServerAuthResponse {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
mut r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
Ok(ServerAuthResponse {
|
||||
success: buffer[0] == 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[1]).await?;
|
||||
w.write_all(&[if self.success { 0x00 } else { 0xde }])
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerAuthResponse {
|
||||
fn arbitrary(g: &mut Gen) -> ServerAuthResponse {
|
||||
let success = bool::arbitrary(g);
|
||||
ServerAuthResponse { success }
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum AuthenticationMethod {
|
||||
None,
|
||||
GSSAPI,
|
||||
UsernameAndPassword,
|
||||
ChallengeHandshake,
|
||||
ChallengeResponse,
|
||||
SSL,
|
||||
NDS,
|
||||
MultiAuthenticationFramework,
|
||||
JSONPropertyBlock,
|
||||
PrivateMethod(u8),
|
||||
NoAcceptableMethods,
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthenticationMethod {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
AuthenticationMethod::None => write!(f, "No authentication"),
|
||||
AuthenticationMethod::GSSAPI => write!(f, "GSS-API"),
|
||||
AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"),
|
||||
AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"),
|
||||
AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"),
|
||||
AuthenticationMethod::SSL => write!(f, "SSL"),
|
||||
AuthenticationMethod::NDS => write!(f, "NDS Authentication"),
|
||||
AuthenticationMethod::MultiAuthenticationFramework => {
|
||||
write!(f, "Multi-Authentication Framework")
|
||||
}
|
||||
AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"),
|
||||
AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m),
|
||||
AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthenticationMethod {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
mut r: Pin<&mut R>,
|
||||
) -> Result<AuthenticationMethod, DeserializationError> {
|
||||
let mut byte_buffer = [0u8; 1];
|
||||
let amount_read = r.read(&mut byte_buffer).await?;
|
||||
|
||||
if amount_read == 0 {
|
||||
return Err(AuthenticationDeserializationError::NoDataFound.into());
|
||||
}
|
||||
|
||||
match byte_buffer[0] {
|
||||
0 => Ok(AuthenticationMethod::None),
|
||||
1 => Ok(AuthenticationMethod::GSSAPI),
|
||||
2 => Ok(AuthenticationMethod::UsernameAndPassword),
|
||||
3 => Ok(AuthenticationMethod::ChallengeHandshake),
|
||||
5 => Ok(AuthenticationMethod::ChallengeResponse),
|
||||
6 => Ok(AuthenticationMethod::SSL),
|
||||
7 => Ok(AuthenticationMethod::NDS),
|
||||
8 => Ok(AuthenticationMethod::MultiAuthenticationFramework),
|
||||
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
|
||||
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
|
||||
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
|
||||
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let value = match self {
|
||||
AuthenticationMethod::None => 0,
|
||||
AuthenticationMethod::GSSAPI => 1,
|
||||
AuthenticationMethod::UsernameAndPassword => 2,
|
||||
AuthenticationMethod::ChallengeHandshake => 3,
|
||||
AuthenticationMethod::ChallengeResponse => 5,
|
||||
AuthenticationMethod::SSL => 6,
|
||||
AuthenticationMethod::NDS => 7,
|
||||
AuthenticationMethod::MultiAuthenticationFramework => 8,
|
||||
AuthenticationMethod::JSONPropertyBlock => 9,
|
||||
AuthenticationMethod::PrivateMethod(pm) => *pm,
|
||||
AuthenticationMethod::NoAcceptableMethods => 0xff,
|
||||
};
|
||||
|
||||
Ok(w.write_all(&[value]).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for AuthenticationMethod {
|
||||
fn arbitrary(g: &mut Gen) -> AuthenticationMethod {
|
||||
let mut vals = vec![
|
||||
AuthenticationMethod::None,
|
||||
AuthenticationMethod::GSSAPI,
|
||||
AuthenticationMethod::UsernameAndPassword,
|
||||
AuthenticationMethod::ChallengeHandshake,
|
||||
AuthenticationMethod::ChallengeResponse,
|
||||
AuthenticationMethod::SSL,
|
||||
AuthenticationMethod::NDS,
|
||||
AuthenticationMethod::MultiAuthenticationFramework,
|
||||
AuthenticationMethod::JSONPropertyBlock,
|
||||
AuthenticationMethod::NoAcceptableMethods,
|
||||
];
|
||||
for x in 0x80..0xffu8 {
|
||||
vals.push(AuthenticationMethod::PrivateMethod(x));
|
||||
}
|
||||
g.choose(&vals).unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum ClientConnectionCommand {
|
||||
EstablishTCPStream,
|
||||
EstablishTCPPortBinding,
|
||||
AssociateUDPPort,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientConnectionRequest {
|
||||
pub command_code: ClientConnectionCommand,
|
||||
pub destination_address: SOCKSv5Address,
|
||||
pub destination_port: u16,
|
||||
}
|
||||
|
||||
impl ClientConnectionRequest {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 2];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
let command_code = match buffer[1] {
|
||||
0x01 => ClientConnectionCommand::EstablishTCPStream,
|
||||
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
0x03 => ClientConnectionCommand::AssociateUDPPort,
|
||||
x => return Err(DeserializationError::InvalidClientCommand(x)),
|
||||
};
|
||||
|
||||
let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
|
||||
|
||||
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
|
||||
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
|
||||
Ok(ClientConnectionRequest {
|
||||
command_code,
|
||||
destination_address,
|
||||
destination_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let command = match self.command_code {
|
||||
ClientConnectionCommand::EstablishTCPStream => 1,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => 2,
|
||||
ClientConnectionCommand::AssociateUDPPort => 3,
|
||||
};
|
||||
|
||||
w.write_all(&[5, command]).await?;
|
||||
self.destination_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.destination_port >> 8) as u8,
|
||||
(self.destination_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientConnectionCommand {
|
||||
fn arbitrary(g: &mut Gen) -> ClientConnectionCommand {
|
||||
g.choose(&[
|
||||
ClientConnectionCommand::EstablishTCPStream,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
ClientConnectionCommand::AssociateUDPPort,
|
||||
])
|
||||
.unwrap()
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientConnectionRequest {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let command_code = ClientConnectionCommand::arbitrary(g);
|
||||
let destination_address = SOCKSv5Address::arbitrary(g);
|
||||
let destination_port = u16::arbitrary(g);
|
||||
|
||||
ClientConnectionRequest {
|
||||
command_code,
|
||||
destination_address,
|
||||
destination_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Error, PartialEq)]
|
||||
pub enum ServerResponseStatus {
|
||||
#[error("Actually, everything's fine (weird to see this in an error)")]
|
||||
RequestGranted,
|
||||
#[error("General server failure")]
|
||||
GeneralFailure,
|
||||
#[error("Connection not allowed by policy rule")]
|
||||
ConnectionNotAllowedByRule,
|
||||
#[error("Network unreachable")]
|
||||
NetworkUnreachable,
|
||||
#[error("Host unreachable")]
|
||||
HostUnreachable,
|
||||
#[error("Connection refused")]
|
||||
ConnectionRefused,
|
||||
#[error("TTL expired")]
|
||||
TTLExpired,
|
||||
#[error("Command not supported")]
|
||||
CommandNotSupported,
|
||||
#[error("Address type not supported")]
|
||||
AddressTypeNotSupported,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ServerResponse {
|
||||
pub status: ServerResponseStatus,
|
||||
pub bound_address: SOCKSv5Address,
|
||||
pub bound_port: u16,
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub fn error<E: Into<ServerResponseStatus>>(resp: E) -> ServerResponse {
|
||||
ServerResponse {
|
||||
status: resp.into(),
|
||||
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
|
||||
bound_port: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 3];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
read_amt(Pin::new(raw_r), 3, &mut buffer).await?;
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
if buffer[2] != 0 {
|
||||
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
|
||||
}
|
||||
|
||||
let status = match buffer[1] {
|
||||
0x00 => ServerResponseStatus::RequestGranted,
|
||||
0x01 => ServerResponseStatus::GeneralFailure,
|
||||
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||
0x03 => ServerResponseStatus::NetworkUnreachable,
|
||||
0x04 => ServerResponseStatus::HostUnreachable,
|
||||
0x05 => ServerResponseStatus::ConnectionRefused,
|
||||
0x06 => ServerResponseStatus::TTLExpired,
|
||||
0x07 => ServerResponseStatus::CommandNotSupported,
|
||||
0x08 => ServerResponseStatus::AddressTypeNotSupported,
|
||||
x => return Err(DeserializationError::InvalidServerResponse(x)),
|
||||
};
|
||||
|
||||
let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
|
||||
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
|
||||
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
|
||||
Ok(ServerResponse {
|
||||
status,
|
||||
bound_address,
|
||||
bound_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let status_code = match self.status {
|
||||
ServerResponseStatus::RequestGranted => 0x00,
|
||||
ServerResponseStatus::GeneralFailure => 0x01,
|
||||
ServerResponseStatus::ConnectionNotAllowedByRule => 0x02,
|
||||
ServerResponseStatus::NetworkUnreachable => 0x03,
|
||||
ServerResponseStatus::HostUnreachable => 0x04,
|
||||
ServerResponseStatus::ConnectionRefused => 0x05,
|
||||
ServerResponseStatus::TTLExpired => 0x06,
|
||||
ServerResponseStatus::CommandNotSupported => 0x07,
|
||||
ServerResponseStatus::AddressTypeNotSupported => 0x08,
|
||||
};
|
||||
|
||||
w.write_all(&[5, status_code, 0]).await?;
|
||||
self.bound_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.bound_port >> 8) as u8,
|
||||
(self.bound_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerResponseStatus {
|
||||
fn arbitrary(g: &mut Gen) -> ServerResponseStatus {
|
||||
g.choose(&[
|
||||
ServerResponseStatus::RequestGranted,
|
||||
ServerResponseStatus::GeneralFailure,
|
||||
ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||
ServerResponseStatus::NetworkUnreachable,
|
||||
ServerResponseStatus::HostUnreachable,
|
||||
ServerResponseStatus::ConnectionRefused,
|
||||
ServerResponseStatus::TTLExpired,
|
||||
ServerResponseStatus::CommandNotSupported,
|
||||
ServerResponseStatus::AddressTypeNotSupported,
|
||||
])
|
||||
.unwrap()
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerResponse {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let status = ServerResponseStatus::arbitrary(g);
|
||||
let bound_address = SOCKSv5Address::arbitrary(g);
|
||||
let bound_port = u16::arbitrary(g);
|
||||
|
||||
ServerResponse {
|
||||
status,
|
||||
bound_address,
|
||||
bound_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! standard_roundtrip {
|
||||
($name: ident, $t: ty) => {
|
||||
#[cfg(test)]
|
||||
quickcheck! {
|
||||
fn $name(xs: $t) -> bool {
|
||||
let mut buffer = vec![];
|
||||
task::block_on(xs.write(&mut buffer)).unwrap();
|
||||
let mut cursor = Cursor::new(buffer);
|
||||
let ys = <$t>::read(Pin::new(&mut cursor));
|
||||
xs == task::block_on(ys).unwrap()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
||||
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
||||
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
||||
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
||||
standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
||||
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
|
||||
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
||||
standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
||||
mod authentication_method;
|
||||
mod client_command;
|
||||
mod client_greeting;
|
||||
mod client_username_password;
|
||||
mod server_auth_response;
|
||||
mod server_choice;
|
||||
mod server_response;
|
||||
pub(crate) mod utils;
|
||||
|
||||
pub use crate::messages::authentication_method::AuthenticationMethod;
|
||||
pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest};
|
||||
pub use crate::messages::client_greeting::ClientGreeting;
|
||||
pub use crate::messages::client_username_password::ClientUsernamePassword;
|
||||
pub use crate::messages::server_auth_response::ServerAuthResponse;
|
||||
pub use crate::messages::server_choice::ServerChoice;
|
||||
pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus};
|
||||
|
||||
156
src/messages/authentication_method.rs
Normal file
156
src/messages/authentication_method.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::fmt;
|
||||
use std::pin::Pin;
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum AuthenticationMethod {
|
||||
None,
|
||||
GSSAPI,
|
||||
UsernameAndPassword,
|
||||
ChallengeHandshake,
|
||||
ChallengeResponse,
|
||||
SSL,
|
||||
NDS,
|
||||
MultiAuthenticationFramework,
|
||||
JSONPropertyBlock,
|
||||
PrivateMethod(u8),
|
||||
NoAcceptableMethods,
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthenticationMethod {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
AuthenticationMethod::None => write!(f, "No authentication"),
|
||||
AuthenticationMethod::GSSAPI => write!(f, "GSS-API"),
|
||||
AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"),
|
||||
AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"),
|
||||
AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"),
|
||||
AuthenticationMethod::SSL => write!(f, "SSL"),
|
||||
AuthenticationMethod::NDS => write!(f, "NDS Authentication"),
|
||||
AuthenticationMethod::MultiAuthenticationFramework => {
|
||||
write!(f, "Multi-Authentication Framework")
|
||||
}
|
||||
AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"),
|
||||
AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m),
|
||||
AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthenticationMethod {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
mut r: Pin<&mut R>,
|
||||
) -> Result<AuthenticationMethod, DeserializationError> {
|
||||
let mut byte_buffer = [0u8; 1];
|
||||
let amount_read = r.read(&mut byte_buffer).await?;
|
||||
|
||||
if amount_read == 0 {
|
||||
return Err(AuthenticationDeserializationError::NoDataFound.into());
|
||||
}
|
||||
|
||||
match byte_buffer[0] {
|
||||
0 => Ok(AuthenticationMethod::None),
|
||||
1 => Ok(AuthenticationMethod::GSSAPI),
|
||||
2 => Ok(AuthenticationMethod::UsernameAndPassword),
|
||||
3 => Ok(AuthenticationMethod::ChallengeHandshake),
|
||||
5 => Ok(AuthenticationMethod::ChallengeResponse),
|
||||
6 => Ok(AuthenticationMethod::SSL),
|
||||
7 => Ok(AuthenticationMethod::NDS),
|
||||
8 => Ok(AuthenticationMethod::MultiAuthenticationFramework),
|
||||
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
|
||||
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
|
||||
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
|
||||
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let value = match self {
|
||||
AuthenticationMethod::None => 0,
|
||||
AuthenticationMethod::GSSAPI => 1,
|
||||
AuthenticationMethod::UsernameAndPassword => 2,
|
||||
AuthenticationMethod::ChallengeHandshake => 3,
|
||||
AuthenticationMethod::ChallengeResponse => 5,
|
||||
AuthenticationMethod::SSL => 6,
|
||||
AuthenticationMethod::NDS => 7,
|
||||
AuthenticationMethod::MultiAuthenticationFramework => 8,
|
||||
AuthenticationMethod::JSONPropertyBlock => 9,
|
||||
AuthenticationMethod::PrivateMethod(pm) => *pm,
|
||||
AuthenticationMethod::NoAcceptableMethods => 0xff,
|
||||
};
|
||||
|
||||
Ok(w.write_all(&[value]).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for AuthenticationMethod {
|
||||
fn arbitrary(g: &mut Gen) -> AuthenticationMethod {
|
||||
let mut vals = vec![
|
||||
AuthenticationMethod::None,
|
||||
AuthenticationMethod::GSSAPI,
|
||||
AuthenticationMethod::UsernameAndPassword,
|
||||
AuthenticationMethod::ChallengeHandshake,
|
||||
AuthenticationMethod::ChallengeResponse,
|
||||
AuthenticationMethod::SSL,
|
||||
AuthenticationMethod::NDS,
|
||||
AuthenticationMethod::MultiAuthenticationFramework,
|
||||
AuthenticationMethod::JSONPropertyBlock,
|
||||
AuthenticationMethod::NoAcceptableMethods,
|
||||
];
|
||||
for x in 0x80..0xffu8 {
|
||||
vals.push(AuthenticationMethod::PrivateMethod(x));
|
||||
}
|
||||
g.choose(&vals).unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
||||
|
||||
#[test]
|
||||
fn bad_byte() {
|
||||
let no_len = vec![42];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = AuthenticationMethod::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::InvalidAuthenticationByte(42)
|
||||
)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_isnt_empty() {
|
||||
let vals = vec![
|
||||
AuthenticationMethod::None,
|
||||
AuthenticationMethod::GSSAPI,
|
||||
AuthenticationMethod::UsernameAndPassword,
|
||||
AuthenticationMethod::ChallengeHandshake,
|
||||
AuthenticationMethod::ChallengeResponse,
|
||||
AuthenticationMethod::SSL,
|
||||
AuthenticationMethod::NDS,
|
||||
AuthenticationMethod::MultiAuthenticationFramework,
|
||||
AuthenticationMethod::JSONPropertyBlock,
|
||||
AuthenticationMethod::NoAcceptableMethods,
|
||||
AuthenticationMethod::PrivateMethod(42),
|
||||
];
|
||||
|
||||
for method in vals.iter() {
|
||||
let str = format!("{}", method);
|
||||
assert!(str.is_ascii());
|
||||
assert!(!str.is_empty());
|
||||
}
|
||||
}
|
||||
164
src/messages/client_command.rs
Normal file
164
src/messages/client_command.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::network::SOCKSv5Address;
|
||||
use crate::serialize::read_amt;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::io::ErrorKind;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
#[cfg(test)]
|
||||
use std::net::Ipv4Addr;
|
||||
use std::pin::Pin;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum ClientConnectionCommand {
|
||||
EstablishTCPStream,
|
||||
EstablishTCPPortBinding,
|
||||
AssociateUDPPort,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientConnectionRequest {
|
||||
pub command_code: ClientConnectionCommand,
|
||||
pub destination_address: SOCKSv5Address,
|
||||
pub destination_port: u16,
|
||||
}
|
||||
|
||||
impl ClientConnectionRequest {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 2];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
let command_code = match buffer[1] {
|
||||
0x01 => ClientConnectionCommand::EstablishTCPStream,
|
||||
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
0x03 => ClientConnectionCommand::AssociateUDPPort,
|
||||
x => return Err(DeserializationError::InvalidClientCommand(x)),
|
||||
};
|
||||
|
||||
let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
|
||||
|
||||
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
|
||||
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
|
||||
Ok(ClientConnectionRequest {
|
||||
command_code,
|
||||
destination_address,
|
||||
destination_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let command = match self.command_code {
|
||||
ClientConnectionCommand::EstablishTCPStream => 1,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => 2,
|
||||
ClientConnectionCommand::AssociateUDPPort => 3,
|
||||
};
|
||||
|
||||
w.write_all(&[5, command]).await?;
|
||||
self.destination_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.destination_port >> 8) as u8,
|
||||
(self.destination_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientConnectionCommand {
|
||||
fn arbitrary(g: &mut Gen) -> ClientConnectionCommand {
|
||||
let options = [
|
||||
ClientConnectionCommand::EstablishTCPStream,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
ClientConnectionCommand::AssociateUDPPort,
|
||||
];
|
||||
g.choose(&options).unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientConnectionRequest {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let command_code = ClientConnectionCommand::arbitrary(g);
|
||||
let destination_address = SOCKSv5Address::arbitrary(g);
|
||||
let destination_port = u16::arbitrary(g);
|
||||
|
||||
ClientConnectionRequest {
|
||||
command_code,
|
||||
destination_address,
|
||||
destination_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
|
||||
let no_len = vec![5, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
let bad_ver = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(bad_ver);
|
||||
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_command() {
|
||||
let bad_cmd = vec![5, 32, 1];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ClientConnectionRequest::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidClientCommand(32)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn short_write_fails_right() {
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ClientConnectionRequest {
|
||||
command_code: ClientConnectionCommand::AssociateUDPPort,
|
||||
destination_address: SOCKSv5Address::IP4(Ipv4Addr::from(0)),
|
||||
destination_port: 22,
|
||||
};
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = task::block_on(cmd.write(&mut cursor));
|
||||
match result {
|
||||
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
||||
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e),
|
||||
}
|
||||
}
|
||||
135
src/messages/client_greeting.rs
Normal file
135
src/messages/client_greeting.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
#[cfg(test)]
|
||||
use crate::errors::AuthenticationDeserializationError;
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::messages::AuthenticationMethod;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Client greetings are the first message sent in a SOCKSv5 session. They
|
||||
/// identify that there's a client that wants to talk to a server, and that
|
||||
/// they can support any of the provided mechanisms for authenticating to
|
||||
/// said server. (It feels weird that the offer/choice goes this way instead
|
||||
/// of the reverse, but whatever.)
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientGreeting {
|
||||
pub acceptable_methods: Vec<AuthenticationMethod>,
|
||||
}
|
||||
|
||||
impl ClientGreeting {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<ClientGreeting, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
if raw_r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
if raw_r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
|
||||
for _ in 0..buffer[0] {
|
||||
acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?);
|
||||
}
|
||||
|
||||
Ok(ClientGreeting { acceptable_methods })
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
if self.acceptable_methods.len() > 255 {
|
||||
return Err(SerializationError::TooManyAuthMethods(
|
||||
self.acceptable_methods.len(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2);
|
||||
buffer.push(5);
|
||||
buffer.push(self.acceptable_methods.len() as u8);
|
||||
w.write_all(&buffer).await?;
|
||||
for authmeth in self.acceptable_methods.iter() {
|
||||
authmeth.write(w).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientGreeting {
|
||||
fn arbitrary(g: &mut Gen) -> ClientGreeting {
|
||||
let amt = u8::arbitrary(g);
|
||||
let mut acceptable_methods = Vec::with_capacity(amt as usize);
|
||||
|
||||
for _ in 0..amt {
|
||||
acceptable_methods.push(AuthenticationMethod::arbitrary(g));
|
||||
}
|
||||
|
||||
ClientGreeting { acceptable_methods }
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientGreeting::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
|
||||
let no_len = vec![5];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientGreeting::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
|
||||
let bad_len = vec![5, 9];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ClientGreeting::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound
|
||||
)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
let no_len = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientGreeting::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_too_many() {
|
||||
let mut auth_methods = Vec::with_capacity(512);
|
||||
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
|
||||
let greet = ClientGreeting {
|
||||
acceptable_methods: auth_methods,
|
||||
};
|
||||
let mut output = vec![0; 1024];
|
||||
assert_eq!(
|
||||
Err(SerializationError::TooManyAuthMethods(512)),
|
||||
task::block_on(greet.write(&mut output))
|
||||
);
|
||||
}
|
||||
86
src/messages/client_username_password.rs
Normal file
86
src/messages/client_username_password.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
#[cfg(test)]
|
||||
use crate::messages::utils::arbitrary_socks_string;
|
||||
use crate::serialize::{read_string, write_string};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::pin::Pin;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientUsernamePassword {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl ClientUsernamePassword {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
if raw_r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
let username = read_string(Pin::new(raw_r)).await?;
|
||||
let password = read_string(Pin::new(raw_r)).await?;
|
||||
|
||||
Ok(ClientUsernamePassword { username, password })
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[1]).await?;
|
||||
write_string(&self.username, w).await?;
|
||||
write_string(&self.password, w).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientUsernamePassword {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let username = arbitrary_socks_string(g);
|
||||
let password = arbitrary_socks_string(g);
|
||||
|
||||
ClientUsernamePassword { username, password }
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientUsernamePassword::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
|
||||
let user_only = vec![1, 3, 102, 111, 111];
|
||||
let mut cursor = Cursor::new(user_only);
|
||||
let ys = ClientUsernamePassword::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
let bad_len = vec![5];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ClientUsernamePassword::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(1, 5)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
83
src/messages/server_auth_response.rs
Normal file
83
src/messages/server_auth_response.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::pin::Pin;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ServerAuthResponse {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
impl ServerAuthResponse {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
mut r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
Ok(ServerAuthResponse {
|
||||
success: buffer[0] == 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[1]).await?;
|
||||
w.write_all(&[if self.success { 0x00 } else { 0xde }])
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerAuthResponse {
|
||||
fn arbitrary(g: &mut Gen) -> ServerAuthResponse {
|
||||
let success = bool::arbitrary(g);
|
||||
ServerAuthResponse { success }
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerAuthResponse::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
|
||||
let no_len = vec![1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerAuthResponse::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
let no_len = vec![6, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerAuthResponse::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(1, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
86
src/messages/server_choice.rs
Normal file
86
src/messages/server_choice.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
#[cfg(test)]
|
||||
use crate::errors::AuthenticationDeserializationError;
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::messages::AuthenticationMethod;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::pin::Pin;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ServerChoice {
|
||||
pub chosen_method: AuthenticationMethod,
|
||||
}
|
||||
|
||||
impl ServerChoice {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
mut r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
let chosen_method = AuthenticationMethod::read(r).await?;
|
||||
|
||||
Ok(ServerChoice { chosen_method })
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[5]).await?;
|
||||
self.chosen_method.write(w).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerChoice {
|
||||
fn arbitrary(g: &mut Gen) -> ServerChoice {
|
||||
ServerChoice {
|
||||
chosen_method: AuthenticationMethod::arbitrary(g),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerChoice::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
|
||||
let bad_len = vec![5];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ServerChoice::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound
|
||||
)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
let no_len = vec![9, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerChoice::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 9)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
202
src/messages/server_response.rs
Normal file
202
src/messages/server_response.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::network::SOCKSv5Address;
|
||||
use crate::serialize::read_amt;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::io::ErrorKind;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use log::warn;
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::net::Ipv4Addr;
|
||||
use std::pin::Pin;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Debug, Eq, Error, PartialEq)]
|
||||
pub enum ServerResponseStatus {
|
||||
#[error("Actually, everything's fine (weird to see this in an error)")]
|
||||
RequestGranted,
|
||||
#[error("General server failure")]
|
||||
GeneralFailure,
|
||||
#[error("Connection not allowed by policy rule")]
|
||||
ConnectionNotAllowedByRule,
|
||||
#[error("Network unreachable")]
|
||||
NetworkUnreachable,
|
||||
#[error("Host unreachable")]
|
||||
HostUnreachable,
|
||||
#[error("Connection refused")]
|
||||
ConnectionRefused,
|
||||
#[error("TTL expired")]
|
||||
TTLExpired,
|
||||
#[error("Command not supported")]
|
||||
CommandNotSupported,
|
||||
#[error("Address type not supported")]
|
||||
AddressTypeNotSupported,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ServerResponse {
|
||||
pub status: ServerResponseStatus,
|
||||
pub bound_address: SOCKSv5Address,
|
||||
pub bound_port: u16,
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub fn error<E: Into<ServerResponseStatus>>(resp: E) -> ServerResponse {
|
||||
ServerResponse {
|
||||
status: resp.into(),
|
||||
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
|
||||
bound_port: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: Pin<&mut R>,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 3];
|
||||
let raw_r = Pin::into_inner(r);
|
||||
|
||||
read_amt(Pin::new(raw_r), 3, &mut buffer).await?;
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
|
||||
if buffer[2] != 0 {
|
||||
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
|
||||
}
|
||||
|
||||
let status = match buffer[1] {
|
||||
0x00 => ServerResponseStatus::RequestGranted,
|
||||
0x01 => ServerResponseStatus::GeneralFailure,
|
||||
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||
0x03 => ServerResponseStatus::NetworkUnreachable,
|
||||
0x04 => ServerResponseStatus::HostUnreachable,
|
||||
0x05 => ServerResponseStatus::ConnectionRefused,
|
||||
0x06 => ServerResponseStatus::TTLExpired,
|
||||
0x07 => ServerResponseStatus::CommandNotSupported,
|
||||
0x08 => ServerResponseStatus::AddressTypeNotSupported,
|
||||
x => return Err(DeserializationError::InvalidServerResponse(x)),
|
||||
};
|
||||
|
||||
let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
|
||||
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
|
||||
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
|
||||
Ok(ServerResponse {
|
||||
status,
|
||||
bound_address,
|
||||
bound_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let status_code = match self.status {
|
||||
ServerResponseStatus::RequestGranted => 0x00,
|
||||
ServerResponseStatus::GeneralFailure => 0x01,
|
||||
ServerResponseStatus::ConnectionNotAllowedByRule => 0x02,
|
||||
ServerResponseStatus::NetworkUnreachable => 0x03,
|
||||
ServerResponseStatus::HostUnreachable => 0x04,
|
||||
ServerResponseStatus::ConnectionRefused => 0x05,
|
||||
ServerResponseStatus::TTLExpired => 0x06,
|
||||
ServerResponseStatus::CommandNotSupported => 0x07,
|
||||
ServerResponseStatus::AddressTypeNotSupported => 0x08,
|
||||
};
|
||||
|
||||
w.write_all(&[5, status_code, 0]).await?;
|
||||
self.bound_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.bound_port >> 8) as u8,
|
||||
(self.bound_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerResponseStatus {
|
||||
fn arbitrary(g: &mut Gen) -> ServerResponseStatus {
|
||||
let options = [
|
||||
ServerResponseStatus::RequestGranted,
|
||||
ServerResponseStatus::GeneralFailure,
|
||||
ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||
ServerResponseStatus::NetworkUnreachable,
|
||||
ServerResponseStatus::HostUnreachable,
|
||||
ServerResponseStatus::ConnectionRefused,
|
||||
ServerResponseStatus::TTLExpired,
|
||||
ServerResponseStatus::CommandNotSupported,
|
||||
ServerResponseStatus::AddressTypeNotSupported,
|
||||
];
|
||||
g.choose(&options).unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerResponse {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let status = ServerResponseStatus::arbitrary(g);
|
||||
let bound_address = SOCKSv5Address::arbitrary(g);
|
||||
let bound_port = u16::arbitrary(g);
|
||||
|
||||
ServerResponse {
|
||||
status,
|
||||
bound_address,
|
||||
bound_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerResponse::read(Pin::new(&mut cursor));
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
let bad_ver = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(bad_ver);
|
||||
let ys = ServerResponse::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_command() {
|
||||
let bad_cmd = vec![5, 32, 0x42];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ServerResponse::read(Pin::new(&mut cursor));
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidServerResponse(32)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn short_write_fails_right() {
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ServerResponse::error(ServerResponseStatus::AddressTypeNotSupported);
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = task::block_on(cmd.write(&mut cursor));
|
||||
match result {
|
||||
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
||||
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e),
|
||||
}
|
||||
}
|
||||
33
src/messages/utils.rs
Normal file
33
src/messages/utils.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
#[cfg(test)]
|
||||
use quickcheck::{Arbitrary, Gen};
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn arbitrary_socks_string(g: &mut Gen) -> String {
|
||||
loop {
|
||||
let mut potential = String::arbitrary(g);
|
||||
|
||||
potential.truncate(255);
|
||||
let bytestring = potential.as_bytes();
|
||||
|
||||
if bytestring.len() > 0 && bytestring.len() < 256 {
|
||||
return potential;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! standard_roundtrip {
|
||||
($name: ident, $t: ty) => {
|
||||
#[cfg(test)]
|
||||
quickcheck! {
|
||||
fn $name(xs: $t) -> bool {
|
||||
let mut buffer = vec![];
|
||||
task::block_on(xs.write(&mut buffer)).unwrap();
|
||||
let mut cursor = Cursor::new(buffer);
|
||||
let ys = <$t>::read(Pin::new(&mut cursor));
|
||||
xs == task::block_on(ys).unwrap()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
#[cfg(test)]
|
||||
use crate::messages::arbitrary_socks_string;
|
||||
use crate::messages::utils::arbitrary_socks_string;
|
||||
use crate::serialize::{read_amt, read_string, write_string};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{Arbitrary, Gen};
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::fmt;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::pin::Pin;
|
||||
@@ -149,3 +154,5 @@ impl Arbitrary for SOCKSv5Address {
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
|
||||
|
||||
@@ -310,4 +310,4 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user