Add support for knowing when the write end of a testing stream drops the reference, and then triggering errors on the read side.
This commit is contained in:
@@ -22,6 +22,7 @@ unsafe impl Sync for TestingStream {}
|
||||
|
||||
struct TestingStreamData {
|
||||
lock: AtomicBool,
|
||||
writer_dead: AtomicBool,
|
||||
waiters: UnsafeCell<Vec<Waker>>,
|
||||
buffer: UnsafeCell<Vec<u8>>,
|
||||
}
|
||||
@@ -36,12 +37,14 @@ impl TestingStream {
|
||||
pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream {
|
||||
let read_side_data = TestingStreamData {
|
||||
lock: AtomicBool::new(false),
|
||||
writer_dead: AtomicBool::new(false),
|
||||
waiters: UnsafeCell::new(Vec::new()),
|
||||
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
||||
};
|
||||
|
||||
let write_side_data = TestingStreamData {
|
||||
lock: AtomicBool::new(false),
|
||||
writer_dead: AtomicBool::new(false),
|
||||
waiters: UnsafeCell::new(Vec::new()),
|
||||
buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)),
|
||||
};
|
||||
@@ -107,14 +110,25 @@ impl Read for TestingStream {
|
||||
let internals = unsafe { self.read_side.as_mut() };
|
||||
|
||||
internals.acquire();
|
||||
|
||||
let stream_buffer = internals.buffer.get_mut();
|
||||
let amount_available = stream_buffer.len();
|
||||
|
||||
if amount_available == 0 {
|
||||
let waker = cx.waker().clone();
|
||||
internals.waiters.get_mut().push(waker);
|
||||
internals.release();
|
||||
return Poll::Pending;
|
||||
// we wait to do this check until we've determined the buffer is empty,
|
||||
// so that we make sure to drain any residual stuff in there.
|
||||
if internals.writer_dead.load(Ordering::SeqCst) {
|
||||
internals.release();
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"Writer closed the socket.",
|
||||
)));
|
||||
} else {
|
||||
let waker = cx.waker().clone();
|
||||
internals.waiters.get_mut().push(waker);
|
||||
internals.release();
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
let amt_written = if buf.len() >= amount_available {
|
||||
@@ -164,3 +178,15 @@ impl Write for TestingStream {
|
||||
}
|
||||
|
||||
impl Streamlike for TestingStream {}
|
||||
|
||||
impl Drop for TestingStream {
|
||||
fn drop(&mut self) {
|
||||
let internals = unsafe { self.write_side.as_mut() };
|
||||
internals.writer_dead.store(true, Ordering::SeqCst);
|
||||
internals.acquire();
|
||||
for waiter in internals.waiters.get_mut().drain(0..) {
|
||||
waiter.wake();
|
||||
}
|
||||
internals.release();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user