ffi: Support multiple return values for bridge_callbacks

...and use it to avoid re-deriving the public key when fetching the
local identity key pair.
This commit is contained in:
Jordan Rose 2026-03-06 12:01:41 -08:00 committed by GitHub
parent 6a7cc67173
commit 55b233d43c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 15 deletions

View File

@ -1267,6 +1267,17 @@ optional_callback_result_type!(SenderKeyRecord);
optional_callback_result_type!(SignedPreKeyRecord);
optional_callback_result_type!(SessionRecord);
impl<A: CallbackResultTypeInfo, B: CallbackResultTypeInfo> CallbackResultTypeInfo for (A, B) {
type ResultType = PairOf<A::ResultType, B::ResultType>;
fn convert_from_callback(foreign: Self::ResultType) -> SignalFfiResult<Self> {
Ok((
A::convert_from_callback(foreign.first)?,
B::convert_from_callback(foreign.second)?,
))
}
}
macro_rules! trivial {
($typ:ty) => {
impl SimpleArgTypeInfo for $typ {
@ -1358,6 +1369,11 @@ macro_rules! ffi_arg_type {
(Result<$typ:tt $(, $ignored:ty)?>) => (ffi_arg_type!($typ));
(Result<$typ:tt<$($args:tt),+> $(, $ignored:ty)?>) => (ffi_arg_type!($typ<$($args),+>));
(Option<$typ:ty>) => (ffi::MutPointer< $typ >);
// Like Result, we can't use `:ty` here because we need the resulting tokens to be matched
// recursively. We can at least match several tokens in the second component though.
(($a:tt, $($b:tt)+)) => (ffi::PairOf<ffi_arg_type!($a), ffi_arg_type!($($b)+)>);
($typ:ty) => (ffi::MutPointer< $typ >);
}

View File

@ -17,9 +17,7 @@ use crate::support::{BridgedCallbacks, ResultLike, WithContext};
/// A bridge-friendly version of [`IdentityKeyStore`].
#[bridge_callbacks(jni = false, node = false)]
pub trait BridgeIdentityKeyStore {
// We ask for just the private key because IdentityKeyPair isn't a single bridge_handle; it's a
// pair of objects. This is easier to bridge.
fn get_local_identity_private_key(&self) -> Result<PrivateKey, SignalProtocolError>;
fn get_local_identity_key_pair(&self) -> Result<(PrivateKey, PublicKey), SignalProtocolError>;
fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError>;
fn get_identity_key(
&self,
@ -53,8 +51,7 @@ pub enum FfiDirection {
#[async_trait(?Send)]
impl<T: BridgeIdentityKeyStore> IdentityKeyStore for BridgedCallbacks<T> {
async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> {
let priv_key = self.0.get_local_identity_private_key()?;
let pub_key = priv_key.public_key()?;
let (priv_key, pub_key) = self.0.get_local_identity_key_pair()?;
Ok(IdentityKeyPair::new(IdentityKey::new(pub_key), priv_key))
}

View File

@ -11,16 +11,34 @@ internal func withIdentityKeyStore<Result>(
_ context: StoreContext,
_ body: (SignalConstPointerFfiIdentityKeyStoreStruct) throws -> Result
) throws -> Result {
func ffiShimGetIdentityPrivateKey(
func ffiShimGetLocalIdentityKeyPair(
storeCtx: UnsafeMutableRawPointer?,
keyp: UnsafeMutablePointer<SignalMutPointerPrivateKey>?
out: UnsafeMutablePointer<SignalPairOfMutPointerPrivateKeyMutPointerPublicKey>?
) -> Int32 {
let storeContext = storeCtx!.assumingMemoryBound(
to: ErrorHandlingContext<(IdentityKeyStore, StoreContext)>.self
)
return storeContext.pointee.catchCallbackErrors { store, context in
var privateKey = try store.identityKeyPair(context: context).privateKey
keyp!.pointee = try cloneOrTakeHandle(from: &privateKey)
let keyPair = try store.identityKeyPair(context: context)
var privateKey = keyPair.privateKey
var publicKey = keyPair.publicKey
// Increase the chances of avoiding a clone.
_ = consume keyPair
// Zero out the output first.
out!.pointee = .init()
out!.pointee.first = try cloneOrTakeHandle(from: &privateKey)
do {
out!.pointee.second = try cloneOrTakeHandle(from: &publicKey)
} catch {
// Just in case cloning fails, destroy any fields already filled in.
// (If *this* fails, we're in a really bad state and recovery isn't worth it.)
if let first = NonNull(out!.pointee.first) {
failOnError(PrivateKey.destroyNativeHandle(first))
}
out!.pointee = .init()
throw error
}
}
}
@ -112,7 +130,7 @@ internal func withIdentityKeyStore<Result>(
return try rethrowCallbackErrors((store, context)) {
var ffiStore = SignalIdentityKeyStore(
ctx: $0,
get_local_identity_private_key: ffiShimGetIdentityPrivateKey,
get_local_identity_key_pair: ffiShimGetLocalIdentityKeyPair,
get_local_registration_id: ffiShimGetLocalRegistrationId,
get_identity_key: ffiShimGetIdentity,
save_identity_key: ffiShimSaveIdentity,

View File

@ -784,14 +784,19 @@ typedef struct {
const SignalSessionStore *raw;
} SignalConstPointerFfiSessionStoreStruct;
typedef int (*SignalFfiBridgeIdentityKeyStoreGetLocalIdentityPrivateKey)(void *ctx, SignalMutPointerPrivateKey *out);
typedef int (*SignalFfiBridgeIdentityKeyStoreGetLocalRegistrationId)(void *ctx, uint32_t *out);
typedef struct {
SignalPublicKey *raw;
} SignalMutPointerPublicKey;
typedef struct {
SignalMutPointerPrivateKey first;
SignalMutPointerPublicKey second;
} SignalPairOfMutPointerPrivateKeyMutPointerPublicKey;
typedef int (*SignalFfiBridgeIdentityKeyStoreGetLocalIdentityKeyPair)(void *ctx, SignalPairOfMutPointerPrivateKeyMutPointerPublicKey *out);
typedef int (*SignalFfiBridgeIdentityKeyStoreGetLocalRegistrationId)(void *ctx, uint32_t *out);
typedef int (*SignalFfiBridgeIdentityKeyStoreGetIdentityKey)(void *ctx, SignalMutPointerPublicKey *out, SignalMutPointerProtocolAddress address);
typedef int (*SignalFfiBridgeIdentityKeyStoreSaveIdentityKey)(void *ctx, uint8_t *out, SignalMutPointerProtocolAddress address, SignalMutPointerPublicKey public_key);
@ -802,7 +807,7 @@ typedef void (*SignalFfiBridgeIdentityKeyStoreDestroy)(void *ctx);
typedef struct {
void *ctx;
SignalFfiBridgeIdentityKeyStoreGetLocalIdentityPrivateKey get_local_identity_private_key;
SignalFfiBridgeIdentityKeyStoreGetLocalIdentityKeyPair get_local_identity_key_pair;
SignalFfiBridgeIdentityKeyStoreGetLocalRegistrationId get_local_registration_id;
SignalFfiBridgeIdentityKeyStoreGetIdentityKey get_identity_key;
SignalFfiBridgeIdentityKeyStoreSaveIdentityKey save_identity_key;