diff --git a/src/crypto.rs b/src/crypto.rs index 4c9060c..1282e51 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -6,31 +6,32 @@ use hmac::{Hmac, Mac}; use sha2::Sha256; pub fn aes_256_cbc_encrypt(ptext: &[u8], key: &[u8], iv: &[u8]) -> Result> { - if key.len() != 32 { - return Err(SignalProtocolError::InvalidCipherKeyLength(key.len())); + match Cbc::::new_var(key, iv) { + Ok(mode) => Ok(mode.encrypt_vec(&ptext)), + Err(block_modes::InvalidKeyIvLength) => { + return Err(SignalProtocolError::InvalidCipherCryptographicParameters( + key.len(), + iv.len(), + )) + } } - if iv.len() != 16 { - return Err(SignalProtocolError::InvalidCipherNonceLength(iv.len())); - } - - let mode = Cbc::::new_var(key, iv) - .map_err(|e| SignalProtocolError::InvalidArgument(format!("{}", e)))?; - Ok(mode.encrypt_vec(&ptext)) } pub fn aes_256_cbc_decrypt(ctext: &[u8], key: &[u8], iv: &[u8]) -> Result> { - if key.len() != 32 { - return Err(SignalProtocolError::InvalidCipherKeyLength(key.len())); - } - if iv.len() != 16 { - return Err(SignalProtocolError::InvalidCipherNonceLength(iv.len())); - } if ctext.len() == 0 || ctext.len() % 16 != 0 { return Err(SignalProtocolError::InvalidCiphertext); } - let mode = Cbc::::new_var(key, iv) - .map_err(|e| SignalProtocolError::InvalidArgument(format!("{}", e)))?; + let mode = match Cbc::::new_var(key, iv) { + Ok(mode) => mode, + Err(block_modes::InvalidKeyIvLength) => { + return Err(SignalProtocolError::InvalidCipherCryptographicParameters( + key.len(), + iv.len(), + )) + } + }; + Ok(mode .decrypt_vec(ctext) .map_err(|_| SignalProtocolError::InvalidCiphertext)?) diff --git a/src/error.rs b/src/error.rs index cf9fc1f..e221c6c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -42,9 +42,8 @@ pub enum SignalProtocolError { InvalidRootKeyLength(usize), InvalidChainKeyLength(usize), - InvalidCipherKeyLength(usize), InvalidMacKeyLength(usize), - InvalidCipherNonceLength(usize), + InvalidCipherCryptographicParameters(usize, usize), InvalidCiphertext, NoSenderKeyState, @@ -137,15 +136,14 @@ impl fmt::Display for SignalProtocolError { SignalProtocolError::InvalidRootKeyLength(l) => { write!(f, "invalid root key length <{}>", l) } - SignalProtocolError::InvalidCipherKeyLength(l) => { - write!(f, "invalid cipher key length <{}>", l) - } + SignalProtocolError::InvalidCipherCryptographicParameters(kl, nl) => write!( + f, + "invalid cipher key length <{}> or nonce length <{}>", + kl, nl + ), SignalProtocolError::InvalidMacKeyLength(l) => { write!(f, "invalid MAC key length <{}>", l) } - SignalProtocolError::InvalidCipherNonceLength(l) => { - write!(f, "invalid cipher nonce length <{}>", l) - } SignalProtocolError::UntrustedIdentity(addr) => { write!(f, "untrusted identity for address {}", addr) } diff --git a/src/ratchet/keys.rs b/src/ratchet/keys.rs index ffb35a1..8399dcc 100644 --- a/src/ratchet/keys.rs +++ b/src/ratchet/keys.rs @@ -24,16 +24,14 @@ impl MessageKeys { } pub fn new(cipher_key: &[u8], mac_key: &[u8], iv: &[u8], counter: u32) -> Result { - if cipher_key.len() != 32 { - return Err(SignalProtocolError::InvalidCipherKeyLength( - cipher_key.len(), - )); - } if mac_key.len() != 32 { return Err(SignalProtocolError::InvalidMacKeyLength(mac_key.len())); } - if iv.len() != 16 { - return Err(SignalProtocolError::InvalidCipherNonceLength(iv.len())); + if cipher_key.len() != 32 || iv.len() != 16 { + return Err(SignalProtocolError::InvalidCipherCryptographicParameters( + cipher_key.len(), + iv.len(), + )); } Ok(MessageKeys {