Make write() consume the objects.

This commit is contained in:
2022-05-14 20:37:18 -07:00
parent c8279cfc5f
commit 277125e1a0
9 changed files with 15 additions and 14 deletions

View File

@@ -41,13 +41,14 @@ macro_rules! standard_roundtrip {
tokio::runtime::Runtime::new().unwrap().block_on(async { tokio::runtime::Runtime::new().unwrap().block_on(async {
use std::io::Cursor; use std::io::Cursor;
let originals = xs.clone();
let buffer = vec![]; let buffer = vec![];
let mut write_cursor = Cursor::new(buffer); let mut write_cursor = Cursor::new(buffer);
xs.write(&mut write_cursor).await.unwrap(); xs.write(&mut write_cursor).await.unwrap();
let serialized_form = write_cursor.into_inner(); let serialized_form = write_cursor.into_inner();
let mut read_cursor = Cursor::new(serialized_form); let mut read_cursor = Cursor::new(serialized_form);
let ys = <$t>::read(&mut read_cursor); let ys = <$t>::read(&mut read_cursor);
assert_eq!(xs, ys.await.unwrap()); assert_eq!(originals, ys.await.unwrap());
}) })
} }
} }

View File

@@ -118,7 +118,7 @@ impl AuthenticationMethod {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), AuthenticationMethodWriteError> { ) -> Result<(), AuthenticationMethodWriteError> {
let value = match self { let value = match self {
@@ -131,9 +131,9 @@ impl AuthenticationMethod {
AuthenticationMethod::NDS => 7, AuthenticationMethod::NDS => 7,
AuthenticationMethod::MultiAuthenticationFramework => 8, AuthenticationMethod::MultiAuthenticationFramework => 8,
AuthenticationMethod::JSONPropertyBlock => 9, AuthenticationMethod::JSONPropertyBlock => 9,
AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(pm) => *pm, AuthenticationMethod::PrivateMethod(pm) if (0x80..=0xfe).contains(&pm) => pm,
AuthenticationMethod::PrivateMethod(pm) => { AuthenticationMethod::PrivateMethod(pm) => {
return Err(AuthenticationMethodWriteError::InvalidAuthMethod(*pm)) return Err(AuthenticationMethodWriteError::InvalidAuthMethod(pm))
} }
AuthenticationMethod::NoAcceptableMethods => 0xff, AuthenticationMethod::NoAcceptableMethods => 0xff,
}; };

View File

@@ -55,7 +55,7 @@ impl ClientConnectionCommand {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), std::io::Error> { ) -> Result<(), std::io::Error> {
match self { match self {
@@ -125,7 +125,7 @@ impl ClientConnectionRequest {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), ClientConnectionCommandWriteError> { ) -> Result<(), ClientConnectionCommandWriteError> {
w.write_u8(5).await?; w.write_u8(5).await?;

View File

@@ -72,7 +72,7 @@ impl ClientGreeting {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, mut self,
w: &mut W, w: &mut W,
) -> Result<(), ClientGreetingWriteError> { ) -> Result<(), ClientGreetingWriteError> {
if self.acceptable_methods.len() > 255 { if self.acceptable_methods.len() > 255 {
@@ -85,7 +85,7 @@ impl ClientGreeting {
buffer.push(5); buffer.push(5);
buffer.push(self.acceptable_methods.len() as u8); buffer.push(self.acceptable_methods.len() as u8);
w.write_all(&buffer).await?; w.write_all(&buffer).await?;
for authmeth in self.acceptable_methods.iter() { for authmeth in self.acceptable_methods.drain(..) {
authmeth.write(w).await?; authmeth.write(w).await?;
} }
Ok(()) Ok(())

View File

@@ -81,7 +81,7 @@ impl ClientUsernamePassword {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), ClientUsernamePasswordWriteError> { ) -> Result<(), ClientUsernamePasswordWriteError> {
w.write_u8(1).await?; w.write_u8(1).await?;

View File

@@ -59,7 +59,7 @@ impl ServerAuthResponse {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), ServerAuthResponseWriteError> { ) -> Result<(), ServerAuthResponseWriteError> {
w.write_all(&[1]).await?; w.write_all(&[1]).await?;

View File

@@ -72,7 +72,7 @@ impl ServerChoice {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), ServerChoiceWriteError> { ) -> Result<(), ServerChoiceWriteError> {
w.write_u8(5).await?; w.write_u8(5).await?;

View File

@@ -111,7 +111,7 @@ impl ServerResponse {
} }
pub async fn write<W: AsyncWrite + Send + Unpin>( pub async fn write<W: AsyncWrite + Send + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), ServerResponseWriteError> { ) -> Result<(), ServerResponseWriteError> {
w.write_u8(5).await?; w.write_u8(5).await?;

View File

@@ -5,7 +5,7 @@ use std::string::FromUtf8Error;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct SOCKSv5String(String); pub struct SOCKSv5String(String);
#[cfg(test)] #[cfg(test)]
@@ -75,7 +75,7 @@ impl SOCKSv5String {
} }
pub async fn write<W: AsyncWrite + Unpin>( pub async fn write<W: AsyncWrite + Unpin>(
&self, self,
w: &mut W, w: &mut W,
) -> Result<(), SOCKSv5StringWriteError> { ) -> Result<(), SOCKSv5StringWriteError> {
let bytestring = self.0.as_bytes(); let bytestring = self.0.as_bytes();