Workspacify

This commit is contained in:
2025-05-03 17:30:01 -07:00
parent 9fe5b78962
commit d036997de3
60 changed files with 450 additions and 212 deletions

18
ssh/Cargo.toml Normal file
View File

@@ -0,0 +1,18 @@
[package]
name = "ssh"
edition = "2024"
[dependencies]
bytes = { workspace = true }
configuration = { workspace = true }
error-stack = { workspace = true }
getrandom = { workspace = true }
itertools = { workspace = true }
num_enum = { workspace = true }
proptest = { workspace = true }
proptest-derive = { workspace = true }
rand = { workspace = true }
rand_chacha = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }

428
ssh/src/channel.rs Normal file
View File

@@ -0,0 +1,428 @@
use bytes::{BufMut, Bytes, BytesMut};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Strategy};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::sync::Mutex;
const MAX_BUFFER_SIZE: usize = 64 * 1024;
#[derive(Clone, Debug, PartialEq)]
pub struct SshPacket {
pub buffer: Bytes,
}
pub struct SshChannel<Stream> {
read_side: Mutex<ReadSide<Stream>>,
write_side: Mutex<WriteSide<Stream>>,
cipher_block_size: usize,
mac_length: usize,
channel_is_closed: bool,
}
struct ReadSide<Stream> {
stream: ReadHalf<Stream>,
buffer: BytesMut,
}
struct WriteSide<Stream> {
stream: WriteHalf<Stream>,
buffer: BytesMut,
rng: ChaCha20Rng,
}
impl<Stream> SshChannel<Stream>
where
Stream: AsyncReadExt + AsyncWriteExt,
Stream: Send + Sync,
Stream: Unpin,
{
/// Create a new SSH channel.
///
/// SshChannels are designed to make sure the various SSH channel read /
/// write operations are cancel- and concurrency-safe. They take ownership
/// of the underlying stream once established, where "established" means
/// that we have read and written the initial SSH banners.
pub fn new(stream: Stream) -> Result<SshChannel<Stream>, getrandom::Error> {
let (read_half, write_half) = tokio::io::split(stream);
Ok(SshChannel {
read_side: Mutex::new(ReadSide {
stream: read_half,
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
}),
write_side: Mutex::new(WriteSide {
stream: write_half,
buffer: BytesMut::with_capacity(MAX_BUFFER_SIZE),
rng: ChaCha20Rng::try_from_os_rng()?,
}),
mac_length: 0,
cipher_block_size: 8,
channel_is_closed: false,
})
}
/// Read an SshPacket from the wire.
///
/// This function is cancel safe, and can be used in `select` (or similar)
/// without problems. It is also safe to be used in a multitasking setting.
/// Returns Ok(Some(...)) if a packet is found, or Ok(None) if we have
/// successfully reached the end of stream. It will also return Ok(None)
/// repeatedly after the stream is closed.
pub async fn read(&self) -> Result<Option<SshPacket>, std::io::Error> {
if self.channel_is_closed {
return Ok(None);
}
let mut reader = self.read_side.lock().await;
let mut local_buffer = vec![0; 4096];
// First, let's try to at least get a size in there.
while reader.buffer.len() < 5 {
let amt_read = reader.stream.read(&mut local_buffer).await?;
reader.buffer.extend_from_slice(&local_buffer[0..amt_read]);
}
let packet_size = ((reader.buffer[0] as usize) << 24)
| ((reader.buffer[1] as usize) << 16)
| ((reader.buffer[2] as usize) << 8)
| reader.buffer[3] as usize;
let padding_size = reader.buffer[4] as usize;
let total_size = 4 + 1 + packet_size + padding_size + self.mac_length;
tracing::trace!(
packet_size,
padding_size,
total_size,
"Initial packet information determined"
);
// Now we need to make sure that the buffer contains at least that
// many bytes. We do this transfer -- from the wire to an internal
// buffer -- to ensure cancel safety. If, at any point, this computation
// is cancelled, the lock will be released and the buffer will be in
// a reasonable place. A subsequent call should be able to pick up
// wherever we left off.
while reader.buffer.len() < total_size {
let amt_read = reader.stream.read(&mut local_buffer).await?;
reader.buffer.extend_from_slice(&local_buffer[0..amt_read]);
}
let mut new_packet = reader.buffer.split_to(total_size);
let _header = new_packet.split_to(5);
let payload = new_packet.split_to(total_size - padding_size - 5);
let _mac = new_packet.split_off(padding_size);
Ok(Some(SshPacket {
buffer: payload.freeze(),
}))
}
fn encode(&self, rng: &mut ChaCha20Rng, packet: SshPacket) -> Option<Bytes> {
let mut encoded_packet = BytesMut::new();
// Arbitrary-length padding, such that the total length of
// (packet_length || padding_length || payload || random padding)
// is a multiple of the cipher block size or 8, whichever is
// larger. There MUST be at least four bytes of padding. The
// padding SHOULD consist of random bytes. The maximum amount of
// padding is 255 bytes.
let paddingless_length = 4 + 1 + packet.buffer.len();
// the padding we need to get to an even multiple of the cipher
// block size, naturally jumping to the cipher block size if we're
// already aligned. (this is just easier, and since we can't have
// 0 as the final padding, seems reasonable to do.)
let mut rounded_padding =
self.cipher_block_size - (paddingless_length % self.cipher_block_size);
// now we enforce the must be greater than or equal to 4 rule
if rounded_padding < 4 {
rounded_padding += self.cipher_block_size;
}
// if this ends up being > 256, then we've run into something terrible
if rounded_padding > (u8::MAX as usize) {
tracing::error!(
payload_length = packet.buffer.len(),
cipher_block_size = ?self.cipher_block_size,
computed_padding = ?rounded_padding,
"generated incoherent padding value in write"
);
return None;
}
encoded_packet.put_u32(packet.buffer.len() as u32);
encoded_packet.put_u8(rounded_padding as u8);
encoded_packet.put(packet.buffer);
for _ in 0..rounded_padding {
encoded_packet.put_u8(rng.random());
}
Some(encoded_packet.freeze())
}
/// Write an SshPacket to the wire.
///
/// This function is cancel safe, and can be used in `select` (or similar).
/// By cancel safe, we mean that one of the following outcomes is guaranteed
/// to occur if the operation is cancelled:
///
/// 1. The whole packet is written to the channel.
/// 2. No part of the packet is written to the channel.
/// 3. The channel is dead, and no further data can be written to it.
///
/// Note that this means that you cannot assume that the packet is not
/// written if the operation is cancelled, it just ensures that you will
/// not be in a place in which only part of the packet has been written.
pub async fn write(&self, packet: SshPacket) -> Result<(), std::io::Error> {
let mut final_data = { self.encode(&mut self.write_side.lock().await.rng, packet) };
loop {
if self.channel_is_closed {
return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
}
let mut writer = self.write_side.lock().await;
if let Some(bytes) = final_data.take() {
if bytes.len() + writer.buffer.len() < MAX_BUFFER_SIZE {
writer.buffer.put(bytes);
} else if !bytes.is_empty() {
final_data = Some(bytes);
}
}
let mut current_buffer = std::mem::take(&mut writer.buffer);
let _written = writer.stream.write_buf(&mut current_buffer).await?;
writer.buffer = current_buffer;
if writer.buffer.is_empty() && final_data.is_none() {
return Ok(());
}
}
}
}
impl Arbitrary for SshPacket {
type Parameters = bool;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(start_with_real_message: Self::Parameters) -> Self::Strategy {
if start_with_real_message {
unimplemented!()
} else {
let data = proptest::collection::vec(u8::arbitrary(), 0..35000);
data.prop_map(|x| SshPacket { buffer: x.into() }).boxed()
}
}
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
enum SendData {
Left(SshPacket),
Right(SshPacket),
}
#[cfg(test)]
impl Arbitrary for SendData {
type Parameters = <SshPacket as Arbitrary>::Parameters;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
(bool::arbitrary(), SshPacket::arbitrary_with(args))
.prop_map(|(is_left, packet)| {
if is_left {
SendData::Left(packet)
} else {
SendData::Right(packet)
}
})
.boxed()
}
}
proptest::proptest! {
#[test]
fn can_read_back_anything(packet in SshPacket::arbitrary()) {
let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
let (left, right) = tokio::io::duplex(8192);
let leftssh = SshChannel::new(left).unwrap();
let rightssh = SshChannel::new(right).unwrap();
let packet_copy = packet.clone();
tokio::task::spawn(async move {
leftssh.write(packet_copy).await.unwrap();
});
rightssh.read().await.unwrap()
});
assert_eq!(packet, result.unwrap());
}
#[test]
fn sequences_send_correctly_serial(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let (left, right) = tokio::io::duplex(8192);
let leftssh = SshChannel::new(left).unwrap();
let rightssh = SshChannel::new(right).unwrap();
let sequence_left = sequence.clone();
let sequence_right = sequence;
let left_task = tokio::task::spawn(async move {
let mut errored = false;
for item in sequence_left.into_iter() {
match item {
SendData::Left(packet) => {
let result = leftssh.write(packet).await;
errored = result.is_err();
}
SendData::Right(packet) => {
if let Ok(Some(item)) = leftssh.read().await {
errored = item != packet;
} else {
errored = true;
}
}
}
if errored {
break
}
}
!errored
});
let right_task = tokio::task::spawn(async move {
let mut errored = false;
for item in sequence_right.into_iter() {
match item {
SendData::Right(packet) => {
let result = rightssh.write(packet).await;
errored = result.is_err();
}
SendData::Left(packet) => {
if let Ok(Some(item)) = rightssh.read().await {
errored = item != packet;
} else {
errored = true;
}
}
}
if errored {
break
}
}
!errored
});
assert!(left_task.await.unwrap());
assert!(right_task.await.unwrap());
});
}
#[test]
fn sequences_send_correctly_parallel(sequence in proptest::collection::vec(SendData::arbitrary(), 0..100)) {
use std::sync::Arc;
tokio::runtime::Runtime::new().unwrap().block_on(async {
let (left, right) = tokio::io::duplex(8192);
let leftsshw = Arc::new(SshChannel::new(left).unwrap());
let leftsshr = leftsshw.clone();
let rightsshw = Arc::new(SshChannel::new(right).unwrap());
let rightsshr = rightsshw.clone();
let sequence_left_write = sequence.clone();
let sequence_left_read = sequence.clone();
let sequence_right_write = sequence.clone();
let sequence_right_read = sequence.clone();
let left_task_write = tokio::task::spawn(async move {
let mut errored = false;
for item in sequence_left_write.into_iter() {
if let SendData::Left(packet) = item {
let result = leftsshw.write(packet).await;
errored = result.is_err();
}
if errored {
break
}
}
!errored
});
let right_task_write = tokio::task::spawn(async move {
let mut errored = false;
for item in sequence_right_write.into_iter() {
if let SendData::Right(packet) = item {
let result = rightsshw.write(packet).await;
errored = result.is_err();
}
if errored {
break
}
}
!errored
});
let left_task_read = tokio::task::spawn(async move {
let mut errored = false;
for item in sequence_left_read.into_iter() {
if let SendData::Right(packet) = item {
if let Ok(Some(item)) = leftsshr.read().await {
errored = item != packet;
} else {
errored = true;
}
}
if errored {
break
}
}
!errored
});
let right_task_read = tokio::task::spawn(async move {
let mut errored = false;
for item in sequence_right_read.into_iter() {
if let SendData::Left(packet) = item {
if let Ok(Some(item)) = rightsshr.read().await {
errored = item != packet;
} else {
errored = true;
}
}
if errored {
break
}
}
!errored
});
assert!(left_task_write.await.unwrap());
assert!(right_task_write.await.unwrap());
assert!(left_task_read.await.unwrap());
assert!(right_task_read.await.unwrap());
});
}
}

