Compare commits

17 Commits

Author SHA1 Message Date
65e79b4237 Try being multiplatform with our config files. 2022-11-23 19:49:58 -08:00
1d182a150f Checkpoint, or something. 2022-11-22 20:13:14 -08:00
277125e1a0 Make write() consume the objects. 2022-05-14 20:37:18 -07:00
c8279cfc5f Switch to basic tokio; will expand later to arbitrary backends. 2022-05-14 20:28:09 -07:00
d284f60d67 Whoops! Had swapped the IPv6 and Name tags; now Firefox works. 2022-01-21 20:24:23 -08:00
811580c64f Just to have a chance to try it out: Switch to proptest. 2022-01-08 16:34:40 -08:00
aa414fd527 More clippy fixin's. 2021-12-31 10:32:15 -08:00
8ac3f52546 Make clippy happier. 2021-12-31 09:35:17 -08:00
bac2c33aee First attempt at implementing remote TCP port binding. 2021-12-20 20:40:25 -08:00
3737d0739d Get TCP forwarding working with itself and an external client.
As it happens, this is a pretty major change, because I misunderstood
how the protocol actually works. Rather than having a single core
command channel and then a series of offshoots, SOCKSv5 does a separate
handshake for each individual command, and then uses the command stream
as a data stream. So ... whoops. So now the `SOCKSv5Server` sits on a
listener, instead, and farms each of the connections out to a task.
2021-11-21 21:18:55 -08:00
74f66ef747 Add a separate trait for converting errors into server responses. 2021-11-21 21:18:44 -08:00
774591cb54 Whoops! Missed a reserved byte in client requests. 2021-11-21 21:18:39 -08:00
c05b0f2b74 Wire up, and add a test for, connecting via a client through a proxy server. 2021-11-21 21:18:28 -08:00
abff1a4ec1 Remove a duplicated firewall check. 2021-11-21 21:18:21 -08:00
ac11ae64a8 Add a test that our little mini-firewall works as intended. 2021-11-21 21:18:07 -08:00
67b2acab25 Add support for knowing when the write end of a testing stream drops the reference, and then triggering errors on the read side. 2021-11-21 21:17:59 -08:00
58c04adeb7 Merge pull request #1 from acw/feature/github-actions
Integrate some CI into this project, to see how it works on GitHub.
2021-10-29 20:56:22 -07:00
35 changed files with 2120 additions and 2439 deletions

View File

@@ -7,11 +7,26 @@ edition = "2018"
[lib] [lib]
name = "async_socks5" name = "async_socks5"
[[bin]]
name="socks5-server"
path="server/main.rs"
[dependencies] [dependencies]
async-std = { version = "1.9.0", features = ["attributes"] } anyhow = "^1.0.57"
async-trait = "0.1.50" clap = { version = "^3.1.18", features = ["derive"] }
futures = "0.3.15" etcetera = "^0.4.0"
log = "0.4.8" futures = "0.3.21"
quickcheck = "1.0.3" if-addrs = "0.7.0"
simplelog = "0.10.0" lazy_static = "1.4.0"
thiserror = "1.0.24" proptest = "^1.0.0"
serde = "^1.0.137"
serde_derive = "^1.0.137"
thiserror = "^1.0.31"
tokio = { version = "^1", features = ["full"] }
toml = "^0.5.9"
tracing = "^0.1.34"
tracing-subscriber = { version = "^0.3.11", features = ["env-filter"] }
[dev-dependencies]
proptest = "1.0.0"
proptest-derive = "0.3.0"

2
TODO Normal file
View File

@@ -0,0 +1,2 @@
* [ ] Turn `write` from &self to self
* [ ] Turn `read`/`write` into a typeclass of some kind.

15
sample.toml Normal file
View File

@@ -0,0 +1,15 @@
# Unless otherwise sepcified, use this log level
log_level = "TRACE"
# Unless the command line specifies servers (or includes --validate),
# start the following named servers.
start_servers = "loopback*,ethernet"
[loopback4]
interface="lo0"
log_level="TRACE"
address="127.0.0.1"
[loopback6]
interface="lo0"
log_level="DEBUG"
address="::1"

176
server/config.rs Normal file
View File

@@ -0,0 +1,176 @@
mod cmdline;
mod config_file;
use self::cmdline::Arguments;
use self::config_file::ConfigFile;
use clap::Parser;
use if_addrs::IfAddr;
use std::io;
use std::net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::str::FromStr;
use thiserror::Error;
use tracing::metadata::LevelFilter;
#[derive(Debug, Error)]
pub enum ConfigError {
#[error(transparent)]
CommandLineError(#[from] clap::Error),
#[error(transparent)]
IOError(#[from] io::Error),
#[error("TOML processing error: {0}")]
TomlError(#[from] toml::de::Error),
#[error("Server '{0}' specifies an interface ({1}) with no addresses")]
NoAddressForInterface(String, String),
#[error("Server '{0}' specifies an address we couldn't parse: {1}")]
AddressParseError(String, AddrParseError),
#[error("Host directory error {0}")]
HostDirectoryError(String),
}
#[derive(Debug)]
pub struct Config {
pub log_level: LevelFilter,
pub server_definitions: Vec<ServerDefinition>,
}
#[derive(Debug)]
pub struct ServerDefinition {
pub name: String,
pub start: bool,
pub interface: Option<String>,
pub address: SocketAddr,
pub log_level: LevelFilter,
}
impl Config {
/// Generate a configuration by reading the command line arguments and any
/// defined config file, generating the actual arguments that we'll use for
/// operating the daemon.
pub fn derive() -> Result<Self, ConfigError> {
let command_line = Arguments::try_parse()?;
let mut config_file = ConfigFile::read(command_line.config_file)?;
let nic_addresses = if_addrs::get_if_addrs()?;
let log_level = command_line
.log_level
.or(config_file.log_level)
.unwrap_or(LevelFilter::ERROR);
let mut server_definitions = Vec::new();
let servers_to_start: Vec<String> = config_file
.start_servers
.map(|x| x.split(',').map(|v| v.to_string()).collect())
.unwrap_or_default();
for (name, config_info) in config_file.servers.drain() {
let start = servers_to_start.contains(&name);
let log_level = config_info.log_level.unwrap_or(log_level);
let port = config_info.port.unwrap_or(1080);
let mut interface = None;
let mut address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port));
match (config_info.interface, config_info.address) {
// if the user provides us nothing, we'll just use a blank address and
// no interface association
(None, None) => {}
// if the user provides us an interface but no address, we'll see if we can
// find the interface and pull a reasonable address from it.
(Some(given_interface), None) => {
let mut found_it = false;
for card_interface in nic_addresses.iter() {
if card_interface.name == given_interface {
interface = Some(given_interface.clone());
address = SocketAddr::new(addr_convert(&card_interface.addr), port);
found_it = true;
break;
}
}
if !found_it {
return Err(ConfigError::NoAddressForInterface(name, given_interface));
}
}
// if the user provides us an address but no interface, we'll quickly see if
// we can find that address in our interface list ... but we won't insist on
// it.
(None, Some(address_string)) => {
let read_address = IpAddr::from_str(&address_string)
.map_err(|x| ConfigError::AddressParseError(name.clone(), x))?;
interface = None;
address = SocketAddr::new(read_address, port);
for card_interface in nic_addresses.iter() {
if addrs_match(&card_interface.addr, &read_address) {
interface = Some(card_interface.name.clone());
break;
}
}
}
// if the user provides both, we'll check to make sure that they match.
(Some(given_interface), Some(address_string)) => {
let read_address = IpAddr::from_str(&address_string)
.map_err(|x| ConfigError::AddressParseError(name.clone(), x))?;
let mut inferred_interface = None;
let mut good_to_go = false;
address = SocketAddr::new(read_address, port);
for card_interface in nic_addresses.iter() {
if addrs_match(&card_interface.addr, &read_address) {
if card_interface.name == given_interface {
interface = Some(given_interface.clone());
good_to_go = true;
break;
} else {
inferred_interface = Some(card_interface.name.clone());
}
}
}
if !good_to_go {
if let Some(inferred_interface) = inferred_interface {
tracing::warn!("Address {} is associated with interface {}, not {}; using it instead", read_address, inferred_interface, given_interface);
} else {
tracing::warn!(
"Address {} is not associated with interface {}, or any interface.",
read_address,
given_interface
);
}
}
}
}
server_definitions.push(ServerDefinition {
name,
start,
interface,
address,
log_level,
});
}
Ok(Config {
log_level,
server_definitions,
})
}
}
fn addr_convert(x: &if_addrs::IfAddr) -> IpAddr {
match x {
if_addrs::IfAddr::V4(x) => IpAddr::V4(x.ip),
if_addrs::IfAddr::V6(x) => IpAddr::V6(x.ip),
}
}
fn addrs_match(x: &if_addrs::IfAddr, y: &IpAddr) -> bool {
match (x, y) {
(IfAddr::V4(x), IpAddr::V4(y)) => &x.ip == y,
(IfAddr::V6(x), IpAddr::V6(y)) => &x.ip == y,
_ => false,
}
}

35
server/config/cmdline.rs Normal file
View File

@@ -0,0 +1,35 @@
use clap::Parser;
use std::path::PathBuf;
use tracing::metadata::LevelFilter;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
pub struct Arguments {
#[clap(
short,
long,
help = "Use the given config file, rather than $XDG_CONFIG_DIR/socks5.toml"
)]
pub config_file: Option<PathBuf>,
#[clap(
short,
long,
help = "Default logging to the given level. (Defaults to ERROR if not given)"
)]
pub log_level: Option<LevelFilter>,
#[clap(
short,
long,
help = "Start only the named server(s) from the config file. For more than one, use comma-separated values or multiple instances of --start"
)]
pub start: Vec<String>,
#[clap(
short,
long = "validate",
help = "Do not actually start any servers; just validate the config file."
)]
pub validate_only: bool,
}

View File

@@ -0,0 +1,64 @@
use super::ConfigError;
use etcetera::base_strategy::{choose_base_strategy, BaseStrategy};
use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tracing::metadata::LevelFilter;
#[derive(serde_derive::Deserialize, Default)]
pub struct ConfigFile {
#[serde(deserialize_with = "parse_log_level")]
pub log_level: Option<LevelFilter>,
pub start_servers: Option<String>,
#[serde(flatten)]
pub servers: HashMap<String, ServerDefinition>,
}
#[derive(serde_derive::Deserialize)]
pub struct ServerDefinition {
pub interface: Option<String>,
#[serde(deserialize_with = "parse_log_level")]
pub log_level: Option<LevelFilter>,
pub address: Option<String>,
pub port: Option<u16>,
}
fn parse_log_level<'de, D>(deserializer: D) -> Result<Option<LevelFilter>, D::Error>
where
D: Deserializer<'de>,
{
let possible_string: Option<&str> = Deserialize::deserialize(deserializer)?;
if let Some(s) = possible_string {
Ok(Some(s.parse().map_err(|e| {
serde::de::Error::custom(format!("Couldn't parse log level '{}': {}", s, e))
})?))
} else {
Ok(None)
}
}
impl ConfigFile {
pub fn read(mut config_file_path: Option<PathBuf>) -> Result<ConfigFile, ConfigError> {
if config_file_path.is_none() {
let base_dirs = choose_base_strategy()
.map_err(|e| ConfigError::HostDirectoryError(e.to_string()))?;
let mut proposed_path = base_dirs.config_dir();
proposed_path.push("socks5");
if let Ok(attributes) = fs::metadata(proposed_path.clone()) {
if attributes.is_file() {
config_file_path = Some(proposed_path);
}
}
}
match config_file_path {
None => Ok(ConfigFile::default()),
Some(path) => {
let content = fs::read(path)?;
Ok(toml::from_slice(&content)?)
}
}
}
}

74
server/main.rs Normal file
View File

@@ -0,0 +1,74 @@
mod config;
use async_socks5::server::{SOCKSv5Server, SecurityParameters};
use config::Config;
use tracing::Instrument;
use tracing_subscriber::filter::EnvFilter;
use tracing_subscriber::fmt;
use tracing_subscriber::prelude::*;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let config = Config::derive()?;
let fmt_layer = fmt::layer().with_target(false);
let filter_layer = EnvFilter::builder()
.with_default_directive(config.log_level.into())
.from_env()?;
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
tracing::trace!("Parsed configuration: {:?}", config);
let core_server = SOCKSv5Server::new(SecurityParameters {
allow_unauthenticated: true,
allow_connection: None,
check_password: None,
connect_tls: None,
});
let mut running_servers = vec![];
for server_def in config.server_definitions {
let span = tracing::trace_span!(
"",
server_name = %server_def.name,
interface = ?server_def.interface,
address = %server_def.address,
);
let result = core_server
.start(server_def.address.ip(), server_def.address.port())
.instrument(span)
.await;
match result {
Ok(x) => running_servers.push(x),
Err(e) => tracing::error!(
server = %server_def.name,
interface = ?server_def.interface,
address = %server_def.address,
"Failure in launching server: {}",
e
),
}
}
while !running_servers.is_empty() {
let (initial_result, _idx, next_runners) =
futures::future::select_all(running_servers).await;
match initial_result {
Ok(Ok(())) => tracing::info!("server completed successfully"),
Ok(Err(e)) => tracing::error!("error in running server: {}", e),
Err(e) => tracing::error!("error joining server: {}", e),
}
running_servers = next_runners;
}
Ok(())
}

174
src/address.rs Normal file
View File

