Whoops! TCP streams are dual-buffered!
This adjusts the way that TestingStream is implemented to allow for two, separate buffers for each of the two directions. In the prior implementation, if you called `write` and then `read`, you would `read` the data you just wrote. Which is not what you want; you want to block until you get data back from the other side.
This commit is contained in:
@@ -104,7 +104,7 @@ impl Networklike for TestingStack {
|
|||||||
None => Err(TestStackError::NoTCPHostFound(target, port)),
|
None => Err(TestStackError::NoTCPHostFound(target, port)),
|
||||||
Some(result) => {
|
Some(result) => {
|
||||||
let stream = TestingStream::new(target, port);
|
let stream = TestingStream::new(target, port);
|
||||||
let retval = stream.clone();
|
let retval = stream.invert();
|
||||||
match result.send(stream).await {
|
match result.send(stream).await {
|
||||||
Ok(()) => Ok(GenericStream::new(retval)),
|
Ok(()) => Ok(GenericStream::new(retval)),
|
||||||
Err(_) => Err(TestStackError::FailureToSend),
|
Err(_) => Err(TestStackError::FailureToSend),
|
||||||
@@ -192,11 +192,8 @@ impl Listenerlike for TestListener {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn check_sanity() {
|
fn check_udp_sanity() {
|
||||||
task::block_on(async {
|
task::block_on(async {
|
||||||
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
|
|
||||||
// going to get any dropped data along here ... which is a very questionable
|
|
||||||
// assumption, morally speaking, but probably fine for most purposes.
|
|
||||||
let mut network = TestingStack::new();
|
let mut network = TestingStack::new();
|
||||||
let receiver = network
|
let receiver = network
|
||||||
.bind("localhost", 0)
|
.bind("localhost", 0)
|
||||||
@@ -223,7 +220,10 @@ fn check_sanity() {
|
|||||||
assert_eq!(p, sender_port);
|
assert_eq!(p, sender_port);
|
||||||
assert_eq!(recvbuffer, buffer);
|
assert_eq!(recvbuffer, buffer);
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn check_basic_tcp() {
|
||||||
task::block_on(async {
|
task::block_on(async {
|
||||||
let mut network = TestingStack::new();
|
let mut network = TestingStack::new();
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ use std::sync::atomic::{AtomicBool, Ordering};
|
|||||||
pub struct TestingStream {
|
pub struct TestingStream {
|
||||||
address: SOCKSv5Address,
|
address: SOCKSv5Address,
|
||||||
port: u16,
|
port: u16,
|
||||||
internals: NonNull<TestingStreamData>,
|
read_side: NonNull<TestingStreamData>,
|
||||||
|
write_side: NonNull<TestingStreamData>,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for TestingStream {}
|
unsafe impl Send for TestingStream {}
|
||||||
@@ -29,28 +30,51 @@ unsafe impl Send for TestingStreamData {}
|
|||||||
unsafe impl Sync for TestingStreamData {}
|
unsafe impl Sync for TestingStreamData {}
|
||||||
|
|
||||||
impl TestingStream {
|
impl TestingStream {
|
||||||
|
/// Generate a testing stream. Note that this is directional. So, if you want to
|
||||||
|
/// talk to this stream, you should also generate an `invert()` and pass that to
|
||||||
|
/// the other thread(s).
|
||||||
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
|
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
|
||||||
let tsd = TestingStreamData {
|
let read_side_data = TestingStreamData {
|
||||||
lock: AtomicBool::new(false),
|
lock: AtomicBool::new(false),
|
||||||
waiters: UnsafeCell::new(Vec::new()),
|
waiters: UnsafeCell::new(Vec::new()),
|
||||||
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let boxed_tsd = Box::new(tsd);
|
let write_side_data = TestingStreamData {
|
||||||
let raw_ptr = Box::leak(boxed_tsd);
|
lock: AtomicBool::new(false),
|
||||||
|
waiters: UnsafeCell::new(Vec::new()),
|
||||||
|
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let boxed_rsd = Box::new(read_side_data);
|
||||||
|
let boxed_wsd = Box::new(write_side_data);
|
||||||
|
let raw_read_ptr = Box::leak(boxed_rsd);
|
||||||
|
let raw_write_ptr = Box::leak(boxed_wsd);
|
||||||
|
|
||||||
TestingStream {
|
TestingStream {
|
||||||
address,
|
address,
|
||||||
port,
|
port,
|
||||||
internals: NonNull::new(raw_ptr).unwrap(),
|
read_side: NonNull::new(raw_read_ptr).unwrap(),
|
||||||
|
write_side: NonNull::new(raw_write_ptr).unwrap(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn acquire_lock(&mut self) {
|
/// Get the flip side of this stream; reads from the inverted side will catch the writes
|
||||||
|
/// of the original, etc.
|
||||||
|
pub fn invert(&self) -> TestingStream {
|
||||||
|
TestingStream {
|
||||||
|
address: self.address.clone(),
|
||||||
|
port: self.port,
|
||||||
|
read_side: self.write_side.clone(),
|
||||||
|
write_side: self.read_side.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestingStreamData {
|
||||||
|
fn acquire(&mut self) {
|
||||||
loop {
|
loop {
|
||||||
let internals = unsafe { self.internals.as_mut() };
|
match self
|
||||||
|
|
||||||
match internals
|
|
||||||
.lock
|
.lock
|
||||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||||
{
|
{
|
||||||
@@ -60,9 +84,8 @@ impl TestingStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn release_lock(&mut self) {
|
fn release(&mut self) {
|
||||||
let internals = unsafe { self.internals.as_mut() };
|
self.lock.store(false, Ordering::SeqCst);
|
||||||
internals.lock.store(false, Ordering::SeqCst);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,17 +104,16 @@ impl Read for TestingStream {
|
|||||||
// so, we're going to spin here, which is less than ideal but should work fine
|
// so, we're going to spin here, which is less than ideal but should work fine
|
||||||
// in practice. we'll obviously need to be very careful to ensure that we keep
|
// in practice. we'll obviously need to be very careful to ensure that we keep
|
||||||
// the stuff internal to this spin really short.
|
// the stuff internal to this spin really short.
|
||||||
self.acquire_lock();
|
let internals = unsafe { self.read_side.as_mut() };
|
||||||
|
|
||||||
let internals = unsafe { self.internals.as_mut() };
|
internals.acquire();
|
||||||
let stream_buffer = internals.buffer.get_mut();
|
let stream_buffer = internals.buffer.get_mut();
|
||||||
|
|
||||||
let amount_available = stream_buffer.len();
|
let amount_available = stream_buffer.len();
|
||||||
|
|
||||||
if amount_available == 0 {
|
if amount_available == 0 {
|
||||||
let waker = cx.waker().clone();
|
let waker = cx.waker().clone();
|
||||||
internals.waiters.get_mut().push(waker);
|
internals.waiters.get_mut().push(waker);
|
||||||
self.release_lock();
|
internals.release();
|
||||||
return Poll::Pending;
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +130,7 @@ impl Read for TestingStream {
|
|||||||
amt_to_copy
|
amt_to_copy
|
||||||
};
|
};
|
||||||
|
|
||||||
self.release_lock();
|
internals.release();
|
||||||
|
|
||||||
Poll::Ready(Ok(amt_written))
|
Poll::Ready(Ok(amt_written))
|
||||||
}
|
}
|
||||||
@@ -120,15 +142,14 @@ impl Write for TestingStream {
|
|||||||
_cx: &mut Context<'_>,
|
_cx: &mut Context<'_>,
|
||||||
buf: &[u8],
|
buf: &[u8],
|
||||||
) -> Poll<io::Result<usize>> {
|
) -> Poll<io::Result<usize>> {
|
||||||
self.acquire_lock();
|
let internals = unsafe { self.write_side.as_mut() };
|
||||||
let internals = unsafe { self.internals.as_mut() };
|
internals.acquire();
|
||||||
let stream_buffer = internals.buffer.get_mut();
|
let stream_buffer = internals.buffer.get_mut();
|
||||||
|
|
||||||
stream_buffer.extend_from_slice(buf);
|
stream_buffer.extend_from_slice(buf);
|
||||||
for waiter in internals.waiters.get_mut().drain(0..) {
|
for waiter in internals.waiters.get_mut().drain(0..) {
|
||||||
waiter.wake();
|
waiter.wake();
|
||||||
}
|
}
|
||||||
self.release_lock();
|
internals.release();
|
||||||
|
|
||||||
Poll::Ready(Ok(buf.len()))
|
Poll::Ready(Ok(buf.len()))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user