checkpoint

This commit is contained in:
Adam Wick
2021-06-24 19:18:16 -07:00
commit 1bf6f62d4e
16 changed files with 1789 additions and 0 deletions

30
src/bin/socks-server.rs Normal file
View File

@@ -0,0 +1,30 @@
use async_socks5::network::Builtin;
use async_socks5::server::{SOCKSv5Server, SecurityParameters};
use async_std::io;
use async_std::net::TcpListener;
use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TermLogger, TerminalMode};
#[async_std::main]
async fn main() -> Result<(), io::Error> {
CombinedLogger::init(vec![TermLogger::new(
LevelFilter::Debug,
Config::default(),
TerminalMode::Mixed,
ColorChoice::Auto,
)])
.expect("Couldn't initialize logger");
let main_listener = TcpListener::bind("127.0.0.1:0").await?;
let params = SecurityParameters {
allow_unauthenticated: false,
allow_connection: None,
check_password: None,
connect_tls: None,
};
let server = SOCKSv5Server::new(Builtin::new(), params, main_listener);
server.run().await?;
Ok(())
}

133
src/client.rs Normal file
View File

@@ -0,0 +1,133 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::messages::{
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting,
ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus,
};
use crate::network::{Network, SOCKSv5Address};
use async_std::net::IpAddr;
use futures::io::{AsyncRead, AsyncWrite};
use std::pin::Pin;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SOCKSv5Error {
#[error("SOCKSv5 serialization error: {0}")]
SerializationError(#[from] SerializationError),
#[error("SOCKSv5 deserialization error: {0}")]
DeserializationError(#[from] DeserializationError),
#[error("No acceptable authentication methods available")]
NoAuthMethodsAllowed,
#[error("Authentication failed")]
AuthenticationFailed,
#[error("Server chose an unsupported authentication method ({0}")]
UnsupportedAuthMethodChosen(AuthenticationMethod),
#[error("Server said no: {0}")]
ServerFailure(#[from] ServerResponseStatus),
}
pub struct SOCKSv5Client<S, N>
where
S: AsyncRead + AsyncWrite,
N: Network,
{
_network: N,
stream: S,
}
pub struct LoginInfo {
username_password: Option<UsernamePassword>,
}
pub struct UsernamePassword {
username: String,
password: String,
}
impl<S, N> SOCKSv5Client<S, N>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
N: Network,
{
/// Create a new SOCKSv5 client connection over the given steam, using the given
/// authentication information.
pub async fn new(_network: N, mut stream: S, login: &LoginInfo) -> Result<Self, SOCKSv5Error> {
let mut acceptable_methods = vec![AuthenticationMethod::None];
if login.username_password.is_some() {
acceptable_methods.push(AuthenticationMethod::UsernameAndPassword);
}
let client_greeting = ClientGreeting { acceptable_methods };
client_greeting.write(&mut stream).await?;
let server_choice = ServerChoice::read(Pin::new(&mut stream)).await?;
match server_choice.chosen_method {
AuthenticationMethod::None => {}
AuthenticationMethod::UsernameAndPassword => {
let (username, password) = if let Some(ref linfo) = login.username_password {
(linfo.username.clone(), linfo.password.clone())
} else {
("".to_string(), "".to_string())
};
let auth_request = ClientUsernamePassword { username, password };
auth_request.write(&mut stream).await?;
let server_response = ServerAuthResponse::read(Pin::new(&mut stream)).await?;
if !server_response.success {
return Err(SOCKSv5Error::AuthenticationFailed);
}
}
AuthenticationMethod::NoAcceptableMethods => {
return Err(SOCKSv5Error::NoAuthMethodsAllowed)
}
x => return Err(SOCKSv5Error::UnsupportedAuthMethodChosen(x)),
}
Ok(SOCKSv5Client { _network, stream })
}
async fn connect_internal(
&mut self,
addr: SOCKSv5Address,
port: u16,
) -> Result<N::Stream, SOCKSv5Error> {
let request = ClientConnectionRequest {
command_code: ClientConnectionCommand::EstablishTCPStream,
destination_address: addr,
destination_port: port,
};
request.write(&mut self.stream).await?;
let response = ServerResponse::read(Pin::new(&mut self.stream)).await?;
if response.status == ServerResponseStatus::RequestGranted {
unimplemented!()
} else {
Err(SOCKSv5Error::from(response.status))
}
}
pub async fn connect(&mut self, addr: IpAddr, port: u16) -> Result<N::Stream, SOCKSv5Error> {
assert!(port != 0);
match addr {
IpAddr::V4(a) => self.connect_internal(SOCKSv5Address::IP4(a), port).await,
IpAddr::V6(a) => self.connect_internal(SOCKSv5Address::IP6(a), port).await,
}
}
pub async fn connect_name(
&mut self,
name: String,
port: u16,
) -> Result<N::Stream, SOCKSv5Error> {
format!("hello {}", 'a');
self.connect_internal(SOCKSv5Address::Name(name), port)
.await
}
}

51
src/errors.rs Normal file
View File

@@ -0,0 +1,51 @@
use std::io;
use std::string::FromUtf8Error;
use thiserror::Error;
/// All the errors that can pop up when trying to turn raw bytes into SOCKSv5
/// messages.
#[derive(Error, Debug)]
pub enum DeserializationError {
#[error("Invalid protocol version for packet ({1} is not {0}!)")]
InvalidVersion(u8, u8),
#[error("Not enough data found")]
NotEnoughData,
#[error("Ooops! Found an empty string where I shouldn't")]
InvalidEmptyString,
#[error("IO error: {0}")]
IOError(#[from] io::Error),
#[error("SOCKS authentication format parse error: {0}")]
AuthenticationMethodError(#[from] AuthenticationDeserializationError),
#[error("Error converting from UTF-8: {0}")]
UTF8Error(#[from] FromUtf8Error),
#[error("Invalid address type; wanted 1, 3, or 4, got {0}")]
InvalidAddressType(u8),
#[error("Invalid client command {0}; expected 1, 2, or 3")]
InvalidClientCommand(u8),
#[error("Invalid server status {0}; expected 0-8")]
InvalidServerResponse(u8),
}
/// All the errors that can occur trying to turn SOCKSv5 message structures
/// into raw bytes. There's a few places that the message structures allow
/// for information that can't be serialized; often, you have to be careful
/// about how long your strings are ...
#[derive(Error, Debug)]
pub enum SerializationError {
#[error("Too many authentication methods for serialization ({0} > 255)")]
TooManyAuthMethods(usize),
#[error("Invalid length for string: {0}")]
InvalidStringLength(String),
#[error("IO error: {0}")]
IOError(#[from] io::Error),
}
#[derive(Error, Debug)]
pub enum AuthenticationDeserializationError {
#[error("No data found deserializing SOCKS authentication type")]
NoDataFound,
#[error("Invalid authentication type value: {0}")]
InvalidAuthenticationByte(u8),
#[error("IO error reading SOCKS authentication type: {0}")]
IOError(#[from] io::Error),
}

6
src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod client;
pub mod errors;
pub mod messages;
pub mod network;
mod serialize;
pub mod server;

604
src/messages.rs Normal file
View File

@@ -0,0 +1,604 @@
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
use crate::network::SOCKSv5Address;
use crate::serialize::{read_amt, read_string, write_string};
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use log::warn;
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::fmt;
use std::net::Ipv4Addr;
use std::pin::Pin;
use thiserror::Error;
/// Client greetings are the first message sent in a SOCKSv5 session. They
/// identify that there's a client that wants to talk to a server, and that
/// they can support any of the provided mechanisms for authenticating to
/// said server. (It feels weird that the offer/choice goes this way instead
/// of the reverse, but whatever.)
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientGreeting {
pub acceptable_methods: Vec<AuthenticationMethod>,
}
impl ClientGreeting {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<ClientGreeting, DeserializationError> {
let mut buffer = [0; 1];
let raw_r = Pin::into_inner(r);
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
for _ in 0..buffer[0] {
acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?);
}
Ok(ClientGreeting { acceptable_methods })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
if self.acceptable_methods.len() > 255 {
return Err(SerializationError::TooManyAuthMethods(
self.acceptable_methods.len(),
));
}
let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2);
buffer.push(5);
buffer.push(self.acceptable_methods.len() as u8);
w.write_all(&buffer).await?;
for authmeth in self.acceptable_methods.iter() {
authmeth.write(w).await?;
}
Ok(())
}
}
#[cfg(test)]
impl Arbitrary for ClientGreeting {
fn arbitrary(g: &mut Gen) -> ClientGreeting {
let amt = u8::arbitrary(g);
let mut acceptable_methods = Vec::with_capacity(amt as usize);
for _ in 0..amt {
acceptable_methods.push(AuthenticationMethod::arbitrary(g));
}
ClientGreeting { acceptable_methods }
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerChoice {
pub chosen_method: AuthenticationMethod,
}
impl ServerChoice {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
let chosen_method = AuthenticationMethod::read(r).await?;
Ok(ServerChoice { chosen_method })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[5]).await?;
self.chosen_method.write(w).await
}
}
#[cfg(test)]
impl Arbitrary for ServerChoice {
fn arbitrary(g: &mut Gen) -> ServerChoice {
ServerChoice {
chosen_method: AuthenticationMethod::arbitrary(g),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientUsernamePassword {
pub username: String,
pub password: String,
}
impl ClientUsernamePassword {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
let raw_r = Pin::into_inner(r);
if raw_r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
let username = read_string(Pin::new(raw_r)).await?;
let password = read_string(Pin::new(raw_r)).await?;
Ok(ClientUsernamePassword { username, password })
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[1]).await?;
write_string(&self.username, w).await?;
write_string(&self.password, w).await
}
}
#[cfg(test)]
impl Arbitrary for ClientUsernamePassword {
fn arbitrary(g: &mut Gen) -> Self {
let username = arbitrary_socks_string(g);
let password = arbitrary_socks_string(g);
ClientUsernamePassword { username, password }
}
}
#[cfg(test)]
pub fn arbitrary_socks_string(g: &mut Gen) -> String {
loop {
let mut potential = String::arbitrary(g);
potential.truncate(255);
let bytestring = potential.as_bytes();
if bytestring.len() > 0 && bytestring.len() < 256 {
return potential;
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerAuthResponse {
pub success: bool,
}
impl ServerAuthResponse {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
Ok(ServerAuthResponse {
success: buffer[0] == 0,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[1]).await?;
w.write_all(&[if self.success { 0x00 } else { 0xde }])
.await?;
Ok(())
}
}
#[cfg(test)]
impl Arbitrary for ServerAuthResponse {
fn arbitrary(g: &mut Gen) -> ServerAuthResponse {
let success = bool::arbitrary(g);
ServerAuthResponse { success }
}
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum AuthenticationMethod {
None,
GSSAPI,
UsernameAndPassword,
ChallengeHandshake,
ChallengeResponse,
SSL,
NDS,
MultiAuthenticationFramework,
JSONPropertyBlock,
PrivateMethod(u8),
NoAcceptableMethods,
}
impl fmt::Display for AuthenticationMethod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
AuthenticationMethod::None => write!(f, "No authentication"),
AuthenticationMethod::GSSAPI => write!(f, "GSS-API"),
AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"),
AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"),
AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"),
AuthenticationMethod::SSL => write!(f, "SSL"),
AuthenticationMethod::NDS => write!(f, "NDS Authentication"),
AuthenticationMethod::MultiAuthenticationFramework => {
write!(f, "Multi-Authentication Framework")
}
AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"),
AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m),
AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"),
}
}
}
impl AuthenticationMethod {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<AuthenticationMethod, DeserializationError> {
let mut byte_buffer = [0u8; 1];
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(AuthenticationDeserializationError::NoDataFound.into());
}
match byte_buffer[0] {
0 => Ok(AuthenticationMethod::None),
1 => Ok(AuthenticationMethod::GSSAPI),
2 => Ok(AuthenticationMethod::UsernameAndPassword),
3 => Ok(AuthenticationMethod::ChallengeHandshake),
5 => Ok(AuthenticationMethod::ChallengeResponse),
6 => Ok(AuthenticationMethod::SSL),
7 => Ok(AuthenticationMethod::NDS),
8 => Ok(AuthenticationMethod::MultiAuthenticationFramework),
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let value = match self {
AuthenticationMethod::None => 0,
AuthenticationMethod::GSSAPI => 1,
AuthenticationMethod::UsernameAndPassword => 2,
AuthenticationMethod::ChallengeHandshake => 3,
AuthenticationMethod::ChallengeResponse => 5,
AuthenticationMethod::SSL => 6,
AuthenticationMethod::NDS => 7,
AuthenticationMethod::MultiAuthenticationFramework => 8,
AuthenticationMethod::JSONPropertyBlock => 9,
AuthenticationMethod::PrivateMethod(pm) => *pm,
AuthenticationMethod::NoAcceptableMethods => 0xff,
};
Ok(w.write_all(&[value]).await?)
}
}
#[cfg(test)]
impl Arbitrary for AuthenticationMethod {
fn arbitrary(g: &mut Gen) -> AuthenticationMethod {
let mut vals = vec![
AuthenticationMethod::None,
AuthenticationMethod::GSSAPI,
AuthenticationMethod::UsernameAndPassword,
AuthenticationMethod::ChallengeHandshake,
AuthenticationMethod::ChallengeResponse,
AuthenticationMethod::SSL,
AuthenticationMethod::NDS,
AuthenticationMethod::MultiAuthenticationFramework,
AuthenticationMethod::JSONPropertyBlock,
AuthenticationMethod::NoAcceptableMethods,
];
for x in 0x80..0xffu8 {
vals.push(AuthenticationMethod::PrivateMethod(x));
}
g.choose(&vals).unwrap().clone()
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ClientConnectionCommand {
EstablishTCPStream,
EstablishTCPPortBinding,
AssociateUDPPort,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientConnectionRequest {
pub command_code: ClientConnectionCommand,
pub destination_address: SOCKSv5Address,
pub destination_port: u16,
}
impl ClientConnectionRequest {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 2];
let raw_r = Pin::into_inner(r);
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
let command_code = match buffer[1] {
0x01 => ClientConnectionCommand::EstablishTCPStream,
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
0x03 => ClientConnectionCommand::AssociateUDPPort,
x => return Err(DeserializationError::InvalidClientCommand(x)),
};
let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ClientConnectionRequest {
command_code,
destination_address,
destination_port,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let command = match self.command_code {
ClientConnectionCommand::EstablishTCPStream => 1,
ClientConnectionCommand::EstablishTCPPortBinding => 2,
ClientConnectionCommand::AssociateUDPPort => 3,
};
w.write_all(&[5, command]).await?;
self.destination_address.write(w).await?;
w.write_all(&[
(self.destination_port >> 8) as u8,
(self.destination_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
}
}
#[cfg(test)]
impl Arbitrary for ClientConnectionCommand {
fn arbitrary(g: &mut Gen) -> ClientConnectionCommand {
g.choose(&[
ClientConnectionCommand::EstablishTCPStream,
ClientConnectionCommand::EstablishTCPPortBinding,
ClientConnectionCommand::AssociateUDPPort,
])
.unwrap()
.clone()
}
}
#[cfg(test)]
impl Arbitrary for ClientConnectionRequest {
fn arbitrary(g: &mut Gen) -> Self {
let command_code = ClientConnectionCommand::arbitrary(g);
let destination_address = SOCKSv5Address::arbitrary(g);
let destination_port = u16::arbitrary(g);
ClientConnectionRequest {
command_code,
destination_address,
destination_port,
}
}
}
#[derive(Clone, Debug, Eq, Error, PartialEq)]
pub enum ServerResponseStatus {
#[error("Actually, everything's fine (weird to see this in an error)")]
RequestGranted,
#[error("General server failure")]
GeneralFailure,
#[error("Connection not allowed by policy rule")]
ConnectionNotAllowedByRule,
#[error("Network unreachable")]
NetworkUnreachable,
#[error("Host unreachable")]
HostUnreachable,
#[error("Connection refused")]
ConnectionRefused,
#[error("TTL expired")]
TTLExpired,
#[error("Command not supported")]
CommandNotSupported,
#[error("Address type not supported")]
AddressTypeNotSupported,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ServerResponse {
pub status: ServerResponseStatus,
pub bound_address: SOCKSv5Address,
pub bound_port: u16,
}
impl ServerResponse {
pub fn error<E: Into<ServerResponseStatus>>(resp: E) -> ServerResponse {
ServerResponse {
status: resp.into(),
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
bound_port: 0,
}
}
}
impl ServerResponse {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 3];
let raw_r = Pin::into_inner(r);
read_amt(Pin::new(raw_r), 3, &mut buffer).await?;
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if buffer[2] != 0 {
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
}
let status = match buffer[1] {
0x00 => ServerResponseStatus::RequestGranted,
0x01 => ServerResponseStatus::GeneralFailure,
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
0x03 => ServerResponseStatus::NetworkUnreachable,
0x04 => ServerResponseStatus::HostUnreachable,
0x05 => ServerResponseStatus::ConnectionRefused,
0x06 => ServerResponseStatus::TTLExpired,
0x07 => ServerResponseStatus::CommandNotSupported,
0x08 => ServerResponseStatus::AddressTypeNotSupported,
x => return Err(DeserializationError::InvalidServerResponse(x)),
};
let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?;
read_amt(Pin::new(raw_r), 2, &mut buffer).await?;
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ServerResponse {
status,
bound_address,
bound_port,
})
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
let status_code = match self.status {
ServerResponseStatus::RequestGranted => 0x00,
ServerResponseStatus::GeneralFailure => 0x01,
ServerResponseStatus::ConnectionNotAllowedByRule => 0x02,
ServerResponseStatus::NetworkUnreachable => 0x03,
ServerResponseStatus::HostUnreachable => 0x04,
ServerResponseStatus::ConnectionRefused => 0x05,
ServerResponseStatus::TTLExpired => 0x06,
ServerResponseStatus::CommandNotSupported => 0x07,
ServerResponseStatus::AddressTypeNotSupported => 0x08,
};
w.write_all(&[5, status_code, 0]).await?;
self.bound_address.write(w).await?;
w.write_all(&[
(self.bound_port >> 8) as u8,
(self.bound_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
}
}
#[cfg(test)]
impl Arbitrary for ServerResponseStatus {
fn arbitrary(g: &mut Gen) -> ServerResponseStatus {
g.choose(&[
ServerResponseStatus::RequestGranted,
ServerResponseStatus::GeneralFailure,
ServerResponseStatus::ConnectionNotAllowedByRule,
ServerResponseStatus::NetworkUnreachable,
ServerResponseStatus::HostUnreachable,
ServerResponseStatus::ConnectionRefused,
ServerResponseStatus::TTLExpired,
ServerResponseStatus::CommandNotSupported,
ServerResponseStatus::AddressTypeNotSupported,
])
.unwrap()
.clone()
}
}
#[cfg(test)]
impl Arbitrary for ServerResponse {
fn arbitrary(g: &mut Gen) -> Self {
let status = ServerResponseStatus::arbitrary(g);
let bound_address = SOCKSv5Address::arbitrary(g);
let bound_port = u16::arbitrary(g);
ServerResponse {
status,
bound_address,
bound_port,
}
}
}
macro_rules! standard_roundtrip {
($name: ident, $t: ty) => {
#[cfg(test)]
quickcheck! {
fn $name(xs: $t) -> bool {
let mut buffer = vec![];
task::block_on(xs.write(&mut buffer)).unwrap();
let mut cursor = Cursor::new(buffer);
let ys = <$t>::read(Pin::new(&mut cursor));
xs == task::block_on(ys).unwrap()
}
}
};
}
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
standard_roundtrip!(server_auth_response, ServerAuthResponse);
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
standard_roundtrip!(server_response_roundtrips, ServerResponse);

43
src/network.rs Normal file
View File

@@ -0,0 +1,43 @@
pub mod address;
pub mod datagram;
pub mod generic;
pub mod listener;
pub mod standard;
pub mod stream;
use crate::messages::ServerResponseStatus;
pub use crate::network::address::{SOCKSv5Address, ToSOCKSAddress};
pub use crate::network::standard::Builtin;
use async_trait::async_trait;
use futures::{AsyncRead, AsyncWrite};
use std::fmt;
#[async_trait]
pub trait Network {
type Stream: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static;
type Listener: SingleShotListener<Self::Stream, Self::Error> + Send + Sync + 'static;
type UdpSocket;
type Error: fmt::Debug + fmt::Display + Into<ServerResponseStatus>;
async fn connect<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: u16,
) -> Result<Self::Stream, Self::Error>;
async fn udp_socket<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: Option<u16>,
) -> Result<Self::UdpSocket, Self::Error>;
async fn listen<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: Option<u16>,
) -> Result<Self::Listener, Self::Error>;
}
#[async_trait]
pub trait SingleShotListener<Stream, Error> {
async fn accept(self) -> Result<Stream, Error>;
fn info(&self) -> Result<(SOCKSv5Address, u16), Error>;
}

151
src/network/address.rs Normal file
View File

@@ -0,0 +1,151 @@
use crate::errors::{DeserializationError, SerializationError};
#[cfg(test)]
use crate::messages::arbitrary_socks_string;
use crate::serialize::{read_amt, read_string, write_string};
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{Arbitrary, Gen};
use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
pub trait ToSOCKSAddress: Send {
fn to_socks_address(&self) -> SOCKSv5Address;
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum SOCKSv5Address {
IP4(Ipv4Addr),
IP6(Ipv6Addr),
Name(String),
}
impl ToSOCKSAddress for SOCKSv5Address {
fn to_socks_address(&self) -> SOCKSv5Address {
self.clone()
}
}
impl ToSOCKSAddress for IpAddr {
fn to_socks_address(&self) -> SOCKSv5Address {
match self {
IpAddr::V4(a) => SOCKSv5Address::IP4(*a),
IpAddr::V6(a) => SOCKSv5Address::IP6(*a),
}
}
}
impl ToSOCKSAddress for Ipv4Addr {
fn to_socks_address(&self) -> SOCKSv5Address {
SOCKSv5Address::IP4(*self)
}
}
impl ToSOCKSAddress for Ipv6Addr {
fn to_socks_address(&self) -> SOCKSv5Address {
SOCKSv5Address::IP6(*self)
}
}
impl ToSOCKSAddress for String {
fn to_socks_address(&self) -> SOCKSv5Address {
SOCKSv5Address::Name(self.clone())
}
}
impl<'a> ToSOCKSAddress for &'a str {
fn to_socks_address(&self) -> SOCKSv5Address {
SOCKSv5Address::Name(self.to_string())
}
}
impl fmt::Display for SOCKSv5Address {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
SOCKSv5Address::Name(a) => write!(f, "{}", a),
}
}
}
impl From<IpAddr> for SOCKSv5Address {
fn from(addr: IpAddr) -> SOCKSv5Address {
match addr {
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
}
}
}
impl SOCKSv5Address {
pub async fn read<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<Self, DeserializationError> {
let mut byte_buffer = [0u8; 1];
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(DeserializationError::NotEnoughData);
}
match byte_buffer[0] {
1 => {
let mut addr_buffer = [0; 4];
read_amt(r, 4, &mut addr_buffer).await?;
Ok(SOCKSv5Address::IP4(Ipv4Addr::from(addr_buffer)))
}
3 => {
let mut addr_buffer = [0; 16];
read_amt(r, 16, &mut addr_buffer).await?;
Ok(SOCKSv5Address::IP6(Ipv6Addr::from(addr_buffer)))
}
4 => {
let name = read_string(r).await?;
Ok(SOCKSv5Address::Name(name))
}
x => Err(DeserializationError::InvalidAddressType(x)),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
match self {
SOCKSv5Address::IP4(x) => {
w.write_all(&[1]).await?;
w.write_all(&x.octets())
.await
.map_err(SerializationError::IOError)
}
SOCKSv5Address::IP6(x) => {
w.write_all(&[3]).await?;
w.write_all(&x.octets())
.await
.map_err(SerializationError::IOError)
}
SOCKSv5Address::Name(x) => {
w.write_all(&[4]).await?;
write_string(x, w).await
}
}
}
}
#[cfg(test)]
impl Arbitrary for SOCKSv5Address {
fn arbitrary(g: &mut Gen) -> Self {
let ip4 = Ipv4Addr::arbitrary(g);
let ip6 = Ipv6Addr::arbitrary(g);
let nm = arbitrary_socks_string(g);
g.choose(&[
SOCKSv5Address::IP4(ip4),
SOCKSv5Address::IP6(ip6),
SOCKSv5Address::Name(nm),
])
.unwrap()
.clone()
}
}

37
src/network/datagram.rs Normal file
View File

@@ -0,0 +1,37 @@
use crate::network::address::SOCKSv5Address;
use async_trait::async_trait;
#[async_trait]
pub trait Datagramlike: Send + Sync {
type Error;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error>;
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error>;
}
pub struct GenericDatagramSocket<E> {
pub internal: Box<dyn Datagramlike<Error = E>>,
}
#[async_trait]
impl<E> Datagramlike for GenericDatagramSocket<E> {
type Error = E;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
Ok(self.internal.send_to(buf, addr, port).await?)
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
Ok(self.internal.recv_from(buf).await?)
}
}

48
src/network/generic.rs Normal file
View File

@@ -0,0 +1,48 @@
use crate::messages::ServerResponseStatus;
use crate::network::address::ToSOCKSAddress;
use crate::network::datagram::GenericDatagramSocket;
use crate::network::listener::GenericListener;
use crate::network::stream::GenericStream;
use async_trait::async_trait;
use std::fmt::Display;
#[async_trait]
pub trait Networklike {
/// The error type for things that fail on this network. Apologies in advance
/// for using only one; if you have a use case for separating your errors,
/// please shoot the author(s) and email to split this into multiple types, one
/// for each trait function.
type Error: Display + Into<ServerResponseStatus> + Send;
/// Connect to the given address and port, over this kind of network. The
/// underlying stream should behave somewhat like a TCP stream ... which
/// may be exactly what you're using. However, in order to support tunnelling
/// scenarios (i.e., using another proxy, going through Tor or SSH, etc.) we
/// work generically over any stream-like object.
async fn connect<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error>;
/// Listen for connections on the given address and port, returning a generic
/// listener socket to use in the future.
async fn listen<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error>;
/// Bind a socket for the purposes of doing some datagram communication. NOTE!
/// this is only for UDP-like communication, not for generic connecting or
/// listening! Maybe obvious from the types, but POSIX has overtrained many
/// of us.
///
/// Recall when using these functions that datagram protocols allow for packet
/// loss and out-of-order delivery. So ... be warned.
async fn bind<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error>;
}

28
src/network/listener.rs Normal file
View File

@@ -0,0 +1,28 @@
use crate::network::address::SOCKSv5Address;
use crate::network::stream::GenericStream;
use async_trait::async_trait;
#[async_trait]
pub trait Listenerlike: Send + Sync {
type Error;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>;
fn info(&self) -> (SOCKSv5Address, u16);
}
pub struct GenericListener<E> {
pub internal: Box<dyn Listenerlike<Error = E>>,
}
#[async_trait]
impl<E> Listenerlike for GenericListener<E> {
type Error = E;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
Ok(self.internal.accept().await?)
}
fn info(&self) -> (SOCKSv5Address, u16) {
self.internal.info()
}
}

214
src/network/standard.rs Normal file
View File

@@ -0,0 +1,214 @@
use crate::messages::ServerResponseStatus;
use crate::network::address::{SOCKSv5Address, ToSOCKSAddress};
use crate::network::datagram::{Datagramlike, GenericDatagramSocket};
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::{GenericStream, Streamlike};
use async_std::io;
use async_std::net::{TcpListener, TcpStream, UdpSocket};
use async_trait::async_trait;
use log::error;
use std::net::Ipv4Addr;
pub struct Builtin {}
impl Builtin {
pub fn new() -> Builtin {
Builtin {}
}
}
impl Streamlike for TcpStream {}
#[async_trait]
impl Listenerlike for TcpListener {
type Error = io::Error;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
let (base, addrport) = self.accept().await?;
let addr = addrport.ip();
let port = addrport.port();
Ok((GenericStream::from(base), SOCKSv5Address::from(addr), port))
}
fn info(&self) -> (SOCKSv5Address, u16) {
match self.local_addr() {
Ok(x) => {
let addr = SOCKSv5Address::from(x.ip());
let port = x.port();
(addr, port)
}
Err(e) => {
error!("Someone asked for a listener address, and we got an error ({}); returning 0.0.0.0:0", e);
(SOCKSv5Address::IP4(Ipv4Addr::from(0)), 0)
}
}
}
}
#[async_trait]
impl Datagramlike for UdpSocket {
type Error = io::Error;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
match addr {
SOCKSv5Address::IP4(a) => self.send_to(buf, (a, port)).await,
SOCKSv5Address::IP6(a) => self.send_to(buf, (a, port)).await,
SOCKSv5Address::Name(n) => self.send_to(buf, (n.as_str(), port)).await,
}
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
let (amt, addrport) = self.recv_from(buf).await?;
let addr = addrport.ip();
let port = addrport.port();
Ok((amt, SOCKSv5Address::from(addr), port))
}
}
#[async_trait]
impl Networklike for Builtin {
type Error = io::Error;
async fn connect<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error> {
let target = addr.to_socks_address();
let base_stream = match target {
SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await?,
SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await?,
SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?,
};
Ok(GenericStream::from(base_stream))
}
async fn listen<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error> {
let target = addr.to_socks_address();
let base_stream = match target {
SOCKSv5Address::IP4(a) => TcpListener::bind((a, port)).await?,
SOCKSv5Address::IP6(a) => TcpListener::bind((a, port)).await?,
SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), port)).await?,
};
Ok(GenericListener {
internal: Box::new(base_stream),
})
}
async fn bind<A: ToSOCKSAddress>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
let target = addr.to_socks_address();
let base_socket = match target {
SOCKSv5Address::IP4(a) => UdpSocket::bind((a, port)).await?,
SOCKSv5Address::IP6(a) => UdpSocket::bind((a, port)).await?,
SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), port)).await?,
};
Ok(GenericDatagramSocket {
internal: Box::new(base_socket),
})
}
}
// pub struct StandardNetworking {}
//
// impl StandardNetworking {
// pub fn new() -> StandardNetworking {
// StandardNetworking {}
// }
// }
//
impl From<io::Error> for ServerResponseStatus {
fn from(e: io::Error) -> ServerResponseStatus {
match e.kind() {
io::ErrorKind::ConnectionRefused => ServerResponseStatus::ConnectionRefused,
io::ErrorKind::NotFound => ServerResponseStatus::HostUnreachable,
_ => ServerResponseStatus::GeneralFailure,
}
}
}
//
// #[async_trait]
// impl Network for StandardNetworking {
// type Stream = TcpStream;
// type Listener = TcpListener;
// type UdpSocket = UdpSocket;
// type Error = io::Error;
//
// async fn connect<A: ToSOCKSAddress>(
// &mut self,
// addr: A,
// port: u16,
// ) -> Result<Self::Stream, Self::Error> {
// let target = addr.to_socks_address();
//
// match target {
// SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await,
// SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await,
// SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await,
// }
// }
//
// async fn udp_socket<A: ToSOCKSAddress>(
// &mut self,
// addr: A,
// port: Option<u16>,
// ) -> Result<Self::UdpSocket, Self::Error> {
// let me = addr.to_socks_address();
// let real_port = port.unwrap_or(0);
//
// match me {
// SOCKSv5Address::IP4(a) => UdpSocket::bind((a, real_port)).await,
// SOCKSv5Address::IP6(a) => UdpSocket::bind((a, real_port)).await,
// SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), real_port)).await,
// }
// }
//
// async fn listen<A: ToSOCKSAddress>(
// &mut self,
// addr: A,
// port: Option<u16>,
// ) -> Result<Self::Listener, Self::Error> {
// let me = addr.to_socks_address();
// let real_port = port.unwrap_or(0);
//
// match me {
// SOCKSv5Address::IP4(a) => TcpListener::bind((a, real_port)).await,
// SOCKSv5Address::IP6(a) => TcpListener::bind((a, real_port)).await,
// SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), real_port)).await,
// }
// }
// }
//
// #[async_trait]
// impl SingleShotListener<TcpStream, io::Error> for TcpListener {
// async fn accept(self) -> Result<TcpStream, io::Error> {
// self.accept().await
// }
//
// fn info(&self) -> Result<(SOCKSv5Address, u16), io::Error> {
// match self.local_addr()? {
// SocketAddr::V4(a) => Ok((SOCKSv5Address::IP4(*a.ip()), a.port())),
// SocketAddr::V6(a) => Ok((SOCKSv5Address::IP6(*a.ip()), a.port())),
// }
// }
// }
//

