Add derive_arrays helper, for getting multiple buffers out of HKDF

(and similar derivations, like password hashing)

Packages up a zerocopy-based pattern originally added by Alex into
a more conventional interface.
This commit is contained in:
Jordan Rose 2026-01-12 15:02:42 -08:00
parent bc7d2af953
commit a24341b044
9 changed files with 103 additions and 98 deletions

1
Cargo.lock generated
View File

@ -2513,6 +2513,7 @@ dependencies = [
"thiserror 2.0.17",
"uuid",
"x25519-dalek",
"zerocopy",
]
[[package]]

View File

@ -11,8 +11,8 @@
//! SVR, so that a restorer can reconstruct the `BackupId`.
use hkdf::Hkdf;
use libsignal_core::Aci;
use libsignal_core::curve::PrivateKey;
use libsignal_core::{Aci, derive_arrays};
use partial_default::PartialDefault;
use sha2::Sha256;
@ -156,21 +156,16 @@ impl<const VERSION: u8> BackupKey<VERSION> {
&self,
salt: &[u8],
) -> BackupForwardSecrecyEncryptionKey {
let mut bytes = [0u8; BACKUP_FORWARD_SECRECY_ENCRYPTION_KEY_CIPHER_KEY_SIZE
+ BACKUP_FORWARD_SECRECY_ENCRYPTION_KEY_HMAC_KEY_SIZE];
const INFO: &[u8] =
b"Signal Message Backup 20250627:BackupForwardSecrecyToken Encryption Key";
Hkdf::<Sha256>::new(Some(salt), &self.0)
.expand(INFO, &mut bytes)
.expect("valid length");
let (cipher_key, hmac_key, []) = derive_arrays(|bytes| {
const INFO: &[u8] =
b"Signal Message Backup 20250627:BackupForwardSecrecyToken Encryption Key";
Hkdf::<Sha256>::new(Some(salt), &self.0)
.expand(INFO, bytes)
.expect("valid length");
});
BackupForwardSecrecyEncryptionKey {
cipher_key: bytes[..BACKUP_FORWARD_SECRECY_ENCRYPTION_KEY_CIPHER_KEY_SIZE]
.try_into()
.expect("should have enough bytes"),
hmac_key: bytes[BACKUP_FORWARD_SECRECY_ENCRYPTION_KEY_CIPHER_KEY_SIZE..]
[..BACKUP_FORWARD_SECRECY_ENCRYPTION_KEY_HMAC_KEY_SIZE]
.try_into()
.expect("should have enough bytes"),
cipher_key,
hmac_key,
}
}
}

View File

@ -26,6 +26,7 @@ use argon2::{
Algorithm, Argon2, ParamsBuilder, PasswordHash, PasswordHasher, PasswordVerifier, Version,
};
use hkdf::Hkdf;
use libsignal_core::try_derive_arrays;
use sha2::Sha256;
use crate::error::Result;
@ -60,15 +61,12 @@ impl PinHash {
.build()
.expect("valid params"),
);
let mut output_key_material = [0u8; 64];
hasher.hash_password_into(pin, salt, &mut output_key_material)?;
let (encryption_key, access_key, []) = try_derive_arrays(|output_key_material| {
hasher.hash_password_into(pin, salt, output_key_material)
})?;
Ok(PinHash {
encryption_key: output_key_material[..32]
.try_into()
.expect("target length 32"),
access_key: output_key_material[32..]
.try_into()
.expect("target length 32"),
encryption_key,
access_key,
})
}

View File

@ -24,6 +24,7 @@ subtle = { workspace = true }
thiserror = { workspace = true }
uuid = { workspace = true }
x25519-dalek = { workspace = true, features = ["static_secrets"] }
zerocopy = { workspace = true }
[dev-dependencies]
assert_matches = { workspace = true }

View File

