use crate::config::connection::ClientConnectionOpts; use crate::ssh::channel::SshPacket; use crate::ssh::message_ids::SshMessageID; use bytes::{Buf, BufMut, Bytes, BytesMut}; use itertools::Itertools; use proptest::arbitrary::Arbitrary; use proptest::strategy::{BoxedStrategy, Strategy}; use rand::{CryptoRng, Rng, SeedableRng}; use std::string::FromUtf8Error; use thiserror::Error; #[derive(Clone, Debug, PartialEq)] pub struct SshKeyExchange { cookie: [u8; 16], keyx_algorithms: Vec, server_host_key_algorithms: Vec, encryption_algorithms_client_to_server: Vec, encryption_algorithms_server_to_client: Vec, mac_algorithms_client_to_server: Vec, mac_algorithms_server_to_client: Vec, compression_algorithms_client_to_server: Vec, compression_algorithms_server_to_client: Vec, languages_client_to_server: Vec, languages_server_to_client: Vec, first_kex_packet_follows: bool, } #[derive(Debug, Error)] pub enum SshKeyExchangeProcessingError { #[error("Message not appropriately tagged as SSH_MSG_KEXINIT")] TaggedWrong, #[error("Initial key exchange message was too short.")] TooShort, #[error("Invalid string encoding for name-list (not ASCII)")] NotAscii, #[error("Invalid conversion (from ASCII to UTF-8??): {0}")] NotUtf8(FromUtf8Error), #[error("Extraneous data at the end of key exchange message")] ExtraneousData, #[error("Received invalid reserved word ({0} != 0)")] InvalidReservedWord(u32), } impl TryFrom for SshKeyExchange { type Error = SshKeyExchangeProcessingError; fn try_from(mut value: SshPacket) -> Result { if SshMessageID::from(value.buffer.get_u8()) != SshMessageID::SSH_MSG_KEXINIT { return Err(SshKeyExchangeProcessingError::TaggedWrong); } let mut cookie = [0; 16]; check_length(&mut value.buffer, 16)?; value.buffer.copy_to_slice(&mut cookie); let keyx_algorithms = name_list(&mut value.buffer)?; let server_host_key_algorithms = name_list(&mut value.buffer)?; let encryption_algorithms_client_to_server = name_list(&mut value.buffer)?; let encryption_algorithms_server_to_client = name_list(&mut value.buffer)?; let mac_algorithms_client_to_server = name_list(&mut value.buffer)?; let mac_algorithms_server_to_client = name_list(&mut value.buffer)?; let compression_algorithms_client_to_server = name_list(&mut value.buffer)?; let compression_algorithms_server_to_client = name_list(&mut value.buffer)?; let languages_client_to_server = name_list(&mut value.buffer)?; let languages_server_to_client = name_list(&mut value.buffer)?; check_length(&mut value.buffer, 5)?; let first_kex_packet_follows = value.buffer.get_u8() != 0; let reserved = value.buffer.get_u32(); if reserved != 0 { return Err(SshKeyExchangeProcessingError::InvalidReservedWord(reserved)); } if value.buffer.remaining() > 0 { return Err(SshKeyExchangeProcessingError::ExtraneousData); } Ok(SshKeyExchange { cookie, keyx_algorithms, server_host_key_algorithms, encryption_algorithms_client_to_server, encryption_algorithms_server_to_client, mac_algorithms_client_to_server, mac_algorithms_server_to_client, compression_algorithms_client_to_server, compression_algorithms_server_to_client, languages_client_to_server, languages_server_to_client, first_kex_packet_follows, }) } } impl From for SshPacket { fn from(value: SshKeyExchange) -> Self { let mut buffer = BytesMut::new(); let put_options = |buffer: &mut BytesMut, vals: Vec| { let mut merged = String::new(); #[allow(unstable_name_collisions)] let comma_sepped = vals.into_iter().intersperse(String::from(",")); merged.extend(comma_sepped); let bytes = merged.as_bytes(); buffer.put_u32(bytes.len() as u32); buffer.put_slice(bytes); }; buffer.put_u8(SshMessageID::SSH_MSG_KEXINIT.into()); buffer.put_slice(&value.cookie); put_options(&mut buffer, value.keyx_algorithms); put_options(&mut buffer, value.server_host_key_algorithms); put_options(&mut buffer, value.encryption_algorithms_client_to_server); put_options(&mut buffer, value.encryption_algorithms_server_to_client); put_options(&mut buffer, value.mac_algorithms_client_to_server); put_options(&mut buffer, value.mac_algorithms_server_to_client); put_options(&mut buffer, value.compression_algorithms_client_to_server); put_options(&mut buffer, value.compression_algorithms_server_to_client); put_options(&mut buffer, value.languages_client_to_server); put_options(&mut buffer, value.languages_server_to_client); buffer.put_u8(value.first_kex_packet_follows as u8); buffer.put_u32(0); SshPacket { buffer: buffer.freeze(), } } } impl SshKeyExchange { /// Create a new SshKeyExchange message for this client or server based /// on the given connection options. /// /// This function takes a random number generator because it needs to /// seed the message with a random cookie, but is otherwise deterministic. /// It will fail only in the case that the underlying random number /// generator fails, and return exactly that error. pub fn new(rng: &mut R, value: ClientConnectionOpts) -> Result where R: CryptoRng + Rng, { let mut result = SshKeyExchange { cookie: [0; 16], keyx_algorithms: value .key_exchange_algorithms .iter() .map(|x| x.to_string()) .collect(), server_host_key_algorithms: value .server_host_key_algorithms .iter() .map(|x| x.to_string()) .collect(), encryption_algorithms_client_to_server: value .encryption_algorithms .iter() .map(|x| x.to_string()) .collect(), encryption_algorithms_server_to_client: value .encryption_algorithms .iter() .map(|x| x.to_string()) .collect(), mac_algorithms_client_to_server: value .mac_algorithms .iter() .map(|x| x.to_string()) .collect(), mac_algorithms_server_to_client: value .mac_algorithms .iter() .map(|x| x.to_string()) .collect(), compression_algorithms_client_to_server: value .compression_algorithms .iter() .map(|x| x.to_string()) .collect(), compression_algorithms_server_to_client: value .compression_algorithms .iter() .map(|x| x.to_string()) .collect(), languages_client_to_server: value.languages.to_vec(), languages_server_to_client: value.languages.to_vec(), first_kex_packet_follows: value.predict.is_some(), }; rng.try_fill(&mut result.cookie)?; Ok(result) } } impl Arbitrary for SshKeyExchange { type Parameters = (); type Strategy = BoxedStrategy; fn arbitrary_with(_: Self::Parameters) -> BoxedStrategy { let client_config = ClientConnectionOpts::arbitrary(); let seed = <[u8; 32]>::arbitrary(); (client_config, seed) .prop_map(|(config, seed)| { let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed); SshKeyExchange::new(&mut rng, config).unwrap() }) .boxed() } } proptest::proptest! { #[test] fn valid_kex_messages_parse(kex in SshKeyExchange::arbitrary()) { let as_packet: SshPacket = kex.clone().try_into().expect("can generate packet"); let as_message = as_packet.try_into().expect("can regenerate message"); assert_eq!(kex, as_message); } } fn check_length(buffer: &mut Bytes, length: usize) -> Result<(), SshKeyExchangeProcessingError> { if buffer.remaining() < length { Err(SshKeyExchangeProcessingError::TooShort) } else { Ok(()) } } fn name_list(buffer: &mut Bytes) -> Result, SshKeyExchangeProcessingError> { check_length(buffer, 4)?; let list_length = buffer.get_u32() as usize; if list_length == 0 { return Ok(vec![]); } check_length(buffer, list_length)?; let mut raw_bytes = vec![0u8; list_length]; buffer.copy_to_slice(&mut raw_bytes); if !raw_bytes.iter().all(|c| c.is_ascii()) { return Err(SshKeyExchangeProcessingError::NotAscii); } let mut result = Vec::new(); for split in raw_bytes.split(|b| char::from(*b) == ',') { let str = String::from_utf8(split.to_vec()).map_err(SshKeyExchangeProcessingError::NotUtf8)?; result.push(str); } Ok(result) } #[cfg(test)] fn standard_kex_message() -> SshPacket { let seed = [0u8; 32]; let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed); let config = ClientConnectionOpts::default(); let message = SshKeyExchange::new(&mut rng, config).expect("default settings work"); message.try_into().expect("default settings serialize") } #[test] fn can_see_bad_tag() { let standard = standard_kex_message(); let mut buffer = standard.buffer.to_vec(); buffer[0] += 1; let bad = SshPacket { buffer: buffer.into(), }; assert!(matches!( SshKeyExchange::try_from(bad), Err(SshKeyExchangeProcessingError::TaggedWrong) )); } #[test] fn checks_for_extraneous_data() { let standard = standard_kex_message(); let mut buffer = standard.buffer.to_vec(); buffer.push(3); let bad = SshPacket { buffer: buffer.into(), }; assert!(matches!( SshKeyExchange::try_from(bad), Err(SshKeyExchangeProcessingError::ExtraneousData) )); } #[test] fn checks_for_short_packets() { let standard = standard_kex_message(); let mut buffer = standard.buffer.to_vec(); let _ = buffer.pop(); let bad = SshPacket { buffer: buffer.into(), }; assert!(matches!( SshKeyExchange::try_from(bad), Err(SshKeyExchangeProcessingError::TooShort) )); } #[test] fn checks_for_invalid_data() { let standard = standard_kex_message(); let mut buffer = standard.buffer.to_vec(); buffer[22] = 0xc3; buffer[23] = 0x28; let bad = SshPacket { buffer: buffer.into(), }; assert!(matches!( SshKeyExchange::try_from(bad), Err(SshKeyExchangeProcessingError::NotAscii) )); } #[test] fn checks_for_bad_reserved_word() { let standard = standard_kex_message(); let mut buffer = standard.buffer.to_vec(); let _ = buffer.pop(); buffer.push(1); let bad = SshPacket { buffer: buffer.into(), }; assert!(matches!( SshKeyExchange::try_from(bad), Err(SshKeyExchangeProcessingError::InvalidReservedWord(1)) )); }