52
src/network/stream.rs Normal file
View File

@@ -0,0 +1,52 @@
use async_std::task::{Context, Poll};
use futures::io;
use futures::io::{AsyncRead, AsyncWrite};
use std::pin::Pin;
use std::sync::Arc;
pub trait Streamlike: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
#[derive(Clone)]
pub struct GenericStream {
internal: Arc<Box<dyn Streamlike>>,
}
impl AsyncRead for GenericStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let base = Pin::into_inner(self);
Pin::new(base).poll_read(cx, buf)
}
}
impl AsyncWrite for GenericStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let base = Pin::into_inner(self);
Pin::new(base).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let base = Pin::into_inner(self);
Pin::new(base).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let base = Pin::into_inner(self);
Pin::new(base).poll_close(cx)
}
}
impl<T: Streamlike + 'static> From<T> for GenericStream {
fn from(x: T) -> GenericStream {
GenericStream {
internal: Arc::new(Box::new(x)),
}
}
}

60
src/serialize.rs Normal file
View File

@@ -0,0 +1,60 @@
use crate::errors::{DeserializationError, SerializationError};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use std::pin::Pin;
pub async fn read_string<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
) -> Result<String, DeserializationError> {
let mut length_buffer = [0; 1];
if r.read(&mut length_buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
let target = length_buffer[0] as usize;
if target == 0 {
return Err(DeserializationError::InvalidEmptyString);
}
let mut bytestring = vec![0; target];
read_amt(r, target, &mut bytestring).await?;
Ok(String::from_utf8(bytestring)?)
}
pub async fn write_string<W: AsyncWrite + Send + Unpin>(
s: &str,
w: &mut W,
) -> Result<(), SerializationError> {
let bytestring = s.as_bytes();
if bytestring.is_empty() || bytestring.len() > 255 {
return Err(SerializationError::InvalidStringLength(s.to_string()));
}
w.write_all(&[bytestring.len() as u8]).await?;
w.write_all(bytestring)
.await
.map_err(SerializationError::IOError)
}
pub async fn read_amt<R: AsyncRead + Send + Unpin>(
mut r: Pin<&mut R>,
amt: usize,
buffer: &mut [u8],
) -> Result<(), DeserializationError> {
let mut amt_read = 0;
while amt_read < amt {
let chunk_amt = r.read(&mut buffer[amt_read..]).await?;
if chunk_amt == 0 {
return Err(DeserializationError::NotEnoughData);
}
amt_read += chunk_amt;
}
Ok(())
}

313
src/server.rs Normal file
View File

@@ -0,0 +1,313 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::messages::{
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting,
ClientUsernamePassword, ServerChoice, ServerResponse, ServerResponseStatus,
};
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::GenericStream;
use crate::network::SOCKSv5Address;
use async_std::io;
use async_std::io::prelude::WriteExt;
use async_std::sync::{Arc, Mutex};
use async_std::task;
use log::{error, info, trace, warn};
use std::pin::Pin;
use thiserror::Error;
pub struct SOCKSv5Server<N: Networklike> {
network: N,
security_parameters: SecurityParameters,
listener: GenericListener<N::Error>,
}
#[derive(Clone)]
pub struct SecurityParameters {
pub allow_unauthenticated: bool,
pub allow_connection: Option<fn(&SOCKSv5Address, u16) -> bool>,
pub check_password: Option<fn(&str, &str) -> bool>,
pub connect_tls: Option<fn(GenericStream) -> Option<GenericStream>>,
}
impl<N: Networklike + Send + 'static> SOCKSv5Server<N> {
pub fn new<S: Listenerlike<Error = N::Error> + 'static>(
network: N,
security_parameters: SecurityParameters,
stream: S,
) -> SOCKSv5Server<N> {
SOCKSv5Server {
network,
security_parameters,
listener: GenericListener {
internal: Box::new(stream),
},
}
}
pub async fn run(self) -> Result<(), N::Error> {
let (my_addr, my_port) = self.listener.info();
info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port);
let locked_network = Arc::new(Mutex::new(self.network));
loop {
let (stream, their_addr, their_port) = self.listener.accept().await?;
trace!(
"Initial accept of connection from {}:{}",
their_addr,
their_port
);
if let Some(checker) = &self.security_parameters.allow_connection {
if !checker(&their_addr, their_port) {
info!(
"Rejecting attempted connection from {}:{}",
their_addr, their_port
);
continue;
}
}
let params = self.security_parameters.clone();
let network_mutex_copy = locked_network.clone();
task::spawn(async move {
if let Some(authed_stream) =
run_authentication(params, stream, their_addr.clone(), their_port).await
{
if let Err(e) = run_main_loop(network_mutex_copy, authed_stream).await {
warn!("Failure in main loop: {}", e);
}
}
});
}
}
}
async fn run_authentication(
params: SecurityParameters,
mut stream: GenericStream,
addr: SOCKSv5Address,
port: u16,
) -> Option<GenericStream> {
match ClientGreeting::read(Pin::new(&mut stream)).await {
Err(e) => {
error!(
"Client hello deserialization error from {}:{}: {}",
addr, port, e
);
None
}
// So we get opinionated here, based on what we think should be our first choice if the
// server offers something up. So we'll first see if we can make this a TLS connection.
Ok(cg)
if cg.acceptable_methods.contains(&AuthenticationMethod::SSL)
&& params.connect_tls.is_some() =>
{
match params.connect_tls {
None => {
error!("Internal error: TLS handler was there, but is now gone");
None
}
Some(converter) => match converter(stream) {
None => {
info!("Rejecting bad TLS handshake from {}:{}", addr, port);
None
}
Some(new_stream) => Some(new_stream),
},
}
}
// if we can't do that, we'll see if we can get a username and password
Ok(cg)
if cg
.acceptable_methods
.contains(&&AuthenticationMethod::UsernameAndPassword)
&& params.check_password.is_some() =>
{
match ClientUsernamePassword::read(Pin::new(&mut stream)).await {
Err(e) => {
warn!(
"Error reading username/password from {}:{}: {}",
addr, port, e
);
None
}
Ok(userinfo) => {
let checker = params.check_password.unwrap_or(|_, _| false);
if checker(&userinfo.username, &userinfo.password) {
Some(stream)
} else {
None
}
}
}
}
// and, in the worst case, we'll see if our user is cool with unauthenticated connections
Ok(cg)
if cg.acceptable_methods.contains(&&AuthenticationMethod::None)
&& params.allow_unauthenticated =>
{
Some(stream)
}
Ok(_) => {
let rejection_letter = ServerChoice {
chosen_method: AuthenticationMethod::NoAcceptableMethods,
};
if let Err(e) = rejection_letter.write(&mut stream).await {
warn!(
"Error sending rejection letter in authentication response: {}",
e
);
}
if let Err(e) = stream.flush().await {
warn!(
"Error flushing buffer after rejection latter in authentication response: {}",
e
);
}
None
}
}
}
#[derive(Error, Debug)]
enum ServerError {
#[error("Error in deserialization: {0}")]
DeserializationError(#[from] DeserializationError),
#[error("Error in serialization: {0}")]
SerializationError(#[from] SerializationError),
}
async fn run_main_loop<N>(
network: Arc<Mutex<N>>,
mut stream: GenericStream,
) -> Result<(), ServerError>
where
N: Networklike,
N::Error: 'static,
{
loop {
let ccr = ClientConnectionRequest::read(Pin::new(&mut stream)).await?;
match ccr.command_code {
ClientConnectionCommand::AssociateUDPPort => {}
ClientConnectionCommand::EstablishTCPPortBinding => {}
ClientConnectionCommand::EstablishTCPStream => {
let target = format!("{}:{}", ccr.destination_address, ccr.destination_port);
info!(
"Client requested connection to {}:{}",
ccr.destination_address, ccr.destination_port
);
let connection_res = {
let mut network = network.lock().await;
network
.connect(ccr.destination_address.clone(), ccr.destination_port)
.await
};
let outgoing_stream = match connection_res {
Ok(x) => x,
Err(e) => {
error!("Failed to connect to {}: {}", target, e);
let response = ServerResponse::error(e);
response.write(&mut stream).await?;
continue;
}
};
trace!(
"Connection established to {}:{}",
ccr.destination_address,
ccr.destination_port
);
let incoming_res = {
let mut network = network.lock().await;
network.listen("127.0.0.1", 0).await
};
let incoming_listener = match incoming_res {
Ok(x) => x,
Err(e) => {
error!("Failed to bind server port for new TCP stream: {}", e);
let response = ServerResponse::error(e);
response.write(&mut stream).await?;
continue;
}
};
let (bound_address, bound_port) = incoming_listener.info();
trace!(
"Set up {}:{} to address request for {}:{}",
bound_address,
bound_port,
ccr.destination_address,
ccr.destination_port
);
let response = ServerResponse {
status: ServerResponseStatus::RequestGranted,
bound_address,
bound_port,
};
response.write(&mut stream).await?;
task::spawn(async move {
let (incoming_stream, from_addr, from_port) = match incoming_listener
.accept()
.await
{
Err(e) => {
error!("Miscellaneous error waiting for someone to connect for proxying: {}", e);
return;
}
Ok(s) => s,
};
trace!(
"Accepted connection from {}:{} to attach to {}:{}",
from_addr,
from_port,
ccr.destination_address,
ccr.destination_port
);
let mut from_left = incoming_stream.clone();
let mut from_right = outgoing_stream.clone();
let mut to_left = incoming_stream;
let mut to_right = outgoing_stream;
let from = format!("{}:{}", from_addr, from_port);
let to = format!("{}:{}", ccr.destination_address, ccr.destination_port);
task::spawn(async move {
info!(
"Spawned {}:{} >--> {}:{} task",
from_addr, from_port, ccr.destination_address, ccr.destination_port
);
if let Err(e) = io::copy(&mut from_left, &mut to_right).await {
warn!(
"{}:{} >--> {}:{} connection failed with: {}",
from_addr,
from_port,
ccr.destination_address,
ccr.destination_port,
e
);
}
});
task::spawn(async move {
info!("Spawned {} <--< {} task", from, to);
if let Err(e) = io::copy(&mut from_right, &mut to_left).await {
warn!("{} <--< {} connection failed with: {}", from, to, e);
}
});
});
}
}
}
}