@ -28,3 +28,38 @@ pub use version::VERSION;
pub fn try_scoped<T, E>(f: impl FnOnce() -> Result<T, E>) -> Result<T, E> {
f()
}
/// Produces three arrays of output based on a callback that fills a single buffer of bytes.
///
/// The length of the buffer will always be `N1 + N2 + N3` (it would be possible to guarantee this
/// using `generic_array`, but the added complexity doesn't actually buy us anything). If you don't
/// need all three arrays, you can pattern-match the last one as `[]` to both discard it and infer
/// its length as 0.
#[inline]
pub fn derive_arrays<const N1: usize, const N2: usize, const N3: usize>(
derive: impl FnOnce(&mut [u8]),
) -> ([u8; N1], [u8; N2], [u8; N3]) {
let Ok(result) = try_derive_arrays::<N1, N2, N3, std::convert::Infallible>(move |buffer| {
derive(buffer);
Ok(())
});
result
}
/// Like [`derive_arrays`], but the callback is permitted to fail.
#[inline]
#[allow(clippy::type_complexity)]
pub fn try_derive_arrays<const N1: usize, const N2: usize, const N3: usize, E>(
derive: impl FnOnce(&mut [u8]) -> Result<(), E>,
) -> Result<([u8; N1], [u8; N2], [u8; N3]), E> {
#[derive(zerocopy::KnownLayout, zerocopy::FromBytes, zerocopy::IntoBytes)]
#[repr(C)]
struct DerivedValues<const N1: usize, const N2: usize, const N3: usize>(
[u8; N1],
[u8; N2],
[u8; N3],
);
let mut derived_values: DerivedValues<N1, N2, N3> = zerocopy::FromZeros::new_zeroed();
derive(zerocopy::IntoBytes::as_mut_bytes(&mut derived_values))?;
Ok((derived_values.0, derived_values.1, derived_values.2))
}

View File

@ -7,6 +7,7 @@
use hkdf::Hkdf;
use libsignal_account_keys::{BackupForwardSecrecyToken, BackupId, BackupKey};
use libsignal_core::derive_arrays;
use sha2::Sha256;
#[derive(Debug)]
@ -33,35 +34,27 @@ impl MessageBackupKey {
backup_id: &BackupId,
backup_nonce: Option<&BackupForwardSecrecyToken>,
) -> Self {
let mut full_bytes = [0; MessageBackupKey::LEN];
let (hmac_key, aes_key, []) = derive_arrays(|full_bytes| {
// See [`BackupKey::derive_backup_id`] for an explanation of this pattern.
match VERSION {
// Disable inference by using explicit type syntax <>, giving us the latest version.
<BackupKey>::VERSION => {
const OLD_DST: &[u8] = b"20241007_SIGNAL_BACKUP_ENCRYPT_MESSAGE_BACKUP:";
const NEW_DST: &[u8] = b"20250708_SIGNAL_BACKUP_ENCRYPT_MESSAGE_BACKUP:";
// See [`BackupKey::derive_backup_id`] for an explanation of this pattern.
match VERSION {
// Disable inference by using explicit type syntax <>, giving us the latest version.
<BackupKey>::VERSION => {
const OLD_DST: &[u8] = b"20241007_SIGNAL_BACKUP_ENCRYPT_MESSAGE_BACKUP:";
const NEW_DST: &[u8] = b"20250708_SIGNAL_BACKUP_ENCRYPT_MESSAGE_BACKUP:";
let (salt, dst) = match backup_nonce {
Some(nonce) => (Some(&nonce.0[..]), NEW_DST),
None => (None, OLD_DST),
};
let (salt, dst) = match backup_nonce {
Some(nonce) => (Some(&nonce.0[..]), NEW_DST),
None => (None, OLD_DST),
};
Hkdf::<Sha256>::new(salt, &backup_key.0)
.expand_multi_info(&[dst, &backup_id.0], &mut full_bytes)
.expect("valid length");
Hkdf::<Sha256>::new(salt, &backup_key.0)
.expand_multi_info(&[dst, &backup_id.0], full_bytes)
.expect("valid length");
}
_ => unreachable!("invalid backup key version"),
}
_ => unreachable!("invalid backup key version"),
}
// TODO split into arrays instead of slices when the API for that is
// stabilized. See https://github.com/rust-lang/rust/issues/90091
let (hmac_key, aes_key) = full_bytes.split_at(Self::HMAC_KEY_LEN);
Self {
hmac_key: hmac_key.try_into().expect("correct length"),
aes_key: aes_key.try_into().expect("correct length"),
}
});
Self { hmac_key, aes_key }
}
}

View File

@ -6,6 +6,7 @@
mod keys;
mod params;
use libsignal_core::derive_arrays;
use rand::{CryptoRng, Rng};
pub(crate) use self::keys::{ChainKey, MessageKeyGenerator, RootKey};
@ -24,16 +25,15 @@ fn derive_keys(secret_input: &[u8]) -> (RootKey, ChainKey, InitialPQRKey) {
}
fn derive_keys_with_label(label: &[u8], secret_input: &[u8]) -> (RootKey, ChainKey, InitialPQRKey) {
let mut secrets = [0; 96];
hkdf::Hkdf::<sha2::Sha256>::new(None, secret_input)
.expand(label, &mut secrets)
.expect("valid length");
let (root_key_bytes, chain_key_bytes, pqr_bytes) =
(&secrets[0..32], &secrets[32..64], &secrets[64..96]);
let (root_key_bytes, chain_key_bytes, pqr_bytes) = derive_arrays(|bytes| {
hkdf::Hkdf::<sha2::Sha256>::new(None, secret_input)
.expand(label, bytes)
.expect("valid length")
});
let root_key = RootKey::new(root_key_bytes.try_into().expect("correct length"));
let chain_key = ChainKey::new(chain_key_bytes.try_into().expect("correct length"), 0);
let pqr_key: InitialPQRKey = pqr_bytes.try_into().expect("correct length");
let root_key = RootKey::new(root_key_bytes);
let chain_key = ChainKey::new(chain_key_bytes, 0);
let pqr_key: InitialPQRKey = pqr_bytes;
(root_key, chain_key, pqr_key)
}

