Whoops! TCP streams are dual-buffered!

This adjusts the way that TestingStream is implemented to allow for
two, separate buffers for each of the two directions. In the prior
implementation, if you called `write` and then `read`, you would
`read` the data you just wrote. Which is not what you want; you want
to block until you get data back from the other side.
This commit is contained in:
2021-10-09 18:29:33 -07:00
parent 748dc33a36
commit 3364031c18
2 changed files with 47 additions and 26 deletions

View File

@@ -104,7 +104,7 @@ impl Networklike for TestingStack {
None => Err(TestStackError::NoTCPHostFound(target, port)), None => Err(TestStackError::NoTCPHostFound(target, port)),
Some(result) => { Some(result) => {
let stream = TestingStream::new(target, port); let stream = TestingStream::new(target, port);
let retval = stream.clone(); let retval = stream.invert();
match result.send(stream).await { match result.send(stream).await {
Ok(()) => Ok(GenericStream::new(retval)), Ok(()) => Ok(GenericStream::new(retval)),
Err(_) => Err(TestStackError::FailureToSend), Err(_) => Err(TestStackError::FailureToSend),
@@ -192,11 +192,8 @@ impl Listenerlike for TestListener {
} }
#[test] #[test]
fn check_sanity() { fn check_udp_sanity() {
task::block_on(async { task::block_on(async {
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
// going to get any dropped data along here ... which is a very questionable
// assumption, morally speaking, but probably fine for most purposes.
let mut network = TestingStack::new(); let mut network = TestingStack::new();
let receiver = network let receiver = network
.bind("localhost", 0) .bind("localhost", 0)
@@ -223,7 +220,10 @@ fn check_sanity() {
assert_eq!(p, sender_port); assert_eq!(p, sender_port);
assert_eq!(recvbuffer, buffer); assert_eq!(recvbuffer, buffer);
}); });
}
#[test]
fn check_basic_tcp() {
task::block_on(async { task::block_on(async {
let mut network = TestingStack::new(); let mut network = TestingStack::new();

View File

@@ -13,7 +13,8 @@ use std::sync::atomic::{AtomicBool, Ordering};
pub struct TestingStream { pub struct TestingStream {
address: SOCKSv5Address, address: SOCKSv5Address,
port: u16, port: u16,
internals: NonNull<TestingStreamData>, read_side: NonNull<TestingStreamData>,
write_side: NonNull<TestingStreamData>,
} }
unsafe impl Send for TestingStream {} unsafe impl Send for TestingStream {}
@@ -29,28 +30,51 @@ unsafe impl Send for TestingStreamData {}
unsafe impl Sync for TestingStreamData {} unsafe impl Sync for TestingStreamData {}
impl TestingStream { impl TestingStream {
/// Generate a testing stream. Note that this is directional. So, if you want to
/// talk to this stream, you should also generate an `invert()` and pass that to
/// the other thread(s).
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream { pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
let tsd = TestingStreamData { let read_side_data = TestingStreamData {
lock: AtomicBool::new(false), lock: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()), waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
}; };
let boxed_tsd = Box::new(tsd); let write_side_data = TestingStreamData {
let raw_ptr = Box::leak(boxed_tsd); lock: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
};
let boxed_rsd = Box::new(read_side_data);
let boxed_wsd = Box::new(write_side_data);
let raw_read_ptr = Box::leak(boxed_rsd);
let raw_write_ptr = Box::leak(boxed_wsd);
TestingStream { TestingStream {
address, address,
port, port,
internals: NonNull::new(raw_ptr).unwrap(), read_side: NonNull::new(raw_read_ptr).unwrap(),
write_side: NonNull::new(raw_write_ptr).unwrap(),
} }
} }
pub fn acquire_lock(&mut self) { /// Get the flip side of this stream; reads from the inverted side will catch the writes
loop { /// of the original, etc.
let internals = unsafe { self.internals.as_mut() }; pub fn invert(&self) -> TestingStream {
TestingStream {
address: self.address.clone(),
port: self.port,
read_side: self.write_side.clone(),
write_side: self.read_side.clone(),
}
}
}
match internals impl TestingStreamData {
fn acquire(&mut self) {
loop {
match self
.lock .lock
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
{ {
@@ -60,9 +84,8 @@ impl TestingStream {
} }
} }
pub fn release_lock(&mut self) { fn release(&mut self) {
let internals = unsafe { self.internals.as_mut() }; self.lock.store(false, Ordering::SeqCst);
internals.lock.store(false, Ordering::SeqCst);
} }
} }
@@ -81,17 +104,16 @@ impl Read for TestingStream {
// so, we're going to spin here, which is less than ideal but should work fine // so, we're going to spin here, which is less than ideal but should work fine
// in practice. we'll obviously need to be very careful to ensure that we keep // in practice. we'll obviously need to be very careful to ensure that we keep
// the stuff internal to this spin really short. // the stuff internal to this spin really short.
self.acquire_lock(); let internals = unsafe { self.read_side.as_mut() };
let internals = unsafe { self.internals.as_mut() }; internals.acquire();
let stream_buffer = internals.buffer.get_mut(); let stream_buffer = internals.buffer.get_mut();
let amount_available = stream_buffer.len(); let amount_available = stream_buffer.len();
if amount_available == 0 { if amount_available == 0 {
let waker = cx.waker().clone(); let waker = cx.waker().clone();
internals.waiters.get_mut().push(waker); internals.waiters.get_mut().push(waker);
self.release_lock(); internals.release();
return Poll::Pending; return Poll::Pending;
} }
@@ -108,7 +130,7 @@ impl Read for TestingStream {
amt_to_copy amt_to_copy
}; };
self.release_lock(); internals.release();
Poll::Ready(Ok(amt_written)) Poll::Ready(Ok(amt_written))
} }
@@ -120,15 +142,14 @@ impl Write for TestingStream {
_cx: &mut Context<'_>, _cx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
self.acquire_lock(); let internals = unsafe { self.write_side.as_mut() };
let internals = unsafe { self.internals.as_mut() }; internals.acquire();
let stream_buffer = internals.buffer.get_mut(); let stream_buffer = internals.buffer.get_mut();
stream_buffer.extend_from_slice(buf); stream_buffer.extend_from_slice(buf);
for waiter in internals.waiters.get_mut().drain(0..) { for waiter in internals.waiters.get_mut().drain(0..) {
waiter.wake(); waiter.wake();
} }
self.release_lock(); internals.release();
Poll::Ready(Ok(buf.len())) Poll::Ready(Ok(buf.len()))
} }