From 1bf6f62d4eccaee062b96cfc6cc424f6da5ed44e Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Thu, 24 Jun 2021 19:18:16 -0700 Subject: [PATCH] checkpoint --- .gitignore | 2 + Cargo.toml | 17 ++ src/bin/socks-server.rs | 30 ++ src/client.rs | 133 +++++++++ src/errors.rs | 51 ++++ src/lib.rs | 6 + src/messages.rs | 604 ++++++++++++++++++++++++++++++++++++++++ src/network.rs | 43 +++ src/network/address.rs | 151 ++++++++++ src/network/datagram.rs | 37 +++ src/network/generic.rs | 48 ++++ src/network/listener.rs | 28 ++ src/network/standard.rs | 214 ++++++++++++++ src/network/stream.rs | 52 ++++ src/serialize.rs | 60 ++++ src/server.rs | 313 +++++++++++++++++++++ 16 files changed, 1789 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/bin/socks-server.rs create mode 100644 src/client.rs create mode 100644 src/errors.rs create mode 100644 src/lib.rs create mode 100644 src/messages.rs create mode 100644 src/network.rs create mode 100644 src/network/address.rs create mode 100644 src/network/datagram.rs create mode 100644 src/network/generic.rs create mode 100644 src/network/listener.rs create mode 100644 src/network/standard.rs create mode 100644 src/network/stream.rs create mode 100644 src/serialize.rs create mode 100644 src/server.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ea5bf45 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "async-socks5" +version = "0.1.0" +authors = ["Adam Wick "] +edition = "2018" + +[lib] +name = "async_socks5" + +[dependencies] +async-std = { version = "1.9.0", features = ["attributes"] } +async-trait = "0.1.50" +futures = "0.3.15" +log = "0.4.8" +quickcheck = "1.0.3" +simplelog = "0.10.0" +thiserror = "1.0.24" \ No newline at end of file diff --git a/src/bin/socks-server.rs b/src/bin/socks-server.rs new file mode 100644 index 0000000..0b704cd --- /dev/null +++ b/src/bin/socks-server.rs @@ -0,0 +1,30 @@ +use async_socks5::network::Builtin; +use async_socks5::server::{SOCKSv5Server, SecurityParameters}; +use async_std::io; +use async_std::net::TcpListener; +use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TermLogger, TerminalMode}; + +#[async_std::main] +async fn main() -> Result<(), io::Error> { + CombinedLogger::init(vec![TermLogger::new( + LevelFilter::Debug, + Config::default(), + TerminalMode::Mixed, + ColorChoice::Auto, + )]) + .expect("Couldn't initialize logger"); + + let main_listener = TcpListener::bind("127.0.0.1:0").await?; + let params = SecurityParameters { + allow_unauthenticated: false, + allow_connection: None, + check_password: None, + connect_tls: None, + }; + + let server = SOCKSv5Server::new(Builtin::new(), params, main_listener); + + server.run().await?; + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..6c46f1b --- /dev/null +++ b/src/client.rs @@ -0,0 +1,133 @@ +use crate::errors::{DeserializationError, SerializationError}; +use crate::messages::{ + AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting, + ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus, +}; +use crate::network::{Network, SOCKSv5Address}; +use async_std::net::IpAddr; +use futures::io::{AsyncRead, AsyncWrite}; +use std::pin::Pin; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum SOCKSv5Error { + #[error("SOCKSv5 serialization error: {0}")] + SerializationError(#[from] SerializationError), + #[error("SOCKSv5 deserialization error: {0}")] + DeserializationError(#[from] DeserializationError), + #[error("No acceptable authentication methods available")] + NoAuthMethodsAllowed, + #[error("Authentication failed")] + AuthenticationFailed, + #[error("Server chose an unsupported authentication method ({0}")] + UnsupportedAuthMethodChosen(AuthenticationMethod), + #[error("Server said no: {0}")] + ServerFailure(#[from] ServerResponseStatus), +} + +pub struct SOCKSv5Client +where + S: AsyncRead + AsyncWrite, + N: Network, +{ + _network: N, + stream: S, +} + +pub struct LoginInfo { + username_password: Option, +} + +pub struct UsernamePassword { + username: String, + password: String, +} + +impl SOCKSv5Client +where + S: AsyncRead + AsyncWrite + Send + Unpin, + N: Network, +{ + /// Create a new SOCKSv5 client connection over the given steam, using the given + /// authentication information. + pub async fn new(_network: N, mut stream: S, login: &LoginInfo) -> Result { + let mut acceptable_methods = vec![AuthenticationMethod::None]; + + if login.username_password.is_some() { + acceptable_methods.push(AuthenticationMethod::UsernameAndPassword); + } + + let client_greeting = ClientGreeting { acceptable_methods }; + + client_greeting.write(&mut stream).await?; + let server_choice = ServerChoice::read(Pin::new(&mut stream)).await?; + + match server_choice.chosen_method { + AuthenticationMethod::None => {} + + AuthenticationMethod::UsernameAndPassword => { + let (username, password) = if let Some(ref linfo) = login.username_password { + (linfo.username.clone(), linfo.password.clone()) + } else { + ("".to_string(), "".to_string()) + }; + + let auth_request = ClientUsernamePassword { username, password }; + + auth_request.write(&mut stream).await?; + let server_response = ServerAuthResponse::read(Pin::new(&mut stream)).await?; + + if !server_response.success { + return Err(SOCKSv5Error::AuthenticationFailed); + } + } + + AuthenticationMethod::NoAcceptableMethods => { + return Err(SOCKSv5Error::NoAuthMethodsAllowed) + } + + x => return Err(SOCKSv5Error::UnsupportedAuthMethodChosen(x)), + } + + Ok(SOCKSv5Client { _network, stream }) + } + + async fn connect_internal( + &mut self, + addr: SOCKSv5Address, + port: u16, + ) -> Result { + let request = ClientConnectionRequest { + command_code: ClientConnectionCommand::EstablishTCPStream, + destination_address: addr, + destination_port: port, + }; + + request.write(&mut self.stream).await?; + let response = ServerResponse::read(Pin::new(&mut self.stream)).await?; + + if response.status == ServerResponseStatus::RequestGranted { + unimplemented!() + } else { + Err(SOCKSv5Error::from(response.status)) + } + } + + pub async fn connect(&mut self, addr: IpAddr, port: u16) -> Result { + assert!(port != 0); + match addr { + IpAddr::V4(a) => self.connect_internal(SOCKSv5Address::IP4(a), port).await, + IpAddr::V6(a) => self.connect_internal(SOCKSv5Address::IP6(a), port).await, + } + } + + pub async fn connect_name( + &mut self, + name: String, + port: u16, + ) -> Result { + format!("hello {}", 'a'); + self.connect_internal(SOCKSv5Address::Name(name), port) + .await + } +} diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..aa1a73d --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,51 @@ +use std::io; +use std::string::FromUtf8Error; +use thiserror::Error; + +/// All the errors that can pop up when trying to turn raw bytes into SOCKSv5 +/// messages. +#[derive(Error, Debug)] +pub enum DeserializationError { + #[error("Invalid protocol version for packet ({1} is not {0}!)")] + InvalidVersion(u8, u8), + #[error("Not enough data found")] + NotEnoughData, + #[error("Ooops! Found an empty string where I shouldn't")] + InvalidEmptyString, + #[error("IO error: {0}")] + IOError(#[from] io::Error), + #[error("SOCKS authentication format parse error: {0}")] + AuthenticationMethodError(#[from] AuthenticationDeserializationError), + #[error("Error converting from UTF-8: {0}")] + UTF8Error(#[from] FromUtf8Error), + #[error("Invalid address type; wanted 1, 3, or 4, got {0}")] + InvalidAddressType(u8), + #[error("Invalid client command {0}; expected 1, 2, or 3")] + InvalidClientCommand(u8), + #[error("Invalid server status {0}; expected 0-8")] + InvalidServerResponse(u8), +} + +/// All the errors that can occur trying to turn SOCKSv5 message structures +/// into raw bytes. There's a few places that the message structures allow +/// for information that can't be serialized; often, you have to be careful +/// about how long your strings are ... +#[derive(Error, Debug)] +pub enum SerializationError { + #[error("Too many authentication methods for serialization ({0} > 255)")] + TooManyAuthMethods(usize), + #[error("Invalid length for string: {0}")] + InvalidStringLength(String), + #[error("IO error: {0}")] + IOError(#[from] io::Error), +} + +#[derive(Error, Debug)] +pub enum AuthenticationDeserializationError { + #[error("No data found deserializing SOCKS authentication type")] + NoDataFound, + #[error("Invalid authentication type value: {0}")] + InvalidAuthenticationByte(u8), + #[error("IO error reading SOCKS authentication type: {0}")] + IOError(#[from] io::Error), +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9dc1882 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,6 @@ +pub mod client; +pub mod errors; +pub mod messages; +pub mod network; +mod serialize; +pub mod server; diff --git a/src/messages.rs b/src/messages.rs new file mode 100644 index 0000000..f3632e5 --- /dev/null +++ b/src/messages.rs @@ -0,0 +1,604 @@ +use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError}; +use crate::network::SOCKSv5Address; +use crate::serialize::{read_amt, read_string, write_string}; +#[cfg(test)] +use async_std::task; +#[cfg(test)] +use futures::io::Cursor; +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use log::warn; +#[cfg(test)] +use quickcheck::{quickcheck, Arbitrary, Gen}; +use std::fmt; +use std::net::Ipv4Addr; +use std::pin::Pin; +use thiserror::Error; + +/// Client greetings are the first message sent in a SOCKSv5 session. They +/// identify that there's a client that wants to talk to a server, and that +/// they can support any of the provided mechanisms for authenticating to +/// said server. (It feels weird that the offer/choice goes this way instead +/// of the reverse, but whatever.) +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ClientGreeting { + pub acceptable_methods: Vec, +} + +impl ClientGreeting { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + let raw_r = Pin::into_inner(r); + + if raw_r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + if raw_r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize); + for _ in 0..buffer[0] { + acceptable_methods.push(AuthenticationMethod::read(Pin::new(raw_r)).await?); + } + + Ok(ClientGreeting { acceptable_methods }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + if self.acceptable_methods.len() > 255 { + return Err(SerializationError::TooManyAuthMethods( + self.acceptable_methods.len(), + )); + } + + let mut buffer = Vec::with_capacity(self.acceptable_methods.len() + 2); + buffer.push(5); + buffer.push(self.acceptable_methods.len() as u8); + w.write_all(&buffer).await?; + for authmeth in self.acceptable_methods.iter() { + authmeth.write(w).await?; + } + Ok(()) + } +} + +#[cfg(test)] +impl Arbitrary for ClientGreeting { + fn arbitrary(g: &mut Gen) -> ClientGreeting { + let amt = u8::arbitrary(g); + let mut acceptable_methods = Vec::with_capacity(amt as usize); + + for _ in 0..amt { + acceptable_methods.push(AuthenticationMethod::arbitrary(g)); + } + + ClientGreeting { acceptable_methods } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ServerChoice { + pub chosen_method: AuthenticationMethod, +} + +impl ServerChoice { + pub async fn read( + mut r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + + if r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + let chosen_method = AuthenticationMethod::read(r).await?; + + Ok(ServerChoice { chosen_method }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + w.write_all(&[5]).await?; + self.chosen_method.write(w).await + } +} + +#[cfg(test)] +impl Arbitrary for ServerChoice { + fn arbitrary(g: &mut Gen) -> ServerChoice { + ServerChoice { + chosen_method: AuthenticationMethod::arbitrary(g), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ClientUsernamePassword { + pub username: String, + pub password: String, +} + +impl ClientUsernamePassword { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + let raw_r = Pin::into_inner(r); + + if raw_r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 1 { + return Err(DeserializationError::InvalidVersion(1, buffer[0])); + } + + let username = read_string(Pin::new(raw_r)).await?; + let password = read_string(Pin::new(raw_r)).await?; + + Ok(ClientUsernamePassword { username, password }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + w.write_all(&[1]).await?; + write_string(&self.username, w).await?; + write_string(&self.password, w).await + } +} + +#[cfg(test)] +impl Arbitrary for ClientUsernamePassword { + fn arbitrary(g: &mut Gen) -> Self { + let username = arbitrary_socks_string(g); + let password = arbitrary_socks_string(g); + + ClientUsernamePassword { username, password } + } +} + +#[cfg(test)] +pub fn arbitrary_socks_string(g: &mut Gen) -> String { + loop { + let mut potential = String::arbitrary(g); + + potential.truncate(255); + let bytestring = potential.as_bytes(); + + if bytestring.len() > 0 && bytestring.len() < 256 { + return potential; + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ServerAuthResponse { + pub success: bool, +} + +impl ServerAuthResponse { + pub async fn read( + mut r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 1]; + + if r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + if buffer[0] != 1 { + return Err(DeserializationError::InvalidVersion(1, buffer[0])); + } + + if r.read(&mut buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + Ok(ServerAuthResponse { + success: buffer[0] == 0, + }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + w.write_all(&[1]).await?; + w.write_all(&[if self.success { 0x00 } else { 0xde }]) + .await?; + Ok(()) + } +} + +#[cfg(test)] +impl Arbitrary for ServerAuthResponse { + fn arbitrary(g: &mut Gen) -> ServerAuthResponse { + let success = bool::arbitrary(g); + ServerAuthResponse { success } + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AuthenticationMethod { + None, + GSSAPI, + UsernameAndPassword, + ChallengeHandshake, + ChallengeResponse, + SSL, + NDS, + MultiAuthenticationFramework, + JSONPropertyBlock, + PrivateMethod(u8), + NoAcceptableMethods, +} + +impl fmt::Display for AuthenticationMethod { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AuthenticationMethod::None => write!(f, "No authentication"), + AuthenticationMethod::GSSAPI => write!(f, "GSS-API"), + AuthenticationMethod::UsernameAndPassword => write!(f, "Username and password"), + AuthenticationMethod::ChallengeHandshake => write!(f, "Challenge/Handshake"), + AuthenticationMethod::ChallengeResponse => write!(f, "Challenge/Response"), + AuthenticationMethod::SSL => write!(f, "SSL"), + AuthenticationMethod::NDS => write!(f, "NDS Authentication"), + AuthenticationMethod::MultiAuthenticationFramework => { + write!(f, "Multi-Authentication Framework") + } + AuthenticationMethod::JSONPropertyBlock => write!(f, "JSON Property Block"), + AuthenticationMethod::PrivateMethod(m) => write!(f, "Private Method {:x}", m), + AuthenticationMethod::NoAcceptableMethods => write!(f, "No Acceptable Methods"), + } + } +} + +impl AuthenticationMethod { + pub async fn read( + mut r: Pin<&mut R>, + ) -> Result { + let mut byte_buffer = [0u8; 1]; + let amount_read = r.read(&mut byte_buffer).await?; + + if amount_read == 0 { + return Err(AuthenticationDeserializationError::NoDataFound.into()); + } + + match byte_buffer[0] { + 0 => Ok(AuthenticationMethod::None), + 1 => Ok(AuthenticationMethod::GSSAPI), + 2 => Ok(AuthenticationMethod::UsernameAndPassword), + 3 => Ok(AuthenticationMethod::ChallengeHandshake), + 5 => Ok(AuthenticationMethod::ChallengeResponse), + 6 => Ok(AuthenticationMethod::SSL), + 7 => Ok(AuthenticationMethod::NDS), + 8 => Ok(AuthenticationMethod::MultiAuthenticationFramework), + 9 => Ok(AuthenticationMethod::JSONPropertyBlock), + x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)), + 0xff => Ok(AuthenticationMethod::NoAcceptableMethods), + e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()), + } + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + let value = match self { + AuthenticationMethod::None => 0, + AuthenticationMethod::GSSAPI => 1, + AuthenticationMethod::UsernameAndPassword => 2, + AuthenticationMethod::ChallengeHandshake => 3, + AuthenticationMethod::ChallengeResponse => 5, + AuthenticationMethod::SSL => 6, + AuthenticationMethod::NDS => 7, + AuthenticationMethod::MultiAuthenticationFramework => 8, + AuthenticationMethod::JSONPropertyBlock => 9, + AuthenticationMethod::PrivateMethod(pm) => *pm, + AuthenticationMethod::NoAcceptableMethods => 0xff, + }; + + Ok(w.write_all(&[value]).await?) + } +} + +#[cfg(test)] +impl Arbitrary for AuthenticationMethod { + fn arbitrary(g: &mut Gen) -> AuthenticationMethod { + let mut vals = vec![ + AuthenticationMethod::None, + AuthenticationMethod::GSSAPI, + AuthenticationMethod::UsernameAndPassword, + AuthenticationMethod::ChallengeHandshake, + AuthenticationMethod::ChallengeResponse, + AuthenticationMethod::SSL, + AuthenticationMethod::NDS, + AuthenticationMethod::MultiAuthenticationFramework, + AuthenticationMethod::JSONPropertyBlock, + AuthenticationMethod::NoAcceptableMethods, + ]; + for x in 0x80..0xffu8 { + vals.push(AuthenticationMethod::PrivateMethod(x)); + } + g.choose(&vals).unwrap().clone() + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ClientConnectionCommand { + EstablishTCPStream, + EstablishTCPPortBinding, + AssociateUDPPort, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ClientConnectionRequest { + pub command_code: ClientConnectionCommand, + pub destination_address: SOCKSv5Address, + pub destination_port: u16, +} + +impl ClientConnectionRequest { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 2]; + let raw_r = Pin::into_inner(r); + + read_amt(Pin::new(raw_r), 2, &mut buffer).await?; + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + let command_code = match buffer[1] { + 0x01 => ClientConnectionCommand::EstablishTCPStream, + 0x02 => ClientConnectionCommand::EstablishTCPPortBinding, + 0x03 => ClientConnectionCommand::AssociateUDPPort, + x => return Err(DeserializationError::InvalidClientCommand(x)), + }; + + let destination_address = SOCKSv5Address::read(Pin::new(raw_r)).await?; + + read_amt(Pin::new(raw_r), 2, &mut buffer).await?; + let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); + + Ok(ClientConnectionRequest { + command_code, + destination_address, + destination_port, + }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + let command = match self.command_code { + ClientConnectionCommand::EstablishTCPStream => 1, + ClientConnectionCommand::EstablishTCPPortBinding => 2, + ClientConnectionCommand::AssociateUDPPort => 3, + }; + + w.write_all(&[5, command]).await?; + self.destination_address.write(w).await?; + w.write_all(&[ + (self.destination_port >> 8) as u8, + (self.destination_port & 0xffu16) as u8, + ]) + .await + .map_err(SerializationError::IOError) + } +} + +#[cfg(test)] +impl Arbitrary for ClientConnectionCommand { + fn arbitrary(g: &mut Gen) -> ClientConnectionCommand { + g.choose(&[ + ClientConnectionCommand::EstablishTCPStream, + ClientConnectionCommand::EstablishTCPPortBinding, + ClientConnectionCommand::AssociateUDPPort, + ]) + .unwrap() + .clone() + } +} + +#[cfg(test)] +impl Arbitrary for ClientConnectionRequest { + fn arbitrary(g: &mut Gen) -> Self { + let command_code = ClientConnectionCommand::arbitrary(g); + let destination_address = SOCKSv5Address::arbitrary(g); + let destination_port = u16::arbitrary(g); + + ClientConnectionRequest { + command_code, + destination_address, + destination_port, + } + } +} + +#[derive(Clone, Debug, Eq, Error, PartialEq)] +pub enum ServerResponseStatus { + #[error("Actually, everything's fine (weird to see this in an error)")] + RequestGranted, + #[error("General server failure")] + GeneralFailure, + #[error("Connection not allowed by policy rule")] + ConnectionNotAllowedByRule, + #[error("Network unreachable")] + NetworkUnreachable, + #[error("Host unreachable")] + HostUnreachable, + #[error("Connection refused")] + ConnectionRefused, + #[error("TTL expired")] + TTLExpired, + #[error("Command not supported")] + CommandNotSupported, + #[error("Address type not supported")] + AddressTypeNotSupported, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ServerResponse { + pub status: ServerResponseStatus, + pub bound_address: SOCKSv5Address, + pub bound_port: u16, +} + +impl ServerResponse { + pub fn error>(resp: E) -> ServerResponse { + ServerResponse { + status: resp.into(), + bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)), + bound_port: 0, + } + } +} + +impl ServerResponse { + pub async fn read( + r: Pin<&mut R>, + ) -> Result { + let mut buffer = [0; 3]; + let raw_r = Pin::into_inner(r); + + read_amt(Pin::new(raw_r), 3, &mut buffer).await?; + + if buffer[0] != 5 { + return Err(DeserializationError::InvalidVersion(5, buffer[0])); + } + + if buffer[2] != 0 { + warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes."); + } + + let status = match buffer[1] { + 0x00 => ServerResponseStatus::RequestGranted, + 0x01 => ServerResponseStatus::GeneralFailure, + 0x02 => ServerResponseStatus::ConnectionNotAllowedByRule, + 0x03 => ServerResponseStatus::NetworkUnreachable, + 0x04 => ServerResponseStatus::HostUnreachable, + 0x05 => ServerResponseStatus::ConnectionRefused, + 0x06 => ServerResponseStatus::TTLExpired, + 0x07 => ServerResponseStatus::CommandNotSupported, + 0x08 => ServerResponseStatus::AddressTypeNotSupported, + x => return Err(DeserializationError::InvalidServerResponse(x)), + }; + + let bound_address = SOCKSv5Address::read(Pin::new(raw_r)).await?; + read_amt(Pin::new(raw_r), 2, &mut buffer).await?; + let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16); + + Ok(ServerResponse { + status, + bound_address, + bound_port, + }) + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + let status_code = match self.status { + ServerResponseStatus::RequestGranted => 0x00, + ServerResponseStatus::GeneralFailure => 0x01, + ServerResponseStatus::ConnectionNotAllowedByRule => 0x02, + ServerResponseStatus::NetworkUnreachable => 0x03, + ServerResponseStatus::HostUnreachable => 0x04, + ServerResponseStatus::ConnectionRefused => 0x05, + ServerResponseStatus::TTLExpired => 0x06, + ServerResponseStatus::CommandNotSupported => 0x07, + ServerResponseStatus::AddressTypeNotSupported => 0x08, + }; + + w.write_all(&[5, status_code, 0]).await?; + self.bound_address.write(w).await?; + w.write_all(&[ + (self.bound_port >> 8) as u8, + (self.bound_port & 0xffu16) as u8, + ]) + .await + .map_err(SerializationError::IOError) + } +} + +#[cfg(test)] +impl Arbitrary for ServerResponseStatus { + fn arbitrary(g: &mut Gen) -> ServerResponseStatus { + g.choose(&[ + ServerResponseStatus::RequestGranted, + ServerResponseStatus::GeneralFailure, + ServerResponseStatus::ConnectionNotAllowedByRule, + ServerResponseStatus::NetworkUnreachable, + ServerResponseStatus::HostUnreachable, + ServerResponseStatus::ConnectionRefused, + ServerResponseStatus::TTLExpired, + ServerResponseStatus::CommandNotSupported, + ServerResponseStatus::AddressTypeNotSupported, + ]) + .unwrap() + .clone() + } +} + +#[cfg(test)] +impl Arbitrary for ServerResponse { + fn arbitrary(g: &mut Gen) -> Self { + let status = ServerResponseStatus::arbitrary(g); + let bound_address = SOCKSv5Address::arbitrary(g); + let bound_port = u16::arbitrary(g); + + ServerResponse { + status, + bound_address, + bound_port, + } + } +} + +macro_rules! standard_roundtrip { + ($name: ident, $t: ty) => { + #[cfg(test)] + quickcheck! { + fn $name(xs: $t) -> bool { + let mut buffer = vec![]; + task::block_on(xs.write(&mut buffer)).unwrap(); + let mut cursor = Cursor::new(buffer); + let ys = <$t>::read(Pin::new(&mut cursor)); + xs == task::block_on(ys).unwrap() + } + } + }; +} + +standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod); +standard_roundtrip!(client_greeting_roundtrips, ClientGreeting); +standard_roundtrip!(server_choice_roundtrips, ServerChoice); +standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword); +standard_roundtrip!(server_auth_response, ServerAuthResponse); +standard_roundtrip!(address_roundtrips, SOCKSv5Address); +standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest); +standard_roundtrip!(server_response_roundtrips, ServerResponse); diff --git a/src/network.rs b/src/network.rs new file mode 100644 index 0000000..c1b22ba --- /dev/null +++ b/src/network.rs @@ -0,0 +1,43 @@ +pub mod address; +pub mod datagram; +pub mod generic; +pub mod listener; +pub mod standard; +pub mod stream; + +use crate::messages::ServerResponseStatus; +pub use crate::network::address::{SOCKSv5Address, ToSOCKSAddress}; +pub use crate::network::standard::Builtin; +use async_trait::async_trait; +use futures::{AsyncRead, AsyncWrite}; +use std::fmt; + +#[async_trait] +pub trait Network { + type Stream: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static; + type Listener: SingleShotListener + Send + Sync + 'static; + type UdpSocket; + type Error: fmt::Debug + fmt::Display + Into; + + async fn connect( + &mut self, + addr: A, + port: u16, + ) -> Result; + async fn udp_socket( + &mut self, + addr: A, + port: Option, + ) -> Result; + async fn listen( + &mut self, + addr: A, + port: Option, + ) -> Result; +} + +#[async_trait] +pub trait SingleShotListener { + async fn accept(self) -> Result; + fn info(&self) -> Result<(SOCKSv5Address, u16), Error>; +} diff --git a/src/network/address.rs b/src/network/address.rs new file mode 100644 index 0000000..fa1c584 --- /dev/null +++ b/src/network/address.rs @@ -0,0 +1,151 @@ +use crate::errors::{DeserializationError, SerializationError}; +#[cfg(test)] +use crate::messages::arbitrary_socks_string; +use crate::serialize::{read_amt, read_string, write_string}; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[cfg(test)] +use quickcheck::{Arbitrary, Gen}; +use std::fmt; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::pin::Pin; + +pub trait ToSOCKSAddress: Send { + fn to_socks_address(&self) -> SOCKSv5Address; +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum SOCKSv5Address { + IP4(Ipv4Addr), + IP6(Ipv6Addr), + Name(String), +} + +impl ToSOCKSAddress for SOCKSv5Address { + fn to_socks_address(&self) -> SOCKSv5Address { + self.clone() + } +} + +impl ToSOCKSAddress for IpAddr { + fn to_socks_address(&self) -> SOCKSv5Address { + match self { + IpAddr::V4(a) => SOCKSv5Address::IP4(*a), + IpAddr::V6(a) => SOCKSv5Address::IP6(*a), + } + } +} + +impl ToSOCKSAddress for Ipv4Addr { + fn to_socks_address(&self) -> SOCKSv5Address { + SOCKSv5Address::IP4(*self) + } +} + +impl ToSOCKSAddress for Ipv6Addr { + fn to_socks_address(&self) -> SOCKSv5Address { + SOCKSv5Address::IP6(*self) + } +} + +impl ToSOCKSAddress for String { + fn to_socks_address(&self) -> SOCKSv5Address { + SOCKSv5Address::Name(self.clone()) + } +} + +impl<'a> ToSOCKSAddress for &'a str { + fn to_socks_address(&self) -> SOCKSv5Address { + SOCKSv5Address::Name(self.to_string()) + } +} + +impl fmt::Display for SOCKSv5Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SOCKSv5Address::IP4(a) => write!(f, "{}", a), + SOCKSv5Address::IP6(a) => write!(f, "{}", a), + SOCKSv5Address::Name(a) => write!(f, "{}", a), + } + } +} + +impl From for SOCKSv5Address { + fn from(addr: IpAddr) -> SOCKSv5Address { + match addr { + IpAddr::V4(a) => SOCKSv5Address::IP4(a), + IpAddr::V6(a) => SOCKSv5Address::IP6(a), + } + } +} + +impl SOCKSv5Address { + pub async fn read( + mut r: Pin<&mut R>, + ) -> Result { + let mut byte_buffer = [0u8; 1]; + let amount_read = r.read(&mut byte_buffer).await?; + + if amount_read == 0 { + return Err(DeserializationError::NotEnoughData); + } + + match byte_buffer[0] { + 1 => { + let mut addr_buffer = [0; 4]; + read_amt(r, 4, &mut addr_buffer).await?; + Ok(SOCKSv5Address::IP4(Ipv4Addr::from(addr_buffer))) + } + 3 => { + let mut addr_buffer = [0; 16]; + read_amt(r, 16, &mut addr_buffer).await?; + Ok(SOCKSv5Address::IP6(Ipv6Addr::from(addr_buffer))) + } + 4 => { + let name = read_string(r).await?; + Ok(SOCKSv5Address::Name(name)) + } + x => Err(DeserializationError::InvalidAddressType(x)), + } + } + + pub async fn write( + &self, + w: &mut W, + ) -> Result<(), SerializationError> { + match self { + SOCKSv5Address::IP4(x) => { + w.write_all(&[1]).await?; + w.write_all(&x.octets()) + .await + .map_err(SerializationError::IOError) + } + SOCKSv5Address::IP6(x) => { + w.write_all(&[3]).await?; + w.write_all(&x.octets()) + .await + .map_err(SerializationError::IOError) + } + SOCKSv5Address::Name(x) => { + w.write_all(&[4]).await?; + write_string(x, w).await + } + } + } +} + +#[cfg(test)] +impl Arbitrary for SOCKSv5Address { + fn arbitrary(g: &mut Gen) -> Self { + let ip4 = Ipv4Addr::arbitrary(g); + let ip6 = Ipv6Addr::arbitrary(g); + let nm = arbitrary_socks_string(g); + + g.choose(&[ + SOCKSv5Address::IP4(ip4), + SOCKSv5Address::IP6(ip6), + SOCKSv5Address::Name(nm), + ]) + .unwrap() + .clone() + } +} diff --git a/src/network/datagram.rs b/src/network/datagram.rs new file mode 100644 index 0000000..fd74f13 --- /dev/null +++ b/src/network/datagram.rs @@ -0,0 +1,37 @@ +use crate::network::address::SOCKSv5Address; +use async_trait::async_trait; + +#[async_trait] +pub trait Datagramlike: Send + Sync { + type Error; + + async fn send_to( + &self, + buf: &[u8], + addr: SOCKSv5Address, + port: u16, + ) -> Result; + async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error>; +} + +pub struct GenericDatagramSocket { + pub internal: Box>, +} + +#[async_trait] +impl Datagramlike for GenericDatagramSocket { + type Error = E; + + async fn send_to( + &self, + buf: &[u8], + addr: SOCKSv5Address, + port: u16, + ) -> Result { + Ok(self.internal.send_to(buf, addr, port).await?) + } + + async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> { + Ok(self.internal.recv_from(buf).await?) + } +} diff --git a/src/network/generic.rs b/src/network/generic.rs new file mode 100644 index 0000000..00e2a64 --- /dev/null +++ b/src/network/generic.rs @@ -0,0 +1,48 @@ +use crate::messages::ServerResponseStatus; +use crate::network::address::ToSOCKSAddress; +use crate::network::datagram::GenericDatagramSocket; +use crate::network::listener::GenericListener; +use crate::network::stream::GenericStream; +use async_trait::async_trait; +use std::fmt::Display; + +#[async_trait] +pub trait Networklike { + /// The error type for things that fail on this network. Apologies in advance + /// for using only one; if you have a use case for separating your errors, + /// please shoot the author(s) and email to split this into multiple types, one + /// for each trait function. + type Error: Display + Into + Send; + + /// Connect to the given address and port, over this kind of network. The + /// underlying stream should behave somewhat like a TCP stream ... which + /// may be exactly what you're using. However, in order to support tunnelling + /// scenarios (i.e., using another proxy, going through Tor or SSH, etc.) we + /// work generically over any stream-like object. + async fn connect( + &mut self, + addr: A, + port: u16, + ) -> Result; + + /// Listen for connections on the given address and port, returning a generic + /// listener socket to use in the future. + async fn listen( + &mut self, + addr: A, + port: u16, + ) -> Result, Self::Error>; + + /// Bind a socket for the purposes of doing some datagram communication. NOTE! + /// this is only for UDP-like communication, not for generic connecting or + /// listening! Maybe obvious from the types, but POSIX has overtrained many + /// of us. + /// + /// Recall when using these functions that datagram protocols allow for packet + /// loss and out-of-order delivery. So ... be warned. + async fn bind( + &mut self, + addr: A, + port: u16, + ) -> Result, Self::Error>; +} diff --git a/src/network/listener.rs b/src/network/listener.rs new file mode 100644 index 0000000..f125948 --- /dev/null +++ b/src/network/listener.rs @@ -0,0 +1,28 @@ +use crate::network::address::SOCKSv5Address; +use crate::network::stream::GenericStream; +use async_trait::async_trait; + +#[async_trait] +pub trait Listenerlike: Send + Sync { + type Error; + + async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>; + fn info(&self) -> (SOCKSv5Address, u16); +} + +pub struct GenericListener { + pub internal: Box>, +} + +#[async_trait] +impl Listenerlike for GenericListener { + type Error = E; + + async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> { + Ok(self.internal.accept().await?) + } + + fn info(&self) -> (SOCKSv5Address, u16) { + self.internal.info() + } +} diff --git a/src/network/standard.rs b/src/network/standard.rs new file mode 100644 index 0000000..8eae7de --- /dev/null +++ b/src/network/standard.rs @@ -0,0 +1,214 @@ +use crate::messages::ServerResponseStatus; +use crate::network::address::{SOCKSv5Address, ToSOCKSAddress}; +use crate::network::datagram::{Datagramlike, GenericDatagramSocket}; +use crate::network::generic::Networklike; +use crate::network::listener::{GenericListener, Listenerlike}; +use crate::network::stream::{GenericStream, Streamlike}; +use async_std::io; +use async_std::net::{TcpListener, TcpStream, UdpSocket}; +use async_trait::async_trait; +use log::error; +use std::net::Ipv4Addr; + +pub struct Builtin {} + +impl Builtin { + pub fn new() -> Builtin { + Builtin {} + } +} + +impl Streamlike for TcpStream {} + +#[async_trait] +impl Listenerlike for TcpListener { + type Error = io::Error; + + async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> { + let (base, addrport) = self.accept().await?; + let addr = addrport.ip(); + let port = addrport.port(); + Ok((GenericStream::from(base), SOCKSv5Address::from(addr), port)) + } + + fn info(&self) -> (SOCKSv5Address, u16) { + match self.local_addr() { + Ok(x) => { + let addr = SOCKSv5Address::from(x.ip()); + let port = x.port(); + (addr, port) + } + Err(e) => { + error!("Someone asked for a listener address, and we got an error ({}); returning 0.0.0.0:0", e); + (SOCKSv5Address::IP4(Ipv4Addr::from(0)), 0) + } + } + } +} + +#[async_trait] +impl Datagramlike for UdpSocket { + type Error = io::Error; + + async fn send_to( + &self, + buf: &[u8], + addr: SOCKSv5Address, + port: u16, + ) -> Result { + match addr { + SOCKSv5Address::IP4(a) => self.send_to(buf, (a, port)).await, + SOCKSv5Address::IP6(a) => self.send_to(buf, (a, port)).await, + SOCKSv5Address::Name(n) => self.send_to(buf, (n.as_str(), port)).await, + } + } + + async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> { + let (amt, addrport) = self.recv_from(buf).await?; + let addr = addrport.ip(); + let port = addrport.port(); + Ok((amt, SOCKSv5Address::from(addr), port)) + } +} + +#[async_trait] +impl Networklike for Builtin { + type Error = io::Error; + + async fn connect( + &mut self, + addr: A, + port: u16, + ) -> Result { + let target = addr.to_socks_address(); + + let base_stream = match target { + SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await?, + SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await?, + SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?, + }; + + Ok(GenericStream::from(base_stream)) + } + + async fn listen( + &mut self, + addr: A, + port: u16, + ) -> Result, Self::Error> { + let target = addr.to_socks_address(); + + let base_stream = match target { + SOCKSv5Address::IP4(a) => TcpListener::bind((a, port)).await?, + SOCKSv5Address::IP6(a) => TcpListener::bind((a, port)).await?, + SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), port)).await?, + }; + + Ok(GenericListener { + internal: Box::new(base_stream), + }) + } + + async fn bind( + &mut self, + addr: A, + port: u16, + ) -> Result, Self::Error> { + let target = addr.to_socks_address(); + + let base_socket = match target { + SOCKSv5Address::IP4(a) => UdpSocket::bind((a, port)).await?, + SOCKSv5Address::IP6(a) => UdpSocket::bind((a, port)).await?, + SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), port)).await?, + }; + + Ok(GenericDatagramSocket { + internal: Box::new(base_socket), + }) + } +} + +// pub struct StandardNetworking {} +// +// impl StandardNetworking { +// pub fn new() -> StandardNetworking { +// StandardNetworking {} +// } +// } +// +impl From for ServerResponseStatus { + fn from(e: io::Error) -> ServerResponseStatus { + match e.kind() { + io::ErrorKind::ConnectionRefused => ServerResponseStatus::ConnectionRefused, + io::ErrorKind::NotFound => ServerResponseStatus::HostUnreachable, + _ => ServerResponseStatus::GeneralFailure, + } + } +} +// +// #[async_trait] +// impl Network for StandardNetworking { +// type Stream = TcpStream; +// type Listener = TcpListener; +// type UdpSocket = UdpSocket; +// type Error = io::Error; +// +// async fn connect( +// &mut self, +// addr: A, +// port: u16, +// ) -> Result { +// let target = addr.to_socks_address(); +// +// match target { +// SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await, +// SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await, +// SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await, +// } +// } +// +// async fn udp_socket( +// &mut self, +// addr: A, +// port: Option, +// ) -> Result { +// let me = addr.to_socks_address(); +// let real_port = port.unwrap_or(0); +// +// match me { +// SOCKSv5Address::IP4(a) => UdpSocket::bind((a, real_port)).await, +// SOCKSv5Address::IP6(a) => UdpSocket::bind((a, real_port)).await, +// SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), real_port)).await, +// } +// } +// +// async fn listen( +// &mut self, +// addr: A, +// port: Option, +// ) -> Result { +// let me = addr.to_socks_address(); +// let real_port = port.unwrap_or(0); +// +// match me { +// SOCKSv5Address::IP4(a) => TcpListener::bind((a, real_port)).await, +// SOCKSv5Address::IP6(a) => TcpListener::bind((a, real_port)).await, +// SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), real_port)).await, +// } +// } +// } +// +// #[async_trait] +// impl SingleShotListener for TcpListener { +// async fn accept(self) -> Result { +// self.accept().await +// } +// +// fn info(&self) -> Result<(SOCKSv5Address, u16), io::Error> { +// match self.local_addr()? { +// SocketAddr::V4(a) => Ok((SOCKSv5Address::IP4(*a.ip()), a.port())), +// SocketAddr::V6(a) => Ok((SOCKSv5Address::IP6(*a.ip()), a.port())), +// } +// } +// } +// diff --git a/src/network/stream.rs b/src/network/stream.rs new file mode 100644 index 0000000..9887fbb --- /dev/null +++ b/src/network/stream.rs @@ -0,0 +1,52 @@ +use async_std::task::{Context, Poll}; +use futures::io; +use futures::io::{AsyncRead, AsyncWrite}; +use std::pin::Pin; +use std::sync::Arc; + +pub trait Streamlike: AsyncRead + AsyncWrite + Send + Sync + Unpin {} + +#[derive(Clone)] +pub struct GenericStream { + internal: Arc>, +} + +impl AsyncRead for GenericStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let base = Pin::into_inner(self); + Pin::new(base).poll_read(cx, buf) + } +} + +impl AsyncWrite for GenericStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let base = Pin::into_inner(self); + Pin::new(base).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let base = Pin::into_inner(self); + Pin::new(base).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let base = Pin::into_inner(self); + Pin::new(base).poll_close(cx) + } +} + +impl From for GenericStream { + fn from(x: T) -> GenericStream { + GenericStream { + internal: Arc::new(Box::new(x)), + } + } +} diff --git a/src/serialize.rs b/src/serialize.rs new file mode 100644 index 0000000..5f5e424 --- /dev/null +++ b/src/serialize.rs @@ -0,0 +1,60 @@ +use crate::errors::{DeserializationError, SerializationError}; +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use std::pin::Pin; + +pub async fn read_string( + mut r: Pin<&mut R>, +) -> Result { + let mut length_buffer = [0; 1]; + + if r.read(&mut length_buffer).await? == 0 { + return Err(DeserializationError::NotEnoughData); + } + + let target = length_buffer[0] as usize; + + if target == 0 { + return Err(DeserializationError::InvalidEmptyString); + } + + let mut bytestring = vec![0; target]; + read_amt(r, target, &mut bytestring).await?; + + Ok(String::from_utf8(bytestring)?) +} + +pub async fn write_string( + s: &str, + w: &mut W, +) -> Result<(), SerializationError> { + let bytestring = s.as_bytes(); + + if bytestring.is_empty() || bytestring.len() > 255 { + return Err(SerializationError::InvalidStringLength(s.to_string())); + } + + w.write_all(&[bytestring.len() as u8]).await?; + w.write_all(bytestring) + .await + .map_err(SerializationError::IOError) +} + +pub async fn read_amt( + mut r: Pin<&mut R>, + amt: usize, + buffer: &mut [u8], +) -> Result<(), DeserializationError> { + let mut amt_read = 0; + + while amt_read < amt { + let chunk_amt = r.read(&mut buffer[amt_read..]).await?; + + if chunk_amt == 0 { + return Err(DeserializationError::NotEnoughData); + } + + amt_read += chunk_amt; + } + + Ok(()) +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..de8611d --- /dev/null +++ b/src/server.rs @@ -0,0 +1,313 @@ +use crate::errors::{DeserializationError, SerializationError}; +use crate::messages::{ + AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting, + ClientUsernamePassword, ServerChoice, ServerResponse, ServerResponseStatus, +}; +use crate::network::generic::Networklike; +use crate::network::listener::{GenericListener, Listenerlike}; +use crate::network::stream::GenericStream; +use crate::network::SOCKSv5Address; +use async_std::io; +use async_std::io::prelude::WriteExt; +use async_std::sync::{Arc, Mutex}; +use async_std::task; +use log::{error, info, trace, warn}; +use std::pin::Pin; +use thiserror::Error; + +pub struct SOCKSv5Server { + network: N, + security_parameters: SecurityParameters, + listener: GenericListener, +} + +#[derive(Clone)] +pub struct SecurityParameters { + pub allow_unauthenticated: bool, + pub allow_connection: Option bool>, + pub check_password: Option bool>, + pub connect_tls: Option Option>, +} + +impl SOCKSv5Server { + pub fn new + 'static>( + network: N, + security_parameters: SecurityParameters, + stream: S, + ) -> SOCKSv5Server { + SOCKSv5Server { + network, + security_parameters, + listener: GenericListener { + internal: Box::new(stream), + }, + } + } + + pub async fn run(self) -> Result<(), N::Error> { + let (my_addr, my_port) = self.listener.info(); + info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port); + let locked_network = Arc::new(Mutex::new(self.network)); + + loop { + let (stream, their_addr, their_port) = self.listener.accept().await?; + + trace!( + "Initial accept of connection from {}:{}", + their_addr, + their_port + ); + if let Some(checker) = &self.security_parameters.allow_connection { + if !checker(&their_addr, their_port) { + info!( + "Rejecting attempted connection from {}:{}", + their_addr, their_port + ); + continue; + } + } + + let params = self.security_parameters.clone(); + let network_mutex_copy = locked_network.clone(); + task::spawn(async move { + if let Some(authed_stream) = + run_authentication(params, stream, their_addr.clone(), their_port).await + { + if let Err(e) = run_main_loop(network_mutex_copy, authed_stream).await { + warn!("Failure in main loop: {}", e); + } + } + }); + } + } +} + +async fn run_authentication( + params: SecurityParameters, + mut stream: GenericStream, + addr: SOCKSv5Address, + port: u16, +) -> Option { + match ClientGreeting::read(Pin::new(&mut stream)).await { + Err(e) => { + error!( + "Client hello deserialization error from {}:{}: {}", + addr, port, e + ); + None + } + + // So we get opinionated here, based on what we think should be our first choice if the + // server offers something up. So we'll first see if we can make this a TLS connection. + Ok(cg) + if cg.acceptable_methods.contains(&AuthenticationMethod::SSL) + && params.connect_tls.is_some() => + { + match params.connect_tls { + None => { + error!("Internal error: TLS handler was there, but is now gone"); + None + } + Some(converter) => match converter(stream) { + None => { + info!("Rejecting bad TLS handshake from {}:{}", addr, port); + None + } + Some(new_stream) => Some(new_stream), + }, + } + } + + // if we can't do that, we'll see if we can get a username and password + Ok(cg) + if cg + .acceptable_methods + .contains(&&AuthenticationMethod::UsernameAndPassword) + && params.check_password.is_some() => + { + match ClientUsernamePassword::read(Pin::new(&mut stream)).await { + Err(e) => { + warn!( + "Error reading username/password from {}:{}: {}", + addr, port, e + ); + None + } + Ok(userinfo) => { + let checker = params.check_password.unwrap_or(|_, _| false); + if checker(&userinfo.username, &userinfo.password) { + Some(stream) + } else { + None + } + } + } + } + + // and, in the worst case, we'll see if our user is cool with unauthenticated connections + Ok(cg) + if cg.acceptable_methods.contains(&&AuthenticationMethod::None) + && params.allow_unauthenticated => + { + Some(stream) + } + + Ok(_) => { + let rejection_letter = ServerChoice { + chosen_method: AuthenticationMethod::NoAcceptableMethods, + }; + + if let Err(e) = rejection_letter.write(&mut stream).await { + warn!( + "Error sending rejection letter in authentication response: {}", + e + ); + } + + if let Err(e) = stream.flush().await { + warn!( + "Error flushing buffer after rejection latter in authentication response: {}", + e + ); + } + + None + } + } +} + +#[derive(Error, Debug)] +enum ServerError { + #[error("Error in deserialization: {0}")] + DeserializationError(#[from] DeserializationError), + #[error("Error in serialization: {0}")] + SerializationError(#[from] SerializationError), +} + +async fn run_main_loop( + network: Arc>, + mut stream: GenericStream, +) -> Result<(), ServerError> +where + N: Networklike, + N::Error: 'static, +{ + loop { + let ccr = ClientConnectionRequest::read(Pin::new(&mut stream)).await?; + + match ccr.command_code { + ClientConnectionCommand::AssociateUDPPort => {} + + ClientConnectionCommand::EstablishTCPPortBinding => {} + + ClientConnectionCommand::EstablishTCPStream => { + let target = format!("{}:{}", ccr.destination_address, ccr.destination_port); + + info!( + "Client requested connection to {}:{}", + ccr.destination_address, ccr.destination_port + ); + let connection_res = { + let mut network = network.lock().await; + network + .connect(ccr.destination_address.clone(), ccr.destination_port) + .await + }; + let outgoing_stream = match connection_res { + Ok(x) => x, + Err(e) => { + error!("Failed to connect to {}: {}", target, e); + let response = ServerResponse::error(e); + response.write(&mut stream).await?; + continue; + } + }; + trace!( + "Connection established to {}:{}", + ccr.destination_address, + ccr.destination_port + ); + + let incoming_res = { + let mut network = network.lock().await; + network.listen("127.0.0.1", 0).await + }; + let incoming_listener = match incoming_res { + Ok(x) => x, + Err(e) => { + error!("Failed to bind server port for new TCP stream: {}", e); + let response = ServerResponse::error(e); + response.write(&mut stream).await?; + continue; + } + }; + let (bound_address, bound_port) = incoming_listener.info(); + trace!( + "Set up {}:{} to address request for {}:{}", + bound_address, + bound_port, + ccr.destination_address, + ccr.destination_port + ); + + let response = ServerResponse { + status: ServerResponseStatus::RequestGranted, + bound_address, + bound_port, + }; + response.write(&mut stream).await?; + + task::spawn(async move { + let (incoming_stream, from_addr, from_port) = match incoming_listener + .accept() + .await + { + Err(e) => { + error!("Miscellaneous error waiting for someone to connect for proxying: {}", e); + return; + } + Ok(s) => s, + }; + trace!( + "Accepted connection from {}:{} to attach to {}:{}", + from_addr, + from_port, + ccr.destination_address, + ccr.destination_port + ); + + let mut from_left = incoming_stream.clone(); + let mut from_right = outgoing_stream.clone(); + let mut to_left = incoming_stream; + let mut to_right = outgoing_stream; + let from = format!("{}:{}", from_addr, from_port); + let to = format!("{}:{}", ccr.destination_address, ccr.destination_port); + + task::spawn(async move { + info!( + "Spawned {}:{} >--> {}:{} task", + from_addr, from_port, ccr.destination_address, ccr.destination_port + ); + if let Err(e) = io::copy(&mut from_left, &mut to_right).await { + warn!( + "{}:{} >--> {}:{} connection failed with: {}", + from_addr, + from_port, + ccr.destination_address, + ccr.destination_port, + e + ); + } + }); + + task::spawn(async move { + info!("Spawned {} <--< {} task", from, to); + if let Err(e) = io::copy(&mut from_right, &mut to_left).await { + warn!("{} <--< {} connection failed with: {}", from, to, e); + } + }); + }); + } + } + } +} \ No newline at end of file