11
ssh/src/lib.rs Normal file
View File

@@ -0,0 +1,11 @@
mod channel;
mod message_ids;
mod operational_error;
mod packets;
mod preamble;
pub use channel::SshChannel;
pub use message_ids::SshMessageID;
pub use operational_error::OperationalError;
pub use packets::{SshKeyExchange, SshKeyExchangeProcessingError};
pub use preamble::Preamble;

173
ssh/src/message_ids.rs Normal file
View File

@@ -0,0 +1,173 @@
use crate::operational_error::OperationalError;
use num_enum::{FromPrimitive, IntoPrimitive};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use std::fmt;
#[allow(non_camel_case_types)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, FromPrimitive, IntoPrimitive)]
#[repr(u8)]
pub enum SshMessageID {
SSH_MSG_DISCONNECT = 1,
SSH_MSG_IGNORE = 2,
SSH_MSG_UNIMPLEMENTED = 3,
SSH_MSG_DEBUG = 4,
SSH_MSG_SERVICE_REQUEST = 5,
SSH_MSG_SERVICE_ACCEPT = 6,
SSH_MSG_KEXINIT = 20,
SSH_MSG_NEWKEYS = 21,
SSH_MSG_USERAUTH_REQUEST = 50,
SSH_MSG_USERAUTH_FAILURE = 51,
SSH_MSG_USERAUTH_SUCCESS = 52,
SSH_MSG_USERAUTH_BANNER = 53,
SSH_MSG_GLOBAL_REQUEST = 80,
SSH_MSG_REQUEST_SUCCESS = 81,
SSH_MSG_REQUEST_FAILURE = 82,
SSH_MSG_CHANNEL_OPEN = 90,
SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 91,
SSH_MSG_CHANNEL_OPEN_FAILURE = 92,
SSH_MSG_CHANNEL_WINDOW_ADJUST = 93,
SSH_MSG_CHANNEL_DATA = 94,
SSH_MSG_CHANNEL_EXTENDED_DATA = 95,
SSH_MSG_CHANNEL_EOF = 96,
SSH_MSG_CHANNEL_CLOSE = 97,
SSH_MSG_CHANNEL_REQUEST = 98,
SSH_MSG_CHANNEL_SUCCESS = 99,
SSH_MSG_CHANNEL_FAILURE = 100,
#[num_enum(catch_all)]
Unknown(u8),
}
impl fmt::Display for SshMessageID {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SshMessageID::SSH_MSG_DISCONNECT => write!(f, "SSH_MSG_DISCONNECT"),
SshMessageID::SSH_MSG_IGNORE => write!(f, "SSH_MSG_IGNORE"),
SshMessageID::SSH_MSG_UNIMPLEMENTED => write!(f, "SSH_MSG_UNIMPLEMENTED"),
SshMessageID::SSH_MSG_DEBUG => write!(f, "SSH_MSG_DEBUG"),
SshMessageID::SSH_MSG_SERVICE_REQUEST => write!(f, "SSH_MSG_SERVICE_REQUEST"),
SshMessageID::SSH_MSG_SERVICE_ACCEPT => write!(f, "SSH_MSG_SERVICE_ACCEPT"),
SshMessageID::SSH_MSG_KEXINIT => write!(f, "SSH_MSG_KEXINIT"),
SshMessageID::SSH_MSG_NEWKEYS => write!(f, "SSH_MSG_NEWKEYS"),
SshMessageID::SSH_MSG_USERAUTH_REQUEST => write!(f, "SSH_MSG_USERAUTH_REQUEST"),
SshMessageID::SSH_MSG_USERAUTH_FAILURE => write!(f, "SSH_MSG_USERAUTH_FAILURE"),
SshMessageID::SSH_MSG_USERAUTH_SUCCESS => write!(f, "SSH_MSG_USERAUTH_SUCCESS"),
SshMessageID::SSH_MSG_USERAUTH_BANNER => write!(f, "SSH_MSG_USERAUTH_BANNER"),
SshMessageID::SSH_MSG_GLOBAL_REQUEST => write!(f, "SSH_MSG_GLOBAL_REQUEST"),
SshMessageID::SSH_MSG_REQUEST_SUCCESS => write!(f, "SSH_MSG_REQUEST_SUCCESS"),
SshMessageID::SSH_MSG_REQUEST_FAILURE => write!(f, "SSH_MSG_REQUEST_FAILURE"),
SshMessageID::SSH_MSG_CHANNEL_OPEN => write!(f, "SSH_MSG_CHANNEL_OPEN"),
SshMessageID::SSH_MSG_CHANNEL_OPEN_CONFIRMATION => {
write!(f, "SSH_MSG_CHANNEL_OPEN_CONFIRMATION")
}
SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE => write!(f, "SSH_MSG_CHANNEL_OPEN_FAILURE"),
SshMessageID::SSH_MSG_CHANNEL_WINDOW_ADJUST => {
write!(f, "SSH_MSG_CHANNEL_WINDOW_ADJUST")
}
SshMessageID::SSH_MSG_CHANNEL_DATA => write!(f, "SSH_MSG_CHANNEL_DATA"),
SshMessageID::SSH_MSG_CHANNEL_EXTENDED_DATA => {
write!(f, "SSH_MSG_CHANNEL_EXTENDED_DATA")
}
SshMessageID::SSH_MSG_CHANNEL_EOF => write!(f, "SSH_MSG_CHANNEL_EOF"),
SshMessageID::SSH_MSG_CHANNEL_CLOSE => write!(f, "SSH_MSG_CHANNEL_CLOSE"),
SshMessageID::SSH_MSG_CHANNEL_REQUEST => write!(f, "SSH_MSG_CHANNEL_REQUEST"),
SshMessageID::SSH_MSG_CHANNEL_SUCCESS => write!(f, "SSH_MSG_CHANNEL_SUCCESS"),
SshMessageID::SSH_MSG_CHANNEL_FAILURE => write!(f, "SSH_MSG_CHANNEL_FAILURE"),
SshMessageID::Unknown(x) => write!(f, "SSH_MSG_UNKNOWN{}", x),
}
}
}
#[test]
fn no_duplicate_messages() {
let mut found = std::collections::HashSet::new();
for i in u8::MIN..=u8::MAX {
let id = SshMessageID::from_primitive(i);
let display = id.to_string();
assert!(!found.contains(&display));
found.insert(display);
}
}
impl From<SshMessageID> for OperationalError {
fn from(message: SshMessageID) -> Self {
match message {
SshMessageID::SSH_MSG_DISCONNECT => OperationalError::Disconnect,
SshMessageID::SSH_MSG_USERAUTH_FAILURE => OperationalError::UserAuthFailed,
SshMessageID::SSH_MSG_REQUEST_FAILURE => OperationalError::RequestFailed,
SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE => OperationalError::OpenChannelFailure,
SshMessageID::SSH_MSG_CHANNEL_EOF => OperationalError::OtherEof,
SshMessageID::SSH_MSG_CHANNEL_CLOSE => OperationalError::OtherClosed,
SshMessageID::SSH_MSG_CHANNEL_FAILURE => OperationalError::ChannelFailure,
_ => OperationalError::UnexpectedMessage { message },
}
}
}
impl TryFrom<OperationalError> for SshMessageID {
type Error = OperationalError;
fn try_from(value: OperationalError) -> Result<Self, Self::Error> {
match value {
OperationalError::Disconnect => Ok(SshMessageID::SSH_MSG_DISCONNECT),
OperationalError::UserAuthFailed => Ok(SshMessageID::SSH_MSG_USERAUTH_FAILURE),
OperationalError::RequestFailed => Ok(SshMessageID::SSH_MSG_REQUEST_FAILURE),
OperationalError::OpenChannelFailure => Ok(SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE),
OperationalError::OtherEof => Ok(SshMessageID::SSH_MSG_CHANNEL_EOF),
OperationalError::OtherClosed => Ok(SshMessageID::SSH_MSG_CHANNEL_CLOSE),
OperationalError::ChannelFailure => Ok(SshMessageID::SSH_MSG_CHANNEL_FAILURE),
OperationalError::UnexpectedMessage { message } => Ok(message),
_ => Err(value),
}
}
}
impl Arbitrary for SshMessageID {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
proptest::prop_oneof![
Just(SshMessageID::SSH_MSG_DISCONNECT),
Just(SshMessageID::SSH_MSG_IGNORE),
Just(SshMessageID::SSH_MSG_UNIMPLEMENTED),
Just(SshMessageID::SSH_MSG_DEBUG),
Just(SshMessageID::SSH_MSG_SERVICE_REQUEST),
Just(SshMessageID::SSH_MSG_SERVICE_ACCEPT),
Just(SshMessageID::SSH_MSG_KEXINIT),
Just(SshMessageID::SSH_MSG_NEWKEYS),
Just(SshMessageID::SSH_MSG_USERAUTH_REQUEST),
Just(SshMessageID::SSH_MSG_USERAUTH_FAILURE),
Just(SshMessageID::SSH_MSG_USERAUTH_SUCCESS),
Just(SshMessageID::SSH_MSG_USERAUTH_BANNER),
Just(SshMessageID::SSH_MSG_GLOBAL_REQUEST),
Just(SshMessageID::SSH_MSG_REQUEST_SUCCESS),
Just(SshMessageID::SSH_MSG_REQUEST_FAILURE),
Just(SshMessageID::SSH_MSG_CHANNEL_OPEN),
Just(SshMessageID::SSH_MSG_CHANNEL_OPEN_CONFIRMATION),
Just(SshMessageID::SSH_MSG_CHANNEL_OPEN_FAILURE),
Just(SshMessageID::SSH_MSG_CHANNEL_WINDOW_ADJUST),
Just(SshMessageID::SSH_MSG_CHANNEL_DATA),
Just(SshMessageID::SSH_MSG_CHANNEL_EXTENDED_DATA),
Just(SshMessageID::SSH_MSG_CHANNEL_EOF),
Just(SshMessageID::SSH_MSG_CHANNEL_CLOSE),
Just(SshMessageID::SSH_MSG_CHANNEL_REQUEST),
Just(SshMessageID::SSH_MSG_CHANNEL_SUCCESS),
Just(SshMessageID::SSH_MSG_CHANNEL_FAILURE),
]
.boxed()
}
}
proptest::proptest! {
#[test]
fn error_encodings_invert(message in SshMessageID::arbitrary()) {
let error_version = OperationalError::from(message);
let back_to_message = SshMessageID::try_from(error_version).unwrap();
assert_eq!(message, back_to_message);
}
}

