diff --git a/src/errors.rs b/src/errors.rs index 0ab6784..c997fd3 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -26,6 +26,8 @@ pub enum DeserializationError { InvalidClientCommand(u8), #[error("Invalid server status {0}; expected 0-8")] InvalidServerResponse(u8), + #[error("Invalid byte in reserved byte ({0})")] + InvalidReservedByte(u8), } #[test] @@ -86,6 +88,10 @@ impl PartialEq for DeserializationError { &DeserializationError::InvalidServerResponse(a), &DeserializationError::InvalidServerResponse(b), ) => a == b, + ( + &DeserializationError::InvalidReservedByte(a), + &DeserializationError::InvalidReservedByte(b), + ) => a == b, (_, _) => false, } } diff --git a/src/messages/client_command.rs b/src/messages/client_command.rs index a8beabc..44ac669 100644 --- a/src/messages/client_command.rs +++ b/src/messages/client_command.rs @@ -32,9 +32,9 @@ impl ClientConnectionRequest { pub async fn read( r: &mut R, ) -> Result { - let mut buffer = [0; 2]; + let mut buffer = [0; 3]; - read_amt(r, 2, &mut buffer).await?; + read_amt(r, 3, &mut buffer).await?; if buffer[0] != 5 { return Err(DeserializationError::InvalidVersion(5, buffer[0])); @@ -47,6 +47,10 @@ impl ClientConnectionRequest { x => return Err(DeserializationError::InvalidClientCommand(x)), }; + if buffer[2] != 0 { + return Err(DeserializationError::InvalidReservedByte(buffer[2])); + } + let destination_address = SOCKSv5Address::read(r).await?; read_amt(r, 2, &mut buffer).await?; @@ -69,7 +73,7 @@ impl ClientConnectionRequest { ClientConnectionCommand::AssociateUDPPort => 3, }; - w.write_all(&[5, command]).await?; + w.write_all(&[5, command, 0]).await?; self.destination_address.write(w).await?; w.write_all(&[ (self.destination_port >> 8) as u8,