View File

@ -5,7 +5,7 @@
use std::fmt;
use zerocopy::{FromBytes, IntoBytes, KnownLayout};
use libsignal_core::derive_arrays;
use crate::proto::storage::session_structure;
use crate::{PrivateKey, PublicKey, Result, crypto};
@ -92,16 +92,11 @@ impl MessageKeys {
optional_salt: Option<&[u8]>,
counter: u32,
) -> Self {
#[derive(Default, KnownLayout, IntoBytes, FromBytes)]
#[repr(C, packed)]
struct DerivedSecretBytes([u8; 32], [u8; 32], [u8; 16]);
let mut okm = DerivedSecretBytes::default();
hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
.expand(b"WhisperMessageKeys", okm.as_mut_bytes())
.expect("valid output length");
let DerivedSecretBytes(cipher_key, mac_key, iv) = okm;
let (cipher_key, mac_key, iv) = derive_arrays(|okm| {
hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
.expand(b"WhisperMessageKeys", okm)
.expect("valid output length")
});
MessageKeys {
cipher_key,
@ -195,16 +190,11 @@ impl RootKey {
our_ratchet_key: &PrivateKey,
) -> Result<(RootKey, ChainKey)> {
let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
#[derive(Default, KnownLayout, IntoBytes, FromBytes)]
#[repr(C, packed)]
struct DerivedSecretBytes([u8; 32], [u8; 32]);
let mut derived_secret_bytes = DerivedSecretBytes::default();
hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
.expand(b"WhisperRatchet", derived_secret_bytes.as_mut_bytes())
.expect("valid output length");
let DerivedSecretBytes(root_key, chain_key) = derived_secret_bytes;
let (root_key, chain_key, []) = derive_arrays(|bytes| {
hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
.expand(b"WhisperRatchet", bytes)
.expect("valid output length")
});
Ok((
RootKey { key: root_key },

View File

@ -701,7 +701,7 @@ mod sealed_sender_v1 {
#[cfg(test)]
use std::fmt;
use zerocopy::IntoBytes;
use libsignal_core::derive_arrays;
use super::*;
@ -732,15 +732,11 @@ mod sealed_sender_v1 {
.concat();
let shared_secret = our_keys.private_key.calculate_agreement(their_public)?;
#[derive(Default, KnownLayout, IntoBytes, FromBytes)]
#[repr(C, packed)]
struct DerivedValues([u8; 32], [u8; 32], [u8; 32]);
let mut derived_values = DerivedValues::default();
hkdf::Hkdf::<sha2::Sha256>::new(Some(&ephemeral_salt), &shared_secret)
.expand(&[], derived_values.as_mut_bytes())
.expect("valid output length");
let DerivedValues(chain_key, cipher_key, mac_key) = derived_values;
let (chain_key, cipher_key, mac_key) = derive_arrays(|bytes| {
hkdf::Hkdf::<sha2::Sha256>::new(Some(&ephemeral_salt), &shared_secret)
.expand(&[], bytes)
.expect("valid output length")
});
Ok(Self {
chain_key,
@ -794,15 +790,11 @@ mod sealed_sender_v1 {
// 96 bytes are derived, but the first 32 are discarded/unused. This is intended to
// mirror the way the EphemeralKeys are derived, even though StaticKeys does not end up
// requiring a third "chain key".
#[derive(Default, KnownLayout, IntoBytes, FromBytes)]
#[repr(C, packed)]
struct DerivedValues(#[allow(unused)] [u8; 32], [u8; 32], [u8; 32]);
let mut derived_values = DerivedValues::default();
hkdf::Hkdf::<sha2::Sha256>::new(Some(&salt), &shared_secret)
.expand(&[], derived_values.as_mut_bytes())
.expect("valid output length");
let DerivedValues(_, cipher_key, mac_key) = derived_values;
let (_, cipher_key, mac_key) = derive_arrays::<32, 32, 32>(|bytes| {
hkdf::Hkdf::<sha2::Sha256>::new(Some(&salt), &shared_secret)
.expand(&[], bytes)
.expect("valid output length")
});
Ok(Self {
cipher_key,