Add HasLocalAddress for querying a socket's address, and a test.

This commit is contained in:
2021-07-25 17:18:29 -07:00
parent 1436b02323
commit 58cb384afd
6 changed files with 112 additions and 104 deletions

View File

@@ -168,6 +168,10 @@ impl SOCKSv5Address {
}
}
pub trait HasLocalAddress {
fn local_addr(&self) -> (SOCKSv5Address, u16);
}
#[cfg(test)]
impl Arbitrary for SOCKSv5Address {
fn arbitrary(g: &mut Gen) -> Self {

View File

@@ -1,8 +1,8 @@
use crate::network::address::SOCKSv5Address;
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use async_trait::async_trait;
#[async_trait]
pub trait Datagramlike: Send + Sync {
pub trait Datagramlike: Send + Sync + HasLocalAddress {
type Error;
async fn send_to(
@@ -35,3 +35,9 @@ impl<E> Datagramlike for GenericDatagramSocket<E> {
Ok(self.internal.recv_from(buf).await?)
}
}
impl<E> HasLocalAddress for GenericDatagramSocket<E> {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
self.internal.local_addr()
}
}

View File

@@ -1,13 +1,12 @@
use crate::network::address::SOCKSv5Address;
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use crate::network::stream::GenericStream;
use async_trait::async_trait;
#[async_trait]
pub trait Listenerlike: Send + Sync {
pub trait Listenerlike: Send + Sync + HasLocalAddress {
type Error;
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error>;
fn info(&self) -> (SOCKSv5Address, u16);
}
pub struct GenericListener<E> {
@@ -21,8 +20,10 @@ impl<E> Listenerlike for GenericListener<E> {
async fn accept(&self) -> Result<(GenericStream, SOCKSv5Address, u16), Self::Error> {
Ok(self.internal.accept().await?)
}
}
fn info(&self) -> (SOCKSv5Address, u16) {
self.internal.info()
impl<E> HasLocalAddress for GenericListener<E> {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
self.internal.local_addr()
}
}

View File

@@ -1,14 +1,17 @@
use crate::messages::ServerResponseStatus;
use crate::network::address::SOCKSv5Address;
use crate::network::address::{HasLocalAddress, SOCKSv5Address};
use crate::network::datagram::{Datagramlike, GenericDatagramSocket};
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::{GenericStream, Streamlike};
use async_std::io;
#[cfg(test)]
use async_std::io::ReadExt;
use async_std::net::{TcpListener, TcpStream, UdpSocket};
use async_trait::async_trait;
#[cfg(test)]
use futures::AsyncWriteExt;
use log::error;
use std::net::Ipv4Addr;
pub struct Builtin {}
@@ -18,6 +21,27 @@ impl Builtin {
}
}
macro_rules! local_address_impl {
($t: ty) => {
impl HasLocalAddress for $t {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
match self.local_addr() {
Ok(a) =>
(SOCKSv5Address::from(a.ip()), a.port()),
Err(e) => {
error!("Couldn't translate (Streamlike) local address to SOCKS local address: {}", e);
(SOCKSv5Address::from("localhost"), 0)
}
}
}
}
};
}
local_address_impl!(TcpStream);
local_address_impl!(TcpListener);
local_address_impl!(UdpSocket);
impl Streamlike for TcpStream {}
#[async_trait]
@@ -28,21 +52,7 @@ impl Listenerlike for TcpListener {
let (base, addrport) = self.accept().await?;
let addr = addrport.ip();
let port = addrport.port();
Ok((GenericStream::from(base), SOCKSv5Address::from(addr), port))
}
fn info(&self) -> (SOCKSv5Address, u16) {
match self.local_addr() {
Ok(x) => {
let addr = SOCKSv5Address::from(x.ip());
let port = x.port();
(addr, port)
}
Err(e) => {
error!("Someone asked for a listener address, and we got an error ({}); returning 0.0.0.0:0", e);
(SOCKSv5Address::IP4(Ipv4Addr::from(0)), 0)
}
}
Ok((GenericStream::new(base), SOCKSv5Address::from(addr), port))
}
}
@@ -88,7 +98,7 @@ impl Networklike for Builtin {
SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await?,
};
Ok(GenericStream::from(base_stream))
Ok(GenericStream::new(base_stream))
}
async fn listen<A: Send + Into<SOCKSv5Address>>(
@@ -128,14 +138,59 @@ impl Networklike for Builtin {
}
}
// pub struct StandardNetworking {}
//
// impl StandardNetworking {
// pub fn new() -> StandardNetworking {
// StandardNetworking {}
// }
// }
//
#[test]
fn check_sanity() {
async_std::task::block_on(async {
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
// going to get any dropped data along here ... which is a very questionable
// assumption, morally speaking, but probably fine for most purposes.
let mut network = Builtin::new();
let receiver = network.bind("localhost", 0).await.expect("Failed to bind receiver socket.");
let sender = network.bind("localhost", 0).await.expect("Failed to bind sender socket.");
let buffer = [0xde, 0xea, 0xbe, 0xef];
let (receiver_addr, receiver_port) = receiver.local_addr();
sender.send_to(&buffer, receiver_addr, receiver_port).await.expect("Failure sending datagram!");
let mut recvbuffer = [0; 4];
let (s, f, p) = receiver.recv_from(&mut recvbuffer).await.expect("Didn't receive UDP message?");
let (sender_addr, sender_port) = sender.local_addr();
assert_eq!(s, 4);
assert_eq!(f, sender_addr);
assert_eq!(p, sender_port);
assert_eq!(recvbuffer, buffer);
});
// This whole block should be pretty solid, though, unless the system we're
// on is in a pretty weird place.
let mut network = Builtin::new();
let listener = async_std::task::block_on(network.listen("localhost", 0)).expect("Couldn't set up listener on localhost");
let (listener_address, listener_port) = listener.local_addr();
let listener_task_handle = async_std::task::spawn(async move {
let (mut stream, addr, port) = listener.accept().await.expect("Didn't get connection");
let mut result_buffer = [0u8; 4];
println!("Starting read!");
stream.read_exact(&mut result_buffer).await.expect("Read failure in TCP test");
(result_buffer, addr, port)
});
let sender_task_handle = async_std::task::spawn(async move {
let mut sender = network.connect(listener_address, listener_port).await.expect("Coudln't connect to listener?");
let (sender_address, sender_port) = sender.local_addr();
let send_buffer = [0xa, 0xff, 0xab, 0x1e];
sender.write_all(&send_buffer).await.expect("Couldn't send the write buffer");
(sender_address, sender_port)
});
async_std::task::block_on(async {
let (result, result_from, result_from_port) = listener_task_handle.await;
assert_eq!(result, [0xa, 0xff, 0xab, 0x1e]);
let (sender_address, sender_port) = sender_task_handle.await;
assert_eq!(result_from, sender_address);
assert_eq!(result_from_port, sender_port);
});
}
impl From<io::Error> for ServerResponseStatus {
fn from(e: io::Error) -> ServerResponseStatus {
match e.kind() {
@@ -144,71 +199,4 @@ impl From<io::Error> for ServerResponseStatus {
_ => ServerResponseStatus::GeneralFailure,
}
}
}
//
// #[async_trait]
// impl Network for StandardNetworking {
// type Stream = TcpStream;
// type Listener = TcpListener;
// type UdpSocket = UdpSocket;
// type Error = io::Error;
//
// async fn connect<A: ToSOCKSAddress>(
// &mut self,
// addr: A,
// port: u16,
// ) -> Result<Self::Stream, Self::Error> {
// let target = addr.into();
//
// match target {
// SOCKSv5Address::IP4(a) => TcpStream::connect((a, port)).await,
// SOCKSv5Address::IP6(a) => TcpStream::connect((a, port)).await,
// SOCKSv5Address::Name(n) => TcpStream::connect((n.as_str(), port)).await,
// }
// }
//
// async fn udp_socket<A: ToSOCKSAddress>(
// &mut self,
// addr: A,
// port: Option<u16>,
// ) -> Result<Self::UdpSocket, Self::Error> {
// let me = addr.into();
// let real_port = port.unwrap_or(0);
//
// match me {
// SOCKSv5Address::IP4(a) => UdpSocket::bind((a, real_port)).await,
// SOCKSv5Address::IP6(a) => UdpSocket::bind((a, real_port)).await,
// SOCKSv5Address::Name(n) => UdpSocket::bind((n.as_str(), real_port)).await,
// }
// }
//
// async fn listen<A: ToSOCKSAddress>(
// &mut self,
// addr: A,
// port: Option<u16>,
// ) -> Result<Self::Listener, Self::Error> {
// let me = addr.into();
// let real_port = port.unwrap_or(0);
//
// match me {
// SOCKSv5Address::IP4(a) => TcpListener::bind((a, real_port)).await,
// SOCKSv5Address::IP6(a) => TcpListener::bind((a, real_port)).await,
// SOCKSv5Address::Name(n) => TcpListener::bind((n.as_str(), real_port)).await,
// }
// }
// }
//
// #[async_trait]
// impl SingleShotListener<TcpStream, io::Error> for TcpListener {
// async fn accept(self) -> Result<TcpStream, io::Error> {
// self.accept().await
// }
//
// fn info(&self) -> Result<(SOCKSv5Address, u16), io::Error> {
// match self.local_addr()? {
// SocketAddr::V4(a) => Ok((SOCKSv5Address::IP4(*a.ip()), a.port())),
// SocketAddr::V6(a) => Ok((SOCKSv5Address::IP6(*a.ip()), a.port())),
// }
// }
// }
//
}

View File

@@ -1,10 +1,13 @@
use async_std::task::{Context, Poll};
use futures::io;
use crate::network::SOCKSv5Address;
use futures::io::{AsyncRead, AsyncWrite};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
pub trait Streamlike: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
use super::address::HasLocalAddress;
pub trait Streamlike: AsyncRead + AsyncWrite + HasLocalAddress + Send + Sync + Unpin {}
#[derive(Clone)]
pub struct GenericStream {
@@ -19,6 +22,11 @@ impl GenericStream {
}
}
impl HasLocalAddress for GenericStream {
fn local_addr(&self) -> (SOCKSv5Address, u16) {
let item = self.internal.lock().unwrap();
item.local_addr()
}
}
impl AsyncRead for GenericStream {

View File

@@ -3,6 +3,7 @@ use crate::messages::{
AuthenticationMethod, ClientConnectionCommand, ClientConnectionRequest, ClientGreeting,
ClientUsernamePassword, ServerChoice, ServerResponse, ServerResponseStatus,
};
use crate::network::address::HasLocalAddress;
use crate::network::generic::Networklike;
use crate::network::listener::{GenericListener, Listenerlike};
use crate::network::stream::GenericStream;
@@ -45,7 +46,7 @@ impl<N: Networklike + Send + 'static> SOCKSv5Server<N> {
}
pub async fn run(self) -> Result<(), N::Error> {
let (my_addr, my_port) = self.listener.info();
let (my_addr, my_port) = self.listener.local_addr();
info!("Starting SOCKSv5 server on {}:{}", my_addr, my_port);
let locked_network = Arc::new(Mutex::new(self.network));
@@ -241,7 +242,7 @@ where
continue;
}
};
let (bound_address, bound_port) = incoming_listener.info();
let (bound_address, bound_port) = incoming_listener.local_addr();
trace!(
"Set up {}:{} to address request for {}:{}",
bound_address,