Merge permessage-deflate support into master

This commit is contained in:
Alex Bakon 2025-09-12 13:40:34 -04:00
commit 691b1e712f
22 changed files with 3677 additions and 414 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
target
Cargo.lock
.vscode
autobahn/client/
autobahn/server/

View File

@ -19,17 +19,20 @@ all-features = true
[features]
default = ["handshake"]
handshake = ["data-encoding", "http", "httparse", "sha1"]
handshake = ["data-encoding", "headers", "httparse", "sha1"]
headers = ["http", "dep:headers"]
url = ["dep:url"]
native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "rustls-pki-types"]
deflate = ["headers", "flate2"]
[dependencies]
data-encoding = { version = "2", optional = true }
bytes = "1.9.0"
headers = { version = "0.4.0", optional = true }
http = { version = "1.0", optional = true }
httparse = { version = "1.3.4", optional = true }
log = "0.4.8"
@ -39,6 +42,16 @@ thiserror = "2.0.7"
url = { version = "2.1.0", optional = true }
utf-8 = "0.7.5"
[dependencies.flate2]
optional = true
version = "1.0.27"
default-features = false
# We need some zlib-compatible backend from flate2 to support setting the
# context window size. Enabling the "zlib" feature is enough, and, per the
# flate2 documentation, if other crates in the build graph enable the "zlib-ng"
# or "zlib-ng-compat" features, those will take precedence.
features = ["zlib"]
[dependencies.native-tls-crate]
optional = true
package = "native-tls"
@ -63,6 +76,8 @@ optional = true
version = "0.26"
[dev-dependencies]
http = "1.0"
httparse = "1.3.4"
criterion = "0.6"
env_logger = "0.11"
input_buffer = "0.5.0"
@ -99,11 +114,11 @@ required-features = ["handshake"]
[[example]]
name = "autobahn-client"
required-features = ["handshake"]
required-features = ["handshake", "deflate"]
[[example]]
name = "autobahn-server"
required-features = ["handshake"]
required-features = ["handshake", "deflate"]
[[example]]
name = "callback-error"
@ -111,4 +126,4 @@ required-features = ["handshake"]
[[example]]
name = "srv_accept_unmasked_frames"
required-features = ["handshake"]
required-features = ["handshake"]

View File

