From 3364031c1870992123cd0b85698fbecebd2cc3a4 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sat, 9 Oct 2021 18:29:33 -0700 Subject: [PATCH] Whoops! TCP streams are dual-buffered! This adjusts the way that TestingStream is implemented to allow for two, separate buffers for each of the two directions. In the prior implementation, if you called `write` and then `read`, you would `read` the data you just wrote. Which is not what you want; you want to block until you get data back from the other side. --- src/network/testing.rs | 10 +++--- src/network/testing/stream.rs | 63 +++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/src/network/testing.rs b/src/network/testing.rs index d361fe2..3f5ad1f 100644 --- a/src/network/testing.rs +++ b/src/network/testing.rs @@ -104,7 +104,7 @@ impl Networklike for TestingStack { None => Err(TestStackError::NoTCPHostFound(target, port)), Some(result) => { let stream = TestingStream::new(target, port); - let retval = stream.clone(); + let retval = stream.invert(); match result.send(stream).await { Ok(()) => Ok(GenericStream::new(retval)), Err(_) => Err(TestStackError::FailureToSend), @@ -192,11 +192,8 @@ impl Listenerlike for TestListener { } #[test] -fn check_sanity() { +fn check_udp_sanity() { task::block_on(async { - // Technically, this is UDP, and UDP is lossy. We're going to assume we're not - // going to get any dropped data along here ... which is a very questionable - // assumption, morally speaking, but probably fine for most purposes. let mut network = TestingStack::new(); let receiver = network .bind("localhost", 0) @@ -223,7 +220,10 @@ fn check_sanity() { assert_eq!(p, sender_port); assert_eq!(recvbuffer, buffer); }); +} +#[test] +fn check_basic_tcp() { task::block_on(async { let mut network = TestingStack::new(); diff --git a/src/network/testing/stream.rs b/src/network/testing/stream.rs index b926d44..b793ccc 100644 --- a/src/network/testing/stream.rs +++ b/src/network/testing/stream.rs @@ -13,7 +13,8 @@ use std::sync::atomic::{AtomicBool, Ordering}; pub struct TestingStream { address: SOCKSv5Address, port: u16, - internals: NonNull, + read_side: NonNull, + write_side: NonNull, } unsafe impl Send for TestingStream {} @@ -29,28 +30,51 @@ unsafe impl Send for TestingStreamData {} unsafe impl Sync for TestingStreamData {} impl TestingStream { + /// Generate a testing stream. Note that this is directional. So, if you want to + /// talk to this stream, you should also generate an `invert()` and pass that to + /// the other thread(s). pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream { - let tsd = TestingStreamData { + let read_side_data = TestingStreamData { lock: AtomicBool::new(false), waiters: UnsafeCell::new(Vec::new()), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), }; - let boxed_tsd = Box::new(tsd); - let raw_ptr = Box::leak(boxed_tsd); + let write_side_data = TestingStreamData { + lock: AtomicBool::new(false), + waiters: UnsafeCell::new(Vec::new()), + buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), + }; + + let boxed_rsd = Box::new(read_side_data); + let boxed_wsd = Box::new(write_side_data); + let raw_read_ptr = Box::leak(boxed_rsd); + let raw_write_ptr = Box::leak(boxed_wsd); TestingStream { address, port, - internals: NonNull::new(raw_ptr).unwrap(), + read_side: NonNull::new(raw_read_ptr).unwrap(), + write_side: NonNull::new(raw_write_ptr).unwrap(), } } - pub fn acquire_lock(&mut self) { - loop { - let internals = unsafe { self.internals.as_mut() }; + /// Get the flip side of this stream; reads from the inverted side will catch the writes + /// of the original, etc. + pub fn invert(&self) -> TestingStream { + TestingStream { + address: self.address.clone(), + port: self.port, + read_side: self.write_side.clone(), + write_side: self.read_side.clone(), + } + } +} - match internals +impl TestingStreamData { + fn acquire(&mut self) { + loop { + match self .lock .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) { @@ -60,9 +84,8 @@ impl TestingStream { } } - pub fn release_lock(&mut self) { - let internals = unsafe { self.internals.as_mut() }; - internals.lock.store(false, Ordering::SeqCst); + fn release(&mut self) { + self.lock.store(false, Ordering::SeqCst); } } @@ -81,17 +104,16 @@ impl Read for TestingStream { // so, we're going to spin here, which is less than ideal but should work fine // in practice. we'll obviously need to be very careful to ensure that we keep // the stuff internal to this spin really short. - self.acquire_lock(); + let internals = unsafe { self.read_side.as_mut() }; - let internals = unsafe { self.internals.as_mut() }; + internals.acquire(); let stream_buffer = internals.buffer.get_mut(); - let amount_available = stream_buffer.len(); if amount_available == 0 { let waker = cx.waker().clone(); internals.waiters.get_mut().push(waker); - self.release_lock(); + internals.release(); return Poll::Pending; } @@ -108,7 +130,7 @@ impl Read for TestingStream { amt_to_copy }; - self.release_lock(); + internals.release(); Poll::Ready(Ok(amt_written)) } @@ -120,15 +142,14 @@ impl Write for TestingStream { _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.acquire_lock(); - let internals = unsafe { self.internals.as_mut() }; + let internals = unsafe { self.write_side.as_mut() }; + internals.acquire(); let stream_buffer = internals.buffer.get_mut(); - stream_buffer.extend_from_slice(buf); for waiter in internals.waiters.get_mut().drain(0..) { waiter.wake(); } - self.release_lock(); + internals.release(); Poll::Ready(Ok(buf.len())) }