From d308d59289e9a76f3a9b934766ab107d2a46e661 Mon Sep 17 00:00:00 2001 From: adel-signal Date: Wed, 15 Apr 2026 14:51:35 -0700 Subject: [PATCH] Fix bugs in signaling validations and SfuToDevice RTP handling --- mrp/src/stream.rs | 40 +++++-- src/rust/src/core/call_manager.rs | 111 ++++++++++--------- src/rust/src/core/call_rwlock.rs | 2 +- src/rust/src/core/connection.rs | 20 ++++ src/rust/src/core/group_call.rs | 171 ++++++++++++++++++++++++------ src/rust/src/webrtc/arc.rs | 2 +- src/rust/tests/outgoing.rs | 85 +++++++++++++++ 7 files changed, 338 insertions(+), 93 deletions(-) diff --git a/mrp/src/stream.rs b/mrp/src/stream.rs index 271b9e51..5f263b33 100644 --- a/mrp/src/stream.rs +++ b/mrp/src/stream.rs @@ -114,6 +114,8 @@ pub enum MrpReceiveError { ReceiveWindowFull(u64), #[error("Received unexpected num packets while merge already in progress")] PacketMergeConflict, + #[error("Specified num_packets is too large for buffer: {0}")] + InvalidNumPackets(u32), #[error("Unexpected error in merge")] InvalidMergeState, } @@ -209,6 +211,10 @@ where if num_packets <= 1 { // treat num_packets == 0 case the same as no num_packets result.push(data); + } else if usize::try_from(num_packets).map_or(true, |num_packets| { + num_packets > self.receive_buffer.capacity_limit() + }) { + return Err(MrpReceiveError::InvalidNumPackets(num_packets)); } else { let mut buffer = MergeBuffer::new(num_packets).unwrap(); let _ = buffer.push(data); @@ -1029,6 +1035,19 @@ mod tests { "Should have finished dropping failed packets" ); + let should_be_error = alice.receive_and_merge( + &MrpHeader { + seqnum: Some(alice.ack_seqnum()), + num_packets: Some(10_000), + ..Default::default() + }, + extendable_packet(None, vec![1]), + ); + assert_eq!( + should_be_error, + Err(MrpReceiveError::InvalidNumPackets(10_000)) + ); + let mut returned = None; let mut bob: MrpStream = MrpStream::with_capacity_limit(16); @@ -1375,12 +1394,13 @@ mod tests { results } - let alice_merge_intervals = generate_random_intervals(1, num_packets as u32); - let bob_merge_intervals = generate_random_intervals(1, num_packets as u32); + let buffer_size = 64; + let alice_merge_intervals = generate_random_intervals(1, num_packets as u32, buffer_size); + let bob_merge_intervals = generate_random_intervals(1, num_packets as u32, buffer_size); let bob_expected_results = expected_results(num_packets, &alice_merge_intervals); let alice_expected_results = expected_results(num_packets, &bob_merge_intervals); - let alice = MrpStream::with_capacity_limit(64); - let bob = MrpStream::with_capacity_limit(64); + let alice = MrpStream::with_capacity_limit(buffer_size); + let bob = MrpStream::with_capacity_limit(buffer_size); let (to_alice, alice_inbox) = mpsc::channel(); let (to_bob, bob_inbox) = mpsc::channel(); let alice_receiver = DelayReceiver::new( @@ -1545,11 +1565,16 @@ mod tests { }) } - fn generate_random_intervals(min_seqnum: u32, max_seqnum: u32) -> Vec<(u32, u32)> { + fn generate_random_intervals( + min_seqnum: u32, + max_seqnum: u32, + max_merge_size: usize, + ) -> Vec<(u32, u32)> { if min_seqnum >= max_seqnum { return vec![]; } + let max_merge_offset = max_merge_size as u32 - 1; let mut rng = rng(); let mut highest = min_seqnum; let mut intervals = Vec::new(); @@ -1560,7 +1585,10 @@ mod tests { break v; } }; - let end = rng.random_range((start + 1)..=max_seqnum); + let end = std::cmp::min( + start + max_merge_offset, + rng.random_range((start + 1)..=max_seqnum), + ); intervals.push((start, end)); highest = end + 1; } diff --git a/src/rust/src/core/call_manager.rs b/src/rust/src/core/call_manager.rs index 8b055333..3ce7a051 100644 --- a/src/rust/src/core/call_manager.rs +++ b/src/rust/src/core/call_manager.rs @@ -201,6 +201,7 @@ where /// Information about a received group ring that hasn't yet been accepted or cancelled. #[derive(Debug)] struct OutstandingGroupRing { + sender_id: UserId, ring_id: group_call::RingId, received: Instant, } @@ -646,14 +647,19 @@ where handle_active_call_api!(self, CallManager::handle_hangup) } + /// removes outstanding group ring. If expected user ID is specified, verifies the pending + /// ring's sender_id matches before removing fn remove_outstanding_group_ring( &mut self, group_id: group_call::GroupIdRef, + expected_user_id: Option<&UserId>, ring_id: group_call::RingId, ) -> Result<()> { let mut outstanding_group_rings = self.outstanding_group_rings.lock()?; if let Some(ring) = outstanding_group_rings.get(group_id) && ring.ring_id == ring_id + && expected_user_id + .is_none_or(|expected| ring.sender_id.as_slice() == expected.as_slice()) { outstanding_group_rings.remove(group_id); } @@ -669,7 +675,7 @@ where ) -> Result<()> { info!("cancel_group_ring(): ring_id: {}", ring_id); - self.remove_outstanding_group_ring(&group_id, ring_id)?; + self.remove_outstanding_group_ring(&group_id, None, ring_id)?; if let Some(reason) = reason { let self_uuid = self @@ -1277,64 +1283,58 @@ where /// Handle message_send_failure() API from application. fn handle_message_send_failure(&mut self, call_id: CallId) -> Result<()> { - let mut should_handle = true; + let (call, is_active) = self + .active_call() + .ok() + .filter(|call| call.call_id() == call_id) + .map(|call| (Some(call), true)) + .unwrap_or_else(|| { + if let Ok(call_map) = self.call_by_call_id.lock() { + (call_map.get(&call_id).cloned(), false) + } else { + (None, false) + } + }); - let active_call = self.active_call().ok().inspect(|active_call| { - if active_call.call_id() == call_id - && active_call - .state() - .is_ok_and(|state| state.connected_or_reconnecting()) - { + if let Some(call) = call { + if is_active { // Get the last sent message type and see if it was for ICE. // Since we are in a connected state, don't handle it if so. - if let Ok(message_queue) = self.message_queue.lock() - && message_queue.last_sent_message_type == Some(signaling::MessageType::Ice) + if call + .state() + .is_ok_and(|state| state.connected_or_reconnecting()) + && self.message_queue.lock().is_ok_and(|queue| { + queue.last_sent_message_type == Some(signaling::MessageType::Ice) + }) { - should_handle = false; + warn!( + "handle_message_send_failure(): id: {}, failed to send ICE message but staying in call", + call_id + ); + } else { + info!( + "handle_message_send_failure(): id: {}, terminating active call", + call_id + ); + let _ = self.terminate_active_call( + call.should_send_hangup_on_failure(), + CallEndReason::SignalingFailure, + ); } - } - }); - - if should_handle { - if let Some(active_call) = active_call { + } else { info!( - "handle_message_send_failure(): id: {}, terminating active call", + "handle_message_send_failure(): id: {}, terminating call", call_id ); - let _ = self.terminate_active_call( - active_call.should_send_hangup_on_failure(), - CallEndReason::SignalingFailure, - ); - } else { - // See if the associated call is in the call map. - let mut call = None; - { - if let Ok(call_map) = self.call_by_call_id.lock() - && let Some(v) = call_map.get(&call_id) - { - call = Some(v.clone()); - }; - } + let hangup = call + .should_send_hangup_on_failure() + .then_some(signaling::Hangup::Normal); - match call { - Some(call) => { - info!( - "handle_message_send_failure(): id: {}, terminating call", - call_id - ); - - let hangup = call - .should_send_hangup_on_failure() - .then_some(signaling::Hangup::Normal); - - self.terminate_call(call, hangup, Some(CallEndReason::SignalingFailure))?; - } - None => { - info!("handle_message_send_failure(): no matching call found"); - } - } + self.terminate_call(call, hangup, Some(CallEndReason::SignalingFailure))?; } + } else { + info!("handle_message_send_failure(): no matching call found"); } match self.message_queue.lock() { @@ -1804,7 +1804,11 @@ where } } IntentionType::Cancelled => { - self.remove_outstanding_group_ring(group_id, ring_id.into())?; + self.remove_outstanding_group_ring( + group_id, + Some(&sender_uuid), + ring_id.into(), + )?; group_call::RingUpdate::CancelledByRinger } }; @@ -1863,7 +1867,7 @@ where } ResponseType::Ringing => unreachable!("handled above"), }; - self.remove_outstanding_group_ring(group_id, ring_id.into())?; + self.remove_outstanding_group_ring(group_id, None, ring_id.into())?; self.platform.lock()?.group_call_ring_update( std::mem::take(group_id), ring_id.into(), @@ -1919,6 +1923,7 @@ where outstanding_group_rings.insert( group_id.clone(), OutstandingGroupRing { + sender_id: sender_uuid.clone(), ring_id, received: Instant::now(), }, @@ -1942,7 +1947,11 @@ where self.worker .send_delayed(*INCOMING_GROUP_CALL_RING_TIME, move |_| { let result = try_scoped(|| { - self_for_timeout.remove_outstanding_group_ring(&group_id, ring_id)?; + self_for_timeout.remove_outstanding_group_ring( + &group_id, + Some(&sender_uuid), + ring_id, + )?; self_for_timeout.platform.lock()?.group_call_ring_update( group_id, ring_id, diff --git a/src/rust/src/core/call_rwlock.rs b/src/rust/src/core/call_rwlock.rs index c69f4cf5..5465ac0c 100644 --- a/src/rust/src/core/call_rwlock.rs +++ b/src/rust/src/core/call_rwlock.rs @@ -21,7 +21,7 @@ pub struct CallRwLock { } unsafe impl Send for CallRwLock {} -unsafe impl Sync for CallRwLock {} +unsafe impl Sync for CallRwLock {} impl CallRwLock { /// Creates a new CallRwLock diff --git a/src/rust/src/core/connection.rs b/src/rust/src/core/connection.rs index ca5f6cd0..b93575db 100644 --- a/src/rust/src/core/connection.rs +++ b/src/rust/src/core/connection.rs @@ -2259,6 +2259,10 @@ fn negotiate_srtp_keys( .map_err(|_| RingRtcError::InvalidRemoteSrtpKey)?; let shared_secret = local_secret.diffie_hellman(&remote_public_key); + if !shared_secret.was_contributory() { + error!("remote secret was non-contributory, rejecting srtp negotiation"); + return Err(RingRtcError::InvalidRemoteSrtpKey.into()); + } let hkdf_salt = vec![0u8; 32]; let hkdf_info_prefix = "Signal_Calling_20200807_SignallingDH_SRTPKey_KDF"; @@ -2349,4 +2353,20 @@ mod tests { assert_eq!(expect(300_000), compute(Low, 1_000_000, true)); assert_eq!(expect(300_000), compute(Low, 300_000, true)); } + + #[test] + fn negotiate_srtp_keys_rejects_low_order_remote_key() { + let local_secret = StaticSecret::random_from_rng(OsRng); + let low_order_key = [0u8; 32]; + let result = + negotiate_srtp_keys(&local_secret, &low_order_key, b"caller_key", b"callee_key"); + let err = result + .err() + .expect("expected an error for a non-contributory remote key"); + assert!( + err.downcast_ref::() + .is_some_and(|e| matches!(e, RingRtcError::InvalidRemoteSrtpKey)), + "expected RingRtcError::InvalidRemoteSrtpKey, got: {err}" + ); + } } diff --git a/src/rust/src/core/group_call.rs b/src/rust/src/core/group_call.rs index 61c8ec56..57efc2a5 100644 --- a/src/rust/src/core/group_call.rs +++ b/src/rust/src/core/group_call.rs @@ -473,6 +473,9 @@ enum DheState { WaitingForServerPublicKey { client_secret: EphemeralSecret, }, + FailedToNegotiate { + reason: &'static str, + }, Negotiated { srtp_keys: SrtpKeys, }, @@ -495,21 +498,27 @@ impl DheState { } DheState::WaitingForServerPublicKey { client_secret } => { let shared_secret = client_secret.diffie_hellman(server_pub_key); - let mut master_key_material = [0u8; SrtpKeys::MASTER_KEY_MATERIAL_LEN]; - Hkdf::::new(Some(&[0u8; 32]), shared_secret.as_bytes()) - .expand_multi_info( - &[ - b"Signal_Group_Call_20211105_SignallingDH_SRTPKey_KDF", - hkdf_extra_info, - ], - &mut master_key_material, - ) - .expect("SRTP master key material expansion"); - DheState::Negotiated { - srtp_keys: SrtpKeys::from_master_key_material(&master_key_material), + if !shared_secret.was_contributory() { + DheState::FailedToNegotiate { + reason: "SFU provided remote secret was non-contributory, rejecting srtp negotiation", + } + } else { + let mut master_key_material = [0u8; SrtpKeys::MASTER_KEY_MATERIAL_LEN]; + Hkdf::::new(Some(&[0u8; 32]), shared_secret.as_bytes()) + .expand_multi_info( + &[ + b"Signal_Group_Call_20211105_SignallingDH_SRTPKey_KDF", + hkdf_extra_info, + ], + &mut master_key_material, + ) + .expect("SRTP master key material expansion"); + DheState::Negotiated { + srtp_keys: SrtpKeys::from_master_key_material(&master_key_material), + } } } - DheState::Negotiated { .. } => { + DheState::Negotiated { .. } | DheState::FailedToNegotiate { .. } => { warn!("Attempting to negotiated SRTP keys a second time."); self } @@ -2756,6 +2765,11 @@ impl Client { ); let srtp_keys = match &state.dhe_state { DheState::Negotiated { srtp_keys } => srtp_keys, + DheState::FailedToNegotiate { reason } => { + error!("join() failed: {reason}"); + Self::end(state, CallEndReason::FailedToNegotiatedSrtpKeys); + return; + } _ => { Self::end(state, CallEndReason::FailedToNegotiatedSrtpKeys); return; @@ -2978,7 +2992,7 @@ impl Client { .. } => { if group_id == state.group_id { - Self::handle_leaving_received(state, leaving_demux_id); + Self::handle_leaving_received(state, Some(sender_user_id), leaving_demux_id); } } _ => { @@ -4320,7 +4334,7 @@ impl Client { } if let Some(_leaving) = msg.leaving { self.actor.send(move |state| { - Self::handle_leaving_received(state, demux_id); + Self::handle_leaving_received(state, None, demux_id); }); } if let Some(reaction) = msg.reaction { @@ -4371,11 +4385,25 @@ impl Client { { Ok(ready_packets) => { for (buffered_header, sfu_to_device) in ready_packets { - Self::handle_sfu_to_device_inner( - &state.actor, - buffered_header, - sfu_to_device, - ) + // If content is present, we should not process any other fields on + // this proto because that would allow for nested protos. Nested protos + // cause thrashing, excessive updates, and hard to follow processing + // order. + if let Some(content) = sfu_to_device.content { + match SfuToDevice::decode(content.as_slice()) { + Ok(msg) => Self::handle_sfu_to_device_inner(&state.actor, buffered_header, msg), + Err(err) => { + error!("Failed to decode content buffer in SfuToDevice: {:?}", err); + } + } + return; + } else { + Self::handle_sfu_to_device_inner( + &state.actor, + buffered_header, + sfu_to_device, + ) + } } } err @ Err(MrpReceiveError::ReceiveWindowFull(_)) => { @@ -4411,21 +4439,10 @@ impl Client { removed, raised_hands, mrp_header: _, - content, + content: _, endorsements, } = msg; - if let Some(content) = content { - match SfuToDevice::decode(content.as_slice()) { - Ok(msg) => Self::handle_sfu_to_device_inner(actor, header, msg), - Err(err) => { - error!("Failed to decode content buffer in SfuToDevice: {:?}", err); - } - } - // ignore all other fields to prevent ordering issues - return; - } - if let Some(Speaker { demux_id: speaker_demux_id, }) = speaker @@ -4790,7 +4807,13 @@ impl Client { }); } - fn handle_leaving_received(state: &mut State, demux_id: DemuxId) { + /// Set state that device is leaving. If expected_user_id is provided, then validate the + /// demuxID's related userID against it. + fn handle_leaving_received( + state: &mut State, + expected_user_id: Option, + demux_id: DemuxId, + ) { // It's likely we haven't received an update from the SFU about this demux_id leaving. debug!( "Request devices because we just received a leaving message from demux_id = {}", @@ -4799,6 +4822,13 @@ impl Client { if let Some(device) = state.remote_devices.find_by_demux_id_mut(demux_id) && !device.leaving_received { + if expected_user_id.is_some_and(|expected| expected != device.user_id) { + warn!( + "Received Leaving message for demux ID {demux_id} but sender's user ID did not match expected user ID, so ignoring" + ); + + return; + } device.leaving_received = true; Self::request_remote_devices_as_soon_as_possible(state); @@ -5402,6 +5432,7 @@ mod tests { era_id: String, response_join_state: Arc>, joins_remaining: Option>, + server_dhe_pub_key: [u8; 32], } #[derive(Default)] @@ -5423,6 +5454,8 @@ mod tests { call_creator: Option, options: FakeSfuClientOptions, ) -> Self { + let server_secret = EphemeralSecret::random_from_rng(OsRng); + let server_dhe_pub_key = *PublicKey::from(&server_secret).as_bytes(); Self { sfu_info: SfuInfo { udp_addresses: Vec::new(), @@ -5440,6 +5473,7 @@ mod tests { joins_remaining: options .max_joins .map(|v| Arc::new(AtomicI64::new(v as i64))), + server_dhe_pub_key, } } @@ -5478,7 +5512,7 @@ mod tests { client.on_sfu_client_join_attempt_completed(Ok(Joined { sfu_info: self.sfu_info.clone(), local_demux_id: self.local_demux_id, - server_dhe_pub_key: [0u8; 32], + server_dhe_pub_key: self.server_dhe_pub_key, hkdf_extra_info: b"hkdf_extra_info".to_vec(), creator: self.call_creator.clone(), era_id: self.era_id.clone(), @@ -7676,6 +7710,60 @@ mod tests { ); } + #[test] + fn ignore_leaving_message_from_wrong_sender() { + use protobuf::group_call::{DeviceToDevice, device_to_device::Leaving}; + + let client1 = TestClient::new(vec![1], 1); + client1.connect_join_and_wait_until_joined(); + let client2 = TestClient::new(vec![2], 2); + client2.connect_join_and_wait_until_joined(); + + client1.set_remotes_and_wait_until_applied(&[&client2]); + + let fake_group_id = b"fake group ID".to_vec(); + + // Use actor task to get state to ensure ordering + let get_leaving_received = |client: &TestClient| { + let (tx, rx) = mpsc::channel(); + client.client.actor.send(move |state| { + let val = state + .remote_devices + .find_by_demux_id(client2.demux_id) + .map(|d| d.leaving_received) + .unwrap_or(false); + tx.send(val).unwrap(); + }); + rx.recv_timeout(Duration::from_secs(5)).unwrap() + }; + + let make_leaving_msg = |group_id: Vec| DeviceToDevice { + group_id: Some(group_id), + leaving: Some(Leaving { + demux_id: Some(client2.demux_id), + }), + ..DeviceToDevice::default() + }; + + let wrong_user_id: UserId = b"wrong_user_id".to_vec(); + client1 + .client + .on_signaling_message_received(wrong_user_id, make_leaving_msg(fake_group_id.clone())); + assert!( + !get_leaving_received(&client1), + "leaving from wrong sender should be rejected" + ); + + client1.client.on_signaling_message_received( + client2.user_id.clone(), + make_leaving_msg(fake_group_id), + ); + assert!( + get_leaving_received(&client1), + "leaving from correct sender should be accepted" + ); + } + #[test] fn device_to_sfu_remove() { use protobuf::group_call::{ @@ -9459,6 +9547,21 @@ mod remote_devices_tests { }; } + #[test] + fn dhe_state_fails_to_negotiate_with_low_order_server_key() { + let client_secret = EphemeralSecret::random_from_rng(OsRng); + let low_order_server_key = PublicKey::from([0u8; 32]); + + let state = DheState::start(client_secret); + assert!(matches!(state, DheState::WaitingForServerPublicKey { .. })); + + let state = state.negotiate(&low_order_server_key, b"hkdf_extra_info"); + assert!( + matches!(state, DheState::FailedToNegotiate { .. }), + "expected FailedToNegotiate, got a different DheState variant" + ); + } + #[test] fn test_mrp_max_size_limit() { let content = [5u8; MAX_MRP_FRAGMENT_BYTE_SIZE]; diff --git a/src/rust/src/webrtc/arc.rs b/src/rust/src/webrtc/arc.rs index 8179a058..4509b5be 100644 --- a/src/rust/src/webrtc/arc.rs +++ b/src/rust/src/webrtc/arc.rs @@ -89,4 +89,4 @@ impl Drop for Arc { } unsafe impl Send for Arc {} -unsafe impl Sync for Arc {} +unsafe impl Sync for Arc {} diff --git a/src/rust/tests/outgoing.rs b/src/rust/tests/outgoing.rs index 548b1d91..d87bd91f 100644 --- a/src/rust/tests/outgoing.rs +++ b/src/rust/tests/outgoing.rs @@ -3518,6 +3518,91 @@ fn group_call_ring_cancelled_by_ringer_before_join() { ); } +#[test] +fn group_call_ring_not_cancelled_by_different_sender() { + test_init(); + + let context = TestContext::new(); + let mut cm = context.cm(); + + let self_uuid = vec![1, 0, 1]; + cm.set_self_uuid(self_uuid.clone()).expect(error_line!()); + + let group_id = vec![1, 1, 1]; + let original_ringer = vec![1, 2, 3]; + let different_sender = vec![4, 5, 6]; + let ring_id = group_call::RingId::from(42); + + let ring_message = protobuf::signaling::CallMessage { + ring_intention: Some(protobuf::signaling::call_message::RingIntention { + group_id: Some(group_id.clone()), + ring_id: Some(ring_id.into()), + r#type: Some(protobuf::signaling::call_message::ring_intention::Type::Ring.into()), + }), + ..Default::default() + }; + let mut buf = Vec::new(); + ring_message + .encode(&mut buf) + .expect("cannot fail encoding to Vec"); + + cm.received_call_message(original_ringer, 1, 2, buf, Duration::ZERO) + .expect(error_line!()); + + // Receive a cancellation from a different sender + let cancel_message = protobuf::signaling::CallMessage { + ring_intention: Some(protobuf::signaling::call_message::RingIntention { + group_id: Some(group_id.clone()), + ring_id: Some(ring_id.into()), + r#type: Some(protobuf::signaling::call_message::ring_intention::Type::Cancelled.into()), + }), + ..Default::default() + }; + let mut buf = Vec::new(); + cancel_message + .encode(&mut buf) + .expect("cannot fail encoding to Vec"); + + cm.received_call_message(different_sender, 1, 2, buf, Duration::ZERO) + .expect(error_line!()); + + cm.synchronize().expect(error_line!()); + + // Join the group call. Because the cancellation was from a different sender, we + // expect the ring is active and an acceptance message will be sent. + let group_call_id = context + .create_group_call(group_id.clone()) + .expect(error_line!()); + cm.join(group_call_id); + cm.synchronize().expect(error_line!()); + + let messages = cm + .platform() + .expect(error_line!()) + .take_outgoing_call_messages(); + match &messages[..] { + [message] => { + assert_eq!(&self_uuid[..], &message.recipient_id[..]); + let call_message = protobuf::signaling::CallMessage::decode(&message.message[..]) + .expect(error_line!()); + assert_eq!( + protobuf::signaling::CallMessage { + ring_response: Some(protobuf::signaling::call_message::RingResponse { + group_id: Some(group_id), + ring_id: Some(ring_id.into()), + r#type: Some( + protobuf::signaling::call_message::ring_response::Type::Accepted.into() + ), + }), + ..Default::default() + }, + call_message + ); + } + _ => panic!("expected one acceptance message, got: {:?}", messages), + } +} + #[test] fn group_call_ring_cancelled_by_another_device_before_join() { test_init();