@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing
-------

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,9 @@
use log::*;
use tungstenite::{connect, Error, Message, Result};
use tungstenite::{
client::connect_with_config, connect, extensions::compression::deflate::DeflateConfig,
protocol::WebSocketConfig, Error, Message, Result,
};
const AGENT: &str = "Tungstenite";
@ -20,7 +23,11 @@ fn update_reports() -> Result<()> {
fn run_test(case: u32) -> Result<()> {
info!("Running test case {case}");
let case_url = format!("ws://localhost:9001/runCase?case={case}&agent={AGENT}");
let (mut socket, _) = connect(case_url)?;
let mut config = WebSocketConfig::default();
config.extensions.permessage_deflate = Some(DeflateConfig::default());
let (mut socket, _) = connect_with_config(case_url, Some(config), 3)?;
loop {
match socket.read()? {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {

View File

@ -4,7 +4,10 @@ use std::{
};
use log::*;
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
use tungstenite::{
accept_with_config, extensions::compression::deflate::DeflateConfig, handshake::HandshakeRole,
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
@ -14,7 +17,10 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
let mut config = WebSocketConfig::default();
config.extensions.permessage_deflate = Some(DeflateConfig::default());
let mut socket = accept_with_config(stream, Some(config)).map_err(must_not_block)?;
info!("Running test");
loop {
match socket.read()? {

View File

@ -32,5 +32,5 @@ docker run -d --rm \
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
sleep 3
cargo run --release --example autobahn-client
cargo run --release --example autobahn-client --features=deflate
test_diff

View File

@ -22,7 +22,7 @@ function test_diff() {
fi
}
cargo run --release --example autobahn-server & WSSERVER_PID=$!
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
sleep 3
docker run --rm \

View File

@ -2,7 +2,10 @@
use std::{io, result, str, string};
use crate::protocol::{frame::coding::Data, Message};
use crate::{
extensions::{compression::CompressionError, ExtensionsError},
protocol::{frame::coding::Data, Message},
};
#[cfg(feature = "handshake")]
use http::{header::HeaderName, Response};
use thiserror::Error;
@ -194,6 +197,9 @@ pub enum ProtocolError {
/// The `Sec-WebSocket-Protocol` header was invalid
#[error("SubProtocol error: {0}")]
SecWebSocketSubProtocolError(SubProtocolError),
/// The `Sec-WebSocket-Extensions` header is invalid.
#[error("Invalid \"Sec-WebSocket-Extensions\" header: {0}")]
InvalidExtensionsHeader(#[from] ExtensionsError),
/// Garbage data encountered after client request.
#[error("Junk after client request")]
JunkAfterRequest,
@ -229,6 +235,9 @@ pub enum ProtocolError {
/// Control frames must not be fragmented.
#[error("Fragmented control frame")]
FragmentedControlFrame,
/// Control frames must not be compressed.
#[error("Compressed control frame")]
CompressedControlFrame,
/// Control frames must have a payload of 125 bytes or less.
#[error("Control frame too big (payload must be 125 bytes or less)")]
ControlFrameTooBig,
@ -241,6 +250,9 @@ pub enum ProtocolError {
/// Received a continue frame despite there being nothing to continue.
#[error("Continue frame but nothing to continue")]
UnexpectedContinueFrame,
/// Received a compressed continue frame.
#[error("Continue frame must not have compress bit set")]
CompressedContinueFrame,
/// Received data while waiting for more fragments.
#[error("While waiting for more fragments received: {0}")]
ExpectedFragment(Data),
@ -253,6 +265,9 @@ pub enum ProtocolError {
/// The payload for the closing frame is invalid.
#[error("Invalid close sequence")]
InvalidCloseSequence,
/// Compression or decompression failure.
#[error("Compression/decompression failed: {0}")]
CompressionFailure(#[from] CompressionError),
}
/// Indicates the specific type/cause of URL error.
@ -331,6 +346,6 @@ mod test {
#[test]
fn protocol_error_size() {
let size = std::mem::size_of::<crate::error::ProtocolError>();
assert!(size <= 16, "ProtocolError is large: {size}");
assert!(size <= 24, "ProtocolError is large: {size}");
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,713 @@
//! Implements "permessage-deflate" PMCE defined in [RFC 7692 Section 7]
//!
//! [RFC 7692 Section 7]: https://tools.ietf.org/html/rfc7692#section-7
use bytes::Bytes;
use flate2::{Compress, Decompress, FlushCompress, FlushDecompress, Status};
use thiserror::Error;
use crate::{extensions::compression::DecompressionError, protocol::Role};
mod config;
#[cfg_attr(not(feature = "handshake"), allow(unused_imports))]
pub(crate) use config::ParameterError as DeflateParameterError;
pub use config::{
DeflateConfig, NegotiationError as DeflateNegotiationError, PermessageDeflateConfig,
PER_MESSAGE_DEFLATE as EXTENSION_NAME,
};
#[derive(Debug)]
/// Manages per message compression using DEFLATE.
pub struct DeflateContext {
compress: DeflateCompress,
decompress: DeflateDecompress,
}
/// Errors from `permessage-deflate` extension.
#[derive(Copy, Clone, Debug, Error, PartialEq, Eq)]
pub enum DeflateError {
/// Compress failed
#[error("Failed to compress")]
Compress,
/// Decompress failed
#[error("Failed to decompress")]
Decompress,
}
#[derive(Debug)]
struct DeflateCompress {
own_context_takeover: bool,
/// The actual compressor to run payloads through.
///
/// Use the low-level [`Compress`] API instead of the higher-level
/// [`flate2::zlib::write::ZlibEncoder`] so we can compress directly into
/// the output buffer instead of the intermediate one that that type holds.
compressor: Compress,
}
#[derive(Debug)]
struct DeflateDecompress {
/// The actual decompressor to run payloads through.
///
/// Use the low-level [`Decompress`] API instead of the higher-level
/// [`flate2::zlib::write::ZlibDecoder`] so we can decompress directly into
/// the output buffer instead of the intermediate one that that type holds.
/// This also lets us avoid some decompression errors that the higher-level
/// version exhibited with certain highly-compressed payloads.
decompressor: Decompress,
peer_context_takeover: bool,
}
impl DeflateContext {
pub(crate) fn new(role: Role, config: DeflateConfig) -> Self {
let DeflateConfig {
server_no_context_takeover,
client_no_context_takeover,
compression,
..
} = config;
// Per RFC 7692 Section 7:
//
// These parameters enable two methods (no_context_takeover and
// max_window_bits) of constraining memory usage that may be
// applied independently to either direction of WebSocket traffic.
// The extension parameters with the "client_" prefix are used by
// the client to configure its compressor and by the server to
// configure its decompressor. The extension parameters with the
// "server_" prefix are used by the server to configure its
// compressor and by the client to configure its decompressor. All
// four parameters are defined for both a client's extension
// negotiation offer and a server's extension negotiation response.
//
// Here `role` is for our own end of the connection, as opposed to the
// peer end.
let (own_no_context_takeover, peer_no_context_takeover) = match role {
Role::Client => (client_no_context_takeover, server_no_context_takeover),
Role::Server => (server_no_context_takeover, client_no_context_takeover),
};
// Both ends of the connection act as both compressor and decompressor.
// We compress with the window size for our role and decompress with the
// size for the opposite role.
let (compressor_window_bits, decompressor_window_bits) = match role {
Role::Client => (config.client_max_window_bits(), config.server_max_window_bits()),
Role::Server => (config.server_max_window_bits(), config.client_max_window_bits()),
};
DeflateContext {
compress: DeflateCompress {
own_context_takeover: !own_no_context_takeover,
compressor: Compress::new_with_window_bits(
compression,
false,
compressor_window_bits.get(),
),
},
decompress: DeflateDecompress {
peer_context_takeover: !peer_no_context_takeover,
decompressor: Decompress::new_with_window_bits(
false,
decompressor_window_bits.get(),
),
},
}
}
/// Compress the payload of an outgoing message.
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Bytes, DeflateError> {
self.compress.compress(data).map_err(|e| {
log::debug!("compression failed: {e}");
DeflateError::Compress
})
}
/// Decompress the payload in a received frame.
///
/// The `is_final` argument should only be set when calling with the contents of the last frame in a message.
pub(crate) fn decompress(
&mut self,
data: &[u8],
is_final: bool,
size_limit: usize,
) -> Result<Bytes, DecompressionError<DeflateError>> {
self.decompress.decompress(data, is_final, size_limit).map_err(|e| {
e.map(|e: std::io::Error| {
log::debug!("decompression failed: {e}");
DeflateError::Decompress
})
})
}
}
const ELIDED_TRAILER_BLOCK_CONTENTS: &[u8] = &[0x00, 0x00, 0xff, 0xff];
impl DeflateCompress {
/// Compress the contents of an entire message.
///
/// This is asymmetric with [`DeflateDecompress::decompress`] in that it
/// operates on the contents of an entire message, not the comprising frames.
fn compress(&mut self, mut data: &[u8]) -> Result<Bytes, std::io::Error> {
log::trace!("compressing message payload with {} bytes", data.len());
if data.is_empty() {
// Fast path for an empty payload: it gets DEFLATE compressed to a
// zero-length uncompressed block, which conveniently is
// concat([0x00], ELIDED_TRAILER_BLOCK_CONTENTS). Then, per the RFC,
// we elide the trailing 4 bytes to get a single 0x00 byte as the
// compressed payload.
return Ok(Bytes::from_static(&[0x00]));
}
let mut output = Vec::new();
// The amount of space that should be available in `output` before
// attempting to compress data into it.
const REQUIRED_OUTPUT_SPACE: usize = 4096;
// Per RFC 7692 Section 7.2.1:
//
// An endpoint uses the following algorithm to compress a message.
//
// 1. Compress all the octets of the payload of the message using
// DEFLATE.
//
{
let mut total_read = self.compressor.total_in();
loop {
// Make sure there's space for compress_vec to write to.
output.reserve(REQUIRED_OUTPUT_SPACE);
let r = self.compressor.compress_vec(data, &mut output, FlushCompress::None)?;
let read_before = std::mem::replace(&mut total_read, self.compressor.total_in());
let read = (total_read - read_before) as usize;
data = &data[read..];
log::trace!(
"compressed {read} bytes, {} remaining; partial output is {} bytes",
data.len(),
output.len()
);
match r {
Status::Ok => continue,
Status::BufError if read == 0 => {
// We made no progress, so this BufError means that
// we're out of input.
break;
}
Status::BufError => {
// We made some progress, so we can continue after
// making more output space.
continue;
}
Status::StreamEnd => break,
}
}
}
log::trace!("flushing compressed data");
// 2. If the resulting data does not end with an empty DEFLATE
// block with no compression (the "BTYPE" bits are set to 00),
// append an empty DEFLATE block with no compression to the tail
// end.
// Ideally, at this point, we'd be able to just call compress_vec once
// with an empty slice, FlushCompress::Sync, and a vector with more than
// enough output space, and then we'd get an empty block and be done.
// After all, compress_vec is documented to output "as much output as
// possible". Unfortunately, compress_vec does not actually do that for
// all backends. See:
// - https://github.com/rust-lang/flate2-rs/blob/1.1.2/src/ffi/rust.rs#L169
// - https://github.com/Frommi/miniz_oxide/blob/0.8.8/miniz_oxide/src/deflate/stream.rs#L82
// - https://github.com/Frommi/miniz_oxide/issues/105
//
// This causes compress_vec to return Ok as soon as the compressor
// writes *any* output when called with an empty slice.
//
// So, instead, we need to keep calling compress_vec with an empty slice
// until we stop making progress.
//
// Once we have done that properly, we should always have an empty block
// at the end of the output, and then we can truncate the output to
// remove the empty block, per the RFC.
{
let mut total_out = self.compressor.total_out();
loop {
output.reserve(REQUIRED_OUTPUT_SPACE);
let output_len_before = output.len();
let output_available_before = output.capacity() - output_len_before;
let _ = self.compressor.compress_vec(&[], &mut output, FlushCompress::Sync)?;
log::trace!(
"flushed {} bytes into an available {output_available_before} bytes",
output.len() - output_len_before,
);
let out_before = std::mem::replace(&mut total_out, self.compressor.total_out());
if total_out == out_before {
break;
}
}
}
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail
// end. After this step, the last octet of the compressed data
// contains (possibly part of) the DEFLATE header bits with the
// "BTYPE" bits set to 00.
debug_assert!(output.ends_with(ELIDED_TRAILER_BLOCK_CONTENTS), "output is {output:02x?}");
output.truncate(output.len() - ELIDED_TRAILER_BLOCK_CONTENTS.len());
if !self.own_context_takeover {
// Reset if the next frame isn't supposed to be starting with the
// same compression window.
self.compressor.reset();
}
log::trace!("finished compression into {} bytes", output.len());
Ok(Bytes::from(output))
}
}
impl DeflateDecompress {
/// Decompress the contents of a single frame.
///
/// The `is_final` argument must be `true` if and only if the frame is the
/// last one in a message. The `size_limit` argument is the maximum number
/// of bytes that can be decompressed. If the input `data` decompresses to
/// more than `size_limit` bytes, [`DecompressionError::SizeLimitReached`]
/// will be returned.
fn decompress(
&mut self,
data: &[u8],
is_final: bool,
size_limit: usize,
) -> Result<Bytes, DecompressionError<std::io::Error>> {
// From RFC 7692 Section 7.2.2:
//
// An endpoint uses the following algorithm to decompress a message.
//
// 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
// payload of the message.
//
// 2. Decompress the resulting data using DEFLATE.
let mut output = Vec::new();
log::trace!(
"decompressing {} bytes in {} frame",
data.len(),
if is_final { "final" } else { "intermediate" }
);
let mut total_read = self.decompressor.total_in();
let mut decompress_from = |mut data: &[u8]| {
loop {
// Make sure there's some space to decompress into,
// optimistically assuming a 50% compression ratio of the input.
// This might put us slightly beyond the requested size limit
// but it also might not all be used.
output.reserve(2 * data.len());
let r =
self.decompressor.decompress_vec(data, &mut output, FlushDecompress::None)?;
if output.len() > size_limit {
return Err(DecompressionError::SizeLimitReached);
}
let read_before = std::mem::replace(&mut total_read, self.decompressor.total_in());
let read = (total_read - read_before) as usize;
data = &data[read..];
match r {
Status::Ok => continue,
Status::BufError => {
// We've either run out of input data or output space.
// Since we reserve space ahead of time, this must mean
// we're out of input.
break;
}
Status::StreamEnd => {
// Finished a block with BFINAL set. This is legal; from
// RFC 7692 Section 7.2.3.4:
//
// On platforms on which the flush method using an
// empty DEFLATE block with no compression is not
// available, implementors can choose to flush data
// using DEFLATE blocks with "BFINAL" set to 1.
//
// On the decompression end we reset the compressor in
// response. This relies on the assumption that the
// client produced the block with BFINAL set by
// informing their compressor that the stream was
// ending, and so any blocks afterwards won't reference
// any context from this block or earlier. It's
// obviously not a perfect assumption, but it matches
// the behavior of other widely-deployed
// permessage-deflate implementations.
self.decompressor.reset(false);
total_read = 0;
}
}
}
Ok(())
};
decompress_from(data)?;
if is_final {
// Decompress the final block that is part of the logical input to
// DEFLATE but is elided from the message payloads. This implicitly
// flushes out any pending bytes that were part of the previous
// block and doesn't leave any others since the trailer is explicitly
// an empty block.
decompress_from(&ELIDED_TRAILER_BLOCK_CONTENTS)?;
if !self.peer_context_takeover {
self.decompressor.reset(false);
}
}
Ok(Bytes::from(output))
}
}
impl From<DeflateContext> for super::PerMessageCompressionContext {
fn from(value: DeflateContext) -> Self {
Self::Deflate(value)
}
}
#[cfg(test)]
pub(crate) mod test {
use rand::{distr::Distribution as _, RngCore, SeedableRng as _};
use super::*;
#[test]
fn interop() {
let mut data = vec![0; 2048];
rand::rngs::SmallRng::seed_from_u64(1023).fill_bytes(&mut data);
let configs = [
DeflateConfig::default(),
DeflateConfig::default().set_no_context_takeover(Role::Client, true),
DeflateConfig::default()
.set_no_context_takeover(Role::Client, true)
.set_max_window_bits(Role::Client, 10)
.unwrap(),
DeflateConfig::default().set_max_window_bits(Role::Client, 10).unwrap(),
];
let frame_sizes = [16, 64, data.len()];
for config in configs {
for frame_size in frame_sizes {
let mut client = DeflateContext::new(Role::Client, config);
let mut server = DeflateContext::new(Role::Server, config);
let mut send_and_receive = |data| {
let compressed = client.compress(data).unwrap();
let mut decompressed = Vec::<u8>::new();
let mut it = compressed.chunks(frame_size).peekable();
while let Some(frame) = it.next() {
decompressed.extend_from_slice(
&server.decompress(frame, it.peek().is_none(), usize::MAX).unwrap(),
);
}
decompressed
};
let decompressed = send_and_receive(&data);
assert_eq!(data, decompressed);
// Make sure we haven't broken compression or decompression for
// the *next* message.
let decompressed = send_and_receive(b"second message");
assert_eq!(decompressed, b"second message");
}
}
}
#[test]
fn large_message_compression() {
let mut data = vec![0; 1 << 19];
rand::rngs::SmallRng::seed_from_u64(1023).fill_bytes(&mut data);
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
let compressed = context.compress(&data).unwrap();
assert_eq!(&context.decompress(&compressed, true, usize::MAX).unwrap(), &data);
}
#[test]
fn decompression_limits_applied() {
let data = vec![0; 1 << 18];
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
let compressed = context.compress(&data).unwrap();
// A buffer of all zeros compresses very well.
assert!(compressed.len() < data.len() / 500);
assert_eq!(
context.decompress(&compressed, true, data.len() - 1),
Err(DecompressionError::SizeLimitReached)
);
}
#[test]
fn compressible_payload_prefixes() {
let _ = env_logger::try_init();
let data: Vec<u8> = rand::distr::Alphanumeric
.sample_iter(&mut rand::rngs::SmallRng::from_seed([59; 32]))
.take(1 << 16)
.collect();
let prefixes =
(5..).map(|i| 1 << i).take_while(|len| *len <= data.len()).map(|len| &data[..len]);
for prefix in prefixes {
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
println!("compressing {} bytes of compressible data", prefix.len());
let compressed = context.compress(prefix).unwrap();
assert_eq!(context.decompress(&compressed, true, usize::MAX).unwrap(), prefix);
}
}
/// Utilities for testing decomrpession of highly-compressed payloads.
pub(crate) mod very_compressed {
use bytes::Bytes;
// Compressed payload that decompresses to 50KB of zeroes. This was
// specifically chosen so that its compressed form aligns with a byte
// boundary, which lets us repeat it an arbitrary number of times to
// form the payload of a single message.
pub(crate) const FRAME_PAYLOAD: &[u8; 66] = &[
0xec, 0xc1, 0x31, 0x01, 0x00, 0x00, 0x00, 0xc2, 0xa0, 0xf5, 0x4f, 0x6d, 0x0b, 0x2f,
0xa0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe0, 0x6f,
];
pub(crate) const DECOMPRESSED_LEN: usize = 50 * 1024;
pub(crate) fn make_frames(frame_count: usize) -> impl Iterator<Item = (Bytes, bool)> {
std::iter::repeat_n(FRAME_PAYLOAD, frame_count).enumerate().map(move |(i, bytes)| {
let is_final = i == frame_count - 1;
let bytes = if is_final {
bytes.iter().copied().chain(std::iter::once(0x00)).collect()
} else {
Bytes::from_static(bytes)
};
(bytes, is_final)
})
}
}
#[test]
fn large_message_decompression() {
let _ = env_logger::try_init();
for frame_count in 1..=10 {
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
let decompressed: Bytes = very_compressed::make_frames(frame_count)
.enumerate()
.flat_map(|(i, (frame, is_final))| {
context
.decompress
.decompress(&frame, is_final, usize::MAX)
.unwrap_or_else(|e| panic!("deflating frame {i}/{frame_count} failed: {e}"))
})
.collect();
assert!(decompressed.iter().all(|b| *b == 0));
assert_eq!(decompressed.len(), frame_count * very_compressed::DECOMPRESSED_LEN);
}
}
#[test]
fn decompress_multiple_messages_that_each_set_bfinal() {
let _ = env_logger::try_init();
let mut rng = rand::rngs::SmallRng::from_seed([12; 32]);
let uncompressed_payloads = std::iter::repeat_with(|| {
let mut data: Vec<u8> = vec![0; 1 << 12];
rng.fill_bytes(&mut data);
data
});
let mut context = DeflateContext::new(Role::Server, DeflateConfig::default());
for (i, payload) in uncompressed_payloads.enumerate().take(5) {
let mut compressed = context.compress(&payload).unwrap().try_into_mut().unwrap();
// The final block in the stream is a 5-byte uncompressed block, but
// with the trailing 4 bytes of the body chopped off (per the RFC).
// We don't know where in the last *byte* the final block begins
// (since DEFLATE is a bit-oriented protocol), so to make sure the
// payload ends with a block with BFINAL set we need to append
// another block. First we reattach the chopped-off bytes from the
// last block. Then we push *another* 5-byte uncompressed block with
// BFINAL set. Lastly we chop off the trailing 4 bytes per the spec.
compressed.extend_from_slice(ELIDED_TRAILER_BLOCK_CONTENTS);
compressed.extend_from_slice(&[0x01, 0x00, 0x00, 0xff, 0xff]);
compressed.truncate(compressed.len() - ELIDED_TRAILER_BLOCK_CONTENTS.len());
println!("decompressing block {i}");
let decompressed = context.decompress(&compressed, true, usize::MAX).unwrap();
assert_eq!(decompressed.len(), payload.len());
assert_eq!(decompressed, payload);
}
}
mod rfc_7692_section_7_2_3_examples {
use super::*;
#[test]
fn one_block() {
// From RFC 7692 Section 7.2.3.1:
//
// Suppose that an endpoint sends a text message "Hello". If the
// endpoint uses one compressed DEFLATE block (compressed with fixed
// Huffman code and the "BFINAL" bit not set) to compress the message,
// the endpoint obtains the compressed data to use for the message
// payload as follows.
//
// The endpoint compresses "Hello" into one compressed DEFLATE block and
// flushes the resulting data into a byte array using an empty DEFLATE
// block with no compression:
//
// 0xf2 0x48 0xcd 0xc9 0xc9 0x07 0x00 0x00 0x00 0xff 0xff
//
// By stripping 0x00 0x00 0xff 0xff from the tail end, the endpoint gets
// the data to use for the message payload:
//
const EXPECTED_COMPRESSED_PAYLOAD: &[u8] = &[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00];
let mut context = DeflateContext::new(Role::Server, DeflateConfig::default());
let compressed = context.compress(b"Hello").unwrap();
assert_eq!(&compressed[..], EXPECTED_COMPRESSED_PAYLOAD);
//
// ...
//
// Suppose that the endpoint sends the compressed message with
// fragmentation. The endpoint splits the compressed data into
// fragments and builds frames for each fragment. For example, if the
// fragments are 3 and 4 octets,
//
const FRAGMENTED_FRAMES: &[&[u8]] = &[
// the first frame is:
&[0x41, 0x03, 0xf2, 0x48, 0xcd],
// and the second frame is:
&[0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00],
];
//
// Note that the RSV1 bit is set only on the first frame.
let frame_payloads =
FRAGMENTED_FRAMES.iter().map(|frame| &frame[2..]).collect::<Vec<_>>();
let decompressed = frame_payloads
.iter()
.enumerate()
.map(|(index, payload)| {
context.decompress(payload, index == frame_payloads.len() - 1, usize::MAX)
})
.collect::<Result<Vec<_>, _>>()
.unwrap()
.concat();
assert_eq!(decompressed, b"Hello");
}
#[test]
fn sharing_sliding_window() {
const ROLE: Role = Role::Client;
// From RFC 7692 Section 7.2.3.2:
//
// Suppose that a client has sent a message "Hello" as a compressed
// message and will send the same message "Hello" again as a compressed
// message.
//
const FIRST_PAYLOAD: &[u8] = &[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00];
//
// The above is the payload of the first message that the client has
// sent. If the "agreed parameters" contain the
// "client_no_context_takeover" extension parameter, the client
// compresses the payload of the next message into the same bytes (if
// the client uses the same "BTYPE" value and "BFINAL" value). So, the
// payload of the second message will be:
//
const SECOND_PAYLOAD: &[u8] = &[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00];
let mut context = DeflateContext::new(
ROLE,
DeflateConfig::default().set_no_context_takeover(ROLE, true),
);
assert_eq!(&context.compress(b"Hello").unwrap()[..], FIRST_PAYLOAD);
assert_eq!(&context.compress(b"Hello").unwrap()[..], SECOND_PAYLOAD);
//
// If the "agreed parameters" did not contain the
// "client_no_context_takeover" extension parameter, the client can
// compress the payload of the next message into fewer bytes by
// referencing the history in the LZ77 sliding window. So, the payload
// of the second message will be:
//
const NEW_SECOND_PAYLOAD: &[u8] = &[0xf2, 0x00, 0x11, 0x00, 0x00];
let mut context = DeflateContext::new(ROLE, DeflateConfig::default());
assert_eq!(&context.compress(b"Hello").unwrap()[..], FIRST_PAYLOAD);
assert_eq!(&context.compress(b"Hello").unwrap()[..], NEW_SECOND_PAYLOAD);
}
#[test]
fn deflate_block_with_bfinal_set() {
// From RFC 7692 Section 7.2.3.4:
//
// On platforms on which the flush method using an empty DEFLATE
// block with no compression is not available, implementors can
// choose to flush data using DEFLATE blocks with "BFINAL" set to
// 1.
const PAYLOAD: &[u8] = &[0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00, 0x00];
// This is the payload of a message containing "Hello" compressed
// using a DEFLATE block with "BFINAL" set to 1. The first 7
// octets constitute a DEFLATE block with "BFINAL" set to 1 and
// "BTYPE" set to 01 containing "Hello". The last 1 octet (0x00)
// contains the header bits with "BFINAL" set to 0 and "BTYPE" set
// to 00, and 5 padding bits of 0. This octet is necessary to
// allow the payload to be decompressed in the same manner as
// messages flushed using DEFLATE blocks with "BFINAL" unset.
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
assert_eq!(
context.decompress(PAYLOAD, true, usize::MAX),
Ok(Bytes::from_static(b"Hello"))
);
}
#[test]
fn two_deflate_blocks() {
// From RFC 7692 Section 7.2.3.5:
//
// Two or more DEFLATE blocks may be used in one message.
const TWO_BLOCKS: &[u8] =
&[0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xca, 0xc9, 0xc9, 0x07, 0x00];
let mut context = DeflateContext::new(Role::Client, DeflateConfig::new());
assert_eq!(&context.decompress(TWO_BLOCKS, true, usize::MAX).unwrap()[..], b"Hello");
}
}
}

View File

@ -0,0 +1,89 @@
//! [Per-Message Compression Extensions][rfc7692]
//!
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
use bytes::Bytes;
use thiserror::Error;
#[cfg(feature = "deflate")]
pub mod deflate;
/// Active context for performing per-message compression.
#[derive(Debug)]
#[cfg_attr(not(feature = "deflate"), allow(missing_copy_implementations))] // This is only trivially copyable if compression is disabled.
pub enum PerMessageCompressionContext {
/// Context for compressing/decompressing with `permessage-deflate`.
#[cfg(feature = "deflate")]
Deflate(deflate::DeflateContext),
}
/// Error encountered while compressing or decompressing.
#[derive(Copy, Clone, Debug, Error, PartialEq, Eq)]
pub enum CompressionError {
/// Error encountered while deflating or inflating
#[error("Deflate error: {0}")]
#[cfg(feature = "deflate")]
Deflate(deflate::DeflateError),
}
#[derive(Debug, Error)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) enum DecompressionError<E = CompressionError> {
/// The decompressed frame is larger than the configured limit.
#[error("decompressed data is too large")]
SizeLimitReached,
/// An error was encountered while decompressing.
#[error("{0}")]
Decompression(E),
}
impl PerMessageCompressionContext {
#[inline]
pub(crate) fn compressor<'s>(
&'s mut self,
) -> impl FnMut(&Bytes) -> Result<Bytes, CompressionError> + 's {
move |payload: &Bytes| match self {
#[cfg(feature = "deflate")]
Self::Deflate(deflate_config) => {
deflate_config.compress(payload).map_err(CompressionError::Deflate)
}
#[cfg(not(feature = "deflate"))]
_ => {
let _ = payload;
unreachable!("*PerMessageCompressionContext is uninhabited")
}
}
}
#[inline]
pub(crate) fn decompressor<'s>(
&'s mut self,
) -> impl FnMut(&Bytes, bool, usize) -> Result<Bytes, DecompressionError> + 's {
move |payload, is_final, size_limit| match self {
#[cfg(feature = "deflate")]
Self::Deflate(deflate_config) => deflate_config
.decompress(payload, is_final, size_limit)
.map_err(|e| e.map(CompressionError::Deflate)),
#[cfg(not(feature = "deflate"))]
_ => {
let _ = (payload, is_final, size_limit);
unreachable!("*PerMessageCompressionContext is uninhabited")
}
}
}
}
impl<E> DecompressionError<E> {
pub(crate) fn map<T>(self, f: impl FnOnce(E) -> T) -> DecompressionError<T> {
match self {
Self::SizeLimitReached => DecompressionError::SizeLimitReached,
Self::Decompression(e) => DecompressionError::Decompression(f(e)),
}
}
}
impl<E: Into<std::io::Error>> From<E> for DecompressionError<std::io::Error> {
fn from(value: E) -> Self {
Self::Decompression(value.into())
}
}

View File

@ -0,0 +1,57 @@
//! HTTP Request and response header handling.
use headers::Error;
use http::HeaderValue;
mod sec_websocket_extensions;
#[allow(unused)]
pub(crate) use sec_websocket_extensions::{
SecWebsocketExtensions, WebsocketExtensionParam, WebsocketProtocolExtension,
};
/// Reads a comma-delimited raw header into a Vec.
fn from_comma_delimited<'i, I, T, E>(values: &mut I) -> Result<E, Error>
where
I: Iterator<Item = &'i HeaderValue>,
T: ::std::str::FromStr,
E: ::std::iter::FromIterator<T>,
{
from_delimited(&mut values.flat_map(|header_value| header_value.to_str()), ',')
}
/// Reads a single-character-delimited raw header into a Vec.
fn from_delimited<'i, I, T, E>(values: &mut I, delimiter: char) -> Result<E, Error>
where
I: Iterator<Item = &'i str>,
T: ::std::str::FromStr,
E: ::std::iter::FromIterator<T>,
{
values
.flat_map(|string| {
let mut in_quotes = false;
string
.split(move |c| {
#[allow(clippy::collapsible_else_if)]
if in_quotes {
if c == '"' {
in_quotes = false;
}
false // dont split
} else {
if c == delimiter {
true // split
} else {
if c == '"' {
in_quotes = true;
}
false // dont split
}
}
})
.filter_map(|x| match x.trim() {
"" => None,
y => Some(y),
})
.map(|x| x.parse().map_err(|_| Error::invalid()))
})
.collect()
}

View File

@ -0,0 +1,376 @@
use std::{borrow::Cow, fmt::Debug, iter::FromIterator, str::FromStr};
use bytes::BytesMut;
use http::HeaderValue;
use super::{from_comma_delimited, from_delimited};
/// The `Sec-Websocket-Extensions` header.
///
/// This header is used in the Websocket handshake, sent by the client to the
/// server and then from the server to the client. It is a proposed and
/// agreed-upon list of websocket protocol extensions to use.
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct SecWebsocketExtensions(Vec<WebsocketProtocolExtension>);
/// An extension listed in a [`SecWebsocketExtensions`] header.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct WebsocketProtocolExtension {
name: Cow<'static, str>,
params: Vec<WebsocketExtensionParam>,
}
/// Named parameter for an extension in a `Sec-Websocket-Extensions` header.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct WebsocketExtensionParam {
name: Cow<'static, str>,
value: Option<String>,
}
impl SecWebsocketExtensions {
/// Constructs a new header with the provided extensions.
pub fn new(extensions: impl IntoIterator<Item = WebsocketProtocolExtension>) -> Self {
Self(extensions.into_iter().collect())
}
/// Returns an iterator over the extensions in this header.
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
self.into_iter()
}
/// Returns the number of extensions in this header.
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
/// Returns a [`HeaderValue`] with the encoded contents of this header.
pub fn header_value(&self) -> HeaderValue {
let extensions = CommaDelimited(self.0.as_slice());
let mut buffer = BytesMut::with_capacity(extensions.encoded_len());
extensions.write_with(&mut |slice| buffer.extend_from_slice(slice));
HeaderValue::from_maybe_shared(buffer).expect("valid construction")
}
}
impl WebsocketProtocolExtension {
/// Constructs a new extension directive with the given name and parameters.
pub fn new(
name: impl Into<Cow<'static, str>>,
params: impl IntoIterator<Item = WebsocketExtensionParam>,
) -> Self {
Self { name: name.into(), params: params.into_iter().collect() }
}
/// The name of this extension directive.
pub fn name(&self) -> &str {
&self.name
}
/// Returns an iterator over the parameters for this extension directive.
pub fn params(&self) -> impl Iterator<Item = &WebsocketExtensionParam> {
self.params.iter()
}
}
impl WebsocketExtensionParam {
/// Constructs a new parameter with the given name and optional value.
#[inline]
pub fn new(name: impl Into<Cow<'static, str>>, value: Option<String>) -> Self {
Self { name: name.into(), value }
}
/// The name of the parameter.
#[inline]
pub fn name(&self) -> &str {
&self.name
}
/// The parameter value, if there is one.
#[inline]
pub fn value(&self) -> Option<&str> {
self.value.as_deref()
}
}
impl headers::Header for SecWebsocketExtensions {
fn name() -> &'static ::http::header::HeaderName {
&::http::header::SEC_WEBSOCKET_EXTENSIONS
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
I: Iterator<Item = &'i HeaderValue>,
{
from_comma_delimited(values).map(SecWebsocketExtensions)
}
fn encode<E: Extend<headers::HeaderValue>>(&self, values: &mut E) {
values.extend(std::iter::once(self.header_value()))
}
}
impl From<WebsocketProtocolExtension> for SecWebsocketExtensions {
fn from(value: WebsocketProtocolExtension) -> Self {
Self(vec![value])
}
}
impl FromIterator<WebsocketProtocolExtension> for SecWebsocketExtensions {
fn from_iter<T: IntoIterator<Item = WebsocketProtocolExtension>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
impl IntoIterator for SecWebsocketExtensions {
type Item = WebsocketProtocolExtension;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a> IntoIterator for &'a SecWebsocketExtensions {
type Item = &'a WebsocketProtocolExtension;
type IntoIter = std::slice::Iter<'a, WebsocketProtocolExtension>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl FromStr for WebsocketProtocolExtension {
type Err = headers::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (name, tail) = s.split_once(';').map(|(n, t)| (n, Some(t))).unwrap_or((s, None));
let params = from_delimited(&mut tail.into_iter(), ';')?;
Ok(Self { name: name.trim().to_owned().into(), params })
}
}
impl std::fmt::Display for WebsocketProtocolExtension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { name, params } = self;
write!(f, "{name}")?;
for param in params {
write!(f, "; {param}")?;
}
Ok(())
}
}
impl FromStr for WebsocketExtensionParam {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (name, value) = s.split_once('=').map(|(n, t)| (n, Some(t))).unwrap_or((s, None));
let value = value.map(|value| value.trim().to_owned());
Ok(Self { name: name.trim().to_owned().into(), value })
}
}
impl std::fmt::Display for WebsocketExtensionParam {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { name, value } = self;
write!(f, "{name}")?;
if let Some(value) = value {
write!(f, "={value}")?;
}
Ok(())
}
}
trait WriteTo {
fn encoded_len(&self) -> usize {
let mut size = 0;
self.write_with(&mut |slice| size += slice.len());
size
}
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized));
}
impl WriteTo for WebsocketProtocolExtension {
fn encoded_len(&self) -> usize {
let Self { name, params } = self;
let params_len: usize = params.iter().map(|p| p.encoded_len() + 2).sum();
name.len() + params_len
}
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
let Self { name, params } = self;
write(name.as_bytes());
for param in params {
write(b"; ");
param.write_with(write);
}
}
}
impl WriteTo for WebsocketExtensionParam {
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
let Self { name, value } = self;
write(name.as_bytes());
if let Some(value) = value {
write(b"=");
write(value.as_bytes());
}
}
}
#[derive(Debug)]
struct CommaDelimited<T>(T);
impl<T> CommaDelimited<T> {
const SEPARATOR: &[u8] = b", ";
}
impl<T: WriteTo> WriteTo for CommaDelimited<&[T]> {
fn encoded_len(&self) -> usize {
let all_encoded_len: usize = self.0.iter().map(T::encoded_len).sum();
let all_separators_len = self.0.len().saturating_sub(1) * Self::SEPARATOR.len();
all_encoded_len + all_separators_len
}
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
let mut is_first = true;
for item in self.0 {
let was_first = std::mem::replace(&mut is_first, false);
if !was_first {
write(Self::SEPARATOR);
}
item.write_with(write);
}
}
}
impl<T: WriteTo, const N: usize> WriteTo for CommaDelimited<[T; N]> {
fn encoded_len(&self) -> usize {
CommaDelimited(self.0.as_slice()).encoded_len()
}
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
CommaDelimited(self.0.as_slice()).write_with(write);
}
}
#[cfg(test)]
mod tests {
use headers::{Header, HeaderMapExt as _};
use super::*;
fn test_decode<T: Header>(values: &[&str]) -> Option<T> {
let mut map = ::http::HeaderMap::new();
for val in values {
map.append(T::name(), val.parse().unwrap());
}
map.typed_get()
}
#[cfg(test)]
fn test_encode<T: Header>(header: T) -> ::http::HeaderMap {
let mut map = ::http::HeaderMap::new();
map.typed_insert(header);
map
}
#[test]
fn parse_separate_headers() {
// From https://tools.ietf.org/html/rfc6455#section-9.1
let extensions =
test_decode::<SecWebsocketExtensions>(&["foo", "bar; baz=2"]).expect("valid");
assert_eq!(
extensions,
SecWebsocketExtensions(vec![
WebsocketProtocolExtension { name: "foo".into(), params: vec![] },
WebsocketProtocolExtension {
name: "bar".into(),
params: vec![WebsocketExtensionParam {
name: "baz".into(),
value: Some("2".to_owned())
}],
}
])
);
}
#[test]
fn round_trip_complex() {
let extensions = test_decode::<SecWebsocketExtensions>(&[
"deflate-stream",
"mux; max-channels=4; flow-control, deflate-stream",
"private-extension",
])
.expect("valid");
let headers = test_encode(extensions);
assert_eq!(
headers["sec-websocket-extensions"],
"deflate-stream, mux; max-channels=4; flow-control, deflate-stream, private-extension"
);
}
#[test]
fn write_to_exact_encoded_len() {
trait WriteToDyn: Debug {
fn encoded_len(&self) -> usize;
fn write_with(&self, write: &mut dyn FnMut(&[u8]));
}
impl<W: WriteTo + Debug> WriteToDyn for W {
fn encoded_len(&self) -> usize {
WriteTo::encoded_len(self)
}
fn write_with(&self, write: &mut dyn FnMut(&[u8])) {
WriteTo::write_with(self, write);
}
}
// This isn't a required property for correctness but if the length
// precomputation is wrong we'll over- or under-allocate during
// conversion.
let cases: &[Box<dyn WriteToDyn>] = &[
Box::new(CommaDelimited([
WebsocketProtocolExtension::from_str("extension-name").unwrap(),
WebsocketProtocolExtension::from_str("with-params; a=5; b=8").unwrap(),
])),
Box::new(CommaDelimited::<[WebsocketProtocolExtension; 0]>([])),
Box::new(CommaDelimited([
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
])),
Box::new(WebsocketProtocolExtension::new(
"name",
["foo=123".parse().unwrap(), "bar".parse().unwrap(), "baz=four".parse().unwrap()],
)),
];
for case in cases {
let mut value = Vec::new();
let expected_len = case.encoded_len();
case.write_with(&mut |slice| value.extend_from_slice(slice));
assert_eq!(value.len(), expected_len, "for {case:?}");
}
}
}