@@ -0,0 +1,174 @@
use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum SOCKSv5Address {
IP4(Ipv4Addr),
IP6(Ipv6Addr),
Hostname(String),
}
impl From<IpAddr> for SOCKSv5Address {
fn from(x: IpAddr) -> SOCKSv5Address {
match x {
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
}
}
}
impl From<Ipv4Addr> for SOCKSv5Address {
fn from(x: Ipv4Addr) -> SOCKSv5Address {
SOCKSv5Address::IP4(x)
}
}
impl From<Ipv6Addr> for SOCKSv5Address {
fn from(x: Ipv6Addr) -> SOCKSv5Address {
SOCKSv5Address::IP6(x)
}
}
impl From<SOCKSv5String> for SOCKSv5Address {
fn from(x: SOCKSv5String) -> SOCKSv5Address {
SOCKSv5Address::Hostname(x.into())
}
}
impl<'a> From<&'a str> for SOCKSv5Address {
fn from(x: &str) -> SOCKSv5Address {
SOCKSv5Address::Hostname(x.to_string())
}
}
impl From<String> for SOCKSv5Address {
fn from(x: String) -> SOCKSv5Address {
SOCKSv5Address::Hostname(x)
}
}
impl fmt::Display for SOCKSv5Address {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
SOCKSv5Address::Hostname(a) => write!(f, "{}", a),
}
}
}
#[cfg(test)]
const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+";
#[cfg(test)]
use proptest::prelude::{any, prop_oneof, Arbitrary, BoxedStrategy, Strategy};
#[cfg(test)]
impl Arbitrary for SOCKSv5Address {
type Parameters = Option<u16>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let max_len = args.unwrap_or(32) as usize;
prop_oneof![
any::<Ipv4Addr>().prop_map(SOCKSv5Address::IP4),
any::<Ipv6Addr>().prop_map(SOCKSv5Address::IP6),
HOSTNAME_REGEX.prop_map(move |mut hostname| {
hostname.shrink_to(max_len);
SOCKSv5Address::Hostname(hostname)
}),
]
.boxed()
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5AddressReadError {
#[error("Bad address type {0} (expected 1, 3, or 4)")]
BadAddressType(u8),
#[error("Read buffer error: {0}")]
ReadError(String),
#[error(transparent)]
SOCKSv5StringError(#[from] SOCKSv5StringReadError),
}
impl From<std::io::Error> for SOCKSv5AddressReadError {
fn from(x: std::io::Error) -> SOCKSv5AddressReadError {
SOCKSv5AddressReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5AddressWriteError {
#[error(transparent)]
SOCKSv5StringError(#[from] SOCKSv5StringWriteError),
#[error("Write buffer error: {0}")]
WriteError(String),
}
impl From<std::io::Error> for SOCKSv5AddressWriteError {
fn from(x: std::io::Error) -> SOCKSv5AddressWriteError {
SOCKSv5AddressWriteError::WriteError(format!("{}", x))
}
}
impl SOCKSv5Address {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<Self, SOCKSv5AddressReadError> {
match r.read_u8().await? {
1 => {
let mut addr_buffer = [0; 4];
r.read_exact(&mut addr_buffer).await?;
let ip4 = Ipv4Addr::from(addr_buffer);
Ok(SOCKSv5Address::IP4(ip4))
}
3 => {
let string = SOCKSv5String::read(r).await?;
Ok(SOCKSv5Address::from(string))
}
4 => {
let mut addr_buffer = [0; 16];
r.read_exact(&mut addr_buffer).await?;
let ip6 = Ipv6Addr::from(addr_buffer);
Ok(SOCKSv5Address::IP6(ip6))
}
x => Err(SOCKSv5AddressReadError::BadAddressType(x)),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SOCKSv5AddressWriteError> {
match self {
SOCKSv5Address::IP4(x) => {
w.write_u8(1).await?;
w.write_all(&x.octets()).await?;
Ok(())
}
SOCKSv5Address::IP6(x) => {
w.write_u8(4).await?;
w.write_all(&x.octets()).await?;
Ok(())
}
SOCKSv5Address::Hostname(x) => {
w.write_u8(3).await?;
let string = SOCKSv5String::from(x.clone());
string.write(w).await?;
Ok(())
}
}
}
}
crate::standard_roundtrip!(socks_address_roundtrips, SOCKSv5Address);

View File

@@ -1,30 +0,0 @@
use async_socks5::network::Builtin;
use async_socks5::server::{SOCKSv5Server, SecurityParameters};
use async_std::io;
use async_std::net::TcpListener;
use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TermLogger, TerminalMode};
#[async_std::main]
async fn main() -> Result<(), io::Error> {
CombinedLogger::init(vec![TermLogger::new(
LevelFilter::Debug,
Config::default(),
TerminalMode::Mixed,
ColorChoice::Auto,
)])
.expect("Couldn't initialize logger");
let main_listener = TcpListener::bind("127.0.0.1:0").await?;
let params = SecurityParameters {
allow_unauthenticated: false,
allow_connection: None,
check_password: None,
connect_tls: None,
};
let server = SOCKSv5Server::new(Builtin::new(), params, main_listener);
server.run().await?;
Ok(())
}

View File

@@ -1,43 +1,65 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::address::SOCKSv5Address;
use crate::messages::{ use crate::messages::{
AuthenticationMethod, ClientGreeting, ClientUsernamePassword, ServerAuthResponse, ServerChoice, AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandWriteError,
ClientConnectionRequest, ClientGreeting, ClientGreetingWriteError, ClientUsernamePassword,
ClientUsernamePasswordWriteError, ServerAuthResponse, ServerAuthResponseReadError,
ServerChoice, ServerChoiceReadError, ServerResponse, ServerResponseReadError,
ServerResponseStatus, ServerResponseStatus,
}; };
use crate::network::generic::Networklike; use std::future::Future;
use futures::io::{AsyncRead, AsyncWrite};
use log::{trace, warn};
use thiserror::Error; use thiserror::Error;
use tokio::net::TcpStream;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum SOCKSv5Error { pub enum SOCKSv5ClientError {
#[error("SOCKSv5 serialization error: {0}")] #[error("Underlying networking error: {0}")]
SerializationError(#[from] SerializationError), NetworkingError(String),
#[error("SOCKSv5 deserialization error: {0}")] #[error("Client greeting write error: {0}")]
DeserializationError(#[from] DeserializationError), ClientWriteError(#[from] ClientGreetingWriteError),
#[error("No acceptable authentication methods available")] #[error("Server choice error: {0}")]
NoAuthMethodsAllowed, ServerChoiceError(#[from] ServerChoiceReadError),
#[error("Error writing credentials: {0}")]
CredentialWriteError(#[from] ClientUsernamePasswordWriteError),
#[error("Server auth read error: {0}")]
AuthResponseError(#[from] ServerAuthResponseReadError),
#[error("Authentication failed")] #[error("Authentication failed")]
AuthenticationFailed, AuthenticationFailed,
#[error("Server chose an unsupported authentication method ({0}")] #[error("No authentication methods allowed")]
NoAuthMethodsAllowed,
#[error("Unsupported authentication method chosen ({0})")]
UnsupportedAuthMethodChosen(AuthenticationMethod), UnsupportedAuthMethodChosen(AuthenticationMethod),
#[error("Client connection command write error: {0}")]
ClientCommandWriteError(#[from] ClientConnectionCommandWriteError),
#[error("Server said no: {0}")] #[error("Server said no: {0}")]
ServerFailure(#[from] ServerResponseStatus), ServerRejected(#[from] ServerResponseStatus),
#[error("Server response read failure: {0}")]
ServerResponseError(#[from] ServerResponseReadError),
} }
pub struct SOCKSv5Client<S, N> impl From<std::io::Error> for SOCKSv5ClientError {
where fn from(x: std::io::Error) -> SOCKSv5ClientError {
S: AsyncRead + AsyncWrite, SOCKSv5ClientError::NetworkingError(format!("{}", x))
N: Networklike, }
{
_network: N,
_stream: S,
} }
pub struct LoginInfo { pub struct LoginInfo {
pub username_password: Option<UsernamePassword>, pub username_password: Option<UsernamePassword>,
} }
impl Default for LoginInfo {
fn default() -> Self {
Self::new()
}
}
impl LoginInfo { impl LoginInfo {
/// Generate an empty bit of login information.
pub fn new() -> LoginInfo {
LoginInfo {
username_password: None,
}
}
/// Turn this information into a list of authentication methods that we can handle, /// Turn this information into a list of authentication methods that we can handle,
/// to send to the server. The RFC isn't super clear if the order of these matters /// to send to the server. The RFC isn't super clear if the order of these matters
/// at all, but we'll try to keep it in our preferred order. /// at all, but we'll try to keep it in our preferred order.
@@ -57,61 +79,191 @@ pub struct UsernamePassword {
pub password: String, pub password: String,
} }
impl<S, N> SOCKSv5Client<S, N> pub struct SOCKSv5Client {
where login_info: LoginInfo,
S: AsyncRead + AsyncWrite + Send + Unpin, address: SOCKSv5Address,
N: Networklike, port: u16,
{ }
impl SOCKSv5Client {
/// Create a new SOCKSv5 client connection over the given steam, using the given /// Create a new SOCKSv5 client connection over the given steam, using the given
/// authentication information. /// authentication information. As part of the process of building this object, we
pub async fn new(_network: N, mut stream: S, login: &LoginInfo) -> Result<Self, SOCKSv5Error> { /// do a little test run to make sure that we can login effectively; this should save
let acceptable_methods = login.acceptable_methods(); /// from *some* surprises later on. If you'd rather *not* do that, though, you can
trace!( /// try `unchecked_new`.
pub async fn new<A: Into<SOCKSv5Address>>(
login: LoginInfo,
server_addr: A,
server_port: u16,
) -> Result<Self, SOCKSv5ClientError> {
let base_version = SOCKSv5Client::unchecked_new(login, server_addr, server_port);
let _ = base_version.start_session().await?;
Ok(base_version)
}
/// Create a new SOCKSv5Client within the given parameters, but don't do a quick
/// check to see if this connection has a chance of working. This saves you a TCP
/// connection sequence at the expense of an increased possibility of an error
/// later on down the road.
pub fn unchecked_new<A: Into<SOCKSv5Address>>(
login_info: LoginInfo,
address: A,
port: u16,
) -> Self {
SOCKSv5Client {
login_info,
address: address.into(),
port,
}
}
/// This runs the connection and negotiates login, as required, and then returns
/// the stream the caller should use to do ... whatever it wants to do.
async fn start_session(&self) -> Result<TcpStream, SOCKSv5ClientError> {
// create the initial stream
let mut stream = match &self.address {
SOCKSv5Address::IP4(x) => TcpStream::connect((*x, self.port)).await?,
SOCKSv5Address::IP6(x) => TcpStream::connect((*x, self.port)).await?,
SOCKSv5Address::Hostname(x) => TcpStream::connect((x.as_ref(), self.port)).await?,
};
// compute how we can log in
let acceptable_methods = self.login_info.acceptable_methods();
tracing::trace!(
"Computed acceptable methods -- {:?} -- sending client greeting.", "Computed acceptable methods -- {:?} -- sending client greeting.",
acceptable_methods acceptable_methods
); );
// Negotiate with the server. Well. "Negotiate."
let client_greeting = ClientGreeting { acceptable_methods }; let client_greeting = ClientGreeting { acceptable_methods };
client_greeting.write(&mut stream).await?; client_greeting.write(&mut stream).await?;
trace!("Write client greeting, waiting for server's choice."); tracing::trace!("Write client greeting, waiting for server's choice.");
let server_choice = ServerChoice::read(&mut stream).await?; let server_choice = ServerChoice::read(&mut stream).await?;
trace!("Received server's choice: {}", server_choice.chosen_method); tracing::trace!("Received server's choice: {}", server_choice.chosen_method);
// Let's do it!
match server_choice.chosen_method { match server_choice.chosen_method {
AuthenticationMethod::None => {} AuthenticationMethod::None => {}
AuthenticationMethod::UsernameAndPassword => { AuthenticationMethod::UsernameAndPassword => {
let (username, password) = if let Some(ref linfo) = login.username_password { let (username, password) = if let Some(ref linfo) =
trace!("Server requested username/password, getting data from login info."); self.login_info.username_password
{
tracing::trace!(
"Server requested username/password, getting data from login info."
);
(linfo.username.clone(), linfo.password.clone()) (linfo.username.clone(), linfo.password.clone())
} else { } else {
warn!("Server requested username/password, but we weren't provided one. Very weird."); tracing::warn!("Server requested username/password, but we weren't provided one. Very weird.");
("".to_string(), "".to_string()) ("".to_string(), "".to_string())
}; };
let auth_request = ClientUsernamePassword { username, password }; let auth_request = ClientUsernamePassword { username, password };
trace!("Writing password information."); tracing::trace!("Writing password information.");
auth_request.write(&mut stream).await?; auth_request.write(&mut stream).await?;
let server_response = ServerAuthResponse::read(&mut stream).await?; let server_response = ServerAuthResponse::read(&mut stream).await?;
trace!("Got server response: {}", server_response.success); tracing::trace!("Got server response: {}", server_response.success);
if !server_response.success { if !server_response.success {
return Err(SOCKSv5Error::AuthenticationFailed); return Err(SOCKSv5ClientError::AuthenticationFailed);
} }
} }
AuthenticationMethod::NoAcceptableMethods => { AuthenticationMethod::NoAcceptableMethods => {
return Err(SOCKSv5Error::NoAuthMethodsAllowed) return Err(SOCKSv5ClientError::NoAuthMethodsAllowed)
} }
x => return Err(SOCKSv5Error::UnsupportedAuthMethodChosen(x)), x => return Err(SOCKSv5ClientError::UnsupportedAuthMethodChosen(x)),
} }
trace!("Returning new SOCKSv5Client object!"); Ok(stream)
Ok(SOCKSv5Client { }
_network,
_stream: stream, /// Listen for one connection on the proxy server, and then wire back a socket
}) /// that can talk to whoever connects. This handshake is a little odd, because
/// we don't necessarily know what port or address we should tell the other
/// person to listen on. So this function takes an async function, which it
/// will pass this information to once it has it. It's up to that function,
/// then, to communicate this to its peer.
pub async fn remote_listen<A, Fut: Future<Output = Result<(), SOCKSv5ClientError>>>(
self,
addr: A,
port: u16,
callback: impl FnOnce(SOCKSv5Address, u16) -> Fut,
) -> Result<(SOCKSv5Address, u16, TcpStream), SOCKSv5ClientError>
where
A: Into<SOCKSv5Address>,
{
let mut stream = self.start_session().await?;
let target = addr.into();
let ccr = ClientConnectionRequest {
command_code: ClientConnectionCommand::EstablishTCPPortBinding,
destination_address: target.clone(),
destination_port: port,
};
ccr.write(&mut stream).await?;
let initial_response = ServerResponse::read(&mut stream).await?;
if initial_response.status != ServerResponseStatus::RequestGranted {
return Err(initial_response.status.into());
}
tracing::info!(
"Proxy port binding of {}:{} established; server listening on {}:{}",
target,
port,
initial_response.bound_address,
initial_response.bound_port
);
callback(initial_response.bound_address, initial_response.bound_port).await?;
let secondary_response = ServerResponse::read(&mut stream).await?;
if secondary_response.status != ServerResponseStatus::RequestGranted {
return Err(secondary_response.status.into());
}
tracing::info!(
"Proxy bind got a connection from {}:{}",
secondary_response.bound_address,
secondary_response.bound_port
);
Ok((
secondary_response.bound_address,
secondary_response.bound_port,
stream,
))
}
pub async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<TcpStream, SOCKSv5ClientError> {
let mut stream = self.start_session().await?;
let target = addr.into();
let ccr = ClientConnectionRequest {
command_code: ClientConnectionCommand::EstablishTCPStream,
destination_address: target.clone(),
destination_port: port,
};
ccr.write(&mut stream).await?;
let response = ServerResponse::read(&mut stream).await?;
if response.status == ServerResponseStatus::RequestGranted {
tracing::info!(
"Proxy connection to {}:{} established; server is using {}:{}",
target,
port,
response.bound_address,
response.bound_port
);
Ok(stream)
} else {
Err(response.status.into())
}
} }
} }

View File

@@ -1,188 +0,0 @@
use std::io;
use std::string::FromUtf8Error;
use thiserror::Error;
use crate::network::SOCKSv5Address;
/// All the errors that can pop up when trying to turn raw bytes into SOCKSv5
/// messages.
#[derive(Error, Debug)]
pub enum DeserializationError {
#[error("Invalid protocol version for packet ({1} is not {0}!)")]
InvalidVersion(u8, u8),
#[error("Not enough data found")]
NotEnoughData,
#[error("Ooops! Found an empty string where I shouldn't")]
InvalidEmptyString,
#[error("IO error: {0}")]
IOError(#[from] io::Error),
#[error("SOCKS authentication format parse error: {0}")]
AuthenticationMethodError(#[from] AuthenticationDeserializationError),
#[error("Error converting from UTF-8: {0}")]
UTF8Error(#[from] FromUtf8Error),
#[error("Invalid address type; wanted 1, 3, or 4, got {0}")]
InvalidAddressType(u8),
#[error("Invalid client command {0}; expected 1, 2, or 3")]
InvalidClientCommand(u8),
#[error("Invalid server status {0}; expected 0-8")]
InvalidServerResponse(u8),
}
#[test]
fn des_error_reasonable_equals() {
let invalid_version = DeserializationError::InvalidVersion(1, 2);
assert_eq!(invalid_version, invalid_version);
let not_enough = DeserializationError::NotEnoughData;
assert_eq!(not_enough, not_enough);
let invalid_empty = DeserializationError::InvalidEmptyString;
assert_eq!(invalid_empty, invalid_empty);
let auth_method = DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound,
);
assert_eq!(auth_method, auth_method);
let utf8 = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
assert_eq!(utf8, utf8);
let invalid_address = DeserializationError::InvalidAddressType(3);
assert_eq!(invalid_address, invalid_address);
let invalid_client_cmd = DeserializationError::InvalidClientCommand(32);
assert_eq!(invalid_client_cmd, invalid_client_cmd);
let invalid_server_resp = DeserializationError::InvalidServerResponse(42);
assert_eq!(invalid_server_resp, invalid_server_resp);
assert_ne!(invalid_version, invalid_address);
assert_ne!(not_enough, invalid_empty);
assert_ne!(auth_method, invalid_client_cmd);
assert_ne!(utf8, invalid_server_resp);
}
impl PartialEq for DeserializationError {
fn eq(&self, other: &DeserializationError) -> bool {
match (self, other) {
(
&DeserializationError::InvalidVersion(a, b),
&DeserializationError::InvalidVersion(x, y),
) => (a == x) && (b == y),
(&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true,
(
&DeserializationError::InvalidEmptyString,
&DeserializationError::InvalidEmptyString,
) => true,
(
&DeserializationError::AuthenticationMethodError(ref a),
&DeserializationError::AuthenticationMethodError(ref b),
) => a == b,
(&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => {
a == b
}
(
&DeserializationError::InvalidAddressType(a),
&DeserializationError::InvalidAddressType(b),
) => a == b,
(
&DeserializationError::InvalidClientCommand(a),
&DeserializationError::InvalidClientCommand(b),
) => a == b,
(
&DeserializationError::InvalidServerResponse(a),
&DeserializationError::InvalidServerResponse(b),
) => a == b,
(_, _) => false,
}
}
}
/// All the errors that can occur trying to turn SOCKSv5 message structures
/// into raw bytes. There's a few places that the message structures allow
/// for information that can't be serialized; often, you have to be careful
/// about how long your strings are ...
#[derive(Error, Debug)]
pub enum SerializationError {
#[error("Too many authentication methods for serialization ({0} > 255)")]
TooManyAuthMethods(usize),
#[error("Invalid length for string: {0}")]
InvalidStringLength(String),
#[error("IO error: {0}")]
IOError(#[from] io::Error),
}
#[test]
fn ser_err_reasonable_equals() {
let too_many = SerializationError::TooManyAuthMethods(512);
assert_eq!(too_many, too_many);
let invalid_str = SerializationError::InvalidStringLength("Whoopsy!".to_string());
assert_eq!(invalid_str, invalid_str);
assert_ne!(too_many, invalid_str);
}
impl PartialEq for SerializationError {
fn eq(&self, other: &SerializationError) -> bool {
match (self, other) {
(
&SerializationError::TooManyAuthMethods(a),
&SerializationError::TooManyAuthMethods(b),
) => a == b,
(
&SerializationError::InvalidStringLength(ref a),
&SerializationError::InvalidStringLength(ref b),
) => a == b,
(_, _) => false,
}
}
}
#[derive(Error, Debug)]
pub enum AuthenticationDeserializationError {
#[error("No data found deserializing SOCKS authentication type")]
NoDataFound,
#[error("Invalid authentication type value: {0}")]
InvalidAuthenticationByte(u8),
#[error("IO error reading SOCKS authentication type: {0}")]
IOError(#[from] io::Error),
}
#[test]
fn auth_des_err_reasonable_equals() {
let no_data = AuthenticationDeserializationError::NoDataFound;
assert_eq!(no_data, no_data);
let invalid_auth = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
assert_eq!(invalid_auth, invalid_auth);
assert_ne!(no_data, invalid_auth);
}
impl PartialEq for AuthenticationDeserializationError {
fn eq(&self, other: &AuthenticationDeserializationError) -> bool {
match (self, other) {
(
&AuthenticationDeserializationError::NoDataFound,
&AuthenticationDeserializationError::NoDataFound,
) => true,
(
&AuthenticationDeserializationError::InvalidAuthenticationByte(x),
&AuthenticationDeserializationError::InvalidAuthenticationByte(y),
) => x == y,
(_, _) => false,
}
}
}
/// The errors that can happen, as a server, when we're negotiating the start
/// of a SOCKS session.
#[derive(Debug, Error)]
pub enum AuthenticationError {
#[error("Firewall disallowed connection from {0}:{1}")]
FirewallRejected(SOCKSv5Address, u16),
#[error("Could not agree on an authentication method with the client")]
ItsNotUsItsYou,
#[error("Failure in serializing response message: {0}")]
SerializationError(#[from] SerializationError),
#[error("Failed TLS handshake")]
FailedTLSHandshake,
#[error("IO error writing response message: {0}")]
IOError(#[from] io::Error),
#[error("Failure in reading client message: {0}")]
DeserializationError(#[from] DeserializationError),
#[error("Username/password check failed (username was {0})")]
FailedUsernamePassword(String),
}

View File

@@ -1,70 +1,54 @@
pub mod client; pub mod client;
pub mod errors;
pub mod messages;
pub mod network;
mod serialize;
pub mod server; pub mod server;
mod address;
mod messages;
mod security_parameters;
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::address::SOCKSv5Address;
use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword}; use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword};
use crate::network::generic::Networklike; use crate::security_parameters::SecurityParameters;
use crate::network::testing::TestingStack; use crate::server::SOCKSv5Server;
use crate::server::{SOCKSv5Server, SecurityParameters}; use std::io;
use async_std::task; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[test] use tokio::net::{TcpSocket, TcpStream};
fn unrestricted_login() { use tokio::sync::oneshot;
task::block_on(async { use tokio::task;
let mut network_stack = TestingStack::default();
#[tokio::test]
async fn unrestricted_login() {
// generate the server // generate the server
let security_parameters = SecurityParameters::unrestricted(); let security_parameters = SecurityParameters::unrestricted();
let default_port = network_stack.listen("localhost", 9999).await.unwrap(); let server = SOCKSv5Server::new(security_parameters);
let server = server.start("localhost", 9999).await.unwrap();
SOCKSv5Server::new(network_stack.clone(), security_parameters, default_port);
let _server_task = task::spawn(async move { server.run().await });
let stream = network_stack.connect("localhost", 9999).await.unwrap();
let login_info = LoginInfo { let login_info = LoginInfo {
username_password: None, username_password: None,
}; };
let client = SOCKSv5Client::new(network_stack, stream, &login_info).await; let client = SOCKSv5Client::new(login_info, "localhost", 9999).await;
assert!(client.is_ok()); assert!(client.is_ok());
})
} }
#[test] #[tokio::test]
fn disallow_unrestricted() { async fn disallow_unrestricted() {
task::block_on(async {
let mut network_stack = TestingStack::default();
// generate the server // generate the server
let mut security_parameters = SecurityParameters::unrestricted(); let mut security_parameters = SecurityParameters::unrestricted();
security_parameters.allow_unauthenticated = false; security_parameters.allow_unauthenticated = false;
let default_port = network_stack.listen("localhost", 9999).await.unwrap(); let server = SOCKSv5Server::new(security_parameters);
let server = server.start("localhost", 9998).await.unwrap();
SOCKSv5Server::new(network_stack.clone(), security_parameters, default_port);
let _server_task = task::spawn(async move { server.run().await }); let login_info = LoginInfo::default();
let client = SOCKSv5Client::new(login_info, "localhost", 9998).await;
let stream = network_stack.connect("localhost", 9999).await.unwrap();
let login_info = LoginInfo {
username_password: None,
};
let client = SOCKSv5Client::new(network_stack, stream, &login_info).await;
assert!(client.is_err()); assert!(client.is_err());
})
} }
#[test] #[tokio::test]
fn password_checks() { async fn password_checks() {
task::block_on(async {
let mut network_stack = TestingStack::default();
// generate the server // generate the server
let security_parameters = SecurityParameters { let security_parameters = SecurityParameters {
allow_unauthenticated: false, allow_unauthenticated: false,
@@ -74,33 +58,115 @@ mod test {
username == "awick" && password == "password" username == "awick" && password == "password"
}), }),
}; };
let default_port = network_stack.listen("localhost", 9999).await.unwrap(); let server = SOCKSv5Server::new(security_parameters);
let server = server.start("localhost", 9997).await.unwrap();
SOCKSv5Server::new(network_stack.clone(), security_parameters, default_port);
let _server_task = task::spawn(async move { server.run().await });
// try the positive side // try the positive side
let stream = network_stack.connect("localhost", 9999).await.unwrap();
let login_info = LoginInfo { let login_info = LoginInfo {
username_password: Some(UsernamePassword { username_password: Some(UsernamePassword {
username: "awick".to_string(), username: "awick".to_string(),
password: "password".to_string(), password: "password".to_string(),
}), }),
}; };
let client = SOCKSv5Client::new(network_stack.clone(), stream, &login_info).await; let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
assert!(client.is_ok()); assert!(client.is_ok());
// try the negative side // try the negative side
let stream = network_stack.connect("localhost", 9999).await.unwrap();
let login_info = LoginInfo { let login_info = LoginInfo {
username_password: Some(UsernamePassword { username_password: Some(UsernamePassword {
username: "adamw".to_string(), username: "adamw".to_string(),
password: "password".to_string(), password: "password".to_string(),
}), }),
}; };
let client = SOCKSv5Client::new(network_stack, stream, &login_info).await; let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
assert!(client.is_err()); assert!(client.is_err());
}
#[tokio::test]
async fn firewall_blocks() {
// generate the server
let mut security_parameters = SecurityParameters::unrestricted();
security_parameters.allow_connection = Some(|_| false);
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9996).await.unwrap();
let login_info = LoginInfo::new();
let client = SOCKSv5Client::new(login_info, "localhost", 9996).await;
assert!(client.is_err());
}
#[tokio::test]
async fn establish_stream() -> io::Result<()> {
let target_socket = TcpSocket::new_v4()?;
target_socket.bind(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1337,
))?;
let target_port = target_socket.listen(1)?;
// generate the server
let security_parameters = SecurityParameters::unrestricted();
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9995).await.unwrap();
let login_info = LoginInfo {
username_password: None,
};
let mut client = SOCKSv5Client::new(login_info, "localhost", 9995)
.await
.unwrap();
task::spawn(async move {
let mut conn = client.connect("localhost", 1337).await.unwrap();
conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap();
});
let (mut target_connection, _) = target_port.accept().await.unwrap();
let mut read_buffer = [0; 4];
target_connection
.read_exact(&mut read_buffer)
.await
.unwrap();
assert_eq!(read_buffer, [1, 3, 3, 7]);
Ok(())
}
#[tokio::test]
async fn bind_test() -> io::Result<()> {
let security_parameters = SecurityParameters::unrestricted();
let server = SOCKSv5Server::new(security_parameters);
server.start("localhost", 9994).await.unwrap();
let login_info = LoginInfo::default();
let client = SOCKSv5Client::new(login_info, "localhost", 9994)
.await
.unwrap();
let (target_sender, target_receiver) = oneshot::channel();
task::spawn(async move {
let (_, _, mut conn) = client
.remote_listen("localhost", 9993, |addr, port| async move {
target_sender.send((addr, port)).unwrap();
Ok(())
}) })
.await
.unwrap();
conn.write_all(&[2, 3, 5, 7]).await.unwrap();
});
let (target_addr, target_port) = target_receiver.await.unwrap();
let mut stream = match target_addr {
SOCKSv5Address::IP4(x) => TcpStream::connect((x, target_port)).await?,
SOCKSv5Address::IP6(x) => TcpStream::connect((x, target_port)).await?,
SOCKSv5Address::Hostname(x) => TcpStream::connect((x, target_port)).await?,
};
let mut read_buffer = [0; 4];
stream.read_exact(&mut read_buffer).await.unwrap();
assert_eq!(read_buffer, [2, 3, 5, 7]);
Ok(())
} }
} }

View File

@@ -5,12 +5,52 @@ mod client_username_password;
mod server_auth_response; mod server_auth_response;
mod server_choice; mod server_choice;
mod server_response; mod server_response;
pub(crate) mod utils;
pub use crate::messages::authentication_method::AuthenticationMethod; pub(crate) mod string;
pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest};
pub use crate::messages::client_greeting::ClientGreeting; pub use crate::messages::authentication_method::{
pub use crate::messages::client_username_password::ClientUsernamePassword; AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
pub use crate::messages::server_auth_response::ServerAuthResponse; };
pub use crate::messages::server_choice::ServerChoice; pub use crate::messages::client_command::{
pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus}; ClientConnectionCommand, ClientConnectionCommandReadError, ClientConnectionCommandWriteError,
ClientConnectionRequest, ClientConnectionRequestReadError,
};
pub use crate::messages::client_greeting::{
ClientGreeting, ClientGreetingReadError, ClientGreetingWriteError,
};
pub use crate::messages::client_username_password::{
ClientUsernamePassword, ClientUsernamePasswordReadError, ClientUsernamePasswordWriteError,
};
pub use crate::messages::server_auth_response::{
ServerAuthResponse, ServerAuthResponseReadError, ServerAuthResponseWriteError,
};
pub use crate::messages::server_choice::{
ServerChoice, ServerChoiceReadError, ServerChoiceWriteError,
};
pub use crate::messages::server_response::{
ServerResponse, ServerResponseReadError, ServerResponseStatus, ServerResponseWriteError,
};
#[doc(hidden)]
#[macro_export]
macro_rules! standard_roundtrip {
($name: ident, $t: ty) => {
proptest::proptest! {
#[test]
fn $name(xs: $t) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
use std::io::Cursor;
let originals = xs.clone();
let buffer = vec![];
let mut write_cursor = Cursor::new(buffer);
xs.write(&mut write_cursor).await.unwrap();
let serialized_form = write_cursor.into_inner();
let mut read_cursor = Cursor::new(serialized_form);
let ys = <$t>::read(&mut read_cursor);
assert_eq!(originals, ys.await.unwrap());
})
}
}
};
}

View File

@@ -1,13 +1,12 @@
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::task; use proptest::prelude::{prop_oneof, Arbitrary, Just, Strategy};
#[cfg(test)] #[cfg(test)]
use futures::io::Cursor; use proptest::strategy::BoxedStrategy;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::fmt; use std::fmt;
#[cfg(test)]
use std::io::Cursor;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
@@ -25,6 +24,34 @@ pub enum AuthenticationMethod {
NoAcceptableMethods, NoAcceptableMethods,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum AuthenticationMethodReadError {
#[error("Invalid authentication method #{0}")]
UnknownAuthenticationMethod(u8),
#[error("Error in underlying buffer: {0}")]
ReadError(String),
}
impl From<std::io::Error> for AuthenticationMethodReadError {
fn from(x: std::io::Error) -> AuthenticationMethodReadError {
AuthenticationMethodReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum AuthenticationMethodWriteError {
#[error("Trying to write invalid authentication method #{0}")]
InvalidAuthMethod(u8),
#[error("Error in underlying buffer: {0}")]
WriteError(String),
}
impl From<std::io::Error> for AuthenticationMethodWriteError {
fn from(x: std::io::Error) -> AuthenticationMethodWriteError {
AuthenticationMethodWriteError::WriteError(format!("{}", x))
}
}
impl fmt::Display for AuthenticationMethod { impl fmt::Display for AuthenticationMethod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
@@ -45,18 +72,34 @@ impl fmt::Display for AuthenticationMethod {
} }
} }
#[cfg(test)]
impl Arbitrary for AuthenticationMethod {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> BoxedStrategy<Self> {
prop_oneof![
Just(AuthenticationMethod::None),
Just(AuthenticationMethod::GSSAPI),
Just(AuthenticationMethod::UsernameAndPassword),
Just(AuthenticationMethod::ChallengeHandshake),
Just(AuthenticationMethod::ChallengeResponse),
Just(AuthenticationMethod::SSL),
Just(AuthenticationMethod::NDS),
Just(AuthenticationMethod::MultiAuthenticationFramework),
Just(AuthenticationMethod::JSONPropertyBlock),
Just(AuthenticationMethod::NoAcceptableMethods),
(0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod),
]
.boxed()
}
}
impl AuthenticationMethod { impl AuthenticationMethod {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<AuthenticationMethod, DeserializationError> { ) -> Result<AuthenticationMethod, AuthenticationMethodReadError> {
let mut byte_buffer = [0u8; 1]; match r.read_u8().await? {
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(AuthenticationDeserializationError::NoDataFound.into());
}
match byte_buffer[0] {
0 => Ok(AuthenticationMethod::None), 0 => Ok(AuthenticationMethod::None),
1 => Ok(AuthenticationMethod::GSSAPI), 1 => Ok(AuthenticationMethod::GSSAPI),
2 => Ok(AuthenticationMethod::UsernameAndPassword), 2 => Ok(AuthenticationMethod::UsernameAndPassword),
@@ -68,14 +111,16 @@ impl AuthenticationMethod {
9 => Ok(AuthenticationMethod::JSONPropertyBlock), 9 => Ok(AuthenticationMethod::JSONPropertyBlock),
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)), x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
0xff => Ok(AuthenticationMethod::NoAcceptableMethods), 0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()), e => Err(AuthenticationMethodReadError::UnknownAuthenticationMethod(
e,
)),
} }
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), AuthenticationMethodWriteError> {
let value = match self { let value = match self {
AuthenticationMethod::None => 0, AuthenticationMethod::None => 0,
AuthenticationMethod::GSSAPI => 1, AuthenticationMethod::GSSAPI => 1,
@@ -86,53 +131,32 @@ impl AuthenticationMethod {
AuthenticationMethod::NDS => 7, AuthenticationMethod::NDS => 7,
AuthenticationMethod::MultiAuthenticationFramework => 8, AuthenticationMethod::MultiAuthenticationFramework => 8,
AuthenticationMethod::JSONPropertyBlock => 9, AuthenticationMethod::JSONPropertyBlock => 9,
AuthenticationMethod::PrivateMethod(pm) => *pm, AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(&pm) => pm,
AuthenticationMethod::PrivateMethod(pm) => {
return Err(AuthenticationMethodWriteError::InvalidAuthMethod(pm))
}
AuthenticationMethod::NoAcceptableMethods => 0xff, AuthenticationMethod::NoAcceptableMethods => 0xff,
}; };
Ok(w.write_all(&[value]).await?) Ok(w.write_u8(value).await?)
} }
} }
#[cfg(test)] crate::standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
impl Arbitrary for AuthenticationMethod {
fn arbitrary(g: &mut Gen) -> AuthenticationMethod {
let mut vals = vec![
AuthenticationMethod::None,
AuthenticationMethod::GSSAPI,
AuthenticationMethod::UsernameAndPassword,
AuthenticationMethod::ChallengeHandshake,
AuthenticationMethod::ChallengeResponse,
AuthenticationMethod::SSL,
AuthenticationMethod::NDS,
AuthenticationMethod::MultiAuthenticationFramework,
AuthenticationMethod::JSONPropertyBlock,
AuthenticationMethod::NoAcceptableMethods,
];
for x in 0x80..0xffu8 {
vals.push(AuthenticationMethod::PrivateMethod(x));
}
g.choose(&vals).unwrap().clone()
}
}
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod); #[tokio::test]
async fn bad_byte() {
#[test]
fn bad_byte() {
let no_len = vec![42]; let no_len = vec![42];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = AuthenticationMethod::read(&mut cursor); let ys = AuthenticationMethod::read(&mut cursor).await.unwrap_err();
assert_eq!( assert_eq!(
Err(DeserializationError::AuthenticationMethodError( AuthenticationMethodReadError::UnknownAuthenticationMethod(42),
AuthenticationDeserializationError::InvalidAuthenticationByte(42) ys
)),
task::block_on(ys)
); );
} }
#[test] #[tokio::test]
fn display_isnt_empty() { async fn display_isnt_empty() {
let vals = vec![ let vals = vec![
AuthenticationMethod::None, AuthenticationMethod::None,
AuthenticationMethod::GSSAPI, AuthenticationMethod::GSSAPI,

View File

@@ -1,56 +1,121 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
use crate::network::SOCKSv5Address;
use crate::serialize::read_amt;
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::io::ErrorKind; use proptest_derive::Arbitrary;
#[cfg(test)] #[cfg(test)]
use async_std::task; use std::io::Cursor;
#[cfg(test)] use thiserror::Error;
use futures::io::Cursor; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
#[cfg(test)]
use std::net::Ipv4Addr;
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
pub enum ClientConnectionCommand { pub enum ClientConnectionCommand {
EstablishTCPStream, EstablishTCPStream,
EstablishTCPPortBinding, EstablishTCPPortBinding,
AssociateUDPPort, AssociateUDPPort,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientConnectionCommandReadError {
#[error("Invalid client connection command code: {0}")]
InvalidClientConnectionCommand(u8),
#[error("Underlying buffer read error: {0}")]
ReadError(String),
}
impl From<std::io::Error> for ClientConnectionCommandReadError {
fn from(x: std::io::Error) -> ClientConnectionCommandReadError {
ClientConnectionCommandReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientConnectionCommandWriteError {
#[error("Underlying buffer write error: {0}")]
WriteError(String),
#[error(transparent)]
SOCKSAddressWriteError(#[from] SOCKSv5AddressWriteError),
}
impl From<std::io::Error> for ClientConnectionCommandWriteError {
fn from(x: std::io::Error) -> ClientConnectionCommandWriteError {
ClientConnectionCommandWriteError::WriteError(format!("{}", x))
}
}
impl ClientConnectionCommand {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<ClientConnectionCommand, ClientConnectionCommandReadError> {
match r.read_u8().await? {
0x01 => Ok(ClientConnectionCommand::EstablishTCPStream),
0x02 => Ok(ClientConnectionCommand::EstablishTCPPortBinding),
0x03 => Ok(ClientConnectionCommand::AssociateUDPPort),
x => Err(ClientConnectionCommandReadError::InvalidClientConnectionCommand(x)),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
self,
w: &mut W,
) -> Result<(), std::io::Error> {
match self {
ClientConnectionCommand::EstablishTCPStream => w.write_u8(0x01).await,
ClientConnectionCommand::EstablishTCPPortBinding => w.write_u8(0x02).await,
ClientConnectionCommand::AssociateUDPPort => w.write_u8(0x03).await,
}
}
}
crate::standard_roundtrip!(client_command_roundtrips, ClientConnectionCommand);
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct ClientConnectionRequest { pub struct ClientConnectionRequest {
pub command_code: ClientConnectionCommand, pub command_code: ClientConnectionCommand,
pub destination_address: SOCKSv5Address, pub destination_address: SOCKSv5Address,
pub destination_port: u16, pub destination_port: u16,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientConnectionRequestReadError {
#[error("Invalid version in client request: {0} (expected 5)")]
InvalidVersion(u8),
#[error("Invalid command for client request: {0}")]
InvalidCommand(#[from] ClientConnectionCommandReadError),
#[error("Invalid reserved byte: {0} (expected 0)")]
InvalidReservedByte(u8),
#[error("Underlying read error: {0}")]
ReadError(String),
#[error(transparent)]
AddressReadError(#[from] SOCKSv5AddressReadError),
}
impl From<std::io::Error> for ClientConnectionRequestReadError {
fn from(x: std::io::Error) -> ClientConnectionRequestReadError {
ClientConnectionRequestReadError::ReadError(format!("{}", x))
}
}
impl ClientConnectionRequest { impl ClientConnectionRequest {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ClientConnectionRequestReadError> {
let mut buffer = [0; 2]; let version = r.read_u8().await?;
if version != 5 {
read_amt(r, 2, &mut buffer).await?; return Err(ClientConnectionRequestReadError::InvalidVersion(version));
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
} }
let command_code = match buffer[1] { let command_code = ClientConnectionCommand::read(r).await?;
0x01 => ClientConnectionCommand::EstablishTCPStream,
0x02 => ClientConnectionCommand::EstablishTCPPortBinding, let reserved = r.read_u8().await?;
0x03 => ClientConnectionCommand::AssociateUDPPort, if reserved != 0 {
x => return Err(DeserializationError::InvalidClientCommand(x)), return Err(ClientConnectionRequestReadError::InvalidReservedByte(
}; reserved,
));
}
let destination_address = SOCKSv5Address::read(r).await?; let destination_address = SOCKSv5Address::read(r).await?;
let destination_port = r.read_u16().await?;
read_amt(r, 2, &mut buffer).await?;
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ClientConnectionRequest { Ok(ClientConnectionRequest {
command_code, command_code,
@@ -60,92 +125,64 @@ impl ClientConnectionRequest {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ClientConnectionCommandWriteError> {
let command = match self.command_code { w.write_u8(5).await?;
ClientConnectionCommand::EstablishTCPStream => 1, self.command_code.write(w).await?;
ClientConnectionCommand::EstablishTCPPortBinding => 2, w.write_u8(0).await?;
ClientConnectionCommand::AssociateUDPPort => 3,
};
w.write_all(&[5, command]).await?;
self.destination_address.write(w).await?; self.destination_address.write(w).await?;
w.write_all(&[ w.write_u16(self.destination_port).await?;
(self.destination_port >> 8) as u8, Ok(())
(self.destination_port & 0xffu16) as u8,
])
.await
.map_err(SerializationError::IOError)
} }
} }
#[cfg(test)] crate::standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
impl Arbitrary for ClientConnectionCommand {
fn arbitrary(g: &mut Gen) -> ClientConnectionCommand {
let options = [
ClientConnectionCommand::EstablishTCPStream,
ClientConnectionCommand::EstablishTCPPortBinding,
ClientConnectionCommand::AssociateUDPPort,
];
g.choose(&options).unwrap().clone()
}
}
#[cfg(test)] #[tokio::test]
impl Arbitrary for ClientConnectionRequest { async fn check_short_reads() {
fn arbitrary(g: &mut Gen) -> Self {
let command_code = ClientConnectionCommand::arbitrary(g);
let destination_address = SOCKSv5Address::arbitrary(g);
let destination_port = u16::arbitrary(g);
ClientConnectionRequest {
command_code,
destination_address,
destination_port,
}
}
}
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
#[test]
fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(
ys,
Err(ClientConnectionRequestReadError::ReadError(_))
));
let no_len = vec![5, 1]; let no_len = vec![5, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(
ys,
Err(ClientConnectionRequestReadError::ReadError(_))
));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let bad_ver = vec![6, 1, 1]; let bad_ver = vec![6, 1, 1];
let mut cursor = Cursor::new(bad_ver); let mut cursor = Cursor::new(bad_ver);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ClientConnectionRequestReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_bad_command() { async fn check_bad_command() {
let bad_cmd = vec![5, 32, 1]; let bad_cmd = vec![5, 32, 1];
let mut cursor = Cursor::new(bad_cmd); let mut cursor = Cursor::new(bad_cmd);
let ys = ClientConnectionRequest::read(&mut cursor); let ys = ClientConnectionRequest::read(&mut cursor).await;
assert_eq!( assert_eq!(
Err(DeserializationError::InvalidClientCommand(32)), Err(ClientConnectionRequestReadError::InvalidCommand(
task::block_on(ys) ClientConnectionCommandReadError::InvalidClientConnectionCommand(32)
)),
ys
); );
} }
#[test] #[tokio::test]
fn short_write_fails_right() { async fn short_write_fails_right() {
use std::net::Ipv4Addr;
let mut buffer = [0u8; 2]; let mut buffer = [0u8; 2];
let cmd = ClientConnectionRequest { let cmd = ClientConnectionRequest {
command_code: ClientConnectionCommand::AssociateUDPPort, command_code: ClientConnectionCommand::AssociateUDPPort,
@@ -153,10 +190,12 @@ fn short_write_fails_right() {
destination_port: 22, destination_port: 22,
}; };
let mut cursor = Cursor::new(&mut buffer as &mut [u8]); let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
let result = task::block_on(cmd.write(&mut cursor)); let result = cmd.write(&mut cursor).await;
match result { match result {
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."), Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), Err(ClientConnectionCommandWriteError::WriteError(x)) => {
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e), assert!(x.contains("write zero"));
}
Err(e) => panic!("Got the wrong error writing too much data: {}", e),
} }
} }

View File

@@ -1,15 +1,12 @@
use crate::messages::authentication_method::{
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
};
#[cfg(test)] #[cfg(test)]
use crate::errors::AuthenticationDeserializationError; use proptest_derive::Arbitrary;
use crate::errors::{DeserializationError, SerializationError};
use crate::messages::AuthenticationMethod;
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::task; use std::io::Cursor;
#[cfg(test)] use thiserror::Error;
use futures::io::Cursor; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
/// Client greetings are the first message sent in a SOCKSv5 session. They /// Client greetings are the first message sent in a SOCKSv5 session. They
/// identify that there's a client that wants to talk to a server, and that /// identify that there's a client that wants to talk to a server, and that
@@ -17,30 +14,57 @@ use quickcheck::{quickcheck, Arbitrary, Gen};
/// said server. (It feels weird that the offer/choice goes this way instead /// said server. (It feels weird that the offer/choice goes this way instead
/// of the reverse, but whatever.) /// of the reverse, but whatever.)
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct ClientGreeting { pub struct ClientGreeting {
pub acceptable_methods: Vec<AuthenticationMethod>, pub acceptable_methods: Vec<AuthenticationMethod>,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientGreetingReadError {
#[error("Invalid version in client request: {0} (expected 5)")]
InvalidVersion(u8),
#[error(transparent)]
AuthMethodReadError(#[from] AuthenticationMethodReadError),
#[error("Underlying read error: {0}")]
ReadError(String),
}
impl From<std::io::Error> for ClientGreetingReadError {
fn from(x: std::io::Error) -> ClientGreetingReadError {
ClientGreetingReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientGreetingWriteError {
#[error("Too many methods provided; need <256, saw {0}")]
TooManyMethods(usize),
#[error(transparent)]
AuthMethodWriteError(#[from] AuthenticationMethodWriteError),
#[error("Underlying write error: {0}")]
WriteError(String),
}
impl From<std::io::Error> for ClientGreetingWriteError {
fn from(x: std::io::Error) -> ClientGreetingWriteError {
ClientGreetingWriteError::WriteError(format!("{}", x))
}
}
impl ClientGreeting { impl ClientGreeting {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<ClientGreeting, DeserializationError> { ) -> Result<ClientGreeting, ClientGreetingReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 5 {
return Err(DeserializationError::NotEnoughData); return Err(ClientGreetingReadError::InvalidVersion(version));
} }
if buffer[0] != 5 { let num_methods = r.read_u8().await? as usize;
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
}
if r.read(&mut buffer).await? == 0 { let mut acceptable_methods = Vec::with_capacity(num_methods);
return Err(DeserializationError::NotEnoughData); for _ in 0..num_methods {
}
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
for _ in 0..buffer[0] {
acceptable_methods.push(AuthenticationMethod::read(r).await?); acceptable_methods.push(AuthenticationMethod::read(r).await?);
} }
@@ -48,11 +72,11 @@ impl ClientGreeting {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, mut self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ClientGreetingWriteError> {
if self.acceptable_methods.len() > 255 { if self.acceptable_methods.len() > 255 {
return Err(SerializationError::TooManyAuthMethods( return Err(ClientGreetingWriteError::TooManyMethods(
self.acceptable_methods.len(), self.acceptable_methods.len(),
)); ));
} }
@@ -61,65 +85,48 @@ impl ClientGreeting {
buffer.push(5); buffer.push(5);
buffer.push(self.acceptable_methods.len() as u8); buffer.push(self.acceptable_methods.len() as u8);
w.write_all(&buffer).await?; w.write_all(&buffer).await?;
for authmeth in self.acceptable_methods.iter() { for authmeth in self.acceptable_methods.drain(..) {
authmeth.write(w).await?; authmeth.write(w).await?;
} }
Ok(()) Ok(())
} }
} }
#[cfg(test)] crate::standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
impl Arbitrary for ClientGreeting {
fn arbitrary(g: &mut Gen) -> ClientGreeting {
let amt = u8::arbitrary(g);
let mut acceptable_methods = Vec::with_capacity(amt as usize);
for _ in 0..amt { #[tokio::test]
acceptable_methods.push(AuthenticationMethod::arbitrary(g)); async fn check_short_reads() {
}
ClientGreeting { acceptable_methods }
}
}
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
#[test]
fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
let no_len = vec![5]; let no_len = vec![5];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
let bad_len = vec![5, 9]; let bad_len = vec![5, 9];
let mut cursor = Cursor::new(bad_len); let mut cursor = Cursor::new(bad_len);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!( assert!(matches!(
Err(DeserializationError::AuthenticationMethodError( ys,
AuthenticationDeserializationError::NoDataFound Err(ClientGreetingReadError::AuthMethodReadError(
)), AuthenticationMethodReadError::ReadError(_)
task::block_on(ys) ))
); ));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let no_len = vec![6, 1, 1]; let no_len = vec![6, 1, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ClientGreeting::read(&mut cursor); let ys = ClientGreeting::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ClientGreetingReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_too_many() { async fn check_too_many() {
let mut auth_methods = Vec::with_capacity(512); let mut auth_methods = Vec::with_capacity(512);
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake); auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
let greet = ClientGreeting { let greet = ClientGreeting {
@@ -127,7 +134,7 @@ fn check_too_many() {
}; };
let mut output = vec![0; 1024]; let mut output = vec![0; 1024];
assert_eq!( assert_eq!(
Err(SerializationError::TooManyAuthMethods(512)), Err(ClientGreetingWriteError::TooManyMethods(512)),
task::block_on(greet.write(&mut output)) greet.write(&mut output).await
); );
} }

View File

@@ -1,15 +1,10 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
#[cfg(test)] #[cfg(test)]
use crate::messages::utils::arbitrary_socks_string; use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
use crate::serialize::{read_string, write_string};
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::task; use std::io::Cursor;
#[cfg(test)] use thiserror::Error;
use futures::io::Cursor; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClientUsernamePassword { pub struct ClientUsernamePassword {
@@ -17,68 +12,111 @@ pub struct ClientUsernamePassword {
pub password: String, pub password: String,
} }
#[cfg(test)]
const USERNAME_REGEX: &str = "[a-zA-Z0-9~!@#$%^&*_\\-+=:;?<>]+";
#[cfg(test)]
const PASSWORD_REGEX: &str = "[a-zA-Z0-9~!@#$%^&*_\\-+=:;?<>]+";
#[cfg(test)]
impl Arbitrary for ClientUsernamePassword {
type Parameters = Option<u8>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let max_len = args.unwrap_or(12) as usize;
(USERNAME_REGEX, PASSWORD_REGEX)
.prop_map(move |(mut username, mut password)| {
username.shrink_to(max_len);
password.shrink_to(max_len);
ClientUsernamePassword { username, password }
})
.boxed()
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientUsernamePasswordReadError {
#[error("Underlying buffer read error: {0}")]
ReadError(String),
#[error("Invalid username/password version; expected 1, saw {0}")]
InvalidVersion(u8),
#[error(transparent)]
StringError(#[from] SOCKSv5StringReadError),
}
impl From<std::io::Error> for ClientUsernamePasswordReadError {
fn from(x: std::io::Error) -> ClientUsernamePasswordReadError {
ClientUsernamePasswordReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ClientUsernamePasswordWriteError {
#[error("Underlying buffer read error: {0}")]
WriteError(String),
#[error(transparent)]
StringError(#[from] SOCKSv5StringWriteError),
}
impl From<std::io::Error> for ClientUsernamePasswordWriteError {
fn from(x: std::io::Error) -> ClientUsernamePasswordWriteError {
ClientUsernamePasswordWriteError::WriteError(format!("{}", x))
}
}
impl ClientUsernamePassword { impl ClientUsernamePassword {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ClientUsernamePasswordReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 1 {
return Err(DeserializationError::NotEnoughData); return Err(ClientUsernamePasswordReadError::InvalidVersion(version));
} }
if buffer[0] != 1 { let username = SOCKSv5String::read(r).await?.into();
return Err(DeserializationError::InvalidVersion(1, buffer[0])); let password = SOCKSv5String::read(r).await?.into();
}
let username = read_string(r).await?;
let password = read_string(r).await?;
Ok(ClientUsernamePassword { username, password }) Ok(ClientUsernamePassword { username, password })
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ClientUsernamePasswordWriteError> {
w.write_all(&[1]).await?; w.write_u8(1).await?;
write_string(&self.username, w).await?; SOCKSv5String::from(self.username.as_str()).write(w).await?;
write_string(&self.password, w).await SOCKSv5String::from(self.password.as_str()).write(w).await?;
Ok(())
} }
} }
#[cfg(test)] crate::standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
impl Arbitrary for ClientUsernamePassword {
fn arbitrary(g: &mut Gen) -> Self {
let username = arbitrary_socks_string(g);
let password = arbitrary_socks_string(g);
ClientUsernamePassword { username, password } #[tokio::test]
} async fn heck_short_reads() {
}
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
#[test]
fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ClientUsernamePassword::read(&mut cursor); let ys = ClientUsernamePassword::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(
ys,
Err(ClientUsernamePasswordReadError::ReadError(_))
));
let user_only = vec![1, 3, 102, 111, 111]; let user_only = vec![1, 3, 102, 111, 111];
let mut cursor = Cursor::new(user_only); let mut cursor = Cursor::new(user_only);
let ys = ClientUsernamePassword::read(&mut cursor); let ys = ClientUsernamePassword::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); println!("ys: {:?}", ys);
assert!(matches!(
ys,
Err(ClientUsernamePasswordReadError::StringError(_))
));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let bad_len = vec![5]; let bad_len = vec![5];
let mut cursor = Cursor::new(bad_len); let mut cursor = Cursor::new(bad_len);
let ys = ClientUsernamePassword::read(&mut cursor); let ys = ClientUsernamePassword::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ClientUsernamePasswordReadError::InvalidVersion(5)), ys);
Err(DeserializationError::InvalidVersion(1, 5)),
task::block_on(ys)
);
} }

View File

@@ -1,18 +1,40 @@
use crate::errors::{DeserializationError, SerializationError};
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::task; use proptest_derive::Arbitrary;
#[cfg(test)] use thiserror::Error;
use futures::io::Cursor; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct ServerAuthResponse { pub struct ServerAuthResponse {
pub success: bool, pub success: bool,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerAuthResponseReadError {
#[error("Underlying buffer read error: {0}")]
ReadError(String),
#[error("Invalid username/password version; expected 1, saw {0}")]
InvalidVersion(u8),
}
impl From<std::io::Error> for ServerAuthResponseReadError {
fn from(x: std::io::Error) -> ServerAuthResponseReadError {
ServerAuthResponseReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerAuthResponseWriteError {
#[error("Underlying buffer read error: {0}")]
WriteError(String),
}
impl From<std::io::Error> for ServerAuthResponseWriteError {
fn from(x: std::io::Error) -> ServerAuthResponseWriteError {
ServerAuthResponseWriteError::WriteError(format!("{}", x))
}
}
impl ServerAuthResponse { impl ServerAuthResponse {
pub fn success() -> ServerAuthResponse { pub fn success() -> ServerAuthResponse {
ServerAuthResponse { success: true } ServerAuthResponse { success: true }
@@ -24,30 +46,22 @@ impl ServerAuthResponse {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ServerAuthResponseReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 1 {
return Err(DeserializationError::NotEnoughData); return Err(ServerAuthResponseReadError::InvalidVersion(version));
}
if buffer[0] != 1 {
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
}
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
} }
Ok(ServerAuthResponse { Ok(ServerAuthResponse {
success: buffer[0] == 0, success: r.read_u8().await? == 0,
}) })
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ServerAuthResponseWriteError> {
w.write_all(&[1]).await?; w.write_all(&[1]).await?;
w.write_all(&[if self.success { 0x00 } else { 0xde }]) w.write_all(&[if self.success { 0x00 } else { 0xde }])
.await?; .await?;
@@ -55,36 +69,29 @@ impl ServerAuthResponse {
} }
} }
#[cfg(test)] crate::standard_roundtrip!(server_auth_response, ServerAuthResponse);
impl Arbitrary for ServerAuthResponse {
fn arbitrary(g: &mut Gen) -> ServerAuthResponse {
let success = bool::arbitrary(g);
ServerAuthResponse { success }
}
}
standard_roundtrip!(server_auth_response, ServerAuthResponse); #[tokio::test]
async fn check_short_reads() {
use std::io::Cursor;
#[test]
fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ServerAuthResponse::read(&mut cursor); let ys = ServerAuthResponse::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
let no_len = vec![1]; let no_len = vec![1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ServerAuthResponse::read(&mut cursor); let ys = ServerAuthResponse::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
use std::io::Cursor;
let no_len = vec![6, 1]; let no_len = vec![6, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ServerAuthResponse::read(&mut cursor); let ys = ServerAuthResponse::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerAuthResponseReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(1, 6)),
task::block_on(ys)
);
} }

View File

@@ -1,21 +1,49 @@
use crate::messages::authentication_method::{
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
};
#[cfg(test)] #[cfg(test)]
use crate::errors::AuthenticationDeserializationError; use proptest_derive::Arbitrary;
use crate::errors::{DeserializationError, SerializationError};
use crate::messages::AuthenticationMethod;
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::task; use std::io::Cursor;
#[cfg(test)] use thiserror::Error;
use futures::io::Cursor; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct ServerChoice { pub struct ServerChoice {
pub chosen_method: AuthenticationMethod, pub chosen_method: AuthenticationMethod,
} }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerChoiceReadError {
#[error(transparent)]
AuthMethodError(#[from] AuthenticationMethodReadError),
#[error("Error in underlying buffer: {0}")]
ReadError(String),
#[error("Invalid version; expected 5, got {0}")]
InvalidVersion(u8),
}
impl From<std::io::Error> for ServerChoiceReadError {
fn from(x: std::io::Error) -> ServerChoiceReadError {
ServerChoiceReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerChoiceWriteError {
#[error(transparent)]
AuthMethodError(#[from] AuthenticationMethodWriteError),
#[error("Error in underlying buffer: {0}")]
WriteError(String),
}
impl From<std::io::Error> for ServerChoiceWriteError {
fn from(x: std::io::Error) -> ServerChoiceWriteError {
ServerChoiceWriteError::WriteError(format!("{}", x))
}
}
impl ServerChoice { impl ServerChoice {
pub fn rejection() -> ServerChoice { pub fn rejection() -> ServerChoice {
ServerChoice { ServerChoice {
@@ -31,15 +59,11 @@ impl ServerChoice {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ServerChoiceReadError> {
let mut buffer = [0; 1]; let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 { if version != 5 {
return Err(DeserializationError::NotEnoughData); return Err(ServerChoiceReadError::InvalidVersion(version));
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
} }
let chosen_method = AuthenticationMethod::read(r).await?; let chosen_method = AuthenticationMethod::read(r).await?;
@@ -48,50 +72,34 @@ impl ServerChoice {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ServerChoiceWriteError> {
w.write_all(&[5]).await?; w.write_u8(5).await?;
self.chosen_method.write(w).await self.chosen_method.write(w).await?;
Ok(())
} }
} }
#[cfg(test)] crate::standard_roundtrip!(server_choice_roundtrips, ServerChoice);
impl Arbitrary for ServerChoice {
fn arbitrary(g: &mut Gen) -> ServerChoice {
ServerChoice {
chosen_method: AuthenticationMethod::arbitrary(g),
}
}
}
standard_roundtrip!(server_choice_roundtrips, ServerChoice); #[tokio::test]
async fn check_short_reads() {
#[test]
fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ServerChoice::read(&mut cursor); let ys = ServerChoice::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerChoiceReadError::ReadError(_))));
let bad_len = vec![5]; let bad_len = vec![5];
let mut cursor = Cursor::new(bad_len); let mut cursor = Cursor::new(bad_len);
let ys = ServerChoice::read(&mut cursor); let ys = ServerChoice::read(&mut cursor).await;
assert_eq!( assert!(matches!(ys, Err(ServerChoiceReadError::AuthMethodError(_))));
Err(DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound
)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let no_len = vec![9, 1]; let no_len = vec![9, 1];
let mut cursor = Cursor::new(no_len); let mut cursor = Cursor::new(no_len);
let ys = ServerChoice::read(&mut cursor); let ys = ServerChoice::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerChoiceReadError::InvalidVersion(9)), ys);
Err(DeserializationError::InvalidVersion(5, 9)),
task::block_on(ys)
);
} }

View File

@@ -1,21 +1,13 @@
use crate::errors::{DeserializationError, SerializationError}; use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
use crate::network::SOCKSv5Address;
use crate::serialize::read_amt;
use crate::standard_roundtrip;
#[cfg(test)] #[cfg(test)]
use async_std::io::ErrorKind; use proptest_derive::Arbitrary;
#[cfg(test)] #[cfg(test)]
use async_std::task; use std::io::Cursor;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use log::warn;
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::net::Ipv4Addr;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, Eq, Error, PartialEq)] #[derive(Clone, Debug, Eq, Error, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
pub enum ServerResponseStatus { pub enum ServerResponseStatus {
#[error("Actually, everything's fine (weird to see this in an error)")] #[error("Actually, everything's fine (weird to see this in an error)")]
RequestGranted, RequestGranted,
@@ -38,39 +30,64 @@ pub enum ServerResponseStatus {
} }
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct ServerResponse { pub struct ServerResponse {
pub status: ServerResponseStatus, pub status: ServerResponseStatus,
pub bound_address: SOCKSv5Address, pub bound_address: SOCKSv5Address,
pub bound_port: u16, pub bound_port: u16,
} }
impl ServerResponse { #[derive(Clone, Debug, Error, PartialEq)]
pub fn error<E: Into<ServerResponseStatus>>(resp: E) -> ServerResponse { pub enum ServerResponseReadError {
ServerResponse { #[error("Error reading from underlying buffer: {0}")]
status: resp.into(), ReadError(String),
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)), #[error(transparent)]
bound_port: 0, AddressReadError(#[from] SOCKSv5AddressReadError),
#[error("Invalid version; expected 5, got {0}")]
InvalidVersion(u8),
#[error("Invalid reserved byte; saw {0}, should be 0")]
InvalidReservedByte(u8),
#[error("Invalid (or just unknown) server response value {0}")]
InvalidServerResponse(u8),
} }
impl From<std::io::Error> for ServerResponseReadError {
fn from(x: std::io::Error) -> ServerResponseReadError {
ServerResponseReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerResponseWriteError {
#[error("Error reading from underlying buffer: {0}")]
WriteError(String),
#[error(transparent)]
AddressWriteError(#[from] SOCKSv5AddressWriteError),
}
impl From<std::io::Error> for ServerResponseWriteError {
fn from(x: std::io::Error) -> ServerResponseWriteError {
ServerResponseWriteError::WriteError(format!("{}", x))
} }
} }
impl ServerResponse { impl ServerResponse {
pub async fn read<R: AsyncRead + Send + Unpin>( pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R, r: &mut R,
) -> Result<Self, DeserializationError> { ) -> Result<Self, ServerResponseReadError> {
let mut buffer = [0; 3]; let version = r.read_u8().await?;
if version != 5 {
read_amt(r, 3, &mut buffer).await?; return Err(ServerResponseReadError::InvalidVersion(version));
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
} }
if buffer[2] != 0 { let status_byte = r.read_u8().await?;
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
let reserved_byte = r.read_u8().await?;
if reserved_byte != 0 {
return Err(ServerResponseReadError::InvalidReservedByte(reserved_byte));
} }
let status = match buffer[1] { let status = match status_byte {
0x00 => ServerResponseStatus::RequestGranted, 0x00 => ServerResponseStatus::RequestGranted,
0x01 => ServerResponseStatus::GeneralFailure, 0x01 => ServerResponseStatus::GeneralFailure,
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule, 0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
@@ -80,12 +97,11 @@ impl ServerResponse {
0x06 => ServerResponseStatus::TTLExpired, 0x06 => ServerResponseStatus::TTLExpired,
0x07 => ServerResponseStatus::CommandNotSupported, 0x07 => ServerResponseStatus::CommandNotSupported,
0x08 => ServerResponseStatus::AddressTypeNotSupported, 0x08 => ServerResponseStatus::AddressTypeNotSupported,
x => return Err(DeserializationError::InvalidServerResponse(x)), x => return Err(ServerResponseReadError::InvalidServerResponse(x)),
}; };
let bound_address = SOCKSv5Address::read(r).await?; let bound_address = SOCKSv5Address::read(r).await?;
read_amt(r, 2, &mut buffer).await?; let bound_port = r.read_u16().await?;
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
Ok(ServerResponse { Ok(ServerResponse {
status, status,
@@ -95,9 +111,11 @@ impl ServerResponse {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), SerializationError> { ) -> Result<(), ServerResponseWriteError> {
w.write_u8(5).await?;
let status_code = match self.status { let status_code = match self.status {
ServerResponseStatus::RequestGranted => 0x00, ServerResponseStatus::RequestGranted => 0x00,
ServerResponseStatus::GeneralFailure => 0x01, ServerResponseStatus::GeneralFailure => 0x01,
@@ -109,92 +127,61 @@ impl ServerResponse {
ServerResponseStatus::CommandNotSupported => 0x07, ServerResponseStatus::CommandNotSupported => 0x07,
ServerResponseStatus::AddressTypeNotSupported => 0x08, ServerResponseStatus::AddressTypeNotSupported => 0x08,
}; };
w.write_u8(status_code).await?;
w.write_all(&[5, status_code, 0]).await?; w.write_u8(0).await?;
self.bound_address.write(w).await?; self.bound_address.write(w).await?;
w.write_all(&[ w.write_u16(self.bound_port).await?;
(self.bound_port >> 8) as u8,
(self.bound_port & 0xffu16) as u8, Ok(())
])
.await
.map_err(SerializationError::IOError)
} }
} }
#[cfg(test)] crate::standard_roundtrip!(server_response_roundtrips, ServerResponse);
impl Arbitrary for ServerResponseStatus {
fn arbitrary(g: &mut Gen) -> ServerResponseStatus {
let options = [
ServerResponseStatus::RequestGranted,
ServerResponseStatus::GeneralFailure,
ServerResponseStatus::ConnectionNotAllowedByRule,
ServerResponseStatus::NetworkUnreachable,
ServerResponseStatus::HostUnreachable,
ServerResponseStatus::ConnectionRefused,
ServerResponseStatus::TTLExpired,
ServerResponseStatus::CommandNotSupported,
ServerResponseStatus::AddressTypeNotSupported,
];
g.choose(&options).unwrap().clone()
}
}
#[cfg(test)] #[tokio::test]
impl Arbitrary for ServerResponse { async fn check_short_reads() {
fn arbitrary(g: &mut Gen) -> Self {
let status = ServerResponseStatus::arbitrary(g);
let bound_address = SOCKSv5Address::arbitrary(g);
let bound_port = u16::arbitrary(g);
ServerResponse {
status,
bound_address,
bound_port,
}
}
}
standard_roundtrip!(server_response_roundtrips, ServerResponse);
#[test]
fn check_short_reads() {
let empty = vec![]; let empty = vec![];
let mut cursor = Cursor::new(empty); let mut cursor = Cursor::new(empty);
let ys = ServerResponse::read(&mut cursor); let ys = ServerResponse::read(&mut cursor).await;
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys)); assert!(matches!(ys, Err(ServerResponseReadError::ReadError(_))));
} }
#[test] #[tokio::test]
fn check_bad_version() { async fn check_bad_version() {
let bad_ver = vec![6, 1, 1]; let bad_ver = vec![6, 1, 1];
let mut cursor = Cursor::new(bad_ver); let mut cursor = Cursor::new(bad_ver);
let ys = ServerResponse::read(&mut cursor); let ys = ServerResponse::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerResponseReadError::InvalidVersion(6)), ys);
Err(DeserializationError::InvalidVersion(5, 6)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn check_bad_command() { async fn check_bad_reserved() {
let bad_cmd = vec![5, 32, 0x42]; let bad_cmd = vec![5, 32, 0x42];
let mut cursor = Cursor::new(bad_cmd); let mut cursor = Cursor::new(bad_cmd);
let ys = ServerResponse::read(&mut cursor); let ys = ServerResponse::read(&mut cursor).await;
assert_eq!( assert_eq!(Err(ServerResponseReadError::InvalidReservedByte(0x42)), ys);
Err(DeserializationError::InvalidServerResponse(32)),
task::block_on(ys)
);
} }
#[test] #[tokio::test]
fn short_write_fails_right() { async fn check_bad_command() {
let bad_cmd = vec![5, 32, 0];
let mut cursor = Cursor::new(bad_cmd);
let ys = ServerResponse::read(&mut cursor).await;
assert_eq!(Err(ServerResponseReadError::InvalidServerResponse(32)), ys);
}
#[tokio::test]
async fn short_write_fails_right() {
let mut buffer = [0u8; 2]; let mut buffer = [0u8; 2];
let cmd = ServerResponse::error(ServerResponseStatus::AddressTypeNotSupported); let cmd = ServerResponse {
status: ServerResponseStatus::AddressTypeNotSupported,
bound_address: SOCKSv5Address::Hostname("tester.com".to_string()),
bound_port: 99,
};
let mut cursor = Cursor::new(&mut buffer as &mut [u8]); let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
let result = task::block_on(cmd.write(&mut cursor)); let result = cmd.write(&mut cursor).await;
match result { assert!(matches!(
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."), result,
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()), Err(ServerResponseWriteError::WriteError(_))
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e), ));
}
} }

117
src/messages/string.rs Normal file
View File

@@ -0,0 +1,117 @@
#[cfg(test)]
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
use std::convert::TryFrom;
use std::string::FromUtf8Error;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, PartialEq)]
pub struct SOCKSv5String(String);
#[cfg(test)]
const STRING_REGEX: &str = "[a-zA-Z0-9_.|!@#$%^]+";
#[cfg(test)]
impl Arbitrary for SOCKSv5String {
type Parameters = Option<u16>;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
let max_len = args.unwrap_or(32) as usize;
STRING_REGEX
.prop_map(move |mut str| {
str.shrink_to(max_len);
SOCKSv5String(str)
})
.boxed()
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5StringReadError {
#[error("Underlying buffer read error: {0}")]
ReadError(String),
#[error("SOCKSv5 string encoding error; encountered empty string (?)")]
ZeroStringLength,
#[error("Invalid UTF-8 string: {0}")]
InvalidUtf8Error(#[from] FromUtf8Error),
}
impl From<std::io::Error> for SOCKSv5StringReadError {
fn from(x: std::io::Error) -> SOCKSv5StringReadError {
SOCKSv5StringReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5StringWriteError {
#[error("Underlying buffer write error: {0}")]
WriteError(String),
#[error("String too large to encode according to SOCKSv5 reuls ({0} bytes long)")]
TooBig(usize),
#[error("Cannot serialize the empty string in SOCKSv5")]
ZeroStringLength,
}
impl From<std::io::Error> for SOCKSv5StringWriteError {
fn from(x: std::io::Error) -> SOCKSv5StringWriteError {
SOCKSv5StringWriteError::WriteError(format!("{}", x))
}
}
impl SOCKSv5String {
pub async fn read<R: AsyncRead + Unpin>(r: &mut R) -> Result<Self, SOCKSv5StringReadError> {
let length = r.read_u8().await? as usize;
if length == 0 {
return Err(SOCKSv5StringReadError::ZeroStringLength);
}
let mut bytestring = vec![0; length];
r.read_exact(&mut bytestring).await?;
Ok(SOCKSv5String(String::from_utf8(bytestring)?))
}
pub async fn write<W: AsyncWrite + Unpin>(
self,
w: &mut W,
) -> Result<(), SOCKSv5StringWriteError> {
let bytestring = self.0.as_bytes();
if bytestring.is_empty() {
return Err(SOCKSv5StringWriteError::ZeroStringLength);
}
let length = match u8::try_from(bytestring.len()) {
Err(_) => return Err(SOCKSv5StringWriteError::TooBig(bytestring.len())),
Ok(x) => x,
};
w.write_u8(length).await?;
w.write_all(bytestring).await?;
Ok(())
}
}
impl From<String> for SOCKSv5String {
fn from(x: String) -> Self {
SOCKSv5String(x)
}
}
impl<'a> From<&'a str> for SOCKSv5String {
fn from(x: &str) -> Self {
SOCKSv5String(x.to_string())
}
}
impl From<SOCKSv5String> for String {
fn from(x: SOCKSv5String) -> Self {
x.0
}
}
crate::standard_roundtrip!(socks_string_roundtrips, SOCKSv5String);

View File

@@ -1,33 +0,0 @@
#[cfg(test)]
use quickcheck::{Arbitrary, Gen};
#[cfg(test)]
pub fn arbitrary_socks_string(g: &mut Gen) -> String {
loop {
let mut potential = String::arbitrary(g);
potential.truncate(255);
let bytestring = potential.as_bytes();
if bytestring.len() > 0 && bytestring.len() < 256 {
return potential;
}
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! standard_roundtrip {
($name: ident, $t: ty) => {
#[cfg(test)]
quickcheck! {
fn $name(xs: $t) -> bool {
let mut buffer = vec![];
task::block_on(xs.write(&mut buffer)).unwrap();
let mut cursor = Cursor::new(buffer);
let ys = <$t>::read(&mut cursor);
xs == task::block_on(ys).unwrap()
}
}
};
}

View File

@@ -1,10 +0,0 @@
pub mod address;
pub mod datagram;
pub mod generic;
pub mod listener;
pub mod standard;
pub mod stream;
pub mod testing;
pub use crate::network::address::SOCKSv5Address;
pub use crate::network::standard::Builtin;

View File

@@ -1,258 +0,0 @@
use crate::errors::{DeserializationError, SerializationError};
#[cfg(test)]
use crate::messages::utils::arbitrary_socks_string;
use crate::serialize::{read_amt, read_string, write_string};
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(test)]
use quickcheck::{quickcheck, Arbitrary, Gen};
use std::convert::TryFrom;
use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use thiserror::Error;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum SOCKSv5Address {
IP4(Ipv4Addr),
IP6(Ipv6Addr),
Name(String),
}
#[derive(Error, Debug, PartialEq)]
pub enum AddressConversionError {
#[error("Couldn't convert IPv4 address into destination type")]
CouldntConvertIP4,
#[error("Couldn't convert IPv6 address into destination type")]
CouldntConvertIP6,
#[error("Couldn't convert name into destination type")]
CouldntConvertName,
}
impl From<IpAddr> for SOCKSv5Address {
fn from(x: IpAddr) -> SOCKSv5Address {
match x {
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
}
}
}
impl TryFrom<SOCKSv5Address> for IpAddr {
type Error = AddressConversionError;
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
match value {
SOCKSv5Address::IP4(a) => Ok(IpAddr::V4(a)),
SOCKSv5Address::IP6(a) => Ok(IpAddr::V6(a)),
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
}
}
}
impl From<Ipv4Addr> for SOCKSv5Address {
fn from(x: Ipv4Addr) -> Self {
SOCKSv5Address::IP4(x)
}
}
impl TryFrom<SOCKSv5Address> for Ipv4Addr {
type Error = AddressConversionError;
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
match value {
SOCKSv5Address::IP4(a) => Ok(a),
SOCKSv5Address::IP6(_) => Err(AddressConversionError::CouldntConvertIP6),
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
}
}
}
impl From<Ipv6Addr> for SOCKSv5Address {
fn from(x: Ipv6Addr) -> Self {
SOCKSv5Address::IP6(x)
}
}
impl TryFrom<SOCKSv5Address> for Ipv6Addr {
type Error = AddressConversionError;
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
match value {
SOCKSv5Address::IP4(_) => Err(AddressConversionError::CouldntConvertIP4),
SOCKSv5Address::IP6(a) => Ok(a),
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
}
}
}
impl From<String> for SOCKSv5Address {
fn from(x: String) -> Self {
SOCKSv5Address::Name(x)
}
}
impl<'a> From<&'a str> for SOCKSv5Address {
fn from(x: &str) -> SOCKSv5Address {
SOCKSv5Address::Name(x.to_string())
}
}
impl fmt::Display for SOCKSv5Address {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
SOCKSv5Address::Name(a) => write!(f, "{}", a),
}
}
}
impl SOCKSv5Address {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<Self, DeserializationError> {
let mut byte_buffer = [0u8; 1];
let amount_read = r.read(&mut byte_buffer).await?;
if amount_read == 0 {
return Err(DeserializationError::NotEnoughData);
}
match byte_buffer[0] {
1 => {
let mut addr_buffer = [0; 4];
read_amt(r, 4, &mut addr_buffer).await?;
Ok(SOCKSv5Address::IP4(Ipv4Addr::from(addr_buffer)))
}
3 => {
let mut addr_buffer = [0; 16];
read_amt(r, 16, &mut addr_buffer).await?;
Ok(SOCKSv5Address::IP6(Ipv6Addr::from(addr_buffer)))
}
4 => {
let name = read_string(r).await?;
Ok(SOCKSv5Address::Name(name))
}
x => Err(DeserializationError::InvalidAddressType(x)),
}
}
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
match self {
SOCKSv5Address::IP4(x) => {
w.write_all(&[1]).await?;
w.write_all(&x.octets())
.await
.map_err(SerializationError::IOError)
}
SOCKSv5Address::IP6(x) => {
w.write_all(&[3]).await?;
w.write_all(&x.octets())
.await
.map_err(SerializationError::IOError)
}
SOCKSv5Address::Name(x) => {
w.write_all(&[4]).await?;
write_string(x, w).await
}
}
}
}
pub trait HasLocalAddress {
fn local_addr(&self) -> (SOCKSv5Address, u16);
}
#[cfg(test)]
impl Arbitrary for SOCKSv5Address {
fn arbitrary(g: &mut Gen) -> Self {
let ip4 = Ipv4Addr::arbitrary(g);
let ip6 = Ipv6Addr::arbitrary(g);
let nm = arbitrary_socks_string(g);
g.choose(&[
SOCKSv5Address::IP4(ip4),
SOCKSv5Address::IP6(ip6),
SOCKSv5Address::Name(nm),
])
.unwrap()
.clone()
}
}
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
#[cfg(test)]
quickcheck! {
fn ip_conversion(x: IpAddr) -> bool {
match x {
IpAddr::V4(ref a) =>
assert_eq!(Err(AddressConversionError::CouldntConvertIP4),
Ipv6Addr::try_from(SOCKSv5Address::from(a.clone()))),
IpAddr::V6(ref a) =>
assert_eq!(Err(AddressConversionError::CouldntConvertIP6),
Ipv4Addr::try_from(SOCKSv5Address::from(a.clone()))),
}
x == IpAddr::try_from(SOCKSv5Address::from(x.clone())).unwrap()
}
fn ip4_conversion(x: Ipv4Addr) -> bool {
x == Ipv4Addr::try_from(SOCKSv5Address::from(x.clone())).unwrap()
}
fn ip6_conversion(x: Ipv6Addr) -> bool {
x == Ipv6Addr::try_from(SOCKSv5Address::from(x.clone())).unwrap()
}
fn display_matches(x: SOCKSv5Address) -> bool {
match x {
SOCKSv5Address::IP4(a) => format!("{}", a) == format!("{}", x),
SOCKSv5Address::IP6(a) => format!("{}", a) == format!("{}", x),
SOCKSv5Address::Name(ref a) => format!("{}", a) == format!("{}", x),
}
}
fn bad_read_key(x: u8) -> bool {
match x {
1 => true,
3 => true,
4 => true,
_ => {
let buffer = [x, 0, 1, 2, 9, 10];
let mut cursor = Cursor::new(buffer);
let meh = SOCKSv5Address::read(&mut cursor);
Err(DeserializationError::InvalidAddressType(x)) == task::block_on(meh)
}
}
}
}
#[test]
fn domain_name_sanity() {
let name = "uhsure.com";
let strname = name.to_string();
let addr1 = SOCKSv5Address::from(name);
let addr2 = SOCKSv5Address::from(strname);
assert_eq!(addr1, addr2);
assert_eq!(
Err(AddressConversionError::CouldntConvertName),
IpAddr::try_from(addr1.clone())
);
assert_eq!(
Err(AddressConversionError::CouldntConvertName),
Ipv4Addr::try_from(addr1.clone())
);
assert_eq!(
Err(AddressConversionError::CouldntConvertName),
Ipv6Addr::try_from(addr1.clone())
);
}

View File

@@ -1,43 +0,0 @@
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use async_trait::async_trait;
#[async_trait]
pub trait Datagramlike: Send + Sync + HasLocalAddress {
type Error;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error>;
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error>;
}
pub struct GenericDatagramSocket<E> {
pub internal: Box<dyn Datagramlike<Error = E>>,
}
#[async_trait]
impl<E> Datagramlike for GenericDatagramSocket<E> {
type Error = E;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
Ok(self.internal.send_to(buf, addr, port).await?)
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
Ok(self.internal.recv_from(buf).await?)
}
}
impl<E> HasLocalAddress for GenericDatagramSocket<E> {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
self.internal.local_addr()
}
}

View File

@@ -1,48 +0,0 @@
use crate::messages::ServerResponseStatus;
use crate::network::address::SOCKSv5Address;
use crate::network::datagram::GenericDatagramSocket;
use crate::network::listener::GenericListener;
use crate::network::stream::GenericStream;
use async_trait::async_trait;
use std::fmt::Display;
#[async_trait]
pub trait Networklike {
/// The error type for things that fail on this network. Apologies in advance
/// for using only one; if you have a use case for separating your errors,
/// please shoot the author(s) and email to split this into multiple types, one
/// for each trait function.
type Error: Display + Into<ServerResponseStatus> + Send;
/// Connect to the given address and port, over this kind of network. The
/// underlying stream should behave somewhat like a TCP stream ... which
/// may be exactly what you're using. However, in order to support tunnelling
/// scenarios (i.e., using another proxy, going through Tor or SSH, etc.) we
/// work generically over any stream-like object.
async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error>;
/// Listen for connections on the given address and port, returning a generic
/// listener socket to use in the future.
async fn listen<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error>;
/// Bind a socket for the purposes of doing some datagram communication. NOTE!
/// this is only for UDP-like communication, not for generic connecting or
/// listening! Maybe obvious from the types, but POSIX has overtrained many
/// of us.
///
/// Recall when using these functions that datagram protocols allow for packet
/// loss and out-of-order delivery. So ... be warned.
async fn bind<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error>;
}

View File

@@ -1,29 +0,0 @@
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use crate::network::stream::GenericStream;
use async_trait::async_trait;
#[async_trait]
pub trait Listenerlike: Send + Sync + HasLocalAddress {
type Error;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>;
}
pub struct GenericListener<E> {
pub internal: Box<dyn Listenerlike<Error = E>>,
}
#[async_trait]
impl<E> Listenerlike for GenericListener<E> {
type Error = E;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
Ok(self.internal.accept().await?)
}
}
impl<E> HasLocalAddress for GenericListener<E> {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
self.internal.local_addr()
}
}

View File

@@ -1,237 +0,0 @@
use crate::messages::ServerResponseStatus;
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use crate::network::datagram::{Datagramlike, GenericDatagramSocket};
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::{GenericStream, Streamlike};
use async_std::io;
#[cfg(test)]
use async_std::io::ReadExt;
use async_std::net::{TcpListener, TcpStream, UdpSocket};
use async_trait::async_trait;
#[cfg(test)]
use futures::AsyncWriteExt;
use log::error;
pub struct Builtin {}
impl Builtin {
pub fn new() -> Builtin {
Builtin {}
}
}
impl Default for Builtin {
fn default() -> Self {
Self::new()
}
}
macro_rules! local_address_impl {
($t: ty) => {
impl HasLocalAddress for $t {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
match self.local_addr() {
Ok(a) =>
(SOCKSv5Address::from(a.ip()), a.port()),
Err(e) => {
error!("Couldn't translate (Streamlike) local address to SOCKS local address: {}", e);
(SOCKSv5Address::from("localhost"), 0)
}
}
}
}
};
}
local_address_impl!(TcpStream);
local_address_impl!(TcpListener);
local_address_impl!(UdpSocket);
impl Streamlike for TcpStream {}
#[async_trait]
impl Listenerlike for TcpListener {
type Error = io::Error;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
let (base, addrport) = self.accept().await?;
let addr = addrport.ip();
let port = addrport.port();
Ok((GenericStream::new(base), SOCKSv5Address::from(addr), port))
}
}
#[async_trait]
impl Datagramlike for UdpSocket {
type Error = io::Error;
async fn send_to(
&self,
buf: &[u8],
addr: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
match addr {
SOCKSv5Address::IP4(a) => self.send_to(buf, (a, port)).await,
SOCKSv5Address::IP6(a) => self.send_to(buf, (a, port)).await,
SOCKSv5Address::Name(n) => self.send_to(buf, (n.as_str(), port)).await,
}
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
let (amt, addrport) = self.recv_from(buf).await?;
let addr = addrport.ip();
let port = addrport.port();
Ok((amt, SOCKSv5Address::from(addr), port))
}
}
#[async_trait]
impl Networklike for Builtin {
type Error = io::Error;
async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error> {
let target = addr.into();
let base_stream = match target {
SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await?,
SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await?,
SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?,
};
Ok(GenericStream::from(base_stream))
}
async fn listen<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error> {
let target = addr.into();
let base_stream = match target {
SOCKSv5Address::IP4(a) => TcpListener::bind((a, port)).await?,
SOCKSv5Address::IP6(a) => TcpListener::bind((a, port)).await?,
SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), port)).await?,
};
Ok(GenericListener {
internal: Box::new(base_stream),
})
}
async fn bind<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
let target = addr.into();
let base_socket = match target {
SOCKSv5Address::IP4(a) => UdpSocket::bind((a, port)).await?,
SOCKSv5Address::IP6(a) => UdpSocket::bind((a, port)).await?,
SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), port)).await?,
};
Ok(GenericDatagramSocket {
internal: Box::new(base_socket),
})
}
}
#[test]
fn check_sanity() {
async_std::task::block_on(async {
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
// going to get any dropped data along here ... which is a very questionable
// assumption, morally speaking, but probably fine for most purposes.
let mut network = Builtin::new();
let receiver = network
.bind("localhost", 0)
.await
.expect("Failed to bind receiver socket.");
let sender = network
.bind("localhost", 0)
.await
.expect("Failed to bind sender socket.");
let buffer = [0xde, 0xea, 0xbe, 0xef];
let (receiver_addr, receiver_port) = receiver.local_addr();
sender
.send_to(&buffer, receiver_addr, receiver_port)
.await
.expect("Failure sending datagram!");
let mut recvbuffer = [0; 4];
let (s, f, p) = receiver
.recv_from(&mut recvbuffer)
.await
.expect("Didn't receive UDP message?");
let (sender_addr, sender_port) = sender.local_addr();
assert_eq!(s, 4);
assert_eq!(f, sender_addr);
assert_eq!(p, sender_port);
assert_eq!(recvbuffer, buffer);
});
// This whole block should be pretty solid, though, unless the system we're
// on is in a pretty weird place.
let mut network = Builtin::new();
let listener = async_std::task::block_on(network.listen("localhost", 0))
.expect("Couldn't set up listener on localhost");
let (listener_address, listener_port) = listener.local_addr();
let listener_task_handle = async_std::task::spawn(async move {
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
let mut result_buffer = [0u8; 4];
stream
.read_exact(&mut result_buffer)
.await
.expect("Read failure in TCP test");
(result_buffer, addr, port)
});
let sender_task_handle = async_std::task::spawn(async move {
let mut sender = network
.connect(listener_address, listener_port)
.await
.expect("Coudln't connect to listener?");
let (sender_address, sender_port) = sender.local_addr();
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
sender
.write_all(&send_buffer)
.await
.expect("Couldn't send the write buffer");
sender
.flush()
.await
.expect("Couldn't flush the write buffer");
sender
.close()
.await
.expect("Couldn't close the write buffer");
(sender_address, sender_port)
});
async_std::task::block_on(async {
let (result, result_from, result_from_port) = listener_task_handle.await;
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
let (sender_address, sender_port) = sender_task_handle.await;
assert_eq!(result_from, sender_address);
assert_eq!(result_from_port, sender_port);
});
}
impl From<io::Error> for ServerResponseStatus {
fn from(e: io::Error) -> ServerResponseStatus {
match e.kind() {
io::ErrorKind::ConnectionRefused => ServerResponseStatus::ConnectionRefused,
io::ErrorKind::NotFound => ServerResponseStatus::HostUnreachable,
_ => ServerResponseStatus::GeneralFailure,
}
}
}

View File

@@ -1,74 +0,0 @@
use crate::network::SOCKSv5Address;
use async_std::task::{Context, Poll};
use futures::io;
use futures::io::{AsyncRead, AsyncWrite};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use super::address::HasLocalAddress;
pub trait Streamlike: AsyncRead + AsyncWrite + HasLocalAddress + Send + Sync + Unpin {}
#[derive(Clone)]
pub struct GenericStream {
internal: Arc<Mutex<dyn Streamlike>>,
}
impl GenericStream {
pub fn new<T: Streamlike + 'static>(x: T) -> GenericStream {
GenericStream {
internal: Arc::new(Mutex::new(x)),
}
}
}
impl HasLocalAddress for GenericStream {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
let item = self.internal.lock().unwrap();
item.local_addr()
}
}
impl AsyncRead for GenericStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_read(cx, buf)
}
}
impl AsyncWrite for GenericStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut item = self.internal.lock().unwrap();
let pinned = Pin::new(&mut *item);
pinned.poll_close(cx)
}
}
impl<T: Streamlike + 'static> From<T> for GenericStream {
fn from(x: T) -> GenericStream {
GenericStream {
internal: Arc::new(Mutex::new(x)),
}
}
}

View File

@@ -1,276 +0,0 @@
mod datagram;
mod stream;
use crate::messages::ServerResponseStatus;
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
#[cfg(test)]
use crate::network::datagram::Datagramlike;
use crate::network::datagram::GenericDatagramSocket;
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::GenericStream;
use crate::network::testing::datagram::TestDatagram;
use crate::network::testing::stream::TestingStream;
use async_std::channel::{bounded, Receiver, Sender};
use async_std::sync::{Arc, Mutex};
#[cfg(test)]
use async_std::task;
use async_trait::async_trait;
#[cfg(test)]
use futures::{AsyncReadExt, AsyncWriteExt};
use std::collections::HashMap;
use std::fmt;
/// A "network", based purely on internal Rust datatypes, for testing
/// networking code. This stack operates purely in memory, so shouldn't
/// suffer from any weird networking effects ... which makes it a good
/// functional test, but not great at actually testing real-world failure
/// modes.
#[allow(clippy::type_complexity)]
#[derive(Clone)]
pub struct TestingStack {
tcp_listeners: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<TestingStream>>>>,
udp_sockets: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<(SOCKSv5Address, u16, Vec<u8>)>>>>,
next_random_socket: u16,
}
impl TestingStack {
pub fn new() -> TestingStack {
TestingStack {
tcp_listeners: Arc::new(Mutex::new(HashMap::new())),
udp_sockets: Arc::new(Mutex::new(HashMap::new())),
next_random_socket: 23,
}
}
}
impl Default for TestingStack {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum TestStackError {
AcceptFailed,
AddressBusy(SOCKSv5Address, u16),
ConnectionFailed,
FailureToSend,
NoTCPHostFound(SOCKSv5Address, u16),
ReceiveFailure,
}
impl fmt::Display for TestStackError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TestStackError::AcceptFailed => write!(f, "Accept failed; the other side died (?)"),
TestStackError::AddressBusy(ref addr, port) => {
write!(f, "Address {}:{} already in use", addr, port)
}
TestStackError::ConnectionFailed => write!(f, "Couldn't connect to host."),
TestStackError::FailureToSend => write!(
f,
"Weird internal error in testing infrastructure; channel send failed"
),
TestStackError::NoTCPHostFound(ref addr, port) => {
write!(f, "No host found at {} for TCP port {}", addr, port)
}
TestStackError::ReceiveFailure => {
write!(f, "Failed to process a UDP receive (this is weird)")
}
}
}
}
impl From<TestStackError> for ServerResponseStatus {
fn from(_: TestStackError) -> Self {
ServerResponseStatus::GeneralFailure
}
}
#[async_trait]
impl Networklike for TestingStack {
type Error = TestStackError;
async fn connect<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
port: u16,
) -> Result<GenericStream, Self::Error> {
let table = self.tcp_listeners.lock().await;
let target = addr.into();
match table.get(&(target.clone(), port)) {
None => Err(TestStackError::NoTCPHostFound(target, port)),
Some(result) => {
let stream = TestingStream::new(target, port);
let retval = stream.invert();
match result.send(stream).await {
Ok(()) => Ok(GenericStream::new(retval)),
Err(_) => Err(TestStackError::FailureToSend),
}
}
}
}
async fn listen<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
mut port: u16,
) -> Result<GenericListener<Self::Error>, Self::Error> {
let mut table = self.tcp_listeners.lock().await;
let target = addr.into();
let (sender, receiver) = bounded(5);
if port == 0 {
port = self.next_random_socket;
self.next_random_socket += 1;
}
table.insert((target.clone(), port), sender);
Ok(GenericListener {
internal: Box::new(TestListener::new(target, port, receiver)),
})
}
async fn bind<A: Send + Into<SOCKSv5Address>>(
&mut self,
addr: A,
mut port: u16,
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
let mut table = self.udp_sockets.lock().await;
let target = addr.into();
let (sender, receiver) = bounded(5);
if port == 0 {
port = self.next_random_socket;
self.next_random_socket += 1;
}
table.insert((target.clone(), port), sender);
Ok(GenericDatagramSocket {
internal: Box::new(TestDatagram::new(self.clone(), target, port, receiver)),
})
}
}
struct TestListener {
address: SOCKSv5Address,
port: u16,
receiver: Receiver<TestingStream>,
}
impl TestListener {
fn new(address: SOCKSv5Address, port: u16, receiver: Receiver<TestingStream>) -> Self {
TestListener {
address,
port,
receiver,
}
}
}
impl HasLocalAddress for TestListener {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
(self.address.clone(), self.port)
}
}
#[async_trait]
impl Listenerlike for TestListener {
type Error = TestStackError;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
match self.receiver.recv().await {
Ok(next) => {
let (addr, port) = next.local_addr();
Ok((GenericStream::new(next), addr, port))
}
Err(_) => Err(TestStackError::AcceptFailed),
}
}
}
#[test]
fn check_udp_sanity() {
task::block_on(async {
let mut network = TestingStack::new();
let receiver = network
.bind("localhost", 0)
.await
.expect("Failed to bind receiver socket.");
let sender = network
.bind("localhost", 0)
.await
.expect("Failed to bind sender socket.");
let buffer = [0xde, 0xea, 0xbe, 0xef];
let (receiver_addr, receiver_port) = receiver.local_addr();
sender
.send_to(&buffer, receiver_addr, receiver_port)
.await
.expect("Failure sending datagram!");
let mut recvbuffer = [0; 4];
let (s, f, p) = receiver
.recv_from(&mut recvbuffer)
.await
.expect("Didn't receive UDP message?");
let (sender_addr, sender_port) = sender.local_addr();
assert_eq!(s, 4);
assert_eq!(f, sender_addr);
assert_eq!(p, sender_port);
assert_eq!(recvbuffer, buffer);
});
}
#[test]
fn check_basic_tcp() {
task::block_on(async {
let mut network = TestingStack::new();
let listener = network
.listen("localhost", 0)
.await
.expect("Couldn't set up listener on localhost");
let (listener_address, listener_port) = listener.local_addr();
let listener_task_handle = task::spawn(async move {
dbg!("Starting listener task!!");
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
let mut result_buffer = [0u8; 4];
if let Err(e) = stream.read_exact(&mut result_buffer).await {
dbg!("Error reading buffer from stream: {}", e);
} else {
dbg!("made it through read_exact");
}
(result_buffer, addr, port)
});
let sender_task_handle = task::spawn(async move {
let mut sender = network
.connect(listener_address, listener_port)
.await
.expect("Coudln't connect to listener?");
let (sender_address, sender_port) = sender.local_addr();
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
sender
.write_all(&send_buffer)
.await
.expect("Couldn't send the write buffer");
sender
.flush()
.await
.expect("Couldn't flush the write buffer");
sender
.close()
.await
.expect("Couldn't close the write buffer");
(sender_address, sender_port)
});
let (result, result_from, result_from_port) = listener_task_handle.await;
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
let (sender_address, sender_port) = sender_task_handle.await;
assert_eq!(result_from, sender_address);
assert_eq!(result_from_port, sender_port);
});
}

View File

@@ -1,88 +0,0 @@
use crate::network::address::HasLocalAddress;
use crate::network::datagram::Datagramlike;
use crate::network::testing::{TestStackError, TestingStack};
use crate::network::SOCKSv5Address;
use async_std::channel::Receiver;
use async_trait::async_trait;
use std::cmp::Ordering;
pub struct TestDatagram {
context: TestingStack,
my_address: SOCKSv5Address,
my_port: u16,
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
}
impl TestDatagram {
pub fn new(
context: TestingStack,
my_address: SOCKSv5Address,
my_port: u16,
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
) -> Self {
TestDatagram {
context,
my_address,
my_port,
input_stream,
}
}
}
impl HasLocalAddress for TestDatagram {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
(self.my_address.clone(), self.my_port)
}
}
#[async_trait]
impl Datagramlike for TestDatagram {
type Error = TestStackError;
async fn send_to(
&self,
buf: &[u8],
target: SOCKSv5Address,
port: u16,
) -> Result<usize, Self::Error> {
let table = self.context.udp_sockets.lock().await;
match table.get(&(target, port)) {
None => Ok(buf.len()),
Some(sender) => {
sender
.send((self.my_address.clone(), self.my_port, buf.to_vec()))
.await
.map_err(|_| TestStackError::FailureToSend)?;
Ok(buf.len())
}
}
}
async fn recv_from(
&self,
buffer: &mut [u8],
) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
let (from_addr, from_port, message) = self
.input_stream
.recv()
.await
.map_err(|_| TestStackError::ReceiveFailure)?;
match message.len().cmp(&buffer.len()) {
Ordering::Greater => {
buffer.copy_from_slice(&message[..buffer.len()]);
Ok((message.len(), from_addr, from_port))
}
Ordering::Less => {
(&mut buffer[..message.len()]).copy_from_slice(&message);
Ok((message.len(), from_addr, from_port))
}
Ordering::Equal => {
buffer.copy_from_slice(message.as_ref());
Ok((message.len(), from_addr, from_port))
}
}
}
}

View File

@@ -1,166 +0,0 @@
use crate::network::address::HasLocalAddress;
use crate::network::stream::Streamlike;
use crate::network::SOCKSv5Address;
use async_std::io;
use async_std::io::{Read, Write};
use async_std::task::{Context, Poll, Waker};
use std::cell::UnsafeCell;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Clone)]
pub struct TestingStream {
address: SOCKSv5Address,
port: u16,
read_side: NonNull<TestingStreamData>,
write_side: NonNull<TestingStreamData>,
}
unsafe impl Send for TestingStream {}
unsafe impl Sync for TestingStream {}
struct TestingStreamData {
lock: AtomicBool,
waiters: UnsafeCell<Vec<Waker>>,
buffer: UnsafeCell<Vec<u8>>,
}
unsafe impl Send for TestingStreamData {}
unsafe impl Sync for TestingStreamData {}
impl TestingStream {
/// Generate a testing stream. Note that this is directional. So, if you want to
/// talk to this stream, you should also generate an `invert()` and pass that to
/// the other thread(s).
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
let read_side_data = TestingStreamData {
lock: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
};
let write_side_data = TestingStreamData {
lock: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
};
let boxed_rsd = Box::new(read_side_data);
let boxed_wsd = Box::new(write_side_data);
let raw_read_ptr = Box::leak(boxed_rsd);
let raw_write_ptr = Box::leak(boxed_wsd);
TestingStream {
address,
port,
read_side: NonNull::new(raw_read_ptr).unwrap(),
write_side: NonNull::new(raw_write_ptr).unwrap(),
}
}
/// Get the flip side of this stream; reads from the inverted side will catch the writes
/// of the original, etc.
pub fn invert(&self) -> TestingStream {
TestingStream {
address: self.address.clone(),
port: self.port,
read_side: self.write_side,
write_side: self.read_side,
}
}
}
impl TestingStreamData {
fn acquire(&mut self) {
loop {
match self
.lock
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
{
Err(_) => continue,
Ok(_) => return,
}
}
}
fn release(&mut self) {
self.lock.store(false, Ordering::SeqCst);
}
}
impl HasLocalAddress for TestingStream {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
(self.address.clone(), self.port)
}
}
impl Read for TestingStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
// so, we're going to spin here, which is less than ideal but should work fine
// in practice. we'll obviously need to be very careful to ensure that we keep
// the stuff internal to this spin really short.
let internals = unsafe { self.read_side.as_mut() };
internals.acquire();
let stream_buffer = internals.buffer.get_mut();
let amount_available = stream_buffer.len();
if amount_available == 0 {
let waker = cx.waker().clone();
internals.waiters.get_mut().push(waker);
internals.release();
return Poll::Pending;
}
let amt_written = if buf.len() >= amount_available {
(&mut buf[0..amount_available]).copy_from_slice(stream_buffer);
stream_buffer.clear();
amount_available
} else {
let amt_to_copy = buf.len();
buf.copy_from_slice(&stream_buffer[0..amt_to_copy]);
stream_buffer.copy_within(amt_to_copy.., 0);
let amt_left = amount_available - amt_to_copy;
stream_buffer.resize(amt_left, 0);
amt_to_copy
};
internals.release();
Poll::Ready(Ok(amt_written))
}
}
impl Write for TestingStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let internals = unsafe { self.write_side.as_mut() };
internals.acquire();
let stream_buffer = internals.buffer.get_mut();
stream_buffer.extend_from_slice(buf);
for waiter in internals.waiters.get_mut().drain(0..) {
waiter.wake();
}
internals.release();
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(())) // FIXME: Might consider having this wait until the buffer is empty
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(())) // FIXME: Might consider putting in some open/closed logic here
}
}
impl Streamlike for TestingStream {}