View File

@@ -0,0 +1,51 @@
use crate::{SshKeyExchangeProcessingError, SshMessageID};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum OperationalError {
#[error("Configuration error")]
ConfigurationError,
#[error("DNS client configuration error")]
DnsConfig,
#[error("Failed to connect to target address")]
Connection,
#[error("Failure during key exchange / agreement protocol")]
KeyExchange,
#[error("Failed to complete initial read: {0}")]
InitialRead(std::io::Error),
#[error("SSH banner was not formatted in UTF-8: {0}")]
BannerError(std::str::Utf8Error),
#[error("Invalid initial SSH versionling line: {line}")]
InvalidHeaderLine { line: String },
#[error("Error writing initial banner: {0}")]
WriteBanner(std::io::Error),
#[error("Unexpected disconnect from other side.")]
Disconnect,
#[error("{message} in unexpected place.")]
UnexpectedMessage { message: SshMessageID },
#[error("User authorization failed.")]
UserAuthFailed,
#[error("Request failed.")]
RequestFailed,
#[error("Failed to open channel.")]
OpenChannelFailure,
#[error("Other side closed connection.")]
OtherClosed,
#[error("Other side sent EOF.")]
OtherEof,
#[error("Channel failed.")]
ChannelFailure,
#[error("Error in initial handshake: {0}")]
KeyxProcessingError(#[from] SshKeyExchangeProcessingError),
#[error("Invalid port number '{port_string}': {error}")]
InvalidPort {
port_string: String,
error: std::num::ParseIntError,
},
#[error("Invalid hostname '{0}'")]
InvalidHostname(String),
#[error("Unable to parse host address")]
UnableToParseHostAddress,
#[error("Unable to configure resolver")]
Resolver,
}

3
ssh/src/packets.rs Normal file
View File

@@ -0,0 +1,3 @@
mod key_exchange;
pub use key_exchange::{SshKeyExchange, SshKeyExchangeProcessingError};

View File

@@ -0,0 +1,332 @@
use configuration::connection::ClientConnectionOpts;
use crate::channel::SshPacket;
use crate::message_ids::SshMessageID;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use itertools::Itertools;
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Strategy};
use rand::{CryptoRng, Rng, SeedableRng};
use std::string::FromUtf8Error;
use thiserror::Error;
#[derive(Clone, Debug, PartialEq)]
pub struct SshKeyExchange {
cookie: [u8; 16],
keyx_algorithms: Vec<String>,
server_host_key_algorithms: Vec<String>,
encryption_algorithms_client_to_server: Vec<String>,
encryption_algorithms_server_to_client: Vec<String>,
mac_algorithms_client_to_server: Vec<String>,
mac_algorithms_server_to_client: Vec<String>,
compression_algorithms_client_to_server: Vec<String>,
compression_algorithms_server_to_client: Vec<String>,
languages_client_to_server: Vec<String>,
languages_server_to_client: Vec<String>,
first_kex_packet_follows: bool,
}
#[derive(Debug, Error)]
pub enum SshKeyExchangeProcessingError {
#[error("Message not appropriately tagged as SSH_MSG_KEXINIT")]
TaggedWrong,
#[error("Initial key exchange message was too short.")]
TooShort,
#[error("Invalid string encoding for name-list (not ASCII)")]
NotAscii,
#[error("Invalid conversion (from ASCII to UTF-8??): {0}")]
NotUtf8(FromUtf8Error),
#[error("Extraneous data at the end of key exchange message")]
ExtraneousData,
#[error("Received invalid reserved word ({0} != 0)")]
InvalidReservedWord(u32),
}
impl TryFrom<SshPacket> for SshKeyExchange {
type Error = SshKeyExchangeProcessingError;
fn try_from(mut value: SshPacket) -> Result<Self, Self::Error> {
if SshMessageID::from(value.buffer.get_u8()) != SshMessageID::SSH_MSG_KEXINIT {
return Err(SshKeyExchangeProcessingError::TaggedWrong);
}
let mut cookie = [0; 16];
check_length(&mut value.buffer, 16)?;
value.buffer.copy_to_slice(&mut cookie);
let keyx_algorithms = name_list(&mut value.buffer)?;
let server_host_key_algorithms = name_list(&mut value.buffer)?;
let encryption_algorithms_client_to_server = name_list(&mut value.buffer)?;
let encryption_algorithms_server_to_client = name_list(&mut value.buffer)?;
let mac_algorithms_client_to_server = name_list(&mut value.buffer)?;
let mac_algorithms_server_to_client = name_list(&mut value.buffer)?;
let compression_algorithms_client_to_server = name_list(&mut value.buffer)?;
let compression_algorithms_server_to_client = name_list(&mut value.buffer)?;
let languages_client_to_server = name_list(&mut value.buffer)?;
let languages_server_to_client = name_list(&mut value.buffer)?;
check_length(&mut value.buffer, 5)?;
let first_kex_packet_follows = value.buffer.get_u8() != 0;
let reserved = value.buffer.get_u32();
if reserved != 0 {
return Err(SshKeyExchangeProcessingError::InvalidReservedWord(reserved));
}
if value.buffer.remaining() > 0 {
return Err(SshKeyExchangeProcessingError::ExtraneousData);
}
Ok(SshKeyExchange {
cookie,
keyx_algorithms,
server_host_key_algorithms,
encryption_algorithms_client_to_server,
encryption_algorithms_server_to_client,
mac_algorithms_client_to_server,
mac_algorithms_server_to_client,
compression_algorithms_client_to_server,
compression_algorithms_server_to_client,
languages_client_to_server,
languages_server_to_client,
first_kex_packet_follows,
})
}
}
impl From<SshKeyExchange> for SshPacket {
fn from(value: SshKeyExchange) -> Self {
let mut buffer = BytesMut::new();
let put_options = |buffer: &mut BytesMut, vals: Vec<String>| {
let mut merged = String::new();
#[allow(unstable_name_collisions)]
let comma_sepped = vals.into_iter().intersperse(String::from(","));
merged.extend(comma_sepped);
let bytes = merged.as_bytes();
buffer.put_u32(bytes.len() as u32);
buffer.put_slice(bytes);
};
buffer.put_u8(SshMessageID::SSH_MSG_KEXINIT.into());
buffer.put_slice(&value.cookie);
put_options(&mut buffer, value.keyx_algorithms);
put_options(&mut buffer, value.server_host_key_algorithms);
put_options(&mut buffer, value.encryption_algorithms_client_to_server);
put_options(&mut buffer, value.encryption_algorithms_server_to_client);
put_options(&mut buffer, value.mac_algorithms_client_to_server);
put_options(&mut buffer, value.mac_algorithms_server_to_client);
put_options(&mut buffer, value.compression_algorithms_client_to_server);
put_options(&mut buffer, value.compression_algorithms_server_to_client);
put_options(&mut buffer, value.languages_client_to_server);
put_options(&mut buffer, value.languages_server_to_client);
buffer.put_u8(value.first_kex_packet_follows as u8);
buffer.put_u32(0);
SshPacket {
buffer: buffer.freeze(),
}
}
}
impl SshKeyExchange {
/// Create a new SshKeyExchange message for this client or server based
/// on the given connection options.
///
/// This function takes a random number generator because it needs to
/// seed the message with a random cookie, but is otherwise deterministic.
/// It will fail only in the case that the underlying random number
/// generator fails, and return exactly that error.
pub fn new<R>(rng: &mut R, value: ClientConnectionOpts) -> Result<Self, ()>
where
R: CryptoRng + Rng,
{
let mut result = SshKeyExchange {
cookie: [0; 16],
keyx_algorithms: value
.key_exchange_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
server_host_key_algorithms: value
.server_host_key_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
encryption_algorithms_client_to_server: value
.encryption_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
encryption_algorithms_server_to_client: value
.encryption_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
mac_algorithms_client_to_server: value
.mac_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
mac_algorithms_server_to_client: value
.mac_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
compression_algorithms_client_to_server: value
.compression_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
compression_algorithms_server_to_client: value
.compression_algorithms
.iter()
.map(|x| x.to_string())
.collect(),
languages_client_to_server: value.languages.to_vec(),
languages_server_to_client: value.languages.to_vec(),
first_kex_packet_follows: value.predict.is_some(),
};
rng.fill(&mut result.cookie);
Ok(result)
}
}
impl Arbitrary for SshKeyExchange {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> BoxedStrategy<Self> {
let client_config = ClientConnectionOpts::arbitrary();
let seed = <[u8; 32]>::arbitrary();
(client_config, seed)
.prop_map(|(config, seed)| {
let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed);
SshKeyExchange::new(&mut rng, config).unwrap()
})
.boxed()
}
}
proptest::proptest! {
#[test]
fn valid_kex_messages_parse(kex in SshKeyExchange::arbitrary()) {
let as_packet: SshPacket = kex.clone().try_into().expect("can generate packet");
let as_message = as_packet.try_into().expect("can regenerate message");
assert_eq!(kex, as_message);
}
}
fn check_length(buffer: &mut Bytes, length: usize) -> Result<(), SshKeyExchangeProcessingError> {
if buffer.remaining() < length {
Err(SshKeyExchangeProcessingError::TooShort)
} else {
Ok(())
}
}
fn name_list(buffer: &mut Bytes) -> Result<Vec<String>, SshKeyExchangeProcessingError> {
check_length(buffer, 4)?;
let list_length = buffer.get_u32() as usize;
if list_length == 0 {
return Ok(vec![]);
}
check_length(buffer, list_length)?;
let mut raw_bytes = vec![0u8; list_length];
buffer.copy_to_slice(&mut raw_bytes);
if !raw_bytes.iter().all(|c| c.is_ascii()) {
return Err(SshKeyExchangeProcessingError::NotAscii);
}
let mut result = Vec::new();
for split in raw_bytes.split(|b| char::from(*b) == ',') {
let str =
String::from_utf8(split.to_vec()).map_err(SshKeyExchangeProcessingError::NotUtf8)?;
result.push(str);
}
Ok(result)
}
#[cfg(test)]
fn standard_kex_message() -> SshPacket {
let seed = [0u8; 32];
let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed);
let config = ClientConnectionOpts::default();
let message = SshKeyExchange::new(&mut rng, config).expect("default settings work");
message.try_into().expect("default settings serialize")
}
#[test]
fn can_see_bad_tag() {
let standard = standard_kex_message();
let mut buffer = standard.buffer.to_vec();
buffer[0] += 1;
let bad = SshPacket {
buffer: buffer.into(),
};
assert!(matches!(
SshKeyExchange::try_from(bad),
Err(SshKeyExchangeProcessingError::TaggedWrong)
));
}
#[test]
fn checks_for_extraneous_data() {
let standard = standard_kex_message();
let mut buffer = standard.buffer.to_vec();
buffer.push(3);
let bad = SshPacket {
buffer: buffer.into(),
};
assert!(matches!(
SshKeyExchange::try_from(bad),
Err(SshKeyExchangeProcessingError::ExtraneousData)
));
}
#[test]
fn checks_for_short_packets() {
let standard = standard_kex_message();
let mut buffer = standard.buffer.to_vec();
let _ = buffer.pop();
let bad = SshPacket {
buffer: buffer.into(),
};
assert!(matches!(
SshKeyExchange::try_from(bad),
Err(SshKeyExchangeProcessingError::TooShort)
));
}
#[test]
fn checks_for_invalid_data() {
let standard = standard_kex_message();
let mut buffer = standard.buffer.to_vec();
buffer[22] = 0xc3;
buffer[23] = 0x28;
let bad = SshPacket {
buffer: buffer.into(),
};
assert!(matches!(
SshKeyExchange::try_from(bad),
Err(SshKeyExchangeProcessingError::NotAscii)
));
}
#[test]
fn checks_for_bad_reserved_word() {
let standard = standard_kex_message();
let mut buffer = standard.buffer.to_vec();
let _ = buffer.pop();
buffer.push(1);
let bad = SshPacket {
buffer: buffer.into(),
};
assert!(matches!(
SshKeyExchange::try_from(bad),
Err(SshKeyExchangeProcessingError::InvalidReservedWord(1))
));
}

