Add support for knowing when the write end of a testing stream drops the reference, and then triggering errors on the read side.

This commit is contained in:
2021-11-06 20:55:58 -07:00
parent 58c04adeb7
commit 67b2acab25

View File

@@ -22,6 +22,7 @@ unsafe impl Sync for TestingStream {}
struct TestingStreamData { struct TestingStreamData {
lock: AtomicBool, lock: AtomicBool,
writer_dead: AtomicBool,
waiters: UnsafeCell<Vec<Waker>>, waiters: UnsafeCell<Vec<Waker>>,
buffer: UnsafeCell<Vec<u8>>, buffer: UnsafeCell<Vec<u8>>,
} }
@@ -36,12 +37,14 @@ impl TestingStream {
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream { pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
let read_side_data = TestingStreamData { let read_side_data = TestingStreamData {
lock: AtomicBool::new(false), lock: AtomicBool::new(false),
writer_dead: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()), waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
}; };
let write_side_data = TestingStreamData { let write_side_data = TestingStreamData {
lock: AtomicBool::new(false), lock: AtomicBool::new(false),
writer_dead: AtomicBool::new(false),
waiters: UnsafeCell::new(Vec::new()), waiters: UnsafeCell::new(Vec::new()),
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
}; };
@@ -107,14 +110,25 @@ impl Read for TestingStream {
let internals = unsafe { self.read_side.as_mut() }; let internals = unsafe { self.read_side.as_mut() };
internals.acquire(); internals.acquire();
let stream_buffer = internals.buffer.get_mut(); let stream_buffer = internals.buffer.get_mut();
let amount_available = stream_buffer.len(); let amount_available = stream_buffer.len();
if amount_available == 0 { if amount_available == 0 {
let waker = cx.waker().clone(); // we wait to do this check until we've determined the buffer is empty,
internals.waiters.get_mut().push(waker); // so that we make sure to drain any residual stuff in there.
internals.release(); if internals.writer_dead.load(Ordering::SeqCst) {
return Poll::Pending; internals.release();
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Writer closed the socket.",
)));
} else {
let waker = cx.waker().clone();
internals.waiters.get_mut().push(waker);
internals.release();
return Poll::Pending;
}
} }
let amt_written = if buf.len() >= amount_available { let amt_written = if buf.len() >= amount_available {
@@ -164,3 +178,15 @@ impl Write for TestingStream {
} }
impl Streamlike for TestingStream {} impl Streamlike for TestingStream {}
impl Drop for TestingStream {
fn drop(&mut self) {
let internals = unsafe { self.write_side.as_mut() };
internals.writer_dead.store(true, Ordering::SeqCst);
internals.acquire();
for waiter in internals.waiters.get_mut().drain(0..) {
waiter.wake();
}
internals.release();
}
}