422
src/extensions/mod.rs Normal file
View File

@ -0,0 +1,422 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
use bytes::Bytes;
use thiserror::Error;
use crate::extensions::compression::{
CompressionError, DecompressionError, PerMessageCompressionContext,
};
#[cfg(feature = "handshake")]
use crate::extensions::headers::{SecWebsocketExtensions, WebsocketProtocolExtension};
use crate::protocol::Role;
pub mod compression;
#[cfg(feature = "headers")]
pub(crate) mod headers;
/// Container for configured extensions for a connection.
#[derive(Debug, Default)]
#[allow(missing_copy_implementations)]
pub struct Extensions {
/// The Per-Message Compression extension configured for the connection, if
/// any.
per_message_compression: Option<PerMessageCompressionContext>,
}
/// Configuration for extensions for a connection.
#[derive(Copy, Clone, Default, Debug)]
#[non_exhaustive]
pub struct ExtensionsConfig {
/// Configuration for the `permessage-deflate` PMCE as specified by [RFC 7692].
///
/// [RFC 7692]: https://tools.ietf.org/html/rfc7692
#[cfg(feature = "deflate")]
pub permessage_deflate: Option<compression::deflate::DeflateConfig>,
}
/// Error encountered while handling extensions.
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ExtensionsError {
/// The header included an invalid extension.
#[error("Extension header had invalid extension: {0}")]
InvalidExtension(Box<str>),
/// The negotiation response included an extension more than once.
#[error("Extension negotiation response had conflicting extension: {0}")]
ExtensionConflict(Box<str>),
/// The header included an unparsable extension.
#[error("Extension negotiation response had malformed extension: {0}")]
MalformedExtension(&'static str),
}
#[cfg(feature = "handshake")]
impl ExtensionsConfig {
pub(crate) fn generate_offers(&self) -> impl Iterator<Item = WebsocketProtocolExtension> {
let Self {
#[cfg(feature = "deflate")]
permessage_deflate,
} = self;
#[allow(unused_mut, unused_assignments)]
let mut permessage_compression_offer = None;
#[cfg(feature = "deflate")]
{
permessage_compression_offer = permessage_deflate.as_ref().map(|p| p.as_offer().into());
}
permessage_compression_offer.into_iter()
}
/// Checks that the given extensions are compatible with the given config
/// (if any).
///
/// Receives a [`SecWebsocketExtensions`] header in a handshake response and
/// evaluates it against the given local configuration. Returns a new
/// `Extensions` that can be used for the connection for the completed
/// handshake.
pub(crate) fn verify_agreed_on(
&self,
agreed: SecWebsocketExtensions,
) -> Result<Extensions, ExtensionsError> {
#[cfg_attr(not(feature = "deflate"), allow(unused_mut))]
let mut per_message_compression = None;
for extension in agreed.iter() {
match extension.name() {
#[cfg(feature = "deflate")]
compression::deflate::EXTENSION_NAME => {
use compression::deflate::{
DeflateContext, DeflateParameterError, PermessageDeflateConfig,
EXTENSION_NAME,
};
// Already had PMCE configured
if per_message_compression.is_some() {
return Err(ExtensionsError::ExtensionConflict(EXTENSION_NAME.into()));
}
let deflate = self
.permessage_deflate
.ok_or_else(|| ExtensionsError::InvalidExtension(EXTENSION_NAME.into()))?;
let extension: PermessageDeflateConfig = PermessageDeflateConfig::parse_params(
extension.params(),
)
.map_err(|_: DeflateParameterError| {
ExtensionsError::MalformedExtension(EXTENSION_NAME)
})?;
let deflate_config = deflate.accept_response(extension).map_err(|e| {
ExtensionsError::InvalidExtension(format!("{EXTENSION_NAME}: {e}").into())
})?;
per_message_compression =
Some(DeflateContext::new(Role::Client, deflate_config).into());
}
name => return Err(ExtensionsError::InvalidExtension(name.into())),
}
}
Ok(Extensions { per_message_compression })
}
/// Checks whether the given extension headers are compatible with the given
/// config (if any).
///
/// Recieves a [`SecWebsocketExtensions`] header in a handshake request and
/// evaluates it against the given local configuration. Returns a
/// `SecWebsocketExtensions` header to be sent in the handshake response to
/// the client, and a `Extensions` value to be used for the connection, once
/// it is established.
pub(crate) fn accept_offers(
&self,
extensions: &SecWebsocketExtensions,
) -> Result<(Extensions, Option<SecWebsocketExtensions>), ExtensionsError> {
#[cfg_attr(not(feature = "deflate"), allow(unused_mut))]
let mut per_message_compression = None;
for extension in extensions.iter() {
// Only one extension is currently supported. If that changes,
// this will need to be updated to apply the extensions in the correct order.
match extension.name() {
#[cfg(feature = "deflate")]
compression::deflate::EXTENSION_NAME => {
use compression::deflate::{
DeflateContext, PermessageDeflateConfig, EXTENSION_NAME,
};
let deflate = match self.permessage_deflate {
Some(deflate) => deflate,
None => continue,
};
let extension = match PermessageDeflateConfig::parse_params(extension.params())
{
Ok(extension) => extension,
Err(e) => {
// Per RFC 7692 Section 7:
//
// A server MUST decline an extension negotiation
// offer for this extension if any of the following
// conditions are met:
//
// o The negotiation offer contains an extension
// parameter not defined for use in an offer.
//
// Declining instead of rejecting the request
// outright allows clients that conform to a
// (currently hypothetical) RFC that supersedes RFC
// 7692 to fall back to requesting to the behavior
// specified in the latter.
log::debug!("{EXTENSION_NAME} extension: {e}");
continue;
}
};
// Per RFC 7692 Section 5:
//
// A client may also offer multiple PMCE choices to the server
// by including multiple elements in the
// "Sec-WebSocket-Extensions" header, one for each PMCE
// offered. This set of elements MAY include multiple PMCEs
// with the same extension name to offer the possibility to
// use the same algorithm with different configuration
// parameters. The order of elements is important as it
// specifies the client's preference. An element preceding
// another element has higher preference. It is recommended
// that a server accepts PMCEs with higher preference if the
// server supports them.
//
// Follow the RFC recommendation by not overwriting a PMCE that
// is already configured.
if per_message_compression.is_some() {
continue;
}
if let Some((config, response)) = deflate.accept_offer(extension) {
per_message_compression = Some((
DeflateContext::new(Role::Server, config).into(),
response.into(),
));
}
}
// Ignore any unknown extensions in the offer.
_ => {}
}
}
let (per_message_compression, response) = match per_message_compression {
Some((a, b)) => (Some(a), Some(b)),
None => (None, None),
};
Ok((
Extensions { per_message_compression },
response.map(|response| SecWebsocketExtensions::new(std::iter::once(response))),
))
}
}
impl ExtensionsConfig {
/// Bypasses negotiation of extension parameters and enables those that have
/// been configured.
///
/// Returns an [`Extensions`] that has all the extensions enabled that this
/// [`ExtensionsConfig`] was configured with.
pub(crate) fn into_unnegotiated_context(self, role: Role) -> Extensions {
// This can only be infallible while only one per-message compression
// extension is supported. If more are added there will need to be some
// resolution strategy for picking which one takes precedence.
let Self {
#[cfg(feature = "deflate")]
permessage_deflate,
} = self;
#[cfg_attr(feature = "deflate", allow(unused_assignments))]
#[cfg_attr(not(feature = "deflate"), allow(unused_mut))]
let mut per_message_compression = None;
#[cfg(feature = "deflate")]
{
per_message_compression = permessage_deflate
.map(|deflate| compression::deflate::DeflateContext::new(role, deflate).into());
}
let _ = role;
Extensions { per_message_compression }
}
}
impl Extensions {
/// Returns a function that, if present, compresses a message payload.
///
/// The returned value will only be `Some` if a per-message compression
/// extension, as specified by [RFC 7692], was configured for the connection
/// state to which this `Extensions` applies.
///
/// [RFC 7692]: https://tools.ietf.org/html/rfc7692
#[inline]
pub(crate) fn per_message_compressor<'s>(
&'s mut self,
) -> Option<impl FnOnce(&Bytes) -> Result<Bytes, CompressionError> + 's> {
let Self { per_message_compression } = self;
per_message_compression.as_mut().map(PerMessageCompressionContext::compressor)
}
/// Returns a function that, if present, decompresses a frame payload.
///
/// The returned value will only be `Some` if a per-message compression
/// extension, as specified by [RFC 7692], was configured for the connection
/// state to which this `Extensions` applies. The closure takes as arguments
/// the frame payload, in bytes, a boolean indicating whether the frame is
/// the final one for a message, and the maximum number of uncompressed
/// bytes to produce before returning an error.
///
/// [RFC 7692]: https://tools.ietf.org/html/rfc7692
#[inline]
pub(crate) fn per_message_decompressor<'s>(
&'s mut self,
) -> Option<impl FnMut(&Bytes, bool, usize) -> Result<Bytes, DecompressionError> + 's> {
let Self { per_message_compression } = self;
per_message_compression.as_mut().map(PerMessageCompressionContext::decompressor)
}
}
#[cfg(feature = "handshake")]
#[cfg(test)]
mod test {
use super::*;
#[test]
fn accept_offers_ignores_unknown_extensions() {
let (Extensions { per_message_compression }, response) = ExtensionsConfig::default()
.accept_offers(&SecWebsocketExtensions::new([
"unknown-1".parse().unwrap(),
"other-unknown; a=5; b=3".parse().unwrap(),
]))
.unwrap();
assert!(matches!(per_message_compression, None));
assert_eq!(response, None);
}
#[test]
fn accept_offers_with_deflate_disabled() {
let extensions = ExtensionsConfig::default();
// With or without #[cfg(feature = "deflate")], the extension should be ignored.
let (Extensions { per_message_compression }, response) =
extensions.accept_offers(&SecWebsocketExtensions::new([])).unwrap();
assert!(matches!(per_message_compression, None));
assert_eq!(response, None);
}
#[cfg(feature = "deflate")]
#[test]
fn accept_offers_with_deflate_enabled() {
let extensions = ExtensionsConfig { permessage_deflate: Some(Default::default()) };
{
// If the client doesn't offer permessage-deflate support, the response
// shouldn't include it.
let (Extensions { per_message_compression }, response) =
extensions.accept_offers(&SecWebsocketExtensions::new([])).unwrap();
assert!(matches!(per_message_compression, None));
assert_eq!(response, None);
}
{
// If the client does offer support, the response should include it.
let (Extensions { per_message_compression }, response) = extensions
.accept_offers(&SecWebsocketExtensions::new([
WebsocketProtocolExtension::new(compression::deflate::EXTENSION_NAME, []),
WebsocketProtocolExtension::new("some-other-extension", []),
]))
.unwrap();
assert!(matches!(per_message_compression, Some(_)));
assert_eq!(
response,
Some(SecWebsocketExtensions::new([WebsocketProtocolExtension::new(
compression::deflate::EXTENSION_NAME,
[]
)]))
);
}
}
#[cfg(feature = "deflate")]
#[test]
fn accept_offers_picks_first_acceptable_offer() {
use compression::deflate::*;
let extensions = ExtensionsConfig {
permessage_deflate: Some(
DeflateConfig::new().set_max_window_bits(Role::Client, 11).unwrap(),
),
};
let (Extensions { per_message_compression }, response) = extensions
.accept_offers(&SecWebsocketExtensions::new([
// These two offers are declined because they doesn't indicate
// support for client_max_window_bits, which the server is
// configured to require.
"permessage-deflate".parse().unwrap(),
"permessage-deflate; server_max_window_bits=12".parse().unwrap(),
// This offer would be acceptable but it has a parameter that the server doesn't recognize.
"permessage-deflate; client_max_window_bits=11; parameter-from-the-future=3"
.parse()
.unwrap(),
// This offer is accepted.
"permessage-deflate; client_no_context_takeover; client_max_window_bits=11"
.parse()
.unwrap(),
// This offer is ignored since an earlier one was accepted.
"permessage-deflate; client_max_window_bits=10".parse().unwrap(),
]))
.unwrap();
assert!(matches!(per_message_compression, Some(PerMessageCompressionContext::Deflate(_))));
assert_eq!(
response,
Some(SecWebsocketExtensions::new([DeflateConfig::new()
.set_no_context_takeover(Role::Client, true)
.set_max_window_bits(Role::Client, 11)
.unwrap()
.as_offer()
.into()]))
)
}
#[cfg(feature = "deflate")]
#[test]
fn verify_agreed_on_deflate_then_garbage() {
let extensions = ExtensionsConfig { permessage_deflate: Some(Default::default()) };
let result = extensions.verify_agreed_on(SecWebsocketExtensions::new([
WebsocketProtocolExtension::new(compression::deflate::EXTENSION_NAME, []),
WebsocketProtocolExtension::new("unrecognized", []),
]));
assert_eq!(result.unwrap_err(), ExtensionsError::InvalidExtension("unrecognized".into()));
}
#[cfg(feature = "deflate")]
#[test]
fn verify_agreed_on_deflate_multiple_times() {
let extensions = ExtensionsConfig { permessage_deflate: Some(Default::default()) };
let result = extensions.verify_agreed_on(SecWebsocketExtensions::new([
WebsocketProtocolExtension::new(compression::deflate::EXTENSION_NAME, []),
WebsocketProtocolExtension::new(
compression::deflate::EXTENSION_NAME,
["client_no_context_takeover".parse().unwrap()],
),
]));
assert_eq!(
result.unwrap_err(),
ExtensionsError::ExtensionConflict(compression::deflate::EXTENSION_NAME.into())
);
}
}

