diff --git a/boring/src/mlkem.rs b/boring/src/mlkem.rs index b10a388c..18569535 100644 --- a/boring/src/mlkem.rs +++ b/boring/src/mlkem.rs @@ -7,12 +7,11 @@ //! Provides ML-KEM-768 (recommended) and ML-KEM-1024 variants via [`MlKem`]. //! //! ``` -//! use boring::mlkem::{MlKem, MlKemParams}; +//! use boring::mlkem::MlKem; //! -//! let kem = MlKem::new(MlKemParams::MlKem768); -//! let (public_key, private_key) = kem.generate_key().unwrap(); -//! let (ciphertext, shared_secret) = kem.encapsulate(&public_key).unwrap(); -//! let decrypted = kem.decapsulate(&private_key, &ciphertext).unwrap(); +//! let (public_key, private_key) = MlKem::MlKem768.generate_key().unwrap(); +//! let (ciphertext, shared_secret) = MlKem::MlKem768.encapsulate(&public_key).unwrap(); +//! let decrypted = MlKem::MlKem768.decapsulate(&private_key, &ciphertext).unwrap(); //! assert_eq!(shared_secret, decrypted); //! ``` @@ -44,22 +43,31 @@ pub type MlKemPrivateKeySeed = [u8; PRIVATE_KEY_SEED_BYTES]; /// Raw bytes of the shared secret ([`SHARED_SECRET_BYTES`] long) pub type MlKemSharedSecret = [u8; SHARED_SECRET_BYTES]; -/// ML-KEM variant selection. +/// ML-KEM with runtime algorithm selection. Works with byte slices. +/// +/// ``` +/// use boring::mlkem::MlKem; +/// +/// let (public_key, private_key) = MlKem::MlKem768.generate_key().unwrap(); +/// let (ciphertext, shared_secret) = MlKem::MlKem768.encapsulate(&public_key).unwrap(); +/// let decrypted = kem.decapsulate(&private_key, &ciphertext).unwrap(); +/// assert_eq!(shared_secret, decrypted); +/// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MlKemParams { +pub enum MlKem { /// Recommended. AES-192 equivalent security. MlKem768, /// AES-256 equivalent security. MlKem1024, } -impl MlKemParams { +impl MlKem { /// Returns 1184 for ML-KEM-768, 1568 for ML-KEM-1024. #[must_use] pub const fn public_key_bytes(&self) -> usize { match self { - MlKemParams::MlKem768 => MlKem768PublicKey::PUBLIC_KEY_BYTES, - MlKemParams::MlKem1024 => MlKem1024PublicKey::PUBLIC_KEY_BYTES, + Self::MlKem768 => MlKem768PublicKey::PUBLIC_KEY_BYTES, + Self::MlKem1024 => MlKem1024PublicKey::PUBLIC_KEY_BYTES, } } @@ -67,61 +75,21 @@ impl MlKemParams { #[must_use] pub const fn ciphertext_bytes(&self) -> usize { match self { - MlKemParams::MlKem768 => MlKem768PrivateKey::CIPHERTEXT_BYTES, - MlKemParams::MlKem1024 => MlKem1024PrivateKey::CIPHERTEXT_BYTES, + Self::MlKem768 => MlKem768PrivateKey::CIPHERTEXT_BYTES, + Self::MlKem1024 => MlKem1024PrivateKey::CIPHERTEXT_BYTES, } } -} - -/// ML-KEM with runtime algorithm selection. Works with byte slices. -/// -/// ``` -/// use boring::mlkem::{MlKem, MlKemParams}; -/// -/// let kem = MlKem::new(MlKemParams::MlKem768); -/// let (public_key, private_key) = kem.generate_key().unwrap(); -/// let (ciphertext, shared_secret) = kem.encapsulate(&public_key).unwrap(); -/// let decrypted = kem.decapsulate(&private_key, &ciphertext).unwrap(); -/// assert_eq!(shared_secret, decrypted); -/// ``` -#[derive(Debug, Clone, Copy)] -pub struct MlKem { - params: MlKemParams, -} - -impl MlKem { - /// Creates a new context for the given parameter set. - #[must_use] - pub fn new(params: MlKemParams) -> Self { - ffi::init(); - Self { params } - } - - #[must_use] - pub fn params(&self) -> MlKemParams { - self.params - } - - #[must_use] - pub fn public_key_bytes(&self) -> usize { - self.params.public_key_bytes() - } - - #[must_use] - pub fn ciphertext_bytes(&self) -> usize { - self.params.ciphertext_bytes() - } /// Generates a new key pair, returning `(public_key, private_key)`. /// /// The private key is a 64-byte seed. Keep it secret. pub fn generate_key(&self) -> Result<(Vec, MlKemPrivateKeySeed), ErrorStack> { - match self.params { - MlKemParams::MlKem768 => { + match self { + Self::MlKem768 => { let (sk, pk) = MlKem768PrivateKey::generate(); Ok((pk.bytes.to_vec(), sk.seed)) } - MlKemParams::MlKem1024 => { + Self::MlKem1024 => { let (sk, pk) = MlKem1024PrivateKey::generate(); Ok((pk.bytes.to_vec(), sk.seed)) } @@ -134,13 +102,13 @@ impl MlKem { &self, public_key: &[u8], ) -> Result<(Vec, MlKemSharedSecret), ErrorStack> { - match self.params { - MlKemParams::MlKem768 => { + match self { + Self::MlKem768 => { let pk = MlKem768PublicKey::from_slice(public_key)?; let (ct, ss) = pk.encapsulate(); Ok((ct.to_vec(), ss)) } - MlKemParams::MlKem1024 => { + Self::MlKem1024 => { let pk = MlKem1024PublicKey::from_slice(public_key)?; let (ct, ss) = pk.encapsulate(); Ok((ct.to_vec(), ss)) @@ -159,15 +127,15 @@ impl MlKem { } let seed_arr: MlKemPrivateKeySeed = private_key.try_into().unwrap(); - match self.params { - MlKemParams::MlKem768 => { + match self { + Self::MlKem768 => { let ct: &[u8; MlKem768PrivateKey::CIPHERTEXT_BYTES] = ciphertext .try_into() .map_err(|_| ErrorStack::internal_error_str("invalid ciphertext length"))?; let sk = MlKem768PrivateKey::from_seed(seed_arr)?; Ok(sk.decapsulate(ct)) } - MlKemParams::MlKem1024 => { + Self::MlKem1024 => { let ct: &[u8; MlKem1024PrivateKey::CIPHERTEXT_BYTES] = ciphertext .try_into() .map_err(|_| ErrorStack::internal_error_str("invalid ciphertext length"))?; @@ -686,13 +654,13 @@ mod tests { use super::*; macro_rules! unified_tests { - ($name:ident, $params:expr, $pk_len:expr, $ct_len:expr) => { + ($name:ident, $algorithm:expr, $pk_len:expr, $ct_len:expr) => { mod $name { use super::*; #[test] fn roundtrip() { - let kem = MlKem::new($params); + let kem = $algorithm; let (pk, seed) = kem.generate_key().unwrap(); let (ct, ss1) = kem.encapsulate(&pk).unwrap(); let ss2 = kem.decapsulate(&seed, &ct).unwrap(); @@ -701,7 +669,7 @@ mod tests { #[test] fn key_sizes() { - let kem = MlKem::new($params); + let kem = $algorithm; assert_eq!(kem.public_key_bytes(), $pk_len); assert_eq!(kem.ciphertext_bytes(), $ct_len); @@ -716,14 +684,14 @@ mod tests { #[test] fn invalid_public_key_length() { - let kem = MlKem::new($params); + let kem = $algorithm; let result = kem.encapsulate(&[0u8; 100]); assert!(result.is_err()); } #[test] fn invalid_private_key_length() { - let kem = MlKem::new($params); + let kem = $algorithm; let (pk, _) = kem.generate_key().unwrap(); let (ct, _) = kem.encapsulate(&pk).unwrap(); let result = kem.decapsulate(&[0u8; 32], &ct); @@ -732,37 +700,31 @@ mod tests { #[test] fn invalid_ciphertext_length() { - let kem = MlKem::new($params); + let kem = $algorithm; let (_, private_key) = kem.generate_key().unwrap(); let result = kem.decapsulate(&private_key, &[0u8; 100]); assert!(result.is_err()); } - - #[test] - fn params_accessor() { - let kem = MlKem::new($params); - assert_eq!(kem.params(), $params); - } } }; } - unified_tests!(mlkem768, MlKemParams::MlKem768, 1184, 1088); - unified_tests!(mlkem1024, MlKemParams::MlKem1024, 1568, 1568); + unified_tests!(mlkem768, MlKem::MlKem768, 1184, 1088); + unified_tests!(mlkem1024, MlKem::MlKem1024, 1568, 1568); #[test] fn params_constants() { - assert_eq!(MlKemParams::MlKem768.public_key_bytes(), 1184); - assert_eq!(MlKemParams::MlKem768.ciphertext_bytes(), 1088); - assert_eq!(MlKemParams::MlKem1024.public_key_bytes(), 1568); - assert_eq!(MlKemParams::MlKem1024.ciphertext_bytes(), 1568); + assert_eq!(MlKem::MlKem768.public_key_bytes(), 1184); + assert_eq!(MlKem::MlKem768.ciphertext_bytes(), 1088); + assert_eq!(MlKem::MlKem1024.public_key_bytes(), 1568); + assert_eq!(MlKem::MlKem1024.ciphertext_bytes(), 1568); } #[test] fn cross_kem_incompatibility() { // Keys from one KEM variant should not work with another - let kem768 = MlKem::new(MlKemParams::MlKem768); - let kem1024 = MlKem::new(MlKemParams::MlKem1024); + let kem768 = MlKem::MlKem768; + let kem1024 = MlKem::MlKem1024; let (pk768, _) = kem768.generate_key().unwrap(); let (pk1024, _) = kem1024.generate_key().unwrap();