Switch to basic tokio; will expand later to arbitrary backends.

This commit is contained in:
2022-05-14 17:59:28 -07:00
parent d284f60d67
commit c8279cfc5f
29 changed files with 1472 additions and 2671 deletions

View File

@@ -1,16 +1,12 @@
#[cfg(test)]
use crate::errors::AuthenticationDeserializationError;
use crate::errors::{DeserializationError, SerializationError};
use crate::messages::AuthenticationMethod;
use crate::standard_roundtrip;
#[cfg(test)]
use async_std::task;
#[cfg(test)]
use futures::io::Cursor;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proptest::proptest;
use crate::messages::authentication_method::{
AuthenticationMethod, AuthenticationMethodReadError, AuthenticationMethodWriteError,
};
#[cfg(test)]
use proptest_derive::Arbitrary;
#[cfg(test)]
use std::io::Cursor;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(test, derive(Arbitrary))]
@@ -18,6 +14,36 @@ pub struct ServerChoice {
pub chosen_method: AuthenticationMethod,
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerChoiceReadError {
#[error(transparent)]
AuthMethodError(#[from] AuthenticationMethodReadError),
#[error("Error in underlying buffer: {0}")]
ReadError(String),
#[error("Invalid version; expected 5, got {0}")]
InvalidVersion(u8),
}
impl From<std::io::Error> for ServerChoiceReadError {
fn from(x: std::io::Error) -> ServerChoiceReadError {
ServerChoiceReadError::ReadError(format!("{}", x))
}
}
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ServerChoiceWriteError {
#[error(transparent)]
AuthMethodError(#[from] AuthenticationMethodWriteError),
#[error("Error in underlying buffer: {0}")]
WriteError(String),
}
impl From<std::io::Error> for ServerChoiceWriteError {
fn from(x: std::io::Error) -> ServerChoiceWriteError {
ServerChoiceWriteError::WriteError(format!("{}", x))
}
}
impl ServerChoice {
pub fn rejection() -> ServerChoice {
ServerChoice {
@@ -33,15 +59,11 @@ impl ServerChoice {
pub async fn read<R: AsyncRead + Send + Unpin>(
r: &mut R,
) -> Result<Self, DeserializationError> {
let mut buffer = [0; 1];
) -> Result<Self, ServerChoiceReadError> {
let version = r.read_u8().await?;
if r.read(&mut buffer).await? == 0 {
return Err(DeserializationError::NotEnoughData);
}
if buffer[0] != 5 {
return Err(DeserializationError::InvalidVersion(5, buffer[0]));
if version != 5 {
return Err(ServerChoiceReadError::InvalidVersion(version));
}
let chosen_method = AuthenticationMethod::read(r).await?;
@@ -52,39 +74,32 @@ impl ServerChoice {
pub async fn write<W: AsyncWrite + Send + Unpin>(
&self,
w: &mut W,
) -> Result<(), SerializationError> {
w.write_all(&[5]).await?;
self.chosen_method.write(w).await
) -> Result<(), ServerChoiceWriteError> {
w.write_u8(5).await?;
self.chosen_method.write(w).await?;
Ok(())
}
}
standard_roundtrip!(server_choice_roundtrips, ServerChoice);
crate::standard_roundtrip!(server_choice_roundtrips, ServerChoice);
#[test]
fn check_short_reads() {
#[tokio::test]
async fn check_short_reads() {
let empty = vec![];
let mut cursor = Cursor::new(empty);
let ys = ServerChoice::read(&mut cursor);
assert_eq!(Err(DeserializationError::NotEnoughData), task::block_on(ys));
let ys = ServerChoice::read(&mut cursor).await;
assert!(matches!(ys, Err(ServerChoiceReadError::ReadError(_))));
let bad_len = vec![5];
let mut cursor = Cursor::new(bad_len);
let ys = ServerChoice::read(&mut cursor);
assert_eq!(
Err(DeserializationError::AuthenticationMethodError(
AuthenticationDeserializationError::NoDataFound
)),
task::block_on(ys)
);
let ys = ServerChoice::read(&mut cursor).await;
assert!(matches!(ys, Err(ServerChoiceReadError::AuthMethodError(_))));
}
#[test]
fn check_bad_version() {
#[tokio::test]
async fn check_bad_version() {
let no_len = vec![9, 1];
let mut cursor = Cursor::new(no_len);
let ys = ServerChoice::read(&mut cursor);
assert_eq!(
Err(DeserializationError::InvalidVersion(5, 9)),
task::block_on(ys)
);
let ys = ServerChoice::read(&mut cursor).await;
assert_eq!(Err(ServerChoiceReadError::InvalidVersion(9)), ys);
}