View File

@ -5,6 +5,7 @@ use std::{
marker::PhantomData,
};
use headers::{Header, HeaderMapExt};
use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -19,6 +20,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result, SubProtocolError, UrlError},
extensions::{headers::SecWebsocketExtensions, Extensions, ExtensionsConfig},
handshake::version_as_str,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -59,7 +61,7 @@ impl<S: Read + Write> ClientHandshake<S> {
// Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?;
let (request, key) = generate_request(request, config.as_ref().map(|w| &w.extensions))?;
let machine = HandshakeMachine::start_write(stream, request);
@ -90,7 +92,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail } => {
let result = match self.verify_data.verify_response(result) {
let (result, extensions) = match self
.verify_data
.verify_response(result, self.config.as_ref().map(|c| &c.extensions))
{
Ok(r) => r,
Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail);
@ -100,8 +105,13 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
};
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
let websocket = WebSocket::from_partially_read_with_extensions(
stream,
tail,
Role::Client,
self.config,
extensions,
);
ProcessingResult::Done((websocket, result))
}
})
@ -109,7 +119,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
pub fn generate_request(
mut request: Request,
extensions: Option<&ExtensionsConfig>,
) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new();
write!(
req,
@ -161,6 +174,14 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
.unwrap();
}
if let Some(header) = extensions
.map(ExtensionsConfig::generate_offers)
.map(SecWebsocketExtensions::new)
.filter(|header| header.len() != 0)
{
headers.append(SecWebsocketExtensions::name(), header.header_value());
}
// Now we must ensure that the headers that we've written once are not anymore present in the map.
// If they do, then the request is invalid (some headers are duplicated there for some reason).
let websocket_headers_contains =
@ -220,7 +241,11 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: Response) -> Result<Response> {
pub fn verify_response(
&self,
response: Response,
extensions: Option<&ExtensionsConfig>,
) -> Result<(Response, Extensions)> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -265,7 +290,18 @@ impl VerifyData {
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
let extensions_header =
headers.typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
})?;
let extensions = match extensions_header {
None => Extensions::default(),
Some(agreed) => extensions
.ok_or(ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into()))?
.verify_agreed_on(agreed)
.map_err(ProtocolError::from)?,
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -294,7 +330,7 @@ impl VerifyData {
}
}
Ok(response)
Ok((response, extensions))
}
}
@ -374,7 +410,7 @@ mod tests {
#[test]
fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, None).unwrap();
let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -382,7 +418,7 @@ mod tests {
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -390,11 +426,41 @@ mod tests {
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
use crate::extensions::{compression::deflate::DeflateConfig, ExtensionsConfig};
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(
request,
Some(&ExtensionsConfig {
permessage_deflate: Some(DeflateConfig::default()),
..ExtensionsConfig::default()
}),
)
.unwrap();
let correct = format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
sec-websocket-extensions: permessage-deflate; client_max_window_bits\r\n\
\r\n",
host = "localhost",
key = key
);
assert_eq!(String::try_from(request).unwrap(), &correct[..]);
}
#[test]
fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
@ -406,6 +472,6 @@ mod tests {
#[test]
fn invalid_custom_request() {
let request = http::Request::builder().method("GET").body(()).unwrap();
assert!(generate_request(request).is_err());
assert!(generate_request(request, None).is_err());
}
}

