use crate::network::address::HasLocalAddress; use crate::network::stream::Streamlike; use crate::network::SOCKSv5Address; use async_std::io; use async_std::io::{Read, Write}; use async_std::task::{Context, Poll, Waker}; use std::cell::UnsafeCell; use std::pin::Pin; use std::ptr::NonNull; use std::sync::atomic::{AtomicBool, Ordering}; #[derive(Clone)] pub struct TestingStream { address: SOCKSv5Address, port: u16, read_side: NonNull, write_side: NonNull, } unsafe impl Send for TestingStream {} unsafe impl Sync for TestingStream {} struct TestingStreamData { lock: AtomicBool, waiters: UnsafeCell>, buffer: UnsafeCell>, } unsafe impl Send for TestingStreamData {} unsafe impl Sync for TestingStreamData {} impl TestingStream { /// Generate a testing stream. Note that this is directional. So, if you want to /// talk to this stream, you should also generate an `invert()` and pass that to /// the other thread(s). pub fn new(address: SOCKSv5Address, port: u16) -> TestingStream { let read_side_data = TestingStreamData { lock: AtomicBool::new(false), waiters: UnsafeCell::new(Vec::new()), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), }; let write_side_data = TestingStreamData { lock: AtomicBool::new(false), waiters: UnsafeCell::new(Vec::new()), buffer: UnsafeCell::new(Vec::with_capacity(16 * 1024)), }; let boxed_rsd = Box::new(read_side_data); let boxed_wsd = Box::new(write_side_data); let raw_read_ptr = Box::leak(boxed_rsd); let raw_write_ptr = Box::leak(boxed_wsd); TestingStream { address, port, read_side: NonNull::new(raw_read_ptr).unwrap(), write_side: NonNull::new(raw_write_ptr).unwrap(), } } /// Get the flip side of this stream; reads from the inverted side will catch the writes /// of the original, etc. pub fn invert(&self) -> TestingStream { TestingStream { address: self.address.clone(), port: self.port, read_side: self.write_side, write_side: self.read_side, } } } impl TestingStreamData { fn acquire(&mut self) { loop { match self .lock .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) { Err(_) => continue, Ok(_) => return, } } } fn release(&mut self) { self.lock.store(false, Ordering::SeqCst); } } impl HasLocalAddress for TestingStream { fn local_addr(&self) -> (SOCKSv5Address, u16) { (self.address.clone(), self.port) } } impl Read for TestingStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { // so, we're going to spin here, which is less than ideal but should work fine // in practice. we'll obviously need to be very careful to ensure that we keep // the stuff internal to this spin really short. let internals = unsafe { self.read_side.as_mut() }; internals.acquire(); let stream_buffer = internals.buffer.get_mut(); let amount_available = stream_buffer.len(); if amount_available == 0 { let waker = cx.waker().clone(); internals.waiters.get_mut().push(waker); internals.release(); return Poll::Pending; } let amt_written = if buf.len() >= amount_available { (&mut buf[0..amount_available]).copy_from_slice(stream_buffer); stream_buffer.clear(); amount_available } else { let amt_to_copy = buf.len(); buf.copy_from_slice(&stream_buffer[0..amt_to_copy]); stream_buffer.copy_within(amt_to_copy.., 0); let amt_left = amount_available - amt_to_copy; stream_buffer.resize(amt_left, 0); amt_to_copy }; internals.release(); Poll::Ready(Ok(amt_written)) } } impl Write for TestingStream { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let internals = unsafe { self.write_side.as_mut() }; internals.acquire(); let stream_buffer = internals.buffer.get_mut(); stream_buffer.extend_from_slice(buf); for waiter in internals.waiters.get_mut().drain(0..) { waiter.wake(); } internals.release(); Poll::Ready(Ok(buf.len())) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) // FIXME: Might consider having this wait until the buffer is empty } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) // FIXME: Might consider putting in some open/closed logic here } } impl Streamlike for TestingStream {}