View File

@@ -0,0 +1,84 @@
use std::net::SocketAddr;
/// The security parameters that you can assign to the server, to make decisions
/// about the weirdos it accepts as users. It is recommended that you only use
/// wide open connections when you're 100% sure that the server will only be
/// accessible locally.
#[derive(Clone)]
pub struct SecurityParameters {
/// Allow completely unauthenticated connections. You should be very, very
/// careful about setting this to true, especially if you don't provide a
/// guard to ensure that you're getting connections from reasonable places.
pub allow_unauthenticated: bool,
/// An optional function that can serve as a firewall for new connections.
/// Return true if the connection should be allowed to continue, false if
/// it shouldn't. This check happens before any data is read from or written
/// to the connecting party.
pub allow_connection: Option<fn(&SocketAddr) -> bool>,
/// An optional function to check a user name (first argument) and password
/// (second argument). Return true if the username / password is good, false
/// if not.
pub check_password: Option<fn(&str, &str) -> bool>,
/// An optional function to transition the stream from an unencrypted one to
/// an encrypted on. The assumption is you're using something like `rustls`
/// to make this happen; the exact mechanism is outside the scope of this
/// particular crate. If the connection shouldn't be allowed for some reason
/// (a bad certificate or handshake, for example), then return None; otherwise,
/// return the new stream.
pub connect_tls: Option<fn() -> Option<()>>,
}
impl SecurityParameters {
/// Generates a `SecurityParameters` object that's empty. It won't accept
/// anything, because it has no mechanisms it can use to actually authenticate
/// a user and yet won't allow unauthenticated connections.
pub fn new() -> SecurityParameters {
SecurityParameters {
allow_unauthenticated: false,
allow_connection: None,
check_password: None,
connect_tls: None,
}
}
/// Generates a `SecurityParameters` object that does not, in any way,
/// restrict who can log in. It also will not induce any transition into
/// TLS. Use this at your own risk ... or, really, just don't use this,
/// ever, and certainly not in production.
pub fn unrestricted() -> SecurityParameters {
SecurityParameters {
allow_unauthenticated: true,
allow_connection: None,
check_password: None,
connect_tls: None,
}
}
/// Use the provided function to check incoming connections before proceeding
/// with the rest of the handshake.
pub fn check_connections(mut self, checker: fn(&SocketAddr) -> bool) -> SecurityParameters {
self.allow_connection = Some(checker);
self
}
/// Use the provided function to check usernames and passwords provided
/// to the server.
pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters {
self.check_password = Some(checker);
self
}
/// Use the provide function to validate a TLS connection, and transition it
/// to the new stream type. If the handshake fails, return `None` instead of
/// `Some`. (And maybe log it somewhere, you know.)
pub fn tls_converter(mut self, converter: fn() -> Option<()>) -> SecurityParameters {
self.connect_tls = Some(converter);
self
}
}
impl Default for SecurityParameters {
fn default() -> Self {
Self::new()
}
}