View File

@ -6,6 +6,7 @@ use std::{
result::Result as StdResult,
};
use headers::{Header, HeaderMapExt};
use http::{
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -20,6 +21,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::{headers::SecWebsocketExtensions, Extensions},
handshake::version_as_str,
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -203,6 +205,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>,
// Negotiated extension context for server.
extensions: Extensions,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -220,6 +224,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback),
config,
error_response: None,
extensions: Extensions::default(),
_marker: PhantomData,
},
}
@ -241,7 +246,30 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
}
let response = create_response(&result)?;
let mut response = create_response(&result)?;
if let Some(extensions) =
result.headers().typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
})?
{
let extensions_config = self
.config
.ok_or_else(|| {
ProtocolError::InvalidHeader(
SecWebsocketExtensions::name().clone().into(),
)
})?
.extensions;
let (extensions, agreed) = extensions_config
.accept_offers(&extensions)
.map_err(ProtocolError::from)?;
if let Some(agreed) = agreed {
response.headers_mut().typed_insert(agreed)
};
self.extensions = extensions;
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
} else {
@ -284,7 +312,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(http::Response::from_parts(parts, body).into()));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
let websocket = WebSocket::from_raw_socket_with_extensions(
stream,
Role::Server,
self.config,
std::mem::take(&mut self.extensions),
);
ProcessingResult::Done(websocket)
}
}

