diff --git a/Cargo.lock b/Cargo.lock index 852a5ff..751402f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -843,6 +843,7 @@ dependencies = [ "serde_with", "sha-1", "sha2", + "smallvec", "strum", "strum_macros", "thiserror 2.0.17", diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 125462d..375d740 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -16,6 +16,9 @@ mrp = { git = "https://github.com/signalapp/ringrtc", tag = "v2.59.1" } calling_common = { path = "../common" } metrics = { path = "../metrics" } +# For small, inline collections +smallvec = "1.15.1" + # For error handling anyhow = "1.0.100" thiserror = "2.0.17" diff --git a/backend/fuzz/Cargo.toml b/backend/fuzz/Cargo.toml index 76c31d5..f11c283 100644 --- a/backend/fuzz/Cargo.toml +++ b/backend/fuzz/Cargo.toml @@ -60,6 +60,12 @@ path = "fuzz_targets/googcc.rs" test = false doc = false +[[bin]] +name = "dependency-descriptor" +path = "fuzz_targets/dependency_descriptor.rs" +test = false +doc = false + # keep in sync with workspace patches [patch.crates-io] # Use our fork of curve25519-dalek because we're using zkgroup. diff --git a/backend/fuzz/fuzz_targets/dependency_descriptor.rs b/backend/fuzz/fuzz_targets/dependency_descriptor.rs new file mode 100644 index 0000000..a84b8a7 --- /dev/null +++ b/backend/fuzz/fuzz_targets/dependency_descriptor.rs @@ -0,0 +1,14 @@ +// +// Copyright 2021 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only +// + +#![no_main] + +use calling_backend::*; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: Vec| { + _ = rtp::DependencyDescriptor::read(&data, None); +}); + diff --git a/backend/src/bitstream.rs b/backend/src/bitstream.rs new file mode 100644 index 0000000..00695b4 --- /dev/null +++ b/backend/src/bitstream.rs @@ -0,0 +1,425 @@ +// +// Copyright 2026 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only +// + +use anyhow::bail; +use smallvec::SmallVec; + +#[derive(Debug)] +pub struct BitstreamReader<'a> { + bytes: &'a [u8], + /// The index into `bytes` of the next byte to read. + byte_index: usize, + /// The offset into `bytes[byte_index]` of the next bit to read. In the range 0..=7. + bit_offset: u8, +} + +impl<'a> BitstreamReader<'a> { + pub fn new(bytes: &'a [u8]) -> Self { + Self { + bytes, + byte_index: 0, + bit_offset: 0, + } + } + + pub fn read_u64(&mut self, bits: usize) -> anyhow::Result { + assert!(bits <= 64); + let mut result = 0; + if bits > 0 { + let byte_count = bits / 8; + if byte_count > 0 { + result |= u64::from(self.read_u8(8)?); + for _ in 0..byte_count - 1 { + result <<= 8; + result |= u64::from(self.read_u8(8)?); + } + } + let bits_left = bits - (byte_count * 8); + result <<= bits_left; + result |= u64::from(self.read_u8(bits_left as u8)?); + } + Ok(result) + } + + pub fn read_u32(&mut self, bits: usize) -> anyhow::Result { + assert!(bits <= 32); + self.read_u64(bits).map(|v| v as u32) + } + + pub fn read_u16(&mut self, bits: usize) -> anyhow::Result { + assert!(bits <= 16); + self.read_u64(bits).map(|v| v as u16) + } + + /// An implementation of the `f(n)` function in the spec, where 0 < n <= 8: + /// https://aomediacodec.github.io/av1-rtp-spec/#a82-syntax + pub fn read_u8(&mut self, bits: u8) -> anyhow::Result { + if bits == 0 { + return Ok(0); + } + + assert!(bits <= 8); + + let bytes_len = self.bytes.len(); + if self.byte_index >= bytes_len + || (self.bit_offset + bits > 8 && self.byte_index + 1 == bytes_len) + { + bail!( + "out of bounds access: byte_index={}, bit_offset={}, bits={bits}, bytes_len={}", + self.byte_index, + self.bit_offset, + bytes_len, + ); + } + + let mut byte: u8; + if self.bit_offset + bits >= 8 { + // Need to read the remainder of the current byte, and potentially some of the + // following byte. + byte = self.bytes[self.byte_index]; + + let num_bits_in_current_byte = 8 - self.bit_offset; + if num_bits_in_current_byte < 8 { + byte &= (1 << num_bits_in_current_byte) - 1; + } + let num_bits_in_next_byte = bits - num_bits_in_current_byte; + byte <<= num_bits_in_next_byte; + + if num_bits_in_next_byte > 0 { + let next_byte = self.bytes[self.byte_index + 1]; + let mask = ((1 << num_bits_in_next_byte) - 1) << (8 - num_bits_in_next_byte); + byte |= (next_byte & mask) >> (8 - num_bits_in_next_byte); + } + + self.byte_index += 1; + self.bit_offset = (self.bit_offset + bits) % 8; + } else { + // Only need to look at the current byte. + byte = self.bytes[self.byte_index]; + byte &= ((1 << bits) - 1) << (8 - self.bit_offset - bits); + byte >>= 8 - self.bit_offset - bits; + + self.bit_offset += bits; + } + + Ok(byte) + } + + pub fn read_non_symmetric(&mut self, n: u8) -> anyhow::Result { + let mut w = 0; + let mut x = n; + while x != 0 { + x >>= 1; + w += 1; + } + + let m = (1 << w) - n; + let v = self.read_u8(w - 1)?; + if v < m { + return Ok(v); + } + + let extra_bit = self.read_u8(1)?; + Ok((v << 1) - m + extra_bit) + } + + pub fn has_more(&mut self) -> bool { + self.byte_index < self.bytes.len() + } + + pub fn zero_pad(&mut self) { + if self.bit_offset > 0 { + self.bit_offset = 0; + self.byte_index += 1; + } + } + + /// An implementation of https://aomediacodec.github.io/av1-spec/#leb128 + pub fn read_leb128(&mut self) -> anyhow::Result { + let mut value = 0; + for i in 0..8 { + let byte = self.read_u8(8)? as u128; + value |= (byte & 0x7f) << (i * 7); + if byte & 0x80 == 0 { + break; + } + } + Ok(value) + } +} + +macro_rules! impl_bit_writer_for_type { + ($type:ty, $func_name:ident) => { + pub fn $func_name(&mut self, value: $type, bits: usize) { + if bits > 0 { + let len = std::mem::size_of::<$type>(); + assert!(bits <= len * 8); + let bytes = value.to_be_bytes(); + let index = (len * 8 - bits) / 8; + let topmost_bits = bits - (len - index - 1) * 8; + self.write_u8(bytes[index], topmost_bits); + for i in index + 1..len { + self.write_u8(bytes[i], 8); + } + } + } + }; +} + +#[derive(Debug, Default)] +pub struct BitstreamWriter { + storage: SmallVec<[u8; N]>, + free_bits: usize, +} + +impl BitstreamWriter { + impl_bit_writer_for_type!(u16, write_u16); + impl_bit_writer_for_type!(u32, write_u32); + impl_bit_writer_for_type!(u64, write_u64); + + fn push_bits(&mut self, value: u8, bits: usize) { + let i = self.storage.len() - 1; + self.free_bits -= bits; + self.storage[i] |= value << self.free_bits; + } + + fn extend(&mut self) { + self.storage.push(0); + self.free_bits = 8; + } + + pub fn write_u8(&mut self, value: u8, bits: usize) { + if bits > 0 { + if self.free_bits == 0 { + self.extend(); + self.push_bits(value, bits); + } else if bits <= self.free_bits { + self.push_bits(value, bits); + } else { + let overflow = bits - self.free_bits; + self.push_bits(value >> overflow, self.free_bits); + self.extend(); + self.push_bits(value & ((1 << overflow) - 1), overflow); + } + } + } + + pub fn write(&mut self, bytes: &[u8]) { + for byte in bytes { + self.write_u8(*byte, 8); + } + } + + pub fn write_padding(&mut self) { + self.write_u8(0, self.free_bits); + } + + pub fn write_non_symmetric(&mut self, n: usize, v: u8) { + if n == 1 { + return; + } + let mut w = 0; + let mut x = n; + while x != 0 { + x >>= 1; + w += 1; + } + let m = (1 << w) - n as u8; + if v < m { + self.write_u8(v, w - 1); + } else { + self.write_u8(v + m, w); + } + } + + pub fn len(&self) -> usize { + 8 * self.storage.len() + 8 - self.free_bits + } + + pub fn is_empty(&self) -> bool { + self.storage.is_empty() && self.free_bits == 8 + } + + pub fn as_slice(&self) -> &[u8] { + self.storage.as_slice() + } +} + +#[cfg(test)] +mod bitstream_reader_tests { + use super::*; + + #[test] + fn read_u8() -> anyhow::Result<()> { + let bytes = [0b0000_0010, 0b1010_0000]; + let mut rdr = BitstreamReader::new(&bytes); + + assert_eq!(rdr.read_u8(1)?, 0); + + rdr.bit_offset = 1; + assert_eq!(rdr.read_u8(1)?, 0); + + rdr.bit_offset = 6; + assert_eq!(rdr.read_u8(1)?, 1); + + rdr.bit_offset = 3; + assert_eq!(rdr.read_u8(5)?, 0b10); + + rdr.byte_index = 0; + rdr.bit_offset = 6; + assert_eq!(rdr.read_u8(3)?, 0b101); + + Ok(()) + } + + #[test] + fn read_u8_error() -> anyhow::Result<()> { + let bytes = []; + let mut rdr = BitstreamReader::new(&bytes); + assert!(!rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(1).is_err()); + + let bytes = [0b1000_0000]; + let mut rdr = BitstreamReader::new(&bytes); + assert!(rdr.has_more()); + assert_eq!(rdr.read_u8(1)?, 1); + + assert!(rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + + assert_eq!(rdr.read_u8(1)?, 0); + assert!(rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(7).is_err()); + + assert_eq!(rdr.read_u8(1)?, 0); + assert!(rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(7).is_err()); + assert!(rdr.read_u8(6).is_err()); + + assert_eq!(rdr.read_u8(1)?, 0); + assert!(rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(7).is_err()); + assert!(rdr.read_u8(6).is_err()); + assert!(rdr.read_u8(5).is_err()); + + assert_eq!(rdr.read_u8(1)?, 0); + assert!(rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(7).is_err()); + assert!(rdr.read_u8(6).is_err()); + assert!(rdr.read_u8(5).is_err()); + assert!(rdr.read_u8(4).is_err()); + + assert_eq!(rdr.read_u8(1)?, 0); + assert!(rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(7).is_err()); + assert!(rdr.read_u8(6).is_err()); + assert!(rdr.read_u8(5).is_err()); + assert!(rdr.read_u8(4).is_err()); + assert!(rdr.read_u8(3).is_err()); + + assert_eq!(rdr.read_u8(1)?, 0); + assert!(rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(7).is_err()); + assert!(rdr.read_u8(6).is_err()); + assert!(rdr.read_u8(5).is_err()); + assert!(rdr.read_u8(4).is_err()); + assert!(rdr.read_u8(3).is_err()); + assert!(rdr.read_u8(2).is_err()); + + assert_eq!(rdr.read_u8(1)?, 0); + assert!(!rdr.has_more()); + assert!(rdr.read_u8(8).is_err()); + assert!(rdr.read_u8(7).is_err()); + assert!(rdr.read_u8(6).is_err()); + assert!(rdr.read_u8(5).is_err()); + assert!(rdr.read_u8(4).is_err()); + assert!(rdr.read_u8(3).is_err()); + assert!(rdr.read_u8(2).is_err()); + assert!(rdr.read_u8(1).is_err()); + + Ok(()) + } + + #[test] + fn read_u8_two_bytes() -> anyhow::Result<()> { + let bytes = [0b0000_0010, 0b1010_0011]; + let mut rdr = BitstreamReader::new(&bytes); + + assert_eq!(rdr.read_u8(1)?, 0); + assert_eq!(rdr.read_u8(1)?, 0); + assert_eq!(rdr.read_u8(5)?, 0b1); + assert_eq!(rdr.read_u8(1)?, 0); + assert_eq!(rdr.read_u8(4)?, 0b1010); + assert_eq!(rdr.read_u8(2)?, 0b0); + assert_eq!(rdr.read_u8(1)?, 0b1); + assert_eq!(rdr.read_u8(1)?, 0b1); + + assert!(rdr.read_u8(1).is_err()); + + Ok(()) + } + + #[test] + fn read_u8_one_byte() -> anyhow::Result<()> { + let bytes = [0b0001_1011]; + let mut rdr = BitstreamReader::new(&bytes); + + assert_eq!(rdr.read_u8(1)?, 0); + assert_eq!(rdr.read_u8(1)?, 0); + assert_eq!(rdr.read_u8(1)?, 0); + assert_eq!(rdr.read_u8(1)?, 1); + assert_eq!(rdr.read_u8(1)?, 1); + assert_eq!(rdr.read_u8(1)?, 0); + assert_eq!(rdr.read_u8(1)?, 1); + assert_eq!(rdr.read_u8(1)?, 1); + + rdr.byte_index = 0; + rdr.bit_offset = 0; + assert_eq!(rdr.read_u8(2)?, 0b0); + assert_eq!(rdr.read_u8(2)?, 0b1); + assert_eq!(rdr.read_u8(2)?, 0b10); + assert_eq!(rdr.read_u8(2)?, 0b11); + + rdr.byte_index = 0; + rdr.bit_offset = 0; + assert_eq!(rdr.read_u8(4)?, 0b1); + assert_eq!(rdr.read_u8(4)?, 0b1011); + + rdr.byte_index = 0; + rdr.bit_offset = 0; + assert_eq!(rdr.read_u8(5)?, 0b11); + assert_eq!(rdr.read_u8(3)?, 0b11); + + rdr.byte_index = 0; + rdr.bit_offset = 0; + assert_eq!(rdr.read_u8(8)?, 0b0001_1011); + + Ok(()) + } +} + +#[cfg(test)] +mod bitstream_writer_tests { + use crate::bitstream::BitstreamWriter; + + #[test] + fn test_bitwriter_u8_write() { + let mut writer = BitstreamWriter::<32>::default(); + writer.write_u8(0b00000011, 2); + assert_eq!(writer.as_slice()[0], 0b11000000); + writer.write_u8(0b00000101, 3); + assert_eq!(writer.as_slice()[0], 0b11101000); + writer.write_u8(0b00001111, 4); + assert_eq!(writer.as_slice()[0], 0b11101111); + assert_eq!(writer.as_slice()[1], 0b10000000); + } +} diff --git a/backend/src/call.rs b/backend/src/call.rs index c95c634..0aff35a 100644 --- a/backend/src/call.rs +++ b/backend/src/call.rs @@ -1567,21 +1567,12 @@ impl CallInner { sender.is_maybe_in_dtx = false; } } - let dependency_descriptor = if incoming_rtp.is_vp8() { - time_scope_us!("calling.call.handle_rtp.vp8_header"); - if let Some((dependency_descriptor, need_reallocation)) = - sender.update_incoming_video_rate_and_resolution(&incoming_rtp, now)? - { - if need_reallocation { - self.reallocate_target_send_rates(now); - } - Some(dependency_descriptor) - } else { - return Ok(vec![]); - } - } else { - None - }; + + if incoming_rtp.is_vp8() + && sender.update_incoming_video_rate_and_resolution(&incoming_rtp, now)? + { + self.reallocate_target_send_rates(now); + } let mut rtp_to_send = vec![]; @@ -1598,12 +1589,17 @@ impl CallInner { LayerId::Audio => receiver.forward_audio_rtp(&incoming_rtp), LayerId::RtpData => receiver.forward_data_rtp(&incoming_rtp), LayerId::Video0 | LayerId::Video1 | LayerId::Video2 => { - receiver.forward_video_rtp(&incoming_rtp, dependency_descriptor.as_ref()) + if incoming_rtp.is_vp8() { + receiver.forward_video_rtp_vp8(&incoming_rtp) + } else { + None + } } } { rtp_to_send.push((receiver.demux_id, rtp_to_forward)); } } + Ok(rtp_to_send) } @@ -2921,13 +2917,10 @@ impl Client { &mut self, incoming_rtp: &rtp::Packet<&[u8]>, now: Instant, - ) -> Result, Error> { - let dependency_descriptor = - if let Some((descriptor, _)) = incoming_rtp.dependency_descriptor { - descriptor - } else { - return Err(Error::MissingDependencyDescriptor); - }; + ) -> Result { + let Some((dependency_descriptor, _)) = incoming_rtp.dependency_descriptor.as_ref() else { + return Err(Error::MissingDependencyDescriptor); + }; let incoming_layer_index = LayerId::layer_index_from_ssrc(incoming_rtp.ssrc()).ok_or(Error::InvalidRtpLayerId)?; let incoming_video = &mut self.incoming_video[incoming_layer_index]; @@ -2949,7 +2942,7 @@ impl Client { } let old_resolution = incoming_video.original_resolution; - if let Some(resolution) = dependency_descriptor.resolution { + if let Some(resolution) = dependency_descriptor.resolution() { incoming_video.apply_resolution(resolution, self.video_rotation); } else if old_resolution.is_none() && !incoming_video.needs_resolution { // Record that we have data on the stream, when the resolution has been cleared. @@ -2963,7 +2956,7 @@ impl Client { // If this is a key frame, and it was not allocatable before, update the bitrate and run // allocation; this allows for switching to a new stream on the first key frame. - if dependency_descriptor.is_key_frame + if dependency_descriptor.is_key_frame() && (old_resolution.is_none() || incoming_video.rate().unwrap_or_default().as_bps() == 0) { incoming_video.rate_tracker.update(now); @@ -3040,7 +3033,7 @@ impl Client { } } - Ok(Some((dependency_descriptor, need_reallocation))) + Ok(need_reallocation) } fn forward_audio_rtp( @@ -3060,28 +3053,21 @@ impl Client { Some(outgoing_rtp) } - fn forward_video_rtp( + fn forward_video_rtp_vp8( &mut self, incoming_rtp: &rtp::Packet<&[u8]>, - dependency_descriptor: Option<&rtp::DependencyDescriptor>, ) -> Option>> { - let dependency_descriptor = dependency_descriptor?; - let sender_demux_id = DemuxId::from_ssrc(incoming_rtp.ssrc()); let forwarder = self .video_forwarder_by_sender_demux_id .get_mut(&sender_demux_id)?; - let (outgoing_ssrc, outgoing) = - forwarder.forward_vp8_rtp(incoming_rtp, dependency_descriptor)?; + let (outgoing_ssrc, outgoing) = forwarder.forward_vp8_rtp(incoming_rtp)?; let mut outgoing_rtp = incoming_rtp.rewrite( outgoing_ssrc, outgoing.seqnum, outgoing.timestamp as rtp::TruncatedTimestamp, ); - if let Some((descriptor, _)) = &mut outgoing_rtp.dependency_descriptor { - descriptor.truncated_frame_number = outgoing.frame_number as rtp::TruncatedFrameNumber; - } outgoing_rtp.set_frame_number_in_header(outgoing.frame_number); Some(outgoing_rtp) } @@ -3679,10 +3665,10 @@ impl Vp8SimulcastRtpForwarder { fn forward_vp8_rtp( &mut self, incoming_rtp: &rtp::Packet<&[u8]>, - dependency_descriptor: &rtp::DependencyDescriptor, ) -> Option<(rtp::Ssrc, VideoRewrittenIds)> { + let (dependency_descriptor, _) = incoming_rtp.dependency_descriptor.as_ref()?; if self.switching_ssrc() == Some(incoming_rtp.ssrc()) - && dependency_descriptor.is_key_frame + && dependency_descriptor.is_key_frame() && (incoming_rtp.is_max_seqnum || self.forwarding_ssrc().is_none()) { // When switching from forwarding one SSRC to another, we only @@ -3718,7 +3704,7 @@ impl Vp8SimulcastRtpForwarder { // In other words, we are only tracking the ROC since the switching point, // and that is now, so the ROC is 0. incoming_rtp.timestamp as rtp::FullTimestamp, - dependency_descriptor.truncated_frame_number as rtp::FullFrameNumber, + dependency_descriptor.truncated_frame_number() as rtp::FullFrameNumber, ); // We make two simplifying assumptions here: // 1. The first packet we received is the first packet of the key frame. @@ -3745,7 +3731,7 @@ impl Vp8SimulcastRtpForwarder { self.switching = Vp8SimulcastRtpSwitchingState::DoNotSwitch; self.max_outgoing = first_outgoing; } else if self.switching_ssrc() == Some(incoming_rtp.ssrc()) - && dependency_descriptor.is_key_frame + && dependency_descriptor.is_key_frame() { event!("calling.forwarding.layer_switch.wait_for_in_order_key_frame"); trace!( @@ -3767,7 +3753,7 @@ impl Vp8SimulcastRtpForwarder { { if *incoming_ssrc == incoming_rtp.ssrc() { let expanded_frame_number = rtp::expand_frame_number( - dependency_descriptor.truncated_frame_number, + dependency_descriptor.truncated_frame_number(), &mut max_incoming.frame_number, ); @@ -3781,7 +3767,7 @@ impl Vp8SimulcastRtpForwarder { first_outgoing.checked_add(&incoming.checked_sub(first_incoming)?)?; self.max_outgoing = self.max_outgoing.max(&outgoing); - if dependency_descriptor.is_key_frame { + if dependency_descriptor.is_key_frame() { *needs_key_frame = false; } trace!( @@ -3898,7 +3884,13 @@ mod call_tests { use mrp::MrpHeader; use super::*; - use crate::protos::sfu_to_device::{peek_info::PeekDeviceInfo, PeekInfo}; + use crate::{ + protos::sfu_to_device::{peek_info::PeekDeviceInfo, PeekInfo}, + rtp::{ + DependencyDescriptor, ExtendedDescriptorFields, MandatoryDescriptorFields, + TemplateDependencyStructure, TemplateDependencyStructureFields, + }, + }; static CALL_ID: &[u8; 7] = b"call_id"; @@ -3943,7 +3935,30 @@ mod call_tests { ssrc: u32, index: u32, rtp: rtp::Packet>, - dependency_descriptor: rtp::DependencyDescriptor, + resolution: Option, + } + + fn dummy_dependency_descriptor( + is_key_frame: bool, + frame_number: u16, + resolution: Option, + ) -> DependencyDescriptor { + DependencyDescriptor { + mandatory_fields: MandatoryDescriptorFields { + start_of_frame: true, + frame_number, + ..Default::default() + }, + extended_fields: Some(ExtendedDescriptorFields { + template_dependency_structure: is_key_frame.then_some( + TemplateDependencyStructure::new(TemplateDependencyStructureFields { + resolutions: resolution.map(|resolution| [resolution.into()].into()), + ..Default::default() + }), + ), + ..Default::default() + }), + } } impl Incoming { @@ -3965,7 +3980,7 @@ mod call_tests { self.rtp.ssrc(), self.index + 1, is_key_frame, - self.dependency_descriptor.resolution, + self.resolution, ) } @@ -3981,15 +3996,19 @@ mod call_tests { let mut rtp = rtp::Packet::with_empty_tag(pt, seqnum, timestamp, ssrc, None, None, &[]); rtp.is_max_seqnum = true; + rtp.dependency_descriptor = Some(( + dummy_dependency_descriptor( + is_key_frame, + ((1000 * ssrc) + index) as u16, + resolution, + ), + 0..0, + )); Self { ssrc, index, rtp, - dependency_descriptor: rtp::DependencyDescriptor { - truncated_frame_number: ((1000 * ssrc) + index) as u16, - is_key_frame, - resolution, - }, + resolution, } } @@ -4002,12 +4021,10 @@ mod call_tests { let mut rtp = self.rtp.clone(); rtp.set_seqnum_in_header(seqnum); rtp.set_timestamp_in_header(timestamp); + let (dependency_descriptor, _) = rtp.dependency_descriptor.as_mut().unwrap(); + dependency_descriptor.mandatory_fields.frame_number = truncated_frame_number; Self { rtp, - dependency_descriptor: rtp::DependencyDescriptor { - truncated_frame_number, - ..self.dependency_descriptor - }, ..self.clone() } } @@ -4016,7 +4033,7 @@ mod call_tests { &self, forwarder: &mut Vp8SimulcastRtpForwarder, ) -> Option<(rtp::Ssrc, VideoRewrittenIds)> { - forwarder.forward_vp8_rtp(&self.rtp.borrow(), &self.dependency_descriptor) + forwarder.forward_vp8_rtp(&self.rtp.borrow()) } } @@ -5058,6 +5075,46 @@ mod call_tests { create_rtp(sender_demux_id, layer_id, seqnum, &payload[..]) } + fn create_dependency_descriptor( + frame_number: u16, + size: Option, + ) -> DependencyDescriptor { + let mut descriptor = if let Some(size) = size { + let width_minus_1 = size.width - 1; + let height_minus_1 = size.height - 1; + let encoded = [ + 0b10000000u8, + 0b00000000, + 0b00000001, + 0b10000000, // The first bit in this byte indicates that this is for a key frame. + 0b00000010, + 0b00000100, + 0b01001110, + 0b10101010, + 0b10101111, + 0b00101000, + 0b01100000, + 0b01000001, + 0b01001101, + 0b00110100, + 0b01010011, + 0b10001010, + 0b00001001, + 0b01000000, + 0b0100_0000 | ((width_minus_1 >> 10) as u8), // width - 1 from 3rd bit + ((width_minus_1 >> 2) & 0xFF) as u8, + (((width_minus_1 & 0b0000_0011) << 6) as u8) | ((height_minus_1 >> 10) as u8), // height - 1 from 3rd bit + ((height_minus_1 >> 2) & 0xFF) as u8, + ((height_minus_1 & 0b0000_0011) << 6) as u8, + ]; + DependencyDescriptor::read(&encoded, None).expect("parse the dependency descriptor") + } else { + DependencyDescriptor::default() + }; + descriptor.mandatory_fields.frame_number = frame_number; + descriptor + } + fn create_video_rtp( sender_demux_id: DemuxId, layer_id: LayerId, @@ -5071,16 +5128,14 @@ mod call_tests { let ssrc = layer_id.to_ssrc(sender_demux_id); let pt = 108; let timestamp = seqnum as rtp::TruncatedTimestamp; + let dependency_descriptor = + create_dependency_descriptor(truncated_frame_number, key_frame_size); rtp::Packet::with_dependency_descriptor( pt, seqnum, timestamp, ssrc, - rtp::DependencyDescriptor { - is_key_frame: key_frame_size.is_some(), - resolution: key_frame_size, - truncated_frame_number, - }, + dependency_descriptor, &payload, ) } diff --git a/backend/src/connection.rs b/backend/src/connection.rs index 58a8050..e03d382 100644 --- a/backend/src/connection.rs +++ b/backend/src/connection.rs @@ -22,7 +22,7 @@ use crate::{ pacer::{self, Pacer}, packet_server::{AddressType, SocketLocator}, region::RegionRelation, - rtp::{self, TruncatedSequenceNumber}, + rtp::{self, TemplateDependencyStructure, TruncatedSequenceNumber}, sfu::ConnectionId, }; @@ -282,9 +282,12 @@ impl Connection { pub fn handle_rtp_packet<'packet>( &self, incoming_packet: &'packet mut [u8], + template_dependency_structure: Option<&TemplateDependencyStructure>, now: Instant, ) -> Result>, Error> { - self.inner.write().handle_rtp_packet(incoming_packet, now) + self.inner + .write() + .handle_rtp_packet(incoming_packet, template_dependency_structure, now) } /// Decrypts an incoming RTCP packet and processes it. @@ -725,6 +728,7 @@ impl ConnectionInner { fn handle_rtp_packet<'packet>( &mut self, incoming_packet: &'packet mut [u8], + template_dependency_structure: Option<&TemplateDependencyStructure>, now: Instant, ) -> Result>, Error> { if self.closed { @@ -733,7 +737,7 @@ impl ConnectionInner { let rtp_endpoint = &mut self.rtp.endpoint; let size = incoming_packet.len(); - match rtp_endpoint.receive_rtp(incoming_packet, now) { + match rtp_endpoint.receive_rtp(incoming_packet, template_dependency_structure, now) { Some(packet) => { if packet.is_rtx() { event!("calling.bandwidth.incoming.rtx_bytes", size); @@ -1553,7 +1557,7 @@ mod connection_tests { assert_eq!( expected_decrypted_rtp.to_owned(), connection - .handle_rtp_packet(&mut encrypted_rtp.into_serialized(), now) + .handle_rtp_packet(&mut encrypted_rtp.into_serialized(), None, now) .unwrap() .unwrap() .to_owned() @@ -1562,7 +1566,7 @@ mod connection_tests { let encrypted_rtp = new_encrypted_rtp(2, None, &encrypt, now); assert_eq!( Err(Error::ReceivedInvalidRtp), - connection.handle_rtp_packet(&mut encrypted_rtp.into_serialized(), now) + connection.handle_rtp_packet(&mut encrypted_rtp.into_serialized(), None, now) ); let encrypted_rtp = new_encrypted_rtx_rtp(5, 2, None, &decrypt, now); @@ -1570,7 +1574,7 @@ mod connection_tests { assert_eq!( expected_decrypted_rtp.borrow().to_owned(), connection - .handle_rtp_packet(&mut encrypted_rtp.into_serialized(), now) + .handle_rtp_packet(&mut encrypted_rtp.into_serialized(), None, now) .unwrap() .unwrap() .to_owned() @@ -1665,7 +1669,7 @@ mod connection_tests { let (buf, _addr) = &packets_to_send[0]; assert_eq!(1172, buf.len()); - let actual_padding_header = rtp::Header::parse(buf).unwrap(); + let actual_padding_header = rtp::Header::parse(buf, None).unwrap(); assert_eq!(padding_ssrc, actual_padding_header.ssrc); assert_eq!(99, actual_padding_header.payload_type); assert_eq!(1136, actual_padding_header.payload_range.len()); @@ -1710,7 +1714,7 @@ mod connection_tests { let (buf, _addr) = &packets_to_send[0]; assert_eq!(1172, buf.len()); - let actual_padding_header = rtp::Header::parse(buf).unwrap(); + let actual_padding_header = rtp::Header::parse(buf, None).unwrap(); assert_eq!(padding_ssrc, actual_padding_header.ssrc); assert_eq!(99, actual_padding_header.payload_type); assert_eq!(1136, actual_padding_header.payload_range.len()); @@ -1815,6 +1819,7 @@ mod connection_tests { connection .handle_rtp_packet( &mut new_encrypted_rtp(1, Some(101), &decrypt, at(1)).into_serialized(), + None, at(1), ) .unwrap() @@ -1823,6 +1828,7 @@ mod connection_tests { connection .handle_rtp_packet( &mut new_encrypted_rtp(3, Some(103), &decrypt, at(3)).into_serialized(), + None, at(3), ) .unwrap() @@ -1875,6 +1881,7 @@ mod connection_tests { connection .handle_rtp_packet( &mut new_encrypted_rtp(2, Some(102), &decrypt, at(10002)).into_serialized(), + None, at(10002), ) .unwrap() diff --git a/backend/src/lib.rs b/backend/src/lib.rs index ce4af91..5a46e69 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -4,6 +4,7 @@ // pub mod audio; +pub mod bitstream; pub mod call; pub mod call_lifecycle; pub mod candidate_selector; diff --git a/backend/src/rtp.rs b/backend/src/rtp.rs index 8bc5d1a..f93cd18 100644 --- a/backend/src/rtp.rs +++ b/backend/src/rtp.rs @@ -6,6 +6,7 @@ //! Implementation of RTP/SRTP. See https://tools.ietf.org/html/rfc3550 and //! https://tools.ietf.org/html/rfc7714. Assumes AES-GCM 128. +mod dependency_descriptor; mod nack; mod packet; mod rtcp; @@ -16,13 +17,15 @@ mod types; use std::{collections::HashMap, convert::TryInto}; use calling_common::{expand_truncated_counter, read_u16, Bits, Duration, Instant, Writer}; +pub use dependency_descriptor::*; use log::*; use metrics::*; use nack::*; pub use nack::{write_nack, Nack}; use packet::*; -pub use packet::{DependencyDescriptor, Header, Packet}; -pub use rtcp::{ControlPacket, KeyFrameRequest, *}; +pub use packet::{Header, Packet}; +use rtcp::*; +pub use rtcp::{ControlPacket, KeyFrameRequest}; pub use rtx::to_rtx_ssrc; use rtx::*; use srtp::*; @@ -38,6 +41,7 @@ const PADDING_PAYLOAD_TYPE: PayloadType = 99; const CLIENT_SERVER_DATA_PAYLOAD_TYPE: PayloadType = 101; pub const OPUS_PAYLOAD_TYPE: PayloadType = 102; pub const VP8_PAYLOAD_TYPE: PayloadType = 108; +pub const VP9_PAYLOAD_TYPE: PayloadType = 109; // Discard outgoing packets after this time. // 3 second lifetime matches WebRTC's RTX history @@ -114,7 +118,7 @@ fn is_audio_payload_type(pt: PayloadType) -> bool { } fn is_video_payload_type(pt: PayloadType) -> bool { - pt == VP8_PAYLOAD_TYPE + pt == VP8_PAYLOAD_TYPE || pt == VP9_PAYLOAD_TYPE } fn is_padding_payload_type(pt: PayloadType) -> bool { @@ -228,10 +232,11 @@ impl Endpoint { pub fn receive_rtp<'packet>( &mut self, encrypted: &'packet mut [u8], + template_dependency_structure: Option<&TemplateDependencyStructure>, now: Instant, ) -> Option> { // Header::parse will log a warning for every place where it fails to parse. - let header = Header::parse(encrypted)?; + let header = Header::parse(encrypted, template_dependency_structure)?; let tcc_seqnum = header .tcc_seqnum @@ -859,7 +864,7 @@ pub mod fuzz { } pub fn parse_and_forward_rtp_for_fuzzing(data: Vec) -> Option> { - let header = Header::parse(&data)?; + let header = Header::parse(&data, None)?; let mut incoming = Packet::new( &header, @@ -946,7 +951,7 @@ mod test { ) .unwrap(); let received1 = receiver - .receive_rtp(sent1.serialized.borrow_mut(), at(10)) + .receive_rtp(sent1.serialized.borrow_mut(), None, at(10)) .unwrap(); sent1.encrypted = false; // Got decrypted by the above let received1 = received1.to_owned(); @@ -994,7 +999,7 @@ mod test { ) .unwrap(); let received3 = receiver - .receive_rtp(sent3.serialized.borrow_mut(), at(20)) + .receive_rtp(sent3.serialized.borrow_mut(), None, at(20)) .unwrap(); sent3.encrypted = false; // Got decrypted by the above let received3 = received3.to_owned(); @@ -1018,7 +1023,7 @@ mod test { let mut resent2 = sender.resend_rtp(3, 2, at(50)).unwrap(); let received2 = receiver - .receive_rtp(resent2.serialized.borrow_mut(), at(60)) + .receive_rtp(resent2.serialized.borrow_mut(), None, at(60)) .unwrap(); resent2.encrypted = false; // Got decrypted by the above let mut received2 = received2.to_owned(); @@ -1047,7 +1052,7 @@ mod test { .send_rtp(sent2.rewrite(33, 22, 44), at(70)) .unwrap(); let forwarded2 = sender - .receive_rtp(forwarded2.serialized.borrow_mut(), at(80)) + .receive_rtp(forwarded2.serialized.borrow_mut(), None, at(80)) .unwrap(); let forwarded2 = forwarded2.to_owned(); assert_eq!(sent2.payload_type(), forwarded2.payload_type()); @@ -1063,13 +1068,13 @@ mod test { .unwrap(); assert_eq!(1, reforwarded2.seqnum_in_header); assert_eq!(Some(22), reforwarded2.seqnum_in_payload); - let reforwarded2 = sender.receive_rtp(reforwarded2.serialized.borrow_mut(), at(90)); + let reforwarded2 = sender.receive_rtp(reforwarded2.serialized.borrow_mut(), None, at(90)); assert_eq!(reforwarded2, None); // Padding let mut padding = sender.send_padding(4, at(100)).unwrap(); let received_padding = receiver - .receive_rtp(padding.serialized.borrow_mut(), at(110)) + .receive_rtp(padding.serialized.borrow_mut(), None, at(110)) .unwrap(); assert_eq!(99, received_padding.payload_type()); assert_eq!(99, received_padding.payload_type_in_header); @@ -1179,14 +1184,14 @@ mod test { .unwrap(); let mut sent200b = sent200a.clone(); - let received1a = receiver.receive_rtp(sent1a.serialized.borrow_mut(), at(10)); - let received1b = receiver.receive_rtp(sent1b.serialized.borrow_mut(), at(10)); - let received2a = receiver.receive_rtp(sent2a.serialized.borrow_mut(), at(10)); - let received2b = receiver.receive_rtp(sent2b.serialized.borrow_mut(), at(10)); - let received200a = receiver.receive_rtp(sent200a.serialized.borrow_mut(), at(10)); - let received200b = receiver.receive_rtp(sent200b.serialized.borrow_mut(), at(10)); - let received1c = receiver.receive_rtp(sent1c.serialized.borrow_mut(), at(10)); - let received2c = receiver.receive_rtp(sent2c.serialized.borrow_mut(), at(10)); + let received1a = receiver.receive_rtp(sent1a.serialized.borrow_mut(), None, at(10)); + let received1b = receiver.receive_rtp(sent1b.serialized.borrow_mut(), None, at(10)); + let received2a = receiver.receive_rtp(sent2a.serialized.borrow_mut(), None, at(10)); + let received2b = receiver.receive_rtp(sent2b.serialized.borrow_mut(), None, at(10)); + let received200a = receiver.receive_rtp(sent200a.serialized.borrow_mut(), None, at(10)); + let received200b = receiver.receive_rtp(sent200b.serialized.borrow_mut(), None, at(10)); + let received1c = receiver.receive_rtp(sent1c.serialized.borrow_mut(), None, at(10)); + let received2c = receiver.receive_rtp(sent2c.serialized.borrow_mut(), None, at(10)); assert!(received1a.is_some()); assert!(received1b.is_none()); diff --git a/backend/src/rtp/dependency_descriptor.rs b/backend/src/rtp/dependency_descriptor.rs new file mode 100644 index 0000000..70120b0 --- /dev/null +++ b/backend/src/rtp/dependency_descriptor.rs @@ -0,0 +1,1109 @@ +// +// Copyright 2026 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only +// + +// Implements parsing and serialization of the Dependency Descriptor as defined in Appendix A +// of the AV1 RTP specification (https://aomediacodec.github.io/av1-rtp-spec/). + +use std::{ + fmt::Debug, + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use anyhow::{anyhow, bail, Result}; +use calling_common::PixelSize; +use smallvec::SmallVec; + +use crate::bitstream::{BitstreamReader, BitstreamWriter}; + +pub type DefaultBitstreamWriter = BitstreamWriter<128>; + +/// RTP header extension containing frame dependency metadata for scalable video streams. +/// +/// # Structure +/// +/// - **Mandatory fields** (3 bytes minimum): Present in every descriptor, containing +/// frame boundaries, template ID, and frame number. +/// - **Extended fields** (optional): Present when descriptor size > 3 bytes, containing +/// template dependency structure, active decode targets, and custom overrides. +/// +/// # Key Frame vs Delta Frame +/// +/// - **Key frames** include a `template_dependency_structure` that defines the entire +/// scalability structure (templates, layers, decode targets, chains). +/// - **Delta frames** reference templates from the most recent key frame, optionally +/// overriding DTIs, Fdiffs, or Chains with custom values. +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct DependencyDescriptor { + pub mandatory_fields: MandatoryDescriptorFields, + pub extended_fields: Option, +} + +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct MandatoryDescriptorFields { + /// MUST be set to `true` if the first payload byte of the RTP packet is the beginning of a new + /// frame, and MUST be set to `false` otherwise. Note that this frame might not be the first + /// frame of a temporal unit. + pub start_of_frame: bool, + /// MUST be set to `true` for the final RTP packet of a frame, and MUST be set to 0 otherwise. + /// Note that, if spatial scalability is in use, more frames from the same temporal unit may + /// follow. + pub end_of_frame: bool, + /// ID of the Frame dependency template to use. MUST be in the range of template_id_offset to + /// (template_id_offset + TemplateCnt - 1), inclusive. frame_dependency_template_id MUST be + /// the same for all packets of the same frame. + pub frame_dependency_template_id: u8, + /// The frame number is represented using 16 bits and increases strictly monotonically in decode + /// order. frame_number MAY start on a random number, and MUST wrap after reaching the maximum + /// value. All packets of the same frame MUST have the same frame_number value. + pub frame_number: u16, +} + +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct ExtendedDescriptorFields { + /// This field defines a set of frame templates that describe how frames relate to each other + /// in a scalable video stream. This field is transmitted in key frames to establish + /// the decoding framework. + pub template_dependency_structure: Option, + /// active_decode_targets_bitmask contains a bitmask that indicates which Decode targets are + /// available for decoding. Bit i is equal to 1 if Decode target i is available for decoding, + /// 0 otherwise. The least significant bit corresponds to Decode target 0. + pub active_decode_targets_bitmask: ActiveDecodeTargetsBitmask, + /// Frame DTIs, if present + pub custom_dtis: Option, + /// Frame Fdiffs, if present + pub custom_fdiffs: Option, + /// Frame chains, if present + pub custom_chains: Option, +} + +/// Represents the state and availability of decode targets in the dependency descriptor. +/// - **Uninitialized**: Default state before any decode targets are configured. +/// - **AllImplicitlyActive**: All decode targets are implicitly active based on template +/// dependency structure. Contains the bitmask derived from decode target count where all +/// targets are enabled. +/// - **Available**: Explicitly specifies which decode targets are active via a custom bitmask. +/// Present when the active decode targets present bit is set in extended descriptor fields. +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub enum ActiveDecodeTargetsBitmask { + #[default] + Uninitialized, + AllImplicitlyActive { + bitmask: u32, + size: usize, + }, + Available { + bitmask: u32, + size: usize, + }, +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct Template { + pub layer: Layer, + pub dtis: Dtis, + pub fdiffs: Fdiffs, + pub chains: Chains, +} + +// This is a simple macro that implements Deref and DerefMut for simple containers that wrap +// SmallVec. Specifically, this is used for Dtis, Fdiffs, Chains, Resolutions, and Layers. +// Improves legibility. +macro_rules! impl_smallvec_container { + ($name:ident, $inner_type:ty, $initial_size:tt) => { + #[derive(Debug, Default, Clone, PartialEq, Eq)] + pub struct $name(SmallVec<[$inner_type; $initial_size]>); + + impl $name { + fn push(&mut self, item: $inner_type) { + self.0.push(item); + } + } + + impl Deref for $name { + type Target = [$inner_type]; + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for $name { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + #[cfg(any(test, feature = "load_test"))] + impl From<[$inner_type; N]> for $name { + fn from(v: [$inner_type; N]) -> Self { + Self(SmallVec::from_iter(v.into_iter())) + } + } + }; +} + +impl_smallvec_container!(Layers, Layer, 16); +impl_smallvec_container!(DecodeTargetChainIndices, u8, 16); +impl_smallvec_container!(Resolutions, Resolution, 8); +impl_smallvec_container!(Chains, u8, 16); +impl_smallvec_container!(CustomChains, u8, 16); +impl_smallvec_container!(Fdiffs, u8, 16); +impl_smallvec_container!(CustomFdiffs, u16, 16); +impl_smallvec_container!(Dtis, Dti, 16); + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct Layer { + pub spatial_id: u8, + pub temporal_id: u8, +} + +impl Layer { + pub const fn zero() -> Self { + Self { + spatial_id: 0, + temporal_id: 0, + } + } +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct Resolution { + pub width: u16, + pub height: u16, +} + +impl Resolution { + pub const fn zero() -> Self { + Self { + width: 0, + height: 0, + } + } +} + +impl From for PixelSize { + fn from(resolution: Resolution) -> Self { + let Resolution { width, height } = resolution; + Self { width, height } + } +} + +impl From for Resolution { + fn from(pixel_size: PixelSize) -> Self { + let PixelSize { width, height } = pixel_size; + Self { width, height } + } +} + +#[derive(Debug, Default, PartialEq, Eq)] +pub struct TemplateDependencyStructureFields { + pub template_id_offset: usize, // Starting template ID (6 bits, 0-63) + pub decode_target_count: usize, // Number of decode targets + pub chain_count: usize, // Number of chains for loss detection + pub max_layer: Layer, // Highest spatial/temporal layer + pub layers: Layers, // All layer combinations + pub templates: Vec