Whoops! TCP streams are dual-buffered!

This adjusts the way that TestingStream is implemented to allow for
two, separate buffers for each of the two directions. In the prior
implementation, if you called `write` and then `read`, you would
`read` the data you just wrote. Which is not what you want; you want
to block until you get data back from the other side.
This commit is contained in:
2021-10-09 18:29:33 -07:00
parent 748dc33a36
commit 3364031c18
2 changed files with 47 additions and 26 deletions

View File

@@ -104,7 +104,7 @@ impl Networklike for TestingStack {
None => Err(TestStackError::NoTCPHostFound(target, port)),
Some(result) => {
let stream = TestingStream::new(target, port);
let retval = stream.clone();
let retval = stream.invert();
match result.send(stream).await {
Ok(()) => Ok(GenericStream::new(retval)),
Err(_) => Err(TestStackError::FailureToSend),
@@ -192,11 +192,8 @@ impl Listenerlike for TestListener {
}
#[test]
fn check_sanity() {
fn check_udp_sanity() {
task::block_on(async {
// Technically, this is UDP, and UDP is lossy. We're going to assume we're not
// going to get any dropped data along here ... which is a very questionable
// assumption, morally speaking, but probably fine for most purposes.
let mut network = TestingStack::new();
let receiver = network
.bind("localhost", 0)
@@ -223,7 +220,10 @@ fn check_sanity() {
assert_eq!(p, sender_port);
assert_eq!(recvbuffer, buffer);
});
}
#[test]
fn check_basic_tcp() {
task::block_on(async {
let mut network = TestingStack::new();

View File

@@ -13,7 +13,8 @@ use std::sync::atomic::{AtomicBool, Ordering};
pub struct TestingStream {
address: SOCKSv5Address,
port: u16,
internals: NonNull<TestingStreamData>,
read_side: NonNull<TestingStreamData>,
write_side: NonNull<TestingStreamData>,
}
unsafe impl Send for TestingStream {}
@@ -29,28 +30,51 @@ unsafe impl Send for TestingStreamData {}
unsafe impl Sync for TestingStreamData {}
impl TestingStream {
/// Generate a testing stream. Note that this is directional. So, if you want to
/// talk to this stream, you should also generate an `invert()` and pass that to
/// the other thread(s).
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
let tsd = TestingStreamData {
let read_side_data = TestingStreamData {
lock: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
};
let boxed_tsd = Box::new(tsd);
let raw_ptr = Box::leak(boxed_tsd);
let write_side_data = TestingStreamData {
lock: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
};
let boxed_rsd = Box::new(read_side_data);
let boxed_wsd = Box::new(write_side_data);
let raw_read_ptr = Box::leak(boxed_rsd);
let raw_write_ptr = Box::leak(boxed_wsd);
TestingStream {
address,
port,
internals: NonNull::new(raw_ptr).unwrap(),
read_side: NonNull::new(raw_read_ptr).unwrap(),
write_side: NonNull::new(raw_write_ptr).unwrap(),
}
}
pub fn acquire_lock(&mut self) {
/// Get the flip side of this stream; reads from the inverted side will catch the writes
/// of the original, etc.
pub fn invert(&self) -> TestingStream {
TestingStream {
address: self.address.clone(),
port: self.port,
read_side: self.write_side.clone(),
write_side: self.read_side.clone(),
}
}
}
impl TestingStreamData {
fn acquire(&mut self) {
loop {
let internals = unsafe { self.internals.as_mut() };
match internals
match self
.lock
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
{
@@ -60,9 +84,8 @@ impl TestingStream {
}
}
pub fn release_lock(&mut self) {
let internals = unsafe { self.internals.as_mut() };
internals.lock.store(false, Ordering::SeqCst);
fn release(&mut self) {
self.lock.store(false, Ordering::SeqCst);
}
}
@@ -81,17 +104,16 @@ impl Read for TestingStream {
// so, we're going to spin here, which is less than ideal but should work fine
// in practice. we'll obviously need to be very careful to ensure that we keep
// the stuff internal to this spin really short.
self.acquire_lock();
let internals = unsafe { self.read_side.as_mut() };
let internals = unsafe { self.internals.as_mut() };
internals.acquire();
let stream_buffer = internals.buffer.get_mut();
let amount_available = stream_buffer.len();
if amount_available == 0 {
let waker = cx.waker().clone();
internals.waiters.get_mut().push(waker);
self.release_lock();
internals.release();
return Poll::Pending;
}
@@ -108,7 +130,7 @@ impl Read for TestingStream {
amt_to_copy
};
self.release_lock();
internals.release();
Poll::Ready(Ok(amt_written))
}
@@ -120,15 +142,14 @@ impl Write for TestingStream {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.acquire_lock();
let internals = unsafe { self.internals.as_mut() };
let internals = unsafe { self.write_side.as_mut() };
internals.acquire();
let stream_buffer = internals.buffer.get_mut();
stream_buffer.extend_from_slice(buf);
for waiter in internals.waiters.get_mut().drain(0..) {
waiter.wake();
}
self.release_lock();
internals.release();
Poll::Ready(Ok(buf.len()))
}