Switch to basic tokio; will expand later to arbitrary backends.
This commit is contained in:
14
Cargo.toml
14
Cargo.toml
@@ -8,14 +8,12 @@ edition = "2018"
|
|||||||
name = "async_socks5"
|
name = "async_socks5"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-std = { version = "1.9.0", features = ["attributes"] }
|
anyhow = "^1.0.57"
|
||||||
async-trait = "0.1.50"
|
proptest = "^1.0.0"
|
||||||
futures = "0.3.15"
|
thiserror = "^1.0.31"
|
||||||
log = "0.4.8"
|
tokio = { version = "^1", features = ["full"] }
|
||||||
proptest = "1.0.0"
|
tracing = "^0.1.34"
|
||||||
simplelog = "0.10.0"
|
|
||||||
thiserror = "1.0.24"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
proptest = "1.0.0"
|
proptest = "1.0.0"
|
||||||
proptest-derive = "0.3.0"
|
proptest-derive = "0.3.0"
|
||||||
|
|||||||
174
src/address.rs
Normal file
174
src/address.rs
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
|
||||||
|
use std::fmt;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||||
|
pub enum SOCKSv5Address {
|
||||||
|
IP4(Ipv4Addr),
|
||||||
|
IP6(Ipv6Addr),
|
||||||
|
Hostname(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<IpAddr> for SOCKSv5Address {
|
||||||
|
fn from(x: IpAddr) -> SOCKSv5Address {
|
||||||
|
match x {
|
||||||
|
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
|
||||||
|
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Ipv4Addr> for SOCKSv5Address {
|
||||||
|
fn from(x: Ipv4Addr) -> SOCKSv5Address {
|
||||||
|
SOCKSv5Address::IP4(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Ipv6Addr> for SOCKSv5Address {
|
||||||
|
fn from(x: Ipv6Addr) -> SOCKSv5Address {
|
||||||
|
SOCKSv5Address::IP6(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SOCKSv5String> for SOCKSv5Address {
|
||||||
|
fn from(x: SOCKSv5String) -> SOCKSv5Address {
|
||||||
|
SOCKSv5Address::Hostname(x.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a str> for SOCKSv5Address {
|
||||||
|
fn from(x: &str) -> SOCKSv5Address {
|
||||||
|
SOCKSv5Address::Hostname(x.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for SOCKSv5Address {
|
||||||
|
fn from(x: String) -> SOCKSv5Address {
|
||||||
|
SOCKSv5Address::Hostname(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for SOCKSv5Address {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
|
||||||
|
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
|
||||||
|
SOCKSv5Address::Hostname(a) => write!(f, "{}", a),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+";
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
use proptest::prelude::{any, prop_oneof, Arbitrary, BoxedStrategy, Strategy};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
impl Arbitrary for SOCKSv5Address {
|
||||||
|
type Parameters = Option<u16>;
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||||
|
let max_len = args.unwrap_or(32) as usize;
|
||||||
|
|
||||||
|
prop_oneof![
|
||||||
|
any::<Ipv4Addr>().prop_map(SOCKSv5Address::IP4),
|
||||||
|
any::<Ipv6Addr>().prop_map(SOCKSv5Address::IP6),
|
||||||
|
HOSTNAME_REGEX.prop_map(move |mut hostname| {
|
||||||
|
hostname.shrink_to(max_len);
|
||||||
|
SOCKSv5Address::Hostname(hostname)
|
||||||
|
}),
|
||||||
|
]
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum SOCKSv5AddressReadError {
|
||||||
|
#[error("Bad address type {0} (expected 1, 3, or 4)")]
|
||||||
|
BadAddressType(u8),
|
||||||
|
#[error("Read buffer error: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
SOCKSv5StringError(#[from] SOCKSv5StringReadError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for SOCKSv5AddressReadError {
|
||||||
|
fn from(x: std::io::Error) -> SOCKSv5AddressReadError {
|
||||||
|
SOCKSv5AddressReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum SOCKSv5AddressWriteError {
|
||||||
|
#[error(transparent)]
|
||||||
|
SOCKSv5StringError(#[from] SOCKSv5StringWriteError),
|
||||||
|
#[error("Write buffer error: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for SOCKSv5AddressWriteError {
|
||||||
|
fn from(x: std::io::Error) -> SOCKSv5AddressWriteError {
|
||||||
|
SOCKSv5AddressWriteError::WriteError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SOCKSv5Address {
|
||||||
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
|
r: &mut R,
|
||||||
|
) -> Result<Self, SOCKSv5AddressReadError> {
|
||||||
|
match r.read_u8().await? {
|
||||||
|
1 => {
|
||||||
|
let mut addr_buffer = [0; 4];
|
||||||
|
r.read_exact(&mut addr_buffer).await?;
|
||||||
|
let ip4 = Ipv4Addr::from(addr_buffer);
|
||||||
|
Ok(SOCKSv5Address::IP4(ip4))
|
||||||
|
}
|
||||||
|
|
||||||
|
3 => {
|
||||||
|
let string = SOCKSv5String::read(r).await?;
|
||||||
|
Ok(SOCKSv5Address::from(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
4 => {
|
||||||
|
let mut addr_buffer = [0; 16];
|
||||||
|
r.read_exact(&mut addr_buffer).await?;
|
||||||
|
let ip6 = Ipv6Addr::from(addr_buffer);
|
||||||
|
Ok(SOCKSv5Address::IP6(ip6))
|
||||||
|
}
|
||||||
|
|
||||||
|
x => Err(SOCKSv5AddressReadError::BadAddressType(x)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
|
&self,
|
||||||
|
w: &mut W,
|
||||||
|
) -> Result<(), SOCKSv5AddressWriteError> {
|
||||||
|
match self {
|
||||||
|
SOCKSv5Address::IP4(x) => {
|
||||||
|
w.write_u8(1).await?;
|
||||||
|
w.write_all(&x.octets()).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
SOCKSv5Address::IP6(x) => {
|
||||||
|
w.write_u8(4).await?;
|
||||||
|
w.write_all(&x.octets()).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
SOCKSv5Address::Hostname(x) => {
|
||||||
|
w.write_u8(3).await?;
|
||||||
|
let string = SOCKSv5String::from(x.clone());
|
||||||
|
string.write(w).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::standard_roundtrip!(socks_address_roundtrips, SOCKSv5Address);
|
||||||
@@ -1,36 +1,3 @@
|
|||||||
use async_socks5::network::Builtin;
|
fn main() -> Result<(), ()> {
|
||||||
use async_socks5::server::{SOCKSv5Server, SecurityParameters};
|
|
||||||
use async_std::io;
|
|
||||||
use futures::stream::StreamExt;
|
|
||||||
use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TermLogger, TerminalMode};
|
|
||||||
|
|
||||||
#[async_std::main]
|
|
||||||
async fn main() -> Result<(), io::Error> {
|
|
||||||
CombinedLogger::init(vec![TermLogger::new(
|
|
||||||
LevelFilter::Debug,
|
|
||||||
Config::default(),
|
|
||||||
TerminalMode::Mixed,
|
|
||||||
ColorChoice::Auto,
|
|
||||||
)])
|
|
||||||
.expect("Couldn't initialize logger");
|
|
||||||
|
|
||||||
let params = SecurityParameters {
|
|
||||||
allow_unauthenticated: true,
|
|
||||||
allow_connection: None,
|
|
||||||
check_password: None,
|
|
||||||
connect_tls: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut server = SOCKSv5Server::new(Builtin::new(), params);
|
|
||||||
server.start("127.0.0.1", 9999).await?;
|
|
||||||
|
|
||||||
let mut responses = Box::pin(server.subserver_results());
|
|
||||||
|
|
||||||
while let Some(response) = responses.next().await {
|
|
||||||
if let Err(e) = response {
|
|
||||||
println!("Server failed with: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
175
src/client.rs
175
src/client.rs
@@ -1,57 +1,47 @@
|
|||||||
use crate::errors::{DeserializationError, SerializationError};
|
use crate::address::SOCKSv5Address;
|
||||||
use crate::messages::{
|
use crate::messages::{
|
||||||
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting,
|
AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandWriteError,
|
||||||
ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus,
|
ClientConnectionRequest, ClientGreeting, ClientGreetingWriteError, ClientUsernamePassword,
|
||||||
|
ClientUsernamePasswordWriteError, ServerAuthResponse, ServerAuthResponseReadError,
|
||||||
|
ServerChoice, ServerChoiceReadError, ServerResponse, ServerResponseReadError,
|
||||||
|
ServerResponseStatus,
|
||||||
};
|
};
|
||||||
use crate::network::datagram::GenericDatagramSocket;
|
use std::future::Future;
|
||||||
use crate::network::generic::{IntoErrorResponse, Networklike};
|
|
||||||
use crate::network::listener::GenericListener;
|
|
||||||
use crate::network::stream::GenericStream;
|
|
||||||
use crate::network::SOCKSv5Address;
|
|
||||||
use async_std::io;
|
|
||||||
use async_std::sync::{Arc, Mutex};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::Future;
|
|
||||||
use log::{info, trace, warn};
|
|
||||||
use std::fmt::{Debug, Display};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum SOCKSv5Error<E: Debug + Display> {
|
pub enum SOCKSv5ClientError {
|
||||||
#[error("SOCKSv5 serialization error: {0}")]
|
#[error("Underlying networking error: {0}")]
|
||||||
SerializationError(#[from] SerializationError),
|
NetworkingError(String),
|
||||||
#[error("SOCKSv5 deserialization error: {0}")]
|
#[error("Client greeting write error: {0}")]
|
||||||
DeserializationError(#[from] DeserializationError),
|
ClientWriteError(#[from] ClientGreetingWriteError),
|
||||||
#[error("No acceptable authentication methods available")]
|
#[error("Server choice error: {0}")]
|
||||||
NoAuthMethodsAllowed,
|
ServerChoiceError(#[from] ServerChoiceReadError),
|
||||||
|
#[error("Error writing credentials: {0}")]
|
||||||
|
CredentialWriteError(#[from] ClientUsernamePasswordWriteError),
|
||||||
|
#[error("Server auth read error: {0}")]
|
||||||
|
AuthResponseError(#[from] ServerAuthResponseReadError),
|
||||||
#[error("Authentication failed")]
|
#[error("Authentication failed")]
|
||||||
AuthenticationFailed,
|
AuthenticationFailed,
|
||||||
#[error("Server chose an unsupported authentication method ({0}")]
|
#[error("No authentication methods allowed")]
|
||||||
|
NoAuthMethodsAllowed,
|
||||||
|
#[error("Unsupported authentication method chosen ({0})")]
|
||||||
UnsupportedAuthMethodChosen(AuthenticationMethod),
|
UnsupportedAuthMethodChosen(AuthenticationMethod),
|
||||||
|
#[error("Client connection command write error: {0}")]
|
||||||
|
ClientCommandWriteError(#[from] ClientConnectionCommandWriteError),
|
||||||
#[error("Server said no: {0}")]
|
#[error("Server said no: {0}")]
|
||||||
ServerFailure(#[from] ServerResponseStatus),
|
ServerRejected(#[from] ServerResponseStatus),
|
||||||
#[error("Connection error: {0}")]
|
#[error("Server response read failure: {0}")]
|
||||||
ConnectionError(#[from] io::Error),
|
ServerResponseError(#[from] ServerResponseReadError),
|
||||||
#[error("Underlying network error: {0}")]
|
|
||||||
UnderlyingNetwork(E),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: Debug + Display> IntoErrorResponse for SOCKSv5Error<E> {
|
impl From<std::io::Error> for SOCKSv5ClientError {
|
||||||
fn into_response(&self) -> ServerResponseStatus {
|
fn from(x: std::io::Error) -> SOCKSv5ClientError {
|
||||||
match self {
|
SOCKSv5ClientError::NetworkingError(format!("{}", x))
|
||||||
SOCKSv5Error::ServerFailure(v) => v.clone(),
|
|
||||||
_ => ServerResponseStatus::GeneralFailure,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SOCKSv5Client<N: Networklike + Sync> {
|
|
||||||
network: Arc<Mutex<N>>,
|
|
||||||
login_info: LoginInfo,
|
|
||||||
address: SOCKSv5Address,
|
|
||||||
port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct LoginInfo {
|
pub struct LoginInfo {
|
||||||
pub username_password: Option<UsernamePassword>,
|
pub username_password: Option<UsernamePassword>,
|
||||||
}
|
}
|
||||||
@@ -64,7 +54,7 @@ impl Default for LoginInfo {
|
|||||||
|
|
||||||
impl LoginInfo {
|
impl LoginInfo {
|
||||||
/// Generate an empty bit of login information.
|
/// Generate an empty bit of login information.
|
||||||
fn new() -> LoginInfo {
|
pub fn new() -> LoginInfo {
|
||||||
LoginInfo {
|
LoginInfo {
|
||||||
username_password: None,
|
username_password: None,
|
||||||
}
|
}
|
||||||
@@ -89,22 +79,24 @@ pub struct UsernamePassword {
|
|||||||
pub password: String,
|
pub password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N> SOCKSv5Client<N>
|
pub struct SOCKSv5Client {
|
||||||
where
|
login_info: LoginInfo,
|
||||||
N: Networklike + Sync,
|
address: SOCKSv5Address,
|
||||||
{
|
port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SOCKSv5Client {
|
||||||
/// Create a new SOCKSv5 client connection over the given steam, using the given
|
/// Create a new SOCKSv5 client connection over the given steam, using the given
|
||||||
/// authentication information. As part of the process of building this object, we
|
/// authentication information. As part of the process of building this object, we
|
||||||
/// do a little test run to make sure that we can login effectively; this should save
|
/// do a little test run to make sure that we can login effectively; this should save
|
||||||
/// from *some* surprises later on. If you'd rather *not* do that, though, you can
|
/// from *some* surprises later on. If you'd rather *not* do that, though, you can
|
||||||
/// try `unchecked_new`.
|
/// try `unchecked_new`.
|
||||||
pub async fn new<A: Into<SOCKSv5Address>>(
|
pub async fn new<A: Into<SOCKSv5Address>>(
|
||||||
network: N,
|
|
||||||
login: LoginInfo,
|
login: LoginInfo,
|
||||||
server_addr: A,
|
server_addr: A,
|
||||||
server_port: u16,
|
server_port: u16,
|
||||||
) -> Result<Self, SOCKSv5Error<N::Error>> {
|
) -> Result<Self, SOCKSv5ClientError> {
|
||||||
let base_version = SOCKSv5Client::unchecked_new(network, login, server_addr, server_port);
|
let base_version = SOCKSv5Client::unchecked_new(login, server_addr, server_port);
|
||||||
let _ = base_version.start_session().await?;
|
let _ = base_version.start_session().await?;
|
||||||
Ok(base_version)
|
Ok(base_version)
|
||||||
}
|
}
|
||||||
@@ -113,13 +105,11 @@ where
|
|||||||
/// connection sequence at the expense of an increased possibility of an error
|
/// connection sequence at the expense of an increased possibility of an error
|
||||||
/// later on down the road.
|
/// later on down the road.
|
||||||
pub fn unchecked_new<A: Into<SOCKSv5Address>>(
|
pub fn unchecked_new<A: Into<SOCKSv5Address>>(
|
||||||
network: N,
|
|
||||||
login_info: LoginInfo,
|
login_info: LoginInfo,
|
||||||
address: A,
|
address: A,
|
||||||
port: u16,
|
port: u16,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
SOCKSv5Client {
|
SOCKSv5Client {
|
||||||
network: Arc::new(Mutex::new(network)),
|
|
||||||
login_info,
|
login_info,
|
||||||
address: address.into(),
|
address: address.into(),
|
||||||
port,
|
port,
|
||||||
@@ -128,17 +118,17 @@ where
|
|||||||
|
|
||||||
/// This runs the connection and negotiates login, as required, and then returns
|
/// This runs the connection and negotiates login, as required, and then returns
|
||||||
/// the stream the caller should use to do ... whatever it wants to do.
|
/// the stream the caller should use to do ... whatever it wants to do.
|
||||||
async fn start_session(&self) -> Result<GenericStream, SOCKSv5Error<N::Error>> {
|
async fn start_session(&self) -> Result<TcpStream, SOCKSv5ClientError> {
|
||||||
// create the initial stream
|
// create the initial stream
|
||||||
let mut stream = {
|
let mut stream = match &self.address {
|
||||||
let mut network = self.network.lock().await;
|
SOCKSv5Address::IP4(x) => TcpStream::connect((*x, self.port)).await?,
|
||||||
network.connect(self.address.clone(), self.port).await
|
SOCKSv5Address::IP6(x) => TcpStream::connect((*x, self.port)).await?,
|
||||||
}
|
SOCKSv5Address::Hostname(x) => TcpStream::connect((x.as_ref(), self.port)).await?,
|
||||||
.map_err(SOCKSv5Error::UnderlyingNetwork)?;
|
};
|
||||||
|
|
||||||
// compute how we can log in
|
// compute how we can log in
|
||||||
let acceptable_methods = self.login_info.acceptable_methods();
|
let acceptable_methods = self.login_info.acceptable_methods();
|
||||||
trace!(
|
tracing::trace!(
|
||||||
"Computed acceptable methods -- {:?} -- sending client greeting.",
|
"Computed acceptable methods -- {:?} -- sending client greeting.",
|
||||||
acceptable_methods
|
acceptable_methods
|
||||||
);
|
);
|
||||||
@@ -146,9 +136,9 @@ where
|
|||||||
// Negotiate with the server. Well. "Negotiate."
|
// Negotiate with the server. Well. "Negotiate."
|
||||||
let client_greeting = ClientGreeting { acceptable_methods };
|
let client_greeting = ClientGreeting { acceptable_methods };
|
||||||
client_greeting.write(&mut stream).await?;
|
client_greeting.write(&mut stream).await?;
|
||||||
trace!("Write client greeting, waiting for server's choice.");
|
tracing::trace!("Write client greeting, waiting for server's choice.");
|
||||||
let server_choice = ServerChoice::read(&mut stream).await?;
|
let server_choice = ServerChoice::read(&mut stream).await?;
|
||||||
trace!("Received server's choice: {}", server_choice.chosen_method);
|
tracing::trace!("Received server's choice: {}", server_choice.chosen_method);
|
||||||
|
|
||||||
// Let's do it!
|
// Let's do it!
|
||||||
match server_choice.chosen_method {
|
match server_choice.chosen_method {
|
||||||
@@ -158,30 +148,32 @@ where
|
|||||||
let (username, password) = if let Some(ref linfo) =
|
let (username, password) = if let Some(ref linfo) =
|
||||||
self.login_info.username_password
|
self.login_info.username_password
|
||||||
{
|
{
|
||||||
trace!("Server requested username/password, getting data from login info.");
|
tracing::trace!(
|
||||||
|
"Server requested username/password, getting data from login info."
|
||||||
|
);
|
||||||
(linfo.username.clone(), linfo.password.clone())
|
(linfo.username.clone(), linfo.password.clone())
|
||||||
} else {
|
} else {
|
||||||
warn!("Server requested username/password, but we weren't provided one. Very weird.");
|
tracing::warn!("Server requested username/password, but we weren't provided one. Very weird.");
|
||||||
("".to_string(), "".to_string())
|
("".to_string(), "".to_string())
|
||||||
};
|
};
|
||||||
|
|
||||||
let auth_request = ClientUsernamePassword { username, password };
|
let auth_request = ClientUsernamePassword { username, password };
|
||||||
|
|
||||||
trace!("Writing password information.");
|
tracing::trace!("Writing password information.");
|
||||||
auth_request.write(&mut stream).await?;
|
auth_request.write(&mut stream).await?;
|
||||||
let server_response = ServerAuthResponse::read(&mut stream).await?;
|
let server_response = ServerAuthResponse::read(&mut stream).await?;
|
||||||
trace!("Got server response: {}", server_response.success);
|
tracing::trace!("Got server response: {}", server_response.success);
|
||||||
|
|
||||||
if !server_response.success {
|
if !server_response.success {
|
||||||
return Err(SOCKSv5Error::AuthenticationFailed);
|
return Err(SOCKSv5ClientError::AuthenticationFailed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AuthenticationMethod::NoAcceptableMethods => {
|
AuthenticationMethod::NoAcceptableMethods => {
|
||||||
return Err(SOCKSv5Error::NoAuthMethodsAllowed)
|
return Err(SOCKSv5ClientError::NoAuthMethodsAllowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
x => return Err(SOCKSv5Error::UnsupportedAuthMethodChosen(x)),
|
x => return Err(SOCKSv5ClientError::UnsupportedAuthMethodChosen(x)),
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
@@ -193,12 +185,12 @@ where
|
|||||||
/// person to listen on. So this function takes an async function, which it
|
/// person to listen on. So this function takes an async function, which it
|
||||||
/// will pass this information to once it has it. It's up to that function,
|
/// will pass this information to once it has it. It's up to that function,
|
||||||
/// then, to communicate this to its peer.
|
/// then, to communicate this to its peer.
|
||||||
pub async fn remote_listen<A, Fut: Future<Output = Result<(), SOCKSv5Error<N::Error>>>>(
|
pub async fn remote_listen<A, Fut: Future<Output = Result<(), SOCKSv5ClientError>>>(
|
||||||
self,
|
self,
|
||||||
addr: A,
|
addr: A,
|
||||||
port: u16,
|
port: u16,
|
||||||
callback: impl FnOnce(SOCKSv5Address, u16) -> Fut,
|
callback: impl FnOnce(SOCKSv5Address, u16) -> Fut,
|
||||||
) -> Result<(SOCKSv5Address, u16, GenericStream), SOCKSv5Error<N::Error>>
|
) -> Result<(SOCKSv5Address, u16, TcpStream), SOCKSv5ClientError>
|
||||||
where
|
where
|
||||||
A: Into<SOCKSv5Address>,
|
A: Into<SOCKSv5Address>,
|
||||||
{
|
{
|
||||||
@@ -217,9 +209,12 @@ where
|
|||||||
return Err(initial_response.status.into());
|
return Err(initial_response.status.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
tracing::info!(
|
||||||
"Proxy port binding of {}:{} established; server listening on {}:{}",
|
"Proxy port binding of {}:{} established; server listening on {}:{}",
|
||||||
target, port, initial_response.bound_address, initial_response.bound_port
|
target,
|
||||||
|
port,
|
||||||
|
initial_response.bound_address,
|
||||||
|
initial_response.bound_port
|
||||||
);
|
);
|
||||||
|
|
||||||
callback(initial_response.bound_address, initial_response.bound_port).await?;
|
callback(initial_response.bound_address, initial_response.bound_port).await?;
|
||||||
@@ -229,9 +224,10 @@ where
|
|||||||
return Err(secondary_response.status.into());
|
return Err(secondary_response.status.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
tracing::info!(
|
||||||
"Proxy bind got a connection from {}:{}",
|
"Proxy bind got a connection from {}:{}",
|
||||||
secondary_response.bound_address, secondary_response.bound_port
|
secondary_response.bound_address,
|
||||||
|
secondary_response.bound_port
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
@@ -240,20 +236,12 @@ where
|
|||||||
stream,
|
stream,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
pub async fn connect<A: Send + Into<SOCKSv5Address>>(
|
||||||
impl<N> Networklike for SOCKSv5Client<N>
|
|
||||||
where
|
|
||||||
N: Networklike + Sync + Send,
|
|
||||||
{
|
|
||||||
type Error = SOCKSv5Error<N::Error>;
|
|
||||||
|
|
||||||
async fn connect<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
&mut self,
|
||||||
addr: A,
|
addr: A,
|
||||||
port: u16,
|
port: u16,
|
||||||
) -> Result<GenericStream, Self::Error> {
|
) -> Result<TcpStream, SOCKSv5ClientError> {
|
||||||
let mut stream = self.start_session().await?;
|
let mut stream = self.start_session().await?;
|
||||||
let target = addr.into();
|
let target = addr.into();
|
||||||
|
|
||||||
@@ -266,29 +254,16 @@ where
|
|||||||
let response = ServerResponse::read(&mut stream).await?;
|
let response = ServerResponse::read(&mut stream).await?;
|
||||||
|
|
||||||
if response.status == ServerResponseStatus::RequestGranted {
|
if response.status == ServerResponseStatus::RequestGranted {
|
||||||
info!(
|
tracing::info!(
|
||||||
"Proxy connection to {}:{} established; server is using {}:{}",
|
"Proxy connection to {}:{} established; server is using {}:{}",
|
||||||
target, port, response.bound_address, response.bound_port
|
target,
|
||||||
|
port,
|
||||||
|
response.bound_address,
|
||||||
|
response.bound_port
|
||||||
);
|
);
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
} else {
|
} else {
|
||||||
Err(response.status.into())
|
Err(response.status.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn listen<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
_addr: A,
|
|
||||||
_port: u16,
|
|
||||||
) -> Result<GenericListener<Self::Error>, Self::Error> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn bind<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
_addr: A,
|
|
||||||
_port: u16,
|
|
||||||
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
217
src/errors.rs
217
src/errors.rs
@@ -1,217 +0,0 @@
|
|||||||
use std::io;
|
|
||||||
use std::string::FromUtf8Error;
|
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
use crate::network::SOCKSv5Address;
|
|
||||||
|
|
||||||
/// All the errors that can pop up when trying to turn raw bytes into SOCKSv5
|
|
||||||
/// messages.
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum DeserializationError {
|
|
||||||
#[error("Invalid protocol version for packet ({1} is not {0}!)")]
|
|
||||||
InvalidVersion(u8, u8),
|
|
||||||
#[error("Not enough data found")]
|
|
||||||
NotEnoughData,
|
|
||||||
#[error("Ooops! Found an empty string where I shouldn't")]
|
|
||||||
InvalidEmptyString,
|
|
||||||
#[error("IO error: {0}")]
|
|
||||||
IOError(#[from] io::Error),
|
|
||||||
#[error("SOCKS authentication format parse error: {0}")]
|
|
||||||
AuthenticationMethodError(#[from] AuthenticationDeserializationError),
|
|
||||||
#[error("Error converting from UTF-8: {0}")]
|
|
||||||
UTF8Error(#[from] FromUtf8Error),
|
|
||||||
#[error("Invalid address type; wanted 1, 3, or 4, got {0}")]
|
|
||||||
InvalidAddressType(u8),
|
|
||||||
#[error("Invalid client command {0}; expected 1, 2, or 3")]
|
|
||||||
InvalidClientCommand(u8),
|
|
||||||
#[error("Invalid server status {0}; expected 0-8")]
|
|
||||||
InvalidServerResponse(u8),
|
|
||||||
#[error("Invalid byte in reserved byte ({0})")]
|
|
||||||
InvalidReservedByte(u8),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn des_error_reasonable_equals() {
|
|
||||||
let invalid_version1 = DeserializationError::InvalidVersion(1, 2);
|
|
||||||
let invalid_version2 = DeserializationError::InvalidVersion(1, 2);
|
|
||||||
assert_eq!(invalid_version1, invalid_version2);
|
|
||||||
|
|
||||||
let not_enough1 = DeserializationError::NotEnoughData;
|
|
||||||
let not_enough2 = DeserializationError::NotEnoughData;
|
|
||||||
assert_eq!(not_enough1, not_enough2);
|
|
||||||
|
|
||||||
let invalid_empty1 = DeserializationError::InvalidEmptyString;
|
|
||||||
let invalid_empty2 = DeserializationError::InvalidEmptyString;
|
|
||||||
assert_eq!(invalid_empty1, invalid_empty2);
|
|
||||||
|
|
||||||
let auth_method1 = DeserializationError::AuthenticationMethodError(
|
|
||||||
AuthenticationDeserializationError::NoDataFound,
|
|
||||||
);
|
|
||||||
let auth_method2 = DeserializationError::AuthenticationMethodError(
|
|
||||||
AuthenticationDeserializationError::NoDataFound,
|
|
||||||
);
|
|
||||||
assert_eq!(auth_method1, auth_method2);
|
|
||||||
|
|
||||||
let utf8a = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
|
|
||||||
let utf8b = DeserializationError::UTF8Error(String::from_utf8(vec![0, 159]).unwrap_err());
|
|
||||||
assert_eq!(utf8a, utf8b);
|
|
||||||
|
|
||||||
let invalid_address1 = DeserializationError::InvalidAddressType(3);
|
|
||||||
let invalid_address2 = DeserializationError::InvalidAddressType(3);
|
|
||||||
assert_eq!(invalid_address1, invalid_address2);
|
|
||||||
|
|
||||||
let invalid_client_cmd1 = DeserializationError::InvalidClientCommand(32);
|
|
||||||
let invalid_client_cmd2 = DeserializationError::InvalidClientCommand(32);
|
|
||||||
assert_eq!(invalid_client_cmd1, invalid_client_cmd2);
|
|
||||||
|
|
||||||
let invalid_server_resp1 = DeserializationError::InvalidServerResponse(42);
|
|
||||||
let invalid_server_resp2 = DeserializationError::InvalidServerResponse(42);
|
|
||||||
assert_eq!(invalid_server_resp1, invalid_server_resp2);
|
|
||||||
|
|
||||||
assert_ne!(invalid_version1, invalid_address1);
|
|
||||||
assert_ne!(not_enough1, invalid_empty1);
|
|
||||||
assert_ne!(auth_method1, invalid_client_cmd1);
|
|
||||||
assert_ne!(utf8a, invalid_server_resp1);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for DeserializationError {
|
|
||||||
fn eq(&self, other: &DeserializationError) -> bool {
|
|
||||||
match (self, other) {
|
|
||||||
(
|
|
||||||
&DeserializationError::InvalidVersion(a, b),
|
|
||||||
&DeserializationError::InvalidVersion(x, y),
|
|
||||||
) => (a == x) && (b == y),
|
|
||||||
(&DeserializationError::NotEnoughData, &DeserializationError::NotEnoughData) => true,
|
|
||||||
(
|
|
||||||
&DeserializationError::InvalidEmptyString,
|
|
||||||
&DeserializationError::InvalidEmptyString,
|
|
||||||
) => true,
|
|
||||||
(
|
|
||||||
&DeserializationError::AuthenticationMethodError(ref a),
|
|
||||||
&DeserializationError::AuthenticationMethodError(ref b),
|
|
||||||
) => a == b,
|
|
||||||
(&DeserializationError::UTF8Error(ref a), &DeserializationError::UTF8Error(ref b)) => {
|
|
||||||
a == b
|
|
||||||
}
|
|
||||||
(
|
|
||||||
&DeserializationError::InvalidAddressType(a),
|
|
||||||
&DeserializationError::InvalidAddressType(b),
|
|
||||||
) => a == b,
|
|
||||||
(
|
|
||||||
&DeserializationError::InvalidClientCommand(a),
|
|
||||||
&DeserializationError::InvalidClientCommand(b),
|
|
||||||
) => a == b,
|
|
||||||
(
|
|
||||||
&DeserializationError::InvalidServerResponse(a),
|
|
||||||
&DeserializationError::InvalidServerResponse(b),
|
|
||||||
) => a == b,
|
|
||||||
(
|
|
||||||
&DeserializationError::InvalidReservedByte(a),
|
|
||||||
&DeserializationError::InvalidReservedByte(b),
|
|
||||||
) => a == b,
|
|
||||||
(_, _) => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// All the errors that can occur trying to turn SOCKSv5 message structures
|
|
||||||
/// into raw bytes. There's a few places that the message structures allow
|
|
||||||
/// for information that can't be serialized; often, you have to be careful
|
|
||||||
/// about how long your strings are ...
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum SerializationError {
|
|
||||||
#[error("Too many authentication methods for serialization ({0} > 255)")]
|
|
||||||
TooManyAuthMethods(usize),
|
|
||||||
#[error("Invalid length for string: {0}")]
|
|
||||||
InvalidStringLength(String),
|
|
||||||
#[error("IO error: {0}")]
|
|
||||||
IOError(#[from] io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn ser_err_reasonable_equals() {
|
|
||||||
let too_many1 = SerializationError::TooManyAuthMethods(512);
|
|
||||||
let too_many2 = SerializationError::TooManyAuthMethods(512);
|
|
||||||
assert_eq!(too_many1, too_many2);
|
|
||||||
|
|
||||||
let invalid_str1 = SerializationError::InvalidStringLength("Whoopsy!".to_string());
|
|
||||||
let invalid_str2 = SerializationError::InvalidStringLength("Whoopsy!".to_string());
|
|
||||||
assert_eq!(invalid_str1, invalid_str2);
|
|
||||||
|
|
||||||
assert_ne!(too_many1, invalid_str1);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for SerializationError {
|
|
||||||
fn eq(&self, other: &SerializationError) -> bool {
|
|
||||||
match (self, other) {
|
|
||||||
(
|
|
||||||
&SerializationError::TooManyAuthMethods(a),
|
|
||||||
&SerializationError::TooManyAuthMethods(b),
|
|
||||||
) => a == b,
|
|
||||||
(
|
|
||||||
&SerializationError::InvalidStringLength(ref a),
|
|
||||||
&SerializationError::InvalidStringLength(ref b),
|
|
||||||
) => a == b,
|
|
||||||
(_, _) => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum AuthenticationDeserializationError {
|
|
||||||
#[error("No data found deserializing SOCKS authentication type")]
|
|
||||||
NoDataFound,
|
|
||||||
#[error("Invalid authentication type value: {0}")]
|
|
||||||
InvalidAuthenticationByte(u8),
|
|
||||||
#[error("IO error reading SOCKS authentication type: {0}")]
|
|
||||||
IOError(#[from] io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn auth_des_err_reasonable_equals() {
|
|
||||||
let no_data1 = AuthenticationDeserializationError::NoDataFound;
|
|
||||||
let no_data2 = AuthenticationDeserializationError::NoDataFound;
|
|
||||||
assert_eq!(no_data1, no_data2);
|
|
||||||
|
|
||||||
let invalid_auth1 = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
|
|
||||||
let invalid_auth2 = AuthenticationDeserializationError::InvalidAuthenticationByte(39);
|
|
||||||
assert_eq!(invalid_auth1, invalid_auth2);
|
|
||||||
|
|
||||||
assert_ne!(no_data1, invalid_auth1);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for AuthenticationDeserializationError {
|
|
||||||
fn eq(&self, other: &AuthenticationDeserializationError) -> bool {
|
|
||||||
match (self, other) {
|
|
||||||
(
|
|
||||||
&AuthenticationDeserializationError::NoDataFound,
|
|
||||||
&AuthenticationDeserializationError::NoDataFound,
|
|
||||||
) => true,
|
|
||||||
(
|
|
||||||
&AuthenticationDeserializationError::InvalidAuthenticationByte(x),
|
|
||||||
&AuthenticationDeserializationError::InvalidAuthenticationByte(y),
|
|
||||||
) => x == y,
|
|
||||||
(_, _) => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The errors that can happen, as a server, when we're negotiating the start
|
|
||||||
/// of a SOCKS session.
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum AuthenticationError {
|
|
||||||
#[error("Firewall disallowed connection from {0}:{1}")]
|
|
||||||
FirewallRejected(SOCKSv5Address, u16),
|
|
||||||
#[error("Could not agree on an authentication method with the client")]
|
|
||||||
ItsNotUsItsYou,
|
|
||||||
#[error("Failure in serializing response message: {0}")]
|
|
||||||
SerializationError(#[from] SerializationError),
|
|
||||||
#[error("Failed TLS handshake")]
|
|
||||||
FailedTLSHandshake,
|
|
||||||
#[error("IO error writing response message: {0}")]
|
|
||||||
IOError(#[from] io::Error),
|
|
||||||
#[error("Failure in reading client message: {0}")]
|
|
||||||
DeserializationError(#[from] DeserializationError),
|
|
||||||
#[error("Username/password check failed (username was {0})")]
|
|
||||||
FailedUsernamePassword(String),
|
|
||||||
}
|
|
||||||
298
src/lib.rs
298
src/lib.rs
@@ -1,192 +1,172 @@
|
|||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod errors;
|
|
||||||
pub mod messages;
|
|
||||||
pub mod network;
|
|
||||||
mod serialize;
|
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
|
||||||
|
mod address;
|
||||||
|
mod messages;
|
||||||
|
mod security_parameters;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
|
use crate::address::SOCKSv5Address;
|
||||||
use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword};
|
use crate::client::{LoginInfo, SOCKSv5Client, UsernamePassword};
|
||||||
use crate::network::generic::Networklike;
|
use crate::security_parameters::SecurityParameters;
|
||||||
use crate::network::listener::Listenerlike;
|
use crate::server::SOCKSv5Server;
|
||||||
use crate::network::testing::TestingStack;
|
use std::io;
|
||||||
use crate::server::{SOCKSv5Server, SecurityParameters};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use async_std::channel::bounded;
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use async_std::io::prelude::WriteExt;
|
use tokio::net::{TcpSocket, TcpStream};
|
||||||
use async_std::task;
|
use tokio::sync::oneshot;
|
||||||
use futures::AsyncReadExt;
|
use tokio::task;
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn unrestricted_login() {
|
async fn unrestricted_login() {
|
||||||
task::block_on(async {
|
// generate the server
|
||||||
let network_stack = TestingStack::default();
|
let security_parameters = SecurityParameters::unrestricted();
|
||||||
|
let server = SOCKSv5Server::new(security_parameters);
|
||||||
|
server.start("localhost", 9999).await.unwrap();
|
||||||
|
|
||||||
// generate the server
|
let login_info = LoginInfo {
|
||||||
let security_parameters = SecurityParameters::unrestricted();
|
username_password: None,
|
||||||
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
|
};
|
||||||
server.start("localhost", 9999).await.unwrap();
|
let client = SOCKSv5Client::new(login_info, "localhost", 9999).await;
|
||||||
|
|
||||||
let login_info = LoginInfo {
|
assert!(client.is_ok());
|
||||||
username_password: None,
|
|
||||||
};
|
|
||||||
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9999).await;
|
|
||||||
|
|
||||||
assert!(client.is_ok());
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn disallow_unrestricted() {
|
async fn disallow_unrestricted() {
|
||||||
task::block_on(async {
|
// generate the server
|
||||||
let network_stack = TestingStack::default();
|
let mut security_parameters = SecurityParameters::unrestricted();
|
||||||
|
security_parameters.allow_unauthenticated = false;
|
||||||
|
let server = SOCKSv5Server::new(security_parameters);
|
||||||
|
server.start("localhost", 9998).await.unwrap();
|
||||||
|
|
||||||
// generate the server
|
let login_info = LoginInfo::default();
|
||||||
let mut security_parameters = SecurityParameters::unrestricted();
|
let client = SOCKSv5Client::new(login_info, "localhost", 9998).await;
|
||||||
security_parameters.allow_unauthenticated = false;
|
|
||||||
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
|
|
||||||
server.start("localhost", 9998).await.unwrap();
|
|
||||||
|
|
||||||
let login_info = LoginInfo {
|
assert!(client.is_err());
|
||||||
username_password: None,
|
|
||||||
};
|
|
||||||
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9998).await;
|
|
||||||
|
|
||||||
assert!(client.is_err());
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn password_checks() {
|
async fn password_checks() {
|
||||||
task::block_on(async {
|
// generate the server
|
||||||
let network_stack = TestingStack::default();
|
let security_parameters = SecurityParameters {
|
||||||
|
allow_unauthenticated: false,
|
||||||
|
allow_connection: None,
|
||||||
|
connect_tls: None,
|
||||||
|
check_password: Some(|username, password| {
|
||||||
|
username == "awick" && password == "password"
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let server = SOCKSv5Server::new(security_parameters);
|
||||||
|
server.start("localhost", 9997).await.unwrap();
|
||||||
|
|
||||||
// generate the server
|
// try the positive side
|
||||||
let security_parameters = SecurityParameters {
|
let login_info = LoginInfo {
|
||||||
allow_unauthenticated: false,
|
username_password: Some(UsernamePassword {
|
||||||
allow_connection: None,
|
username: "awick".to_string(),
|
||||||
connect_tls: None,
|
password: "password".to_string(),
|
||||||
check_password: Some(|username, password| {
|
}),
|
||||||
username == "awick" && password == "password"
|
};
|
||||||
}),
|
let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
|
||||||
};
|
assert!(client.is_ok());
|
||||||
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
|
|
||||||
server.start("localhost", 9997).await.unwrap();
|
|
||||||
|
|
||||||
// try the positive side
|
// try the negative side
|
||||||
let login_info = LoginInfo {
|
let login_info = LoginInfo {
|
||||||
username_password: Some(UsernamePassword {
|
username_password: Some(UsernamePassword {
|
||||||
username: "awick".to_string(),
|
username: "adamw".to_string(),
|
||||||
password: "password".to_string(),
|
password: "password".to_string(),
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
let client =
|
let client = SOCKSv5Client::new(login_info, "localhost", 9997).await;
|
||||||
SOCKSv5Client::new(network_stack.clone(), login_info, "localhost", 9997).await;
|
assert!(client.is_err());
|
||||||
assert!(client.is_ok());
|
|
||||||
|
|
||||||
// try the negative side
|
|
||||||
let login_info = LoginInfo {
|
|
||||||
username_password: Some(UsernamePassword {
|
|
||||||
username: "adamw".to_string(),
|
|
||||||
password: "password".to_string(),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9997).await;
|
|
||||||
assert!(client.is_err());
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn firewall_blocks() {
|
async fn firewall_blocks() {
|
||||||
task::block_on(async {
|
// generate the server
|
||||||
let network_stack = TestingStack::default();
|
let mut security_parameters = SecurityParameters::unrestricted();
|
||||||
|
security_parameters.allow_connection = Some(|_| false);
|
||||||
|
let server = SOCKSv5Server::new(security_parameters);
|
||||||
|
server.start("localhost", 9996).await.unwrap();
|
||||||
|
|
||||||
// generate the server
|
let login_info = LoginInfo::new();
|
||||||
let mut security_parameters = SecurityParameters::unrestricted();
|
let client = SOCKSv5Client::new(login_info, "localhost", 9996).await;
|
||||||
security_parameters.allow_connection = Some(|_, _| false);
|
|
||||||
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
|
|
||||||
server.start("localhost", 9996).await.unwrap();
|
|
||||||
|
|
||||||
let login_info = LoginInfo {
|
assert!(client.is_err());
|
||||||
username_password: None,
|
|
||||||
};
|
|
||||||
let client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9996).await;
|
|
||||||
|
|
||||||
assert!(client.is_err());
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn establish_stream() {
|
async fn establish_stream() -> io::Result<()> {
|
||||||
task::block_on(async {
|
let target_socket = TcpSocket::new_v4()?;
|
||||||
let mut network_stack = TestingStack::default();
|
target_socket.bind(SocketAddr::new(
|
||||||
|
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
|
||||||
|
1337,
|
||||||
|
))?;
|
||||||
|
let target_port = target_socket.listen(1)?;
|
||||||
|
|
||||||
let target_port = network_stack.listen("localhost", 1337).await.unwrap();
|
// generate the server
|
||||||
|
let security_parameters = SecurityParameters::unrestricted();
|
||||||
|
let server = SOCKSv5Server::new(security_parameters);
|
||||||
|
server.start("localhost", 9995).await.unwrap();
|
||||||
|
|
||||||
// generate the server
|
let login_info = LoginInfo {
|
||||||
let security_parameters = SecurityParameters::unrestricted();
|
username_password: None,
|
||||||
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
|
};
|
||||||
server.start("localhost", 9995).await.unwrap();
|
|
||||||
|
|
||||||
let login_info = LoginInfo {
|
let mut client = SOCKSv5Client::new(login_info, "localhost", 9995)
|
||||||
username_password: None,
|
.await
|
||||||
};
|
.unwrap();
|
||||||
|
|
||||||
let mut client = SOCKSv5Client::new(network_stack, login_info, "localhost", 9995)
|
task::spawn(async move {
|
||||||
|
let mut conn = client.connect("localhost", 1337).await.unwrap();
|
||||||
|
conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let (mut target_connection, _) = target_port.accept().await.unwrap();
|
||||||
|
let mut read_buffer = [0; 4];
|
||||||
|
target_connection
|
||||||
|
.read_exact(&mut read_buffer)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(read_buffer, [1, 3, 3, 7]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn bind_test() -> io::Result<()> {
|
||||||
|
let security_parameters = SecurityParameters::unrestricted();
|
||||||
|
let server = SOCKSv5Server::new(security_parameters);
|
||||||
|
server.start("localhost", 9994).await.unwrap();
|
||||||
|
|
||||||
|
let login_info = LoginInfo::default();
|
||||||
|
let client = SOCKSv5Client::new(login_info, "localhost", 9994)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (target_sender, target_receiver) = oneshot::channel();
|
||||||
|
|
||||||
|
task::spawn(async move {
|
||||||
|
let (_, _, mut conn) = client
|
||||||
|
.remote_listen("localhost", 9993, |addr, port| async move {
|
||||||
|
target_sender.send((addr, port)).unwrap();
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
task::spawn(async move {
|
conn.write_all(&[2, 3, 5, 7]).await.unwrap();
|
||||||
let mut conn = client.connect("localhost", 1337).await.unwrap();
|
});
|
||||||
conn.write_all(&[1, 3, 3, 7, 9]).await.unwrap();
|
|
||||||
});
|
|
||||||
|
|
||||||
let (mut target_connection, _, _) = target_port.accept().await.unwrap();
|
let (target_addr, target_port) = target_receiver.await.unwrap();
|
||||||
let mut read_buffer = [0; 4];
|
let mut stream = match target_addr {
|
||||||
target_connection
|
SOCKSv5Address::IP4(x) => TcpStream::connect((x, target_port)).await?,
|
||||||
.read_exact(&mut read_buffer)
|
SOCKSv5Address::IP6(x) => TcpStream::connect((x, target_port)).await?,
|
||||||
.await
|
SOCKSv5Address::Hostname(x) => TcpStream::connect((x, target_port)).await?,
|
||||||
.unwrap();
|
};
|
||||||
assert_eq!(read_buffer, [1, 3, 3, 7]);
|
let mut read_buffer = [0; 4];
|
||||||
})
|
stream.read_exact(&mut read_buffer).await.unwrap();
|
||||||
}
|
assert_eq!(read_buffer, [2, 3, 5, 7]);
|
||||||
|
Ok(())
|
||||||
#[test]
|
|
||||||
fn bind_test() {
|
|
||||||
task::block_on(async {
|
|
||||||
let mut network_stack = TestingStack::default();
|
|
||||||
|
|
||||||
let security_parameters = SecurityParameters::unrestricted();
|
|
||||||
let server = SOCKSv5Server::new(network_stack.clone(), security_parameters);
|
|
||||||
server.start("localhost", 9994).await.unwrap();
|
|
||||||
|
|
||||||
let login_info = LoginInfo::default();
|
|
||||||
let client = SOCKSv5Client::new(network_stack.clone(), login_info, "localhost", 9994)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let (target_sender, target_receiver) = bounded(1);
|
|
||||||
|
|
||||||
task::spawn(async move {
|
|
||||||
let (_, _, mut conn) = client
|
|
||||||
.remote_listen("localhost", 9993, |addr, port| async move {
|
|
||||||
target_sender.send((addr, port)).await.unwrap();
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
conn.write_all(&[2, 3, 5, 7]).await.unwrap();
|
|
||||||
});
|
|
||||||
|
|
||||||
let (target_addr, target_port) = target_receiver.recv().await.unwrap();
|
|
||||||
let mut stream = network_stack
|
|
||||||
.connect(target_addr, target_port)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let mut read_buffer = [0; 4];
|
|
||||||
stream.read_exact(&mut read_buffer).await.unwrap();
|
|
||||||
assert_eq!(read_buffer, [2, 3, 5, 7]);
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,12 +5,51 @@ mod client_username_password;
|
|||||||
mod server_auth_response;
|
mod server_auth_response;
|
||||||
mod server_choice;
|
mod server_choice;
|
||||||
mod server_response;
|
mod server_response;
|
||||||
pub(crate) mod utils;
|
|
||||||
|
|
||||||
pub use crate::messages::authentication_method::AuthenticationMethod;
|
pub(crate) mod string;
|
||||||
pub use crate::messages::client_command::{ClientConnectionCommand, ClientConnectionRequest};
|
|
||||||
pub use crate::messages::client_greeting::ClientGreeting;
|
pub use crate::messages::authentication_method::{
|
||||||
pub use crate::messages::client_username_password::ClientUsernamePassword;
|
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||||
pub use crate::messages::server_auth_response::ServerAuthResponse;
|
};
|
||||||
pub use crate::messages::server_choice::ServerChoice;
|
pub use crate::messages::client_command::{
|
||||||
pub use crate::messages::server_response::{ServerResponse, ServerResponseStatus};
|
ClientConnectionCommand, ClientConnectionCommandReadError, ClientConnectionCommandWriteError,
|
||||||
|
ClientConnectionRequest, ClientConnectionRequestReadError,
|
||||||
|
};
|
||||||
|
pub use crate::messages::client_greeting::{
|
||||||
|
ClientGreeting, ClientGreetingReadError, ClientGreetingWriteError,
|
||||||
|
};
|
||||||
|
pub use crate::messages::client_username_password::{
|
||||||
|
ClientUsernamePassword, ClientUsernamePasswordReadError, ClientUsernamePasswordWriteError,
|
||||||
|
};
|
||||||
|
pub use crate::messages::server_auth_response::{
|
||||||
|
ServerAuthResponse, ServerAuthResponseReadError, ServerAuthResponseWriteError,
|
||||||
|
};
|
||||||
|
pub use crate::messages::server_choice::{
|
||||||
|
ServerChoice, ServerChoiceReadError, ServerChoiceWriteError,
|
||||||
|
};
|
||||||
|
pub use crate::messages::server_response::{
|
||||||
|
ServerResponse, ServerResponseReadError, ServerResponseStatus, ServerResponseWriteError,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! standard_roundtrip {
|
||||||
|
($name: ident, $t: ty) => {
|
||||||
|
proptest::proptest! {
|
||||||
|
#[test]
|
||||||
|
fn $name(xs: $t) {
|
||||||
|
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
let buffer = vec![];
|
||||||
|
let mut write_cursor = Cursor::new(buffer);
|
||||||
|
xs.write(&mut write_cursor).await.unwrap();
|
||||||
|
let serialized_form = write_cursor.into_inner();
|
||||||
|
let mut read_cursor = Cursor::new(serialized_form);
|
||||||
|
let ys = <$t>::read(&mut read_cursor);
|
||||||
|
assert_eq!(xs, ys.await.unwrap());
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
use crate::errors::{AuthenticationDeserializationError, DeserializationError, SerializationError};
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use async_std::task;
|
use proptest::prelude::{prop_oneof, Arbitrary, Just, Strategy};
|
||||||
#[cfg(test)]
|
|
||||||
use futures::io::Cursor;
|
|
||||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
use proptest::proptest;
|
|
||||||
#[cfg(test)]
|
|
||||||
use proptest::prelude::{Arbitrary, Just, Strategy, prop_oneof};
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use proptest::strategy::BoxedStrategy;
|
use proptest::strategy::BoxedStrategy;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
#[cfg(test)]
|
||||||
|
use std::io::Cursor;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
@@ -28,6 +24,34 @@ pub enum AuthenticationMethod {
|
|||||||
NoAcceptableMethods,
|
NoAcceptableMethods,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum AuthenticationMethodReadError {
|
||||||
|
#[error("Invalid authentication method #{0}")]
|
||||||
|
UnknownAuthenticationMethod(u8),
|
||||||
|
#[error("Error in underlying buffer: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for AuthenticationMethodReadError {
|
||||||
|
fn from(x: std::io::Error) -> AuthenticationMethodReadError {
|
||||||
|
AuthenticationMethodReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum AuthenticationMethodWriteError {
|
||||||
|
#[error("Trying to write invalid authentication method #{0}")]
|
||||||
|
InvalidAuthMethod(u8),
|
||||||
|
#[error("Error in underlying buffer: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for AuthenticationMethodWriteError {
|
||||||
|
fn from(x: std::io::Error) -> AuthenticationMethodWriteError {
|
||||||
|
AuthenticationMethodWriteError::WriteError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl fmt::Display for AuthenticationMethod {
|
impl fmt::Display for AuthenticationMethod {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
@@ -65,26 +89,17 @@ impl Arbitrary for AuthenticationMethod {
|
|||||||
Just(AuthenticationMethod::MultiAuthenticationFramework),
|
Just(AuthenticationMethod::MultiAuthenticationFramework),
|
||||||
Just(AuthenticationMethod::JSONPropertyBlock),
|
Just(AuthenticationMethod::JSONPropertyBlock),
|
||||||
Just(AuthenticationMethod::NoAcceptableMethods),
|
Just(AuthenticationMethod::NoAcceptableMethods),
|
||||||
|
|
||||||
(0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod),
|
(0x80u8..=0xfe).prop_map(AuthenticationMethod::PrivateMethod),
|
||||||
].boxed()
|
]
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
impl AuthenticationMethod {
|
impl AuthenticationMethod {
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
) -> Result<AuthenticationMethod, DeserializationError> {
|
) -> Result<AuthenticationMethod, AuthenticationMethodReadError> {
|
||||||
let mut byte_buffer = [0u8; 1];
|
match r.read_u8().await? {
|
||||||
let amount_read = r.read(&mut byte_buffer).await?;
|
|
||||||
|
|
||||||
if amount_read == 0 {
|
|
||||||
return Err(AuthenticationDeserializationError::NoDataFound.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
match byte_buffer[0] {
|
|
||||||
0 => Ok(AuthenticationMethod::None),
|
0 => Ok(AuthenticationMethod::None),
|
||||||
1 => Ok(AuthenticationMethod::GSSAPI),
|
1 => Ok(AuthenticationMethod::GSSAPI),
|
||||||
2 => Ok(AuthenticationMethod::UsernameAndPassword),
|
2 => Ok(AuthenticationMethod::UsernameAndPassword),
|
||||||
@@ -96,14 +111,16 @@ impl AuthenticationMethod {
|
|||||||
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
|
9 => Ok(AuthenticationMethod::JSONPropertyBlock),
|
||||||
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
|
x if (0x80..=0xfe).contains(&x) => Ok(AuthenticationMethod::PrivateMethod(x)),
|
||||||
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
|
0xff => Ok(AuthenticationMethod::NoAcceptableMethods),
|
||||||
e => Err(AuthenticationDeserializationError::InvalidAuthenticationByte(e).into()),
|
e => Err(AuthenticationMethodReadError::UnknownAuthenticationMethod(
|
||||||
|
e,
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
&self,
|
&self,
|
||||||
w: &mut W,
|
w: &mut W,
|
||||||
) -> Result<(), SerializationError> {
|
) -> Result<(), AuthenticationMethodWriteError> {
|
||||||
let value = match self {
|
let value = match self {
|
||||||
AuthenticationMethod::None => 0,
|
AuthenticationMethod::None => 0,
|
||||||
AuthenticationMethod::GSSAPI => 1,
|
AuthenticationMethod::GSSAPI => 1,
|
||||||
@@ -114,31 +131,32 @@ impl AuthenticationMethod {
|
|||||||
AuthenticationMethod::NDS => 7,
|
AuthenticationMethod::NDS => 7,
|
||||||
AuthenticationMethod::MultiAuthenticationFramework => 8,
|
AuthenticationMethod::MultiAuthenticationFramework => 8,
|
||||||
AuthenticationMethod::JSONPropertyBlock => 9,
|
AuthenticationMethod::JSONPropertyBlock => 9,
|
||||||
AuthenticationMethod::PrivateMethod(pm) => *pm,
|
AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(pm) => *pm,
|
||||||
|
AuthenticationMethod::PrivateMethod(pm) => {
|
||||||
|
return Err(AuthenticationMethodWriteError::InvalidAuthMethod(*pm))
|
||||||
|
}
|
||||||
AuthenticationMethod::NoAcceptableMethods => 0xff,
|
AuthenticationMethod::NoAcceptableMethods => 0xff,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(w.write_all(&[value]).await?)
|
Ok(w.write_u8(value).await?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
crate::standard_roundtrip!(auth_byte_roundtrips, AuthenticationMethod);
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn bad_byte() {
|
async fn bad_byte() {
|
||||||
let no_len = vec![42];
|
let no_len = vec![42];
|
||||||
let mut cursor = Cursor::new(no_len);
|
let mut cursor = Cursor::new(no_len);
|
||||||
let ys = AuthenticationMethod::read(&mut cursor);
|
let ys = AuthenticationMethod::read(&mut cursor).await.unwrap_err();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Err(DeserializationError::AuthenticationMethodError(
|
AuthenticationMethodReadError::UnknownAuthenticationMethod(42),
|
||||||
AuthenticationDeserializationError::InvalidAuthenticationByte(42)
|
ys
|
||||||
)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn display_isnt_empty() {
|
async fn display_isnt_empty() {
|
||||||
let vals = vec![
|
let vals = vec![
|
||||||
AuthenticationMethod::None,
|
AuthenticationMethod::None,
|
||||||
AuthenticationMethod::GSSAPI,
|
AuthenticationMethod::GSSAPI,
|
||||||
|
|||||||
@@ -1,20 +1,10 @@
|
|||||||
use crate::errors::{DeserializationError, SerializationError};
|
use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
|
||||||
use crate::network::SOCKSv5Address;
|
|
||||||
use crate::serialize::read_amt;
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::io::ErrorKind;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::task;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::io::Cursor;
|
|
||||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
|
||||||
use log::debug;
|
|
||||||
use proptest::proptest;
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use proptest_derive::Arbitrary;
|
use proptest_derive::Arbitrary;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use std::net::Ipv4Addr;
|
use std::io::Cursor;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||||
#[cfg_attr(test, derive(Arbitrary))]
|
#[cfg_attr(test, derive(Arbitrary))]
|
||||||
@@ -24,6 +14,60 @@ pub enum ClientConnectionCommand {
|
|||||||
AssociateUDPPort,
|
AssociateUDPPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ClientConnectionCommandReadError {
|
||||||
|
#[error("Invalid client connection command code: {0}")]
|
||||||
|
InvalidClientConnectionCommand(u8),
|
||||||
|
#[error("Underlying buffer read error: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ClientConnectionCommandReadError {
|
||||||
|
fn from(x: std::io::Error) -> ClientConnectionCommandReadError {
|
||||||
|
ClientConnectionCommandReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ClientConnectionCommandWriteError {
|
||||||
|
#[error("Underlying buffer write error: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
SOCKSAddressWriteError(#[from] SOCKSv5AddressWriteError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ClientConnectionCommandWriteError {
|
||||||
|
fn from(x: std::io::Error) -> ClientConnectionCommandWriteError {
|
||||||
|
ClientConnectionCommandWriteError::WriteError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientConnectionCommand {
|
||||||
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
|
r: &mut R,
|
||||||
|
) -> Result<ClientConnectionCommand, ClientConnectionCommandReadError> {
|
||||||
|
match r.read_u8().await? {
|
||||||
|
0x01 => Ok(ClientConnectionCommand::EstablishTCPStream),
|
||||||
|
0x02 => Ok(ClientConnectionCommand::EstablishTCPPortBinding),
|
||||||
|
0x03 => Ok(ClientConnectionCommand::AssociateUDPPort),
|
||||||
|
x => Err(ClientConnectionCommandReadError::InvalidClientConnectionCommand(x)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
|
&self,
|
||||||
|
w: &mut W,
|
||||||
|
) -> Result<(), std::io::Error> {
|
||||||
|
match self {
|
||||||
|
ClientConnectionCommand::EstablishTCPStream => w.write_u8(0x01).await,
|
||||||
|
ClientConnectionCommand::EstablishTCPPortBinding => w.write_u8(0x02).await,
|
||||||
|
ClientConnectionCommand::AssociateUDPPort => w.write_u8(0x03).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::standard_roundtrip!(client_command_roundtrips, ClientConnectionCommand);
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
#[cfg_attr(test, derive(Arbitrary))]
|
#[cfg_attr(test, derive(Arbitrary))]
|
||||||
pub struct ClientConnectionRequest {
|
pub struct ClientConnectionRequest {
|
||||||
@@ -32,37 +76,46 @@ pub struct ClientConnectionRequest {
|
|||||||
pub destination_port: u16,
|
pub destination_port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ClientConnectionRequestReadError {
|
||||||
|
#[error("Invalid version in client request: {0} (expected 5)")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
#[error("Invalid command for client request: {0}")]
|
||||||
|
InvalidCommand(#[from] ClientConnectionCommandReadError),
|
||||||
|
#[error("Invalid reserved byte: {0} (expected 0)")]
|
||||||
|
InvalidReservedByte(u8),
|
||||||
|
#[error("Underlying read error: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
AddressReadError(#[from] SOCKSv5AddressReadError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ClientConnectionRequestReadError {
|
||||||
|
fn from(x: std::io::Error) -> ClientConnectionRequestReadError {
|
||||||
|
ClientConnectionRequestReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ClientConnectionRequest {
|
impl ClientConnectionRequest {
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
) -> Result<Self, DeserializationError> {
|
) -> Result<Self, ClientConnectionRequestReadError> {
|
||||||
let mut buffer = [0; 3];
|
let version = r.read_u8().await?;
|
||||||
|
if version != 5 {
|
||||||
debug!("Starting to read request.");
|
return Err(ClientConnectionRequestReadError::InvalidVersion(version));
|
||||||
read_amt(r, 3, &mut buffer).await?;
|
|
||||||
debug!("Read three opening bytes: {:?}", buffer);
|
|
||||||
if buffer[0] != 5 {
|
|
||||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let command_code = match buffer[1] {
|
let command_code = ClientConnectionCommand::read(r).await?;
|
||||||
0x01 => ClientConnectionCommand::EstablishTCPStream,
|
|
||||||
0x02 => ClientConnectionCommand::EstablishTCPPortBinding,
|
|
||||||
0x03 => ClientConnectionCommand::AssociateUDPPort,
|
|
||||||
x => return Err(DeserializationError::InvalidClientCommand(x)),
|
|
||||||
};
|
|
||||||
debug!("Command code: {:?}", command_code);
|
|
||||||
|
|
||||||
if buffer[2] != 0 {
|
let reserved = r.read_u8().await?;
|
||||||
return Err(DeserializationError::InvalidReservedByte(buffer[2]));
|
if reserved != 0 {
|
||||||
|
return Err(ClientConnectionRequestReadError::InvalidReservedByte(
|
||||||
|
reserved,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let destination_address = SOCKSv5Address::read(r).await?;
|
let destination_address = SOCKSv5Address::read(r).await?;
|
||||||
debug!("Destination address: {}", destination_address);
|
let destination_port = r.read_u16().await?;
|
||||||
|
|
||||||
read_amt(r, 2, &mut buffer).await?;
|
|
||||||
let destination_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
|
||||||
debug!("Destination port: {}", destination_port);
|
|
||||||
|
|
||||||
Ok(ClientConnectionRequest {
|
Ok(ClientConnectionRequest {
|
||||||
command_code,
|
command_code,
|
||||||
@@ -74,63 +127,62 @@ impl ClientConnectionRequest {
|
|||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
&self,
|
&self,
|
||||||
w: &mut W,
|
w: &mut W,
|
||||||
) -> Result<(), SerializationError> {
|
) -> Result<(), ClientConnectionCommandWriteError> {
|
||||||
let command = match self.command_code {
|
w.write_u8(5).await?;
|
||||||
ClientConnectionCommand::EstablishTCPStream => 1,
|
self.command_code.write(w).await?;
|
||||||
ClientConnectionCommand::EstablishTCPPortBinding => 2,
|
w.write_u8(0).await?;
|
||||||
ClientConnectionCommand::AssociateUDPPort => 3,
|
|
||||||
};
|
|
||||||
|
|
||||||
w.write_all(&[5, command, 0]).await?;
|
|
||||||
self.destination_address.write(w).await?;
|
self.destination_address.write(w).await?;
|
||||||
w.write_all(&[
|
w.write_u16(self.destination_port).await?;
|
||||||
(self.destination_port >> 8) as u8,
|
Ok(())
|
||||||
(self.destination_port & 0xffu16) as u8,
|
|
||||||
])
|
|
||||||
.await
|
|
||||||
.map_err(SerializationError::IOError)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
crate::standard_roundtrip!(client_request_roundtrips, ClientConnectionRequest);
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_short_reads() {
|
async fn check_short_reads() {
|
||||||
let empty = vec![];
|
let empty = vec![];
|
||||||
let mut cursor = Cursor::new(empty);
|
let mut cursor = Cursor::new(empty);
|
||||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(
|
||||||
|
ys,
|
||||||
|
Err(ClientConnectionRequestReadError::ReadError(_))
|
||||||
|
));
|
||||||
|
|
||||||
let no_len = vec![5, 1];
|
let no_len = vec![5, 1];
|
||||||
let mut cursor = Cursor::new(no_len);
|
let mut cursor = Cursor::new(no_len);
|
||||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(
|
||||||
|
ys,
|
||||||
|
Err(ClientConnectionRequestReadError::ReadError(_))
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_version() {
|
async fn check_bad_version() {
|
||||||
let bad_ver = vec![6, 1, 1];
|
let bad_ver = vec![6, 1, 1];
|
||||||
let mut cursor = Cursor::new(bad_ver);
|
let mut cursor = Cursor::new(bad_ver);
|
||||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(Err(ClientConnectionRequestReadError::InvalidVersion(6)), ys);
|
||||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_command() {
|
async fn check_bad_command() {
|
||||||
let bad_cmd = vec![5, 32, 1];
|
let bad_cmd = vec![5, 32, 1];
|
||||||
let mut cursor = Cursor::new(bad_cmd);
|
let mut cursor = Cursor::new(bad_cmd);
|
||||||
let ys = ClientConnectionRequest::read(&mut cursor);
|
let ys = ClientConnectionRequest::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Err(DeserializationError::InvalidClientCommand(32)),
|
Err(ClientConnectionRequestReadError::InvalidCommand(
|
||||||
task::block_on(ys)
|
ClientConnectionCommandReadError::InvalidClientConnectionCommand(32)
|
||||||
|
)),
|
||||||
|
ys
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn short_write_fails_right() {
|
async fn short_write_fails_right() {
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
let mut buffer = [0u8; 2];
|
let mut buffer = [0u8; 2];
|
||||||
let cmd = ClientConnectionRequest {
|
let cmd = ClientConnectionRequest {
|
||||||
command_code: ClientConnectionCommand::AssociateUDPPort,
|
command_code: ClientConnectionCommand::AssociateUDPPort,
|
||||||
@@ -138,10 +190,12 @@ fn short_write_fails_right() {
|
|||||||
destination_port: 22,
|
destination_port: 22,
|
||||||
};
|
};
|
||||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||||
let result = task::block_on(cmd.write(&mut cursor));
|
let result = cmd.write(&mut cursor).await;
|
||||||
match result {
|
match result {
|
||||||
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
|
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
|
||||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
Err(ClientConnectionCommandWriteError::WriteError(x)) => {
|
||||||
|
assert!(x.contains("write zero"));
|
||||||
|
}
|
||||||
Err(e) => panic!("Got the wrong error writing too much data: {}", e),
|
Err(e) => panic!("Got the wrong error writing too much data: {}", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
#[cfg(test)]
|
use crate::messages::authentication_method::{
|
||||||
use crate::errors::AuthenticationDeserializationError;
|
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||||
use crate::errors::{DeserializationError, SerializationError};
|
};
|
||||||
use crate::messages::AuthenticationMethod;
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::task;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::io::Cursor;
|
|
||||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
use proptest::proptest;
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use proptest_derive::Arbitrary;
|
use proptest_derive::Arbitrary;
|
||||||
|
#[cfg(test)]
|
||||||
|
use std::io::Cursor;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
/// Client greetings are the first message sent in a SOCKSv5 session. They
|
/// Client greetings are the first message sent in a SOCKSv5 session. They
|
||||||
/// identify that there's a client that wants to talk to a server, and that
|
/// identify that there's a client that wants to talk to a server, and that
|
||||||
@@ -23,26 +19,52 @@ pub struct ClientGreeting {
|
|||||||
pub acceptable_methods: Vec<AuthenticationMethod>,
|
pub acceptable_methods: Vec<AuthenticationMethod>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ClientGreetingReadError {
|
||||||
|
#[error("Invalid version in client request: {0} (expected 5)")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
#[error(transparent)]
|
||||||
|
AuthMethodReadError(#[from] AuthenticationMethodReadError),
|
||||||
|
#[error("Underlying read error: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ClientGreetingReadError {
|
||||||
|
fn from(x: std::io::Error) -> ClientGreetingReadError {
|
||||||
|
ClientGreetingReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ClientGreetingWriteError {
|
||||||
|
#[error("Too many methods provided; need <256, saw {0}")]
|
||||||
|
TooManyMethods(usize),
|
||||||
|
#[error(transparent)]
|
||||||
|
AuthMethodWriteError(#[from] AuthenticationMethodWriteError),
|
||||||
|
#[error("Underlying write error: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ClientGreetingWriteError {
|
||||||
|
fn from(x: std::io::Error) -> ClientGreetingWriteError {
|
||||||
|
ClientGreetingWriteError::WriteError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ClientGreeting {
|
impl ClientGreeting {
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
) -> Result<ClientGreeting, DeserializationError> {
|
) -> Result<ClientGreeting, ClientGreetingReadError> {
|
||||||
let mut buffer = [0; 1];
|
let version = r.read_u8().await?;
|
||||||
|
|
||||||
if r.read(&mut buffer).await? == 0 {
|
if version != 5 {
|
||||||
return Err(DeserializationError::NotEnoughData);
|
return Err(ClientGreetingReadError::InvalidVersion(version));
|
||||||
}
|
}
|
||||||
|
|
||||||
if buffer[0] != 5 {
|
let num_methods = r.read_u8().await? as usize;
|
||||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.read(&mut buffer).await? == 0 {
|
let mut acceptable_methods = Vec::with_capacity(num_methods);
|
||||||
return Err(DeserializationError::NotEnoughData);
|
for _ in 0..num_methods {
|
||||||
}
|
|
||||||
|
|
||||||
let mut acceptable_methods = Vec::with_capacity(buffer[0] as usize);
|
|
||||||
for _ in 0..buffer[0] {
|
|
||||||
acceptable_methods.push(AuthenticationMethod::read(r).await?);
|
acceptable_methods.push(AuthenticationMethod::read(r).await?);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,9 +74,9 @@ impl ClientGreeting {
|
|||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
&self,
|
&self,
|
||||||
w: &mut W,
|
w: &mut W,
|
||||||
) -> Result<(), SerializationError> {
|
) -> Result<(), ClientGreetingWriteError> {
|
||||||
if self.acceptable_methods.len() > 255 {
|
if self.acceptable_methods.len() > 255 {
|
||||||
return Err(SerializationError::TooManyAuthMethods(
|
return Err(ClientGreetingWriteError::TooManyMethods(
|
||||||
self.acceptable_methods.len(),
|
self.acceptable_methods.len(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
@@ -70,44 +92,41 @@ impl ClientGreeting {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
crate::standard_roundtrip!(client_greeting_roundtrips, ClientGreeting);
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_short_reads() {
|
async fn check_short_reads() {
|
||||||
let empty = vec![];
|
let empty = vec![];
|
||||||
let mut cursor = Cursor::new(empty);
|
let mut cursor = Cursor::new(empty);
|
||||||
let ys = ClientGreeting::read(&mut cursor);
|
let ys = ClientGreeting::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
|
||||||
|
|
||||||
let no_len = vec![5];
|
let no_len = vec![5];
|
||||||
let mut cursor = Cursor::new(no_len);
|
let mut cursor = Cursor::new(no_len);
|
||||||
let ys = ClientGreeting::read(&mut cursor);
|
let ys = ClientGreeting::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(ys, Err(ClientGreetingReadError::ReadError(_))));
|
||||||
|
|
||||||
let bad_len = vec![5, 9];
|
let bad_len = vec![5, 9];
|
||||||
let mut cursor = Cursor::new(bad_len);
|
let mut cursor = Cursor::new(bad_len);
|
||||||
let ys = ClientGreeting::read(&mut cursor);
|
let ys = ClientGreeting::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert!(matches!(
|
||||||
Err(DeserializationError::AuthenticationMethodError(
|
ys,
|
||||||
AuthenticationDeserializationError::NoDataFound
|
Err(ClientGreetingReadError::AuthMethodReadError(
|
||||||
)),
|
AuthenticationMethodReadError::ReadError(_)
|
||||||
task::block_on(ys)
|
))
|
||||||
);
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_version() {
|
async fn check_bad_version() {
|
||||||
let no_len = vec![6, 1, 1];
|
let no_len = vec![6, 1, 1];
|
||||||
let mut cursor = Cursor::new(no_len);
|
let mut cursor = Cursor::new(no_len);
|
||||||
let ys = ClientGreeting::read(&mut cursor);
|
let ys = ClientGreeting::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(Err(ClientGreetingReadError::InvalidVersion(6)), ys);
|
||||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_too_many() {
|
async fn check_too_many() {
|
||||||
let mut auth_methods = Vec::with_capacity(512);
|
let mut auth_methods = Vec::with_capacity(512);
|
||||||
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
|
auth_methods.resize(512, AuthenticationMethod::ChallengeHandshake);
|
||||||
let greet = ClientGreeting {
|
let greet = ClientGreeting {
|
||||||
@@ -115,7 +134,7 @@ fn check_too_many() {
|
|||||||
};
|
};
|
||||||
let mut output = vec![0; 1024];
|
let mut output = vec![0; 1024];
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Err(SerializationError::TooManyAuthMethods(512)),
|
Err(ClientGreetingWriteError::TooManyMethods(512)),
|
||||||
task::block_on(greet.write(&mut output))
|
greet.write(&mut output).await
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,10 @@
|
|||||||
use crate::errors::{DeserializationError, SerializationError};
|
use crate::messages::string::{SOCKSv5String, SOCKSv5StringReadError, SOCKSv5StringWriteError};
|
||||||
use crate::serialize::{read_string, write_string};
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use async_std::task;
|
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use futures::io::Cursor;
|
use std::io::Cursor;
|
||||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use thiserror::Error;
|
||||||
#[cfg(test)]
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use proptest::prelude::{Arbitrary, BoxedStrategy};
|
|
||||||
use proptest::proptest;
|
|
||||||
#[cfg(test)]
|
|
||||||
use proptest::strategy::Strategy;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
pub struct ClientUsernamePassword {
|
pub struct ClientUsernamePassword {
|
||||||
@@ -30,30 +24,58 @@ impl Arbitrary for ClientUsernamePassword {
|
|||||||
|
|
||||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||||
let max_len = args.unwrap_or(12) as usize;
|
let max_len = args.unwrap_or(12) as usize;
|
||||||
(USERNAME_REGEX, PASSWORD_REGEX).prop_map(move |(mut username, mut password)| {
|
(USERNAME_REGEX, PASSWORD_REGEX)
|
||||||
username.shrink_to(max_len);
|
.prop_map(move |(mut username, mut password)| {
|
||||||
password.shrink_to(max_len);
|
username.shrink_to(max_len);
|
||||||
ClientUsernamePassword { username, password }
|
password.shrink_to(max_len);
|
||||||
}).boxed()
|
ClientUsernamePassword { username, password }
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ClientUsernamePasswordReadError {
|
||||||
|
#[error("Underlying buffer read error: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
#[error("Invalid username/password version; expected 1, saw {0}")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
#[error(transparent)]
|
||||||
|
StringError(#[from] SOCKSv5StringReadError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ClientUsernamePasswordReadError {
|
||||||
|
fn from(x: std::io::Error) -> ClientUsernamePasswordReadError {
|
||||||
|
ClientUsernamePasswordReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ClientUsernamePasswordWriteError {
|
||||||
|
#[error("Underlying buffer read error: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
StringError(#[from] SOCKSv5StringWriteError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ClientUsernamePasswordWriteError {
|
||||||
|
fn from(x: std::io::Error) -> ClientUsernamePasswordWriteError {
|
||||||
|
ClientUsernamePasswordWriteError::WriteError(format!("{}", x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientUsernamePassword {
|
impl ClientUsernamePassword {
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
) -> Result<Self, DeserializationError> {
|
) -> Result<Self, ClientUsernamePasswordReadError> {
|
||||||
let mut buffer = [0; 1];
|
let version = r.read_u8().await?;
|
||||||
|
|
||||||
if r.read(&mut buffer).await? == 0 {
|
if version != 1 {
|
||||||
return Err(DeserializationError::NotEnoughData);
|
return Err(ClientUsernamePasswordReadError::InvalidVersion(version));
|
||||||
}
|
}
|
||||||
|
|
||||||
if buffer[0] != 1 {
|
let username = SOCKSv5String::read(r).await?.into();
|
||||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
let password = SOCKSv5String::read(r).await?.into();
|
||||||
}
|
|
||||||
|
|
||||||
let username = read_string(r).await?;
|
|
||||||
let password = read_string(r).await?;
|
|
||||||
|
|
||||||
Ok(ClientUsernamePassword { username, password })
|
Ok(ClientUsernamePassword { username, password })
|
||||||
}
|
}
|
||||||
@@ -61,35 +83,40 @@ impl ClientUsernamePassword {
|
|||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
&self,
|
&self,
|
||||||
w: &mut W,
|
w: &mut W,
|
||||||
) -> Result<(), SerializationError> {
|
) -> Result<(), ClientUsernamePasswordWriteError> {
|
||||||
w.write_all(&[1]).await?;
|
w.write_u8(1).await?;
|
||||||
write_string(&self.username, w).await?;
|
SOCKSv5String::from(self.username.as_str()).write(w).await?;
|
||||||
write_string(&self.password, w).await
|
SOCKSv5String::from(self.password.as_str()).write(w).await?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
crate::standard_roundtrip!(username_password_roundtrips, ClientUsernamePassword);
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_short_reads() {
|
async fn heck_short_reads() {
|
||||||
let empty = vec![];
|
let empty = vec![];
|
||||||
let mut cursor = Cursor::new(empty);
|
let mut cursor = Cursor::new(empty);
|
||||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(
|
||||||
|
ys,
|
||||||
|
Err(ClientUsernamePasswordReadError::ReadError(_))
|
||||||
|
));
|
||||||
|
|
||||||
let user_only = vec![1, 3, 102, 111, 111];
|
let user_only = vec![1, 3, 102, 111, 111];
|
||||||
let mut cursor = Cursor::new(user_only);
|
let mut cursor = Cursor::new(user_only);
|
||||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
println!("ys: {:?}", ys);
|
||||||
|
assert!(matches!(
|
||||||
|
ys,
|
||||||
|
Err(ClientUsernamePasswordReadError::StringError(_))
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_version() {
|
async fn check_bad_version() {
|
||||||
let bad_len = vec![5];
|
let bad_len = vec![5];
|
||||||
let mut cursor = Cursor::new(bad_len);
|
let mut cursor = Cursor::new(bad_len);
|
||||||
let ys = ClientUsernamePassword::read(&mut cursor);
|
let ys = ClientUsernamePassword::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(Err(ClientUsernamePasswordReadError::InvalidVersion(5)), ys);
|
||||||
Err(DeserializationError::InvalidVersion(1, 5)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,7 @@
|
|||||||
use crate::errors::{DeserializationError, SerializationError};
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::task;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::io::Cursor;
|
|
||||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
use proptest::proptest;
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use proptest_derive::Arbitrary;
|
use proptest_derive::Arbitrary;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
#[cfg_attr(test, derive(Arbitrary))]
|
#[cfg_attr(test, derive(Arbitrary))]
|
||||||
@@ -15,6 +9,32 @@ pub struct ServerAuthResponse {
|
|||||||
pub success: bool,
|
pub success: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ServerAuthResponseReadError {
|
||||||
|
#[error("Underlying buffer read error: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
#[error("Invalid username/password version; expected 1, saw {0}")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ServerAuthResponseReadError {
|
||||||
|
fn from(x: std::io::Error) -> ServerAuthResponseReadError {
|
||||||
|
ServerAuthResponseReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ServerAuthResponseWriteError {
|
||||||
|
#[error("Underlying buffer read error: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ServerAuthResponseWriteError {
|
||||||
|
fn from(x: std::io::Error) -> ServerAuthResponseWriteError {
|
||||||
|
ServerAuthResponseWriteError::WriteError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ServerAuthResponse {
|
impl ServerAuthResponse {
|
||||||
pub fn success() -> ServerAuthResponse {
|
pub fn success() -> ServerAuthResponse {
|
||||||
ServerAuthResponse { success: true }
|
ServerAuthResponse { success: true }
|
||||||
@@ -26,30 +46,22 @@ impl ServerAuthResponse {
|
|||||||
|
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
) -> Result<Self, DeserializationError> {
|
) -> Result<Self, ServerAuthResponseReadError> {
|
||||||
let mut buffer = [0; 1];
|
let version = r.read_u8().await?;
|
||||||
|
|
||||||
if r.read(&mut buffer).await? == 0 {
|
if version != 1 {
|
||||||
return Err(DeserializationError::NotEnoughData);
|
return Err(ServerAuthResponseReadError::InvalidVersion(version));
|
||||||
}
|
|
||||||
|
|
||||||
if buffer[0] != 1 {
|
|
||||||
return Err(DeserializationError::InvalidVersion(1, buffer[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.read(&mut buffer).await? == 0 {
|
|
||||||
return Err(DeserializationError::NotEnoughData);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(ServerAuthResponse {
|
Ok(ServerAuthResponse {
|
||||||
success: buffer[0] == 0,
|
success: r.read_u8().await? == 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
&self,
|
&self,
|
||||||
w: &mut W,
|
w: &mut W,
|
||||||
) -> Result<(), SerializationError> {
|
) -> Result<(), ServerAuthResponseWriteError> {
|
||||||
w.write_all(&[1]).await?;
|
w.write_all(&[1]).await?;
|
||||||
w.write_all(&[if self.success { 0x00 } else { 0xde }])
|
w.write_all(&[if self.success { 0x00 } else { 0xde }])
|
||||||
.await?;
|
.await?;
|
||||||
@@ -57,28 +69,29 @@ impl ServerAuthResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
crate::standard_roundtrip!(server_auth_response, ServerAuthResponse);
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn check_short_reads() {
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn check_short_reads() {
|
|
||||||
let empty = vec![];
|
let empty = vec![];
|
||||||
let mut cursor = Cursor::new(empty);
|
let mut cursor = Cursor::new(empty);
|
||||||
let ys = ServerAuthResponse::read(&mut cursor);
|
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
|
||||||
|
|
||||||
let no_len = vec![1];
|
let no_len = vec![1];
|
||||||
let mut cursor = Cursor::new(no_len);
|
let mut cursor = Cursor::new(no_len);
|
||||||
let ys = ServerAuthResponse::read(&mut cursor);
|
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(ys, Err(ServerAuthResponseReadError::ReadError(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_version() {
|
async fn check_bad_version() {
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
let no_len = vec![6, 1];
|
let no_len = vec![6, 1];
|
||||||
let mut cursor = Cursor::new(no_len);
|
let mut cursor = Cursor::new(no_len);
|
||||||
let ys = ServerAuthResponse::read(&mut cursor);
|
let ys = ServerAuthResponse::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(Err(ServerAuthResponseReadError::InvalidVersion(6)), ys);
|
||||||
Err(DeserializationError::InvalidVersion(1, 6)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
#[cfg(test)]
|
use crate::messages::authentication_method::{
|
||||||
use crate::errors::AuthenticationDeserializationError;
|
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
|
||||||
use crate::errors::{DeserializationError, SerializationError};
|
};
|
||||||
use crate::messages::AuthenticationMethod;
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::task;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::io::Cursor;
|
|
||||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
use proptest::proptest;
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use proptest_derive::Arbitrary;
|
use proptest_derive::Arbitrary;
|
||||||
|
#[cfg(test)]
|
||||||
|
use std::io::Cursor;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
#[cfg_attr(test, derive(Arbitrary))]
|
#[cfg_attr(test, derive(Arbitrary))]
|
||||||
@@ -18,6 +14,36 @@ pub struct ServerChoice {
|
|||||||
pub chosen_method: AuthenticationMethod,
|
pub chosen_method: AuthenticationMethod,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ServerChoiceReadError {
|
||||||
|
#[error(transparent)]
|
||||||
|
AuthMethodError(#[from] AuthenticationMethodReadError),
|
||||||
|
#[error("Error in underlying buffer: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
#[error("Invalid version; expected 5, got {0}")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ServerChoiceReadError {
|
||||||
|
fn from(x: std::io::Error) -> ServerChoiceReadError {
|
||||||
|
ServerChoiceReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ServerChoiceWriteError {
|
||||||
|
#[error(transparent)]
|
||||||
|
AuthMethodError(#[from] AuthenticationMethodWriteError),
|
||||||
|
#[error("Error in underlying buffer: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ServerChoiceWriteError {
|
||||||
|
fn from(x: std::io::Error) -> ServerChoiceWriteError {
|
||||||
|
ServerChoiceWriteError::WriteError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ServerChoice {
|
impl ServerChoice {
|
||||||
pub fn rejection() -> ServerChoice {
|
pub fn rejection() -> ServerChoice {
|
||||||
ServerChoice {
|
ServerChoice {
|
||||||
@@ -33,15 +59,11 @@ impl ServerChoice {
|
|||||||
|
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
) -> Result<Self, DeserializationError> {
|
) -> Result<Self, ServerChoiceReadError> {
|
||||||
let mut buffer = [0; 1];
|
let version = r.read_u8().await?;
|
||||||
|
|
||||||
if r.read(&mut buffer).await? == 0 {
|
if version != 5 {
|
||||||
return Err(DeserializationError::NotEnoughData);
|
return Err(ServerChoiceReadError::InvalidVersion(version));
|
||||||
}
|
|
||||||
|
|
||||||
if buffer[0] != 5 {
|
|
||||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let chosen_method = AuthenticationMethod::read(r).await?;
|
let chosen_method = AuthenticationMethod::read(r).await?;
|
||||||
@@ -52,39 +74,32 @@ impl ServerChoice {
|
|||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
&self,
|
&self,
|
||||||
w: &mut W,
|
w: &mut W,
|
||||||
) -> Result<(), SerializationError> {
|
) -> Result<(), ServerChoiceWriteError> {
|
||||||
w.write_all(&[5]).await?;
|
w.write_u8(5).await?;
|
||||||
self.chosen_method.write(w).await
|
self.chosen_method.write(w).await?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
crate::standard_roundtrip!(server_choice_roundtrips, ServerChoice);
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_short_reads() {
|
async fn check_short_reads() {
|
||||||
let empty = vec![];
|
let empty = vec![];
|
||||||
let mut cursor = Cursor::new(empty);
|
let mut cursor = Cursor::new(empty);
|
||||||
let ys = ServerChoice::read(&mut cursor);
|
let ys = ServerChoice::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(ys, Err(ServerChoiceReadError::ReadError(_))));
|
||||||
|
|
||||||
let bad_len = vec![5];
|
let bad_len = vec![5];
|
||||||
let mut cursor = Cursor::new(bad_len);
|
let mut cursor = Cursor::new(bad_len);
|
||||||
let ys = ServerChoice::read(&mut cursor);
|
let ys = ServerChoice::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert!(matches!(ys, Err(ServerChoiceReadError::AuthMethodError(_))));
|
||||||
Err(DeserializationError::AuthenticationMethodError(
|
|
||||||
AuthenticationDeserializationError::NoDataFound
|
|
||||||
)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_version() {
|
async fn check_bad_version() {
|
||||||
let no_len = vec![9, 1];
|
let no_len = vec![9, 1];
|
||||||
let mut cursor = Cursor::new(no_len);
|
let mut cursor = Cursor::new(no_len);
|
||||||
let ys = ServerChoice::read(&mut cursor);
|
let ys = ServerChoice::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(Err(ServerChoiceReadError::InvalidVersion(9)), ys);
|
||||||
Err(DeserializationError::InvalidVersion(5, 9)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,10 @@
|
|||||||
use crate::errors::{DeserializationError, SerializationError};
|
use crate::address::{SOCKSv5Address, SOCKSv5AddressReadError, SOCKSv5AddressWriteError};
|
||||||
use crate::network::generic::IntoErrorResponse;
|
|
||||||
use crate::network::SOCKSv5Address;
|
|
||||||
use crate::serialize::read_amt;
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::io::ErrorKind;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::task;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::io::Cursor;
|
|
||||||
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
|
||||||
use log::warn;
|
|
||||||
use proptest::proptest;
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use proptest_derive::Arbitrary;
|
use proptest_derive::Arbitrary;
|
||||||
use std::net::Ipv4Addr;
|
#[cfg(test)]
|
||||||
|
use std::io::Cursor;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, Error, PartialEq)]
|
#[derive(Clone, Debug, Eq, Error, PartialEq)]
|
||||||
#[cfg_attr(test, derive(Arbitrary))]
|
#[cfg_attr(test, derive(Arbitrary))]
|
||||||
@@ -40,12 +29,6 @@ pub enum ServerResponseStatus {
|
|||||||
AddressTypeNotSupported,
|
AddressTypeNotSupported,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoErrorResponse for ServerResponseStatus {
|
|
||||||
fn into_response(&self) -> ServerResponseStatus {
|
|
||||||
self.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
#[cfg_attr(test, derive(Arbitrary))]
|
#[cfg_attr(test, derive(Arbitrary))]
|
||||||
pub struct ServerResponse {
|
pub struct ServerResponse {
|
||||||
@@ -54,33 +37,57 @@ pub struct ServerResponse {
|
|||||||
pub bound_port: u16,
|
pub bound_port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerResponse {
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
pub fn error<E: IntoErrorResponse>(resp: &E) -> ServerResponse {
|
pub enum ServerResponseReadError {
|
||||||
ServerResponse {
|
#[error("Error reading from underlying buffer: {0}")]
|
||||||
status: resp.into_response(),
|
ReadError(String),
|
||||||
bound_address: SOCKSv5Address::IP4(Ipv4Addr::new(0, 0, 0, 0)),
|
#[error(transparent)]
|
||||||
bound_port: 0,
|
AddressReadError(#[from] SOCKSv5AddressReadError),
|
||||||
}
|
#[error("Invalid version; expected 5, got {0}")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
#[error("Invalid reserved byte; saw {0}, should be 0")]
|
||||||
|
InvalidReservedByte(u8),
|
||||||
|
#[error("Invalid (or just unknown) server response value {0}")]
|
||||||
|
InvalidServerResponse(u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ServerResponseReadError {
|
||||||
|
fn from(x: std::io::Error) -> ServerResponseReadError {
|
||||||
|
ServerResponseReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum ServerResponseWriteError {
|
||||||
|
#[error("Error reading from underlying buffer: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
AddressWriteError(#[from] SOCKSv5AddressWriteError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ServerResponseWriteError {
|
||||||
|
fn from(x: std::io::Error) -> ServerResponseWriteError {
|
||||||
|
ServerResponseWriteError::WriteError(format!("{}", x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerResponse {
|
impl ServerResponse {
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
pub async fn read<R: AsyncRead + Send + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
) -> Result<Self, DeserializationError> {
|
) -> Result<Self, ServerResponseReadError> {
|
||||||
let mut buffer = [0; 3];
|
let version = r.read_u8().await?;
|
||||||
|
if version != 5 {
|
||||||
read_amt(r, 3, &mut buffer).await?;
|
return Err(ServerResponseReadError::InvalidVersion(version));
|
||||||
|
|
||||||
if buffer[0] != 5 {
|
|
||||||
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if buffer[2] != 0 {
|
let status_byte = r.read_u8().await?;
|
||||||
warn!(target: "async-socks5", "Hey, this isn't terrible, but the server is sending invalid reserved bytes.");
|
|
||||||
|
let reserved_byte = r.read_u8().await?;
|
||||||
|
if reserved_byte != 0 {
|
||||||
|
return Err(ServerResponseReadError::InvalidReservedByte(reserved_byte));
|
||||||
}
|
}
|
||||||
|
|
||||||
let status = match buffer[1] {
|
let status = match status_byte {
|
||||||
0x00 => ServerResponseStatus::RequestGranted,
|
0x00 => ServerResponseStatus::RequestGranted,
|
||||||
0x01 => ServerResponseStatus::GeneralFailure,
|
0x01 => ServerResponseStatus::GeneralFailure,
|
||||||
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
|
0x02 => ServerResponseStatus::ConnectionNotAllowedByRule,
|
||||||
@@ -90,12 +97,11 @@ impl ServerResponse {
|
|||||||
0x06 => ServerResponseStatus::TTLExpired,
|
0x06 => ServerResponseStatus::TTLExpired,
|
||||||
0x07 => ServerResponseStatus::CommandNotSupported,
|
0x07 => ServerResponseStatus::CommandNotSupported,
|
||||||
0x08 => ServerResponseStatus::AddressTypeNotSupported,
|
0x08 => ServerResponseStatus::AddressTypeNotSupported,
|
||||||
x => return Err(DeserializationError::InvalidServerResponse(x)),
|
x => return Err(ServerResponseReadError::InvalidServerResponse(x)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let bound_address = SOCKSv5Address::read(r).await?;
|
let bound_address = SOCKSv5Address::read(r).await?;
|
||||||
read_amt(r, 2, &mut buffer).await?;
|
let bound_port = r.read_u16().await?;
|
||||||
let bound_port = ((buffer[0] as u16) << 8) + (buffer[1] as u16);
|
|
||||||
|
|
||||||
Ok(ServerResponse {
|
Ok(ServerResponse {
|
||||||
status,
|
status,
|
||||||
@@ -107,7 +113,9 @@ impl ServerResponse {
|
|||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
||||||
&self,
|
&self,
|
||||||
w: &mut W,
|
w: &mut W,
|
||||||
) -> Result<(), SerializationError> {
|
) -> Result<(), ServerResponseWriteError> {
|
||||||
|
w.write_u8(5).await?;
|
||||||
|
|
||||||
let status_code = match self.status {
|
let status_code = match self.status {
|
||||||
ServerResponseStatus::RequestGranted => 0x00,
|
ServerResponseStatus::RequestGranted => 0x00,
|
||||||
ServerResponseStatus::GeneralFailure => 0x01,
|
ServerResponseStatus::GeneralFailure => 0x01,
|
||||||
@@ -119,59 +127,61 @@ impl ServerResponse {
|
|||||||
ServerResponseStatus::CommandNotSupported => 0x07,
|
ServerResponseStatus::CommandNotSupported => 0x07,
|
||||||
ServerResponseStatus::AddressTypeNotSupported => 0x08,
|
ServerResponseStatus::AddressTypeNotSupported => 0x08,
|
||||||
};
|
};
|
||||||
|
w.write_u8(status_code).await?;
|
||||||
w.write_all(&[5, status_code, 0]).await?;
|
w.write_u8(0).await?;
|
||||||
self.bound_address.write(w).await?;
|
self.bound_address.write(w).await?;
|
||||||
w.write_all(&[
|
w.write_u16(self.bound_port).await?;
|
||||||
(self.bound_port >> 8) as u8,
|
|
||||||
(self.bound_port & 0xffu16) as u8,
|
Ok(())
|
||||||
])
|
|
||||||
.await
|
|
||||||
.map_err(SerializationError::IOError)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
crate::standard_roundtrip!(server_response_roundtrips, ServerResponse);
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_short_reads() {
|
async fn check_short_reads() {
|
||||||
let empty = vec![];
|
let empty = vec![];
|
||||||
let mut cursor = Cursor::new(empty);
|
let mut cursor = Cursor::new(empty);
|
||||||
let ys = ServerResponse::read(&mut cursor);
|
let ys = ServerResponse::read(&mut cursor).await;
|
||||||
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
|
assert!(matches!(ys, Err(ServerResponseReadError::ReadError(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_version() {
|
async fn check_bad_version() {
|
||||||
let bad_ver = vec![6, 1, 1];
|
let bad_ver = vec![6, 1, 1];
|
||||||
let mut cursor = Cursor::new(bad_ver);
|
let mut cursor = Cursor::new(bad_ver);
|
||||||
let ys = ServerResponse::read(&mut cursor);
|
let ys = ServerResponse::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(Err(ServerResponseReadError::InvalidVersion(6)), ys);
|
||||||
Err(DeserializationError::InvalidVersion(5, 6)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn check_bad_command() {
|
async fn check_bad_reserved() {
|
||||||
let bad_cmd = vec![5, 32, 0x42];
|
let bad_cmd = vec![5, 32, 0x42];
|
||||||
let mut cursor = Cursor::new(bad_cmd);
|
let mut cursor = Cursor::new(bad_cmd);
|
||||||
let ys = ServerResponse::read(&mut cursor);
|
let ys = ServerResponse::read(&mut cursor).await;
|
||||||
assert_eq!(
|
assert_eq!(Err(ServerResponseReadError::InvalidReservedByte(0x42)), ys);
|
||||||
Err(DeserializationError::InvalidServerResponse(32)),
|
|
||||||
task::block_on(ys)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn short_write_fails_right() {
|
async fn check_bad_command() {
|
||||||
let mut buffer = [0u8; 2];
|
let bad_cmd = vec![5, 32, 0];
|
||||||
let cmd = ServerResponse::error(&ServerResponseStatus::AddressTypeNotSupported);
|
let mut cursor = Cursor::new(bad_cmd);
|
||||||
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
let ys = ServerResponse::read(&mut cursor).await;
|
||||||
let result = task::block_on(cmd.write(&mut cursor));
|
assert_eq!(Err(ServerResponseReadError::InvalidServerResponse(32)), ys);
|
||||||
match result {
|
}
|
||||||
Ok(_) => panic!("Mysteriously able to fit > 2 bytes in 2 bytes."),
|
|
||||||
Err(SerializationError::IOError(x)) => assert_eq!(ErrorKind::WriteZero, x.kind()),
|
#[tokio::test]
|
||||||
Err(e) => panic!("Got the wrong error writing too much data: {}", e),
|
async fn short_write_fails_right() {
|
||||||
}
|
let mut buffer = [0u8; 2];
|
||||||
|
let cmd = ServerResponse {
|
||||||
|
status: ServerResponseStatus::AddressTypeNotSupported,
|
||||||
|
bound_address: SOCKSv5Address::Hostname("tester.com".to_string()),
|
||||||
|
bound_port: 99,
|
||||||
|
};
|
||||||
|
let mut cursor = Cursor::new(&mut buffer as &mut [u8]);
|
||||||
|
let result = cmd.write(&mut cursor).await;
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Err(ServerResponseWriteError::WriteError(_))
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|||||||
117
src/messages/string.rs
Normal file
117
src/messages/string.rs
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy};
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
use std::string::FromUtf8Error;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub struct SOCKSv5String(String);
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
const STRING_REGEX: &str = "[a-zA-Z0-9_.|!@#$%^]+";
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
impl Arbitrary for SOCKSv5String {
|
||||||
|
type Parameters = Option<u16>;
|
||||||
|
type Strategy = BoxedStrategy<Self>;
|
||||||
|
|
||||||
|
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||||
|
let max_len = args.unwrap_or(32) as usize;
|
||||||
|
|
||||||
|
STRING_REGEX
|
||||||
|
.prop_map(move |mut str| {
|
||||||
|
str.shrink_to(max_len);
|
||||||
|
SOCKSv5String(str)
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum SOCKSv5StringReadError {
|
||||||
|
#[error("Underlying buffer read error: {0}")]
|
||||||
|
ReadError(String),
|
||||||
|
#[error("SOCKSv5 string encoding error; encountered empty string (?)")]
|
||||||
|
ZeroStringLength,
|
||||||
|
#[error("Invalid UTF-8 string: {0}")]
|
||||||
|
InvalidUtf8Error(#[from] FromUtf8Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for SOCKSv5StringReadError {
|
||||||
|
fn from(x: std::io::Error) -> SOCKSv5StringReadError {
|
||||||
|
SOCKSv5StringReadError::ReadError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
|
pub enum SOCKSv5StringWriteError {
|
||||||
|
#[error("Underlying buffer write error: {0}")]
|
||||||
|
WriteError(String),
|
||||||
|
#[error("String too large to encode according to SOCKSv5 reuls ({0} bytes long)")]
|
||||||
|
TooBig(usize),
|
||||||
|
#[error("Cannot serialize the empty string in SOCKSv5")]
|
||||||
|
ZeroStringLength,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for SOCKSv5StringWriteError {
|
||||||
|
fn from(x: std::io::Error) -> SOCKSv5StringWriteError {
|
||||||
|
SOCKSv5StringWriteError::WriteError(format!("{}", x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SOCKSv5String {
|
||||||
|
pub async fn read<R: AsyncRead + Unpin>(r: &mut R) -> Result<Self, SOCKSv5StringReadError> {
|
||||||
|
let length = r.read_u8().await? as usize;
|
||||||
|
|
||||||
|
if length == 0 {
|
||||||
|
return Err(SOCKSv5StringReadError::ZeroStringLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut bytestring = vec![0; length];
|
||||||
|
r.read_exact(&mut bytestring).await?;
|
||||||
|
|
||||||
|
Ok(SOCKSv5String(String::from_utf8(bytestring)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write<W: AsyncWrite + Unpin>(
|
||||||
|
&self,
|
||||||
|
w: &mut W,
|
||||||
|
) -> Result<(), SOCKSv5StringWriteError> {
|
||||||
|
let bytestring = self.0.as_bytes();
|
||||||
|
|
||||||
|
if bytestring.is_empty() {
|
||||||
|
return Err(SOCKSv5StringWriteError::ZeroStringLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
let length = match u8::try_from(bytestring.len()) {
|
||||||
|
Err(_) => return Err(SOCKSv5StringWriteError::TooBig(bytestring.len())),
|
||||||
|
Ok(x) => x,
|
||||||
|
};
|
||||||
|
|
||||||
|
w.write_u8(length).await?;
|
||||||
|
w.write_all(bytestring).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for SOCKSv5String {
|
||||||
|
fn from(x: String) -> Self {
|
||||||
|
SOCKSv5String(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a str> for SOCKSv5String {
|
||||||
|
fn from(x: &str) -> Self {
|
||||||
|
SOCKSv5String(x.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SOCKSv5String> for String {
|
||||||
|
fn from(x: SOCKSv5String) -> Self {
|
||||||
|
x.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::standard_roundtrip!(socks_string_roundtrips, SOCKSv5String);
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
#[doc(hidden)]
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! standard_roundtrip {
|
|
||||||
($name: ident, $t: ty) => {
|
|
||||||
proptest! {
|
|
||||||
#[test]
|
|
||||||
fn $name(xs: $t) {
|
|
||||||
let mut buffer = vec![];
|
|
||||||
task::block_on(xs.write(&mut buffer)).unwrap();
|
|
||||||
let mut cursor = Cursor::new(buffer);
|
|
||||||
let ys = <$t>::read(&mut cursor);
|
|
||||||
assert_eq!(xs, task::block_on(ys).unwrap());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
pub mod address;
|
|
||||||
pub mod datagram;
|
|
||||||
pub mod generic;
|
|
||||||
pub mod listener;
|
|
||||||
pub mod standard;
|
|
||||||
pub mod stream;
|
|
||||||
pub mod testing;
|
|
||||||
|
|
||||||
pub use crate::network::address::SOCKSv5Address;
|
|
||||||
pub use crate::network::standard::Builtin;
|
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
use crate::errors::{DeserializationError, SerializationError};
|
|
||||||
use crate::serialize::{read_amt, read_string, write_string};
|
|
||||||
use crate::standard_roundtrip;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::task;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::io::Cursor;
|
|
||||||
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
use proptest::prelude::proptest;
|
|
||||||
#[cfg(test)]
|
|
||||||
use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any, prop_oneof};
|
|
||||||
use std::convert::TryFrom;
|
|
||||||
use std::fmt;
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
|
||||||
pub enum SOCKSv5Address {
|
|
||||||
IP4(Ipv4Addr),
|
|
||||||
IP6(Ipv6Addr),
|
|
||||||
Name(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
const HOSTNAME_REGEX: &str = "[a-zA-Z0-9_.]+";
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
impl Arbitrary for SOCKSv5Address {
|
|
||||||
type Parameters = Option<u16>;
|
|
||||||
type Strategy = BoxedStrategy<Self>;
|
|
||||||
|
|
||||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
|
||||||
let max_len = args.unwrap_or(32) as usize;
|
|
||||||
|
|
||||||
prop_oneof![
|
|
||||||
any::<Ipv4Addr>().prop_map(SOCKSv5Address::IP4),
|
|
||||||
any::<Ipv6Addr>().prop_map(SOCKSv5Address::IP6),
|
|
||||||
HOSTNAME_REGEX.prop_map(move |mut hostname| {
|
|
||||||
hostname.shrink_to(max_len);
|
|
||||||
SOCKSv5Address::Name(hostname)
|
|
||||||
}),
|
|
||||||
].boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug, PartialEq)]
|
|
||||||
pub enum AddressConversionError {
|
|
||||||
#[error("Couldn't convert IPv4 address into destination type")]
|
|
||||||
CouldntConvertIP4,
|
|
||||||
#[error("Couldn't convert IPv6 address into destination type")]
|
|
||||||
CouldntConvertIP6,
|
|
||||||
#[error("Couldn't convert name into destination type")]
|
|
||||||
CouldntConvertName,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<IpAddr> for SOCKSv5Address {
|
|
||||||
fn from(x: IpAddr) -> SOCKSv5Address {
|
|
||||||
match x {
|
|
||||||
IpAddr::V4(a) => SOCKSv5Address::IP4(a),
|
|
||||||
IpAddr::V6(a) => SOCKSv5Address::IP6(a),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<SOCKSv5Address> for IpAddr {
|
|
||||||
type Error = AddressConversionError;
|
|
||||||
|
|
||||||
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
|
|
||||||
match value {
|
|
||||||
SOCKSv5Address::IP4(a) => Ok(IpAddr::V4(a)),
|
|
||||||
SOCKSv5Address::IP6(a) => Ok(IpAddr::V6(a)),
|
|
||||||
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Ipv4Addr> for SOCKSv5Address {
|
|
||||||
fn from(x: Ipv4Addr) -> Self {
|
|
||||||
SOCKSv5Address::IP4(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<SOCKSv5Address> for Ipv4Addr {
|
|
||||||
type Error = AddressConversionError;
|
|
||||||
|
|
||||||
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
|
|
||||||
match value {
|
|
||||||
SOCKSv5Address::IP4(a) => Ok(a),
|
|
||||||
SOCKSv5Address::IP6(_) => Err(AddressConversionError::CouldntConvertIP6),
|
|
||||||
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Ipv6Addr> for SOCKSv5Address {
|
|
||||||
fn from(x: Ipv6Addr) -> Self {
|
|
||||||
SOCKSv5Address::IP6(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<SOCKSv5Address> for Ipv6Addr {
|
|
||||||
type Error = AddressConversionError;
|
|
||||||
|
|
||||||
fn try_from(value: SOCKSv5Address) -> Result<Self, Self::Error> {
|
|
||||||
match value {
|
|
||||||
SOCKSv5Address::IP4(_) => Err(AddressConversionError::CouldntConvertIP4),
|
|
||||||
SOCKSv5Address::IP6(a) => Ok(a),
|
|
||||||
SOCKSv5Address::Name(_) => Err(AddressConversionError::CouldntConvertName),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<String> for SOCKSv5Address {
|
|
||||||
fn from(x: String) -> Self {
|
|
||||||
SOCKSv5Address::Name(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> From<&'a str> for SOCKSv5Address {
|
|
||||||
fn from(x: &str) -> SOCKSv5Address {
|
|
||||||
SOCKSv5Address::Name(x.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for SOCKSv5Address {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
SOCKSv5Address::IP4(a) => write!(f, "{}", a),
|
|
||||||
SOCKSv5Address::IP6(a) => write!(f, "{}", a),
|
|
||||||
SOCKSv5Address::Name(a) => write!(f, "{}", a),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SOCKSv5Address {
|
|
||||||
pub async fn read<R: AsyncRead + Send + Unpin>(
|
|
||||||
r: &mut R,
|
|
||||||
) -> Result<Self, DeserializationError> {
|
|
||||||
let mut byte_buffer = [0u8; 1];
|
|
||||||
let amount_read = r.read(&mut byte_buffer).await?;
|
|
||||||
|
|
||||||
if amount_read == 0 {
|
|
||||||
return Err(DeserializationError::NotEnoughData);
|
|
||||||
}
|
|
||||||
|
|
||||||
match byte_buffer[0] {
|
|
||||||
1 => {
|
|
||||||
let mut addr_buffer = [0; 4];
|
|
||||||
read_amt(r, 4, &mut addr_buffer).await?;
|
|
||||||
Ok(SOCKSv5Address::IP4(Ipv4Addr::from(addr_buffer)))
|
|
||||||
}
|
|
||||||
3 => {
|
|
||||||
let name = read_string(r).await?;
|
|
||||||
Ok(SOCKSv5Address::Name(name))
|
|
||||||
}
|
|
||||||
4 => {
|
|
||||||
let mut addr_buffer = [0; 16];
|
|
||||||
read_amt(r, 16, &mut addr_buffer).await?;
|
|
||||||
Ok(SOCKSv5Address::IP6(Ipv6Addr::from(addr_buffer)))
|
|
||||||
}
|
|
||||||
x => Err(DeserializationError::InvalidAddressType(x)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write<W: AsyncWrite + Send + Unpin>(
|
|
||||||
&self,
|
|
||||||
w: &mut W,
|
|
||||||
) -> Result<(), SerializationError> {
|
|
||||||
match self {
|
|
||||||
SOCKSv5Address::IP4(x) => {
|
|
||||||
w.write_all(&[1]).await?;
|
|
||||||
w.write_all(&x.octets())
|
|
||||||
.await
|
|
||||||
.map_err(SerializationError::IOError)
|
|
||||||
}
|
|
||||||
SOCKSv5Address::IP6(x) => {
|
|
||||||
w.write_all(&[4]).await?;
|
|
||||||
w.write_all(&x.octets())
|
|
||||||
.await
|
|
||||||
.map_err(SerializationError::IOError)
|
|
||||||
}
|
|
||||||
SOCKSv5Address::Name(x) => {
|
|
||||||
w.write_all(&[3]).await?;
|
|
||||||
write_string(x, w).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait HasLocalAddress {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16);
|
|
||||||
}
|
|
||||||
|
|
||||||
standard_roundtrip!(address_roundtrips, SOCKSv5Address);
|
|
||||||
|
|
||||||
proptest! {
|
|
||||||
#[test]
|
|
||||||
fn ip_conversion(x: IpAddr) {
|
|
||||||
match x {
|
|
||||||
IpAddr::V4(ref a) =>
|
|
||||||
assert_eq!(Err(AddressConversionError::CouldntConvertIP4),
|
|
||||||
Ipv6Addr::try_from(SOCKSv5Address::from(*a))),
|
|
||||||
IpAddr::V6(ref a) =>
|
|
||||||
assert_eq!(Err(AddressConversionError::CouldntConvertIP6),
|
|
||||||
Ipv4Addr::try_from(SOCKSv5Address::from(*a))),
|
|
||||||
}
|
|
||||||
assert_eq!(x, IpAddr::try_from(SOCKSv5Address::from(x)).unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn ip4_conversion(x: Ipv4Addr) {
|
|
||||||
assert_eq!(x, Ipv4Addr::try_from(SOCKSv5Address::from(x)).unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn ip6_conversion(x: Ipv6Addr) {
|
|
||||||
assert_eq!(x, Ipv6Addr::try_from(SOCKSv5Address::from(x)).unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn display_matches(x: SOCKSv5Address) {
|
|
||||||
match x {
|
|
||||||
SOCKSv5Address::IP4(a) => assert_eq!(format!("{}", a), format!("{}", x)),
|
|
||||||
SOCKSv5Address::IP6(a) => assert_eq!(format!("{}", a), format!("{}", x)),
|
|
||||||
SOCKSv5Address::Name(ref a) => assert_eq!(*a, x.to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn bad_read_key(x: u8) {
|
|
||||||
match x {
|
|
||||||
1 | 3 | 4 => {}
|
|
||||||
_ => {
|
|
||||||
let buffer = [x, 0, 1, 2, 9, 10];
|
|
||||||
let mut cursor = Cursor::new(buffer);
|
|
||||||
let meh = SOCKSv5Address::read(&mut cursor);
|
|
||||||
assert_eq!(Err(DeserializationError::InvalidAddressType(x)), task::block_on(meh));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn domain_name_sanity() {
|
|
||||||
let name = "uhsure.com";
|
|
||||||
let strname = name.to_string();
|
|
||||||
|
|
||||||
let addr1 = SOCKSv5Address::from(name);
|
|
||||||
let addr2 = SOCKSv5Address::from(strname);
|
|
||||||
|
|
||||||
assert_eq!(addr1, addr2);
|
|
||||||
assert_eq!(
|
|
||||||
Err(AddressConversionError::CouldntConvertName),
|
|
||||||
IpAddr::try_from(addr1.clone())
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Err(AddressConversionError::CouldntConvertName),
|
|
||||||
Ipv4Addr::try_from(addr1.clone())
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Err(AddressConversionError::CouldntConvertName),
|
|
||||||
Ipv6Addr::try_from(addr1)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Datagramlike: Send + Sync + HasLocalAddress {
|
|
||||||
type Error;
|
|
||||||
|
|
||||||
async fn send_to(
|
|
||||||
&self,
|
|
||||||
buf: &[u8],
|
|
||||||
addr: SOCKSv5Address,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<usize, Self::Error>;
|
|
||||||
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct GenericDatagramSocket<E> {
|
|
||||||
pub internal: Box<dyn Datagramlike<Error = E>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<E> Datagramlike for GenericDatagramSocket<E> {
|
|
||||||
type Error = E;
|
|
||||||
|
|
||||||
async fn send_to(
|
|
||||||
&self,
|
|
||||||
buf: &[u8],
|
|
||||||
addr: SOCKSv5Address,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<usize, Self::Error> {
|
|
||||||
Ok(self.internal.send_to(buf, addr, port).await?)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
|
|
||||||
Ok(self.internal.recv_from(buf).await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E> HasLocalAddress for GenericDatagramSocket<E> {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
|
||||||
self.internal.local_addr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
use crate::messages::ServerResponseStatus;
|
|
||||||
use crate::network::address::SOCKSv5Address;
|
|
||||||
use crate::network::datagram::GenericDatagramSocket;
|
|
||||||
use crate::network::listener::GenericListener;
|
|
||||||
use crate::network::stream::GenericStream;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use std::fmt::{Debug, Display};
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Networklike {
|
|
||||||
/// The error type for things that fail on this network. Apologies in advance
|
|
||||||
/// for using only one; if you have a use case for separating your errors,
|
|
||||||
/// please shoot the author(s) and email to split this into multiple types, one
|
|
||||||
/// for each trait function.
|
|
||||||
type Error: Debug + Display + IntoErrorResponse + Send;
|
|
||||||
|
|
||||||
/// Connect to the given address and port, over this kind of network. The
|
|
||||||
/// underlying stream should behave somewhat like a TCP stream ... which
|
|
||||||
/// may be exactly what you're using. However, in order to support tunnelling
|
|
||||||
/// scenarios (i.e., using another proxy, going through Tor or SSH, etc.) we
|
|
||||||
/// work generically over any stream-like object.
|
|
||||||
async fn connect<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<GenericStream, Self::Error>;
|
|
||||||
|
|
||||||
/// Listen for connections on the given address and port, returning a generic
|
|
||||||
/// listener socket to use in the future.
|
|
||||||
async fn listen<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<GenericListener<Self::Error>, Self::Error>;
|
|
||||||
|
|
||||||
/// Bind a socket for the purposes of doing some datagram communication. NOTE!
|
|
||||||
/// this is only for UDP-like communication, not for generic connecting or
|
|
||||||
/// listening! Maybe obvious from the types, but POSIX has overtrained many
|
|
||||||
/// of us.
|
|
||||||
///
|
|
||||||
/// Recall when using these functions that datagram protocols allow for packet
|
|
||||||
/// loss and out-of-order delivery. So ... be warned.
|
|
||||||
async fn bind<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This trait is a hack; sorry about that. The thing is, we want to be able to
|
|
||||||
/// convert Errors from the `Networklike` trait into a `ServerResponseStatus`,
|
|
||||||
/// but want to do so on references to the error object rather than the actual
|
|
||||||
/// object. This is for the paired reason that (a) we want to be able to use
|
|
||||||
/// the errors in multiple places -- for example, to return a value to the client
|
|
||||||
/// and then also to whoever called the function -- and (b) some common errors
|
|
||||||
/// (I'm looking at you, `io::Error`) aren't `Clone`. So ... hence this overly-
|
|
||||||
/// specific trait.
|
|
||||||
pub trait IntoErrorResponse {
|
|
||||||
#[allow(clippy::wrong_self_convention)]
|
|
||||||
fn into_response(&self) -> ServerResponseStatus;
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
|
||||||
use crate::network::stream::GenericStream;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Listenerlike: Send + Sync + HasLocalAddress {
|
|
||||||
type Error;
|
|
||||||
|
|
||||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct GenericListener<E> {
|
|
||||||
pub internal: Box<dyn Listenerlike<Error = E>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<E> Listenerlike for GenericListener<E> {
|
|
||||||
type Error = E;
|
|
||||||
|
|
||||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
|
|
||||||
Ok(self.internal.accept().await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E> HasLocalAddress for GenericListener<E> {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
|
||||||
self.internal.local_addr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
use crate::messages::ServerResponseStatus;
|
|
||||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
|
||||||
use crate::network::datagram::{Datagramlike, GenericDatagramSocket};
|
|
||||||
use crate::network::generic::Networklike;
|
|
||||||
use crate::network::listener::{GenericListener, Listenerlike};
|
|
||||||
use crate::network::stream::{GenericStream, Streamlike};
|
|
||||||
use async_std::io;
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::io::ReadExt;
|
|
||||||
use async_std::net::{TcpListener, TcpStream, UdpSocket};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::AsyncWriteExt;
|
|
||||||
use log::error;
|
|
||||||
|
|
||||||
use super::generic::IntoErrorResponse;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Builtin {}
|
|
||||||
|
|
||||||
impl Builtin {
|
|
||||||
pub fn new() -> Builtin {
|
|
||||||
Builtin {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Builtin {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! local_address_impl {
|
|
||||||
($t: ty) => {
|
|
||||||
impl HasLocalAddress for $t {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
|
||||||
match self.local_addr() {
|
|
||||||
Ok(a) =>
|
|
||||||
(SOCKSv5Address::from(a.ip()), a.port()),
|
|
||||||
Err(e) => {
|
|
||||||
error!("Couldn't translate (Streamlike) local address to SOCKS local address: {}", e);
|
|
||||||
(SOCKSv5Address::from("localhost"), 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
local_address_impl!(TcpStream);
|
|
||||||
local_address_impl!(TcpListener);
|
|
||||||
local_address_impl!(UdpSocket);
|
|
||||||
|
|
||||||
impl Streamlike for TcpStream {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Listenerlike for TcpListener {
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
|
|
||||||
let (base, addrport) = self.accept().await?;
|
|
||||||
let addr = addrport.ip();
|
|
||||||
let port = addrport.port();
|
|
||||||
Ok((GenericStream::new(base), SOCKSv5Address::from(addr), port))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Datagramlike for UdpSocket {
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
async fn send_to(
|
|
||||||
&self,
|
|
||||||
buf: &[u8],
|
|
||||||
addr: SOCKSv5Address,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<usize, Self::Error> {
|
|
||||||
match addr {
|
|
||||||
SOCKSv5Address::IP4(a) => self.send_to(buf, (a, port)).await,
|
|
||||||
SOCKSv5Address::IP6(a) => self.send_to(buf, (a, port)).await,
|
|
||||||
SOCKSv5Address::Name(n) => self.send_to(buf, (n.as_str(), port)).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
|
|
||||||
let (amt, addrport) = self.recv_from(buf).await?;
|
|
||||||
let addr = addrport.ip();
|
|
||||||
let port = addrport.port();
|
|
||||||
Ok((amt, SOCKSv5Address::from(addr), port))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Networklike for Builtin {
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
async fn connect<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<GenericStream, Self::Error> {
|
|
||||||
let target = addr.into();
|
|
||||||
|
|
||||||
let base_stream = match target {
|
|
||||||
SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await?,
|
|
||||||
SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await?,
|
|
||||||
SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(GenericStream::from(base_stream))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn listen<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<GenericListener<Self::Error>, Self::Error> {
|
|
||||||
let target = addr.into();
|
|
||||||
|
|
||||||
let base_stream = match target {
|
|
||||||
SOCKSv5Address::IP4(a) => TcpListener::bind((a, port)).await?,
|
|
||||||
SOCKSv5Address::IP6(a) => TcpListener::bind((a, port)).await?,
|
|
||||||
SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), port)).await?,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(GenericListener {
|
|
||||||
internal: Box::new(base_stream),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn bind<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
|
|
||||||
let target = addr.into();
|
|
||||||
|
|
||||||
let base_socket = match target {
|
|
||||||
SOCKSv5Address::IP4(a) => UdpSocket::bind((a, port)).await?,
|
|
||||||
SOCKSv5Address::IP6(a) => UdpSocket::bind((a, port)).await?,
|
|
||||||
SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), port)).await?,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(GenericDatagramSocket {
|
|
||||||
internal: Box::new(base_socket),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn check_sanity() {
|
|
||||||
async_std::task::block_on(async {
|
|
||||||
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
|
|
||||||
// going to get any dropped data along here ... which is a very questionable
|
|
||||||
// assumption, morally speaking, but probably fine for most purposes.
|
|
||||||
let mut network = Builtin::new();
|
|
||||||
let receiver = network
|
|
||||||
.bind("localhost", 0)
|
|
||||||
.await
|
|
||||||
.expect("Failed to bind receiver socket.");
|
|
||||||
let sender = network
|
|
||||||
.bind("localhost", 0)
|
|
||||||
.await
|
|
||||||
.expect("Failed to bind sender socket.");
|
|
||||||
let buffer = [0xde, 0xea, 0xbe, 0xef];
|
|
||||||
let (receiver_addr, receiver_port) = receiver.local_addr();
|
|
||||||
sender
|
|
||||||
.send_to(&buffer, receiver_addr, receiver_port)
|
|
||||||
.await
|
|
||||||
.expect("Failure sending datagram!");
|
|
||||||
let mut recvbuffer = [0; 4];
|
|
||||||
let (s, f, p) = receiver
|
|
||||||
.recv_from(&mut recvbuffer)
|
|
||||||
.await
|
|
||||||
.expect("Didn't receive UDP message?");
|
|
||||||
let (sender_addr, sender_port) = sender.local_addr();
|
|
||||||
assert_eq!(s, 4);
|
|
||||||
assert_eq!(f, sender_addr);
|
|
||||||
assert_eq!(p, sender_port);
|
|
||||||
assert_eq!(recvbuffer, buffer);
|
|
||||||
});
|
|
||||||
|
|
||||||
// This whole block should be pretty solid, though, unless the system we're
|
|
||||||
// on is in a pretty weird place.
|
|
||||||
let mut network = Builtin::new();
|
|
||||||
|
|
||||||
let listener = async_std::task::block_on(network.listen("localhost", 0))
|
|
||||||
.expect("Couldn't set up listener on localhost");
|
|
||||||
let (listener_address, listener_port) = listener.local_addr();
|
|
||||||
|
|
||||||
let listener_task_handle = async_std::task::spawn(async move {
|
|
||||||
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
|
|
||||||
let mut result_buffer = [0u8; 4];
|
|
||||||
stream
|
|
||||||
.read_exact(&mut result_buffer)
|
|
||||||
.await
|
|
||||||
.expect("Read failure in TCP test");
|
|
||||||
(result_buffer, addr, port)
|
|
||||||
});
|
|
||||||
|
|
||||||
let sender_task_handle = async_std::task::spawn(async move {
|
|
||||||
let mut sender = network
|
|
||||||
.connect(listener_address, listener_port)
|
|
||||||
.await
|
|
||||||
.expect("Coudln't connect to listener?");
|
|
||||||
let (sender_address, sender_port) = sender.local_addr();
|
|
||||||
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
|
|
||||||
sender
|
|
||||||
.write_all(&send_buffer)
|
|
||||||
.await
|
|
||||||
.expect("Couldn't send the write buffer");
|
|
||||||
sender
|
|
||||||
.flush()
|
|
||||||
.await
|
|
||||||
.expect("Couldn't flush the write buffer");
|
|
||||||
sender
|
|
||||||
.close()
|
|
||||||
.await
|
|
||||||
.expect("Couldn't close the write buffer");
|
|
||||||
(sender_address, sender_port)
|
|
||||||
});
|
|
||||||
|
|
||||||
async_std::task::block_on(async {
|
|
||||||
let (result, result_from, result_from_port) = listener_task_handle.await;
|
|
||||||
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
|
|
||||||
let (sender_address, sender_port) = sender_task_handle.await;
|
|
||||||
assert_eq!(result_from, sender_address);
|
|
||||||
assert_eq!(result_from_port, sender_port);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoErrorResponse for io::Error {
|
|
||||||
fn into_response(&self) -> ServerResponseStatus {
|
|
||||||
match self.kind() {
|
|
||||||
io::ErrorKind::ConnectionRefused => ServerResponseStatus::ConnectionRefused,
|
|
||||||
io::ErrorKind::NotFound => ServerResponseStatus::HostUnreachable,
|
|
||||||
_ => ServerResponseStatus::GeneralFailure,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
use crate::network::SOCKSv5Address;
|
|
||||||
use async_std::task::{Context, Poll};
|
|
||||||
use futures::io;
|
|
||||||
use futures::io::{AsyncRead, AsyncWrite};
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
use super::address::HasLocalAddress;
|
|
||||||
|
|
||||||
pub trait Streamlike: AsyncRead + AsyncWrite + HasLocalAddress + Send + Sync + Unpin {}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct GenericStream {
|
|
||||||
internal: Arc<Mutex<dyn Streamlike>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenericStream {
|
|
||||||
pub fn new<T: Streamlike + 'static>(x: T) -> GenericStream {
|
|
||||||
GenericStream {
|
|
||||||
internal: Arc::new(Mutex::new(x)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HasLocalAddress for GenericStream {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
|
||||||
let item = self.internal.lock().unwrap();
|
|
||||||
item.local_addr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncRead for GenericStream {
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut [u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
let mut item = self.internal.lock().unwrap();
|
|
||||||
let pinned = Pin::new(&mut *item);
|
|
||||||
pinned.poll_read(cx, buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncWrite for GenericStream {
|
|
||||||
fn poll_write(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
let mut item = self.internal.lock().unwrap();
|
|
||||||
let pinned = Pin::new(&mut *item);
|
|
||||||
pinned.poll_write(cx, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
let mut item = self.internal.lock().unwrap();
|
|
||||||
let pinned = Pin::new(&mut *item);
|
|
||||||
pinned.poll_flush(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
let mut item = self.internal.lock().unwrap();
|
|
||||||
let pinned = Pin::new(&mut *item);
|
|
||||||
pinned.poll_close(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Streamlike + 'static> From<T> for GenericStream {
|
|
||||||
fn from(x: T) -> GenericStream {
|
|
||||||
GenericStream {
|
|
||||||
internal: Arc::new(Mutex::new(x)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,276 +0,0 @@
|
|||||||
mod datagram;
|
|
||||||
mod stream;
|
|
||||||
|
|
||||||
use crate::messages::ServerResponseStatus;
|
|
||||||
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
|
|
||||||
#[cfg(test)]
|
|
||||||
use crate::network::datagram::Datagramlike;
|
|
||||||
use crate::network::datagram::GenericDatagramSocket;
|
|
||||||
use crate::network::generic::{IntoErrorResponse, Networklike};
|
|
||||||
use crate::network::listener::{GenericListener, Listenerlike};
|
|
||||||
use crate::network::stream::GenericStream;
|
|
||||||
use crate::network::testing::datagram::TestDatagram;
|
|
||||||
use crate::network::testing::stream::TestingStream;
|
|
||||||
use async_std::channel::{bounded, Receiver, Sender};
|
|
||||||
use async_std::sync::{Arc, Mutex};
|
|
||||||
#[cfg(test)]
|
|
||||||
use async_std::task;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
#[cfg(test)]
|
|
||||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
/// A "network", based purely on internal Rust datatypes, for testing
|
|
||||||
/// networking code. This stack operates purely in memory, so shouldn't
|
|
||||||
/// suffer from any weird networking effects ... which makes it a good
|
|
||||||
/// functional test, but not great at actually testing real-world failure
|
|
||||||
/// modes.
|
|
||||||
#[allow(clippy::type_complexity)]
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct TestingStack {
|
|
||||||
tcp_listeners: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<TestingStream>>>>,
|
|
||||||
udp_sockets: Arc<Mutex<HashMap<(SOCKSv5Address, u16), Sender<(SOCKSv5Address, u16, Vec<u8>)>>>>,
|
|
||||||
next_random_socket: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestingStack {
|
|
||||||
pub fn new() -> TestingStack {
|
|
||||||
TestingStack {
|
|
||||||
tcp_listeners: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
udp_sockets: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
next_random_socket: 23,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for TestingStack {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TestStackError {
|
|
||||||
AcceptFailed,
|
|
||||||
AddressBusy(SOCKSv5Address, u16),
|
|
||||||
ConnectionFailed,
|
|
||||||
FailureToSend,
|
|
||||||
NoTCPHostFound(SOCKSv5Address, u16),
|
|
||||||
ReceiveFailure,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for TestStackError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
TestStackError::AcceptFailed => write!(f, "Accept failed; the other side died (?)"),
|
|
||||||
TestStackError::AddressBusy(ref addr, port) => {
|
|
||||||
write!(f, "Address {}:{} already in use", addr, port)
|
|
||||||
}
|
|
||||||
TestStackError::ConnectionFailed => write!(f, "Couldn't connect to host."),
|
|
||||||
TestStackError::FailureToSend => write!(
|
|
||||||
f,
|
|
||||||
"Weird internal error in testing infrastructure; channel send failed"
|
|
||||||
),
|
|
||||||
TestStackError::NoTCPHostFound(ref addr, port) => {
|
|
||||||
write!(f, "No host found at {} for TCP port {}", addr, port)
|
|
||||||
}
|
|
||||||
TestStackError::ReceiveFailure => {
|
|
||||||
write!(f, "Failed to process a UDP receive (this is weird)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoErrorResponse for TestStackError {
|
|
||||||
fn into_response(&self) -> ServerResponseStatus {
|
|
||||||
ServerResponseStatus::GeneralFailure
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Networklike for TestingStack {
|
|
||||||
type Error = TestStackError;
|
|
||||||
|
|
||||||
async fn connect<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<GenericStream, Self::Error> {
|
|
||||||
let table = self.tcp_listeners.lock().await;
|
|
||||||
let target = addr.into();
|
|
||||||
|
|
||||||
match table.get(&(target.clone(), port)) {
|
|
||||||
None => Err(TestStackError::NoTCPHostFound(target, port)),
|
|
||||||
Some(result) => {
|
|
||||||
let stream = TestingStream::new(target, port);
|
|
||||||
let retval = stream.invert();
|
|
||||||
match result.send(stream).await {
|
|
||||||
Ok(()) => Ok(GenericStream::new(retval)),
|
|
||||||
Err(_) => Err(TestStackError::FailureToSend),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn listen<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
mut port: u16,
|
|
||||||
) -> Result<GenericListener<Self::Error>, Self::Error> {
|
|
||||||
let mut table = self.tcp_listeners.lock().await;
|
|
||||||
let target = addr.into();
|
|
||||||
let (sender, receiver) = bounded(5);
|
|
||||||
|
|
||||||
if port == 0 {
|
|
||||||
port = self.next_random_socket;
|
|
||||||
self.next_random_socket += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
table.insert((target.clone(), port), sender);
|
|
||||||
Ok(GenericListener {
|
|
||||||
internal: Box::new(TestListener::new(target, port, receiver)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn bind<A: Send + Into<SOCKSv5Address>>(
|
|
||||||
&mut self,
|
|
||||||
addr: A,
|
|
||||||
mut port: u16,
|
|
||||||
) -> Result<GenericDatagramSocket<Self::Error>, Self::Error> {
|
|
||||||
let mut table = self.udp_sockets.lock().await;
|
|
||||||
let target = addr.into();
|
|
||||||
let (sender, receiver) = bounded(5);
|
|
||||||
|
|
||||||
if port == 0 {
|
|
||||||
port = self.next_random_socket;
|
|
||||||
self.next_random_socket += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
table.insert((target.clone(), port), sender);
|
|
||||||
Ok(GenericDatagramSocket {
|
|
||||||
internal: Box::new(TestDatagram::new(self.clone(), target, port, receiver)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TestListener {
|
|
||||||
address: SOCKSv5Address,
|
|
||||||
port: u16,
|
|
||||||
receiver: Receiver<TestingStream>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestListener {
|
|
||||||
fn new(address: SOCKSv5Address, port: u16, receiver: Receiver<TestingStream>) -> Self {
|
|
||||||
TestListener {
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
receiver,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HasLocalAddress for TestListener {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
|
||||||
(self.address.clone(), self.port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Listenerlike for TestListener {
|
|
||||||
type Error = TestStackError;
|
|
||||||
|
|
||||||
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
|
|
||||||
match self.receiver.recv().await {
|
|
||||||
Ok(next) => {
|
|
||||||
let (addr, port) = next.local_addr();
|
|
||||||
Ok((GenericStream::new(next), addr, port))
|
|
||||||
}
|
|
||||||
Err(_) => Err(TestStackError::AcceptFailed),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn check_udp_sanity() {
|
|
||||||
task::block_on(async {
|
|
||||||
let mut network = TestingStack::new();
|
|
||||||
let receiver = network
|
|
||||||
.bind("localhost", 0)
|
|
||||||
.await
|
|
||||||
.expect("Failed to bind receiver socket.");
|
|
||||||
let sender = network
|
|
||||||
.bind("localhost", 0)
|
|
||||||
.await
|
|
||||||
.expect("Failed to bind sender socket.");
|
|
||||||
let buffer = [0xde, 0xea, 0xbe, 0xef];
|
|
||||||
let (receiver_addr, receiver_port) = receiver.local_addr();
|
|
||||||
sender
|
|
||||||
.send_to(&buffer, receiver_addr, receiver_port)
|
|
||||||
.await
|
|
||||||
.expect("Failure sending datagram!");
|
|
||||||
let mut recvbuffer = [0; 4];
|
|
||||||
let (s, f, p) = receiver
|
|
||||||
.recv_from(&mut recvbuffer)
|
|
||||||
.await
|
|
||||||
.expect("Didn't receive UDP message?");
|
|
||||||
let (sender_addr, sender_port) = sender.local_addr();
|
|
||||||
assert_eq!(s, 4);
|
|
||||||
assert_eq!(f, sender_addr);
|
|
||||||
assert_eq!(p, sender_port);
|
|
||||||
assert_eq!(recvbuffer, buffer);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn check_basic_tcp() {
|
|
||||||
task::block_on(async {
|
|
||||||
let mut network = TestingStack::new();
|
|
||||||
|
|
||||||
let listener = network
|
|
||||||
.listen("localhost", 0)
|
|
||||||
.await
|
|
||||||
.expect("Couldn't set up listener on localhost");
|
|
||||||
let (listener_address, listener_port) = listener.local_addr();
|
|
||||||
|
|
||||||
let listener_task_handle = task::spawn(async move {
|
|
||||||
dbg!("Starting listener task!!");
|
|
||||||
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
|
|
||||||
let mut result_buffer = [0u8; 4];
|
|
||||||
if let Err(e) = stream.read_exact(&mut result_buffer).await {
|
|
||||||
dbg!("Error reading buffer from stream: {}", e);
|
|
||||||
} else {
|
|
||||||
dbg!("made it through read_exact");
|
|
||||||
}
|
|
||||||
(result_buffer, addr, port)
|
|
||||||
});
|
|
||||||
|
|
||||||
let sender_task_handle = task::spawn(async move {
|
|
||||||
let mut sender = network
|
|
||||||
.connect(listener_address, listener_port)
|
|
||||||
.await
|
|
||||||
.expect("Coudln't connect to listener?");
|
|
||||||
let (sender_address, sender_port) = sender.local_addr();
|
|
||||||
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
|
|
||||||
sender
|
|
||||||
.write_all(&send_buffer)
|
|
||||||
.await
|
|
||||||
.expect("Couldn't send the write buffer");
|
|
||||||
sender
|
|
||||||
.flush()
|
|
||||||
.await
|
|
||||||
.expect("Couldn't flush the write buffer");
|
|
||||||
sender
|
|
||||||
.close()
|
|
||||||
.await
|
|
||||||
.expect("Couldn't close the write buffer");
|
|
||||||
(sender_address, sender_port)
|
|
||||||
});
|
|
||||||
|
|
||||||
let (result, result_from, result_from_port) = listener_task_handle.await;
|
|
||||||
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
|
|
||||||
let (sender_address, sender_port) = sender_task_handle.await;
|
|
||||||
assert_eq!(result_from, sender_address);
|
|
||||||
assert_eq!(result_from_port, sender_port);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
use crate::network::address::HasLocalAddress;
|
|
||||||
use crate::network::datagram::Datagramlike;
|
|
||||||
use crate::network::testing::{TestStackError, TestingStack};
|
|
||||||
use crate::network::SOCKSv5Address;
|
|
||||||
use async_std::channel::Receiver;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use std::cmp::Ordering;
|
|
||||||
|
|
||||||
pub struct TestDatagram {
|
|
||||||
context: TestingStack,
|
|
||||||
my_address: SOCKSv5Address,
|
|
||||||
my_port: u16,
|
|
||||||
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestDatagram {
|
|
||||||
pub fn new(
|
|
||||||
context: TestingStack,
|
|
||||||
my_address: SOCKSv5Address,
|
|
||||||
my_port: u16,
|
|
||||||
input_stream: Receiver<(SOCKSv5Address, u16, Vec<u8>)>,
|
|
||||||
) -> Self {
|
|
||||||
TestDatagram {
|
|
||||||
context,
|
|
||||||
my_address,
|
|
||||||
my_port,
|
|
||||||
input_stream,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HasLocalAddress for TestDatagram {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
|
||||||
(self.my_address.clone(), self.my_port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Datagramlike for TestDatagram {
|
|
||||||
type Error = TestStackError;
|
|
||||||
|
|
||||||
async fn send_to(
|
|
||||||
&self,
|
|
||||||
buf: &[u8],
|
|
||||||
target: SOCKSv5Address,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<usize, Self::Error> {
|
|
||||||
let table = self.context.udp_sockets.lock().await;
|
|
||||||
match table.get(&(target, port)) {
|
|
||||||
None => Ok(buf.len()),
|
|
||||||
Some(sender) => {
|
|
||||||
sender
|
|
||||||
.send((self.my_address.clone(), self.my_port, buf.to_vec()))
|
|
||||||
.await
|
|
||||||
.map_err(|_| TestStackError::FailureToSend)?;
|
|
||||||
Ok(buf.len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn recv_from(
|
|
||||||
&self,
|
|
||||||
buffer: &mut [u8],
|
|
||||||
) -> Result<(usize, SOCKSv5Address, u16), Self::Error> {
|
|
||||||
let (from_addr, from_port, message) = self
|
|
||||||
.input_stream
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.map_err(|_| TestStackError::ReceiveFailure)?;
|
|
||||||
|
|
||||||
match message.len().cmp(&buffer.len()) {
|
|
||||||
Ordering::Greater => {
|
|
||||||
buffer.copy_from_slice(&message[..buffer.len()]);
|
|
||||||
Ok((message.len(), from_addr, from_port))
|
|
||||||
}
|
|
||||||
|
|
||||||
Ordering::Less => {
|
|
||||||
(&mut buffer[..message.len()]).copy_from_slice(&message);
|
|
||||||
Ok((message.len(), from_addr, from_port))
|
|
||||||
}
|
|
||||||
|
|
||||||
Ordering::Equal => {
|
|
||||||
buffer.copy_from_slice(message.as_ref());
|
|
||||||
Ok((message.len(), from_addr, from_port))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,192 +0,0 @@
|
|||||||
use crate::network::address::HasLocalAddress;
|
|
||||||
use crate::network::stream::Streamlike;
|
|
||||||
use crate::network::SOCKSv5Address;
|
|
||||||
use async_std::io;
|
|
||||||
use async_std::io::{Read, Write};
|
|
||||||
use async_std::task::{Context, Poll, Waker};
|
|
||||||
use std::cell::UnsafeCell;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::ptr::NonNull;
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct TestingStream {
|
|
||||||
address: SOCKSv5Address,
|
|
||||||
port: u16,
|
|
||||||
read_side: NonNull<TestingStreamData>,
|
|
||||||
write_side: NonNull<TestingStreamData>,
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl Send for TestingStream {}
|
|
||||||
unsafe impl Sync for TestingStream {}
|
|
||||||
|
|
||||||
struct TestingStreamData {
|
|
||||||
lock: AtomicBool,
|
|
||||||
writer_dead: AtomicBool,
|
|
||||||
waiters: UnsafeCell<Vec<Waker>>,
|
|
||||||
buffer: UnsafeCell<Vec<u8>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl Send for TestingStreamData {}
|
|
||||||
unsafe impl Sync for TestingStreamData {}
|
|
||||||
|
|
||||||
impl TestingStream {
|
|
||||||
/// Generate a testing stream. Note that this is directional. So, if you want to
|
|
||||||
/// talk to this stream, you should also generate an `invert()` and pass that to
|
|
||||||
/// the other thread(s).
|
|
||||||
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
|
|
||||||
let read_side_data = TestingStreamData {
|
|
||||||
lock: AtomicBool::new(false),
|
|
||||||
writer_dead: AtomicBool::new(false),
|
|
||||||
waiters: UnsafeCell::new(Vec::new()),
|
|
||||||
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let write_side_data = TestingStreamData {
|
|
||||||
lock: AtomicBool::new(false),
|
|
||||||
writer_dead: AtomicBool::new(false),
|
|
||||||
waiters: UnsafeCell::new(Vec::new()),
|
|
||||||
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let boxed_rsd = Box::new(read_side_data);
|
|
||||||
let boxed_wsd = Box::new(write_side_data);
|
|
||||||
let raw_read_ptr = Box::leak(boxed_rsd);
|
|
||||||
let raw_write_ptr = Box::leak(boxed_wsd);
|
|
||||||
|
|
||||||
TestingStream {
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
read_side: NonNull::new(raw_read_ptr).unwrap(),
|
|
||||||
write_side: NonNull::new(raw_write_ptr).unwrap(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the flip side of this stream; reads from the inverted side will catch the writes
|
|
||||||
/// of the original, etc.
|
|
||||||
pub fn invert(&self) -> TestingStream {
|
|
||||||
TestingStream {
|
|
||||||
address: self.address.clone(),
|
|
||||||
port: self.port,
|
|
||||||
read_side: self.write_side,
|
|
||||||
write_side: self.read_side,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestingStreamData {
|
|
||||||
fn acquire(&mut self) {
|
|
||||||
loop {
|
|
||||||
match self
|
|
||||||
.lock
|
|
||||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
|
||||||
{
|
|
||||||
Err(_) => continue,
|
|
||||||
Ok(_) => return,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn release(&mut self) {
|
|
||||||
self.lock.store(false, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HasLocalAddress for TestingStream {
|
|
||||||
fn local_addr(&self) -> (SOCKSv5Address, u16) {
|
|
||||||
(self.address.clone(), self.port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Read for TestingStream {
|
|
||||||
fn poll_read(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut [u8],
|
|
||||||
) -> Poll<std::io::Result<usize>> {
|
|
||||||
// so, we're going to spin here, which is less than ideal but should work fine
|
|
||||||
// in practice. we'll obviously need to be very careful to ensure that we keep
|
|
||||||
// the stuff internal to this spin really short.
|
|
||||||
let internals = unsafe { self.read_side.as_mut() };
|
|
||||||
|
|
||||||
internals.acquire();
|
|
||||||
|
|
||||||
let stream_buffer = internals.buffer.get_mut();
|
|
||||||
let amount_available = stream_buffer.len();
|
|
||||||
|
|
||||||
if amount_available == 0 {
|
|
||||||
// we wait to do this check until we've determined the buffer is empty,
|
|
||||||
// so that we make sure to drain any residual stuff in there.
|
|
||||||
if internals.writer_dead.load(Ordering::SeqCst) {
|
|
||||||
internals.release();
|
|
||||||
return Poll::Ready(Err(io::Error::new(
|
|
||||||
io::ErrorKind::ConnectionReset,
|
|
||||||
"Writer closed the socket.",
|
|
||||||
)));
|
|
||||||
} else {
|
|
||||||
let waker = cx.waker().clone();
|
|
||||||
internals.waiters.get_mut().push(waker);
|
|
||||||
internals.release();
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let amt_written = if buf.len() >= amount_available {
|
|
||||||
(&mut buf[0..amount_available]).copy_from_slice(stream_buffer);
|
|
||||||
stream_buffer.clear();
|
|
||||||
amount_available
|
|
||||||
} else {
|
|
||||||
let amt_to_copy = buf.len();
|
|
||||||
buf.copy_from_slice(&stream_buffer[0..amt_to_copy]);
|
|
||||||
stream_buffer.copy_within(amt_to_copy.., 0);
|
|
||||||
let amt_left = amount_available - amt_to_copy;
|
|
||||||
stream_buffer.resize(amt_left, 0);
|
|
||||||
amt_to_copy
|
|
||||||
};
|
|
||||||
|
|
||||||
internals.release();
|
|
||||||
|
|
||||||
Poll::Ready(Ok(amt_written))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Write for TestingStream {
|
|
||||||
fn poll_write(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
_cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
let internals = unsafe { self.write_side.as_mut() };
|
|
||||||
internals.acquire();
|
|
||||||
let stream_buffer = internals.buffer.get_mut();
|
|
||||||
stream_buffer.extend_from_slice(buf);
|
|
||||||
for waiter in internals.waiters.get_mut().drain(0..) {
|
|
||||||
waiter.wake();
|
|
||||||
}
|
|
||||||
internals.release();
|
|
||||||
|
|
||||||
Poll::Ready(Ok(buf.len()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
Poll::Ready(Ok(())) // FIXME: Might consider having this wait until the buffer is empty
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
Poll::Ready(Ok(())) // FIXME: Might consider putting in some open/closed logic here
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Streamlike for TestingStream {}
|
|
||||||
|
|
||||||
impl Drop for TestingStream {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
let internals = unsafe { self.write_side.as_mut() };
|
|
||||||
internals.writer_dead.store(true, Ordering::SeqCst);
|
|
||||||
internals.acquire();
|
|
||||||
for waiter in internals.waiters.get_mut().drain(0..) {
|
|
||||||
waiter.wake();
|
|
||||||
}
|
|
||||||
internals.release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
84
src/security_parameters.rs
Normal file
84
src/security_parameters.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
/// The security parameters that you can assign to the server, to make decisions
|
||||||
|
/// about the weirdos it accepts as users. It is recommended that you only use
|
||||||
|
/// wide open connections when you're 100% sure that the server will only be
|
||||||
|
/// accessible locally.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SecurityParameters {
|
||||||
|
/// Allow completely unauthenticated connections. You should be very, very
|
||||||
|
/// careful about setting this to true, especially if you don't provide a
|
||||||
|
/// guard to ensure that you're getting connections from reasonable places.
|
||||||
|
pub allow_unauthenticated: bool,
|
||||||
|
/// An optional function that can serve as a firewall for new connections.
|
||||||
|
/// Return true if the connection should be allowed to continue, false if
|
||||||
|
/// it shouldn't. This check happens before any data is read from or written
|
||||||
|
/// to the connecting party.
|
||||||
|
pub allow_connection: Option<fn(&SocketAddr) -> bool>,
|
||||||
|
/// An optional function to check a user name (first argument) and password
|
||||||
|
/// (second argument). Return true if the username / password is good, false
|
||||||
|
/// if not.
|
||||||
|
pub check_password: Option<fn(&str, &str) -> bool>,
|
||||||
|
/// An optional function to transition the stream from an unencrypted one to
|
||||||
|
/// an encrypted on. The assumption is you're using something like `rustls`
|
||||||
|
/// to make this happen; the exact mechanism is outside the scope of this
|
||||||
|
/// particular crate. If the connection shouldn't be allowed for some reason
|
||||||
|
/// (a bad certificate or handshake, for example), then return None; otherwise,
|
||||||
|
/// return the new stream.
|
||||||
|
pub connect_tls: Option<fn() -> Option<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SecurityParameters {
|
||||||
|
/// Generates a `SecurityParameters` object that's empty. It won't accept
|
||||||
|
/// anything, because it has no mechanisms it can use to actually authenticate
|
||||||
|
/// a user and yet won't allow unauthenticated connections.
|
||||||
|
pub fn new() -> SecurityParameters {
|
||||||
|
SecurityParameters {
|
||||||
|
allow_unauthenticated: false,
|
||||||
|
allow_connection: None,
|
||||||
|
check_password: None,
|
||||||
|
connect_tls: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a `SecurityParameters` object that does not, in any way,
|
||||||
|
/// restrict who can log in. It also will not induce any transition into
|
||||||
|
/// TLS. Use this at your own risk ... or, really, just don't use this,
|
||||||
|
/// ever, and certainly not in production.
|
||||||
|
pub fn unrestricted() -> SecurityParameters {
|
||||||
|
SecurityParameters {
|
||||||
|
allow_unauthenticated: true,
|
||||||
|
allow_connection: None,
|
||||||
|
check_password: None,
|
||||||
|
connect_tls: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use the provided function to check incoming connections before proceeding
|
||||||
|
/// with the rest of the handshake.
|
||||||
|
pub fn check_connections(mut self, checker: fn(&SocketAddr) -> bool) -> SecurityParameters {
|
||||||
|
self.allow_connection = Some(checker);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use the provided function to check usernames and passwords provided
|
||||||
|
/// to the server.
|
||||||
|
pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters {
|
||||||
|
self.check_password = Some(checker);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use the provide function to validate a TLS connection, and transition it
|
||||||
|
/// to the new stream type. If the handshake fails, return `None` instead of
|
||||||
|
/// `Some`. (And maybe log it somewhere, you know.)
|
||||||
|
pub fn tls_converter(mut self, converter: fn() -> Option<()>) -> SecurityParameters {
|
||||||
|
self.connect_tls = Some(converter);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SecurityParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
use crate::errors::{DeserializationError, SerializationError};
|
|
||||||
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
|
|
||||||
pub async fn read_string<R: AsyncRead + Send + Unpin>(
|
|
||||||
r: &mut R,
|
|
||||||
) -> Result<String, DeserializationError> {
|
|
||||||
let mut length_buffer = [0; 1];
|
|
||||||
|
|
||||||
if r.read(&mut length_buffer).await? == 0 {
|
|
||||||
return Err(DeserializationError::NotEnoughData);
|
|
||||||
}
|
|
||||||
|
|
||||||
let target = length_buffer[0] as usize;
|
|
||||||
|
|
||||||
if target == 0 {
|
|
||||||
return Err(DeserializationError::InvalidEmptyString);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut bytestring = vec![0; target];
|
|
||||||
read_amt(r, target, &mut bytestring).await?;
|
|
||||||
|
|
||||||
Ok(String::from_utf8(bytestring)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_string<W: AsyncWrite + Send + Unpin>(
|
|
||||||
s: &str,
|
|
||||||
w: &mut W,
|
|
||||||
) -> Result<(), SerializationError> {
|
|
||||||
let bytestring = s.as_bytes();
|
|
||||||
|
|
||||||
if bytestring.is_empty() || bytestring.len() > 255 {
|
|
||||||
return Err(SerializationError::InvalidStringLength(s.to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
w.write_all(&[bytestring.len() as u8]).await?;
|
|
||||||
w.write_all(bytestring)
|
|
||||||
.await
|
|
||||||
.map_err(SerializationError::IOError)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_amt<R: AsyncRead + Send + Unpin>(
|
|
||||||
r: &mut R,
|
|
||||||
amt: usize,
|
|
||||||
buffer: &mut [u8],
|
|
||||||
) -> Result<(), DeserializationError> {
|
|
||||||
let mut amt_read = 0;
|
|
||||||
|
|
||||||
while amt_read < amt {
|
|
||||||
let chunk_amt = r.read(&mut buffer[amt_read..]).await?;
|
|
||||||
|
|
||||||
if chunk_amt == 0 {
|
|
||||||
return Err(DeserializationError::NotEnoughData);
|
|
||||||
}
|
|
||||||
|
|
||||||
amt_read += chunk_amt;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
744
src/server.rs
744
src/server.rs
@@ -1,157 +1,59 @@
|
|||||||
//! An implementation of a SOCKSv5 server, parameterizable by the security parameters
|
use std::net::SocketAddr;
|
||||||
//! and network stack you want to use. You should implement the server by first
|
|
||||||
//! setting up the `SecurityParameters`, then initializing the server object, and
|
use crate::address::SOCKSv5Address;
|
||||||
//! then running it, as follows:
|
|
||||||
//!
|
|
||||||
//! ```
|
|
||||||
//! use async_socks5::network::Builtin;
|
|
||||||
//! use async_socks5::server::{SecurityParameters, SOCKSv5Server};
|
|
||||||
//! use std::io;
|
|
||||||
//!
|
|
||||||
//! async {
|
|
||||||
//! let parameters = SecurityParameters::new()
|
|
||||||
//! .password_check(|u,p| { u == "adam" && p == "evil" });
|
|
||||||
//! let network = Builtin::new();
|
|
||||||
//! let server = SOCKSv5Server::new(network, parameters);
|
|
||||||
//! server.start("localhost", 9999).await;
|
|
||||||
//! // ... do other stuff ...
|
|
||||||
//! };
|
|
||||||
//!
|
|
||||||
//! ```
|
|
||||||
use crate::errors::{AuthenticationError, DeserializationError, SerializationError};
|
|
||||||
use crate::messages::{
|
use crate::messages::{
|
||||||
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting,
|
AuthenticationMethod, ClientConnectionCommand, ClientConnectionCommandReadError,
|
||||||
ClientUsernamePassword, ServerAuthResponse, ServerChoice, ServerResponse, ServerResponseStatus,
|
ClientConnectionRequest, ClientConnectionRequestReadError, ClientGreeting,
|
||||||
|
ClientGreetingReadError, ClientUsernamePassword, ClientUsernamePasswordReadError,
|
||||||
|
ServerAuthResponse, ServerAuthResponseWriteError, ServerChoice, ServerChoiceWriteError,
|
||||||
|
ServerResponse, ServerResponseStatus, ServerResponseWriteError,
|
||||||
};
|
};
|
||||||
use crate::network::address::HasLocalAddress;
|
use crate::security_parameters::SecurityParameters;
|
||||||
use crate::network::generic::Networklike;
|
|
||||||
use crate::network::listener::{GenericListener, Listenerlike};
|
|
||||||
use crate::network::stream::GenericStream;
|
|
||||||
use crate::network::SOCKSv5Address;
|
|
||||||
use async_std::io;
|
|
||||||
use async_std::io::prelude::WriteExt;
|
|
||||||
use async_std::sync::{Arc, Mutex};
|
|
||||||
use async_std::task;
|
|
||||||
use futures::Stream;
|
|
||||||
use log::{error, info, trace, warn};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::default::Default;
|
|
||||||
use std::fmt::{Debug, Display};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokio::io::{copy_bidirectional, AsyncWriteExt};
|
||||||
|
use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket};
|
||||||
|
|
||||||
/// A convenient bit of shorthand for an address and port
|
|
||||||
pub type AddressAndPort = (SOCKSv5Address, u16);
|
|
||||||
|
|
||||||
// Just some shorthand for us.
|
|
||||||
type ResultHandle = task::JoinHandle<Result<(), String>>;
|
|
||||||
|
|
||||||
/// A handle representing a SOCKSv5 server, parameterized by the underlying network
|
|
||||||
/// stack it runs over.
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SOCKSv5Server<N: Networklike> {
|
pub struct SOCKSv5Server {
|
||||||
network: Arc<Mutex<N>>,
|
|
||||||
running_servers: Arc<Mutex<HashMap<AddressAndPort, ResultHandle>>>,
|
|
||||||
security_parameters: SecurityParameters,
|
security_parameters: SecurityParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The security parameters that you can assign to the server, to make decisions
|
#[derive(Clone, Debug, Error, PartialEq)]
|
||||||
/// about the weirdos it accepts as users. It is recommended that you only use
|
pub enum SOCKSv5ServerError {
|
||||||
/// wide open connections when you're 100% sure that the server will only be
|
#[error("Underlying networking error: {0}")]
|
||||||
/// accessible locally.
|
NetworkingError(String),
|
||||||
#[derive(Clone)]
|
#[error("Couldn't negotiate authentication with client.")]
|
||||||
pub struct SecurityParameters {
|
ItsNotUsItsYou,
|
||||||
/// Allow completely unauthenticated connections. You should be very, very
|
#[error("Client greeting read problem: {0}")]
|
||||||
/// careful about setting this to true, especially if you don't provide a
|
GreetingReadProblem(#[from] ClientGreetingReadError),
|
||||||
/// guard to ensure that you're getting connections from reasonable places.
|
#[error("Server choice write problem: {0}")]
|
||||||
pub allow_unauthenticated: bool,
|
ChoiceWriteProblem(#[from] ServerChoiceWriteError),
|
||||||
/// An optional function that can serve as a firewall for new connections.
|
#[error("Failed username/password authentication for user {0}")]
|
||||||
/// Return true if the connection should be allowed to continue, false if
|
FailedUsernamePassword(String),
|
||||||
/// it shouldn't. This check happens before any data is read from or written
|
#[error("Server authentication response problem: {0}")]
|
||||||
/// to the connecting party.
|
ServerAuthWriteProblem(#[from] ServerAuthResponseWriteError),
|
||||||
pub allow_connection: Option<fn(&SOCKSv5Address, u16) -> bool>,
|
#[error("Error reading client username/password: {0}")]
|
||||||
/// An optional function to check a user name (first argument) and password
|
UserPassReadProblem(#[from] ClientUsernamePasswordReadError),
|
||||||
/// (second argument). Return true if the username / password is good, false
|
#[error("Error reading client connection command: {0}")]
|
||||||
/// if not.
|
ClientConnReadProblem(#[from] ClientConnectionCommandReadError),
|
||||||
pub check_password: Option<fn(&str, &str) -> bool>,
|
#[error("Error reading client connection request: {0}")]
|
||||||
/// An optional function to transition the stream from an unencrypted one to
|
ClientRequestReadProblem(#[from] ClientConnectionRequestReadError),
|
||||||
/// an encrypted on. The assumption is you're using something like `rustls`
|
#[error("Error writing server response: {0}")]
|
||||||
/// to make this happen; the exact mechanism is outside the scope of this
|
ServerResponseWriteProblem(#[from] ServerResponseWriteError),
|
||||||
/// particular crate. If the connection shouldn't be allowed for some reason
|
|
||||||
/// (a bad certificate or handshake, for example), then return None; otherwise,
|
|
||||||
/// return the new stream.
|
|
||||||
pub connect_tls: Option<fn(GenericStream) -> Option<GenericStream>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SecurityParameters {
|
impl From<std::io::Error> for SOCKSv5ServerError {
|
||||||
/// Generates a `SecurityParameters` object that's empty. It won't accept
|
fn from(x: std::io::Error) -> SOCKSv5ServerError {
|
||||||
/// anything, because it has no mechanisms it can use to actually authenticate
|
SOCKSv5ServerError::NetworkingError(format!("{}", x))
|
||||||
/// a user and yet won't allow unauthenticated connections.
|
|
||||||
pub fn new() -> SecurityParameters {
|
|
||||||
SecurityParameters {
|
|
||||||
allow_unauthenticated: false,
|
|
||||||
allow_connection: None,
|
|
||||||
check_password: None,
|
|
||||||
connect_tls: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a `SecurityParameters` object that does not, in any way,
|
|
||||||
/// restrict who can log in. It also will not induce any transition into
|
|
||||||
/// TLS. Use this at your own risk ... or, really, just don't use this,
|
|
||||||
/// ever, and certainly not in production.
|
|
||||||
pub fn unrestricted() -> SecurityParameters {
|
|
||||||
SecurityParameters {
|
|
||||||
allow_unauthenticated: true,
|
|
||||||
allow_connection: None,
|
|
||||||
check_password: None,
|
|
||||||
connect_tls: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Use the provided function to check incoming connections before proceeding
|
|
||||||
/// with the rest of the handshake.
|
|
||||||
pub fn check_connections(
|
|
||||||
mut self,
|
|
||||||
checker: fn(&SOCKSv5Address, u16) -> bool,
|
|
||||||
) -> SecurityParameters {
|
|
||||||
self.allow_connection = Some(checker);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Use the provided function to check usernames and passwords provided
|
|
||||||
/// to the server.
|
|
||||||
pub fn password_check(mut self, checker: fn(&str, &str) -> bool) -> SecurityParameters {
|
|
||||||
self.check_password = Some(checker);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Use the provide function to validate a TLS connection, and transition it
|
|
||||||
/// to the new stream type. If the handshake fails, return `None` instead of
|
|
||||||
/// `Some`. (And maybe log it somewhere, you know.)
|
|
||||||
pub fn tls_converter(
|
|
||||||
mut self,
|
|
||||||
converter: fn(GenericStream) -> Option<GenericStream>,
|
|
||||||
) -> SecurityParameters {
|
|
||||||
self.connect_tls = Some(converter);
|
|
||||||
self
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SecurityParameters {
|
impl SOCKSv5Server {
|
||||||
fn default() -> Self {
|
/// Initialize a SOCKSv5 server for use later on. Once initialized, you can listen
|
||||||
Self::new()
|
/// on as many addresses and ports as you like; the metadata about the server will
|
||||||
}
|
/// be synced across all the instances.
|
||||||
}
|
pub fn new(security_parameters: SecurityParameters) -> Self {
|
||||||
|
|
||||||
impl<N: Networklike + Clone + Send + 'static> SOCKSv5Server<N> {
|
|
||||||
/// Initialize a SOCKSv5 server for use later on. Once initialize, you can listen on
|
|
||||||
/// as many addresses and ports as you like; the metadata about the server will be
|
|
||||||
/// sync'd across all of the instances, should you want to gather that data for some
|
|
||||||
/// reason.
|
|
||||||
pub fn new(network: N, security_parameters: SecurityParameters) -> SOCKSv5Server<N> {
|
|
||||||
SOCKSv5Server {
|
SOCKSv5Server {
|
||||||
network: Arc::new(Mutex::new(network)),
|
|
||||||
running_servers: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
security_parameters,
|
security_parameters,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -159,191 +61,241 @@ impl<N: Networklike + Clone + Send + 'static> SOCKSv5Server<N> {
|
|||||||
/// Start a server on the given address and port. This function returns when it has
|
/// Start a server on the given address and port. This function returns when it has
|
||||||
/// set up its listening socket, but spawns a separate task to actually wait for
|
/// set up its listening socket, but spawns a separate task to actually wait for
|
||||||
/// connections. You can query which ones are still active, or see which ones have
|
/// connections. You can query which ones are still active, or see which ones have
|
||||||
/// failed, using some of the other items in this structure.
|
/// failed, using some of the other functions for this structure.
|
||||||
|
///
|
||||||
|
/// If you don't care what port is assigned to this server, pass 0 in as the port
|
||||||
|
/// number and one will be chosen for you by the OS.
|
||||||
|
///
|
||||||
pub async fn start<A: Send + Into<SOCKSv5Address>>(
|
pub async fn start<A: Send + Into<SOCKSv5Address>>(
|
||||||
&self,
|
&self,
|
||||||
addr: A,
|
addr: A,
|
||||||
port: u16,
|
port: u16,
|
||||||
) -> Result<(), N::Error> {
|
) -> Result<(), std::io::Error> {
|
||||||
// This might seem a little weird, but we do this in a separate block to make it
|
let listener = match addr.into() {
|
||||||
// as clear as possible to the borrow checker (and the reader) that we only want
|
SOCKSv5Address::IP4(x) => TcpListener::bind((x, port)).await?,
|
||||||
// to hold the lock while we're actually calling listen.
|
SOCKSv5Address::IP6(x) => TcpListener::bind((x, port)).await?,
|
||||||
let listener = {
|
SOCKSv5Address::Hostname(x) => TcpListener::bind((x, port)).await?,
|
||||||
let mut network = self.network.lock().await;
|
};
|
||||||
network.listen(addr, port).await
|
|
||||||
}?;
|
|
||||||
|
|
||||||
// this should really be the same as the input, but technically they could've
|
let sockaddr = listener.local_addr()?;
|
||||||
// thrown some zeros in there and let the underlying network stack decide. So
|
tracing::info!(
|
||||||
// we'll just pull this information post-initialization, and maybe get something
|
"Starting SOCKSv5 server on {}:{}",
|
||||||
// a bit more detailed.
|
sockaddr.ip(),
|
||||||
let (my_addr, my_port) = listener.local_addr();
|
sockaddr.port()
|
||||||
info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port);
|
);
|
||||||
|
|
||||||
// OK, spawn off the server loop, and then we'll register this in our list of
|
let second_life = self.clone();
|
||||||
// things running.
|
|
||||||
let new_self = self.clone();
|
tokio::task::spawn(async move {
|
||||||
let task_id = task::spawn(async move {
|
if let Err(e) = second_life.server_loop(listener).await {
|
||||||
new_self
|
tracing::error!(
|
||||||
.server_loop(listener)
|
"{}:{}: server network error: {}",
|
||||||
.await
|
sockaddr.ip(),
|
||||||
.map_err(|x| format!("Server network error: {}", x))
|
sockaddr.port(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut server_map = self.running_servers.lock().await;
|
|
||||||
server_map.insert((my_addr, my_port), task_id);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provide a list of open sockets on the server.
|
/// Run the server loop for a particular listener. This routine will never actually
|
||||||
pub async fn open_sockets(&self) -> Vec<AddressAndPort> {
|
/// return except in error conditions.
|
||||||
let server_map = self.running_servers.lock().await;
|
async fn server_loop(self, listener: TcpListener) -> Result<(), std::io::Error> {
|
||||||
server_map.keys().cloned().collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn subserver_results(&mut self) -> impl Stream<Item = Result<(), String>> {
|
|
||||||
futures::stream::unfold(self.running_servers.clone(), |locked_map| async move {
|
|
||||||
let first_server = {
|
|
||||||
let mut server_map = locked_map.lock().await;
|
|
||||||
let first_key = server_map.keys().next().cloned()?;
|
|
||||||
|
|
||||||
server_map.remove(&first_key)
|
|
||||||
}?;
|
|
||||||
|
|
||||||
let result = first_server.await;
|
|
||||||
Some((result, locked_map))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn server_loop(self, listener: GenericListener<N::Error>) -> Result<(), N::Error> {
|
|
||||||
loop {
|
loop {
|
||||||
let (stream, their_addr, their_port) = listener.accept().await?;
|
let (socket, their_addr) = listener.accept().await?;
|
||||||
trace!(
|
|
||||||
"Initial accept of connection from {}:{}",
|
|
||||||
their_addr,
|
|
||||||
their_port
|
|
||||||
);
|
|
||||||
|
|
||||||
// before we do anything, make sure this connection is cool. we don't want to
|
// before we do anything of note, make sure this connection is cool. we don't want
|
||||||
// waste resources (or parse any data) if this isn't someone we actually care
|
// to waste any resources (and certainly don't want to handle any data!) if this
|
||||||
// about it.
|
// isn't someone we want to accept connections from.
|
||||||
if let Some(checker) = &self.security_parameters.allow_connection {
|
tracing::trace!("Initial accept of connection from {}", their_addr);
|
||||||
if !checker(&their_addr, their_port) {
|
if let Some(checker) = self.security_parameters.allow_connection {
|
||||||
info!(
|
if !checker(&their_addr) {
|
||||||
"Rejecting attempted connection from {}:{}",
|
tracing::info!("Rejecting attempted connection from {}", their_addr,);
|
||||||
their_addr, their_port
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// throw this off into another task to take from here. We could to the rest
|
// continue this work in another task. we could absolutely do this work here,
|
||||||
// of this handshake here, but there's a chance that an adversarial connection
|
// but just in case someone starts doing slow responses (or other nasty things),
|
||||||
// could just stall us out, and keep us from doing the next connection. So ...
|
// we want to make sure that that doesn't slow down our ability to accept other
|
||||||
// we'll potentially spin off the task early.
|
// requests.
|
||||||
let me_again = self.clone();
|
let me_again = self.clone();
|
||||||
task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
me_again
|
if let Err(e) = me_again.start_authentication(their_addr, socket).await {
|
||||||
.authenticate_step(their_addr, their_port, stream)
|
tracing::error!("{}: server handler failure: {}", their_addr, e);
|
||||||
.await;
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authenticate_step(
|
/// Start the authentication phase of the SOCKS handshake. This may be very short, and
|
||||||
|
/// is the first stage of handling a request. This will only really return on errors.
|
||||||
|
async fn start_authentication(
|
||||||
self,
|
self,
|
||||||
their_addr: SOCKSv5Address,
|
their_addr: SocketAddr,
|
||||||
their_port: u16,
|
mut socket: TcpStream,
|
||||||
base_stream: GenericStream,
|
) -> Result<(), SOCKSv5ServerError> {
|
||||||
) {
|
let greeting = ClientGreeting::read(&mut socket).await?;
|
||||||
// Turn this stream into one where we've authenticated the other side. Or, you
|
|
||||||
// know, don't, and just restart this loop.
|
|
||||||
let mut authenticated_stream =
|
|
||||||
match run_authentication(&self.security_parameters, base_stream).await {
|
|
||||||
Ok(authed_stream) => authed_stream,
|
|
||||||
Err(e) => {
|
|
||||||
warn!(
|
|
||||||
"Failure running authentication from {}:{}: {}",
|
|
||||||
their_addr, their_port, e
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Figure out what the client actually wants from this connection, and
|
match choose_authentication_method(&self.security_parameters, &greeting.acceptable_methods)
|
||||||
// then dispatch a task to deal with that.
|
{
|
||||||
let mccr = ClientConnectionRequest::read(&mut authenticated_stream).await;
|
// it's not us, it's you. (we're just going to say no.)
|
||||||
match mccr {
|
None => {
|
||||||
Err(e) => warn!("Failure figuring out what the client wanted: {}", e),
|
tracing::trace!(
|
||||||
Ok(ccr) => match ccr.command_code {
|
"{}: Failed to find acceptable authentication method.",
|
||||||
ClientConnectionCommand::AssociateUDPPort => self
|
their_addr,
|
||||||
.handle_udp_request(authenticated_stream, ccr, their_addr, their_port)
|
);
|
||||||
.await
|
let rejection_letter = ServerChoice::rejection();
|
||||||
.unwrap_or_else(|e| warn!("Internal server error in UDP association: {}", e)),
|
|
||||||
ClientConnectionCommand::EstablishTCPPortBinding => self
|
rejection_letter.write(&mut socket).await?;
|
||||||
.handle_tcp_bind(authenticated_stream, ccr, their_addr, their_port)
|
socket.flush().await?;
|
||||||
.await
|
|
||||||
.unwrap_or_else(|e| warn!("Internal server error in TCP bind: {}", e)),
|
Err(SOCKSv5ServerError::ItsNotUsItsYou)
|
||||||
ClientConnectionCommand::EstablishTCPStream => self
|
}
|
||||||
.handle_tcp_forward(authenticated_stream, ccr, their_addr, their_port)
|
|
||||||
.await
|
// the gold standard. great choice.
|
||||||
.unwrap_or_else(|e| warn!("Internal server error in TCP forward: {}", e)),
|
Some(ChosenMethod::TLS(_converter)) => {
|
||||||
},
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// well, I guess this is something?
|
||||||
|
Some(ChosenMethod::Password(checker)) => {
|
||||||
|
tracing::trace!(
|
||||||
|
"{}: Choosing username/password for authentication.",
|
||||||
|
their_addr,
|
||||||
|
);
|
||||||
|
let ok_lets_do_password =
|
||||||
|
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
|
||||||
|
ok_lets_do_password.write(&mut socket).await?;
|
||||||
|
socket.flush().await?;
|
||||||
|
|
||||||
|
let their_info = ClientUsernamePassword::read(&mut socket).await?;
|
||||||
|
if checker(&their_info.username, &their_info.password) {
|
||||||
|
let its_all_good = ServerAuthResponse::success();
|
||||||
|
its_all_good.write(&mut socket).await?;
|
||||||
|
socket.flush().await?;
|
||||||
|
self.choose_mode(socket, their_addr).await
|
||||||
|
} else {
|
||||||
|
let yeah_no = ServerAuthResponse::failure();
|
||||||
|
yeah_no.write(&mut socket).await?;
|
||||||
|
socket.flush().await?;
|
||||||
|
Err(SOCKSv5ServerError::FailedUsernamePassword(
|
||||||
|
their_info.username,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Um. I guess we're doing this unchecked. Yay?
|
||||||
|
Some(ChosenMethod::None) => {
|
||||||
|
tracing::trace!(
|
||||||
|
"{}: Just skipping the whole authentication thing.",
|
||||||
|
their_addr,
|
||||||
|
);
|
||||||
|
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
|
||||||
|
nothin_i_guess.write(&mut socket).await?;
|
||||||
|
socket.flush().await?;
|
||||||
|
self.choose_mode(socket, their_addr).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_udp_request(
|
/// Determine which of the modes we might want this particular connection to run
|
||||||
|
/// in.
|
||||||
|
async fn choose_mode(
|
||||||
self,
|
self,
|
||||||
stream: GenericStream,
|
mut socket: TcpStream,
|
||||||
ccr: ClientConnectionRequest,
|
their_addr: SocketAddr,
|
||||||
their_addr: SOCKSv5Address,
|
) -> Result<(), SOCKSv5ServerError> {
|
||||||
their_port: u16,
|
let ccr = ClientConnectionRequest::read(&mut socket).await?;
|
||||||
) -> Result<(), ServerError<N::Error>> {
|
match ccr.command_code {
|
||||||
// Let the user know that we're maybe making progress
|
ClientConnectionCommand::AssociateUDPPort => {
|
||||||
let (my_addr, my_port) = stream.local_addr();
|
self.handle_udp_request(socket, their_addr, ccr).await?
|
||||||
info!(
|
}
|
||||||
"[{}:{}] Handling UDP bind request from {}:{}, seeking to bind {}:{}",
|
ClientConnectionCommand::EstablishTCPStream => {
|
||||||
my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port
|
self.handle_tcp_request(socket, their_addr, ccr).await?
|
||||||
);
|
}
|
||||||
|
ClientConnectionCommand::EstablishTCPPortBinding => {
|
||||||
unimplemented!()
|
self.handle_tcp_binding_request(socket, their_addr, ccr)
|
||||||
|
.await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tcp_forward(
|
/// Handle UDP forwarding requests
|
||||||
|
#[allow(unreachable_code)]
|
||||||
|
async fn handle_udp_request(
|
||||||
self,
|
self,
|
||||||
mut stream: GenericStream,
|
stream: TcpStream,
|
||||||
|
their_addr: SocketAddr,
|
||||||
ccr: ClientConnectionRequest,
|
ccr: ClientConnectionRequest,
|
||||||
their_addr: SOCKSv5Address,
|
) -> Result<(), SOCKSv5ServerError> {
|
||||||
their_port: u16,
|
let my_addr = stream.local_addr()?;
|
||||||
) -> Result<(), ServerError<N::Error>> {
|
tracing::info!(
|
||||||
|
"[{}:{}] Handling UDP bind request from {}:{}, seeking to bind towards {}:{}",
|
||||||
|
my_addr.ip(),
|
||||||
|
my_addr.port(),
|
||||||
|
their_addr.ip(),
|
||||||
|
their_addr.port(),
|
||||||
|
ccr.destination_address,
|
||||||
|
ccr.destination_port
|
||||||
|
);
|
||||||
|
|
||||||
|
let _socket = match ccr.destination_address.clone() {
|
||||||
|
SOCKSv5Address::IP4(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
|
||||||
|
SOCKSv5Address::IP6(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
|
||||||
|
SOCKSv5Address::Hostname(x) => UdpSocket::bind((x, ccr.destination_port)).await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
// OK, it worked. In order to mitigate an infinitesimal chance of a race condition, we're
|
||||||
|
// going to set up our forwarding tasks first, and then return the result to the user. (Note,
|
||||||
|
// we'd have to be slightly more precious in order to ensure a lack of race conditions, as
|
||||||
|
// the runtime could take forever to actually start these tasks, but I'm not ready to be
|
||||||
|
// bothered by this, yet. FIXME.)
|
||||||
|
unimplemented!();
|
||||||
|
|
||||||
|
// Cool; now we can get the result out to the user.
|
||||||
|
let bound_address = _socket.local_addr()?;
|
||||||
|
let response = ServerResponse {
|
||||||
|
status: ServerResponseStatus::RequestGranted,
|
||||||
|
bound_address: bound_address.ip().into(),
|
||||||
|
bound_port: bound_address.port(),
|
||||||
|
};
|
||||||
|
|
||||||
|
response.write(&mut stream).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle TCP forwarding requests
|
||||||
|
async fn handle_tcp_request(
|
||||||
|
self,
|
||||||
|
mut stream: TcpStream,
|
||||||
|
their_addr: SocketAddr,
|
||||||
|
ccr: ClientConnectionRequest,
|
||||||
|
) -> Result<(), SOCKSv5ServerError> {
|
||||||
// Let the user know that we're maybe making progress
|
// Let the user know that we're maybe making progress
|
||||||
let (my_addr, my_port) = stream.local_addr();
|
let my_addr = stream.local_addr()?;
|
||||||
info!(
|
tracing::info!(
|
||||||
"[{}:{}] Handling TCP forward request from {}:{}, seeking to connect to {}:{}",
|
"[{}] Handling TCP forward request from {}, seeking to connect to {}:{}",
|
||||||
my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port
|
my_addr,
|
||||||
|
their_addr,
|
||||||
|
ccr.destination_address,
|
||||||
|
ccr.destination_port
|
||||||
);
|
);
|
||||||
|
|
||||||
// OK, first thing's first: We need to actually connect to the server that the user
|
// OK, first thing's first: We need to actually connect to the server that the user
|
||||||
// wants us to connect to.
|
// wants us to connect to.
|
||||||
let connection_res = {
|
let outgoing_stream = match &ccr.destination_address {
|
||||||
let mut network = self.network.lock().await;
|
SOCKSv5Address::IP4(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
|
||||||
network
|
SOCKSv5Address::IP6(x) => TcpStream::connect((*x, ccr.destination_port)).await?,
|
||||||
.connect(ccr.destination_address.clone(), ccr.destination_port)
|
SOCKSv5Address::Hostname(x) => {
|
||||||
.await
|
TcpStream::connect((x.as_ref(), ccr.destination_port)).await?
|
||||||
};
|
|
||||||
|
|
||||||
let outgoing_stream = match connection_res {
|
|
||||||
Ok(x) => x,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to connect to {}: {}", ccr.destination_address, e);
|
|
||||||
let response = ServerResponse::error(&e);
|
|
||||||
response.write(&mut stream).await?;
|
|
||||||
return Err(ServerError::NetworkError(e));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!(
|
tracing::trace!(
|
||||||
"Connection established to {}:{}",
|
"Connection established to {}:{}",
|
||||||
ccr.destination_address,
|
ccr.destination_address,
|
||||||
ccr.destination_port
|
ccr.destination_port
|
||||||
@@ -352,117 +304,117 @@ impl<N: Networklike + Clone + Send + 'static> SOCKSv5Server<N> {
|
|||||||
// Now, for whatever reason -- and this whole thing sent me down a garden path
|
// Now, for whatever reason -- and this whole thing sent me down a garden path
|
||||||
// in understanding how this whole protocol works -- we tell the user what address
|
// in understanding how this whole protocol works -- we tell the user what address
|
||||||
// and port we bound for that connection.
|
// and port we bound for that connection.
|
||||||
let (bound_address, bound_port) = outgoing_stream.local_addr();
|
let bound_address = outgoing_stream.local_addr()?;
|
||||||
let response = ServerResponse {
|
let response = ServerResponse {
|
||||||
status: ServerResponseStatus::RequestGranted,
|
status: ServerResponseStatus::RequestGranted,
|
||||||
bound_address,
|
bound_address: bound_address.ip().into(),
|
||||||
bound_port,
|
bound_port: bound_address.port(),
|
||||||
};
|
};
|
||||||
response.write(&mut stream).await?;
|
response.write(&mut stream).await?;
|
||||||
|
|
||||||
// so now tie our streams together, and we're good to go
|
// so now tie our streams together, and we're good to go
|
||||||
tie_streams(
|
tie_streams(stream, outgoing_stream).await;
|
||||||
format!("{}:{}", their_addr, their_port),
|
|
||||||
stream,
|
|
||||||
format!("{}:{}", ccr.destination_address, ccr.destination_port),
|
|
||||||
outgoing_stream,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tcp_bind(
|
/// Handle TCP binding requests
|
||||||
|
async fn handle_tcp_binding_request(
|
||||||
self,
|
self,
|
||||||
mut stream: GenericStream,
|
mut stream: TcpStream,
|
||||||
|
their_addr: SocketAddr,
|
||||||
ccr: ClientConnectionRequest,
|
ccr: ClientConnectionRequest,
|
||||||
their_addr: SOCKSv5Address,
|
) -> Result<(), SOCKSv5ServerError> {
|
||||||
their_port: u16,
|
|
||||||
) -> Result<(), ServerError<N::Error>> {
|
|
||||||
// Let the user know that we're maybe making progress
|
// Let the user know that we're maybe making progress
|
||||||
let (my_addr, my_port) = stream.local_addr();
|
let my_addr = stream.local_addr()?;
|
||||||
info!(
|
tracing::info!(
|
||||||
"[{}:{}] Handling TCP bind request from {}:{}, seeking to bind {}:{}",
|
"[{}] Handling TCP bind request from {}, seeking to bind {}:{}",
|
||||||
my_addr, my_port, their_addr, their_port, ccr.destination_address, ccr.destination_port
|
my_addr,
|
||||||
|
their_addr,
|
||||||
|
ccr.destination_address,
|
||||||
|
ccr.destination_port
|
||||||
);
|
);
|
||||||
|
|
||||||
// OK, we have to bind the darn socket first.
|
// OK, we have to bind the darn socket first.
|
||||||
let port_binding = {
|
let listener_port = match &their_addr {
|
||||||
let mut network = self.network.lock().await;
|
SocketAddr::V4(_) => TcpSocket::new_v4(),
|
||||||
network.listen(their_addr.clone(), their_port).await
|
SocketAddr::V6(_) => TcpSocket::new_v6(),
|
||||||
}
|
}?;
|
||||||
.map_err(ServerError::NetworkError)?;
|
// FIXME: Might want to bind on a particular interface, based on a
|
||||||
|
// config flag, at some point.
|
||||||
|
let listener = listener_port.listen(1)?;
|
||||||
|
|
||||||
// Tell them what we bound, just in case they want to inform anyone.
|
// Tell them what we bound, just in case they want to inform anyone.
|
||||||
let (bound_address, bound_port) = port_binding.local_addr();
|
let bound_address = listener.local_addr()?;
|
||||||
let response = ServerResponse {
|
let response = ServerResponse {
|
||||||
status: ServerResponseStatus::RequestGranted,
|
status: ServerResponseStatus::RequestGranted,
|
||||||
bound_address,
|
bound_address: bound_address.ip().into(),
|
||||||
bound_port,
|
bound_port: bound_address.port(),
|
||||||
};
|
};
|
||||||
response.write(&mut stream).await?;
|
response.write(&mut stream).await?;
|
||||||
|
|
||||||
// Wait politely for someone to talk to us.
|
// Wait politely for someone to talk to us.
|
||||||
let (other, other_addr, other_port) = port_binding
|
let (other, other_addr) = listener.accept().await?;
|
||||||
.accept()
|
|
||||||
.await
|
|
||||||
.map_err(ServerError::NetworkError)?;
|
|
||||||
let info = ServerResponse {
|
let info = ServerResponse {
|
||||||
status: ServerResponseStatus::RequestGranted,
|
status: ServerResponseStatus::RequestGranted,
|
||||||
bound_address: other_addr.clone(),
|
bound_address: other_addr.ip().into(),
|
||||||
bound_port: other_port,
|
bound_port: other_addr.port(),
|
||||||
};
|
};
|
||||||
info.write(&mut stream).await?;
|
info.write(&mut stream).await?;
|
||||||
|
|
||||||
tie_streams(
|
tie_streams(stream, other).await;
|
||||||
format!("{}:{}", their_addr, their_port),
|
|
||||||
stream,
|
|
||||||
format!("{}:{}", other_addr, other_port),
|
|
||||||
other,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn tie_streams(
|
async fn tie_streams(mut left: TcpStream, mut right: TcpStream) {
|
||||||
left_name: String,
|
let left_local_addr = left
|
||||||
left: GenericStream,
|
.local_addr()
|
||||||
right_name: String,
|
.expect("couldn't get left local address in tie_streams");
|
||||||
right: GenericStream,
|
let left_peer_addr = left
|
||||||
) {
|
.peer_addr()
|
||||||
// Now that we've informed them of that, we set up one task to transfer information
|
.expect("couldn't get left peer address in tie_streams");
|
||||||
// from the current stream (`stream`) to the connection (`outgoing_stream`), and
|
let right_local_addr = right
|
||||||
// another task that goes in the reverse direction.
|
.local_addr()
|
||||||
//
|
.expect("couldn't get right local address in tie_streams");
|
||||||
// I've chosen to start two fresh tasks and let this one die; I'm not sure that
|
let right_peer_addr = right
|
||||||
// this is the right approach. My only rationale is that this might let some
|
.peer_addr()
|
||||||
// memory we might have accumulated along the way drop more easily, but that
|
.expect("couldn't get right peer address in tie_streams");
|
||||||
// might not actually matter.
|
|
||||||
let mut from_left = left.clone();
|
|
||||||
let mut from_right = right.clone();
|
|
||||||
let mut to_left = left;
|
|
||||||
let mut to_right = right;
|
|
||||||
let left_right_name = format!("{} >--> {}", left_name, right_name);
|
|
||||||
let right_left_name = format!("{} <--< {}", left_name, right_name);
|
|
||||||
|
|
||||||
task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
info!("Spawned {} task", left_right_name);
|
tracing::info!(
|
||||||
if let Err(e) = io::copy(&mut from_left, &mut to_right).await {
|
"Setting up linkage {}/{} <-> {}/{}",
|
||||||
warn!("{} connection failed with: {}", left_right_name, e);
|
left_peer_addr,
|
||||||
}
|
left_local_addr,
|
||||||
});
|
right_local_addr,
|
||||||
|
right_peer_addr
|
||||||
task::spawn(async move {
|
);
|
||||||
info!("Spawned {} task", right_left_name);
|
match copy_bidirectional(&mut left, &mut right).await {
|
||||||
if let Err(e) = io::copy(&mut from_right, &mut to_left).await {
|
Ok((l2r, r2l)) => tracing::info!(
|
||||||
warn!("{} connection failed with: {}", right_left_name, e);
|
"Shutting down linkage {}/{} <-> {}/{} (sent {} and {} bytes, respectively)",
|
||||||
|
left_peer_addr,
|
||||||
|
left_local_addr,
|
||||||
|
right_local_addr,
|
||||||
|
right_peer_addr,
|
||||||
|
l2r,
|
||||||
|
r2l
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
"Shutting down linkage {}/{} <-> {}/{} with error: {}",
|
||||||
|
left_peer_addr,
|
||||||
|
left_local_addr,
|
||||||
|
right_local_addr,
|
||||||
|
right_peer_addr,
|
||||||
|
e
|
||||||
|
),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
enum ChosenMethod {
|
enum ChosenMethod {
|
||||||
TLS(fn(GenericStream) -> Option<GenericStream>),
|
TLS(fn() -> Option<()>),
|
||||||
Password(fn(&str, &str) -> bool),
|
Password(fn(&str, &str) -> bool),
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
@@ -567,7 +519,7 @@ fn reasonable_auth_method_choices() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// OK, cool. If we have a TLS handler, that shouldn't actually make a difference.
|
// OK, cool. If we have a TLS handler, that shouldn't actually make a difference.
|
||||||
params.connect_tls = Some(|_| unimplemented!());
|
params.connect_tls = Some(|| unimplemented!());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from),
|
choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from),
|
||||||
None
|
None
|
||||||
@@ -580,7 +532,7 @@ fn reasonable_auth_method_choices() {
|
|||||||
None
|
None
|
||||||
);
|
);
|
||||||
// but if we have a handler, and they go for it, we use it.
|
// but if we have a handler, and they go for it, we use it.
|
||||||
params.connect_tls = Some(|_| unimplemented!());
|
params.connect_tls = Some(|| unimplemented!());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from),
|
choose_authentication_method(¶ms, &client_suggestions).map(AuthenticationMethod::from),
|
||||||
Some(AuthenticationMethod::SSL)
|
Some(AuthenticationMethod::SSL)
|
||||||
@@ -600,75 +552,3 @@ fn reasonable_auth_method_choices() {
|
|||||||
Some(AuthenticationMethod::SSL)
|
Some(AuthenticationMethod::SSL)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_authentication(
|
|
||||||
params: &SecurityParameters,
|
|
||||||
mut stream: GenericStream,
|
|
||||||
) -> Result<GenericStream, AuthenticationError> {
|
|
||||||
let greeting = ClientGreeting::read(&mut stream).await?;
|
|
||||||
|
|
||||||
match choose_authentication_method(params, &greeting.acceptable_methods) {
|
|
||||||
// it's not us, it's you
|
|
||||||
None => {
|
|
||||||
trace!("Failed to find acceptable authentication method.");
|
|
||||||
let rejection_letter = ServerChoice::rejection();
|
|
||||||
|
|
||||||
rejection_letter.write(&mut stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
|
|
||||||
Err(AuthenticationError::ItsNotUsItsYou)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the gold standard. great choice.
|
|
||||||
Some(ChosenMethod::TLS(converter)) => {
|
|
||||||
trace!("Choosing TLS for authentication.");
|
|
||||||
let lets_do_this = ServerChoice::option(AuthenticationMethod::SSL);
|
|
||||||
lets_do_this.write(&mut stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
|
|
||||||
converter(stream).ok_or(AuthenticationError::FailedTLSHandshake)
|
|
||||||
}
|
|
||||||
|
|
||||||
// well, I guess this is something?
|
|
||||||
Some(ChosenMethod::Password(checker)) => {
|
|
||||||
trace!("Choosing Username/Password for authentication.");
|
|
||||||
let ok_lets_do_password =
|
|
||||||
ServerChoice::option(AuthenticationMethod::UsernameAndPassword);
|
|
||||||
ok_lets_do_password.write(&mut stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
|
|
||||||
let their_info = ClientUsernamePassword::read(&mut stream).await?;
|
|
||||||
if checker(&their_info.username, &their_info.password) {
|
|
||||||
let its_all_good = ServerAuthResponse::success();
|
|
||||||
its_all_good.write(&mut stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
Ok(stream)
|
|
||||||
} else {
|
|
||||||
let yeah_no = ServerAuthResponse::failure();
|
|
||||||
yeah_no.write(&mut stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
Err(AuthenticationError::FailedUsernamePassword(
|
|
||||||
their_info.username,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(ChosenMethod::None) => {
|
|
||||||
trace!("Just skipping the whole authentication thing.");
|
|
||||||
let nothin_i_guess = ServerChoice::option(AuthenticationMethod::None);
|
|
||||||
nothin_i_guess.write(&mut stream).await?;
|
|
||||||
stream.flush().await?;
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum ServerError<E: Debug + Display> {
|
|
||||||
#[error("Error in deserialization: {0}")]
|
|
||||||
DeserializationError(#[from] DeserializationError),
|
|
||||||
#[error("Error in serialization: {0}")]
|
|
||||||
SerializationError(#[from] SerializationError),
|
|
||||||
#[error("Underlying network error: {0}")]
|
|
||||||
NetworkError(E),
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user