Compare commits
17 Commits
feature/gi
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
| 65e79b4237 | |||
| 1d182a150f | |||
| 277125e1a0 | |||
| c8279cfc5f | |||
| d284f60d67 | |||
| 811580c64f | |||
| aa414fd527 | |||
| 8ac3f52546 | |||
| bac2c33aee | |||
| 3737d0739d | |||
| 74f66ef747 | |||
| 774591cb54 | |||
| c05b0f2b74 | |||
| abff1a4ec1 | |||
| ac11ae64a8 | |||
| 67b2acab25 | |||
| 58c04adeb7 |
29
Cargo.toml
29
Cargo.toml
@@ -7,11 +7,26 @@ edition = "2018"
|
||||
[lib]
|
||||
name = "async_socks5"
|
||||
|
||||
[[bin]]
|
||||
name="socks5-server"
|
||||
path="server/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-std = { version = "1.9.0", features = ["attributes"] }
|
||||
async-trait = "0.1.50"
|
||||
futures = "0.3.15"
|
||||
log = "0.4.8"
|
||||
quickcheck = "1.0.3"
|
||||
simplelog = "0.10.0"
|
||||
thiserror = "1.0.24"
|
||||
anyhow = "^1.0.57"
|
||||
clap = { version = "^3.1.18", features = ["derive"] }
|
||||
etcetera = "^0.4.0"
|
||||
futures = "0.3.21"
|
||||
if-addrs = "0.7.0"
|
||||
lazy_static = "1.4.0"
|
||||
proptest = "^1.0.0"
|
||||
serde = "^1.0.137"
|
||||
serde_derive = "^1.0.137"
|
||||
thiserror = "^1.0.31"
|
||||
tokio = { version = "^1", features = ["full"] }
|
||||
toml = "^0.5.9"
|
||||
tracing = "^0.1.34"
|
||||
tracing-subscriber = { version = "^0.3.11", features = ["env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
proptest = "1.0.0"
|
||||
proptest-derive = "0.3.0"
|
||||
|
||||
2
TODO
Normal file
2
TODO
Normal file
@@ -0,0 +1,2 @@
|
||||
* [ ] Turn `write` from &self to self
|
||||
* [ ] Turn `read`/`write` into a typeclass of some kind.
|
||||
15
sample.toml
Normal file
15
sample.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
# Unless otherwise sepcified, use this log level
|
||||
log_level = "TRACE"
|
||||
# Unless the command line specifies servers (or includes --validate),
|
||||
# start the following named servers.
|
||||
start_servers = "loopback*,ethernet"
|
||||
|
||||
[loopback4]
|
||||
interface="lo0"
|
||||
log_level="TRACE"
|
||||
address="127.0.0.1"
|
||||
|
||||
[loopback6]
|
||||
interface="lo0"
|
||||
log_level="DEBUG"
|
||||
address="::1"
|
||||
176
server/config.rs
Normal file
176
server/config.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
mod cmdline;
|
||||
mod config_file;
|
||||
|
||||
use self::cmdline::Arguments;
|
||||
use self::config_file::ConfigFile;
|
||||
use clap::Parser;
|
||||
use if_addrs::IfAddr;
|
||||
use std::io;
|
||||
use std::net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::str::FromStr;
|
||||
use thiserror::Error;
|
||||
use tracing::metadata::LevelFilter;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConfigError {
|
||||
#[error(transparent)]
|
||||
CommandLineError(#[from] clap::Error),
|
||||
#[error(transparent)]
|
||||
IOError(#[from] io::Error),
|
||||
#[error("TOML processing error: {0}")]
|
||||
TomlError(#[from] toml::de::Error),
|
||||
#[error("Server '{0}' specifies an interface ({1}) with no addresses")]
|
||||
NoAddressForInterface(String, String),
|
||||
#[error("Server '{0}' specifies an address we couldn't parse: {1}")]
|
||||
AddressParseError(String, AddrParseError),
|
||||
#[error("Host directory error {0}")]
|
||||
HostDirectoryError(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub log_level: LevelFilter,
|
||||
pub server_definitions: Vec<ServerDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ServerDefinition {
|
||||
pub name: String,
|
||||
pub start: bool,
|
||||
pub interface: Option<String>,
|
||||
pub address: SocketAddr,
|
||||
pub log_level: LevelFilter,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Generate a configuration by reading the command line arguments and any
|
||||
/// defined config file, generating the actual arguments that we'll use for
|
||||
/// operating the daemon.
|
||||
pub fn derive() -> Result<Self, ConfigError> {
|
||||
let command_line = Arguments::try_parse()?;
|
||||
let mut config_file = ConfigFile::read(command_line.config_file)?;
|
||||
let nic_addresses = if_addrs::get_if_addrs()?;
|
||||
|
||||
let log_level = command_line
|
||||
.log_level
|
||||
.or(config_file.log_level)
|
||||
.unwrap_or(LevelFilter::ERROR);
|
||||
|
||||
let mut server_definitions = Vec::new();
|
||||
let servers_to_start: Vec<String> = config_file
|
||||
.start_servers
|
||||
.map(|x| x.split(',').map(|v| v.to_string()).collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
for (name, config_info) in config_file.servers.drain() {
|
||||
let start = servers_to_start.contains(&name);
|
||||
let log_level = config_info.log_level.unwrap_or(log_level);
|
||||
let port = config_info.port.unwrap_or(1080);
|
||||
let mut interface = None;
|
||||
let mut address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port));
|
||||
|
||||
match (config_info.interface, config_info.address) {
|
||||
// if the user provides us nothing, we'll just use a blank address and
|
||||
// no interface association
|
||||
(None, None) => {}
|
||||
|
||||
// if the user provides us an interface but no address, we'll see if we can
|
||||
// find the interface and pull a reasonable address from it.
|
||||
(Some(given_interface), None) => {
|
||||
let mut found_it = false;
|
||||
|
||||
for card_interface in nic_addresses.iter() {
|
||||
if card_interface.name == given_interface {
|
||||
interface = Some(given_interface.clone());
|
||||
address = SocketAddr::new(addr_convert(&card_interface.addr), port);
|
||||
found_it = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !found_it {
|
||||
return Err(ConfigError::NoAddressForInterface(name, given_interface));
|
||||
}
|
||||
}
|
||||
|
||||
// if the user provides us an address but no interface, we'll quickly see if
|
||||
// we can find that address in our interface list ... but we won't insist on
|
||||
// it.
|
||||
(None, Some(address_string)) => {
|
||||
let read_address = IpAddr::from_str(&address_string)
|
||||
.map_err(|x| ConfigError::AddressParseError(name.clone(), x))?;
|
||||
|
||||
interface = None;
|
||||
address = SocketAddr::new(read_address, port);
|
||||
for card_interface in nic_addresses.iter() {
|
||||
if addrs_match(&card_interface.addr, &read_address) {
|
||||
interface = Some(card_interface.name.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if the user provides both, we'll check to make sure that they match.
|
||||
(Some(given_interface), Some(address_string)) => {
|
||||
let read_address = IpAddr::from_str(&address_string)
|
||||
.map_err(|x| ConfigError::AddressParseError(name.clone(), x))?;
|
||||
let mut inferred_interface = None;
|
||||
let mut good_to_go = false;
|
||||
|
||||
address = SocketAddr::new(read_address, port);
|
||||
for card_interface in nic_addresses.iter() {
|
||||
if addrs_match(&card_interface.addr, &read_address) {
|
||||
if card_interface.name == given_interface {
|
||||
interface = Some(given_interface.clone());
|
||||
good_to_go = true;
|
||||
break;
|
||||
} else {
|
||||
inferred_interface = Some(card_interface.name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !good_to_go {
|
||||
if let Some(inferred_interface) = inferred_interface {
|
||||
tracing::warn!("Address {} is associated with interface {}, not {}; using it instead", read_address, inferred_interface, given_interface);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Address {} is not associated with interface {}, or any interface.",
|
||||
read_address,
|
||||
given_interface
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
server_definitions.push(ServerDefinition {
|
||||
name,
|
||||
start,
|
||||
interface,
|
||||
address,
|
||||
log_level,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Config {
|
||||
log_level,
|
||||
server_definitions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn addr_convert(x: &if_addrs::IfAddr) -> IpAddr {
|
||||
match x {
|
||||
if_addrs::IfAddr::V4(x) => IpAddr::V4(x.ip),
|
||||
if_addrs::IfAddr::V6(x) => IpAddr::V6(x.ip),
|
||||
}
|
||||
}
|
||||
|
||||
fn addrs_match(x: &if_addrs::IfAddr, y: &IpAddr) -> bool {
|
||||
match (x, y) {
|
||||
(IfAddr::V4(x), IpAddr::V4(y)) => &x.ip == y,
|
||||
(IfAddr::V6(x), IpAddr::V6(y)) => &x.ip == y,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
35
server/config/cmdline.rs
Normal file
35
server/config/cmdline.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use clap::Parser;
|
||||
use std::path::PathBuf;
|
||||
use tracing::metadata::LevelFilter;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
pub struct Arguments {
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
help = "Use the given config file, rather than $XDG_CONFIG_DIR/socks5.toml"
|
||||
)]
|
||||
pub config_file: Option<PathBuf>,
|
||||
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
help = "Default logging to the given level. (Defaults to ERROR if not given)"
|
||||
)]
|
||||
pub log_level: Option<LevelFilter>,
|
||||
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
help = "Start only the named server(s) from the config file. For more than one, use comma-separated values or multiple instances of --start"
|
||||
)]
|
||||
pub start: Vec<String>,
|
||||
|
||||
#[clap(
|
||||
short,
|
||||
long = "validate",
|
||||
help = "Do not actually start any servers; just validate the config file."
|
||||
)]
|
||||
pub validate_only: bool,
|
||||
}
|
||||
64
server/config/config_file.rs
Normal file
64
server/config/config_file.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use super::ConfigError;
|
||||
use etcetera::base_strategy::{choose_base_strategy, BaseStrategy};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tracing::metadata::LevelFilter;
|
||||
|
||||
#[derive(serde_derive::Deserialize, Default)]
|
||||
pub struct ConfigFile {
|
||||
#[serde(deserialize_with = "parse_log_level")]
|
||||
pub log_level: Option<LevelFilter>,
|
||||
pub start_servers: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub servers: HashMap<String, ServerDefinition>,
|
||||
}
|
||||
|
||||
#[derive(serde_derive::Deserialize)]
|
||||
pub struct ServerDefinition {
|
||||
pub interface: Option<String>,
|
||||
#[serde(deserialize_with = "parse_log_level")]
|
||||
pub log_level: Option<LevelFilter>,
|
||||
pub address: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
}
|
||||
|
||||
fn parse_log_level<'de, D>(deserializer: D) -> Result<Option<LevelFilter>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let possible_string: Option<&str> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
if let Some(s) = possible_string {
|
||||
Ok(Some(s.parse().map_err(|e| {
|
||||
serde::de::Error::custom(format!("Couldn't parse log level '{}': {}", s, e))
|
||||
})?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigFile {
|
||||
pub fn read(mut config_file_path: Option<PathBuf>) -> Result<ConfigFile, ConfigError> {
|
||||
if config_file_path.is_none() {
|
||||
let base_dirs = choose_base_strategy()
|
||||
.map_err(|e| ConfigError::HostDirectoryError(e.to_string()))?;
|
||||
let mut proposed_path = base_dirs.config_dir();
|
||||
proposed_path.push("socks5");
|
||||
if let Ok(attributes) = fs::metadata(proposed_path.clone()) {
|
||||
if attributes.is_file() {
|
||||
config_file_path = Some(proposed_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match config_file_path {
|
||||
None => Ok(ConfigFile::default()),
|
||||
Some(path) => {
|
||||
let content = fs::read(path)?;
|
||||
Ok(toml::from_slice(&content)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
74
server/main.rs
Normal file
74
server/main.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
mod config;
|
||||
|
||||
use async_socks5::server::{SOCKSv5Server, SecurityParameters};
|
||||
use config::Config;
|
||||
use tracing::Instrument;
|
||||
use tracing_subscriber::filter::EnvFilter;
|
||||
use tracing_subscriber::fmt;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
let config = Config::derive()?;
|
||||
|
||||
let fmt_layer = fmt::layer().with_target(false);
|
||||
let filter_layer = EnvFilter::builder()
|
||||
.with_default_directive(config.log_level.into())
|
||||
.from_env()?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(filter_layer)
|
||||
.with(fmt_layer)
|
||||
.init();
|
||||
|
||||
tracing::trace!("Parsed configuration: {:?}", config);
|
||||
|
||||
let core_server = SOCKSv5Server::new(SecurityParameters {
|
||||
allow_unauthenticated: true,
|
||||
allow_connection: None,
|
||||
check_password: None,
|
||||
connect_tls: None,
|
||||
});
|
||||
|
||||
let mut running_servers = vec![];
|
||||
|
||||
for server_def in config.server_definitions {
|
||||
let span = tracing::trace_span!(
|
||||
"",
|
||||
server_name = %server_def.name,
|
||||
interface = ?server_def.interface,
|
||||
address = %server_def.address,
|
||||
);
|
||||
|
||||
let result = core_server
|
||||
.start(server_def.address.ip(), server_def.address.port())
|
||||
.instrument(span)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(x) => running_servers.push(x),
|
||||
Err(e) => tracing::error!(
|
||||
server = %server_def.name,
|
||||
interface = ?server_def.interface,
|
||||
address = %server_def.address,
|
||||
"Failure in launching server: {}",
|
||||
e
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
while !running_servers.is_empty() {
|
||||
let (initial_result, _idx, next_runners) =
|
||||
futures::future::select_all(running_servers).await;
|
||||
|
||||
match initial_result {
|
||||
Ok(Ok(())) => tracing::info!("server completed successfully"),
|
||||
Ok(Err(e)) => tracing::error!("error in running server: {}", e),
|
||||
Err(e) => tracing::error!("error joining server: {}", e),
|
||||
}
|
||||
|
||||
running_servers = next_runners;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
174
src/address.rs
Normal file
174
src/address.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
|
||||
use std::fmt;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub enum SOCKSv5Address {
|
||||
IP4(Ipv4Addr),
|
||||
IP6(Ipv6Addr),
|
||||
Hostname(String),
|
||||
}
|
||||
|
||||
impl From<IpAddr> for SOCKSv5Address {
|
||||
fn from(x: IpAddr) -> SOCKSv5Address {
|
||||
match x {
|
||||
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
|
||||
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Ipv4Addr> for SOCKSv5Address {
|
||||
fn from(x: Ipv4Addr) -> SOCKSv5Address {
|
||||
SOCKSv5Address::IP4(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Ipv6Addr> for SOCKSv5Address {
|
||||
fn from(x: Ipv6Addr) -> SOCKSv5Address {
|
||||
SOCKSv5Address::IP6(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SOCKSv5String> for SOCKSv5Address {
|
||||
fn from(x: SOCKSv5String) -> SOCKSv5Address {
|
||||
SOCKSv5Address::Hostname(x.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for SOCKSv5Address {
|
||||
fn from(x: &str) -> SOCKSv5Address {
|
||||
SOCKSv5Address::Hostname(x.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for SOCKSv5Address {
|
||||
fn from(x: String) -> SOCKSv5Address {
|
||||
SOCKSv5Address::Hostname(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SOCKSv5Address {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
|
||||
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
|
||||
SOCKSv5Address::Hostname(a) => write!(f, "{}", a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+";
|
||||
|
||||
#[cfg(test)]
|
||||
use proptest::prelude::{any, prop_oneof, Arbitrary, BoxedStrategy, Strategy};
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for SOCKSv5Address {
|
||||
type Parameters = Option<u16>;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
let max_len = args.unwrap_or(32) as usize;
|
||||
|
||||
prop_oneof![
|
||||
any::<Ipv4Addr>().prop_map(SOCKSv5Address::IP4),
|
||||
any::<Ipv6Addr>().prop_map(SOCKSv5Address::IP6),
|
||||
HOSTNAME_REGEX.prop_map(move |mut hostname| {
|
||||
hostname.shrink_to(max_len);
|
||||
SOCKSv5Address::Hostname(hostname)
|
||||
}),
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum SOCKSv5AddressReadError {
|
||||
#[error("Bad address type {0} (expected 1, 3, or 4)")]
|
||||
BadAddressType(u8),
|
||||
#[error("Read buffer error: {0}")]
|
||||
ReadError(String),
|
||||
#[error(transparent)]
|
||||
SOCKSv5StringError(#[from] SOCKSv5StringReadError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for SOCKSv5AddressReadError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5AddressReadError {
|
||||
SOCKSv5AddressReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum SOCKSv5AddressWriteError {
|
||||
#[error(transparent)]
|
||||
SOCKSv5StringError(#[from] SOCKSv5StringWriteError),
|
||||
#[error("Write buffer error: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for SOCKSv5AddressWriteError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5AddressWriteError {
|
||||
SOCKSv5AddressWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl SOCKSv5Address {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, SOCKSv5AddressReadError> {
|
||||
match r.read_u8().await? {
|
||||
1 => {
|
||||
let mut addr_buffer = [0; 4];
|
||||
r.read_exact(&mut addr_buffer).await?;
|
||||
let ip4 = Ipv4Addr::from(addr_buffer);
|
||||
Ok(SOCKSv5Address::IP4(ip4))
|
||||
}
|
||||
|
||||
3 => {
|
||||
let string = SOCKSv5String::read(r).await?;
|
||||
Ok(SOCKSv5Address::from(string))
|
||||
}
|
||||
|
||||
4 => {
|
||||
let mut addr_buffer = [0; 16];
|
||||
r.read_exact(&mut addr_buffer).await?;
|
||||
let ip6 = Ipv6Addr::from(addr_buffer);
|
||||
Ok(SOCKSv5Address::IP6(ip6))
|
||||
}
|
||||
|
||||
x => Err(SOCKSv5AddressReadError::BadAddressType(x)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SOCKSv5AddressWriteError> {
|
||||
match self {
|
||||
SOCKSv5Address::IP4(x) => {
|
||||
w.write_u8(1).await?;
|
||||
w.write_all(&x.octets()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SOCKSv5Address::IP6(x) => {
|
||||
w.write_u8(4).await?;
|
||||
w.write_all(&x.octets()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SOCKSv5Address::Hostname(x) => {
|
||||
w.write_u8(3).await?;
|
||||
let string = SOCKSv5String::from(x.clone());
|
||||
string.write(w).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crate::standard_roundtrip!(socks_address_roundtrips, SOCKSv5Address);
|
||||
@@ -1,30 +0,0 @@
|
||||
use async_socks5::network::Builtin;
|
||||
use async_socks5::server::{SOCKSv5Server, SecurityParameters};
|
||||
use async_std::io;
|
||||
use async_std::net::TcpListener;
|
||||
use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TermLogger, TerminalMode};
|
||||
|
||||
#[async_std::main]
|
||||
async fn main() -> Result<(), io::Error> {
|
||||
CombinedLogger::init(vec![TermLogger::new(
|
||||
LevelFilter::Debug,
|
||||
Config::default(),
|
||||
TerminalMode::Mixed,
|
||||
ColorChoice::Auto,
|
||||
)])
|
||||
.expect("Couldn't initialize logger");
|
||||
|
||||
let main_listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let params = SecurityParameters {
|
||||
allow_unauthenticated: false,
|
||||
allow_connection: None,
|
||||
check_password: None,
|
||||
connect_tls: None,
|
||||
};
|
||||
|
||||
let server = SOCKSv5Server::new(Builtin::new(), params, main_listener);
|
||||
|
||||
server.run().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
242
src/client.rs
242
src/client.rs
@@ -1,43 +1,65 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::address::SOCKSv5Address;
|
||||
use crate::messages::{
|
||||
AuthenticationMethod, ClientGreeting, ClientUsernamePassword, ServerAuthResponse, ServerChoice,
|
||||
AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandWriteError,
|
||||
ClientConnectionRequest, ClientGreeting, ClientGreetingWriteError, ClientUsernamePassword,
|
||||
ClientUsernamePasswordWriteError, ServerAuthResponse, ServerAuthResponseReadError,
|
||||
ServerChoice, ServerChoiceReadError, ServerResponse, ServerResponseReadError,
|
||||
ServerResponseStatus,
|
||||
};
|
||||
use crate::network::generic::Networklike;
|
||||
use futures::io::{AsyncRead, AsyncWrite};
|
||||
use log::{trace, warn};
|
||||
use std::future::Future;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SOCKSv5Error {
|
||||
#[error("SOCKSv5 serialization error: {0}")]
|
||||
SerializationError(#[from] SerializationError),
|
||||
#[error("SOCKSv5 deserialization error: {0}")]
|
||||
DeserializationError(#[from] DeserializationError),
|
||||
#[error("No acceptable authentication methods available")]
|
||||
NoAuthMethodsAllowed,
|
||||
pub enum SOCKSv5ClientError {
|
||||
#[error("Underlying networking error: {0}")]
|
||||
NetworkingError(String),
|
||||
#[error("Client greeting write error: {0}")]
|
||||
ClientWriteError(#[from] ClientGreetingWriteError),
|
||||
#[error("Server choice error: {0}")]
|
||||
ServerChoiceError(#[from] ServerChoiceReadError),
|
||||
#[error("Error writing credentials: {0}")]
|
||||
CredentialWriteError(#[from] ClientUsernamePasswordWriteError),
|
||||
#[error("Server auth read error: {0}")]
|
||||
AuthResponseError(#[from] ServerAuthResponseReadError),
|
||||
#[error("Authentication failed")]
|
||||
AuthenticationFailed,
|
||||
#[error("Server chose an unsupported authentication method ({0}")]
|
||||
#[error("No authentication methods allowed")]
|
||||
NoAuthMethodsAllowed,
|
||||
#[error("Unsupported authentication method chosen ({0})")]
|
||||
UnsupportedAuthMethodChosen(AuthenticationMethod),
|
||||
#[error("Client connection command write error: {0}")]
|
||||
ClientCommandWriteError(#[from] ClientConnectionCommandWriteError),
|
||||
#[error("Server said no: {0}")]
|
||||
ServerFailure(#[from] ServerResponseStatus),
|
||||
ServerRejected(#[from] ServerResponseStatus),
|
||||
#[error("Server response read failure: {0}")]
|
||||
ServerResponseError(#[from] ServerResponseReadError),
|
||||
}
|
||||
|
||||
pub struct SOCKSv5Client<S, N>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
N: Networklike,
|
||||
{
|
||||
_network: N,
|
||||
_stream: S,
|
||||
impl From<std::io::Error> for SOCKSv5ClientError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5ClientError {
|
||||
SOCKSv5ClientError::NetworkingError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LoginInfo {
|
||||
pub username_password: Option<UsernamePassword>,
|
||||
}
|
||||
|
||||
impl Default for LoginInfo {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl LoginInfo {
|
||||
/// Generate an empty bit of login information.
|
||||
pub fn new() -> LoginInfo {
|
||||
LoginInfo {
|
||||
username_password: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Turn this information into a list of authentication methods that we can handle,
|
||||
/// to send to the server. The RFC isn't super clear if the order of these matters
|
||||
/// at all, but we'll try to keep it in our preferred order.
|
||||
@@ -57,61 +79,191 @@ pub struct UsernamePassword {
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl<S, N> SOCKSv5Client<S, N>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Send + Unpin,
|
||||
N: Networklike,
|
||||
{
|
||||
pub struct SOCKSv5Client {
|
||||
login_info: LoginInfo,
|
||||
address: SOCKSv5Address,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
impl SOCKSv5Client {
|
||||
/// Create a new SOCKSv5 client connection over the given steam, using the given
|
||||
/// authentication information.
|
||||
pub async fn new(_network: N, mut stream: S, login: &LoginInfo) -> Result<Self, SOCKSv5Error> {
|
||||
let acceptable_methods = login.acceptable_methods();
|
||||
trace!(
|
||||
/// authentication information. As part of the process of building this object, we
|
||||
/// do a little test run to make sure that we can login effectively; this should save
|
||||
/// from *some* surprises later on. If you'd rather *not* do that, though, you can
|
||||
/// try `unchecked_new`.
|
||||
pub async fn new<A: Into<SOCKSv5Address>>(
|
||||
login: LoginInfo,
|
||||
server_addr: A,
|
||||
server_port: u16,
|
||||
) -> Result<Self, SOCKSv5ClientError> {
|
||||
let base_version = SOCKSv5Client::unchecked_new(login, server_addr, server_port);
|
||||
let _ = base_version.start_session().await?;
|
||||
Ok(base_version)
|
||||
}
|
||||
/// Create a new SOCKSv5Client within the given parameters, but don't do a quick
|
||||
/// check to see if this connection has a chance of working. This saves you a TCP
|
||||
/// connection sequence at the expense of an increased possibility of an error
|
||||
/// later on down the road.
|
||||
pub fn unchecked_new<A: Into<SOCKSv5Address>>(
|
||||
login_info: LoginInfo,
|
||||
address: A,
|
||||
port: u16,
|
||||
) -> Self {
|
||||
SOCKSv5Client {
|
||||
login_info,
|
||||
address: address.into(),
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// This runs the connection and negotiates login, as required, and then returns
|
||||
/// the stream the caller should use to do ... whatever it wants to do.
|
||||
async fn start_session(&self) -> Result<TcpStream, SOCKSv5ClientError> {
|
||||
// create the initial stream
|
||||
let mut stream = match &self.address {
|
||||
SOCKSv5Address::IP4(x) => TcpStream::connect((*x, self.port)).await?,
|
||||
SOCKSv5Address::IP6(x) => TcpStream::connect((*x, self.port)).await?,
|
||||
SOCKSv5Address::Hostname(x) => TcpStream::connect((x.as_ref(), self.port)).await?,
|
||||
};
|
||||
|
||||
// compute how we can log in
|
||||
let acceptable_methods = self.login_info.acceptable_methods();
|
||||
tracing::trace!(
|
||||
"Computed acceptable methods -- {:?} -- sending client greeting.",
|
||||
acceptable_methods
|
||||
);
|
||||
|
||||
// Negotiate with the server. Well. "Negotiate."
|
||||
let client_greeting = ClientGreeting { acceptable_methods };
|
||||
client_greeting.write(&mut stream).await?;
|
||||
trace!("Write client greeting, waiting for server's choice.");
|
||||
tracing::trace!("Write client greeting, waiting for server's choice.");
|
||||
let server_choice = ServerChoice::read(&mut stream).await?;
|
||||
trace!("Received server's choice: {}", server_choice.chosen_method);
|
||||
tracing::trace!("Received server's choice: {}", server_choice.chosen_method);
|
||||
|
||||
// Let's do it!
|
||||
match server_choice.chosen_method {
|
||||
AuthenticationMethod::None => {}
|
||||
|
||||
AuthenticationMethod::UsernameAndPassword => {
|
||||
let (username, password) = if let Some(ref linfo) = login.username_password {
|
||||
trace!("Server requested username/password, getting data from login info.");
|
||||
let (username, password) = if let Some(ref linfo) =
|
||||
self.login_info.username_password
|
||||
{
|
||||
tracing::trace!(
|
||||
"Server requested username/password, getting data from login info."
|
||||
);
|
||||
(linfo.username.clone(), linfo.password.clone())
|
||||
} else {
|
||||
warn!("Server requested username/password, but we weren't provided one. Very weird.");
|
||||
tracing::warn!("Server requested username/password, but we weren't provided one. Very weird.");
|
||||
("".to_string(), "".to_string())
|
||||
};
|
||||
|
||||
let auth_request = ClientUsernamePassword { username, password };
|
||||
|
||||
trace!("Writing password information.");
|
||||
tracing::trace!("Writing password information.");
|
||||
auth_request.write(&mut stream).await?;
|
||||
let server_response = ServerAuthResponse::read(&mut stream).await?;
|
||||
trace!("Got server response: {}", server_response.success);
|
||||
tracing::trace!("Got server response: {}", server_response.success);
|
||||
|
||||
if !server_response.success {
|
||||
return Err(SOCKSv5Error::AuthenticationFailed);
|
||||
return Err(SOCKSv5ClientError::AuthenticationFailed);
|
||||
}
|
||||
}
|
||||
|
||||
AuthenticationMethod::NoAcceptableMethods => {
|
||||
return Err(SOCKSv5Error::NoAuthMethodsAllowed)
|
||||
return Err(SOCKSv5ClientError::NoAuthMethodsAllowed)
|
||||
}
|
||||
|
||||
x => return Err(SOCKSv5Error::UnsupportedAuthMethodChosen(x)),
|
||||
x => return Err(SOCKSv5ClientError::UnsupportedAuthMethodChosen(x)),
|
||||
}
|
||||
|
||||
trace!("Returning new SOCKSv5Client object!");
|
||||
Ok(SOCKSv5Client {
|
||||
_network,
|
||||
_stream: stream,
|
||||
})
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
/// Listen for one connection on the proxy server, and then wire back a socket
|
||||
/// that can talk to whoever connects. This handshake is a little odd, because
|
||||
/// we don't necessarily know what port or address we should tell the other
|
||||
/// person to listen on. So this function takes an async function, which it
|
||||
/// will pass this information to once it has it. It's up to that function,
|
||||
/// then, to communicate this to its peer.
|
||||
pub async fn remote_listen<A, Fut: Future<Output = Result<(), SOCKSv5ClientError>>>(
|
||||
self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
callback: impl FnOnce(SOCKSv5Address, u16) -> Fut,
|
||||
) -> Result<(SOCKSv5Address, u16, TcpStream), SOCKSv5ClientError>
|
||||
where
|
||||
A: Into<SOCKSv5Address>,
|
||||
{
|
||||
let mut stream = self.start_session().await?;
|
||||
let target = addr.into();
|
||||
let ccr = ClientConnectionRequest {
|
||||
command_code: ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
destination_address: target.clone(),
|
||||
destination_port: port,
|
||||
};
|
||||
|
||||
ccr.write(&mut stream).await?;
|
||||
|
||||
let initial_response = ServerResponse::read(&mut stream).await?;
|
||||
if initial_response.status != ServerResponseStatus::RequestGranted {
|
||||
return Err(initial_response.status.into());
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Proxy port binding of {}:{} established; server listening on {}:{}",
|
||||
target,
|
||||
port,
|
||||
initial_response.bound_address,
|
||||
initial_response.bound_port
|
||||
);
|
||||
|
||||
callback(initial_response.bound_address, initial_response.bound_port).await?;
|
||||
|
||||
let secondary_response = ServerResponse::read(&mut stream).await?;
|
||||
if secondary_response.status != ServerResponseStatus::RequestGranted {
|
||||
return Err(secondary_response.status.into());
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Proxy bind got a connection from {}:{}",
|
||||
secondary_response.bound_address,
|
||||
secondary_response.bound_port
|
||||
);
|
||||
|
||||
Ok((
|
||||
secondary_response.bound_address,
|
||||
secondary_response.bound_port,
|
||||
stream,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn connect<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<TcpStream, SOCKSv5ClientError> {
|
||||
let mut stream = self.start_session().await?;
|
||||
let target = addr.into();
|
||||
|
||||
let ccr = ClientConnectionRequest {
|
||||
command_code: ClientConnectionCommand::EstablishTCPStream,
|
||||
destination_address: target.clone(),
|
||||
destination_port: port,
|
||||
};
|
||||
ccr.write(&mut stream).await?;
|
||||
let response = ServerResponse::read(&mut stream).await?;
|
||||
|
||||
if response.status == ServerResponseStatus::RequestGranted {
|
||||
tracing::info!(
|
||||
"Proxy connection to {}:{} established; server is using {}:{}",
|
||||
target,
|
||||
port,
|
||||
response.bound_address,
|
||||
response.bound_port
|
||||
);
|
||||
Ok(stream)
|
||||
} else {
|
||||
Err(response.status.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
188
src/errors.rs
188
src/errors.rs
@@ -1,188 +0,0 @@
|
||||
use std::io;
|
||||
use std::string::FromUtf8Error;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::network::SOCKSv5Address;
|
||||
|
||||
/// All the errors that can pop up when trying to turn raw bytes into SOCKSv5
|
||||
/// messages.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DeserializationError {
|
||||
#[error("Invalid protocol version for packet ({1} is not {0}!)")]
|
||||
InvalidVersion(u8, u8),
|
||||
#[error("Not enough data found")]
|
||||
NotEnoughData,
|
||||
#[error("Ooops! Found an empty string where I shouldn't")]
|
||||
InvalidEmptyString,
|
||||
#[error("IO error: {0}")]
|
||||
IOError(#[from] io::Error),
|
||||
#[error("SOCKS authentication format parse error: {0}")]
|
||||
AuthenticationMethodError(#[from] AuthenticationDeserializationError),
|
||||
#[error("Error converting from UTF-8: {0}")]
|
||||
UTF8Error(#[from] FromUtf8Error),
|
||||
#[error("Invalid address type; wanted 1, 3, or 4, got {0}")]
|
||||
InvalidAddressType(u8),
|
||||
#[error("Invalid client command {0}; expected 1, 2, or 3")]
|
||||
InvalidClientCommand(u8),
|
||||
#[error("Invalid server status {0}; expected 0-8")]
|
||||
InvalidServerResponse(u8),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn des_error_reasonable_equals() {
|
||||
let invalid_version = DeserializationError::InvalidVersion(1, 2);
|
||||
assert_eq!(invalid_version, invalid_version);
|
||||
let not_enough = DeserializationError::NotEnoughData;
|
||||
assert_eq!(not_enough, not_enough);
|
||||
let invalid_empty = DeserializationError::InvalidEmptyString;
|
||||
assert_eq!(invalid_empty, invalid_empty);
|
||||
let auth_method = DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound,
|
||||
);
|
||||
assert_eq!(auth_method, auth_method);
|
||||
let utf8 = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
|
||||
assert_eq!(utf8, utf8);
|
||||
let invalid_address = DeserializationError::InvalidAddressType(3);
|
||||
assert_eq!(invalid_address, invalid_address);
|
||||
let invalid_client_cmd = DeserializationError::InvalidClientCommand(32);
|
||||
assert_eq!(invalid_client_cmd, invalid_client_cmd);
|
||||
let invalid_server_resp = DeserializationError::InvalidServerResponse(42);
|
||||
assert_eq!(invalid_server_resp, invalid_server_resp);
|
||||
|
||||
assert_ne!(invalid_version, invalid_address);
|
||||
assert_ne!(not_enough, invalid_empty);
|
||||
assert_ne!(auth_method, invalid_client_cmd);
|
||||
assert_ne!(utf8, invalid_server_resp);
|
||||
}
|
||||
|
||||
impl PartialEq for DeserializationError {
|
||||
fn eq(&self, other: &DeserializationError) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
&DeserializationError::InvalidVersion(a, b),
|
||||
&DeserializationError::InvalidVersion(x, y),
|
||||
) => (a == x) && (b == y),
|
||||
(&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true,
|
||||
(
|
||||
&DeserializationError::InvalidEmptyString,
|
||||
&DeserializationError::InvalidEmptyString,
|
||||
) => true,
|
||||
(
|
||||
&DeserializationError::AuthenticationMethodError(ref a),
|
||||
&DeserializationError::AuthenticationMethodError(ref b),
|
||||
) => a == b,
|
||||
(&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => {
|
||||
a == b
|
||||
}
|
||||
(
|
||||
&DeserializationError::InvalidAddressType(a),
|
||||
&DeserializationError::InvalidAddressType(b),
|
||||
) => a == b,
|
||||
(
|
||||
&DeserializationError::InvalidClientCommand(a),
|
||||
&DeserializationError::InvalidClientCommand(b),
|
||||
) => a == b,
|
||||
(
|
||||
&DeserializationError::InvalidServerResponse(a),
|
||||
&DeserializationError::InvalidServerResponse(b),
|
||||
) => a == b,
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All the errors that can occur trying to turn SOCKSv5 message structures
|
||||
/// into raw bytes. There's a few places that the message structures allow
|
||||
/// for information that can't be serialized; often, you have to be careful
|
||||
/// about how long your strings are ...
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SerializationError {
|
||||
#[error("Too many authentication methods for serialization ({0} > 255)")]
|
||||
TooManyAuthMethods(usize),
|
||||
#[error("Invalid length for string: {0}")]
|
||||
InvalidStringLength(String),
|
||||
#[error("IO error: {0}")]
|
||||
IOError(#[from] io::Error),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ser_err_reasonable_equals() {
|
||||
let too_many = SerializationError::TooManyAuthMethods(512);
|
||||
assert_eq!(too_many, too_many);
|
||||
let invalid_str = SerializationError::InvalidStringLength("Whoopsy!".to_string());
|
||||
assert_eq!(invalid_str, invalid_str);
|
||||
|
||||
assert_ne!(too_many, invalid_str);
|
||||
}
|
||||
|
||||
impl PartialEq for SerializationError {
|
||||
fn eq(&self, other: &SerializationError) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
&SerializationError::TooManyAuthMethods(a),
|
||||
&SerializationError::TooManyAuthMethods(b),
|
||||
) => a == b,
|
||||
(
|
||||
&SerializationError::InvalidStringLength(ref a),
|
||||
&SerializationError::InvalidStringLength(ref b),
|
||||
) => a == b,
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AuthenticationDeserializationError {
|
||||
#[error("No data found deserializing SOCKS authentication type")]
|
||||
NoDataFound,
|
||||
#[error("Invalid authentication type value: {0}")]
|
||||
InvalidAuthenticationByte(u8),
|
||||
#[error("IO error reading SOCKS authentication type: {0}")]
|
||||
IOError(#[from] io::Error),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_des_err_reasonable_equals() {
|
||||
let no_data = AuthenticationDeserializationError::NoDataFound;
|
||||
assert_eq!(no_data, no_data);
|
||||
let invalid_auth = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
|
||||
assert_eq!(invalid_auth, invalid_auth);
|
||||
|
||||
assert_ne!(no_data, invalid_auth);
|
||||
}
|
||||
|
||||
impl PartialEq for AuthenticationDeserializationError {
|
||||
fn eq(&self, other: &AuthenticationDeserializationError) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
&AuthenticationDeserializationError::NoDataFound,
|
||||
&AuthenticationDeserializationError::NoDataFound,
|
||||
) => true,
|
||||
(
|
||||
&AuthenticationDeserializationError::InvalidAuthenticationByte(x),
|
||||
&AuthenticationDeserializationError::InvalidAuthenticationByte(y),
|
||||
) => x == y,
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The errors that can happen, as a server, when we're negotiating the start
|
||||
/// of a SOCKS session.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthenticationError {
|
||||
#[error("Firewall disallowed connection from {0}:{1}")]
|
||||
FirewallRejected(SOCKSv5Address, u16),
|
||||
#[error("Could not agree on an authentication method with the client")]
|
||||
ItsNotUsItsYou,
|
||||
#[error("Failure in serializing response message: {0}")]
|
||||
SerializationError(#[from] SerializationError),
|
||||
#[error("Failed TLS handshake")]
|
||||
FailedTLSHandshake,
|
||||
#[error("IO error writing response message: {0}")]
|
||||
IOError(#[from] io::Error),
|
||||
#[error("Failure in reading client message: {0}")]
|
||||
DeserializationError(#[from] DeserializationError),
|
||||
#[error("Username/password check failed (username was {0})")]
|
||||
FailedUsernamePassword(String),
|
||||
}
|
||||
168
src/lib.rs
168
src/lib.rs
@@ -1,70 +1,54 @@
|
||||
pub mod client;
|
||||
pub mod errors;
|
||||
pub mod messages;
|
||||
pub mod network;
|
||||
mod serialize;
|
||||
pub mod server;
|
||||
|
||||
mod address;
|
||||
mod messages;
|
||||
mod security_parameters;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::address::SOCKSv5Address;
|
||||
use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword};
|
||||
use crate::network::generic::Networklike;
|
||||
use crate::network::testing::TestingStack;
|
||||
use crate::server::{SOCKSv5Server, SecurityParameters};
|
||||
use async_std::task;
|
||||
|
||||
#[test]
|
||||
fn unrestricted_login() {
|
||||
task::block_on(async {
|
||||
let mut network_stack = TestingStack::default();
|
||||
use crate::security_parameters::SecurityParameters;
|
||||
use crate::server::SOCKSv5Server;
|
||||
use std::io;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpSocket, TcpStream};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task;
|
||||
|
||||
#[tokio::test]
|
||||
async fn unrestricted_login() {
|
||||
// generate the server
|
||||
let security_parameters = SecurityParameters::unrestricted();
|
||||
let default_port = network_stack.listen("localhost", 9999).await.unwrap();
|
||||
let server =
|
||||
SOCKSv5Server::new(network_stack.clone(), security_parameters, default_port);
|
||||
let server = SOCKSv5Server::new(security_parameters);
|
||||
server.start("localhost", 9999).await.unwrap();
|
||||
|
||||
let _server_task = task::spawn(async move { server.run().await });
|
||||
|
||||
let stream = network_stack.connect("localhost", 9999).await.unwrap();
|
||||
let login_info = LoginInfo {
|
||||
username_password: None,
|
||||
};
|
||||
let client = SOCKSv5Client::new(network_stack, stream, &login_info).await;
|
||||
let client = SOCKSv5Client::new(login_info, "localhost", 9999).await;
|
||||
|
||||
assert!(client.is_ok());
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disallow_unrestricted() {
|
||||
task::block_on(async {
|
||||
let mut network_stack = TestingStack::default();
|
||||
|
||||
#[tokio::test]
|
||||
async fn disallow_unrestricted() {
|
||||
// generate the server
|
||||
let mut security_parameters = SecurityParameters::unrestricted();
|
||||
security_parameters.allow_unauthenticated = false;
|
||||
let default_port = network_stack.listen("localhost", 9999).await.unwrap();
|
||||
let server =
|
||||
SOCKSv5Server::new(network_stack.clone(), security_parameters, default_port);
|
||||
let server = SOCKSv5Server::new(security_parameters);
|
||||
server.start("localhost", 9998).await.unwrap();
|
||||
|
||||
let _server_task = task::spawn(async move { server.run().await });
|
||||
|
||||
let stream = network_stack.connect("localhost", 9999).await.unwrap();
|
||||
let login_info = LoginInfo {
|
||||
username_password: None,
|
||||
};
|
||||
let client = SOCKSv5Client::new(network_stack, stream, &login_info).await;
|
||||
let login_info = LoginInfo::default();
|
||||
let client = SOCKSv5Client::new(login_info, "localhost", 9998).await;
|
||||
|
||||
assert!(client.is_err());
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn password_checks() {
|
||||
task::block_on(async {
|
||||
let mut network_stack = TestingStack::default();
|
||||
|
||||
#[tokio::test]
|
||||
async fn password_checks() {
|
||||
// generate the server
|
||||
let security_parameters = SecurityParameters {
|
||||
allow_unauthenticated: false,
|
||||
@@ -74,33 +58,115 @@ mod test {
|
||||
username == "awick" && password == "password"
|
||||
}),
|
||||
};
|
||||
let default_port = network_stack.listen("localhost", 9999).await.unwrap();
|
||||
let server =
|
||||
SOCKSv5Server::new(network_stack.clone(), security_parameters, default_port);
|
||||
|
||||
let _server_task = task::spawn(async move { server.run().await });
|
||||
let server = SOCKSv5Server::new(security_parameters);
|
||||
server.start("localhost", 9997).await.unwrap();
|
||||
|
||||
// try the positive side
|
||||
let stream = network_stack.connect("localhost", 9999).await.unwrap();
|
||||
let login_info = LoginInfo {
|
||||
username_password: Some(UsernamePassword {
|
||||
username: "awick".to_string(),
|
||||
password: "password".to_string(),
|
||||
}),
|
||||
};
|
||||
let client = SOCKSv5Client::new(network_stack.clone(), stream, &login_info).await;
|
||||
let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
|
||||
assert!(client.is_ok());
|
||||
|
||||
// try the negative side
|
||||
let stream = network_stack.connect("localhost", 9999).await.unwrap();
|
||||
let login_info = LoginInfo {
|
||||
username_password: Some(UsernamePassword {
|
||||
username: "adamw".to_string(),
|
||||
password: "password".to_string(),
|
||||
}),
|
||||
};
|
||||
let client = SOCKSv5Client::new(network_stack, stream, &login_info).await;
|
||||
let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
|
||||
assert!(client.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn firewall_blocks() {
|
||||
// generate the server
|
||||
let mut security_parameters = SecurityParameters::unrestricted();
|
||||
security_parameters.allow_connection = Some(|_| false);
|
||||
let server = SOCKSv5Server::new(security_parameters);
|
||||
server.start("localhost", 9996).await.unwrap();
|
||||
|
||||
let login_info = LoginInfo::new();
|
||||
let client = SOCKSv5Client::new(login_info, "localhost", 9996).await;
|
||||
|
||||
assert!(client.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn establish_stream() -> io::Result<()> {
|
||||
let target_socket = TcpSocket::new_v4()?;
|
||||
target_socket.bind(SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
|
||||
1337,
|
||||
))?;
|
||||
let target_port = target_socket.listen(1)?;
|
||||
|
||||
// generate the server
|
||||
let security_parameters = SecurityParameters::unrestricted();
|
||||
let server = SOCKSv5Server::new(security_parameters);
|
||||
server.start("localhost", 9995).await.unwrap();
|
||||
|
||||
let login_info = LoginInfo {
|
||||
username_password: None,
|
||||
};
|
||||
|
||||
let mut client = SOCKSv5Client::new(login_info, "localhost", 9995)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
task::spawn(async move {
|
||||
let mut conn = client.connect("localhost", 1337).await.unwrap();
|
||||
conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap();
|
||||
});
|
||||
|
||||
let (mut target_connection, _) = target_port.accept().await.unwrap();
|
||||
let mut read_buffer = [0; 4];
|
||||
target_connection
|
||||
.read_exact(&mut read_buffer)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(read_buffer, [1, 3, 3, 7]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bind_test() -> io::Result<()> {
|
||||
let security_parameters = SecurityParameters::unrestricted();
|
||||
let server = SOCKSv5Server::new(security_parameters);
|
||||
server.start("localhost", 9994).await.unwrap();
|
||||
|
||||
let login_info = LoginInfo::default();
|
||||
let client = SOCKSv5Client::new(login_info, "localhost", 9994)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (target_sender, target_receiver) = oneshot::channel();
|
||||
|
||||
task::spawn(async move {
|
||||
let (_, _, mut conn) = client
|
||||
.remote_listen("localhost", 9993, |addr, port| async move {
|
||||
target_sender.send((addr, port)).unwrap();
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
conn.write_all(&[2, 3, 5, 7]).await.unwrap();
|
||||
});
|
||||
|
||||
let (target_addr, target_port) = target_receiver.await.unwrap();
|
||||
let mut stream = match target_addr {
|
||||
SOCKSv5Address::IP4(x) => TcpStream::connect((x, target_port)).await?,
|
||||
SOCKSv5Address::IP6(x) => TcpStream::connect((x, target_port)).await?,
|
||||
SOCKSv5Address::Hostname(x) => TcpStream::connect((x, target_port)).await?,
|
||||
};
|
||||
let mut read_buffer = [0; 4];
|
||||
stream.read_exact(&mut read_buffer).await.unwrap();
|
||||
assert_eq!(read_buffer, [2, 3, 5, 7]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,12 +5,52 @@ mod client_username_password;
|
||||
mod server_auth_response;
|
||||
mod server_choice;
|
||||
mod server_response;
|
||||
pub(crate) mod utils;
|
||||
|
||||
pub use crate::messages::authentication_method::AuthenticationMethod;
|
||||
pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest};
|
||||
pub use crate::messages::client_greeting::ClientGreeting;
|
||||
pub use crate::messages::client_username_password::ClientUsernamePassword;
|
||||
pub use crate::messages::server_auth_response::ServerAuthResponse;
|
||||
pub use crate::messages::server_choice::ServerChoice;
|
||||
pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus};
|
||||
pub(crate) mod string;
|
||||
|
||||
pub use crate::messages::authentication_method::{
|
||||
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||
};
|
||||
pub use crate::messages::client_command::{
|
||||
ClientConnectionCommand, ClientConnectionCommandReadError, ClientConnectionCommandWriteError,
|
||||
ClientConnectionRequest, ClientConnectionRequestReadError,
|
||||
};
|
||||
pub use crate::messages::client_greeting::{
|
||||
ClientGreeting, ClientGreetingReadError, ClientGreetingWriteError,
|
||||
};
|
||||
pub use crate::messages::client_username_password::{
|
||||
ClientUsernamePassword, ClientUsernamePasswordReadError, ClientUsernamePasswordWriteError,
|
||||
};
|
||||
pub use crate::messages::server_auth_response::{
|
||||
ServerAuthResponse, ServerAuthResponseReadError, ServerAuthResponseWriteError,
|
||||
};
|
||||
pub use crate::messages::server_choice::{
|
||||
ServerChoice, ServerChoiceReadError, ServerChoiceWriteError,
|
||||
};
|
||||
pub use crate::messages::server_response::{
|
||||
ServerResponse, ServerResponseReadError, ServerResponseStatus, ServerResponseWriteError,
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! standard_roundtrip {
|
||||
($name: ident, $t: ty) => {
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn $name(xs: $t) {
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
use std::io::Cursor;
|
||||
|
||||
let originals = xs.clone();
|
||||
let buffer = vec![];
|
||||
let mut write_cursor = Cursor::new(buffer);
|
||||
xs.write(&mut write_cursor).await.unwrap();
|
||||
let serialized_form = write_cursor.into_inner();
|
||||
let mut read_cursor = Cursor::new(serialized_form);
|
||||
let ys = <$t>::read(&mut read_cursor);
|
||||
assert_eq!(originals, ys.await.unwrap());
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
use proptest::prelude::{prop_oneof, Arbitrary, Just, Strategy};
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use proptest::strategy::BoxedStrategy;
|
||||
use std::fmt;
|
||||
#[cfg(test)]
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
@@ -25,6 +24,34 @@ pub enum AuthenticationMethod {
|
||||
NoAcceptableMethods,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum AuthenticationMethodReadError {
|
||||
#[error("Invalid authentication method #{0}")]
|
||||
UnknownAuthenticationMethod(u8),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for AuthenticationMethodReadError {
|
||||
fn from(x: std::io::Error) -> AuthenticationMethodReadError {
|
||||
AuthenticationMethodReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum AuthenticationMethodWriteError {
|
||||
#[error("Trying to write invalid authentication method #{0}")]
|
||||
InvalidAuthMethod(u8),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for AuthenticationMethodWriteError {
|
||||
fn from(x: std::io::Error) -> AuthenticationMethodWriteError {
|
||||
AuthenticationMethodWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthenticationMethod {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
@@ -45,18 +72,34 @@ impl fmt::Display for AuthenticationMethod {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for AuthenticationMethod {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_args: Self::Parameters) -> BoxedStrategy<Self> {
|
||||
prop_oneof![
|
||||
Just(AuthenticationMethod::None),
|
||||
Just(AuthenticationMethod::GSSAPI),
|
||||
Just(AuthenticationMethod::UsernameAndPassword),
|
||||
Just(AuthenticationMethod::ChallengeHandshake),
|
||||
Just(AuthenticationMethod::ChallengeResponse),
|
||||
Just(AuthenticationMethod::SSL),
|
||||
Just(AuthenticationMethod::NDS),
|
||||
Just(AuthenticationMethod::MultiAuthenticationFramework),
|
||||
Just(AuthenticationMethod::JSONPropertyBlock),
|
||||
Just(AuthenticationMethod::NoAcceptableMethods),
|
||||
(0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod),
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthenticationMethod {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<AuthenticationMethod, DeserializationError> {
|
||||
let mut byte_buffer = [0u8; 1];
|
||||
let amount_read = r.read(&mut byte_buffer).await?;
|
||||
|
||||
if amount_read == 0 {
|
||||
return Err(AuthenticationDeserializationError::NoDataFound.into());
|
||||
}
|
||||
|
||||
match byte_buffer[0] {
|
||||
) -> Result<AuthenticationMethod, AuthenticationMethodReadError> {
|
||||
match r.read_u8().await? {
|
||||
0 => Ok(AuthenticationMethod::None),
|
||||
1 => Ok(AuthenticationMethod::GSSAPI),
|
||||
2 => Ok(AuthenticationMethod::UsernameAndPassword),
|
||||
@@ -68,14 +111,16 @@ impl AuthenticationMethod {
|
||||
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
|
||||
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
|
||||
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
|
||||
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
|
||||
e => Err(AuthenticationMethodReadError::UnknownAuthenticationMethod(
|
||||
e,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), AuthenticationMethodWriteError> {
|
||||
let value = match self {
|
||||
AuthenticationMethod::None => 0,
|
||||
AuthenticationMethod::GSSAPI => 1,
|
||||
@@ -86,53 +131,32 @@ impl AuthenticationMethod {
|
||||
AuthenticationMethod::NDS => 7,
|
||||
AuthenticationMethod::MultiAuthenticationFramework => 8,
|
||||
AuthenticationMethod::JSONPropertyBlock => 9,
|
||||
AuthenticationMethod::PrivateMethod(pm) => *pm,
|
||||
AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(&pm) => pm,
|
||||
AuthenticationMethod::PrivateMethod(pm) => {
|
||||
return Err(AuthenticationMethodWriteError::InvalidAuthMethod(pm))
|
||||
}
|
||||
AuthenticationMethod::NoAcceptableMethods => 0xff,
|
||||
};
|
||||
|
||||
Ok(w.write_all(&[value]).await?)
|
||||
Ok(w.write_u8(value).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for AuthenticationMethod {
|
||||
fn arbitrary(g: &mut Gen) -> AuthenticationMethod {
|
||||
let mut vals = vec![
|
||||
AuthenticationMethod::None,
|
||||
AuthenticationMethod::GSSAPI,
|
||||
AuthenticationMethod::UsernameAndPassword,
|
||||
AuthenticationMethod::ChallengeHandshake,
|
||||
AuthenticationMethod::ChallengeResponse,
|
||||
AuthenticationMethod::SSL,
|
||||
AuthenticationMethod::NDS,
|
||||
AuthenticationMethod::MultiAuthenticationFramework,
|
||||
AuthenticationMethod::JSONPropertyBlock,
|
||||
AuthenticationMethod::NoAcceptableMethods,
|
||||
];
|
||||
for x in 0x80..0xffu8 {
|
||||
vals.push(AuthenticationMethod::PrivateMethod(x));
|
||||
}
|
||||
g.choose(&vals).unwrap().clone()
|
||||
}
|
||||
}
|
||||
crate::standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
||||
|
||||
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
||||
|
||||
#[test]
|
||||
fn bad_byte() {
|
||||
#[tokio::test]
|
||||
async fn bad_byte() {
|
||||
let no_len = vec![42];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = AuthenticationMethod::read(&mut cursor);
|
||||
let ys = AuthenticationMethod::read(&mut cursor).await.unwrap_err();
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::InvalidAuthenticationByte(42)
|
||||
)),
|
||||
task::block_on(ys)
|
||||
AuthenticationMethodReadError::UnknownAuthenticationMethod(42),
|
||||
ys
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_isnt_empty() {
|
||||
#[tokio::test]
|
||||
async fn display_isnt_empty() {
|
||||
let vals = vec![
|
||||
AuthenticationMethod::None,
|
||||
AuthenticationMethod::GSSAPI,
|
||||
|
||||
@@ -1,56 +1,121 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::network::SOCKSv5Address;
|
||||
use crate::serialize::read_amt;
|
||||
use crate::standard_roundtrip;
|
||||
use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
|
||||
#[cfg(test)]
|
||||
use async_std::io::ErrorKind;
|
||||
use proptest_derive::Arbitrary;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
#[cfg(test)]
|
||||
use std::net::Ipv4Addr;
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub enum ClientConnectionCommand {
|
||||
EstablishTCPStream,
|
||||
EstablishTCPPortBinding,
|
||||
AssociateUDPPort,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientConnectionCommandReadError {
|
||||
#[error("Invalid client connection command code: {0}")]
|
||||
InvalidClientConnectionCommand(u8),
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientConnectionCommandReadError {
|
||||
fn from(x: std::io::Error) -> ClientConnectionCommandReadError {
|
||||
ClientConnectionCommandReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientConnectionCommandWriteError {
|
||||
#[error("Underlying buffer write error: {0}")]
|
||||
WriteError(String),
|
||||
#[error(transparent)]
|
||||
SOCKSAddressWriteError(#[from] SOCKSv5AddressWriteError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientConnectionCommandWriteError {
|
||||
fn from(x: std::io::Error) -> ClientConnectionCommandWriteError {
|
||||
ClientConnectionCommandWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientConnectionCommand {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<ClientConnectionCommand, ClientConnectionCommandReadError> {
|
||||
match r.read_u8().await? {
|
||||
0x01 => Ok(ClientConnectionCommand::EstablishTCPStream),
|
||||
0x02 => Ok(ClientConnectionCommand::EstablishTCPPortBinding),
|
||||
0x03 => Ok(ClientConnectionCommand::AssociateUDPPort),
|
||||
x => Err(ClientConnectionCommandReadError::InvalidClientConnectionCommand(x)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), std::io::Error> {
|
||||
match self {
|
||||
ClientConnectionCommand::EstablishTCPStream => w.write_u8(0x01).await,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => w.write_u8(0x02).await,
|
||||
ClientConnectionCommand::AssociateUDPPort => w.write_u8(0x03).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crate::standard_roundtrip!(client_command_roundtrips, ClientConnectionCommand);
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub struct ClientConnectionRequest {
|
||||
pub command_code: ClientConnectionCommand,
|
||||
pub destination_address: SOCKSv5Address,
|
||||
pub destination_port: u16,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientConnectionRequestReadError {
|
||||
#[error("Invalid version in client request: {0} (expected 5)")]
|
||||
InvalidVersion(u8),
|
||||
#[error("Invalid command for client request: {0}")]
|
||||
InvalidCommand(#[from] ClientConnectionCommandReadError),
|
||||
#[error("Invalid reserved byte: {0} (expected 0)")]
|
||||
InvalidReservedByte(u8),
|
||||
#[error("Underlying read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error(transparent)]
|
||||
AddressReadError(#[from] SOCKSv5AddressReadError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientConnectionRequestReadError {
|
||||
fn from(x: std::io::Error) -> ClientConnectionRequestReadError {
|
||||
ClientConnectionRequestReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientConnectionRequest {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 2];
|
||||
|
||||
read_amt(r, 2, &mut buffer).await?;
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
) -> Result<Self, ClientConnectionRequestReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
if version != 5 {
|
||||
return Err(ClientConnectionRequestReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
let command_code = match buffer[1] {
|
||||
0x01 => ClientConnectionCommand::EstablishTCPStream,
|
||||
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
0x03 => ClientConnectionCommand::AssociateUDPPort,
|
||||
x => return Err(DeserializationError::InvalidClientCommand(x)),
|
||||
};
|
||||
let command_code = ClientConnectionCommand::read(r).await?;
|
||||
|
||||
let reserved = r.read_u8().await?;
|
||||
if reserved != 0 {
|
||||
return Err(ClientConnectionRequestReadError::InvalidReservedByte(
|
||||
reserved,
|
||||
));
|
||||
}
|
||||
|
||||
let destination_address = SOCKSv5Address::read(r).await?;
|
||||
|
||||
read_amt(r, 2, &mut buffer).await?;
|
||||
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
let destination_port = r.read_u16().await?;
|
||||
|
||||
Ok(ClientConnectionRequest {
|
||||
command_code,
|
||||
@@ -60,92 +125,64 @@ impl ClientConnectionRequest {
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let command = match self.command_code {
|
||||
ClientConnectionCommand::EstablishTCPStream => 1,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => 2,
|
||||
ClientConnectionCommand::AssociateUDPPort => 3,
|
||||
};
|
||||
|
||||
w.write_all(&[5, command]).await?;
|
||||
) -> Result<(), ClientConnectionCommandWriteError> {
|
||||
w.write_u8(5).await?;
|
||||
self.command_code.write(w).await?;
|
||||
w.write_u8(0).await?;
|
||||
self.destination_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.destination_port >> 8) as u8,
|
||||
(self.destination_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
w.write_u16(self.destination_port).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientConnectionCommand {
|
||||
fn arbitrary(g: &mut Gen) -> ClientConnectionCommand {
|
||||
let options = [
|
||||
ClientConnectionCommand::EstablishTCPStream,
|
||||
ClientConnectionCommand::EstablishTCPPortBinding,
|
||||
ClientConnectionCommand::AssociateUDPPort,
|
||||
];
|
||||
g.choose(&options).unwrap().clone()
|
||||
}
|
||||
}
|
||||
crate::standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientConnectionRequest {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let command_code = ClientConnectionCommand::arbitrary(g);
|
||||
let destination_address = SOCKSv5Address::arbitrary(g);
|
||||
let destination_port = u16::arbitrary(g);
|
||||
|
||||
ClientConnectionRequest {
|
||||
command_code,
|
||||
destination_address,
|
||||
destination_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientConnectionRequestReadError::ReadError(_))
|
||||
));
|
||||
|
||||
let no_len = vec![5, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientConnectionRequestReadError::ReadError(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let bad_ver = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(bad_ver);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert_eq!(Err(ClientConnectionRequestReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_command() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_command() {
|
||||
let bad_cmd = vec![5, 32, 1];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
||||
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidClientCommand(32)),
|
||||
task::block_on(ys)
|
||||
Err(ClientConnectionRequestReadError::InvalidCommand(
|
||||
ClientConnectionCommandReadError::InvalidClientConnectionCommand(32)
|
||||
)),
|
||||
ys
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn short_write_fails_right() {
|
||||
#[tokio::test]
|
||||
async fn short_write_fails_right() {
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ClientConnectionRequest {
|
||||
command_code: ClientConnectionCommand::AssociateUDPPort,
|
||||
@@ -153,10 +190,12 @@ fn short_write_fails_right() {
|
||||
destination_port: 22,
|
||||
};
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = task::block_on(cmd.write(&mut cursor));
|
||||
let result = cmd.write(&mut cursor).await;
|
||||
match result {
|
||||
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
||||
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e),
|
||||
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||
Err(ClientConnectionCommandWriteError::WriteError(x)) => {
|
||||
assert!(x.contains("write zero"));
|
||||
}
|
||||
Err(e) => panic!("Got the wrong error writing too much data: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
use crate::messages::authentication_method::{
|
||||
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use crate::errors::AuthenticationDeserializationError;
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::messages::AuthenticationMethod;
|
||||
use crate::standard_roundtrip;
|
||||
use proptest_derive::Arbitrary;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
/// Client greetings are the first message sent in a SOCKSv5 session. They
|
||||
/// identify that there's a client that wants to talk to a server, and that
|
||||
@@ -17,30 +14,57 @@ use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
/// said server. (It feels weird that the offer/choice goes this way instead
|
||||
/// of the reverse, but whatever.)
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub struct ClientGreeting {
|
||||
pub acceptable_methods: Vec<AuthenticationMethod>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientGreetingReadError {
|
||||
#[error("Invalid version in client request: {0} (expected 5)")]
|
||||
InvalidVersion(u8),
|
||||
#[error(transparent)]
|
||||
AuthMethodReadError(#[from] AuthenticationMethodReadError),
|
||||
#[error("Underlying read error: {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientGreetingReadError {
|
||||
fn from(x: std::io::Error) -> ClientGreetingReadError {
|
||||
ClientGreetingReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientGreetingWriteError {
|
||||
#[error("Too many methods provided; need <256, saw {0}")]
|
||||
TooManyMethods(usize),
|
||||
#[error(transparent)]
|
||||
AuthMethodWriteError(#[from] AuthenticationMethodWriteError),
|
||||
#[error("Underlying write error: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientGreetingWriteError {
|
||||
fn from(x: std::io::Error) -> ClientGreetingWriteError {
|
||||
ClientGreetingWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientGreeting {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<ClientGreeting, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<ClientGreeting, ClientGreetingReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
if version != 5 {
|
||||
return Err(ClientGreetingReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
}
|
||||
let num_methods = r.read_u8().await? as usize;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
|
||||
for _ in 0..buffer[0] {
|
||||
let mut acceptable_methods = Vec::with_capacity(num_methods);
|
||||
for _ in 0..num_methods {
|
||||
acceptable_methods.push(AuthenticationMethod::read(r).await?);
|
||||
}
|
||||
|
||||
@@ -48,11 +72,11 @@ impl ClientGreeting {
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
mut self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), ClientGreetingWriteError> {
|
||||
if self.acceptable_methods.len() > 255 {
|
||||
return Err(SerializationError::TooManyAuthMethods(
|
||||
return Err(ClientGreetingWriteError::TooManyMethods(
|
||||
self.acceptable_methods.len(),
|
||||
));
|
||||
}
|
||||
@@ -61,65 +85,48 @@ impl ClientGreeting {
|
||||
buffer.push(5);
|
||||
buffer.push(self.acceptable_methods.len() as u8);
|
||||
w.write_all(&buffer).await?;
|
||||
for authmeth in self.acceptable_methods.iter() {
|
||||
for authmeth in self.acceptable_methods.drain(..) {
|
||||
authmeth.write(w).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientGreeting {
|
||||
fn arbitrary(g: &mut Gen) -> ClientGreeting {
|
||||
let amt = u8::arbitrary(g);
|
||||
let mut acceptable_methods = Vec::with_capacity(amt as usize);
|
||||
crate::standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
||||
|
||||
for _ in 0..amt {
|
||||
acceptable_methods.push(AuthenticationMethod::arbitrary(g));
|
||||
}
|
||||
|
||||
ClientGreeting { acceptable_methods }
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
|
||||
|
||||
let no_len = vec![5];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
|
||||
|
||||
let bad_len = vec![5, 9];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound
|
||||
)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientGreetingReadError::AuthMethodReadError(
|
||||
AuthenticationMethodReadError::ReadError(_)
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let no_len = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ClientGreeting::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientGreeting::read(&mut cursor).await;
|
||||
assert_eq!(Err(ClientGreetingReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_too_many() {
|
||||
#[tokio::test]
|
||||
async fn check_too_many() {
|
||||
let mut auth_methods = Vec::with_capacity(512);
|
||||
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
|
||||
let greet = ClientGreeting {
|
||||
@@ -127,7 +134,7 @@ fn check_too_many() {
|
||||
};
|
||||
let mut output = vec![0; 1024];
|
||||
assert_eq!(
|
||||
Err(SerializationError::TooManyAuthMethods(512)),
|
||||
task::block_on(greet.write(&mut output))
|
||||
Err(ClientGreetingWriteError::TooManyMethods(512)),
|
||||
greet.write(&mut output).await
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
|
||||
#[cfg(test)]
|
||||
use crate::messages::utils::arbitrary_socks_string;
|
||||
use crate::serialize::{read_string, write_string};
|
||||
use crate::standard_roundtrip;
|
||||
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ClientUsernamePassword {
|
||||
@@ -17,68 +12,111 @@ pub struct ClientUsernamePassword {
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
const USERNAME_REGEX: &str = "[a-zA-Z0-9~!@#$%^&*_\\-+=:;?<>]+";
|
||||
#[cfg(test)]
|
||||
const PASSWORD_REGEX: &str = "[a-zA-Z0-9~!@#$%^&*_\\-+=:;?<>]+";
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientUsernamePassword {
|
||||
type Parameters = Option<u8>;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
let max_len = args.unwrap_or(12) as usize;
|
||||
(USERNAME_REGEX, PASSWORD_REGEX)
|
||||
.prop_map(move |(mut username, mut password)| {
|
||||
username.shrink_to(max_len);
|
||||
password.shrink_to(max_len);
|
||||
ClientUsernamePassword { username, password }
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientUsernamePasswordReadError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error("Invalid username/password version; expected 1, saw {0}")]
|
||||
InvalidVersion(u8),
|
||||
#[error(transparent)]
|
||||
StringError(#[from] SOCKSv5StringReadError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientUsernamePasswordReadError {
|
||||
fn from(x: std::io::Error) -> ClientUsernamePasswordReadError {
|
||||
ClientUsernamePasswordReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ClientUsernamePasswordWriteError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
WriteError(String),
|
||||
#[error(transparent)]
|
||||
StringError(#[from] SOCKSv5StringWriteError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ClientUsernamePasswordWriteError {
|
||||
fn from(x: std::io::Error) -> ClientUsernamePasswordWriteError {
|
||||
ClientUsernamePasswordWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientUsernamePassword {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<Self, ClientUsernamePasswordReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
if version != 1 {
|
||||
return Err(ClientUsernamePasswordReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
let username = read_string(r).await?;
|
||||
let password = read_string(r).await?;
|
||||
let username = SOCKSv5String::read(r).await?.into();
|
||||
let password = SOCKSv5String::read(r).await?.into();
|
||||
|
||||
Ok(ClientUsernamePassword { username, password })
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[1]).await?;
|
||||
write_string(&self.username, w).await?;
|
||||
write_string(&self.password, w).await
|
||||
) -> Result<(), ClientUsernamePasswordWriteError> {
|
||||
w.write_u8(1).await?;
|
||||
SOCKSv5String::from(self.username.as_str()).write(w).await?;
|
||||
SOCKSv5String::from(self.password.as_str()).write(w).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ClientUsernamePassword {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let username = arbitrary_socks_string(g);
|
||||
let password = arbitrary_socks_string(g);
|
||||
crate::standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
||||
|
||||
ClientUsernamePassword { username, password }
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn heck_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientUsernamePasswordReadError::ReadError(_))
|
||||
));
|
||||
|
||||
let user_only = vec![1, 3, 102, 111, 111];
|
||||
let mut cursor = Cursor::new(user_only);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||
println!("ys: {:?}", ys);
|
||||
assert!(matches!(
|
||||
ys,
|
||||
Err(ClientUsernamePasswordReadError::StringError(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let bad_len = vec![5];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(1, 5)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||
assert_eq!(Err(ClientUsernamePasswordReadError::InvalidVersion(5)), ys);
|
||||
}
|
||||
|
||||
@@ -1,18 +1,40 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use proptest_derive::Arbitrary;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub struct ServerAuthResponse {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerAuthResponseReadError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error("Invalid username/password version; expected 1, saw {0}")]
|
||||
InvalidVersion(u8),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerAuthResponseReadError {
|
||||
fn from(x: std::io::Error) -> ServerAuthResponseReadError {
|
||||
ServerAuthResponseReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerAuthResponseWriteError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerAuthResponseWriteError {
|
||||
fn from(x: std::io::Error) -> ServerAuthResponseWriteError {
|
||||
ServerAuthResponseWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerAuthResponse {
|
||||
pub fn success() -> ServerAuthResponse {
|
||||
ServerAuthResponse { success: true }
|
||||
@@ -24,30 +46,22 @@ impl ServerAuthResponse {
|
||||
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<Self, ServerAuthResponseReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 1 {
|
||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
||||
}
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
if version != 1 {
|
||||
return Err(ServerAuthResponseReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
Ok(ServerAuthResponse {
|
||||
success: buffer[0] == 0,
|
||||
success: r.read_u8().await? == 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), ServerAuthResponseWriteError> {
|
||||
w.write_all(&[1]).await?;
|
||||
w.write_all(&[if self.success { 0x00 } else { 0xde }])
|
||||
.await?;
|
||||
@@ -55,36 +69,29 @@ impl ServerAuthResponse {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerAuthResponse {
|
||||
fn arbitrary(g: &mut Gen) -> ServerAuthResponse {
|
||||
let success = bool::arbitrary(g);
|
||||
ServerAuthResponse { success }
|
||||
}
|
||||
}
|
||||
crate::standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
||||
|
||||
standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
use std::io::Cursor;
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerAuthResponse::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
|
||||
|
||||
let no_len = vec![1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerAuthResponse::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
use std::io::Cursor;
|
||||
|
||||
let no_len = vec![6, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerAuthResponse::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(1, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerAuthResponseReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,49 @@
|
||||
use crate::messages::authentication_method::{
|
||||
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use crate::errors::AuthenticationDeserializationError;
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::messages::AuthenticationMethod;
|
||||
use crate::standard_roundtrip;
|
||||
use proptest_derive::Arbitrary;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub struct ServerChoice {
|
||||
pub chosen_method: AuthenticationMethod,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerChoiceReadError {
|
||||
#[error(transparent)]
|
||||
AuthMethodError(#[from] AuthenticationMethodReadError),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
ReadError(String),
|
||||
#[error("Invalid version; expected 5, got {0}")]
|
||||
InvalidVersion(u8),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerChoiceReadError {
|
||||
fn from(x: std::io::Error) -> ServerChoiceReadError {
|
||||
ServerChoiceReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerChoiceWriteError {
|
||||
#[error(transparent)]
|
||||
AuthMethodError(#[from] AuthenticationMethodWriteError),
|
||||
#[error("Error in underlying buffer: {0}")]
|
||||
WriteError(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerChoiceWriteError {
|
||||
fn from(x: std::io::Error) -> ServerChoiceWriteError {
|
||||
ServerChoiceWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerChoice {
|
||||
pub fn rejection() -> ServerChoice {
|
||||
ServerChoice {
|
||||
@@ -31,15 +59,11 @@ impl ServerChoice {
|
||||
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 1];
|
||||
) -> Result<Self, ServerChoiceReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
|
||||
if r.read(&mut buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
if version != 5 {
|
||||
return Err(ServerChoiceReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
let chosen_method = AuthenticationMethod::read(r).await?;
|
||||
@@ -48,50 +72,34 @@ impl ServerChoice {
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
w.write_all(&[5]).await?;
|
||||
self.chosen_method.write(w).await
|
||||
) -> Result<(), ServerChoiceWriteError> {
|
||||
w.write_u8(5).await?;
|
||||
self.chosen_method.write(w).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerChoice {
|
||||
fn arbitrary(g: &mut Gen) -> ServerChoice {
|
||||
ServerChoice {
|
||||
chosen_method: AuthenticationMethod::arbitrary(g),
|
||||
}
|
||||
}
|
||||
}
|
||||
crate::standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
||||
|
||||
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerChoice::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerChoice::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerChoiceReadError::ReadError(_))));
|
||||
|
||||
let bad_len = vec![5];
|
||||
let mut cursor = Cursor::new(bad_len);
|
||||
let ys = ServerChoice::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::AuthenticationMethodError(
|
||||
AuthenticationDeserializationError::NoDataFound
|
||||
)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerChoice::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerChoiceReadError::AuthMethodError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let no_len = vec![9, 1];
|
||||
let mut cursor = Cursor::new(no_len);
|
||||
let ys = ServerChoice::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 9)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerChoice::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerChoiceReadError::InvalidVersion(9)), ys);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,13 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use crate::network::SOCKSv5Address;
|
||||
use crate::serialize::read_amt;
|
||||
use crate::standard_roundtrip;
|
||||
use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
|
||||
#[cfg(test)]
|
||||
use async_std::io::ErrorKind;
|
||||
use proptest_derive::Arbitrary;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use log::warn;
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::net::Ipv4Addr;
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, Eq, Error, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub enum ServerResponseStatus {
|
||||
#[error("Actually, everything's fine (weird to see this in an error)")]
|
||||
RequestGranted,
|
||||
@@ -38,39 +30,64 @@ pub enum ServerResponseStatus {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[cfg_attr(test, derive(Arbitrary))]
|
||||
pub struct ServerResponse {
|
||||
pub status: ServerResponseStatus,
|
||||
pub bound_address: SOCKSv5Address,
|
||||
pub bound_port: u16,
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub fn error<E: Into<ServerResponseStatus>>(resp: E) -> ServerResponse {
|
||||
ServerResponse {
|
||||
status: resp.into(),
|
||||
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
|
||||
bound_port: 0,
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerResponseReadError {
|
||||
#[error("Error reading from underlying buffer: {0}")]
|
||||
ReadError(String),
|
||||
#[error(transparent)]
|
||||
AddressReadError(#[from] SOCKSv5AddressReadError),
|
||||
#[error("Invalid version; expected 5, got {0}")]
|
||||
InvalidVersion(u8),
|
||||
#[error("Invalid reserved byte; saw {0}, should be 0")]
|
||||
InvalidReservedByte(u8),
|
||||
#[error("Invalid (or just unknown) server response value {0}")]
|
||||
InvalidServerResponse(u8),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerResponseReadError {
|
||||
fn from(x: std::io::Error) -> ServerResponseReadError {
|
||||
ServerResponseReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum ServerResponseWriteError {
|
||||
#[error("Error reading from underlying buffer: {0}")]
|
||||
WriteError(String),
|
||||
#[error(transparent)]
|
||||
AddressWriteError(#[from] SOCKSv5AddressWriteError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerResponseWriteError {
|
||||
fn from(x: std::io::Error) -> ServerResponseWriteError {
|
||||
ServerResponseWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerResponse {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut buffer = [0; 3];
|
||||
|
||||
read_amt(r, 3, &mut buffer).await?;
|
||||
|
||||
if buffer[0] != 5 {
|
||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
||||
) -> Result<Self, ServerResponseReadError> {
|
||||
let version = r.read_u8().await?;
|
||||
if version != 5 {
|
||||
return Err(ServerResponseReadError::InvalidVersion(version));
|
||||
}
|
||||
|
||||
if buffer[2] != 0 {
|
||||
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
|
||||
let status_byte = r.read_u8().await?;
|
||||
|
||||
let reserved_byte = r.read_u8().await?;
|
||||
if reserved_byte != 0 {
|
||||
return Err(ServerResponseReadError::InvalidReservedByte(reserved_byte));
|
||||
}
|
||||
|
||||
let status = match buffer[1] {
|
||||
let status = match status_byte {
|
||||
0x00 => ServerResponseStatus::RequestGranted,
|
||||
0x01 => ServerResponseStatus::GeneralFailure,
|
||||
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||
@@ -80,12 +97,11 @@ impl ServerResponse {
|
||||
0x06 => ServerResponseStatus::TTLExpired,
|
||||
0x07 => ServerResponseStatus::CommandNotSupported,
|
||||
0x08 => ServerResponseStatus::AddressTypeNotSupported,
|
||||
x => return Err(DeserializationError::InvalidServerResponse(x)),
|
||||
x => return Err(ServerResponseReadError::InvalidServerResponse(x)),
|
||||
};
|
||||
|
||||
let bound_address = SOCKSv5Address::read(r).await?;
|
||||
read_amt(r, 2, &mut buffer).await?;
|
||||
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
||||
let bound_port = r.read_u16().await?;
|
||||
|
||||
Ok(ServerResponse {
|
||||
status,
|
||||
@@ -95,9 +111,11 @@ impl ServerResponse {
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
) -> Result<(), ServerResponseWriteError> {
|
||||
w.write_u8(5).await?;
|
||||
|
||||
let status_code = match self.status {
|
||||
ServerResponseStatus::RequestGranted => 0x00,
|
||||
ServerResponseStatus::GeneralFailure => 0x01,
|
||||
@@ -109,92 +127,61 @@ impl ServerResponse {
|
||||
ServerResponseStatus::CommandNotSupported => 0x07,
|
||||
ServerResponseStatus::AddressTypeNotSupported => 0x08,
|
||||
};
|
||||
|
||||
w.write_all(&[5, status_code, 0]).await?;
|
||||
w.write_u8(status_code).await?;
|
||||
w.write_u8(0).await?;
|
||||
self.bound_address.write(w).await?;
|
||||
w.write_all(&[
|
||||
(self.bound_port >> 8) as u8,
|
||||
(self.bound_port & 0xffu16) as u8,
|
||||
])
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
w.write_u16(self.bound_port).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerResponseStatus {
|
||||
fn arbitrary(g: &mut Gen) -> ServerResponseStatus {
|
||||
let options = [
|
||||
ServerResponseStatus::RequestGranted,
|
||||
ServerResponseStatus::GeneralFailure,
|
||||
ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||
ServerResponseStatus::NetworkUnreachable,
|
||||
ServerResponseStatus::HostUnreachable,
|
||||
ServerResponseStatus::ConnectionRefused,
|
||||
ServerResponseStatus::TTLExpired,
|
||||
ServerResponseStatus::CommandNotSupported,
|
||||
ServerResponseStatus::AddressTypeNotSupported,
|
||||
];
|
||||
g.choose(&options).unwrap().clone()
|
||||
}
|
||||
}
|
||||
crate::standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for ServerResponse {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let status = ServerResponseStatus::arbitrary(g);
|
||||
let bound_address = SOCKSv5Address::arbitrary(g);
|
||||
let bound_port = u16::arbitrary(g);
|
||||
|
||||
ServerResponse {
|
||||
status,
|
||||
bound_address,
|
||||
bound_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
||||
|
||||
#[test]
|
||||
fn check_short_reads() {
|
||||
#[tokio::test]
|
||||
async fn check_short_reads() {
|
||||
let empty = vec![];
|
||||
let mut cursor = Cursor::new(empty);
|
||||
let ys = ServerResponse::read(&mut cursor);
|
||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert!(matches!(ys, Err(ServerResponseReadError::ReadError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_version() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_version() {
|
||||
let bad_ver = vec![6, 1, 1];
|
||||
let mut cursor = Cursor::new(bad_ver);
|
||||
let ys = ServerResponse::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerResponseReadError::InvalidVersion(6)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_bad_command() {
|
||||
#[tokio::test]
|
||||
async fn check_bad_reserved() {
|
||||
let bad_cmd = vec![5, 32, 0x42];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ServerResponse::read(&mut cursor);
|
||||
assert_eq!(
|
||||
Err(DeserializationError::InvalidServerResponse(32)),
|
||||
task::block_on(ys)
|
||||
);
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerResponseReadError::InvalidReservedByte(0x42)), ys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn short_write_fails_right() {
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ServerResponse::error(ServerResponseStatus::AddressTypeNotSupported);
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = task::block_on(cmd.write(&mut cursor));
|
||||
match result {
|
||||
Ok(_) => assert!(false, "Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
||||
Err(e) => assert!(false, "Got the wrong error writing too much data: {}", e),
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn check_bad_command() {
|
||||
let bad_cmd = vec![5, 32, 0];
|
||||
let mut cursor = Cursor::new(bad_cmd);
|
||||
let ys = ServerResponse::read(&mut cursor).await;
|
||||
assert_eq!(Err(ServerResponseReadError::InvalidServerResponse(32)), ys);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn short_write_fails_right() {
|
||||
let mut buffer = [0u8; 2];
|
||||
let cmd = ServerResponse {
|
||||
status: ServerResponseStatus::AddressTypeNotSupported,
|
||||
bound_address: SOCKSv5Address::Hostname("tester.com".to_string()),
|
||||
bound_port: 99,
|
||||
};
|
||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||
let result = cmd.write(&mut cursor).await;
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ServerResponseWriteError::WriteError(_))
|
||||
));
|
||||
}
|
||||
|
||||
117
src/messages/string.rs
Normal file
117
src/messages/string.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
#[cfg(test)]
|
||||
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
|
||||
use std::convert::TryFrom;
|
||||
use std::string::FromUtf8Error;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct SOCKSv5String(String);
|
||||
|
||||
#[cfg(test)]
|
||||
const STRING_REGEX: &str = "[a-zA-Z0-9_.|!@#$%^]+";
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for SOCKSv5String {
|
||||
type Parameters = Option<u16>;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
let max_len = args.unwrap_or(32) as usize;
|
||||
|
||||
STRING_REGEX
|
||||
.prop_map(move |mut str| {
|
||||
str.shrink_to(max_len);
|
||||
SOCKSv5String(str)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum SOCKSv5StringReadError {
|
||||
#[error("Underlying buffer read error: {0}")]
|
||||
ReadError(String),
|
||||
#[error("SOCKSv5 string encoding error; encountered empty string (?)")]
|
||||
ZeroStringLength,
|
||||
#[error("Invalid UTF-8 string: {0}")]
|
||||
InvalidUtf8Error(#[from] FromUtf8Error),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for SOCKSv5StringReadError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5StringReadError {
|
||||
SOCKSv5StringReadError::ReadError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum SOCKSv5StringWriteError {
|
||||
#[error("Underlying buffer write error: {0}")]
|
||||
WriteError(String),
|
||||
#[error("String too large to encode according to SOCKSv5 reuls ({0} bytes long)")]
|
||||
TooBig(usize),
|
||||
#[error("Cannot serialize the empty string in SOCKSv5")]
|
||||
ZeroStringLength,
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for SOCKSv5StringWriteError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5StringWriteError {
|
||||
SOCKSv5StringWriteError::WriteError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl SOCKSv5String {
|
||||
pub async fn read<R: AsyncRead + Unpin>(r: &mut R) -> Result<Self, SOCKSv5StringReadError> {
|
||||
let length = r.read_u8().await? as usize;
|
||||
|
||||
if length == 0 {
|
||||
return Err(SOCKSv5StringReadError::ZeroStringLength);
|
||||
}
|
||||
|
||||
let mut bytestring = vec![0; length];
|
||||
r.read_exact(&mut bytestring).await?;
|
||||
|
||||
Ok(SOCKSv5String(String::from_utf8(bytestring)?))
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Unpin>(
|
||||
self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SOCKSv5StringWriteError> {
|
||||
let bytestring = self.0.as_bytes();
|
||||
|
||||
if bytestring.is_empty() {
|
||||
return Err(SOCKSv5StringWriteError::ZeroStringLength);
|
||||
}
|
||||
|
||||
let length = match u8::try_from(bytestring.len()) {
|
||||
Err(_) => return Err(SOCKSv5StringWriteError::TooBig(bytestring.len())),
|
||||
Ok(x) => x,
|
||||
};
|
||||
|
||||
w.write_u8(length).await?;
|
||||
w.write_all(bytestring).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for SOCKSv5String {
|
||||
fn from(x: String) -> Self {
|
||||
SOCKSv5String(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for SOCKSv5String {
|
||||
fn from(x: &str) -> Self {
|
||||
SOCKSv5String(x.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SOCKSv5String> for String {
|
||||
fn from(x: SOCKSv5String) -> Self {
|
||||
x.0
|
||||
}
|
||||
}
|
||||
|
||||
crate::standard_roundtrip!(socks_string_roundtrips, SOCKSv5String);
|
||||
@@ -1,33 +0,0 @@
|
||||
#[cfg(test)]
|
||||
use quickcheck::{Arbitrary, Gen};
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn arbitrary_socks_string(g: &mut Gen) -> String {
|
||||
loop {
|
||||
let mut potential = String::arbitrary(g);
|
||||
|
||||
potential.truncate(255);
|
||||
let bytestring = potential.as_bytes();
|
||||
|
||||
if bytestring.len() > 0 && bytestring.len() < 256 {
|
||||
return potential;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! standard_roundtrip {
|
||||
($name: ident, $t: ty) => {
|
||||
#[cfg(test)]
|
||||
quickcheck! {
|
||||
fn $name(xs: $t) -> bool {
|
||||
let mut buffer = vec![];
|
||||
task::block_on(xs.write(&mut buffer)).unwrap();
|
||||
let mut cursor = Cursor::new(buffer);
|
||||
let ys = <$t>::read(&mut cursor);
|
||||
xs == task::block_on(ys).unwrap()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
pub mod address;
|
||||
pub mod datagram;
|
||||
pub mod generic;
|
||||
pub mod listener;
|
||||
pub mod standard;
|
||||
pub mod stream;
|
||||
pub mod testing;
|
||||
|
||||
pub use crate::network::address::SOCKSv5Address;
|
||||
pub use crate::network::standard::Builtin;
|
||||
@@ -1,258 +0,0 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
#[cfg(test)]
|
||||
use crate::messages::utils::arbitrary_socks_string;
|
||||
use crate::serialize::{read_amt, read_string, write_string};
|
||||
use crate::standard_roundtrip;
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
#[cfg(test)]
|
||||
use futures::io::Cursor;
|
||||
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
#[cfg(test)]
|
||||
use quickcheck::{quickcheck, Arbitrary, Gen};
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub enum SOCKSv5Address {
|
||||
IP4(Ipv4Addr),
|
||||
IP6(Ipv6Addr),
|
||||
Name(String),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, PartialEq)]
|
||||
pub enum AddressConversionError {
|
||||
#[error("Couldn't convert IPv4 address into destination type")]
|
||||
CouldntConvertIP4,
|
||||
#[error("Couldn't convert IPv6 address into destination type")]
|
||||
CouldntConvertIP6,
|
||||
#[error("Couldn't convert name into destination type")]
|
||||
CouldntConvertName,
|
||||
}
|
||||
|
||||
impl From<IpAddr> for SOCKSv5Address {
|
||||
fn from(x: IpAddr) -> SOCKSv5Address {
|
||||
match x {
|
||||
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
|
||||
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SOCKSv5Address> for IpAddr {
|
||||
type Error = AddressConversionError;
|
||||
|
||||
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
SOCKSv5Address::IP4(a) => Ok(IpAddr::V4(a)),
|
||||
SOCKSv5Address::IP6(a) => Ok(IpAddr::V6(a)),
|
||||
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Ipv4Addr> for SOCKSv5Address {
|
||||
fn from(x: Ipv4Addr) -> Self {
|
||||
SOCKSv5Address::IP4(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SOCKSv5Address> for Ipv4Addr {
|
||||
type Error = AddressConversionError;
|
||||
|
||||
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
SOCKSv5Address::IP4(a) => Ok(a),
|
||||
SOCKSv5Address::IP6(_) => Err(AddressConversionError::CouldntConvertIP6),
|
||||
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Ipv6Addr> for SOCKSv5Address {
|
||||
fn from(x: Ipv6Addr) -> Self {
|
||||
SOCKSv5Address::IP6(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SOCKSv5Address> for Ipv6Addr {
|
||||
type Error = AddressConversionError;
|
||||
|
||||
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
SOCKSv5Address::IP4(_) => Err(AddressConversionError::CouldntConvertIP4),
|
||||
SOCKSv5Address::IP6(a) => Ok(a),
|
||||
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for SOCKSv5Address {
|
||||
fn from(x: String) -> Self {
|
||||
SOCKSv5Address::Name(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for SOCKSv5Address {
|
||||
fn from(x: &str) -> SOCKSv5Address {
|
||||
SOCKSv5Address::Name(x.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SOCKSv5Address {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
|
||||
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
|
||||
SOCKSv5Address::Name(a) => write!(f, "{}", a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SOCKSv5Address {
|
||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<Self, DeserializationError> {
|
||||
let mut byte_buffer = [0u8; 1];
|
||||
let amount_read = r.read(&mut byte_buffer).await?;
|
||||
|
||||
if amount_read == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
match byte_buffer[0] {
|
||||
1 => {
|
||||
let mut addr_buffer = [0; 4];
|
||||
read_amt(r, 4, &mut addr_buffer).await?;
|
||||
Ok(SOCKSv5Address::IP4(Ipv4Addr::from(addr_buffer)))
|
||||
}
|
||||
3 => {
|
||||
let mut addr_buffer = [0; 16];
|
||||
read_amt(r, 16, &mut addr_buffer).await?;
|
||||
Ok(SOCKSv5Address::IP6(Ipv6Addr::from(addr_buffer)))
|
||||
}
|
||||
4 => {
|
||||
let name = read_string(r).await?;
|
||||
Ok(SOCKSv5Address::Name(name))
|
||||
}
|
||||
x => Err(DeserializationError::InvalidAddressType(x)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||
&self,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
match self {
|
||||
SOCKSv5Address::IP4(x) => {
|
||||
w.write_all(&[1]).await?;
|
||||
w.write_all(&x.octets())
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
}
|
||||
SOCKSv5Address::IP6(x) => {
|
||||
w.write_all(&[3]).await?;
|
||||
w.write_all(&x.octets())
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
}
|
||||
SOCKSv5Address::Name(x) => {
|
||||
w.write_all(&[4]).await?;
|
||||
write_string(x, w).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HasLocalAddress {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for SOCKSv5Address {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let ip4 = Ipv4Addr::arbitrary(g);
|
||||
let ip6 = Ipv6Addr::arbitrary(g);
|
||||
let nm = arbitrary_socks_string(g);
|
||||
|
||||
g.choose(&[
|
||||
SOCKSv5Address::IP4(ip4),
|
||||
SOCKSv5Address::IP6(ip6),
|
||||
SOCKSv5Address::Name(nm),
|
||||
])
|
||||
.unwrap()
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
|
||||
|
||||
#[cfg(test)]
|
||||
quickcheck! {
|
||||
fn ip_conversion(x: IpAddr) -> bool {
|
||||
match x {
|
||||
IpAddr::V4(ref a) =>
|
||||
assert_eq!(Err(AddressConversionError::CouldntConvertIP4),
|
||||
Ipv6Addr::try_from(SOCKSv5Address::from(a.clone()))),
|
||||
IpAddr::V6(ref a) =>
|
||||
assert_eq!(Err(AddressConversionError::CouldntConvertIP6),
|
||||
Ipv4Addr::try_from(SOCKSv5Address::from(a.clone()))),
|
||||
}
|
||||
x == IpAddr::try_from(SOCKSv5Address::from(x.clone())).unwrap()
|
||||
}
|
||||
|
||||
fn ip4_conversion(x: Ipv4Addr) -> bool {
|
||||
x == Ipv4Addr::try_from(SOCKSv5Address::from(x.clone())).unwrap()
|
||||
}
|
||||
|
||||
fn ip6_conversion(x: Ipv6Addr) -> bool {
|
||||
x == Ipv6Addr::try_from(SOCKSv5Address::from(x.clone())).unwrap()
|
||||
}
|
||||
|
||||
fn display_matches(x: SOCKSv5Address) -> bool {
|
||||
match x {
|
||||
SOCKSv5Address::IP4(a) => format!("{}", a) == format!("{}", x),
|
||||
SOCKSv5Address::IP6(a) => format!("{}", a) == format!("{}", x),
|
||||
SOCKSv5Address::Name(ref a) => format!("{}", a) == format!("{}", x),
|
||||
}
|
||||
}
|
||||
|
||||
fn bad_read_key(x: u8) -> bool {
|
||||
match x {
|
||||
1 => true,
|
||||
3 => true,
|
||||
4 => true,
|
||||
_ => {
|
||||
let buffer = [x, 0, 1, 2, 9, 10];
|
||||
let mut cursor = Cursor::new(buffer);
|
||||
let meh = SOCKSv5Address::read(&mut cursor);
|
||||
Err(DeserializationError::InvalidAddressType(x)) == task::block_on(meh)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_name_sanity() {
|
||||
let name = "uhsure.com";
|
||||
let strname = name.to_string();
|
||||
|
||||
let addr1 = SOCKSv5Address::from(name);
|
||||
let addr2 = SOCKSv5Address::from(strname);
|
||||
|
||||
assert_eq!(addr1, addr2);
|
||||
assert_eq!(
|
||||
Err(AddressConversionError::CouldntConvertName),
|
||||
IpAddr::try_from(addr1.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
Err(AddressConversionError::CouldntConvertName),
|
||||
Ipv4Addr::try_from(addr1.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
Err(AddressConversionError::CouldntConvertName),
|
||||
Ipv6Addr::try_from(addr1.clone())
|
||||
);
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Datagramlike: Send + Sync + HasLocalAddress {
|
||||
type Error;
|
||||
|
||||
async fn send_to(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
addr: SOCKSv5Address,
|
||||
port: u16,
|
||||
) -> Result<usize, Self::Error>;
|
||||
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error>;
|
||||
}
|
||||
|
||||
pub struct GenericDatagramSocket<E> {
|
||||
pub internal: Box<dyn Datagramlike<Error = E>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<E> Datagramlike for GenericDatagramSocket<E> {
|
||||
type Error = E;
|
||||
|
||||
async fn send_to(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
addr: SOCKSv5Address,
|
||||
port: u16,
|
||||
) -> Result<usize, Self::Error> {
|
||||
Ok(self.internal.send_to(buf, addr, port).await?)
|
||||
}
|
||||
|
||||
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
|
||||
Ok(self.internal.recv_from(buf).await?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> HasLocalAddress for GenericDatagramSocket<E> {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
||||
self.internal.local_addr()
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
use crate::messages::ServerResponseStatus;
|
||||
use crate::network::address::SOCKSv5Address;
|
||||
use crate::network::datagram::GenericDatagramSocket;
|
||||
use crate::network::listener::GenericListener;
|
||||
use crate::network::stream::GenericStream;
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Networklike {
|
||||
/// The error type for things that fail on this network. Apologies in advance
|
||||
/// for using only one; if you have a use case for separating your errors,
|
||||
/// please shoot the author(s) and email to split this into multiple types, one
|
||||
/// for each trait function.
|
||||
type Error: Display + Into<ServerResponseStatus> + Send;
|
||||
|
||||
/// Connect to the given address and port, over this kind of network. The
|
||||
/// underlying stream should behave somewhat like a TCP stream ... which
|
||||
/// may be exactly what you're using. However, in order to support tunnelling
|
||||
/// scenarios (i.e., using another proxy, going through Tor or SSH, etc.) we
|
||||
/// work generically over any stream-like object.
|
||||
async fn connect<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<GenericStream, Self::Error>;
|
||||
|
||||
/// Listen for connections on the given address and port, returning a generic
|
||||
/// listener socket to use in the future.
|
||||
async fn listen<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<GenericListener<Self::Error>, Self::Error>;
|
||||
|
||||
/// Bind a socket for the purposes of doing some datagram communication. NOTE!
|
||||
/// this is only for UDP-like communication, not for generic connecting or
|
||||
/// listening! Maybe obvious from the types, but POSIX has overtrained many
|
||||
/// of us.
|
||||
///
|
||||
/// Recall when using these functions that datagram protocols allow for packet
|
||||
/// loss and out-of-order delivery. So ... be warned.
|
||||
async fn bind<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error>;
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
||||
use crate::network::stream::GenericStream;
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Listenerlike: Send + Sync + HasLocalAddress {
|
||||
type Error;
|
||||
|
||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>;
|
||||
}
|
||||
|
||||
pub struct GenericListener<E> {
|
||||
pub internal: Box<dyn Listenerlike<Error = E>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<E> Listenerlike for GenericListener<E> {
|
||||
type Error = E;
|
||||
|
||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
|
||||
Ok(self.internal.accept().await?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> HasLocalAddress for GenericListener<E> {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
||||
self.internal.local_addr()
|
||||
}
|
||||
}
|
||||
@@ -1,237 +0,0 @@
|
||||
use crate::messages::ServerResponseStatus;
|
||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
||||
use crate::network::datagram::{Datagramlike, GenericDatagramSocket};
|
||||
use crate::network::generic::Networklike;
|
||||
use crate::network::listener::{GenericListener, Listenerlike};
|
||||
use crate::network::stream::{GenericStream, Streamlike};
|
||||
use async_std::io;
|
||||
#[cfg(test)]
|
||||
use async_std::io::ReadExt;
|
||||
use async_std::net::{TcpListener, TcpStream, UdpSocket};
|
||||
use async_trait::async_trait;
|
||||
#[cfg(test)]
|
||||
use futures::AsyncWriteExt;
|
||||
use log::error;
|
||||
|
||||
pub struct Builtin {}
|
||||
|
||||
impl Builtin {
|
||||
pub fn new() -> Builtin {
|
||||
Builtin {}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Builtin {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! local_address_impl {
|
||||
($t: ty) => {
|
||||
impl HasLocalAddress for $t {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
||||
match self.local_addr() {
|
||||
Ok(a) =>
|
||||
(SOCKSv5Address::from(a.ip()), a.port()),
|
||||
Err(e) => {
|
||||
error!("Couldn't translate (Streamlike) local address to SOCKS local address: {}", e);
|
||||
(SOCKSv5Address::from("localhost"), 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
local_address_impl!(TcpStream);
|
||||
local_address_impl!(TcpListener);
|
||||
local_address_impl!(UdpSocket);
|
||||
|
||||
impl Streamlike for TcpStream {}
|
||||
|
||||
#[async_trait]
|
||||
impl Listenerlike for TcpListener {
|
||||
type Error = io::Error;
|
||||
|
||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
|
||||
let (base, addrport) = self.accept().await?;
|
||||
let addr = addrport.ip();
|
||||
let port = addrport.port();
|
||||
Ok((GenericStream::new(base), SOCKSv5Address::from(addr), port))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Datagramlike for UdpSocket {
|
||||
type Error = io::Error;
|
||||
|
||||
async fn send_to(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
addr: SOCKSv5Address,
|
||||
port: u16,
|
||||
) -> Result<usize, Self::Error> {
|
||||
match addr {
|
||||
SOCKSv5Address::IP4(a) => self.send_to(buf, (a, port)).await,
|
||||
SOCKSv5Address::IP6(a) => self.send_to(buf, (a, port)).await,
|
||||
SOCKSv5Address::Name(n) => self.send_to(buf, (n.as_str(), port)).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
|
||||
let (amt, addrport) = self.recv_from(buf).await?;
|
||||
let addr = addrport.ip();
|
||||
let port = addrport.port();
|
||||
Ok((amt, SOCKSv5Address::from(addr), port))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Networklike for Builtin {
|
||||
type Error = io::Error;
|
||||
|
||||
async fn connect<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<GenericStream, Self::Error> {
|
||||
let target = addr.into();
|
||||
|
||||
let base_stream = match target {
|
||||
SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await?,
|
||||
SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await?,
|
||||
SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?,
|
||||
};
|
||||
|
||||
Ok(GenericStream::from(base_stream))
|
||||
}
|
||||
|
||||
async fn listen<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<GenericListener<Self::Error>, Self::Error> {
|
||||
let target = addr.into();
|
||||
|
||||
let base_stream = match target {
|
||||
SOCKSv5Address::IP4(a) => TcpListener::bind((a, port)).await?,
|
||||
SOCKSv5Address::IP6(a) => TcpListener::bind((a, port)).await?,
|
||||
SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), port)).await?,
|
||||
};
|
||||
|
||||
Ok(GenericListener {
|
||||
internal: Box::new(base_stream),
|
||||
})
|
||||
}
|
||||
|
||||
async fn bind<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
|
||||
let target = addr.into();
|
||||
|
||||
let base_socket = match target {
|
||||
SOCKSv5Address::IP4(a) => UdpSocket::bind((a, port)).await?,
|
||||
SOCKSv5Address::IP6(a) => UdpSocket::bind((a, port)).await?,
|
||||
SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), port)).await?,
|
||||
};
|
||||
|
||||
Ok(GenericDatagramSocket {
|
||||
internal: Box::new(base_socket),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_sanity() {
|
||||
async_std::task::block_on(async {
|
||||
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
|
||||
// going to get any dropped data along here ... which is a very questionable
|
||||
// assumption, morally speaking, but probably fine for most purposes.
|
||||
let mut network = Builtin::new();
|
||||
let receiver = network
|
||||
.bind("localhost", 0)
|
||||
.await
|
||||
.expect("Failed to bind receiver socket.");
|
||||
let sender = network
|
||||
.bind("localhost", 0)
|
||||
.await
|
||||
.expect("Failed to bind sender socket.");
|
||||
let buffer = [0xde, 0xea, 0xbe, 0xef];
|
||||
let (receiver_addr, receiver_port) = receiver.local_addr();
|
||||
sender
|
||||
.send_to(&buffer, receiver_addr, receiver_port)
|
||||
.await
|
||||
.expect("Failure sending datagram!");
|
||||
let mut recvbuffer = [0; 4];
|
||||
let (s, f, p) = receiver
|
||||
.recv_from(&mut recvbuffer)
|
||||
.await
|
||||
.expect("Didn't receive UDP message?");
|
||||
let (sender_addr, sender_port) = sender.local_addr();
|
||||
assert_eq!(s, 4);
|
||||
assert_eq!(f, sender_addr);
|
||||
assert_eq!(p, sender_port);
|
||||
assert_eq!(recvbuffer, buffer);
|
||||
});
|
||||
|
||||
// This whole block should be pretty solid, though, unless the system we're
|
||||
// on is in a pretty weird place.
|
||||
let mut network = Builtin::new();
|
||||
|
||||
let listener = async_std::task::block_on(network.listen("localhost", 0))
|
||||
.expect("Couldn't set up listener on localhost");
|
||||
let (listener_address, listener_port) = listener.local_addr();
|
||||
|
||||
let listener_task_handle = async_std::task::spawn(async move {
|
||||
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
|
||||
let mut result_buffer = [0u8; 4];
|
||||
stream
|
||||
.read_exact(&mut result_buffer)
|
||||
.await
|
||||
.expect("Read failure in TCP test");
|
||||
(result_buffer, addr, port)
|
||||
});
|
||||
|
||||
let sender_task_handle = async_std::task::spawn(async move {
|
||||
let mut sender = network
|
||||
.connect(listener_address, listener_port)
|
||||
.await
|
||||
.expect("Coudln't connect to listener?");
|
||||
let (sender_address, sender_port) = sender.local_addr();
|
||||
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
|
||||
sender
|
||||
.write_all(&send_buffer)
|
||||
.await
|
||||
.expect("Couldn't send the write buffer");
|
||||
sender
|
||||
.flush()
|
||||
.await
|
||||
.expect("Couldn't flush the write buffer");
|
||||
sender
|
||||
.close()
|
||||
.await
|
||||
.expect("Couldn't close the write buffer");
|
||||
(sender_address, sender_port)
|
||||
});
|
||||
|
||||
async_std::task::block_on(async {
|
||||
let (result, result_from, result_from_port) = listener_task_handle.await;
|
||||
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
|
||||
let (sender_address, sender_port) = sender_task_handle.await;
|
||||
assert_eq!(result_from, sender_address);
|
||||
assert_eq!(result_from_port, sender_port);
|
||||
});
|
||||
}
|
||||
|
||||
impl From<io::Error> for ServerResponseStatus {
|
||||
fn from(e: io::Error) -> ServerResponseStatus {
|
||||
match e.kind() {
|
||||
io::ErrorKind::ConnectionRefused => ServerResponseStatus::ConnectionRefused,
|
||||
io::ErrorKind::NotFound => ServerResponseStatus::HostUnreachable,
|
||||
_ => ServerResponseStatus::GeneralFailure,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
use crate::network::SOCKSv5Address;
|
||||
use async_std::task::{Context, Poll};
|
||||
use futures::io;
|
||||
use futures::io::{AsyncRead, AsyncWrite};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::address::HasLocalAddress;
|
||||
|
||||
pub trait Streamlike: AsyncRead + AsyncWrite + HasLocalAddress + Send + Sync + Unpin {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GenericStream {
|
||||
internal: Arc<Mutex<dyn Streamlike>>,
|
||||
}
|
||||
|
||||
impl GenericStream {
|
||||
pub fn new<T: Streamlike + 'static>(x: T) -> GenericStream {
|
||||
GenericStream {
|
||||
internal: Arc::new(Mutex::new(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HasLocalAddress for GenericStream {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
||||
let item = self.internal.lock().unwrap();
|
||||
item.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for GenericStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let mut item = self.internal.lock().unwrap();
|
||||
let pinned = Pin::new(&mut *item);
|
||||
pinned.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for GenericStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let mut item = self.internal.lock().unwrap();
|
||||
let pinned = Pin::new(&mut *item);
|
||||
pinned.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let mut item = self.internal.lock().unwrap();
|
||||
let pinned = Pin::new(&mut *item);
|
||||
pinned.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let mut item = self.internal.lock().unwrap();
|
||||
let pinned = Pin::new(&mut *item);
|
||||
pinned.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Streamlike + 'static> From<T> for GenericStream {
|
||||
fn from(x: T) -> GenericStream {
|
||||
GenericStream {
|
||||
internal: Arc::new(Mutex::new(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
mod datagram;
|
||||
mod stream;
|
||||
|
||||
use crate::messages::ServerResponseStatus;
|
||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
||||
#[cfg(test)]
|
||||
use crate::network::datagram::Datagramlike;
|
||||
use crate::network::datagram::GenericDatagramSocket;
|
||||
use crate::network::generic::Networklike;
|
||||
use crate::network::listener::{GenericListener, Listenerlike};
|
||||
use crate::network::stream::GenericStream;
|
||||
use crate::network::testing::datagram::TestDatagram;
|
||||
use crate::network::testing::stream::TestingStream;
|
||||
use async_std::channel::{bounded, Receiver, Sender};
|
||||
use async_std::sync::{Arc, Mutex};
|
||||
#[cfg(test)]
|
||||
use async_std::task;
|
||||
use async_trait::async_trait;
|
||||
#[cfg(test)]
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
|
||||
/// A "network", based purely on internal Rust datatypes, for testing
|
||||
/// networking code. This stack operates purely in memory, so shouldn't
|
||||
/// suffer from any weird networking effects ... which makes it a good
|
||||
/// functional test, but not great at actually testing real-world failure
|
||||
/// modes.
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[derive(Clone)]
|
||||
pub struct TestingStack {
|
||||
tcp_listeners: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<TestingStream>>>>,
|
||||
udp_sockets: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<(SOCKSv5Address, u16, Vec<u8>)>>>>,
|
||||
next_random_socket: u16,
|
||||
}
|
||||
|
||||
impl TestingStack {
|
||||
pub fn new() -> TestingStack {
|
||||
TestingStack {
|
||||
tcp_listeners: Arc::new(Mutex::new(HashMap::new())),
|
||||
udp_sockets: Arc::new(Mutex::new(HashMap::new())),
|
||||
next_random_socket: 23,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TestingStack {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TestStackError {
|
||||
AcceptFailed,
|
||||
AddressBusy(SOCKSv5Address, u16),
|
||||
ConnectionFailed,
|
||||
FailureToSend,
|
||||
NoTCPHostFound(SOCKSv5Address, u16),
|
||||
ReceiveFailure,
|
||||
}
|
||||
|
||||
impl fmt::Display for TestStackError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
TestStackError::AcceptFailed => write!(f, "Accept failed; the other side died (?)"),
|
||||
TestStackError::AddressBusy(ref addr, port) => {
|
||||
write!(f, "Address {}:{} already in use", addr, port)
|
||||
}
|
||||
TestStackError::ConnectionFailed => write!(f, "Couldn't connect to host."),
|
||||
TestStackError::FailureToSend => write!(
|
||||
f,
|
||||
"Weird internal error in testing infrastructure; channel send failed"
|
||||
),
|
||||
TestStackError::NoTCPHostFound(ref addr, port) => {
|
||||
write!(f, "No host found at {} for TCP port {}", addr, port)
|
||||
}
|
||||
TestStackError::ReceiveFailure => {
|
||||
write!(f, "Failed to process a UDP receive (this is weird)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TestStackError> for ServerResponseStatus {
|
||||
fn from(_: TestStackError) -> Self {
|
||||
ServerResponseStatus::GeneralFailure
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Networklike for TestingStack {
|
||||
type Error = TestStackError;
|
||||
|
||||
async fn connect<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<GenericStream, Self::Error> {
|
||||
let table = self.tcp_listeners.lock().await;
|
||||
let target = addr.into();
|
||||
|
||||
match table.get(&(target.clone(), port)) {
|
||||
None => Err(TestStackError::NoTCPHostFound(target, port)),
|
||||
Some(result) => {
|
||||
let stream = TestingStream::new(target, port);
|
||||
let retval = stream.invert();
|
||||
match result.send(stream).await {
|
||||
Ok(()) => Ok(GenericStream::new(retval)),
|
||||
Err(_) => Err(TestStackError::FailureToSend),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn listen<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
mut port: u16,
|
||||
) -> Result<GenericListener<Self::Error>, Self::Error> {
|
||||
let mut table = self.tcp_listeners.lock().await;
|
||||
let target = addr.into();
|
||||
let (sender, receiver) = bounded(5);
|
||||
|
||||
if port == 0 {
|
||||
port = self.next_random_socket;
|
||||
self.next_random_socket += 1;
|
||||
}
|
||||
|
||||
table.insert((target.clone(), port), sender);
|
||||
Ok(GenericListener {
|
||||
internal: Box::new(TestListener::new(target, port, receiver)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn bind<A: Send + Into<SOCKSv5Address>>(
|
||||
&mut self,
|
||||
addr: A,
|
||||
mut port: u16,
|
||||
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
|
||||
let mut table = self.udp_sockets.lock().await;
|
||||
let target = addr.into();
|
||||
let (sender, receiver) = bounded(5);
|
||||
|
||||
if port == 0 {
|
||||
port = self.next_random_socket;
|
||||
self.next_random_socket += 1;
|
||||
}
|
||||
|
||||
table.insert((target.clone(), port), sender);
|
||||
Ok(GenericDatagramSocket {
|
||||
internal: Box::new(TestDatagram::new(self.clone(), target, port, receiver)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct TestListener {
|
||||
address: SOCKSv5Address,
|
||||
port: u16,
|
||||
receiver: Receiver<TestingStream>,
|
||||
}
|
||||
|
||||
impl TestListener {
|
||||
fn new(address: SOCKSv5Address, port: u16, receiver: Receiver<TestingStream>) -> Self {
|
||||
TestListener {
|
||||
address,
|
||||
port,
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HasLocalAddress for TestListener {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
||||
(self.address.clone(), self.port)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Listenerlike for TestListener {
|
||||
type Error = TestStackError;
|
||||
|
||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
|
||||
match self.receiver.recv().await {
|
||||
Ok(next) => {
|
||||
let (addr, port) = next.local_addr();
|
||||
Ok((GenericStream::new(next), addr, port))
|
||||
}
|
||||
Err(_) => Err(TestStackError::AcceptFailed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_udp_sanity() {
|
||||
task::block_on(async {
|
||||
let mut network = TestingStack::new();
|
||||
let receiver = network
|
||||
.bind("localhost", 0)
|
||||
.await
|
||||
.expect("Failed to bind receiver socket.");
|
||||
let sender = network
|
||||
.bind("localhost", 0)
|
||||
.await
|
||||
.expect("Failed to bind sender socket.");
|
||||
let buffer = [0xde, 0xea, 0xbe, 0xef];
|
||||
let (receiver_addr, receiver_port) = receiver.local_addr();
|
||||
sender
|
||||
.send_to(&buffer, receiver_addr, receiver_port)
|
||||
.await
|
||||
.expect("Failure sending datagram!");
|
||||
let mut recvbuffer = [0; 4];
|
||||
let (s, f, p) = receiver
|
||||
.recv_from(&mut recvbuffer)
|
||||
.await
|
||||
.expect("Didn't receive UDP message?");
|
||||
let (sender_addr, sender_port) = sender.local_addr();
|
||||
assert_eq!(s, 4);
|
||||
assert_eq!(f, sender_addr);
|
||||
assert_eq!(p, sender_port);
|
||||
assert_eq!(recvbuffer, buffer);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_basic_tcp() {
|
||||
task::block_on(async {
|
||||
let mut network = TestingStack::new();
|
||||
|
||||
let listener = network
|
||||
.listen("localhost", 0)
|
||||
.await
|
||||
.expect("Couldn't set up listener on localhost");
|
||||
let (listener_address, listener_port) = listener.local_addr();
|
||||
|
||||
let listener_task_handle = task::spawn(async move {
|
||||
dbg!("Starting listener task!!");
|
||||
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
|
||||
let mut result_buffer = [0u8; 4];
|
||||
if let Err(e) = stream.read_exact(&mut result_buffer).await {
|
||||
dbg!("Error reading buffer from stream: {}", e);
|
||||
} else {
|
||||
dbg!("made it through read_exact");
|
||||
}
|
||||
(result_buffer, addr, port)
|
||||
});
|
||||
|
||||
let sender_task_handle = task::spawn(async move {
|
||||
let mut sender = network
|
||||
.connect(listener_address, listener_port)
|
||||
.await
|
||||
.expect("Coudln't connect to listener?");
|
||||
let (sender_address, sender_port) = sender.local_addr();
|
||||
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
|
||||
sender
|
||||
.write_all(&send_buffer)
|
||||
.await
|
||||
.expect("Couldn't send the write buffer");
|
||||
sender
|
||||
.flush()
|
||||
.await
|
||||
.expect("Couldn't flush the write buffer");
|
||||
sender
|
||||
.close()
|
||||
.await
|
||||
.expect("Couldn't close the write buffer");
|
||||
(sender_address, sender_port)
|
||||
});
|
||||
|
||||
let (result, result_from, result_from_port) = listener_task_handle.await;
|
||||
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
|
||||
let (sender_address, sender_port) = sender_task_handle.await;
|
||||
assert_eq!(result_from, sender_address);
|
||||
assert_eq!(result_from_port, sender_port);
|
||||
});
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
use crate::network::address::HasLocalAddress;
|
||||
use crate::network::datagram::Datagramlike;
|
||||
use crate::network::testing::{TestStackError, TestingStack};
|
||||
use crate::network::SOCKSv5Address;
|
||||
use async_std::channel::Receiver;
|
||||
use async_trait::async_trait;
|
||||
use std::cmp::Ordering;
|
||||
|
||||
pub struct TestDatagram {
|
||||
context: TestingStack,
|
||||
my_address: SOCKSv5Address,
|
||||
my_port: u16,
|
||||
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
|
||||
}
|
||||
|
||||
impl TestDatagram {
|
||||
pub fn new(
|
||||
context: TestingStack,
|
||||
my_address: SOCKSv5Address,
|
||||
my_port: u16,
|
||||
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
|
||||
) -> Self {
|
||||
TestDatagram {
|
||||
context,
|
||||
my_address,
|
||||
my_port,
|
||||
input_stream,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HasLocalAddress for TestDatagram {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
||||
(self.my_address.clone(), self.my_port)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Datagramlike for TestDatagram {
|
||||
type Error = TestStackError;
|
||||
|
||||
async fn send_to(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
target: SOCKSv5Address,
|
||||
port: u16,
|
||||
) -> Result<usize, Self::Error> {
|
||||
let table = self.context.udp_sockets.lock().await;
|
||||
match table.get(&(target, port)) {
|
||||
None => Ok(buf.len()),
|
||||
Some(sender) => {
|
||||
sender
|
||||
.send((self.my_address.clone(), self.my_port, buf.to_vec()))
|
||||
.await
|
||||
.map_err(|_| TestStackError::FailureToSend)?;
|
||||
Ok(buf.len())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_from(
|
||||
&self,
|
||||
buffer: &mut [u8],
|
||||
) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
|
||||
let (from_addr, from_port, message) = self
|
||||
.input_stream
|
||||
.recv()
|
||||
.await
|
||||
.map_err(|_| TestStackError::ReceiveFailure)?;
|
||||
|
||||
match message.len().cmp(&buffer.len()) {
|
||||
Ordering::Greater => {
|
||||
buffer.copy_from_slice(&message[..buffer.len()]);
|
||||
Ok((message.len(), from_addr, from_port))
|
||||
}
|
||||
|
||||
Ordering::Less => {
|
||||
(&mut buffer[..message.len()]).copy_from_slice(&message);
|
||||
Ok((message.len(), from_addr, from_port))
|
||||
}
|
||||
|
||||
Ordering::Equal => {
|
||||
buffer.copy_from_slice(message.as_ref());
|
||||
Ok((message.len(), from_addr, from_port))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
use crate::network::address::HasLocalAddress;
|
||||
use crate::network::stream::Streamlike;
|
||||
use crate::network::SOCKSv5Address;
|
||||
use async_std::io;
|
||||
use async_std::io::{Read, Write};
|
||||
use async_std::task::{Context, Poll, Waker};
|
||||
use std::cell::UnsafeCell;
|
||||
use std::pin::Pin;
|
||||
use std::ptr::NonNull;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TestingStream {
|
||||
address: SOCKSv5Address,
|
||||
port: u16,
|
||||
read_side: NonNull<TestingStreamData>,
|
||||
write_side: NonNull<TestingStreamData>,
|
||||
}
|
||||
|
||||
unsafe impl Send for TestingStream {}
|
||||
unsafe impl Sync for TestingStream {}
|
||||
|
||||
struct TestingStreamData {
|
||||
lock: AtomicBool,
|
||||
waiters: UnsafeCell<Vec<Waker>>,
|
||||
buffer: UnsafeCell<Vec<u8>>,
|
||||
}
|
||||
|
||||
unsafe impl Send for TestingStreamData {}
|
||||
unsafe impl Sync for TestingStreamData {}
|
||||
|
||||
impl TestingStream {
|
||||
/// Generate a testing stream. Note that this is directional. So, if you want to
|
||||
/// talk to this stream, you should also generate an `invert()` and pass that to
|
||||
/// the other thread(s).
|
||||
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
|
||||
let read_side_data = TestingStreamData {
|
||||
lock: AtomicBool::new(false),
|
||||
waiters: UnsafeCell::new(Vec::new()),
|
||||
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
||||
};
|
||||
|
||||
let write_side_data = TestingStreamData {
|
||||
lock: AtomicBool::new(false),
|
||||
waiters: UnsafeCell::new(Vec::new()),
|
||||
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
||||
};
|
||||
|
||||
let boxed_rsd = Box::new(read_side_data);
|
||||
let boxed_wsd = Box::new(write_side_data);
|
||||
let raw_read_ptr = Box::leak(boxed_rsd);
|
||||
let raw_write_ptr = Box::leak(boxed_wsd);
|
||||
|
||||
TestingStream {
|
||||
address,
|
||||
port,
|
||||
read_side: NonNull::new(raw_read_ptr).unwrap(),
|
||||
write_side: NonNull::new(raw_write_ptr).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the flip side of this stream; reads from the inverted side will catch the writes
|
||||
/// of the original, etc.
|
||||
pub fn invert(&self) -> TestingStream {
|
||||
TestingStream {
|
||||
address: self.address.clone(),
|
||||
port: self.port,
|
||||
read_side: self.write_side,
|
||||
write_side: self.read_side,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TestingStreamData {
|
||||
fn acquire(&mut self) {
|
||||
loop {
|
||||
match self
|
||||
.lock
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
{
|
||||
Err(_) => continue,
|
||||
Ok(_) => return,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn release(&mut self) {
|
||||
self.lock.store(false, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl HasLocalAddress for TestingStream {
|
||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
||||
(self.address.clone(), self.port)
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for TestingStream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
// so, we're going to spin here, which is less than ideal but should work fine
|
||||
// in practice. we'll obviously need to be very careful to ensure that we keep
|
||||
// the stuff internal to this spin really short.
|
||||
let internals = unsafe { self.read_side.as_mut() };
|
||||
|
||||
internals.acquire();
|
||||
let stream_buffer = internals.buffer.get_mut();
|
||||
let amount_available = stream_buffer.len();
|
||||
|
||||
if amount_available == 0 {
|
||||
let waker = cx.waker().clone();
|
||||
internals.waiters.get_mut().push(waker);
|
||||
internals.release();
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
let amt_written = if buf.len() >= amount_available {
|
||||
(&mut buf[0..amount_available]).copy_from_slice(stream_buffer);
|
||||
stream_buffer.clear();
|
||||
amount_available
|
||||
} else {
|
||||
let amt_to_copy = buf.len();
|
||||
buf.copy_from_slice(&stream_buffer[0..amt_to_copy]);
|
||||
stream_buffer.copy_within(amt_to_copy.., 0);
|
||||
let amt_left = amount_available - amt_to_copy;
|
||||
stream_buffer.resize(amt_left, 0);
|
||||
amt_to_copy
|
||||
};
|
||||
|
||||
internals.release();
|
||||
|
||||
Poll::Ready(Ok(amt_written))
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for TestingStream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let internals = unsafe { self.write_side.as_mut() };
|
||||
internals.acquire();
|
||||
let stream_buffer = internals.buffer.get_mut();
|
||||
stream_buffer.extend_from_slice(buf);
|
||||
for waiter in internals.waiters.get_mut().drain(0..) {
|
||||
waiter.wake();
|
||||
}
|
||||
internals.release();
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(())) // FIXME: Might consider having this wait until the buffer is empty
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(())) // FIXME: Might consider putting in some open/closed logic here
|
||||
}
|
||||
}
|
||||
|
||||
impl Streamlike for TestingStream {}
|
||||
84
src/security_parameters.rs
Normal file
84
src/security_parameters.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
/// The security parameters that you can assign to the server, to make decisions
|
||||
/// about the weirdos it accepts as users. It is recommended that you only use
|
||||
/// wide open connections when you're 100% sure that the server will only be
|
||||
/// accessible locally.
|
||||
#[derive(Clone)]
|
||||
pub struct SecurityParameters {
|
||||
/// Allow completely unauthenticated connections. You should be very, very
|
||||
/// careful about setting this to true, especially if you don't provide a
|
||||
/// guard to ensure that you're getting connections from reasonable places.
|
||||
pub allow_unauthenticated: bool,
|
||||
/// An optional function that can serve as a firewall for new connections.
|
||||
/// Return true if the connection should be allowed to continue, false if
|
||||
/// it shouldn't. This check happens before any data is read from or written
|
||||
/// to the connecting party.
|
||||
pub allow_connection: Option<fn(&SocketAddr) -> bool>,
|
||||
/// An optional function to check a user name (first argument) and password
|
||||
/// (second argument). Return true if the username / password is good, false
|
||||
/// if not.
|
||||
pub check_password: Option<fn(&str, &str) -> bool>,
|
||||
/// An optional function to transition the stream from an unencrypted one to
|
||||
/// an encrypted on. The assumption is you're using something like `rustls`
|
||||
/// to make this happen; the exact mechanism is outside the scope of this
|
||||
/// particular crate. If the connection shouldn't be allowed for some reason
|
||||
/// (a bad certificate or handshake, for example), then return None; otherwise,
|
||||
/// return the new stream.
|
||||
pub connect_tls: Option<fn() -> Option<()>>,
|
||||
}
|
||||
|
||||
impl SecurityParameters {
|
||||
/// Generates a `SecurityParameters` object that's empty. It won't accept
|
||||
/// anything, because it has no mechanisms it can use to actually authenticate
|
||||
/// a user and yet won't allow unauthenticated connections.
|
||||
pub fn new() -> SecurityParameters {
|
||||
SecurityParameters {
|
||||
allow_unauthenticated: false,
|
||||
allow_connection: None,
|
||||
check_password: None,
|
||||
connect_tls: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a `SecurityParameters` object that does not, in any way,
|
||||
/// restrict who can log in. It also will not induce any transition into
|
||||
/// TLS. Use this at your own risk ... or, really, just don't use this,
|
||||
/// ever, and certainly not in production.
|
||||
pub fn unrestricted() -> SecurityParameters {
|
||||
SecurityParameters {
|
||||
allow_unauthenticated: true,
|
||||
allow_connection: None,
|
||||
check_password: None,
|
||||
connect_tls: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Use the provided function to check incoming connections before proceeding
|
||||
/// with the rest of the handshake.
|
||||
pub fn check_connections(mut self, checker: fn(&SocketAddr) -> bool) -> SecurityParameters {
|
||||
self.allow_connection = Some(checker);
|
||||
self
|
||||
}
|
||||
|
||||
/// Use the provided function to check usernames and passwords provided
|
||||
/// to the server.
|
||||
pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters {
|
||||
self.check_password = Some(checker);
|
||||
self
|
||||
}
|
||||
|
||||
/// Use the provide function to validate a TLS connection, and transition it
|
||||
/// to the new stream type. If the handshake fails, return `None` instead of
|
||||
/// `Some`. (And maybe log it somewhere, you know.)
|
||||
pub fn tls_converter(mut self, converter: fn() -> Option<()>) -> SecurityParameters {
|
||||
self.connect_tls = Some(converter);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecurityParameters {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
use crate::errors::{DeserializationError, SerializationError};
|
||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
pub async fn read_string<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
) -> Result<String, DeserializationError> {
|
||||
let mut length_buffer = [0; 1];
|
||||
|
||||
if r.read(&mut length_buffer).await? == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
let target = length_buffer[0] as usize;
|
||||
|
||||
if target == 0 {
|
||||
return Err(DeserializationError::InvalidEmptyString);
|
||||
}
|
||||
|
||||
let mut bytestring = vec![0; target];
|
||||
read_amt(r, target, &mut bytestring).await?;
|
||||
|
||||
Ok(String::from_utf8(bytestring)?)
|
||||
}
|
||||
|
||||
pub async fn write_string<W: AsyncWrite + Send + Unpin>(
|
||||
s: &str,
|
||||
w: &mut W,
|
||||
) -> Result<(), SerializationError> {
|
||||
let bytestring = s.as_bytes();
|
||||
|
||||
if bytestring.is_empty() || bytestring.len() > 255 {
|
||||
return Err(SerializationError::InvalidStringLength(s.to_string()));
|
||||
}
|
||||
|
||||
w.write_all(&[bytestring.len() as u8]).await?;
|
||||
w.write_all(bytestring)
|
||||
.await
|
||||
.map_err(SerializationError::IOError)
|
||||
}
|
||||
|
||||
pub async fn read_amt<R: AsyncRead + Send + Unpin>(
|
||||
r: &mut R,
|
||||
amt: usize,
|
||||
buffer: &mut [u8],
|
||||
) -> Result<(), DeserializationError> {
|
||||
let mut amt_read = 0;
|
||||
|
||||
while amt_read < amt {
|
||||
let chunk_amt = r.read(&mut buffer[amt_read..]).await?;
|
||||
|
||||
if chunk_amt == 0 {
|
||||
return Err(DeserializationError::NotEnoughData);
|
||||
}
|
||||
|
||||
amt_read += chunk_amt;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
678
src/server.rs
678
src/server.rs
@@ -1,110 +1,415 @@
|
||||
use crate::errors::{AuthenticationError, DeserializationError, SerializationError};
|
||||
use crate::address::SOCKSv5Address;
|
||||
use crate::messages::{
|
||||
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting,
|
||||
ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus,
|
||||
AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandReadError,
|
||||
ClientConnectionRequest, ClientConnectionRequestReadError, ClientGreeting,
|
||||
ClientGreetingReadError, ClientUsernamePassword, ClientUsernamePasswordReadError,
|
||||
ServerAuthResponse, ServerAuthResponseWriteError, ServerChoice, ServerChoiceWriteError,
|
||||
ServerResponse, ServerResponseStatus, ServerResponseWriteError,
|
||||
};
|
||||
use crate::network::address::HasLocalAddress;
|
||||
use crate::network::generic::Networklike;
|
||||
use crate::network::listener::{GenericListener, Listenerlike};
|
||||
use crate::network::stream::GenericStream;
|
||||
use crate::network::SOCKSv5Address;
|
||||
use async_std::io;
|
||||
use async_std::io::prelude::WriteExt;
|
||||
use async_std::sync::{Arc, Mutex};
|
||||
use async_std::task;
|
||||
use log::{error, info, trace, warn};
|
||||
pub use crate::security_parameters::SecurityParameters;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
|
||||
pub struct SOCKSv5Server<N: Networklike> {
|
||||
network: N,
|
||||
security_parameters: SecurityParameters,
|
||||
listener: GenericListener<N::Error>,
|
||||
}
|
||||
use tokio::io::{copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket};
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{field, info_span, Instrument, Span};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SecurityParameters {
|
||||
pub allow_unauthenticated: bool,
|
||||
pub allow_connection: Option<fn(&SOCKSv5Address, u16) -> bool>,
|
||||
pub check_password: Option<fn(&str, &str) -> bool>,
|
||||
pub connect_tls: Option<fn(GenericStream) -> Option<GenericStream>>,
|
||||
pub struct SOCKSv5Server {
|
||||
info: Arc<ServerInfo>,
|
||||
}
|
||||
|
||||
impl SecurityParameters {
|
||||
/// Generates a `SecurityParameters` object that does not, in any way,
|
||||
/// restrict who can log in. It also will not induce any transition into
|
||||
/// TLS. Use this at your own risk ... or, really, just don't use this,
|
||||
/// ever, and certainly not in production.
|
||||
pub fn unrestricted() -> SecurityParameters {
|
||||
SecurityParameters {
|
||||
allow_unauthenticated: true,
|
||||
allow_connection: None,
|
||||
check_password: None,
|
||||
connect_tls: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Networklike + Send + 'static> SOCKSv5Server<N> {
|
||||
pub fn new<S: Listenerlike<Error = N::Error> + 'static>(
|
||||
network: N,
|
||||
struct ServerInfo {
|
||||
security_parameters: SecurityParameters,
|
||||
stream: S,
|
||||
) -> SOCKSv5Server<N> {
|
||||
next_id: AtomicU64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error, PartialEq)]
|
||||
pub enum SOCKSv5ServerError {
|
||||
#[error("Underlying networking error: {0}")]
|
||||
NetworkingError(String),
|
||||
#[error("Couldn't negotiate authentication with client.")]
|
||||
ItsNotUsItsYou,
|
||||
#[error("Client greeting read problem: {0}")]
|
||||
GreetingReadProblem(#[from] ClientGreetingReadError),
|
||||
#[error("Server choice write problem: {0}")]
|
||||
ChoiceWriteProblem(#[from] ServerChoiceWriteError),
|
||||
#[error("Failed username/password authentication for user {0}")]
|
||||
FailedUsernamePassword(String),
|
||||
#[error("Server authentication response problem: {0}")]
|
||||
ServerAuthWriteProblem(#[from] ServerAuthResponseWriteError),
|
||||
#[error("Error reading client username/password: {0}")]
|
||||
UserPassReadProblem(#[from] ClientUsernamePasswordReadError),
|
||||
#[error("Error reading client connection command: {0}")]
|
||||
ClientConnReadProblem(#[from] ClientConnectionCommandReadError),
|
||||
#[error("Error reading client connection request: {0}")]
|
||||
ClientRequestReadProblem(#[from] ClientConnectionRequestReadError),
|
||||
#[error("Error writing server response: {0}")]
|
||||
ServerResponseWriteProblem(#[from] ServerResponseWriteError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for SOCKSv5ServerError {
|
||||
fn from(x: std::io::Error) -> SOCKSv5ServerError {
|
||||
SOCKSv5ServerError::NetworkingError(format!("{}", x))
|
||||
}
|
||||
}
|
||||
|
||||
impl SOCKSv5Server {
|
||||
/// Initialize a SOCKSv5 server for use later on. Once initialized, you can listen
|
||||
/// on as many addresses and ports as you like; the metadata about the server will
|
||||
/// be synced across all the instances.
|
||||
pub fn new(security_parameters: SecurityParameters) -> Self {
|
||||
SOCKSv5Server {
|
||||
network,
|
||||
info: Arc::new(ServerInfo {
|
||||
security_parameters,
|
||||
listener: GenericListener {
|
||||
internal: Box::new(stream),
|
||||
},
|
||||
next_id: AtomicU64::new(1),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(self) -> Result<(), N::Error> {
|
||||
let (my_addr, my_port) = self.listener.local_addr();
|
||||
info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port);
|
||||
let locked_network = Arc::new(Mutex::new(self.network));
|
||||
/// Start a server on the given address and port. This function returns when it has
|
||||
/// set up its listening socket, but spawns a separate task to actually wait for
|
||||
/// connections. You can query which ones are still active, or see which ones have
|
||||
/// failed, using some of the other functions for this structure.
|
||||
///
|
||||
/// If you don't care what port is assigned to this server, pass 0 in as the port
|
||||
/// number and one will be chosen for you by the OS.
|
||||
///
|
||||
pub async fn start<A: Send + Into<SOCKSv5Address>>(
|
||||
&self,
|
||||
addr: A,
|
||||
port: u16,
|
||||
) -> Result<JoinHandle<Result<(), std::io::Error>>, std::io::Error> {
|
||||
let listener = match addr.into() {
|
||||
SOCKSv5Address::IP4(x) => TcpListener::bind((x, port)).await?,
|
||||
SOCKSv5Address::IP6(x) => TcpListener::bind((x, port)).await?,
|
||||
SOCKSv5Address::Hostname(x) => TcpListener::bind((x, port)).await?,
|
||||
};
|
||||
|
||||
let sockaddr = listener.local_addr()?;
|
||||
tracing::info!(
|
||||
"Starting SOCKSv5 server on {}:{}",
|
||||
sockaddr.ip(),
|
||||
sockaddr.port()
|
||||
);
|
||||
|
||||
let second_life = self.clone();
|
||||
|
||||
Ok(tokio::task::spawn(async move {
|
||||
second_life.server_loop(listener).await
|
||||
}))
|
||||
}
|
||||
|
||||
/// Run the server loop for a particular listener. This routine will never actually
|
||||
/// return except in error conditions.
|
||||
async fn server_loop(self, listener: TcpListener) -> Result<(), std::io::Error> {
|
||||
let local_addr = listener.local_addr()?;
|
||||
loop {
|
||||
let (stream, their_addr, their_port) = self.listener.accept().await?;
|
||||
|
||||
trace!(
|
||||
"Initial accept of connection from {}:{}",
|
||||
their_addr,
|
||||
their_port
|
||||
let (socket, their_addr) = listener.accept().await?;
|
||||
let accepted_span = info_span!(
|
||||
"session",
|
||||
server_address=?local_addr,
|
||||
remote_address=?their_addr,
|
||||
auth_method=field::Empty,
|
||||
ident=field::Empty,
|
||||
);
|
||||
if let Some(checker) = &self.security_parameters.allow_connection {
|
||||
if !checker(&their_addr, their_port) {
|
||||
info!(
|
||||
"Rejecting attempted connection from {}:{}",
|
||||
their_addr, their_port
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let params = self.security_parameters.clone();
|
||||
let network_mutex_copy = locked_network.clone();
|
||||
task::spawn(async move {
|
||||
match run_authentication(params, stream, their_addr.clone(), their_port).await {
|
||||
Ok(authed_stream) => {
|
||||
match run_main_loop(network_mutex_copy, authed_stream).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => warn!("Failure in main loop: {}", e),
|
||||
accepted_span.in_scope(|| {
|
||||
// before we do anything of note, make sure this connection is cool. we don't want
|
||||
// to waste any resources (and certainly don't want to handle any data!) if this
|
||||
// isn't someone we want to accept connections from.
|
||||
if let Some(checker) = self.info.security_parameters.allow_connection {
|
||||
if !checker(&their_addr) {
|
||||
tracing::info!("Rejecting attempted connection from {}", their_addr,);
|
||||
}
|
||||
} else {
|
||||
// continue this work in another task. we could absolutely do this work here,
|
||||
// but just in case someone starts doing slow responses (or other nasty things),
|
||||
// we want to make sure that that doesn't slow down our ability to accept other
|
||||
// requests.
|
||||
let span_again = accepted_span.clone();
|
||||
let me_again = self.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let session_identifier =
|
||||
me_again.info.next_id.fetch_add(1, Ordering::SeqCst);
|
||||
span_again.record("ident", &session_identifier);
|
||||
if let Err(e) = me_again
|
||||
.start_authentication(their_addr, socket)
|
||||
.instrument(span_again)
|
||||
.await
|
||||
{
|
||||
tracing::error!("{}: server handler failure: {}", their_addr, e);
|
||||
}
|
||||
Err(e) => warn!(
|
||||
"Failure running authentication from {}:{}: {}",
|
||||
their_addr, their_port, e
|
||||
),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the authentication phase of the SOCKS handshake. This may be very short, and
|
||||
/// is the first stage of handling a request. This will only really return on errors.
|
||||
async fn start_authentication(
|
||||
self,
|
||||
their_addr: SocketAddr,
|
||||
mut socket: TcpStream,
|
||||
) -> Result<(), SOCKSv5ServerError> {
|
||||
let greeting = ClientGreeting::read(&mut socket).await?;
|
||||
|
||||
match choose_authentication_method(
|
||||
&self.info.security_parameters,
|
||||
&greeting.acceptable_methods,
|
||||
) {
|
||||
// it's not us, it's you. (we're just going to say no.)
|
||||
None => {
|
||||
tracing::trace!(
|
||||
"{}: Failed to find acceptable authentication method.",
|
||||
their_addr,
|
||||
);
|
||||
let rejection_letter = ServerChoice::rejection();
|
||||
|
||||
rejection_letter.write(&mut socket).await?;
|
||||
socket.flush().await?;
|
||||
|
||||
Err(SOCKSv5ServerError::ItsNotUsItsYou)
|
||||
}
|
||||
|
||||
// the gold standard. great choice.
|
||||
Some(ChosenMethod::TLS(_converter)) => {
|
||||
Span::current().record("auth_method", &"TLS");
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// well, I guess this is something?
|
||||
Some(ChosenMethod::Password(checker)) => {
|
||||
tracing::trace!(
|
||||
"{}: Choosing username/password for authentication.",
|
||||
their_addr,
|
||||
);
|
||||
let ok_lets_do_password =
|
||||
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
|
||||
ok_lets_do_password.write(&mut socket).await?;
|
||||
socket.flush().await?;
|
||||
|
||||
let their_info = ClientUsernamePassword::read(&mut socket).await?;
|
||||
if checker(&their_info.username, &their_info.password) {
|
||||
let its_all_good = ServerAuthResponse::success();
|
||||
its_all_good.write(&mut socket).await?;
|
||||
socket.flush().await?;
|
||||
Span::current().record("auth_method", &"password");
|
||||
self.choose_mode(socket, their_addr).await
|
||||
} else {
|
||||
let yeah_no = ServerAuthResponse::failure();
|
||||
yeah_no.write(&mut socket).await?;
|
||||
socket.flush().await?;
|
||||
Err(SOCKSv5ServerError::FailedUsernamePassword(
|
||||
their_info.username,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// Um. I guess we're doing this unchecked. Yay?
|
||||
Some(ChosenMethod::None) => {
|
||||
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
|
||||
nothin_i_guess.write(&mut socket).await?;
|
||||
socket.flush().await?;
|
||||
Span::current().record("auth_method", &"unauthenticated");
|
||||
self.choose_mode(socket, their_addr).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine which of the modes we might want this particular connection to run
|
||||
/// in.
|
||||
async fn choose_mode(
|
||||
self,
|
||||
mut socket: TcpStream,
|
||||
their_addr: SocketAddr,
|
||||
) -> Result<(), SOCKSv5ServerError> {
|
||||
let ccr = ClientConnectionRequest::read(&mut socket).await?;
|
||||
match ccr.command_code {
|
||||
ClientConnectionCommand::AssociateUDPPort => {
|
||||
self.handle_udp_request(socket, their_addr, ccr).await?
|
||||
}
|
||||
ClientConnectionCommand::EstablishTCPStream => {
|
||||
self.handle_tcp_request(socket, ccr).await?
|
||||
}
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => {
|
||||
self.handle_tcp_binding_request(socket, their_addr, ccr)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle UDP forwarding requests
|
||||
#[allow(unreachable_code)]
|
||||
async fn handle_udp_request(
|
||||
self,
|
||||
stream: TcpStream,
|
||||
their_addr: SocketAddr,
|
||||
ccr: ClientConnectionRequest,
|
||||
) -> Result<(), SOCKSv5ServerError> {
|
||||
let my_addr = stream.local_addr()?;
|
||||
tracing::info!(
|
||||
"[{}:{}] Handling UDP bind request from {}:{}, seeking to bind towards {}:{}",
|
||||
my_addr.ip(),
|
||||
my_addr.port(),
|
||||
their_addr.ip(),
|
||||
their_addr.port(),
|
||||
ccr.destination_address,
|
||||
ccr.destination_port
|
||||
);
|
||||
|
||||
let _socket = match ccr.destination_address.clone() {
|
||||
SOCKSv5Address::IP4(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
|
||||
SOCKSv5Address::IP6(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
|
||||
SOCKSv5Address::Hostname(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
|
||||
};
|
||||
|
||||
// OK, it worked. In order to mitigate an infinitesimal chance of a race condition, we're
|
||||
// going to set up our forwarding tasks first, and then return the result to the user. (Note,
|
||||
// we'd have to be slightly more precious in order to ensure a lack of race conditions, as
|
||||
// the runtime could take forever to actually start these tasks, but I'm not ready to be
|
||||
// bothered by this, yet. FIXME.)
|
||||
unimplemented!();
|
||||
|
||||
// Cool; now we can get the result out to the user.
|
||||
let bound_address = _socket.local_addr()?;
|
||||
let response = ServerResponse {
|
||||
status: ServerResponseStatus::RequestGranted,
|
||||
bound_address: bound_address.ip().into(),
|
||||
bound_port: bound_address.port(),
|
||||
};
|
||||
|
||||
response.write(&mut stream).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle TCP forwarding requests
|
||||
async fn handle_tcp_request(
|
||||
self,
|
||||
mut stream: TcpStream,
|
||||
ccr: ClientConnectionRequest,
|
||||
) -> Result<(), SOCKSv5ServerError> {
|
||||
// Let the user know that we're maybe making progress
|
||||
tracing::info!(
|
||||
"Handling TCP forward request to {}:{}",
|
||||
ccr.destination_address,
|
||||
ccr.destination_port
|
||||
);
|
||||
|
||||
// OK, first thing's first: We need to actually connect to the server that the user
|
||||
// wants us to connect to.
|
||||
let outgoing_stream = match &ccr.destination_address {
|
||||
SOCKSv5Address::IP4(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
|
||||
SOCKSv5Address::IP6(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
|
||||
SOCKSv5Address::Hostname(x) => {
|
||||
TcpStream::connect((x.as_ref(), ccr.destination_port)).await?
|
||||
}
|
||||
};
|
||||
|
||||
let tcp_forwarding_span = info_span!(
|
||||
"tcp_forwarding",
|
||||
target_address=?ccr.destination_address,
|
||||
target_port=ccr.destination_port
|
||||
);
|
||||
|
||||
// Now, for whatever reason -- and this whole thing sent me down a garden path
|
||||
// in understanding how this whole protocol works -- we tell the user what address
|
||||
// and port we bound for that connection.
|
||||
let bound_address = outgoing_stream.local_addr()?;
|
||||
let response = ServerResponse {
|
||||
status: ServerResponseStatus::RequestGranted,
|
||||
bound_address: bound_address.ip().into(),
|
||||
bound_port: bound_address.port(),
|
||||
};
|
||||
response
|
||||
.write(&mut stream)
|
||||
.instrument(tcp_forwarding_span.clone())
|
||||
.await?;
|
||||
|
||||
// so now tie our streams together, and we're good to go
|
||||
tie_streams(stream, outgoing_stream)
|
||||
.instrument(tcp_forwarding_span)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle TCP binding requests
|
||||
async fn handle_tcp_binding_request(
|
||||
self,
|
||||
mut stream: TcpStream,
|
||||
their_addr: SocketAddr,
|
||||
ccr: ClientConnectionRequest,
|
||||
) -> Result<(), SOCKSv5ServerError> {
|
||||
// Let the user know that we're maybe making progress
|
||||
let my_addr = stream.local_addr()?;
|
||||
tracing::info!(
|
||||
"[{}] Handling TCP bind request from {}, seeking to bind {}:{}",
|
||||
my_addr,
|
||||
their_addr,
|
||||
ccr.destination_address,
|
||||
ccr.destination_port
|
||||
);
|
||||
|
||||
// OK, we have to bind the darn socket first.
|
||||
let listener_port = match &their_addr {
|
||||
SocketAddr::V4(_) => TcpSocket::new_v4(),
|
||||
SocketAddr::V6(_) => TcpSocket::new_v6(),
|
||||
}?;
|
||||
// FIXME: Might want to bind on a particular interface, based on a
|
||||
// config flag, at some point.
|
||||
let listener = listener_port.listen(1)?;
|
||||
|
||||
// Tell them what we bound, just in case they want to inform anyone.
|
||||
let bound_address = listener.local_addr()?;
|
||||
let response = ServerResponse {
|
||||
status: ServerResponseStatus::RequestGranted,
|
||||
bound_address: bound_address.ip().into(),
|
||||
bound_port: bound_address.port(),
|
||||
};
|
||||
response.write(&mut stream).await?;
|
||||
|
||||
// Wait politely for someone to talk to us.
|
||||
let (other, other_addr) = listener.accept().await?;
|
||||
let info = ServerResponse {
|
||||
status: ServerResponseStatus::RequestGranted,
|
||||
bound_address: other_addr.ip().into(),
|
||||
bound_port: other_addr.port(),
|
||||
};
|
||||
info.write(&mut stream).await?;
|
||||
|
||||
tie_streams(stream, other).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn tie_streams(mut left: TcpStream, mut right: TcpStream) {
|
||||
let span_copy = Span::current();
|
||||
tracing::info!("linking forwarding streams");
|
||||
tokio::task::spawn(
|
||||
async move {
|
||||
match copy_bidirectional(&mut left, &mut right)
|
||||
.instrument(span_copy)
|
||||
.await
|
||||
{
|
||||
Ok((l2r, r2l)) => {
|
||||
tracing::info!(sent = l2r, received = r2l, "shutting down streams")
|
||||
}
|
||||
Err(e) => tracing::warn!("Linked streams shut down with error: {}", e),
|
||||
}
|
||||
}
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
enum ChosenMethod {
|
||||
TLS(fn(GenericStream) -> Option<GenericStream>),
|
||||
TLS(fn() -> Option<()>),
|
||||
Password(fn(&str, &str) -> bool),
|
||||
None,
|
||||
}
|
||||
@@ -209,7 +514,7 @@ fn reasonable_auth_method_choices() {
|
||||
);
|
||||
|
||||
// OK, cool. If we have a TLS handler, that shouldn't actually make a difference.
|
||||
params.connect_tls = Some(|_| unimplemented!());
|
||||
params.connect_tls = Some(|| unimplemented!());
|
||||
assert_eq!(
|
||||
choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from),
|
||||
None
|
||||
@@ -222,7 +527,7 @@ fn reasonable_auth_method_choices() {
|
||||
None
|
||||
);
|
||||
// but if we have a handler, and they go for it, we use it.
|
||||
params.connect_tls = Some(|_| unimplemented!());
|
||||
params.connect_tls = Some(|| unimplemented!());
|
||||
assert_eq!(
|
||||
choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from),
|
||||
Some(AuthenticationMethod::SSL)
|
||||
@@ -242,212 +547,3 @@ fn reasonable_auth_method_choices() {
|
||||
Some(AuthenticationMethod::SSL)
|
||||
);
|
||||
}
|
||||
|
||||
async fn run_authentication(
|
||||
params: SecurityParameters,
|
||||
mut stream: GenericStream,
|
||||
addr: SOCKSv5Address,
|
||||
port: u16,
|
||||
) -> Result<GenericStream, AuthenticationError> {
|
||||
// before we do anything at all, we check to see if we just want to blindly reject
|
||||
// this connection, utterly and completely.
|
||||
if let Some(firewall_allows) = params.allow_connection {
|
||||
if !firewall_allows(&addr, port) {
|
||||
return Err(AuthenticationError::FirewallRejected(addr, port));
|
||||
}
|
||||
}
|
||||
|
||||
// OK, I guess we'll listen to you
|
||||
let greeting = ClientGreeting::read(&mut stream).await?;
|
||||
|
||||
match choose_authentication_method(¶ms, &greeting.acceptable_methods) {
|
||||
// it's not us, it's you
|
||||
None => {
|
||||
trace!("Failed to find acceptable authentication method.");
|
||||
let rejection_letter = ServerChoice::rejection();
|
||||
|
||||
rejection_letter.write(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Err(AuthenticationError::ItsNotUsItsYou)
|
||||
}
|
||||
|
||||
// the gold standard. great choice.
|
||||
Some(ChosenMethod::TLS(converter)) => {
|
||||
trace!("Choosing TLS for authentication.");
|
||||
let lets_do_this = ServerChoice::option(AuthenticationMethod::SSL);
|
||||
lets_do_this.write(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
converter(stream).ok_or(AuthenticationError::FailedTLSHandshake)
|
||||
}
|
||||
|
||||
// well, I guess this is something?
|
||||
Some(ChosenMethod::Password(checker)) => {
|
||||
trace!("Choosing Username/Password for authentication.");
|
||||
let ok_lets_do_password =
|
||||
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
|
||||
ok_lets_do_password.write(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let their_info = ClientUsernamePassword::read(&mut stream).await?;
|
||||
if checker(&their_info.username, &their_info.password) {
|
||||
let its_all_good = ServerAuthResponse::success();
|
||||
its_all_good.write(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
Ok(stream)
|
||||
} else {
|
||||
let yeah_no = ServerAuthResponse::failure();
|
||||
yeah_no.write(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
Err(AuthenticationError::FailedUsernamePassword(
|
||||
their_info.username,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
Some(ChosenMethod::None) => {
|
||||
trace!("Just skipping the whole authentication thing.");
|
||||
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
|
||||
nothin_i_guess.write(&mut stream).await?;
|
||||
stream.flush().await?;
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
enum ServerError {
|
||||
#[error("Error in deserialization: {0}")]
|
||||
DeserializationError(#[from] DeserializationError),
|
||||
#[error("Error in serialization: {0}")]
|
||||
SerializationError(#[from] SerializationError),
|
||||
}
|
||||
|
||||
async fn run_main_loop<N>(
|
||||
network: Arc<Mutex<N>>,
|
||||
mut stream: GenericStream,
|
||||
) -> Result<(), ServerError>
|
||||
where
|
||||
N: Networklike,
|
||||
N::Error: 'static,
|
||||
{
|
||||
loop {
|
||||
let ccr = ClientConnectionRequest::read(&mut stream).await?;
|
||||
|
||||
match ccr.command_code {
|
||||
ClientConnectionCommand::AssociateUDPPort => {}
|
||||
|
||||
ClientConnectionCommand::EstablishTCPPortBinding => {}
|
||||
|
||||
ClientConnectionCommand::EstablishTCPStream => {
|
||||
let target = format!("{}:{}", ccr.destination_address, ccr.destination_port);
|
||||
|
||||
info!(
|
||||
"Client requested connection to {}:{}",
|
||||
ccr.destination_address, ccr.destination_port
|
||||
);
|
||||
let connection_res = {
|
||||
let mut network = network.lock().await;
|
||||
network
|
||||
.connect(ccr.destination_address.clone(), ccr.destination_port)
|
||||
.await
|
||||
};
|
||||
let outgoing_stream = match connection_res {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
error!("Failed to connect to {}: {}", target, e);
|
||||
let response = ServerResponse::error(e);
|
||||
response.write(&mut stream).await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
trace!(
|
||||
"Connection established to {}:{}",
|
||||
ccr.destination_address,
|
||||
ccr.destination_port
|
||||
);
|
||||
|
||||
let incoming_res = {
|
||||
let mut network = network.lock().await;
|
||||
network.listen("127.0.0.1", 0).await
|
||||
};
|
||||
let incoming_listener = match incoming_res {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
error!("Failed to bind server port for new TCP stream: {}", e);
|
||||
let response = ServerResponse::error(e);
|
||||
response.write(&mut stream).await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let (bound_address, bound_port) = incoming_listener.local_addr();
|
||||
trace!(
|
||||
"Set up {}:{} to address request for {}:{}",
|
||||
bound_address,
|
||||
bound_port,
|
||||
ccr.destination_address,
|
||||
ccr.destination_port
|
||||
);
|
||||
|
||||
let response = ServerResponse {
|
||||
status: ServerResponseStatus::RequestGranted,
|
||||
bound_address,
|
||||
bound_port,
|
||||
};
|
||||
response.write(&mut stream).await?;
|
||||
|
||||
task::spawn(async move {
|
||||
let (incoming_stream, from_addr, from_port) = match incoming_listener
|
||||
.accept()
|
||||
.await
|
||||
{
|
||||
Err(e) => {
|
||||
error!("Miscellaneous error waiting for someone to connect for proxying: {}", e);
|
||||
return;
|
||||
}
|
||||
Ok(s) => s,
|
||||
};
|
||||
trace!(
|
||||
"Accepted connection from {}:{} to attach to {}:{}",
|
||||
from_addr,
|
||||
from_port,
|
||||
ccr.destination_address,
|
||||
ccr.destination_port
|
||||
);
|
||||
|
||||
let mut from_left = incoming_stream.clone();
|
||||
let mut from_right = outgoing_stream.clone();
|
||||
let mut to_left = incoming_stream;
|
||||
let mut to_right = outgoing_stream;
|
||||
let from = format!("{}:{}", from_addr, from_port);
|
||||
let to = format!("{}:{}", ccr.destination_address, ccr.destination_port);
|
||||
|
||||
task::spawn(async move {
|
||||
info!(
|
||||
"Spawned {}:{} >--> {}:{} task",
|
||||
from_addr, from_port, ccr.destination_address, ccr.destination_port
|
||||
);
|
||||
if let Err(e) = io::copy(&mut from_left, &mut to_right).await {
|
||||
warn!(
|
||||
"{}:{} >--> {}:{} connection failed with: {}",
|
||||
from_addr,
|
||||
from_port,
|
||||
ccr.destination_address,
|
||||
ccr.destination_port,
|
||||
e
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
task::spawn(async move {
|
||||
info!("Spawned {} <--< {} task", from, to);
|
||||
if let Err(e) = io::copy(&mut from_right, &mut to_left).await {
|
||||
warn!("{} <--< {} connection failed with: {}", from, to, e);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user