View File

@@ -1,59 +0,0 @@
use crate::errors::{DeserializationError, SerializationError};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub async fn read_string<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<String, DeserializationError> {
let mut length_buffer = [0; 1];
if r.read(&mut length_buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
let target = length_buffer[0] as usize;
if target == 0 {
return Err(DeserializationError::InvalidEmptyString);
}
let mut bytestring = vec![0; target];
read_amt(r, target, &mut bytestring).await?;
Ok(String::from_utf8(bytestring)?)
}
pub async fn write_string<W: AsyncWrite + Send + Unpin>(
s: &str,
w: &mut W,
) -> Result<(), SerializationError> {
let bytestring = s.as_bytes();
if bytestring.is_empty() || bytestring.len() > 255 {
return Err(SerializationError::InvalidStringLength(s.to_string()));
}
w.write_all(&[bytestring.len() as u8]).await?;
w.write_all(bytestring)
.await
.map_err(SerializationError::IOError)
}
pub async fn read_amt<R: AsyncRead + Send + Unpin>(
r: &mut R,
amt: usize,
buffer: &mut [u8],
) -> Result<(), DeserializationError> {
let mut amt_read = 0;
while amt_read < amt {
let chunk_amt = r.read(&mut buffer[amt_read..]).await?;
if chunk_amt == 0 {
return Err(DeserializationError::NotEnoughData);
}
amt_read += chunk_amt;
}
Ok(())
}

View File

@@ -1,110 +1,415 @@
use crate::errors::{AuthenticationError, DeserializationError, SerializationError}; use crate::address::SOCKSv5Address;
use crate::messages::{ use crate::messages::{
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting, AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandReadError,
ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus, ClientConnectionRequest, ClientConnectionRequestReadError, ClientGreeting,
ClientGreetingReadError, ClientUsernamePassword, ClientUsernamePasswordReadError,
ServerAuthResponse, ServerAuthResponseWriteError, ServerChoice, ServerChoiceWriteError,
ServerResponse, ServerResponseStatus, ServerResponseWriteError,
}; };
use crate::network::address::HasLocalAddress; pub use crate::security_parameters::SecurityParameters;
use crate::network::generic::Networklike; use std::net::SocketAddr;
use crate::network::listener::{GenericListener, Listenerlike}; use std::sync::atomic::{AtomicU64, Ordering};
use crate::network::stream::GenericStream; use std::sync::Arc;
use crate::network::SOCKSv5Address;
use async_std::io;
use async_std::io::prelude::WriteExt;
use async_std::sync::{Arc, Mutex};
use async_std::task;
use log::{error, info, trace, warn};
use thiserror::Error; use thiserror::Error;
use tokio::io::{copy_bidirectional, AsyncWriteExt};
pub struct SOCKSv5Server<N: Networklike> { use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket};
network: N, use tokio::task::JoinHandle;
security_parameters: SecurityParameters, use tracing::{field, info_span, Instrument, Span};
listener: GenericListener<N::Error>,
}
#[derive(Clone)] #[derive(Clone)]
pub struct SecurityParameters { pub struct SOCKSv5Server {
pub allow_unauthenticated: bool, info: Arc<ServerInfo>,
pub allow_connection: Option<fn(&SOCKSv5Address, u16) -> bool>,
pub check_password: Option<fn(&str, &str) -> bool>,
pub connect_tls: Option<fn(GenericStream) -> Option<GenericStream>>,
} }
impl SecurityParameters { struct ServerInfo {
/// Generates a `SecurityParameters` object that does not, in any way,
/// restrict who can log in. It also will not induce any transition into
/// TLS. Use this at your own risk ... or, really, just don't use this,
/// ever, and certainly not in production.
pub fn unrestricted() -> SecurityParameters {
SecurityParameters {
allow_unauthenticated: true,
allow_connection: None,
check_password: None,
connect_tls: None,
}
}
}
impl<N: Networklike + Send + 'static> SOCKSv5Server<N> {
pub fn new<S: Listenerlike<Error = N::Error> + 'static>(
network: N,
security_parameters: SecurityParameters, security_parameters: SecurityParameters,
stream: S, next_id: AtomicU64,
) -> SOCKSv5Server<N> { }
#[derive(Clone, Debug, Error, PartialEq)]
pub enum SOCKSv5ServerError {
#[error("Underlying networking error: {0}")]
NetworkingError(String),
#[error("Couldn't negotiate authentication with client.")]
ItsNotUsItsYou,
#[error("Client greeting read problem: {0}")]
GreetingReadProblem(#[from] ClientGreetingReadError),
#[error("Server choice write problem: {0}")]
ChoiceWriteProblem(#[from] ServerChoiceWriteError),
#[error("Failed username/password authentication for user {0}")]
FailedUsernamePassword(String),
#[error("Server authentication response problem: {0}")]
ServerAuthWriteProblem(#[from] ServerAuthResponseWriteError),
#[error("Error reading client username/password: {0}")]
UserPassReadProblem(#[from] ClientUsernamePasswordReadError),
#[error("Error reading client connection command: {0}")]
ClientConnReadProblem(#[from] ClientConnectionCommandReadError),
#[error("Error reading client connection request: {0}")]
ClientRequestReadProblem(#[from] ClientConnectionRequestReadError),
#[error("Error writing server response: {0}")]
ServerResponseWriteProblem(#[from] ServerResponseWriteError),
}
impl From<std::io::Error> for SOCKSv5ServerError {
fn from(x: std::io::Error) -> SOCKSv5ServerError {
SOCKSv5ServerError::NetworkingError(format!("{}", x))
}
}
impl SOCKSv5Server {
/// Initialize a SOCKSv5 server for use later on. Once initialized, you can listen
/// on as many addresses and ports as you like; the metadata about the server will
/// be synced across all the instances.
pub fn new(security_parameters: SecurityParameters) -> Self {
SOCKSv5Server { SOCKSv5Server {
network, info: Arc::new(ServerInfo {
security_parameters, security_parameters,
listener: GenericListener { next_id: AtomicU64::new(1),
internal: Box::new(stream), }),
},
} }
} }
pub async fn run(self) -> Result<(), N::Error> { /// Start a server on the given address and port. This function returns when it has
let (my_addr, my_port) = self.listener.local_addr(); /// set up its listening socket, but spawns a separate task to actually wait for
info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port); /// connections. You can query which ones are still active, or see which ones have
let locked_network = Arc::new(Mutex::new(self.network)); /// failed, using some of the other functions for this structure.
///
/// If you don't care what port is assigned to this server, pass 0 in as the port
/// number and one will be chosen for you by the OS.
///
pub async fn start<A: Send + Into<SOCKSv5Address>>(
&self,
addr: A,
port: u16,
) -> Result<JoinHandle<Result<(), std::io::Error>>, std::io::Error> {
let listener = match addr.into() {
SOCKSv5Address::IP4(x) => TcpListener::bind((x, port)).await?,
SOCKSv5Address::IP6(x) => TcpListener::bind((x, port)).await?,
SOCKSv5Address::Hostname(x) => TcpListener::bind((x, port)).await?,
};
let sockaddr = listener.local_addr()?;
tracing::info!(
"Starting SOCKSv5 server on {}:{}",
sockaddr.ip(),
sockaddr.port()
);
let second_life = self.clone();
Ok(tokio::task::spawn(async move {
second_life.server_loop(listener).await
}))
}
/// Run the server loop for a particular listener. This routine will never actually
/// return except in error conditions.
async fn server_loop(self, listener: TcpListener) -> Result<(), std::io::Error> {
let local_addr = listener.local_addr()?;
loop { loop {
let (stream, their_addr, their_port) = self.listener.accept().await?; let (socket, their_addr) = listener.accept().await?;
let accepted_span = info_span!(
trace!( "session",
"Initial accept of connection from {}:{}", server_address=?local_addr,
their_addr, remote_address=?their_addr,
their_port auth_method=field::Empty,
ident=field::Empty,
); );
if let Some(checker) = &self.security_parameters.allow_connection {
if !checker(&their_addr, their_port) {
info!(
"Rejecting attempted connection from {}:{}",
their_addr, their_port
);
continue;
}
}
let params = self.security_parameters.clone(); accepted_span.in_scope(|| {
let network_mutex_copy = locked_network.clone(); // before we do anything of note, make sure this connection is cool. we don't want
task::spawn(async move { // to waste any resources (and certainly don't want to handle any data!) if this
match run_authentication(params, stream, their_addr.clone(), their_port).await { // isn't someone we want to accept connections from.
Ok(authed_stream) => { if let Some(checker) = self.info.security_parameters.allow_connection {
match run_main_loop(network_mutex_copy, authed_stream).await { if !checker(&their_addr) {
Ok(_) => {} tracing::info!("Rejecting attempted connection from {}", their_addr,);
Err(e) => warn!("Failure in main loop: {}", e),
} }
} else {
// continue this work in another task. we could absolutely do this work here,
// but just in case someone starts doing slow responses (or other nasty things),
// we want to make sure that that doesn't slow down our ability to accept other
// requests.
let span_again = accepted_span.clone();
let me_again = self.clone();
tokio::task::spawn(async move {
let session_identifier =
me_again.info.next_id.fetch_add(1, Ordering::SeqCst);
span_again.record("ident", &session_identifier);
if let Err(e) = me_again
.start_authentication(their_addr, socket)
.instrument(span_again)
.await
{
tracing::error!("{}: server handler failure: {}", their_addr, e);
} }
Err(e) => warn!( });
"Failure running authentication from {}:{}: {}",
their_addr, their_port, e
),
} }
}); });
} }
} }
/// Start the authentication phase of the SOCKS handshake. This may be very short, and
/// is the first stage of handling a request. This will only really return on errors.
async fn start_authentication(
self,
their_addr: SocketAddr,
mut socket: TcpStream,
) -> Result<(), SOCKSv5ServerError> {
let greeting = ClientGreeting::read(&mut socket).await?;
match choose_authentication_method(
&self.info.security_parameters,
&greeting.acceptable_methods,
) {
// it's not us, it's you. (we're just going to say no.)
None => {
tracing::trace!(
"{}: Failed to find acceptable authentication method.",
their_addr,
);
let rejection_letter = ServerChoice::rejection();
rejection_letter.write(&mut socket).await?;
socket.flush().await?;
Err(SOCKSv5ServerError::ItsNotUsItsYou)
}
// the gold standard. great choice.
Some(ChosenMethod::TLS(_converter)) => {
Span::current().record("auth_method", &"TLS");
unimplemented!()
}
// well, I guess this is something?
Some(ChosenMethod::Password(checker)) => {
tracing::trace!(
"{}: Choosing username/password for authentication.",
their_addr,
);
let ok_lets_do_password =
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
ok_lets_do_password.write(&mut socket).await?;
socket.flush().await?;
let their_info = ClientUsernamePassword::read(&mut socket).await?;
if checker(&their_info.username, &their_info.password) {
let its_all_good = ServerAuthResponse::success();
its_all_good.write(&mut socket).await?;
socket.flush().await?;
Span::current().record("auth_method", &"password");
self.choose_mode(socket, their_addr).await
} else {
let yeah_no = ServerAuthResponse::failure();
yeah_no.write(&mut socket).await?;
socket.flush().await?;
Err(SOCKSv5ServerError::FailedUsernamePassword(
their_info.username,
))
}
}
// Um. I guess we're doing this unchecked. Yay?
Some(ChosenMethod::None) => {
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
nothin_i_guess.write(&mut socket).await?;
socket.flush().await?;
Span::current().record("auth_method", &"unauthenticated");
self.choose_mode(socket, their_addr).await
}
}
}
/// Determine which of the modes we might want this particular connection to run
/// in.
async fn choose_mode(
self,
mut socket: TcpStream,
their_addr: SocketAddr,
) -> Result<(), SOCKSv5ServerError> {
let ccr = ClientConnectionRequest::read(&mut socket).await?;
match ccr.command_code {
ClientConnectionCommand::AssociateUDPPort => {
self.handle_udp_request(socket, their_addr, ccr).await?
}
ClientConnectionCommand::EstablishTCPStream => {
self.handle_tcp_request(socket, ccr).await?
}
ClientConnectionCommand::EstablishTCPPortBinding => {
self.handle_tcp_binding_request(socket, their_addr, ccr)
.await?
}
}
Ok(())
}
/// Handle UDP forwarding requests
#[allow(unreachable_code)]
async fn handle_udp_request(
self,
stream: TcpStream,
their_addr: SocketAddr,
ccr: ClientConnectionRequest,
) -> Result<(), SOCKSv5ServerError> {
let my_addr = stream.local_addr()?;
tracing::info!(
"[{}:{}] Handling UDP bind request from {}:{}, seeking to bind towards {}:{}",
my_addr.ip(),
my_addr.port(),
their_addr.ip(),
their_addr.port(),
ccr.destination_address,
ccr.destination_port
);
let _socket = match ccr.destination_address.clone() {
SOCKSv5Address::IP4(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
SOCKSv5Address::IP6(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
SOCKSv5Address::Hostname(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
};
// OK, it worked. In order to mitigate an infinitesimal chance of a race condition, we're
// going to set up our forwarding tasks first, and then return the result to the user. (Note,
// we'd have to be slightly more precious in order to ensure a lack of race conditions, as
// the runtime could take forever to actually start these tasks, but I'm not ready to be
// bothered by this, yet. FIXME.)
unimplemented!();
// Cool; now we can get the result out to the user.
let bound_address = _socket.local_addr()?;
let response = ServerResponse {
status: ServerResponseStatus::RequestGranted,
bound_address: bound_address.ip().into(),
bound_port: bound_address.port(),
};
response.write(&mut stream).await?;
Ok(())
}
/// Handle TCP forwarding requests
async fn handle_tcp_request(
self,
mut stream: TcpStream,
ccr: ClientConnectionRequest,
) -> Result<(), SOCKSv5ServerError> {
// Let the user know that we're maybe making progress
tracing::info!(
"Handling TCP forward request to {}:{}",
ccr.destination_address,
ccr.destination_port
);
// OK, first thing's first: We need to actually connect to the server that the user
// wants us to connect to.
let outgoing_stream = match &ccr.destination_address {
SOCKSv5Address::IP4(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
SOCKSv5Address::IP6(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
SOCKSv5Address::Hostname(x) => {
TcpStream::connect((x.as_ref(), ccr.destination_port)).await?
}
};
let tcp_forwarding_span = info_span!(
"tcp_forwarding",
target_address=?ccr.destination_address,
target_port=ccr.destination_port
);
// Now, for whatever reason -- and this whole thing sent me down a garden path
// in understanding how this whole protocol works -- we tell the user what address
// and port we bound for that connection.
let bound_address = outgoing_stream.local_addr()?;
let response = ServerResponse {
status: ServerResponseStatus::RequestGranted,
bound_address: bound_address.ip().into(),
bound_port: bound_address.port(),
};
response
.write(&mut stream)
.instrument(tcp_forwarding_span.clone())
.await?;
// so now tie our streams together, and we're good to go
tie_streams(stream, outgoing_stream)
.instrument(tcp_forwarding_span)
.await;
Ok(())
}
/// Handle TCP binding requests
async fn handle_tcp_binding_request(
self,
mut stream: TcpStream,
their_addr: SocketAddr,
ccr: ClientConnectionRequest,
) -> Result<(), SOCKSv5ServerError> {
// Let the user know that we're maybe making progress
let my_addr = stream.local_addr()?;
tracing::info!(
"[{}] Handling TCP bind request from {}, seeking to bind {}:{}",
my_addr,
their_addr,
ccr.destination_address,
ccr.destination_port
);
// OK, we have to bind the darn socket first.
let listener_port = match &their_addr {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}?;
// FIXME: Might want to bind on a particular interface, based on a
// config flag, at some point.
let listener = listener_port.listen(1)?;
// Tell them what we bound, just in case they want to inform anyone.
let bound_address = listener.local_addr()?;
let response = ServerResponse {
status: ServerResponseStatus::RequestGranted,
bound_address: bound_address.ip().into(),
bound_port: bound_address.port(),
};
response.write(&mut stream).await?;
// Wait politely for someone to talk to us.
let (other, other_addr) = listener.accept().await?;
let info = ServerResponse {
status: ServerResponseStatus::RequestGranted,
bound_address: other_addr.ip().into(),
bound_port: other_addr.port(),
};
info.write(&mut stream).await?;
tie_streams(stream, other).await;
Ok(())
}
}
async fn tie_streams(mut left: TcpStream, mut right: TcpStream) {
let span_copy = Span::current();
tracing::info!("linking forwarding streams");
tokio::task::spawn(
async move {
match copy_bidirectional(&mut left, &mut right)
.instrument(span_copy)
.await
{
Ok((l2r, r2l)) => {
tracing::info!(sent = l2r, received = r2l, "shutting down streams")
}
Err(e) => tracing::warn!("Linked streams shut down with error: {}", e),
}
}
.instrument(Span::current()),
);
} }
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
enum ChosenMethod { enum ChosenMethod {
TLS(fn(GenericStream) -> Option<GenericStream>), TLS(fn() -> Option<()>),
Password(fn(&str, &str) -> bool), Password(fn(&str, &str) -> bool),
None, None,
} }
@@ -209,7 +514,7 @@ fn reasonable_auth_method_choices() {
); );
// OK, cool. If we have a TLS handler, that shouldn't actually make a difference. // OK, cool. If we have a TLS handler, that shouldn't actually make a difference.
params.connect_tls = Some(|_| unimplemented!()); params.connect_tls = Some(|| unimplemented!());
assert_eq!( assert_eq!(
choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from), choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from),
None None
@@ -222,7 +527,7 @@ fn reasonable_auth_method_choices() {
None None
); );
// but if we have a handler, and they go for it, we use it. // but if we have a handler, and they go for it, we use it.
params.connect_tls = Some(|_| unimplemented!()); params.connect_tls = Some(|| unimplemented!());
assert_eq!( assert_eq!(
choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from), choose_authentication_method(&params, &client_suggestions).map(AuthenticationMethod::from),
Some(AuthenticationMethod::SSL) Some(AuthenticationMethod::SSL)
@@ -242,212 +547,3 @@ fn reasonable_auth_method_choices() {
Some(AuthenticationMethod::SSL) Some(AuthenticationMethod::SSL)
); );
} }
async fn run_authentication(
params: SecurityParameters,
mut stream: GenericStream,
addr: SOCKSv5Address,
port: u16,
) -> Result<GenericStream, AuthenticationError> {
// before we do anything at all, we check to see if we just want to blindly reject
// this connection, utterly and completely.
if let Some(firewall_allows) = params.allow_connection {
if !firewall_allows(&addr, port) {
return Err(AuthenticationError::FirewallRejected(addr, port));
}
}
// OK, I guess we'll listen to you
let greeting = ClientGreeting::read(&mut stream).await?;
match choose_authentication_method(&params, &greeting.acceptable_methods) {
// it's not us, it's you
None => {
trace!("Failed to find acceptable authentication method.");
let rejection_letter = ServerChoice::rejection();
rejection_letter.write(&mut stream).await?;
stream.flush().await?;
Err(AuthenticationError::ItsNotUsItsYou)
}
// the gold standard. great choice.
Some(ChosenMethod::TLS(converter)) => {
trace!("Choosing TLS for authentication.");
let lets_do_this = ServerChoice::option(AuthenticationMethod::SSL);
lets_do_this.write(&mut stream).await?;
stream.flush().await?;
converter(stream).ok_or(AuthenticationError::FailedTLSHandshake)
}
// well, I guess this is something?
Some(ChosenMethod::Password(checker)) => {
trace!("Choosing Username/Password for authentication.");
let ok_lets_do_password =
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
ok_lets_do_password.write(&mut stream).await?;
stream.flush().await?;
let their_info = ClientUsernamePassword::read(&mut stream).await?;
if checker(&their_info.username, &their_info.password) {
let its_all_good = ServerAuthResponse::success();
its_all_good.write(&mut stream).await?;
stream.flush().await?;
Ok(stream)
} else {
let yeah_no = ServerAuthResponse::failure();
yeah_no.write(&mut stream).await?;
stream.flush().await?;
Err(AuthenticationError::FailedUsernamePassword(
their_info.username,
))
}
}
Some(ChosenMethod::None) => {
trace!("Just skipping the whole authentication thing.");
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
nothin_i_guess.write(&mut stream).await?;
stream.flush().await?;
Ok(stream)
}
}
}
#[derive(Error, Debug)]
enum ServerError {
#[error("Error in deserialization: {0}")]
DeserializationError(#[from] DeserializationError),
#[error("Error in serialization: {0}")]
SerializationError(#[from] SerializationError),
}
async fn run_main_loop<N>(
network: Arc<Mutex<N>>,
mut stream: GenericStream,
) -> Result<(), ServerError>
where
N: Networklike,
N::Error: 'static,
{
loop {
let ccr = ClientConnectionRequest::read(&mut stream).await?;
match ccr.command_code {
ClientConnectionCommand::AssociateUDPPort => {}
ClientConnectionCommand::EstablishTCPPortBinding => {}
ClientConnectionCommand::EstablishTCPStream => {
let target = format!("{}:{}", ccr.destination_address, ccr.destination_port);
info!(
"Client requested connection to {}:{}",
ccr.destination_address, ccr.destination_port
);
let connection_res = {
let mut network = network.lock().await;
network
.connect(ccr.destination_address.clone(), ccr.destination_port)
.await
};
let outgoing_stream = match connection_res {
Ok(x) => x,
Err(e) => {
error!("Failed to connect to {}: {}", target, e);
let response = ServerResponse::error(e);
response.write(&mut stream).await?;
continue;
}
};
trace!(
"Connection established to {}:{}",
ccr.destination_address,
ccr.destination_port
);
let incoming_res = {
let mut network = network.lock().await;
network.listen("127.0.0.1", 0).await
};
let incoming_listener = match incoming_res {
Ok(x) => x,
Err(e) => {
error!("Failed to bind server port for new TCP stream: {}", e);
let response = ServerResponse::error(e);
response.write(&mut stream).await?;
continue;
}
};
let (bound_address, bound_port) = incoming_listener.local_addr();
trace!(
"Set up {}:{} to address request for {}:{}",
bound_address,
bound_port,
ccr.destination_address,
ccr.destination_port
);
let response = ServerResponse {
status: ServerResponseStatus::RequestGranted,
bound_address,
bound_port,
};
response.write(&mut stream).await?;
task::spawn(async move {
let (incoming_stream, from_addr, from_port) = match incoming_listener
.accept()
.await
{
Err(e) => {
error!("Miscellaneous error waiting for someone to connect for proxying: {}", e);
return;
}
Ok(s) => s,
};
trace!(
"Accepted connection from {}:{} to attach to {}:{}",
from_addr,
from_port,
ccr.destination_address,
ccr.destination_port
);
let mut from_left = incoming_stream.clone();
let mut from_right = outgoing_stream.clone();
let mut to_left = incoming_stream;
let mut to_right = outgoing_stream;
let from = format!("{}:{}", from_addr, from_port);
let to = format!("{}:{}", ccr.destination_address, ccr.destination_port);
task::spawn(async move {
info!(
"Spawned {}:{} >--> {}:{} task",
from_addr, from_port, ccr.destination_address, ccr.destination_port
);
if let Err(e) = io::copy(&mut from_left, &mut to_right).await {
warn!(
"{}:{} >--> {}:{} connection failed with: {}",
from_addr,
from_port,
ccr.destination_address,
ccr.destination_port,
e
);
}
});
task::spawn(async move {
info!("Spawned {} <--< {} task", from, to);
if let Err(e) = io::copy(&mut from_right, &mut to_left).await {
warn!("{} <--< {} connection failed with: {}", from, to, e);
}
});
});
}
}
}
}