392
ssh/src/preamble.rs Normal file
View File

@@ -0,0 +1,392 @@
use error_stack::{report, ResultExt};
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Strategy};
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[derive(Debug, PartialEq)]
pub struct Preamble {
pub preamble: String,
pub software_name: String,
pub software_version: String,
pub commentary: String,
}
impl Arbitrary for Preamble {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
let name = proptest::string::string_regex("[[:alpha:]][[:alnum:]]{0,32}").unwrap();
let soft_major = u8::arbitrary();
let soft_minor = u8::arbitrary();
let soft_patch = proptest::option::of(u8::arbitrary());
let commentary = proptest::option::of(
proptest::string::string_regex("[[:alnum:]][[[:alnum:]][[:blank:]][[:punct:]]]{0,64}")
.unwrap(),
);
(name, soft_major, soft_minor, soft_patch, commentary)
.prop_map(|(name, major, minor, patch, commentary)| Preamble {
preamble: String::new(),
software_name: name,
software_version: if let Some(patch) = patch {
format!("{}.{}.{}", major, minor, patch)
} else {
format!("{}.{}", major, minor)
},
commentary: commentary.unwrap_or_default(),
})
.boxed()
}
}
impl Default for Preamble {
fn default() -> Self {
Preamble {
preamble: String::new(),
software_name: env!("CARGO_PKG_NAME").to_string(),
software_version: env!("CARGO_PKG_VERSION").to_string(),
commentary: String::new(),
}
}
}
#[derive(Debug, Error)]
pub enum PreambleReadError {
#[error("Reading from the input stream failed")]
Read,
#[error("Illegal version number, expected '2.0' (saw characters {0})")]
IllegalVersion(String),
#[error("No dash found after seeing SSH version number")]
NoDashAfterVersion,
#[error("Illegal character in SSH software name")]
IllegalSoftwareNameChar,
#[error("Protocol error in preamble: No line feed for carriage return")]
NoLineFeedForCarriage,
#[error("Missing the final newline in the preamble")]
MissingFinalNewline,
#[error("Illegal UTF-8 in software name")]
InvalidSoftwareName,
}
#[derive(Debug, Error)]
pub enum PreambleWriteError {
#[error("Could not write preamble to socket")]
Write,
}
#[derive(Debug)]
enum PreambleState {
StartOfLine,
Preamble,
CarriageReturn,
InitialS,
SecondS,
InitialH,
InitialDash,
Version2,
VersionDot,
Version0,
VersionDash,
SoftwareName,
SoftwareVersion,
Commentary,
FinalCarriageReturn,
}
impl Preamble {
/// Read an SSH preamble from the given read channel.
///
/// Will fail if the underlying read channel fails, or if the preamble does not
/// meet the formatting requirements of the RFC.
pub async fn read<R: AsyncReadExt + Unpin>(
connection: &mut R,
) -> error_stack::Result<Preamble, PreambleReadError> {
let mut preamble = String::new();
let mut software_name_bytes = Vec::new();
let mut software_version = String::new();
let mut commentary = String::new();
let mut state = PreambleState::StartOfLine;
loop {
let next_byte = connection
.read_u8()
.await
.change_context(PreambleReadError::Read)?;
let next_char = char::from(next_byte);
tracing::trace!(?next_char, ?state, "processing next preamble character");
match state {
PreambleState::StartOfLine => match next_char {
'S' => state = PreambleState::InitialS,
_ => {
preamble.push(next_char);
state = PreambleState::Preamble;
}
},
PreambleState::Preamble => match next_char {
'\r' => state = PreambleState::CarriageReturn,
_ => preamble.push(next_char),
},
PreambleState::CarriageReturn => match next_char {
'\n' => state = PreambleState::StartOfLine,
_ => return Err(report!(PreambleReadError::NoLineFeedForCarriage)),
},
PreambleState::InitialS => match next_char {
'S' => state = PreambleState::SecondS,
_ => {
preamble.push('S');
preamble.push(next_char);
state = PreambleState::Preamble;
}
},
PreambleState::SecondS => match next_char {
'H' => state = PreambleState::InitialH,
_ => {
preamble.push_str("SS");
preamble.push(next_char);
state = PreambleState::Preamble;
}
},
PreambleState::InitialH => match next_char {
'-' => state = PreambleState::InitialDash,
_ => {
preamble.push_str("SSH");
preamble.push(next_char);
state = PreambleState::InitialDash;
}
},
PreambleState::InitialDash => match next_char {
'2' => state = PreambleState::Version2,
_ => {
return Err(report!(PreambleReadError::IllegalVersion(String::from(
next_char
))))
}
},
PreambleState::Version2 => match next_char {
'.' => state = PreambleState::VersionDot,
_ => {
return Err(report!(PreambleReadError::IllegalVersion(format!(
"2{}",
next_char
))))
}
},
PreambleState::VersionDot => match next_char {
'0' => state = PreambleState::Version0,
_ => {
return Err(report!(PreambleReadError::IllegalVersion(format!(
"2.{}",
next_char
))))
}
},
PreambleState::Version0 => match next_char {
'-' => state = PreambleState::VersionDash,
_ => return Err(report!(PreambleReadError::NoDashAfterVersion)),
},
PreambleState::VersionDash => {
software_name_bytes.push(next_byte);
state = PreambleState::SoftwareName;
}
PreambleState::SoftwareName => match next_char {
'_' => state = PreambleState::SoftwareVersion,
x if x == '-' || x.is_ascii_whitespace() => {
return Err(report!(PreambleReadError::IllegalSoftwareNameChar))
}
_ => software_name_bytes.push(next_byte),
},
PreambleState::SoftwareVersion => match next_char {
' ' => state = PreambleState::Commentary,
'\r' => state = PreambleState::FinalCarriageReturn,
'_' => state = PreambleState::SoftwareVersion,
x if x == '-' || x.is_ascii_whitespace() => {
return Err(report!(PreambleReadError::IllegalSoftwareNameChar))
.attach_printable_lazy(|| {
format!("saw {:?} / {}", next_char, next_byte)
})?
}
_ => software_version.push(next_char),
},
PreambleState::FinalCarriageReturn => match next_char {
'\n' => break,
_ => return Err(report!(PreambleReadError::MissingFinalNewline)),
},
PreambleState::Commentary => match next_char {
'\r' => state = PreambleState::FinalCarriageReturn,
_ => commentary.push(next_char),
},
}
}
let software_name = String::from_utf8(software_name_bytes)
.change_context(PreambleReadError::InvalidSoftwareName)?;
Ok(Preamble {
preamble,
software_name,
software_version,
commentary,
})
}
// let mut read_buffer = vec![0; 4096];
// let mut pre_message = String::new();
// let protocol_version;
// let software_name;
// let software_version;
// let commentary;
// let mut prefix = String::new();
//
//
// 'outer: loop {
// let read_length = connection
// .read(&mut read_buffer)
// .await
// .change_context(PreambleReadError::InitialRead)?;
// let string_version = std::str::from_utf8(&read_buffer[0..read_length])
// .change_context(PreambleReadError::BannerError)?;
//
// prefix.push_str(string_version);
// let ends_with_newline = prefix.ends_with("\r\n");
//
// let new_prefix = if ends_with_newline {
// // we are cleanly reading up to a \r\n, so our new prefix after
// // this loop is empty
// String::new()
// } else if let Some((correct_bits, leftover)) = prefix.rsplit_once("\r\n") {
// // there's some dangling bits in this read, so we'll cut this string
// // at the final "\r\n" and then remember to use the leftover as the
// // new prefix at the end of this loop
// let result = leftover.to_string();
// prefix = correct_bits.to_string();
// result
// } else {
// // there's no "\r\n", so we don't have a full line yet, so keep reading
// continue;
// };
//
// for line in prefix.lines() {
// if line.starts_with("SSH") {
// let (_, interesting_bits) = line
// .split_once('-')
// .ok_or_else(||
// report!(PreambleReadError::InvalidHeaderLine {
// reason: "could not find dash after SSH",
// line: line.to_string(),
// }))?;
//
// let (protoversion, other_bits) = interesting_bits
// .split_once('-')
// .ok_or_else(|| report!(PreambleReadError::InvalidHeaderLine {
// reason: "could not find dash after protocol version",
// line: line.to_string(),
// }))?;
//
// let (softwarever, comment) = match other_bits.split_once(' ') {
// Some((s, c)) => (s, c),
// None => (other_bits, ""),
// };
//
// let (software_name_str, software_version_str) = softwarever
// .split_once('_')
// .ok_or_else(|| report!(PreambleReadError::InvalidHeaderLine {
// reason: "could not find underscore between software name and version",
// line: line.to_string(),
// }))?;
//
// software_name = software_name_str.to_string();
// software_version = software_version_str.to_string();
// protocol_version = protoversion.to_string();
// commentary = comment.to_string();
// break 'outer;
// } else {
// pre_message.push_str(line);
// pre_message.push('\n');
// }
// }
//
// prefix = new_prefix;
// }
//
// tracing::info!(
// ?protocol_version,
// ?software_version,
// ?commentary,
// "Got server information"
// );
//
// Ok(Preamble {
// protocol_version,
// software_name,
// software_version,
// commentary,
// })
// }
/// Write a preamble to the given network socket.
pub async fn write<W: AsyncWriteExt + Unpin>(
&self,
connection: &mut W,
) -> error_stack::Result<(), PreambleWriteError> {
let comment = if self.commentary.is_empty() {
self.commentary.clone()
} else {
format!(" {}", self.commentary)
};
let output = format!(
"SSH-2.0-{}_{}{}\r\n",
self.software_name, self.software_version, comment
);
connection
.write_all(output.as_bytes())
.await
.change_context(PreambleWriteError::Write)
}
}
proptest::proptest! {
#[test]
fn preamble_roundtrips(preamble in Preamble::arbitrary()) {
let read_version = tokio::runtime::Runtime::new().unwrap().block_on(async {
let (mut writer, mut reader) = tokio::io::duplex(4096);
preamble.write(&mut writer).await.unwrap();
Preamble::read(&mut reader).await.unwrap()
});
assert_eq!(preamble, read_version);
}
#[test]
fn preamble_read_is_thrifty(preamble in Preamble::arbitrary(), b in u8::arbitrary()) {
let (read_version, next) = tokio::runtime::Runtime::new().unwrap().block_on(async {
let (mut writer, mut reader) = tokio::io::duplex(4096);
preamble.write(&mut writer).await.unwrap();
writer.write_u8(b).await.unwrap();
writer.flush().await.unwrap();
drop(writer);
let preamble = Preamble::read(&mut reader).await.unwrap();
let next = reader.read_u8().await.unwrap();
(preamble, next)
});
assert_eq!(preamble, read_version);
assert_eq!(b, next);
}
}