Workspacify
This commit is contained in:
18
ssh/Cargo.toml
Normal file
18
ssh/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "ssh"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
bytes = { workspace = true }
|
||||
configuration = { workspace = true }
|
||||
error-stack = { workspace = true }
|
||||
getrandom = { workspace = true }
|
||||
itertools = { workspace = true }
|
||||
num_enum = { workspace = true }
|
||||
proptest = { workspace = true }
|
||||
proptest-derive = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_chacha = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
428
ssh/src/channel.rs
Normal file
428
ssh/src/channel.rs
Normal file
@@ -0,0 +1,428 @@
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::strategy::{BoxedStrategy, Strategy};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
const MAX_BUFFER_SIZE: usize = 64 * 1024;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct SshPacket {
|
||||
pub buffer: Bytes,
|
||||
}
|
||||
|
||||
pub struct SshChannel<Stream> {
|
||||
read_side: Mutex<ReadSide<Stream>>,
|
||||
write_side: Mutex<WriteSide<Stream>>,
|
||||
cipher_block_size: usize,
|
||||
mac_length: usize,
|
||||
channel_is_closed: bool,
|
||||
}
|
||||
|
||||
struct ReadSide<Stream> {
|
||||
stream: ReadHalf<Stream>,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
|
||||
struct WriteSide<Stream> {
|
||||
stream: WriteHalf<Stream>,
|
||||
buffer: BytesMut,
|
||||
rng: ChaCha20Rng,
|
||||
}
|
||||
|
||||
impl<Stream> SshChannel<Stream>
|
||||
where
|
||||
Stream: AsyncReadExt + AsyncWriteExt,
|
||||
Stream: Send + Sync,
|
||||
Stream: Unpin,
|
||||
{
|
||||
/// Create a new SSH channel.
|
||||
///
|
||||
/// SshChannels are designed to make sure the various SSH channel read /
|
||||
/// write operations are cancel- and concurrency-safe. They take ownership
|
||||
/// of the underlying stream once established, where "established" means
|
||||
/// that we have read and written the initial SSH banners.
|
||||
pub fn new(stream: Stream) -> Result<SshChannel<Stream>, getrandom::Error> {
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
|
||||
Ok(SshChannel {
|
||||
read_side: Mutex::new(ReadSide {
|
||||
stream: read_half,
|
||||
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
|
||||
}),
|
||||
write_side: Mutex::new(WriteSide {
|
||||
stream: write_half,
|
||||
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
|
||||
rng: ChaCha20Rng::try_from_os_rng()?,
|
||||
}),
|
||||
mac_length: 0,
|
||||
cipher_block_size: 8,
|
||||
channel_is_closed: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Read an SshPacket from the wire.
|
||||
///
|
||||
/// This function is cancel safe, and can be used in `select` (or similar)
|
||||
/// without problems. It is also safe to be used in a multitasking setting.
|
||||
/// Returns Ok(Some(...)) if a packet is found, or Ok(None) if we have
|
||||
/// successfully reached the end of stream. It will also return Ok(None)
|
||||
/// repeatedly after the stream is closed.
|
||||
pub async fn read(&self) -> Result<Option<SshPacket>, std::io::Error> {
|
||||
if self.channel_is_closed {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut reader = self.read_side.lock().await;
|
||||
let mut local_buffer = vec![0; 4096];
|
||||
|
||||
// First, let's try to at least get a size in there.
|
||||
while reader.buffer.len() < 5 {
|
||||
let amt_read = reader.stream.read(&mut local_buffer).await?;
|
||||
reader.buffer.extend_from_slice(&local_buffer[0..amt_read]);
|
||||
}
|
||||
|
||||
let packet_size = ((reader.buffer[0] as usize) << 24)
|
||||
| ((reader.buffer[1] as usize) << 16)
|
||||
| ((reader.buffer[2] as usize) << 8)
|
||||
| reader.buffer[3] as usize;
|
||||
let padding_size = reader.buffer[4] as usize;
|
||||
let total_size = 4 + 1 + packet_size + padding_size + self.mac_length;
|
||||
tracing::trace!(
|
||||
packet_size,
|
||||
padding_size,
|
||||
total_size,
|
||||
"Initial packet information determined"
|
||||
);
|
||||
|
||||
// Now we need to make sure that the buffer contains at least that
|
||||
// many bytes. We do this transfer -- from the wire to an internal
|
||||
// buffer -- to ensure cancel safety. If, at any point, this computation
|
||||
// is cancelled, the lock will be released and the buffer will be in
|
||||
// a reasonable place. A subsequent call should be able to pick up
|
||||
// wherever we left off.
|
||||
while reader.buffer.len() < total_size {
|
||||
let amt_read = reader.stream.read(&mut local_buffer).await?;
|
||||
reader.buffer.extend_from_slice(&local_buffer[0..amt_read]);
|
||||
}
|
||||
|
||||
let mut new_packet = reader.buffer.split_to(total_size);
|
||||
|
||||
let _header = new_packet.split_to(5);
|
||||
let payload = new_packet.split_to(total_size - padding_size - 5);
|
||||
let _mac = new_packet.split_off(padding_size);
|
||||
|
||||
Ok(Some(SshPacket {
|
||||
buffer: payload.freeze(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn encode(&self, rng: &mut ChaCha20Rng, packet: SshPacket) -> Option<Bytes> {
|
||||
let mut encoded_packet = BytesMut::new();
|
||||
|
||||
// Arbitrary-length padding, such that the total length of
|
||||
// (packet_length || padding_length || payload || random padding)
|
||||
// is a multiple of the cipher block size or 8, whichever is
|
||||
// larger. There MUST be at least four bytes of padding. The
|
||||
// padding SHOULD consist of random bytes. The maximum amount of
|
||||
// padding is 255 bytes.
|
||||
let paddingless_length = 4 + 1 + packet.buffer.len();
|
||||
// the padding we need to get to an even multiple of the cipher
|
||||
// block size, naturally jumping to the cipher block size if we're
|
||||
// already aligned. (this is just easier, and since we can't have
|
||||
// 0 as the final padding, seems reasonable to do.)
|
||||
let mut rounded_padding =
|
||||
self.cipher_block_size - (paddingless_length % self.cipher_block_size);
|
||||
// now we enforce the must be greater than or equal to 4 rule
|
||||
if rounded_padding < 4 {
|
||||
rounded_padding += self.cipher_block_size;
|
||||
}
|
||||
// if this ends up being > 256, then we've run into something terrible
|
||||
if rounded_padding > (u8::MAX as usize) {
|
||||
tracing::error!(
|
||||
payload_length = packet.buffer.len(),
|
||||
cipher_block_size = ?self.cipher_block_size,
|
||||
computed_padding = ?rounded_padding,
|
||||
"generated incoherent padding value in write"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
encoded_packet.put_u32(packet.buffer.len() as u32);
|
||||
encoded_packet.put_u8(rounded_padding as u8);
|
||||
encoded_packet.put(packet.buffer);
|
||||
|
||||
for _ in 0..rounded_padding {
|
||||
encoded_packet.put_u8(rng.random());
|
||||
}
|
||||
|
||||
Some(encoded_packet.freeze())
|
||||
}
|
||||
|
||||
/// Write an SshPacket to the wire.
|
||||
///
|
||||
/// This function is cancel safe, and can be used in `select` (or similar).
|
||||
/// By cancel safe, we mean that one of the following outcomes is guaranteed
|
||||
/// to occur if the operation is cancelled:
|
||||
///
|
||||
/// 1. The whole packet is written to the channel.
|
||||
/// 2. No part of the packet is written to the channel.
|
||||
/// 3. The channel is dead, and no further data can be written to it.
|
||||
///
|
||||
/// Note that this means that you cannot assume that the packet is not
|
||||
/// written if the operation is cancelled, it just ensures that you will
|
||||
/// not be in a place in which only part of the packet has been written.
|
||||
pub async fn write(&self, packet: SshPacket) -> Result<(), std::io::Error> {
|
||||
let mut final_data = { self.encode(&mut self.write_side.lock().await.rng, packet) };
|
||||
|
||||
loop {
|
||||
if self.channel_is_closed {
|
||||
return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
|
||||
}
|
||||
|
||||
let mut writer = self.write_side.lock().await;
|
||||
|
||||
if let Some(bytes) = final_data.take() {
|
||||
if bytes.len() + writer.buffer.len() < MAX_BUFFER_SIZE {
|
||||
writer.buffer.put(bytes);
|
||||
} else if !bytes.is_empty() {
|
||||
final_data = Some(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
let mut current_buffer = std::mem::take(&mut writer.buffer);
|
||||
let _written = writer.stream.write_buf(&mut current_buffer).await?;
|
||||
writer.buffer = current_buffer;
|
||||
|
||||
if writer.buffer.is_empty() && final_data.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for SshPacket {
|
||||
type Parameters = bool;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(start_with_real_message: Self::Parameters) -> Self::Strategy {
|
||||
if start_with_real_message {
|
||||
unimplemented!()
|
||||
} else {
|
||||
let data = proptest::collection::vec(u8::arbitrary(), 0..35000);
|
||||
data.prop_map(|x| SshPacket { buffer: x.into() }).boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum SendData {
|
||||
Left(SshPacket),
|
||||
Right(SshPacket),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Arbitrary for SendData {
|
||||
type Parameters = <SshPacket as Arbitrary>::Parameters;
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||
(bool::arbitrary(), SshPacket::arbitrary_with(args))
|
||||
.prop_map(|(is_left, packet)| {
|
||||
if is_left {
|
||||
SendData::Left(packet)
|
||||
} else {
|
||||
SendData::Right(packet)
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn can_read_back_anything(packet in SshPacket::arbitrary()) {
|
||||
let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let (left, right) = tokio::io::duplex(8192);
|
||||
let leftssh = SshChannel::new(left).unwrap();
|
||||
let rightssh = SshChannel::new(right).unwrap();
|
||||
let packet_copy = packet.clone();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
leftssh.write(packet_copy).await.unwrap();
|
||||
});
|
||||
rightssh.read().await.unwrap()
|
||||
});
|
||||
|
||||
assert_eq!(packet, result.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sequences_send_correctly_serial(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) {
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let (left, right) = tokio::io::duplex(8192);
|
||||
let leftssh = SshChannel::new(left).unwrap();
|
||||
let rightssh = SshChannel::new(right).unwrap();
|
||||
|
||||
let sequence_left = sequence.clone();
|
||||
let sequence_right = sequence;
|
||||
|
||||
let left_task = tokio::task::spawn(async move {
|
||||
let mut errored = false;
|
||||
|
||||
for item in sequence_left.into_iter() {
|
||||
match item {
|
||||
SendData::Left(packet) => {
|
||||
let result = leftssh.write(packet).await;
|
||||
errored = result.is_err();
|
||||
}
|
||||
|
||||
SendData::Right(packet) => {
|
||||
if let Ok(Some(item)) = leftssh.read().await {
|
||||
errored = item != packet;
|
||||
} else {
|
||||
errored = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if errored {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
!errored
|
||||
});
|
||||
|
||||
let right_task = tokio::task::spawn(async move {
|
||||
let mut errored = false;
|
||||
|
||||
for item in sequence_right.into_iter() {
|
||||
match item {
|
||||
SendData::Right(packet) => {
|
||||
let result = rightssh.write(packet).await;
|
||||
errored = result.is_err();
|
||||
}
|
||||
|
||||
SendData::Left(packet) => {
|
||||
if let Ok(Some(item)) = rightssh.read().await {
|
||||
errored = item != packet;
|
||||
} else {
|
||||
errored = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if errored {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
!errored
|
||||
});
|
||||
|
||||
assert!(left_task.await.unwrap());
|
||||
assert!(right_task.await.unwrap());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sequences_send_correctly_parallel(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) {
|
||||
use std::sync::Arc;
|
||||
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let (left, right) = tokio::io::duplex(8192);
|
||||
let leftsshw = Arc::new(SshChannel::new(left).unwrap());
|
||||
let leftsshr = leftsshw.clone();
|
||||
let rightsshw = Arc::new(SshChannel::new(right).unwrap());
|
||||
let rightsshr = rightsshw.clone();
|
||||
|
||||
let sequence_left_write = sequence.clone();
|
||||
let sequence_left_read = sequence.clone();
|
||||
let sequence_right_write = sequence.clone();
|
||||
let sequence_right_read = sequence.clone();
|
||||
|
||||
let left_task_write = tokio::task::spawn(async move {
|
||||
let mut errored = false;
|
||||
|
||||
for item in sequence_left_write.into_iter() {
|
||||
if let SendData::Left(packet) = item {
|
||||
let result = leftsshw.write(packet).await;
|
||||
errored = result.is_err();
|
||||
}
|
||||
|
||||
if errored {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
!errored
|
||||
});
|
||||
|
||||
let right_task_write = tokio::task::spawn(async move {
|
||||
let mut errored = false;
|
||||
|
||||
for item in sequence_right_write.into_iter() {
|
||||
if let SendData::Right(packet) = item {
|
||||
let result = rightsshw.write(packet).await;
|
||||
errored = result.is_err();
|
||||
}
|
||||
|
||||
if errored {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
!errored
|
||||
});
|
||||
|
||||
let left_task_read = tokio::task::spawn(async move {
|
||||
let mut errored = false;
|
||||
|
||||
for item in sequence_left_read.into_iter() {
|
||||
if let SendData::Right(packet) = item {
|
||||
if let Ok(Some(item)) = leftsshr.read().await {
|
||||
errored = item != packet;
|
||||
} else {
|
||||
errored = true;
|
||||
}
|
||||
}
|
||||
|
||||
if errored {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
!errored
|
||||
});
|
||||
|
||||
let right_task_read = tokio::task::spawn(async move {
|
||||
let mut errored = false;
|
||||
|
||||
for item in sequence_right_read.into_iter() {
|
||||
if let SendData::Left(packet) = item {
|
||||
if let Ok(Some(item)) = rightsshr.read().await {
|
||||
errored = item != packet;
|
||||
} else {
|
||||
errored = true;
|
||||
}
|
||||
}
|
||||
|
||||
if errored {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
!errored
|
||||
});
|
||||
|
||||
assert!(left_task_write.await.unwrap());
|
||||
assert!(right_task_write.await.unwrap());
|
||||
assert!(left_task_read.await.unwrap());
|
||||
assert!(right_task_read.await.unwrap());
|
||||
});
|
||||
}
|
||||
}
|
||||
11
ssh/src/lib.rs
Normal file
11
ssh/src/lib.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
mod channel;
|
||||
mod message_ids;
|
||||
mod operational_error;
|
||||
mod packets;
|
||||
mod preamble;
|
||||
|
||||
pub use channel::SshChannel;
|
||||
pub use message_ids::SshMessageID;
|
||||
pub use operational_error::OperationalError;
|
||||
pub use packets::{SshKeyExchange, SshKeyExchangeProcessingError};
|
||||
pub use preamble::Preamble;
|
||||
173
ssh/src/message_ids.rs
Normal file
173
ssh/src/message_ids.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use crate::operational_error::OperationalError;
|
||||
use num_enum::{FromPrimitive, IntoPrimitive};
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::strategy::{BoxedStrategy, Just, Strategy};
|
||||
use std::fmt;
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, FromPrimitive, IntoPrimitive)]
|
||||
#[repr(u8)]
|
||||
pub enum SshMessageID {
|
||||
SSH_MSG_DISCONNECT = 1,
|
||||
SSH_MSG_IGNORE = 2,
|
||||
SSH_MSG_UNIMPLEMENTED = 3,
|
||||
SSH_MSG_DEBUG = 4,
|
||||
SSH_MSG_SERVICE_REQUEST = 5,
|
||||
SSH_MSG_SERVICE_ACCEPT = 6,
|
||||
SSH_MSG_KEXINIT = 20,
|
||||
SSH_MSG_NEWKEYS = 21,
|
||||
SSH_MSG_USERAUTH_REQUEST = 50,
|
||||
SSH_MSG_USERAUTH_FAILURE = 51,
|
||||
SSH_MSG_USERAUTH_SUCCESS = 52,
|
||||
SSH_MSG_USERAUTH_BANNER = 53,
|
||||
SSH_MSG_GLOBAL_REQUEST = 80,
|
||||
SSH_MSG_REQUEST_SUCCESS = 81,
|
||||
SSH_MSG_REQUEST_FAILURE = 82,
|
||||
SSH_MSG_CHANNEL_OPEN = 90,
|
||||
SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 91,
|
||||
SSH_MSG_CHANNEL_OPEN_FAILURE = 92,
|
||||
SSH_MSG_CHANNEL_WINDOW_ADJUST = 93,
|
||||
SSH_MSG_CHANNEL_DATA = 94,
|
||||
SSH_MSG_CHANNEL_EXTENDED_DATA = 95,
|
||||
SSH_MSG_CHANNEL_EOF = 96,
|
||||
SSH_MSG_CHANNEL_CLOSE = 97,
|
||||
SSH_MSG_CHANNEL_REQUEST = 98,
|
||||
SSH_MSG_CHANNEL_SUCCESS = 99,
|
||||
SSH_MSG_CHANNEL_FAILURE = 100,
|
||||
#[num_enum(catch_all)]
|
||||
Unknown(u8),
|
||||
}
|
||||
|
||||
impl fmt::Display for SshMessageID {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SshMessageID::SSH_MSG_DISCONNECT => write!(f, "SSH_MSG_DISCONNECT"),
|
||||
SshMessageID::SSH_MSG_IGNORE => write!(f, "SSH_MSG_IGNORE"),
|
||||
SshMessageID::SSH_MSG_UNIMPLEMENTED => write!(f, "SSH_MSG_UNIMPLEMENTED"),
|
||||
SshMessageID::SSH_MSG_DEBUG => write!(f, "SSH_MSG_DEBUG"),
|
||||
SshMessageID::SSH_MSG_SERVICE_REQUEST => write!(f, "SSH_MSG_SERVICE_REQUEST"),
|
||||
SshMessageID::SSH_MSG_SERVICE_ACCEPT => write!(f, "SSH_MSG_SERVICE_ACCEPT"),
|
||||
SshMessageID::SSH_MSG_KEXINIT => write!(f, "SSH_MSG_KEXINIT"),
|
||||
SshMessageID::SSH_MSG_NEWKEYS => write!(f, "SSH_MSG_NEWKEYS"),
|
||||
SshMessageID::SSH_MSG_USERAUTH_REQUEST => write!(f, "SSH_MSG_USERAUTH_REQUEST"),
|
||||
SshMessageID::SSH_MSG_USERAUTH_FAILURE => write!(f, "SSH_MSG_USERAUTH_FAILURE"),
|
||||
SshMessageID::SSH_MSG_USERAUTH_SUCCESS => write!(f, "SSH_MSG_USERAUTH_SUCCESS"),
|
||||
SshMessageID::SSH_MSG_USERAUTH_BANNER => write!(f, "SSH_MSG_USERAUTH_BANNER"),
|
||||
SshMessageID::SSH_MSG_GLOBAL_REQUEST => write!(f, "SSH_MSG_GLOBAL_REQUEST"),
|
||||
SshMessageID::SSH_MSG_REQUEST_SUCCESS => write!(f, "SSH_MSG_REQUEST_SUCCESS"),
|
||||
SshMessageID::SSH_MSG_REQUEST_FAILURE => write!(f, "SSH_MSG_REQUEST_FAILURE"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_OPEN => write!(f, "SSH_MSG_CHANNEL_OPEN"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_OPEN_CONFIRMATION => {
|
||||
write!(f, "SSH_MSG_CHANNEL_OPEN_CONFIRMATION")
|
||||
}
|
||||
SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE => write!(f, "SSH_MSG_CHANNEL_OPEN_FAILURE"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_WINDOW_ADJUST => {
|
||||
write!(f, "SSH_MSG_CHANNEL_WINDOW_ADJUST")
|
||||
}
|
||||
SshMessageID::SSH_MSG_CHANNEL_DATA => write!(f, "SSH_MSG_CHANNEL_DATA"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_EXTENDED_DATA => {
|
||||
write!(f, "SSH_MSG_CHANNEL_EXTENDED_DATA")
|
||||
}
|
||||
SshMessageID::SSH_MSG_CHANNEL_EOF => write!(f, "SSH_MSG_CHANNEL_EOF"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_CLOSE => write!(f, "SSH_MSG_CHANNEL_CLOSE"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_REQUEST => write!(f, "SSH_MSG_CHANNEL_REQUEST"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_SUCCESS => write!(f, "SSH_MSG_CHANNEL_SUCCESS"),
|
||||
SshMessageID::SSH_MSG_CHANNEL_FAILURE => write!(f, "SSH_MSG_CHANNEL_FAILURE"),
|
||||
SshMessageID::Unknown(x) => write!(f, "SSH_MSG_UNKNOWN{}", x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_duplicate_messages() {
|
||||
let mut found = std::collections::HashSet::new();
|
||||
|
||||
for i in u8::MIN..=u8::MAX {
|
||||
let id = SshMessageID::from_primitive(i);
|
||||
let display = id.to_string();
|
||||
|
||||
assert!(!found.contains(&display));
|
||||
found.insert(display);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SshMessageID> for OperationalError {
|
||||
fn from(message: SshMessageID) -> Self {
|
||||
match message {
|
||||
SshMessageID::SSH_MSG_DISCONNECT => OperationalError::Disconnect,
|
||||
SshMessageID::SSH_MSG_USERAUTH_FAILURE => OperationalError::UserAuthFailed,
|
||||
SshMessageID::SSH_MSG_REQUEST_FAILURE => OperationalError::RequestFailed,
|
||||
SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE => OperationalError::OpenChannelFailure,
|
||||
SshMessageID::SSH_MSG_CHANNEL_EOF => OperationalError::OtherEof,
|
||||
SshMessageID::SSH_MSG_CHANNEL_CLOSE => OperationalError::OtherClosed,
|
||||
SshMessageID::SSH_MSG_CHANNEL_FAILURE => OperationalError::ChannelFailure,
|
||||
|
||||
_ => OperationalError::UnexpectedMessage { message },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<OperationalError> for SshMessageID {
|
||||
type Error = OperationalError;
|
||||
|
||||
fn try_from(value: OperationalError) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
OperationalError::Disconnect => Ok(SshMessageID::SSH_MSG_DISCONNECT),
|
||||
OperationalError::UserAuthFailed => Ok(SshMessageID::SSH_MSG_USERAUTH_FAILURE),
|
||||
OperationalError::RequestFailed => Ok(SshMessageID::SSH_MSG_REQUEST_FAILURE),
|
||||
OperationalError::OpenChannelFailure => Ok(SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE),
|
||||
OperationalError::OtherEof => Ok(SshMessageID::SSH_MSG_CHANNEL_EOF),
|
||||
OperationalError::OtherClosed => Ok(SshMessageID::SSH_MSG_CHANNEL_CLOSE),
|
||||
OperationalError::ChannelFailure => Ok(SshMessageID::SSH_MSG_CHANNEL_FAILURE),
|
||||
OperationalError::UnexpectedMessage { message } => Ok(message),
|
||||
|
||||
_ => Err(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for SshMessageID {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
|
||||
proptest::prop_oneof![
|
||||
Just(SshMessageID::SSH_MSG_DISCONNECT),
|
||||
Just(SshMessageID::SSH_MSG_IGNORE),
|
||||
Just(SshMessageID::SSH_MSG_UNIMPLEMENTED),
|
||||
Just(SshMessageID::SSH_MSG_DEBUG),
|
||||
Just(SshMessageID::SSH_MSG_SERVICE_REQUEST),
|
||||
Just(SshMessageID::SSH_MSG_SERVICE_ACCEPT),
|
||||
Just(SshMessageID::SSH_MSG_KEXINIT),
|
||||
Just(SshMessageID::SSH_MSG_NEWKEYS),
|
||||
Just(SshMessageID::SSH_MSG_USERAUTH_REQUEST),
|
||||
Just(SshMessageID::SSH_MSG_USERAUTH_FAILURE),
|
||||
Just(SshMessageID::SSH_MSG_USERAUTH_SUCCESS),
|
||||
Just(SshMessageID::SSH_MSG_USERAUTH_BANNER),
|
||||
Just(SshMessageID::SSH_MSG_GLOBAL_REQUEST),
|
||||
Just(SshMessageID::SSH_MSG_REQUEST_SUCCESS),
|
||||
Just(SshMessageID::SSH_MSG_REQUEST_FAILURE),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_OPEN),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_OPEN_CONFIRMATION),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_WINDOW_ADJUST),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_DATA),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_EXTENDED_DATA),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_EOF),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_CLOSE),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_REQUEST),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_SUCCESS),
|
||||
Just(SshMessageID::SSH_MSG_CHANNEL_FAILURE),
|
||||
]
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn error_encodings_invert(message in SshMessageID::arbitrary()) {
|
||||
let error_version = OperationalError::from(message);
|
||||
let back_to_message = SshMessageID::try_from(error_version).unwrap();
|
||||
assert_eq!(message, back_to_message);
|
||||
}
|
||||
}
|
||||
51
ssh/src/operational_error.rs
Normal file
51
ssh/src/operational_error.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use crate::{SshKeyExchangeProcessingError, SshMessageID};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OperationalError {
|
||||
#[error("Configuration error")]
|
||||
ConfigurationError,
|
||||
#[error("DNS client configuration error")]
|
||||
DnsConfig,
|
||||
#[error("Failed to connect to target address")]
|
||||
Connection,
|
||||
#[error("Failure during key exchange / agreement protocol")]
|
||||
KeyExchange,
|
||||
#[error("Failed to complete initial read: {0}")]
|
||||
InitialRead(std::io::Error),
|
||||
#[error("SSH banner was not formatted in UTF-8: {0}")]
|
||||
BannerError(std::str::Utf8Error),
|
||||
#[error("Invalid initial SSH versionling line: {line}")]
|
||||
InvalidHeaderLine { line: String },
|
||||
#[error("Error writing initial banner: {0}")]
|
||||
WriteBanner(std::io::Error),
|
||||
#[error("Unexpected disconnect from other side.")]
|
||||
Disconnect,
|
||||
#[error("{message} in unexpected place.")]
|
||||
UnexpectedMessage { message: SshMessageID },
|
||||
#[error("User authorization failed.")]
|
||||
UserAuthFailed,
|
||||
#[error("Request failed.")]
|
||||
RequestFailed,
|
||||
#[error("Failed to open channel.")]
|
||||
OpenChannelFailure,
|
||||
#[error("Other side closed connection.")]
|
||||
OtherClosed,
|
||||
#[error("Other side sent EOF.")]
|
||||
OtherEof,
|
||||
#[error("Channel failed.")]
|
||||
ChannelFailure,
|
||||
#[error("Error in initial handshake: {0}")]
|
||||
KeyxProcessingError(#[from] SshKeyExchangeProcessingError),
|
||||
#[error("Invalid port number '{port_string}': {error}")]
|
||||
InvalidPort {
|
||||
port_string: String,
|
||||
error: std::num::ParseIntError,
|
||||
},
|
||||
#[error("Invalid hostname '{0}'")]
|
||||
InvalidHostname(String),
|
||||
#[error("Unable to parse host address")]
|
||||
UnableToParseHostAddress,
|
||||
#[error("Unable to configure resolver")]
|
||||
Resolver,
|
||||
}
|
||||
3
ssh/src/packets.rs
Normal file
3
ssh/src/packets.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod key_exchange;
|
||||
|
||||
pub use key_exchange::{SshKeyExchange, SshKeyExchangeProcessingError};
|
||||
332
ssh/src/packets/key_exchange.rs
Normal file
332
ssh/src/packets/key_exchange.rs
Normal file
@@ -0,0 +1,332 @@
|
||||
use configuration::connection::ClientConnectionOpts;
|
||||
use crate::channel::SshPacket;
|
||||
use crate::message_ids::SshMessageID;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use itertools::Itertools;
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::strategy::{BoxedStrategy, Strategy};
|
||||
use rand::{CryptoRng, Rng, SeedableRng};
|
||||
use std::string::FromUtf8Error;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct SshKeyExchange {
|
||||
cookie: [u8; 16],
|
||||
keyx_algorithms: Vec<String>,
|
||||
server_host_key_algorithms: Vec<String>,
|
||||
encryption_algorithms_client_to_server: Vec<String>,
|
||||
encryption_algorithms_server_to_client: Vec<String>,
|
||||
mac_algorithms_client_to_server: Vec<String>,
|
||||
mac_algorithms_server_to_client: Vec<String>,
|
||||
compression_algorithms_client_to_server: Vec<String>,
|
||||
compression_algorithms_server_to_client: Vec<String>,
|
||||
languages_client_to_server: Vec<String>,
|
||||
languages_server_to_client: Vec<String>,
|
||||
first_kex_packet_follows: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SshKeyExchangeProcessingError {
|
||||
#[error("Message not appropriately tagged as SSH_MSG_KEXINIT")]
|
||||
TaggedWrong,
|
||||
#[error("Initial key exchange message was too short.")]
|
||||
TooShort,
|
||||
#[error("Invalid string encoding for name-list (not ASCII)")]
|
||||
NotAscii,
|
||||
#[error("Invalid conversion (from ASCII to UTF-8??): {0}")]
|
||||
NotUtf8(FromUtf8Error),
|
||||
#[error("Extraneous data at the end of key exchange message")]
|
||||
ExtraneousData,
|
||||
#[error("Received invalid reserved word ({0} != 0)")]
|
||||
InvalidReservedWord(u32),
|
||||
}
|
||||
|
||||
impl TryFrom<SshPacket> for SshKeyExchange {
|
||||
type Error = SshKeyExchangeProcessingError;
|
||||
|
||||
fn try_from(mut value: SshPacket) -> Result<Self, Self::Error> {
|
||||
if SshMessageID::from(value.buffer.get_u8()) != SshMessageID::SSH_MSG_KEXINIT {
|
||||
return Err(SshKeyExchangeProcessingError::TaggedWrong);
|
||||
}
|
||||
|
||||
let mut cookie = [0; 16];
|
||||
check_length(&mut value.buffer, 16)?;
|
||||
value.buffer.copy_to_slice(&mut cookie);
|
||||
|
||||
let keyx_algorithms = name_list(&mut value.buffer)?;
|
||||
let server_host_key_algorithms = name_list(&mut value.buffer)?;
|
||||
let encryption_algorithms_client_to_server = name_list(&mut value.buffer)?;
|
||||
let encryption_algorithms_server_to_client = name_list(&mut value.buffer)?;
|
||||
let mac_algorithms_client_to_server = name_list(&mut value.buffer)?;
|
||||
let mac_algorithms_server_to_client = name_list(&mut value.buffer)?;
|
||||
let compression_algorithms_client_to_server = name_list(&mut value.buffer)?;
|
||||
let compression_algorithms_server_to_client = name_list(&mut value.buffer)?;
|
||||
let languages_client_to_server = name_list(&mut value.buffer)?;
|
||||
let languages_server_to_client = name_list(&mut value.buffer)?;
|
||||
check_length(&mut value.buffer, 5)?;
|
||||
let first_kex_packet_follows = value.buffer.get_u8() != 0;
|
||||
let reserved = value.buffer.get_u32();
|
||||
if reserved != 0 {
|
||||
return Err(SshKeyExchangeProcessingError::InvalidReservedWord(reserved));
|
||||
}
|
||||
|
||||
if value.buffer.remaining() > 0 {
|
||||
return Err(SshKeyExchangeProcessingError::ExtraneousData);
|
||||
}
|
||||
|
||||
Ok(SshKeyExchange {
|
||||
cookie,
|
||||
keyx_algorithms,
|
||||
server_host_key_algorithms,
|
||||
encryption_algorithms_client_to_server,
|
||||
encryption_algorithms_server_to_client,
|
||||
mac_algorithms_client_to_server,
|
||||
mac_algorithms_server_to_client,
|
||||
compression_algorithms_client_to_server,
|
||||
compression_algorithms_server_to_client,
|
||||
languages_client_to_server,
|
||||
languages_server_to_client,
|
||||
first_kex_packet_follows,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SshKeyExchange> for SshPacket {
|
||||
fn from(value: SshKeyExchange) -> Self {
|
||||
let mut buffer = BytesMut::new();
|
||||
|
||||
let put_options = |buffer: &mut BytesMut, vals: Vec<String>| {
|
||||
let mut merged = String::new();
|
||||
#[allow(unstable_name_collisions)]
|
||||
let comma_sepped = vals.into_iter().intersperse(String::from(","));
|
||||
merged.extend(comma_sepped);
|
||||
let bytes = merged.as_bytes();
|
||||
buffer.put_u32(bytes.len() as u32);
|
||||
buffer.put_slice(bytes);
|
||||
};
|
||||
|
||||
buffer.put_u8(SshMessageID::SSH_MSG_KEXINIT.into());
|
||||
buffer.put_slice(&value.cookie);
|
||||
put_options(&mut buffer, value.keyx_algorithms);
|
||||
put_options(&mut buffer, value.server_host_key_algorithms);
|
||||
put_options(&mut buffer, value.encryption_algorithms_client_to_server);
|
||||
put_options(&mut buffer, value.encryption_algorithms_server_to_client);
|
||||
put_options(&mut buffer, value.mac_algorithms_client_to_server);
|
||||
put_options(&mut buffer, value.mac_algorithms_server_to_client);
|
||||
put_options(&mut buffer, value.compression_algorithms_client_to_server);
|
||||
put_options(&mut buffer, value.compression_algorithms_server_to_client);
|
||||
put_options(&mut buffer, value.languages_client_to_server);
|
||||
put_options(&mut buffer, value.languages_server_to_client);
|
||||
buffer.put_u8(value.first_kex_packet_follows as u8);
|
||||
buffer.put_u32(0);
|
||||
|
||||
SshPacket {
|
||||
buffer: buffer.freeze(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SshKeyExchange {
|
||||
/// Create a new SshKeyExchange message for this client or server based
|
||||
/// on the given connection options.
|
||||
///
|
||||
/// This function takes a random number generator because it needs to
|
||||
/// seed the message with a random cookie, but is otherwise deterministic.
|
||||
/// It will fail only in the case that the underlying random number
|
||||
/// generator fails, and return exactly that error.
|
||||
pub fn new<R>(rng: &mut R, value: ClientConnectionOpts) -> Result<Self, ()>
|
||||
where
|
||||
R: CryptoRng + Rng,
|
||||
{
|
||||
let mut result = SshKeyExchange {
|
||||
cookie: [0; 16],
|
||||
keyx_algorithms: value
|
||||
.key_exchange_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
server_host_key_algorithms: value
|
||||
.server_host_key_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
encryption_algorithms_client_to_server: value
|
||||
.encryption_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
encryption_algorithms_server_to_client: value
|
||||
.encryption_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
mac_algorithms_client_to_server: value
|
||||
.mac_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
mac_algorithms_server_to_client: value
|
||||
.mac_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
compression_algorithms_client_to_server: value
|
||||
.compression_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
compression_algorithms_server_to_client: value
|
||||
.compression_algorithms
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
languages_client_to_server: value.languages.to_vec(),
|
||||
languages_server_to_client: value.languages.to_vec(),
|
||||
first_kex_packet_follows: value.predict.is_some(),
|
||||
};
|
||||
|
||||
rng.fill(&mut result.cookie);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Arbitrary for SshKeyExchange {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with(_: Self::Parameters) -> BoxedStrategy<Self> {
|
||||
let client_config = ClientConnectionOpts::arbitrary();
|
||||
let seed = <[u8; 32]>::arbitrary();
|
||||
|
||||
(client_config, seed)
|
||||
.prop_map(|(config, seed)| {
|
||||
let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed);
|
||||
SshKeyExchange::new(&mut rng, config).unwrap()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn valid_kex_messages_parse(kex in SshKeyExchange::arbitrary()) {
|
||||
let as_packet: SshPacket = kex.clone().try_into().expect("can generate packet");
|
||||
let as_message = as_packet.try_into().expect("can regenerate message");
|
||||
assert_eq!(kex, as_message);
|
||||
}
|
||||
}
|
||||
|
||||
fn check_length(buffer: &mut Bytes, length: usize) -> Result<(), SshKeyExchangeProcessingError> {
|
||||
if buffer.remaining() < length {
|
||||
Err(SshKeyExchangeProcessingError::TooShort)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn name_list(buffer: &mut Bytes) -> Result<Vec<String>, SshKeyExchangeProcessingError> {
|
||||
check_length(buffer, 4)?;
|
||||
let list_length = buffer.get_u32() as usize;
|
||||
|
||||
if list_length == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
check_length(buffer, list_length)?;
|
||||
let mut raw_bytes = vec![0u8; list_length];
|
||||
buffer.copy_to_slice(&mut raw_bytes);
|
||||
if !raw_bytes.iter().all(|c| c.is_ascii()) {
|
||||
return Err(SshKeyExchangeProcessingError::NotAscii);
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
for split in raw_bytes.split(|b| char::from(*b) == ',') {
|
||||
let str =
|
||||
String::from_utf8(split.to_vec()).map_err(SshKeyExchangeProcessingError::NotUtf8)?;
|
||||
result.push(str);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn standard_kex_message() -> SshPacket {
|
||||
let seed = [0u8; 32];
|
||||
let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed);
|
||||
let config = ClientConnectionOpts::default();
|
||||
let message = SshKeyExchange::new(&mut rng, config).expect("default settings work");
|
||||
message.try_into().expect("default settings serialize")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_see_bad_tag() {
|
||||
let standard = standard_kex_message();
|
||||
let mut buffer = standard.buffer.to_vec();
|
||||
buffer[0] += 1;
|
||||
let bad = SshPacket {
|
||||
buffer: buffer.into(),
|
||||
};
|
||||
assert!(matches!(
|
||||
SshKeyExchange::try_from(bad),
|
||||
Err(SshKeyExchangeProcessingError::TaggedWrong)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checks_for_extraneous_data() {
|
||||
let standard = standard_kex_message();
|
||||
let mut buffer = standard.buffer.to_vec();
|
||||
buffer.push(3);
|
||||
let bad = SshPacket {
|
||||
buffer: buffer.into(),
|
||||
};
|
||||
assert!(matches!(
|
||||
SshKeyExchange::try_from(bad),
|
||||
Err(SshKeyExchangeProcessingError::ExtraneousData)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checks_for_short_packets() {
|
||||
let standard = standard_kex_message();
|
||||
let mut buffer = standard.buffer.to_vec();
|
||||
let _ = buffer.pop();
|
||||
let bad = SshPacket {
|
||||
buffer: buffer.into(),
|
||||
};
|
||||
assert!(matches!(
|
||||
SshKeyExchange::try_from(bad),
|
||||
Err(SshKeyExchangeProcessingError::TooShort)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checks_for_invalid_data() {
|
||||
let standard = standard_kex_message();
|
||||
let mut buffer = standard.buffer.to_vec();
|
||||
buffer[22] = 0xc3;
|
||||
buffer[23] = 0x28;
|
||||
let bad = SshPacket {
|
||||
buffer: buffer.into(),
|
||||
};
|
||||
assert!(matches!(
|
||||
SshKeyExchange::try_from(bad),
|
||||
Err(SshKeyExchangeProcessingError::NotAscii)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checks_for_bad_reserved_word() {
|
||||
let standard = standard_kex_message();
|
||||
let mut buffer = standard.buffer.to_vec();
|
||||
let _ = buffer.pop();
|
||||
buffer.push(1);
|
||||
let bad = SshPacket {
|
||||
buffer: buffer.into(),
|
||||
};
|
||||
assert!(matches!(
|
||||
SshKeyExchange::try_from(bad),
|
||||
Err(SshKeyExchangeProcessingError::InvalidReservedWord(1))
|
||||
));
|
||||
}
|
||||
392
ssh/src/preamble.rs
Normal file
392
ssh/src/preamble.rs
Normal file
@@ -0,0 +1,392 @@
|
||||
use error_stack::{report, ResultExt};
|
||||
use proptest::arbitrary::Arbitrary;
|
||||
use proptest::strategy::{BoxedStrategy, Strategy};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct Preamble {
|
||||
pub preamble: String,
|
||||
pub software_name: String,
|
||||
pub software_version: String,
|
||||
pub commentary: String,
|
||||
}
|
||||
|
||||
impl Arbitrary for Preamble {
|
||||
type Parameters = ();
|
||||
type Strategy = BoxedStrategy<Self>;
|
||||
|
||||
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
|
||||
let name = proptest::string::string_regex("[[:alpha:]][[:alnum:]]{0,32}").unwrap();
|
||||
let soft_major = u8::arbitrary();
|
||||
let soft_minor = u8::arbitrary();
|
||||
let soft_patch = proptest::option::of(u8::arbitrary());
|
||||
let commentary = proptest::option::of(
|
||||
proptest::string::string_regex("[[:alnum:]][[[:alnum:]][[:blank:]][[:punct:]]]{0,64}")
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
(name, soft_major, soft_minor, soft_patch, commentary)
|
||||
.prop_map(|(name, major, minor, patch, commentary)| Preamble {
|
||||
preamble: String::new(),
|
||||
software_name: name,
|
||||
software_version: if let Some(patch) = patch {
|
||||
format!("{}.{}.{}", major, minor, patch)
|
||||
} else {
|
||||
format!("{}.{}", major, minor)
|
||||
},
|
||||
commentary: commentary.unwrap_or_default(),
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Preamble {
|
||||
fn default() -> Self {
|
||||
Preamble {
|
||||
preamble: String::new(),
|
||||
software_name: env!("CARGO_PKG_NAME").to_string(),
|
||||
software_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
commentary: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PreambleReadError {
|
||||
#[error("Reading from the input stream failed")]
|
||||
Read,
|
||||
#[error("Illegal version number, expected '2.0' (saw characters {0})")]
|
||||
IllegalVersion(String),
|
||||
#[error("No dash found after seeing SSH version number")]
|
||||
NoDashAfterVersion,
|
||||
#[error("Illegal character in SSH software name")]
|
||||
IllegalSoftwareNameChar,
|
||||
#[error("Protocol error in preamble: No line feed for carriage return")]
|
||||
NoLineFeedForCarriage,
|
||||
#[error("Missing the final newline in the preamble")]
|
||||
MissingFinalNewline,
|
||||
#[error("Illegal UTF-8 in software name")]
|
||||
InvalidSoftwareName,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PreambleWriteError {
|
||||
#[error("Could not write preamble to socket")]
|
||||
Write,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PreambleState {
|
||||
StartOfLine,
|
||||
Preamble,
|
||||
CarriageReturn,
|
||||
InitialS,
|
||||
SecondS,
|
||||
InitialH,
|
||||
InitialDash,
|
||||
Version2,
|
||||
VersionDot,
|
||||
Version0,
|
||||
VersionDash,
|
||||
SoftwareName,
|
||||
SoftwareVersion,
|
||||
Commentary,
|
||||
FinalCarriageReturn,
|
||||
}
|
||||
|
||||
impl Preamble {
|
||||
/// Read an SSH preamble from the given read channel.
|
||||
///
|
||||
/// Will fail if the underlying read channel fails, or if the preamble does not
|
||||
/// meet the formatting requirements of the RFC.
|
||||
pub async fn read<R: AsyncReadExt + Unpin>(
|
||||
connection: &mut R,
|
||||
) -> error_stack::Result<Preamble, PreambleReadError> {
|
||||
let mut preamble = String::new();
|
||||
let mut software_name_bytes = Vec::new();
|
||||
let mut software_version = String::new();
|
||||
let mut commentary = String::new();
|
||||
let mut state = PreambleState::StartOfLine;
|
||||
|
||||
loop {
|
||||
let next_byte = connection
|
||||
.read_u8()
|
||||
.await
|
||||
.change_context(PreambleReadError::Read)?;
|
||||
let next_char = char::from(next_byte);
|
||||
|
||||
tracing::trace!(?next_char, ?state, "processing next preamble character");
|
||||
match state {
|
||||
PreambleState::StartOfLine => match next_char {
|
||||
'S' => state = PreambleState::InitialS,
|
||||
_ => {
|
||||
preamble.push(next_char);
|
||||
state = PreambleState::Preamble;
|
||||
}
|
||||
},
|
||||
|
||||
PreambleState::Preamble => match next_char {
|
||||
'\r' => state = PreambleState::CarriageReturn,
|
||||
_ => preamble.push(next_char),
|
||||
},
|
||||
|
||||
PreambleState::CarriageReturn => match next_char {
|
||||
'\n' => state = PreambleState::StartOfLine,
|
||||
_ => return Err(report!(PreambleReadError::NoLineFeedForCarriage)),
|
||||
},
|
||||
|
||||
PreambleState::InitialS => match next_char {
|
||||
'S' => state = PreambleState::SecondS,
|
||||
_ => {
|
||||
preamble.push('S');
|
||||
preamble.push(next_char);
|
||||
state = PreambleState::Preamble;
|
||||
}
|
||||
},
|
||||
|
||||
PreambleState::SecondS => match next_char {
|
||||
'H' => state = PreambleState::InitialH,
|
||||
_ => {
|
||||
preamble.push_str("SS");
|
||||
preamble.push(next_char);
|
||||
state = PreambleState::Preamble;
|
||||
}
|
||||
},
|
||||
|
||||
PreambleState::InitialH => match next_char {
|
||||
'-' => state = PreambleState::InitialDash,
|
||||
_ => {
|
||||
preamble.push_str("SSH");
|
||||
preamble.push(next_char);
|
||||
state = PreambleState::InitialDash;
|
||||
}
|
||||
},
|
||||
|
||||
PreambleState::InitialDash => match next_char {
|
||||
'2' => state = PreambleState::Version2,
|
||||
_ => {
|
||||
return Err(report!(PreambleReadError::IllegalVersion(String::from(
|
||||
next_char
|
||||
))))
|
||||
}
|
||||
},
|
||||
|
||||
PreambleState::Version2 => match next_char {
|
||||
'.' => state = PreambleState::VersionDot,
|
||||
_ => {
|
||||
return Err(report!(PreambleReadError::IllegalVersion(format!(
|
||||
"2{}",
|
||||
next_char
|
||||
))))
|
||||
}
|
||||
},
|
||||
|
||||
PreambleState::VersionDot => match next_char {
|
||||
'0' => state = PreambleState::Version0,
|
||||
_ => {
|
||||
return Err(report!(PreambleReadError::IllegalVersion(format!(
|
||||
"2.{}",
|
||||
next_char
|
||||
))))
|
||||
}
|
||||
},
|
||||
|
||||
PreambleState::Version0 => match next_char {
|
||||
'-' => state = PreambleState::VersionDash,
|
||||
_ => return Err(report!(PreambleReadError::NoDashAfterVersion)),
|
||||
},
|
||||
|
||||
PreambleState::VersionDash => {
|
||||
software_name_bytes.push(next_byte);
|
||||
state = PreambleState::SoftwareName;
|
||||
}
|
||||
|
||||
PreambleState::SoftwareName => match next_char {
|
||||
'_' => state = PreambleState::SoftwareVersion,
|
||||
x if x == '-' || x.is_ascii_whitespace() => {
|
||||
return Err(report!(PreambleReadError::IllegalSoftwareNameChar))
|
||||
}
|
||||
_ => software_name_bytes.push(next_byte),
|
||||
},
|
||||
|
||||
PreambleState::SoftwareVersion => match next_char {
|
||||
' ' => state = PreambleState::Commentary,
|
||||
'\r' => state = PreambleState::FinalCarriageReturn,
|
||||
'_' => state = PreambleState::SoftwareVersion,
|
||||
x if x == '-' || x.is_ascii_whitespace() => {
|
||||
return Err(report!(PreambleReadError::IllegalSoftwareNameChar))
|
||||
.attach_printable_lazy(|| {
|
||||
format!("saw {:?} / {}", next_char, next_byte)
|
||||
})?
|
||||
}
|
||||
_ => software_version.push(next_char),
|
||||
},
|
||||
|
||||
PreambleState::FinalCarriageReturn => match next_char {
|
||||
'\n' => break,
|
||||
_ => return Err(report!(PreambleReadError::MissingFinalNewline)),
|
||||
},
|
||||
|
||||
PreambleState::Commentary => match next_char {
|
||||
'\r' => state = PreambleState::FinalCarriageReturn,
|
||||
_ => commentary.push(next_char),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let software_name = String::from_utf8(software_name_bytes)
|
||||
.change_context(PreambleReadError::InvalidSoftwareName)?;
|
||||
|
||||
Ok(Preamble {
|
||||
preamble,
|
||||
software_name,
|
||||
software_version,
|
||||
commentary,
|
||||
})
|
||||
}
|
||||
|
||||
// let mut read_buffer = vec![0; 4096];
|
||||
// let mut pre_message = String::new();
|
||||
// let protocol_version;
|
||||
// let software_name;
|
||||
// let software_version;
|
||||
// let commentary;
|
||||
// let mut prefix = String::new();
|
||||
//
|
||||
//
|
||||
// 'outer: loop {
|
||||
// let read_length = connection
|
||||
// .read(&mut read_buffer)
|
||||
// .await
|
||||
// .change_context(PreambleReadError::InitialRead)?;
|
||||
// let string_version = std::str::from_utf8(&read_buffer[0..read_length])
|
||||
// .change_context(PreambleReadError::BannerError)?;
|
||||
//
|
||||
// prefix.push_str(string_version);
|
||||
// let ends_with_newline = prefix.ends_with("\r\n");
|
||||
//
|
||||
// let new_prefix = if ends_with_newline {
|
||||
// // we are cleanly reading up to a \r\n, so our new prefix after
|
||||
// // this loop is empty
|
||||
// String::new()
|
||||
// } else if let Some((correct_bits, leftover)) = prefix.rsplit_once("\r\n") {
|
||||
// // there's some dangling bits in this read, so we'll cut this string
|
||||
// // at the final "\r\n" and then remember to use the leftover as the
|
||||
// // new prefix at the end of this loop
|
||||
// let result = leftover.to_string();
|
||||
// prefix = correct_bits.to_string();
|
||||
// result
|
||||
// } else {
|
||||
// // there's no "\r\n", so we don't have a full line yet, so keep reading
|
||||
// continue;
|
||||
// };
|
||||
//
|
||||
// for line in prefix.lines() {
|
||||
// if line.starts_with("SSH") {
|
||||
// let (_, interesting_bits) = line
|
||||
// .split_once('-')
|
||||
// .ok_or_else(||
|
||||
// report!(PreambleReadError::InvalidHeaderLine {
|
||||
// reason: "could not find dash after SSH",
|
||||
// line: line.to_string(),
|
||||
// }))?;
|
||||
//
|
||||
// let (protoversion, other_bits) = interesting_bits
|
||||
// .split_once('-')
|
||||
// .ok_or_else(|| report!(PreambleReadError::InvalidHeaderLine {
|
||||
// reason: "could not find dash after protocol version",
|
||||
// line: line.to_string(),
|
||||
// }))?;
|
||||
//
|
||||
// let (softwarever, comment) = match other_bits.split_once(' ') {
|
||||
// Some((s, c)) => (s, c),
|
||||
// None => (other_bits, ""),
|
||||
// };
|
||||
//
|
||||
// let (software_name_str, software_version_str) = softwarever
|
||||
// .split_once('_')
|
||||
// .ok_or_else(|| report!(PreambleReadError::InvalidHeaderLine {
|
||||
// reason: "could not find underscore between software name and version",
|
||||
// line: line.to_string(),
|
||||
// }))?;
|
||||
//
|
||||
// software_name = software_name_str.to_string();
|
||||
// software_version = software_version_str.to_string();
|
||||
// protocol_version = protoversion.to_string();
|
||||
// commentary = comment.to_string();
|
||||
// break 'outer;
|
||||
// } else {
|
||||
// pre_message.push_str(line);
|
||||
// pre_message.push('\n');
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// prefix = new_prefix;
|
||||
// }
|
||||
//
|
||||
// tracing::info!(
|
||||
// ?protocol_version,
|
||||
// ?software_version,
|
||||
// ?commentary,
|
||||
// "Got server information"
|
||||
// );
|
||||
//
|
||||
// Ok(Preamble {
|
||||
// protocol_version,
|
||||
// software_name,
|
||||
// software_version,
|
||||
// commentary,
|
||||
// })
|
||||
// }
|
||||
|
||||
/// Write a preamble to the given network socket.
|
||||
pub async fn write<W: AsyncWriteExt + Unpin>(
|
||||
&self,
|
||||
connection: &mut W,
|
||||
) -> error_stack::Result<(), PreambleWriteError> {
|
||||
let comment = if self.commentary.is_empty() {
|
||||
self.commentary.clone()
|
||||
} else {
|
||||
format!(" {}", self.commentary)
|
||||
};
|
||||
let output = format!(
|
||||
"SSH-2.0-{}_{}{}\r\n",
|
||||
self.software_name, self.software_version, comment
|
||||
);
|
||||
connection
|
||||
.write_all(output.as_bytes())
|
||||
.await
|
||||
.change_context(PreambleWriteError::Write)
|
||||
}
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
#[test]
|
||||
fn preamble_roundtrips(preamble in Preamble::arbitrary()) {
|
||||
let read_version = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let (mut writer, mut reader) = tokio::io::duplex(4096);
|
||||
preamble.write(&mut writer).await.unwrap();
|
||||
Preamble::read(&mut reader).await.unwrap()
|
||||
});
|
||||
|
||||
assert_eq!(preamble, read_version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preamble_read_is_thrifty(preamble in Preamble::arbitrary(), b in u8::arbitrary()) {
|
||||
let (read_version, next) = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let (mut writer, mut reader) = tokio::io::duplex(4096);
|
||||
preamble.write(&mut writer).await.unwrap();
|
||||
writer.write_u8(b).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
drop(writer);
|
||||
let preamble = Preamble::read(&mut reader).await.unwrap();
|
||||
let next = reader.read_u8().await.unwrap();
|
||||
(preamble, next)
|
||||
});
|
||||
|
||||
assert_eq!(preamble, read_version);
|
||||
assert_eq!(b, next);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user