View File

@ -22,6 +22,7 @@ pub mod buffer;
#[cfg(feature = "handshake")]
pub mod client;
pub mod error;
pub mod extensions;
#[cfg(feature = "handshake")]
pub mod handshake;
pub mod protocol;

View File

@ -300,6 +300,32 @@ impl Frame {
}
}
/// Create a new compressed data frame.
///
/// `opcode` is of type `Data` because, per [RFC 7692 Section 6], "PMCEs
/// operate only on data messages".
///
/// [RFC 7692 Section 6]: https://tools.ietf.org/html/rfc7692#section-6
#[inline]
pub(crate) fn compressed_message(data: Bytes, opcode: Data, is_final: bool) -> Frame {
Frame {
header: FrameHeader {
is_final,
opcode: OpCode::Data(opcode),
// Per RFC 7692 Section 6:
//
// This document allocates the RSV1 bit of the WebSocket
// header for PMCEs and calls the bit the "Per-Message
// Compressed" bit. On a WebSocket connection where a PMCE is
// in use, this bit indicates whether a message is compressed
// or not.
rsv1: true,
..FrameHeader::default()
},
payload: data,
}
}
/// Create a new Pong control frame.
#[inline]
pub fn pong(data: impl Into<Bytes>) -> Frame {

View File

@ -301,7 +301,7 @@ mod tests {
#[test]
fn read_frames() {
env_logger::init();
let _ = env_logger::try_init();
let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,

View File

@ -83,6 +83,8 @@ use bytes::Bytes;
#[derive(Debug)]
pub struct IncompleteMessage {
collector: IncompleteMessageCollector,
#[cfg(feature = "deflate")]
compressed: bool,
}
#[derive(Debug)]
@ -101,9 +103,32 @@ impl IncompleteMessage {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
#[cfg(feature = "deflate")]
compressed: false,
}
}
/// Create new instance that will hold compressed data.
#[cfg(feature = "deflate")]
pub fn new_compressed(message_type: IncompleteMessageType) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
compressed: true,
}
}
pub fn compressed(&self) -> bool {
#[cfg(feature = "deflate")]
return self.compressed;
#[cfg(not(feature = "deflate"))]
return false;
}
/// Get the current filled size of the buffer.
pub fn len(&self) -> usize {
match self.collector {

View File

@ -15,12 +15,14 @@ use self::{
};
use crate::{
error::{CapacityError, Error, ProtocolError, Result},
extensions::{compression::DecompressionError, Extensions, ExtensionsConfig},
protocol::frame::Utf8Bytes,
};
use log::*;
use std::{
io::{self, Read, Write},
mem::replace,
usize,
};
/// Indicates a Client or Server role of the websocket
@ -90,6 +92,11 @@ pub struct WebSocketConfig {
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
/// By default this option is set to `false`, i.e. according to RFC 6455.
pub accept_unmasked_frames: bool,
/// Configuration for optional extensions to the base websocket protocol.
///
/// Some extensions may require optional features to be enabled at build
/// time to be supported.
pub extensions: ExtensionsConfig,
}
impl Default for WebSocketConfig {
@ -101,6 +108,7 @@ impl Default for WebSocketConfig {
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
accept_unmasked_frames: false,
extensions: ExtensionsConfig::default(),
}
}
}
@ -179,6 +187,18 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
extensions: Extensions,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.extensions = extensions;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
@ -192,10 +212,22 @@ impl<Stream> WebSocket<Stream> {
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
Self::from_partially_read_with_extensions(stream, part, role, config, Extensions::default())
}
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Extensions,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read(part, role, config),
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
@ -375,6 +407,8 @@ pub struct WebSocketContext {
unflushed_additional: bool,
/// The configuration for the websocket session.
config: WebSocketConfig,
// Container for extensions.
extensions: Extensions,
}
impl WebSocketContext {
@ -384,7 +418,12 @@ impl WebSocketContext {
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
let conf = config.unwrap_or_default();
Self::_new(role, FrameCodec::new(conf.read_buffer_size), conf)
Self::_new(
role,
FrameCodec::new(conf.read_buffer_size),
conf,
conf.extensions.into_unnegotiated_context(role),
)
}
/// Create a WebSocket context that manages an post-handshake stream.
@ -393,10 +432,41 @@ impl WebSocketContext {
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
let conf = config.unwrap_or_default();
Self::_new(role, FrameCodec::from_partially_read(part, conf.read_buffer_size), conf)
let extensions = conf.extensions.into_unnegotiated_context(role);
Self::_new(
role,
FrameCodec::from_partially_read(part, conf.read_buffer_size),
conf,
extensions,
)
}
fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
/// Create a WebSocket context for a post-handshake stream with the enabled extensions.
///
/// Where [`WebSocketContext::from_partially_read`] infers the enabled
/// extensions from the [`WebSocketConfig`], this allows the caller to
/// explicitly sets the extensions in use for the connection.
pub(crate) fn from_partially_read_with_extensions(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Extensions,
) -> Self {
let conf = config.unwrap_or_default();
Self::_new(
role,
FrameCodec::from_partially_read(part, conf.read_buffer_size),
conf,
extensions,
)
}
fn _new(
role: Role,
mut frame: FrameCodec,
config: WebSocketConfig,
extensions: Extensions,
) -> Self {
config.assert_valid();
frame.set_max_out_buffer_len(config.max_write_buffer_size);
frame.set_out_buffer_write_len(config.write_buffer_size);
@ -408,6 +478,7 @@ impl WebSocketContext {
additional_send: None,
unflushed_additional: false,
config,
extensions,
}
}
@ -500,9 +571,18 @@ impl WebSocketContext {
return Err(Error::Protocol(ProtocolError::SendAfterClosing));
}
let mut prepare_data_frame = |data, opdata| -> Result<Frame, ProtocolError> {
const IS_FINAL: bool = true;
if let Some(compressor) = self.extensions.per_message_compressor() {
let compressed = compressor(&data)?;
return Ok(Frame::compressed_message(compressed, opdata, IS_FINAL));
}
Ok(Frame::message(data, OpCode::Data(opdata), IS_FINAL))
};
let frame = match message {
Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Text(data) => prepare_data_frame(data.into(), OpData::Text)?,
Message::Binary(data) => prepare_data_frame(data, OpData::Binary)?,
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.set_additional(Frame::pong(data));
@ -633,17 +713,53 @@ impl WebSocketContext {
if !self.state.can_read() {
return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing));
}
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
{
let (is_compressed, decompressor) = {
let decompressor = self.extensions.per_message_decompressor();
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
// Per RFC 6455, the RSV1, RSV2, and RSV3 bits
//
// MUST be 0 unless an extension is negotiated that defines
// meanings for non-zero values. If a nonzero value is
// received and none of the negotiated extensions defines the
// meaning of such a nonzero value, the receiving endpoint
// MUST _Fail the WebSocket Connection_.
//
// Per RFC 7692:
//
// This document allocates the RSV1 bit of the WebSocket
// header for PMCEs and calls the bit the "Per-Message
// Compressed" bit. On a WebSocket connection where a PMCE is
// in use, this bit indicates whether a message is compressed
// or not.
if (hdr.rsv1 && decompressor.is_none()) || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
}
}
let decompressor_with_size_limit = decompressor.map(|mut f| {
let incomplete_len =
self.incomplete.as_ref().map(IncompleteMessage::len).unwrap_or(0);
let message_max = self.config.max_message_size.unwrap_or(usize::MAX);
move |bytes, is_final| {
let decompress_limit = message_max.saturating_sub(incomplete_len);
f(bytes, is_final, decompress_limit).map_err(|e| match e {
DecompressionError::SizeLimitReached => {
Error::Capacity(CapacityError::MessageTooLong {
size: incomplete_len.saturating_add(decompress_limit),
max_size: message_max,
})
}
DecompressionError::Decompression(e) => {
Error::Protocol(ProtocolError::CompressionFailure(e))
}
})
}
});
(hdr.rsv1, decompressor_with_size_limit)
};
if self.role == Role::Client && frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
@ -652,6 +768,7 @@ impl WebSocketContext {
match frame.header().opcode {
OpCode::Control(ctl) => {
drop(decompressor);
match ctl {
// All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented. (RFC 6455)
@ -661,6 +778,15 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => {
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
}
// Per RFC 7692:
//
// An endpoint MUST NOT set the "Per-Message
// Compressed" bit of control frames and non-first
// fragments of a data message. An endpoint receiving
// such a frame MUST _Fail the WebSocket Connection_.
_ if is_compressed => {
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => {
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
@ -679,29 +805,74 @@ impl WebSocketContext {
OpCode::Data(data) => {
let fin = frame.header().is_final;
match data {
OpData::Continue => {
let msg = self
let incomplete = self
.incomplete
.as_mut()
.ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
msg.extend(frame.into_payload(), self.config.max_message_size)?;
// Per RFC 7692:
//
// An endpoint MUST NOT set the "Per-Message
// Compressed" bit of control frames and non-first
// fragments of a data message. An endpoint
// receiving such a frame MUST _Fail the WebSocket
// Connection_.
if is_compressed {
return Err(Error::Protocol(ProtocolError::CompressedContinueFrame));
}
let mut payload = frame.into_payload();
if incomplete.compressed() {
let mut decompressor = decompressor.ok_or_else(|| {
// This is a continuation frame that was
// received with compression disabled, but
// the initial frame of the message was
// received with compression *enabled* and
// RSV1 set.
//
// The only way to get here is to manually
// disable compression for a stream after
// it's been established, which is arguably
// operator error. This is incorrect enough
// that it's not worth spending a lot of
// code on, but it's better to return an
// error here than crash.
log::debug!("compression was disabled between receiving frames");
ProtocolError::CompressedContinueFrame
})?;
payload = decompressor(&payload, fin)?;
};
incomplete.extend(payload, self.config.max_message_size)?;
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
}
OpData::Text if fin => {
check_max_size(frame.payload().len(), self.config.max_message_size)?;
Ok(Some(Message::Text(frame.into_text()?)))
}
OpData::Binary if fin => {
check_max_size(frame.payload().len(), self.config.max_message_size)?;
Ok(Some(Message::Binary(frame.into_payload())))
OpData::Text | OpData::Binary if fin => {
let payload = frame.into_payload();
let payload = match decompressor.filter(|_| is_compressed) {
Some(mut frame_decompressor) => frame_decompressor(&payload, fin)?,
None => payload,
};
check_max_size(payload.len(), self.config.max_message_size)?;
match data {
OpData::Text => Ok(Some(Message::Text(payload.try_into()?))),
OpData::Binary => Ok(Some(Message::Binary(payload))),
_ => panic!("Bug: message is not text nor binary"),
}
}
OpData::Text | OpData::Binary => {
let message_type = match data {
@ -709,8 +880,19 @@ impl WebSocketContext {
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut incomplete = IncompleteMessage::new(message_type);
incomplete.extend(frame.into_payload(), self.config.max_message_size)?;
let payload = frame.into_payload();
let payload = match decompressor.filter(|_| is_compressed) {
Some(mut frame_decompressor) => frame_decompressor(&payload, fin)?,
None => payload,
};
let mut incomplete = match is_compressed {
#[cfg(feature = "deflate")]
true => IncompleteMessage::new_compressed(message_type),
_ => IncompleteMessage::new(message_type),
};
incomplete.extend(payload, self.config.max_message_size)?;
self.incomplete = Some(incomplete);
Ok(None)
}
@ -860,6 +1042,7 @@ impl<T> CheckConnectionReset for Result<T> {
mod tests {
use super::{Message, Role, WebSocket, WebSocketConfig};
use crate::error::{CapacityError, Error};
use crate::extensions::ExtensionsConfig;
use std::{io, io::Cursor};
@ -920,4 +1103,178 @@ mod tests {
Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
));
}
#[cfg(feature = "deflate")]
#[test]
fn per_message_deflate_compression() {
// Example frames from RFC 7692 Section 7.2.3.2
use crate::{extensions::compression, protocol::FrameCodec};
let mut stream = Cursor::new(Vec::new());
let config = WebSocketConfig {
extensions: ExtensionsConfig {
permessage_deflate: Some(compression::deflate::DeflateConfig::default()),
},
..Default::default()
};
let mut socket = WebSocket::from_raw_socket(&mut stream, Role::Client, Some(config));
// The same message sent twice should compress better the second time
// because context takeover is enabled.
socket.write(Message::Text("Hello".into())).unwrap();
socket.write(Message::Text("Hello".into())).unwrap();
socket.flush().unwrap();
let written = stream.into_inner();
let mut codec = FrameCodec::new(written.len());
let mut stream = Cursor::new(written);
let first_frame = codec.read_frame(&mut stream, None, true, false).unwrap().unwrap();
let second_frame = codec.read_frame(&mut stream, None, true, false).unwrap().unwrap();
assert_eq!(
first_frame.payload(),
// First frame payload, from the RFC
&[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00]
);
assert_eq!(
second_frame.payload(),
// Second frame payload, from the RFC
&[0xf2, 0x00, 0x11, 0x00, 0x00]
);
}
#[cfg(feature = "deflate")]
#[test]
fn per_message_deflate_decompression() {
// Example frames from RFC 7692 Section 7.2.3.2
use crate::extensions::compression::deflate::DeflateConfig;
let incoming =
Cursor::new(&[0x41, 0x03, 0xf2, 0x48, 0xcd, 0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00]);
let config = WebSocketConfig {
extensions: ExtensionsConfig { permessage_deflate: Some(DeflateConfig::default()) },
..Default::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(config));
assert_eq!(socket.read().unwrap(), Message::Text("Hello".into()));
}
#[test]
fn per_message_compression_not_recognized() {
// Without the extension configuration, frames with the RSV1 bit set are rejected.
let incoming =
Cursor::new(&[0x41, 0x03, 0xf2, 0x48, 0xcd, 0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00]);
let config =
WebSocketConfig { extensions: ExtensionsConfig::default(), ..Default::default() };
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(config));
assert!(matches!(
socket.read().unwrap_err(),
Error::Protocol(crate::error::ProtocolError::NonZeroReservedBits)
));
}
#[cfg(feature = "deflate")]
#[test]
fn per_message_compression_decompress_respects_message_size_limit() {
use crate::extensions::compression::deflate::test::very_compressed;
use crate::extensions::compression::deflate::DeflateConfig;
use crate::protocol::frame::{
coding::{Data, OpCode},
FrameHeader,
};
let _ = env_logger::try_init();
let base_config = WebSocketConfig {
extensions: ExtensionsConfig {
permessage_deflate: Some(DeflateConfig::default()),
..Default::default()
},
..Default::default()
};
fn make_message(frame_count: usize) -> Vec<u8> {
let mut is_first = true;
let mut output = Vec::new();
for (frame, is_final) in very_compressed::make_frames(frame_count) {
let is_first = std::mem::replace(&mut is_first, false);
let header = FrameHeader {
opcode: OpCode::Data(if is_first { Data::Binary } else { Data::Continue }),
rsv1: is_first,
is_final,
..Default::default()
};
header.format(frame.len() as u64, &mut output).unwrap();
output.extend_from_slice(&frame);
}
output
}
// With the default configuration, a short message of these frames is fine.
{
let input = Cursor::new(make_message(4));
let mut socket =
WebSocket::from_raw_socket(input, Role::Client, Some(base_config.clone()));
let message = socket.read().unwrap();
assert_eq!(
message,
Message::Binary(
bytes::BytesMut::zeroed(4 * very_compressed::DECOMPRESSED_LEN).into()
)
);
}
// The maximum frame size limits on-the-wire frame size, not
// decompressed size, so this still decompresses.
{
let input = Cursor::new(make_message(2));
const MAX_FRAME_SIZE: usize = very_compressed::DECOMPRESSED_LEN - 1;
let mut socket = WebSocket::from_raw_socket(
input,
Role::Client,
Some(base_config.clone().max_frame_size(Some(MAX_FRAME_SIZE))),
);
let message = socket.read().unwrap();
assert_eq!(
message,
Message::Binary(
bytes::BytesMut::zeroed(2 * very_compressed::DECOMPRESSED_LEN).into()
)
);
}
// With a reduced maximum message size, decompressing the whole message
// fails.
{
let input = Cursor::new(make_message(5));
const MAX_MESSAGE_SIZE: usize = 3 * very_compressed::DECOMPRESSED_LEN;
let mut socket = WebSocket::from_raw_socket(
input,
Role::Client,
Some(base_config.clone().max_message_size(Some(MAX_MESSAGE_SIZE))),
);
let error = socket.read().unwrap_err();
assert!(matches!(
error,
Error::Capacity(CapacityError::MessageTooLong {
size: _,
max_size: MAX_MESSAGE_SIZE
})
));
}
}
}