Switch to basic tokio; will expand later to arbitrary backends.
This commit is contained in:
@@ -1,16 +1,12 @@
|
||||
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use proptest::proptest;
|
||||
#[cfg(test)]
|
||||
use proptest::prelude::{Arbitrary, Just, Strategy, prop_oneof};
|
||||
use proptest::prelude::{prop_oneof, Arbitrary, Just, Strategy};
|
||||
#[cfg(test)]
|
||||
use proptest::strategy::BoxedStrategy;
|
||||
use std::fmt;
|
||||
#[cfg(test)]
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
@@ -28,6 +24,34 @@ pub enum AuthenticationMethod {
|
||||
NoAcceptableMethods,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum AuthenticationMethodReadError {
|
||||
#[error("Invalid authentication method #{0}")]
|
||||
UnknownAuthenticationMethod(u8),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for AuthenticationMethodReadError {
|
||||
fn from(x: std::io::Error) -> AuthenticationMethodReadError {
|
||||
AuthenticationMethodReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum AuthenticationMethodWriteError {
|
||||
#[error("Trying to write invalid authentication method #{0}")]
|
||||
InvalidAuthMethod(u8),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for AuthenticationMethodWriteError {
|
||||
fn from(x: std::io::Error) -> AuthenticationMethodWriteError {
|
||||
AuthenticationMethodWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthenticationMethod {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
@@ -65,26 +89,17 @@ impl Arbitrary for AuthenticationMethod {
|
||||
Just(AuthenticationMethod::MultiAuthenticationFramework),
|
||||
Just(AuthenticationMethod::JSONPropertyBlock),
|
||||
Just(AuthenticationMethod::NoAcceptableMethods),
|
||||
|
||||
(0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod),
|
||||
].boxed()
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
impl AuthenticationMethod {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<AuthenticationMethod, DeserializationError> {
|
||||
let mut byte_buffer = [0u8; 1];
|
||||
let amount_read = r.read(&mut byte_buffer).await?;
|
||||
|
||||
if amount_read == 0 {
|
||||
return Err(AuthenticationDeserializationError::NoDataFound.into());
|
||||
}
|
||||
|
||||
match byte_buffer[0] {
|
||||
) -> Result<AuthenticationMethod, AuthenticationMethodReadError> {
|
||||
match r.read_u8().await? {
|
||||
0 => Ok(AuthenticationMethod::None),
|
||||
1 => Ok(AuthenticationMethod::GSSAPI),
|
||||
2 => Ok(AuthenticationMethod::UsernameAndPassword),
|
||||
@@ -96,14 +111,16 @@ impl AuthenticationMethod {
|
||||
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
|
||||
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
|
||||
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
|
||||
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
|
||||
e => Err(AuthenticationMethodReadError::UnknownAuthenticationMethod(
|
||||
e,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), AuthenticationMethodWriteError> {
|
||||
let value = match self {
|
||||
AuthenticationMethod::None => 0,
|
||||
AuthenticationMethod::GSSAPI => 1,
|
||||
@@ -114,31 +131,32 @@ impl AuthenticationMethod {
|
||||
AuthenticationMethod::NDS => 7,
|
||||
AuthenticationMethod::MultiAuthenticationFramework => 8,
|
||||
AuthenticationMethod::JSONPropertyBlock => 9,
|
||||
AuthenticationMethod::PrivateMethod(pm) => *pm,
|
||||
AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(pm) => *pm,
|
||||
AuthenticationMethod::PrivateMethod(pm) => {
|
||||
return Err(AuthenticationMethodWriteError::InvalidAuthMethod(*pm))
|
||||
}
|
||||
AuthenticationMethod::NoAcceptableMethods => 0xff,
|
||||
};
|
||||
|
||||
Ok(w.write_all(&[value]).await?)
|
||||
Ok(w.write_u8(value).await?)
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
||||
crate::standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
||||
|
||||
#[test]
|
||||
fn bad_byte() {
|
||||
#[tokio::test]
|
||||
async fn bad_byte() {
|
||||
let no_len = vec![42];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = AuthenticationMethod::read(&mut cursor);
|
||||
let ys = AuthenticationMethod::read(&mut cursor).await.unwrap_err();
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::InvalidAuthenticationByte(42)
|
||||
)),
|
||||
task::block_on(ys)
|
||||
AuthenticationMethodReadError::UnknownAuthenticationMethod(42),
|
||||
ys
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_isnt_empty() {
|
||||
#[tokio::test]
|
||||
async fn display_isnt_empty() {
|
||||
let vals = vec![
|
||||
AuthenticationMethod::None,
|
||||
AuthenticationMethod::GSSAPI,
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::network::SOCKSv5Address;
|
||||
use crate::serialize::read_amt;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::io::ErrorKind;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use log::debug;
|
||||
use proptest::proptest;
|
||||
use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
|
||||
#[cfg(test)]
|
||||
use proptest_derive::Arbitrary;
|
||||
#[cfg(test)]
|
||||
use std::net::Ipv4Addr;
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
@@ -24,6 +14,60 @@ pub enum ClientConnectionCommand {
|
||||
AssociateUDPPort,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientConnectionCommandReadError {
|
||||
#[error("Invalid client connection command code: {0}")]
|
||||
InvalidClientConnectionCommand(u8),
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientConnectionCommandReadError {
|
||||
fn from(x: std::io::Error) -> ClientConnectionCommandReadError {
|
||||
ClientConnectionCommandReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientConnectionCommandWriteError {
|
||||
#[error("Underlying buffer write error: {0}")]
|
||||
WriteError(String),
|
||||
#[error(transparent)]
|
||||
SOCKSAddressWriteError(#[from] SOCKSv5AddressWriteError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientConnectionCommandWriteError {
|
||||
fn from(x: std::io::Error) -> ClientConnectionCommandWriteError {
|
||||
ClientConnectionCommandWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientConnectionCommand {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<ClientConnectionCommand, ClientConnectionCommandReadError> {
|
||||
match r.read_u8().await? {
|
||||
0x01 => Ok(ClientConnectionCommand::EstablishTCPStream),
|
||||
0x02 => Ok(ClientConnectionCommand::EstablishTCPPortBinding),
|
||||
0x03 => Ok(ClientConnectionCommand::AssociateUDPPort),
|
||||
x => Err(ClientConnectionCommandReadError::InvalidClientConnectionCommand(x)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), std::io::Error> {
|
||||
match self {
|
||||
ClientConnectionCommand::EstablishTCPStream => w.write_u8(0x01).await,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => w.write_u8(0x02).await,
|
||||
ClientConnectionCommand::AssociateUDPPort => w.write_u8(0x03).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crate::standard_roundtrip!(client_command_roundtrips, ClientConnectionCommand);
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub struct ClientConnectionRequest {
|
||||
@@ -32,37 +76,46 @@ pub struct ClientConnectionRequest {
|
||||
pub destination_port: u16,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientConnectionRequestReadError {
|
||||
#[error("Invalid version in client request: {0} (expected 5)")]
|
||||
InvalidVersion(u8),
|
||||
#[error("Invalid command for client request: {0}")]
|
||||
InvalidCommand(#[from] ClientConnectionCommandReadError),
|
||||
#[error("Invalid reserved byte: {0} (expected 0)")]
|
||||
InvalidReservedByte(u8),
|
||||
#[error("Underlying read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error(transparent)]
|
||||
AddressReadError(#[from] SOCKSv5AddressReadError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientConnectionRequestReadError {
|
||||
fn from(x: std::io::Error) -> ClientConnectionRequestReadError {
|
||||
ClientConnectionRequestReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientConnectionRequest {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 3];
|
||||
|
||||
debug!("Starting to read request.");
|
||||
read_amt(r, 3, &mut buffer).await?;
|
||||
debug!("Read three opening bytes: {:?}", buffer);
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
) -> Result<Self, ClientConnectionRequestReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
if version != 5 {
|
||||
return Err(ClientConnectionRequestReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
let command_code = match buffer[1] {
|
||||
0x01 => ClientConnectionCommand::EstablishTCPStream,
|
||||
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
0x03 => ClientConnectionCommand::AssociateUDPPort,
|
||||
x => return Err(DeserializationError::InvalidClientCommand(x)),
|
||||
};
|
||||
debug!("Command code: {:?}", command_code);
|
||||
let command_code = ClientConnectionCommand::read(r).await?;
|
||||
|
||||
if buffer[2] != 0 {
|
||||
return Err(DeserializationError::InvalidReservedByte(buffer[2]));
|
||||
let reserved = r.read_u8().await?;
|
||||
if reserved != 0 {
|
||||
return Err(ClientConnectionRequestReadError::InvalidReservedByte(
|
||||
reserved,
|
||||
));
|
||||
}
|
||||
|
||||
let destination_address = SOCKSv5Address::read(r).await?;
|
||||
debug!("Destination address: {}", destination_address);
|
||||
|
||||
read_amt(r, 2, &mut buffer).await?;
|
||||
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
debug!("Destination port: {}", destination_port);
|
||||
let destination_port = r.read_u16().await?;
|
||||
|
||||
Ok(ClientConnectionRequest {
|
||||
command_code,
|
||||
@@ -74,63 +127,62 @@ impl ClientConnectionRequest {
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let command = match self.command_code {
|
||||
ClientConnectionCommand::EstablishTCPStream => 1,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => 2,
|
||||
ClientConnectionCommand::AssociateUDPPort => 3,
|
||||
};
|
||||
|
||||
w.write_all(&[5, command, 0]).await?;
|
||||
) -> Result<(), ClientConnectionCommandWriteError> {
|
||||
w.write_u8(5).await?;
|
||||
self.command_code.write(w).await?;
|
||||
w.write_u8(0).await?;
|
||||
self.destination_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.destination_port >> 8) as u8,
|
||||
(self.destination_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
w.write_u16(self.destination_port).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
||||
crate::standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientConnectionRequestReadError::ReadError(_))
|
||||
));
|
||||
|
||||
let no_len = vec![5, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientConnectionRequestReadError::ReadError(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let bad_ver = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(bad_ver);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert_eq!(Err(ClientConnectionRequestReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_command() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_command() {
|
||||
let bad_cmd = vec![5, 32, 1];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidClientCommand(32)),
|
||||
task::block_on(ys)
|
||||
Err(ClientConnectionRequestReadError::InvalidCommand(
|
||||
ClientConnectionCommandReadError::InvalidClientConnectionCommand(32)
|
||||
)),
|
||||
ys
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn short_write_fails_right() {
|
||||
#[tokio::test]
|
||||
async fn short_write_fails_right() {
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ClientConnectionRequest {
|
||||
command_code: ClientConnectionCommand::AssociateUDPPort,
|
||||
@@ -138,10 +190,12 @@ fn short_write_fails_right() {
|
||||
destination_port: 22,
|
||||
};
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = task::block_on(cmd.write(&mut cursor));
|
||||
let result = cmd.write(&mut cursor).await;
|
||||
match result {
|
||||
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
||||
Err(ClientConnectionCommandWriteError::WriteError(x)) => {
|
||||
assert!(x.contains("write zero"));
|
||||
}
|
||||
Err(e) => panic!("Got the wrong error writing too much data: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
#[cfg(test)]
|
||||
use crate::errors::AuthenticationDeserializationError;
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::messages::AuthenticationMethod;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use proptest::proptest;
|
||||
use crate::messages::authentication_method::{
|
||||
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use proptest_derive::Arbitrary;
|
||||
#[cfg(test)]
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
/// Client greetings are the first message sent in a SOCKSv5 session. They
|
||||
/// identify that there's a client that wants to talk to a server, and that
|
||||
@@ -23,26 +19,52 @@ pub struct ClientGreeting {
|
||||
pub acceptable_methods: Vec<AuthenticationMethod>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientGreetingReadError {
|
||||
#[error("Invalid version in client request: {0} (expected 5)")]
|
||||
InvalidVersion(u8),
|
||||
#[error(transparent)]
|
||||
AuthMethodReadError(#[from] AuthenticationMethodReadError),
|
||||
#[error("Underlying read error: {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientGreetingReadError {
|
||||
fn from(x: std::io::Error) -> ClientGreetingReadError {
|
||||
ClientGreetingReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientGreetingWriteError {
|
||||
#[error("Too many methods provided; need <256, saw {0}")]
|
||||
TooManyMethods(usize),
|
||||
#[error(transparent)]
|
||||
AuthMethodWriteError(#[from] AuthenticationMethodWriteError),
|
||||
#[error("Underlying write error: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientGreetingWriteError {
|
||||
fn from(x: std::io::Error) -> ClientGreetingWriteError {
|
||||
ClientGreetingWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientGreeting {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<ClientGreeting, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<ClientGreeting, ClientGreetingReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
if version != 5 {
|
||||
return Err(ClientGreetingReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
let num_methods = r.read_u8().await? as usize;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
|
||||
for _ in 0..buffer[0] {
|
||||
let mut acceptable_methods = Vec::with_capacity(num_methods);
|
||||
for _ in 0..num_methods {
|
||||
acceptable_methods.push(AuthenticationMethod::read(r).await?);
|
||||
}
|
||||
|
||||
@@ -52,9 +74,9 @@ impl ClientGreeting {
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), ClientGreetingWriteError> {
|
||||
if self.acceptable_methods.len() > 255 {
|
||||
return Err(SerializationError::TooManyAuthMethods(
|
||||
return Err(ClientGreetingWriteError::TooManyMethods(
|
||||
self.acceptable_methods.len(),
|
||||
));
|
||||
}
|
||||
@@ -70,44 +92,41 @@ impl ClientGreeting {
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
||||
crate::standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
|
||||
|
||||
let no_len = vec![5];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
|
||||
|
||||
let bad_len = vec![5, 9];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound
|
||||
)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientGreetingReadError::AuthMethodReadError(
|
||||
AuthenticationMethodReadError::ReadError(_)
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let no_len = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert_eq!(Err(ClientGreetingReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_too_many() {
|
||||
#[tokio::test]
|
||||
async fn check_too_many() {
|
||||
let mut auth_methods = Vec::with_capacity(512);
|
||||
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
|
||||
let greet = ClientGreeting {
|
||||
@@ -115,7 +134,7 @@ fn check_too_many() {
|
||||
};
|
||||
let mut output = vec![0; 1024];
|
||||
assert_eq!(
|
||||
Err(SerializationError::TooManyAuthMethods(512)),
|
||||
task::block_on(greet.write(&mut output))
|
||||
Err(ClientGreetingWriteError::TooManyMethods(512)),
|
||||
greet.write(&mut output).await
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::serialize::{read_string, write_string};
|
||||
use crate::standard_roundtrip;
|
||||
use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use proptest::prelude::{Arbitrary, BoxedStrategy};
|
||||
use proptest::proptest;
|
||||
#[cfg(test)]
|
||||
use proptest::strategy::Strategy;
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientUsernamePassword {
|
||||
@@ -30,30 +24,58 @@ impl Arbitrary for ClientUsernamePassword {
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
let max_len = args.unwrap_or(12) as usize;
|
||||
(USERNAME_REGEX, PASSWORD_REGEX).prop_map(move |(mut username, mut password)| {
|
||||
username.shrink_to(max_len);
|
||||
password.shrink_to(max_len);
|
||||
ClientUsernamePassword { username, password }
|
||||
}).boxed()
|
||||
(USERNAME_REGEX, PASSWORD_REGEX)
|
||||
.prop_map(move |(mut username, mut password)| {
|
||||
username.shrink_to(max_len);
|
||||
password.shrink_to(max_len);
|
||||
ClientUsernamePassword { username, password }
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientUsernamePasswordReadError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error("Invalid username/password version; expected 1, saw {0}")]
|
||||
InvalidVersion(u8),
|
||||
#[error(transparent)]
|
||||
StringError(#[from] SOCKSv5StringReadError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientUsernamePasswordReadError {
|
||||
fn from(x: std::io::Error) -> ClientUsernamePasswordReadError {
|
||||
ClientUsernamePasswordReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientUsernamePasswordWriteError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
WriteError(String),
|
||||
#[error(transparent)]
|
||||
StringError(#[from] SOCKSv5StringWriteError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientUsernamePasswordWriteError {
|
||||
fn from(x: std::io::Error) -> ClientUsernamePasswordWriteError {
|
||||
ClientUsernamePasswordWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientUsernamePassword {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<Self, ClientUsernamePasswordReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
if version != 1 {
|
||||
return Err(ClientUsernamePasswordReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
let username = read_string(r).await?;
|
||||
let password = read_string(r).await?;
|
||||
let username = SOCKSv5String::read(r).await?.into();
|
||||
let password = SOCKSv5String::read(r).await?.into();
|
||||
|
||||
Ok(ClientUsernamePassword { username, password })
|
||||
}
|
||||
@@ -61,35 +83,40 @@ impl ClientUsernamePassword {
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[1]).await?;
|
||||
write_string(&self.username, w).await?;
|
||||
write_string(&self.password, w).await
|
||||
) -> Result<(), ClientUsernamePasswordWriteError> {
|
||||
w.write_u8(1).await?;
|
||||
SOCKSv5String::from(self.username.as_str()).write(w).await?;
|
||||
SOCKSv5String::from(self.password.as_str()).write(w).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
||||
crate::standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn heck_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientUsernamePasswordReadError::ReadError(_))
|
||||
));
|
||||
|
||||
let user_only = vec![1, 3, 102, 111, 111];
|
||||
let mut cursor = Cursor::new(user_only);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||
println!("ys: {:?}", ys);
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientUsernamePasswordReadError::StringError(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let bad_len = vec![5];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(1, 5)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||
assert_eq!(Err(ClientUsernamePasswordReadError::InvalidVersion(5)), ys);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use proptest::proptest;
|
||||
#[cfg(test)]
|
||||
use proptest_derive::Arbitrary;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
@@ -15,6 +9,32 @@ pub struct ServerAuthResponse {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerAuthResponseReadError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error("Invalid username/password version; expected 1, saw {0}")]
|
||||
InvalidVersion(u8),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerAuthResponseReadError {
|
||||
fn from(x: std::io::Error) -> ServerAuthResponseReadError {
|
||||
ServerAuthResponseReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerAuthResponseWriteError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerAuthResponseWriteError {
|
||||
fn from(x: std::io::Error) -> ServerAuthResponseWriteError {
|
||||
ServerAuthResponseWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerAuthResponse {
|
||||
pub fn success() -> ServerAuthResponse {
|
||||
ServerAuthResponse { success: true }
|
||||
@@ -26,30 +46,22 @@ impl ServerAuthResponse {
|
||||
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<Self, ServerAuthResponseReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
if version != 1 {
|
||||
return Err(ServerAuthResponseReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
Ok(ServerAuthResponse {
|
||||
success: buffer[0] == 0,
|
||||
success: r.read_u8().await? == 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), ServerAuthResponseWriteError> {
|
||||
w.write_all(&[1]).await?;
|
||||
w.write_all(&[if self.success { 0x00 } else { 0xde }])
|
||||
.await?;
|
||||
@@ -57,28 +69,29 @@ impl ServerAuthResponse {
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
||||
crate::standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
use std::io::Cursor;
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerAuthResponse::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
|
||||
|
||||
let no_len = vec![1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerAuthResponse::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
use std::io::Cursor;
|
||||
|
||||
let no_len = vec![6, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerAuthResponse::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(1, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerAuthResponseReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
#[cfg(test)]
|
||||
use crate::errors::AuthenticationDeserializationError;
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::messages::AuthenticationMethod;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use proptest::proptest;
|
||||
use crate::messages::authentication_method::{
|
||||
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use proptest_derive::Arbitrary;
|
||||
#[cfg(test)]
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
@@ -18,6 +14,36 @@ pub struct ServerChoice {
|
||||
pub chosen_method: AuthenticationMethod,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerChoiceReadError {
|
||||
#[error(transparent)]
|
||||
AuthMethodError(#[from] AuthenticationMethodReadError),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
ReadError(String),
|
||||
#[error("Invalid version; expected 5, got {0}")]
|
||||
InvalidVersion(u8),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerChoiceReadError {
|
||||
fn from(x: std::io::Error) -> ServerChoiceReadError {
|
||||
ServerChoiceReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerChoiceWriteError {
|
||||
#[error(transparent)]
|
||||
AuthMethodError(#[from] AuthenticationMethodWriteError),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerChoiceWriteError {
|
||||
fn from(x: std::io::Error) -> ServerChoiceWriteError {
|
||||
ServerChoiceWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerChoice {
|
||||
pub fn rejection() -> ServerChoice {
|
||||
ServerChoice {
|
||||
@@ -33,15 +59,11 @@ impl ServerChoice {
|
||||
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<Self, ServerChoiceReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
if version != 5 {
|
||||
return Err(ServerChoiceReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
let chosen_method = AuthenticationMethod::read(r).await?;
|
||||
@@ -52,39 +74,32 @@ impl ServerChoice {
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[5]).await?;
|
||||
self.chosen_method.write(w).await
|
||||
) -> Result<(), ServerChoiceWriteError> {
|
||||
w.write_u8(5).await?;
|
||||
self.chosen_method.write(w).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
||||
crate::standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerChoice::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerChoice::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerChoiceReadError::ReadError(_))));
|
||||
|
||||
let bad_len = vec![5];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ServerChoice::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound
|
||||
)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerChoice::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerChoiceReadError::AuthMethodError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let no_len = vec![9, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerChoice::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 9)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerChoice::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerChoiceReadError::InvalidVersion(9)), ys);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,10 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::network::generic::IntoErrorResponse;
|
||||
use crate::network::SOCKSv5Address;
|
||||
use crate::serialize::read_amt;
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::io::ErrorKind;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use log::warn;
|
||||
use proptest::proptest;
|
||||
use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
|
||||
#[cfg(test)]
|
||||
use proptest_derive::Arbitrary;
|
||||
use std::net::Ipv4Addr;
|
||||
#[cfg(test)]
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, Error, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
@@ -40,12 +29,6 @@ pub enum ServerResponseStatus {
|
||||
AddressTypeNotSupported,
|
||||
}
|
||||
|
||||
impl IntoErrorResponse for ServerResponseStatus {
|
||||
fn into_response(&self) -> ServerResponseStatus {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub struct ServerResponse {
|
||||
@@ -54,33 +37,57 @@ pub struct ServerResponse {
|
||||
pub bound_port: u16,
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub fn error<E: IntoErrorResponse>(resp: &E) -> ServerResponse {
|
||||
ServerResponse {
|
||||
status: resp.into_response(),
|
||||
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
|
||||
bound_port: 0,
|
||||
}
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerResponseReadError {
|
||||
#[error("Error reading from underlying buffer: {0}")]
|
||||
ReadError(String),
|
||||
#[error(transparent)]
|
||||
AddressReadError(#[from] SOCKSv5AddressReadError),
|
||||
#[error("Invalid version; expected 5, got {0}")]
|
||||
InvalidVersion(u8),
|
||||
#[error("Invalid reserved byte; saw {0}, should be 0")]
|
||||
InvalidReservedByte(u8),
|
||||
#[error("Invalid (or just unknown) server response value {0}")]
|
||||
InvalidServerResponse(u8),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerResponseReadError {
|
||||
fn from(x: std::io::Error) -> ServerResponseReadError {
|
||||
ServerResponseReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerResponseWriteError {
|
||||
#[error("Error reading from underlying buffer: {0}")]
|
||||
WriteError(String),
|
||||
#[error(transparent)]
|
||||
AddressWriteError(#[from] SOCKSv5AddressWriteError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerResponseWriteError {
|
||||
fn from(x: std::io::Error) -> ServerResponseWriteError {
|
||||
ServerResponseWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 3];
|
||||
|
||||
read_amt(r, 3, &mut buffer).await?;
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
) -> Result<Self, ServerResponseReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
if version != 5 {
|
||||
return Err(ServerResponseReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
if buffer[2] != 0 {
|
||||
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
|
||||
let status_byte = r.read_u8().await?;
|
||||
|
||||
let reserved_byte = r.read_u8().await?;
|
||||
if reserved_byte != 0 {
|
||||
return Err(ServerResponseReadError::InvalidReservedByte(reserved_byte));
|
||||
}
|
||||
|
||||
let status = match buffer[1] {
|
||||
let status = match status_byte {
|
||||
0x00 => ServerResponseStatus::RequestGranted,
|
||||
0x01 => ServerResponseStatus::GeneralFailure,
|
||||
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||
@@ -90,12 +97,11 @@ impl ServerResponse {
|
||||
0x06 => ServerResponseStatus::TTLExpired,
|
||||
0x07 => ServerResponseStatus::CommandNotSupported,
|
||||
0x08 => ServerResponseStatus::AddressTypeNotSupported,
|
||||
x => return Err(DeserializationError::InvalidServerResponse(x)),
|
||||
x => return Err(ServerResponseReadError::InvalidServerResponse(x)),
|
||||
};
|
||||
|
||||
let bound_address = SOCKSv5Address::read(r).await?;
|
||||
read_amt(r, 2, &mut buffer).await?;
|
||||
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
let bound_port = r.read_u16().await?;
|
||||
|
||||
Ok(ServerResponse {
|
||||
status,
|
||||
@@ -107,7 +113,9 @@ impl ServerResponse {
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), ServerResponseWriteError> {
|
||||
w.write_u8(5).await?;
|
||||
|
||||
let status_code = match self.status {
|
||||
ServerResponseStatus::RequestGranted => 0x00,
|
||||
ServerResponseStatus::GeneralFailure => 0x01,
|
||||
@@ -119,59 +127,61 @@ impl ServerResponse {
|
||||
ServerResponseStatus::CommandNotSupported => 0x07,
|
||||
ServerResponseStatus::AddressTypeNotSupported => 0x08,
|
||||
};
|
||||
|
||||
w.write_all(&[5, status_code, 0]).await?;
|
||||
w.write_u8(status_code).await?;
|
||||
w.write_u8(0).await?;
|
||||
self.bound_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.bound_port >> 8) as u8,
|
||||
(self.bound_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
w.write_u16(self.bound_port).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
||||
crate::standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerResponse::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerResponseReadError::ReadError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let bad_ver = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(bad_ver);
|
||||
let ys = ServerResponse::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerResponseReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_command() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_reserved() {
|
||||
let bad_cmd = vec![5, 32, 0x42];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ServerResponse::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidServerResponse(32)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerResponseReadError::InvalidReservedByte(0x42)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn short_write_fails_right() {
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ServerResponse::error(&ServerResponseStatus::AddressTypeNotSupported);
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = task::block_on(cmd.write(&mut cursor));
|
||||
match result {
|
||||
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
||||
Err(e) => panic!("Got the wrong error writing too much data: {}", e),
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn check_bad_command() {
|
||||
let bad_cmd = vec![5, 32, 0];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerResponseReadError::InvalidServerResponse(32)), ys);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn short_write_fails_right() {
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ServerResponse {
|
||||
status: ServerResponseStatus::AddressTypeNotSupported,
|
||||
bound_address: SOCKSv5Address::Hostname("tester.com".to_string()),
|
||||
bound_port: 99,
|
||||
};
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = cmd.write(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ServerResponseWriteError::WriteError(_))
|
||||
));
|
||||
}
|
||||
|
||||
117
src/messages/string.rs
Normal file
117
src/messages/string.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
#[cfg(test)]
|
||||
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
|
||||
use std::convert::TryFrom;
|
||||
use std::string::FromUtf8Error;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct SOCKSv5String(String);
|
||||
|
||||
#[cfg(test)]
|
||||
const STRING_REGEX: &str = "[a-zA-Z0-9_.|!@#$%^]+";
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for SOCKSv5String {
|
||||
type Parameters = Option<u16>;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
let max_len = args.unwrap_or(32) as usize;
|
||||
|
||||
STRING_REGEX
|
||||
.prop_map(move |mut str| {
|
||||
str.shrink_to(max_len);
|
||||
SOCKSv5String(str)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum SOCKSv5StringReadError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error("SOCKSv5 string encoding error; encountered empty string (?)")]
|
||||
ZeroStringLength,
|
||||
#[error("Invalid UTF-8 string: {0}")]
|
||||
InvalidUtf8Error(#[from] FromUtf8Error),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for SOCKSv5StringReadError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5StringReadError {
|
||||
SOCKSv5StringReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum SOCKSv5StringWriteError {
|
||||
#[error("Underlying buffer write error: {0}")]
|
||||
WriteError(String),
|
||||
#[error("String too large to encode according to SOCKSv5 reuls ({0} bytes long)")]
|
||||
TooBig(usize),
|
||||
#[error("Cannot serialize the empty string in SOCKSv5")]
|
||||
ZeroStringLength,
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for SOCKSv5StringWriteError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5StringWriteError {
|
||||
SOCKSv5StringWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl SOCKSv5String {
|
||||
pub async fn read<R: AsyncRead + Unpin>(r: &mut R) -> Result<Self, SOCKSv5StringReadError> {
|
||||
let length = r.read_u8().await? as usize;
|
||||
|
||||
if length == 0 {
|
||||
return Err(SOCKSv5StringReadError::ZeroStringLength);
|
||||
}
|
||||
|
||||
let mut bytestring = vec![0; length];
|
||||
r.read_exact(&mut bytestring).await?;
|
||||
|
||||
Ok(SOCKSv5String(String::from_utf8(bytestring)?))
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SOCKSv5StringWriteError> {
|
||||
let bytestring = self.0.as_bytes();
|
||||
|
||||
if bytestring.is_empty() {
|
||||
return Err(SOCKSv5StringWriteError::ZeroStringLength);
|
||||
}
|
||||
|
||||
let length = match u8::try_from(bytestring.len()) {
|
||||
Err(_) => return Err(SOCKSv5StringWriteError::TooBig(bytestring.len())),
|
||||
Ok(x) => x,
|
||||
};
|
||||
|
||||
w.write_u8(length).await?;
|
||||
w.write_all(bytestring).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for SOCKSv5String {
|
||||
fn from(x: String) -> Self {
|
||||
SOCKSv5String(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for SOCKSv5String {
|
||||
fn from(x: &str) -> Self {
|
||||
SOCKSv5String(x.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SOCKSv5String> for String {
|
||||
fn from(x: SOCKSv5String) -> Self {
|
||||
x.0
|
||||
}
|
||||
}
|
||||
|
||||
crate::standard_roundtrip!(socks_string_roundtrips, SOCKSv5String);
|
||||
@@ -1,16 +0,0 @@
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! standard_roundtrip {
|
||||
($name: ident, $t: ty) => {
|
||||
proptest! {
|
||||
#[test]
|
||||
fn $name(xs: $t) {
|
||||
let mut buffer = vec![];
|
||||
task::block_on(xs.write(&mut buffer)).unwrap();
|
||||
let mut cursor = Cursor::new(buffer);
|
||||
let ys = <$t>::read(&mut cursor);
|
||||
assert_eq!(xs, task::block_on(ys).unwrap());
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user