From 67b2acab2557a935bf94f41609cd356d1ff72d04 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sat, 6 Nov 2021 20:55:58 -0700 Subject: [PATCH] Add support for knowing when the write end of a testing stream drops the reference, and then triggering errors on the read side. --- src/network/testing/stream.rs | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/network/testing/stream.rs b/src/network/testing/stream.rs index 568ba0e..2e66423 100644 --- a/src/network/testing/stream.rs +++ b/src/network/testing/stream.rs @@ -22,6 +22,7 @@ unsafe impl Sync for TestingStream {} struct TestingStreamData { lock: AtomicBool, + writer_dead: AtomicBool, waiters: UnsafeCell>, buffer: UnsafeCell>, } @@ -36,12 +37,14 @@ impl TestingStream { pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream { let read_side_data = TestingStreamData { lock: AtomicBool::new(false), + writer_dead: AtomicBool::new(false), waiters: UnsafeCell::new(Vec::new()), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), }; let write_side_data = TestingStreamData { lock: AtomicBool::new(false), + writer_dead: AtomicBool::new(false), waiters: UnsafeCell::new(Vec::new()), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), }; @@ -107,14 +110,25 @@ impl Read for TestingStream { let internals = unsafe { self.read_side.as_mut() }; internals.acquire(); + let stream_buffer = internals.buffer.get_mut(); let amount_available = stream_buffer.len(); if amount_available == 0 { - let waker = cx.waker().clone(); - internals.waiters.get_mut().push(waker); - internals.release(); - return Poll::Pending; + // we wait to do this check until we've determined the buffer is empty, + // so that we make sure to drain any residual stuff in there. + if internals.writer_dead.load(Ordering::SeqCst) { + internals.release(); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "Writer closed the socket.", + ))); + } else { + let waker = cx.waker().clone(); + internals.waiters.get_mut().push(waker); + internals.release(); + return Poll::Pending; + } } let amt_written = if buf.len() >= amount_available { @@ -164,3 +178,15 @@ impl Write for TestingStream { } impl Streamlike for TestingStream {} + +impl Drop for TestingStream { + fn drop(&mut self) { + let internals = unsafe { self.write_side.as_mut() }; + internals.writer_dead.store(true, Ordering::SeqCst); + internals.acquire(); + for waiter in internals.waiters.get_mut().drain(0..) { + waiter.wake(); + } + internals.release(); + } +}