Merge permessage-deflate support into master
This commit is contained in:
commit
691b1e712f
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,3 +1,5 @@
|
||||
target
|
||||
Cargo.lock
|
||||
.vscode
|
||||
autobahn/client/
|
||||
autobahn/server/
|
||||
|
||||
23
Cargo.toml
23
Cargo.toml
@ -19,17 +19,20 @@ all-features = true
|
||||
|
||||
[features]
|
||||
default = ["handshake"]
|
||||
handshake = ["data-encoding", "http", "httparse", "sha1"]
|
||||
handshake = ["data-encoding", "headers", "httparse", "sha1"]
|
||||
headers = ["http", "dep:headers"]
|
||||
url = ["dep:url"]
|
||||
native-tls = ["native-tls-crate"]
|
||||
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
|
||||
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
|
||||
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
|
||||
__rustls-tls = ["rustls", "rustls-pki-types"]
|
||||
deflate = ["headers", "flate2"]
|
||||
|
||||
[dependencies]
|
||||
data-encoding = { version = "2", optional = true }
|
||||
bytes = "1.9.0"
|
||||
headers = { version = "0.4.0", optional = true }
|
||||
http = { version = "1.0", optional = true }
|
||||
httparse = { version = "1.3.4", optional = true }
|
||||
log = "0.4.8"
|
||||
@ -39,6 +42,16 @@ thiserror = "2.0.7"
|
||||
url = { version = "2.1.0", optional = true }
|
||||
utf-8 = "0.7.5"
|
||||
|
||||
[dependencies.flate2]
|
||||
optional = true
|
||||
version = "1.0.27"
|
||||
default-features = false
|
||||
# We need some zlib-compatible backend from flate2 to support setting the
|
||||
# context window size. Enabling the "zlib" feature is enough, and, per the
|
||||
# flate2 documentation, if other crates in the build graph enable the "zlib-ng"
|
||||
# or "zlib-ng-compat" features, those will take precedence.
|
||||
features = ["zlib"]
|
||||
|
||||
[dependencies.native-tls-crate]
|
||||
optional = true
|
||||
package = "native-tls"
|
||||
@ -63,6 +76,8 @@ optional = true
|
||||
version = "0.26"
|
||||
|
||||
[dev-dependencies]
|
||||
http = "1.0"
|
||||
httparse = "1.3.4"
|
||||
criterion = "0.6"
|
||||
env_logger = "0.11"
|
||||
input_buffer = "0.5.0"
|
||||
@ -99,11 +114,11 @@ required-features = ["handshake"]
|
||||
|
||||
[[example]]
|
||||
name = "autobahn-client"
|
||||
required-features = ["handshake"]
|
||||
required-features = ["handshake", "deflate"]
|
||||
|
||||
[[example]]
|
||||
name = "autobahn-server"
|
||||
required-features = ["handshake"]
|
||||
required-features = ["handshake", "deflate"]
|
||||
|
||||
[[example]]
|
||||
name = "callback-error"
|
||||
@ -111,4 +126,4 @@ required-features = ["handshake"]
|
||||
|
||||
[[example]]
|
||||
name = "srv_accept_unmasked_frames"
|
||||
required-features = ["handshake"]
|
||||
required-features = ["handshake"]
|
||||
@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
|
||||
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
|
||||
otherwise you won't be able to communicate with the TLS endpoints.
|
||||
|
||||
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
|
||||
|
||||
Testing
|
||||
-------
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,9 @@
|
||||
use log::*;
|
||||
|
||||
use tungstenite::{connect, Error, Message, Result};
|
||||
use tungstenite::{
|
||||
client::connect_with_config, connect, extensions::compression::deflate::DeflateConfig,
|
||||
protocol::WebSocketConfig, Error, Message, Result,
|
||||
};
|
||||
|
||||
const AGENT: &str = "Tungstenite";
|
||||
|
||||
@ -20,7 +23,11 @@ fn update_reports() -> Result<()> {
|
||||
fn run_test(case: u32) -> Result<()> {
|
||||
info!("Running test case {case}");
|
||||
let case_url = format!("ws://localhost:9001/runCase?case={case}&agent={AGENT}");
|
||||
let (mut socket, _) = connect(case_url)?;
|
||||
|
||||
let mut config = WebSocketConfig::default();
|
||||
config.extensions.permessage_deflate = Some(DeflateConfig::default());
|
||||
|
||||
let (mut socket, _) = connect_with_config(case_url, Some(config), 3)?;
|
||||
loop {
|
||||
match socket.read()? {
|
||||
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
|
||||
|
||||
@ -4,7 +4,10 @@ use std::{
|
||||
};
|
||||
|
||||
use log::*;
|
||||
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
|
||||
use tungstenite::{
|
||||
accept_with_config, extensions::compression::deflate::DeflateConfig, handshake::HandshakeRole,
|
||||
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
|
||||
};
|
||||
|
||||
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
||||
match err {
|
||||
@ -14,7 +17,10 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
||||
}
|
||||
|
||||
fn handle_client(stream: TcpStream) -> Result<()> {
|
||||
let mut socket = accept(stream).map_err(must_not_block)?;
|
||||
let mut config = WebSocketConfig::default();
|
||||
config.extensions.permessage_deflate = Some(DeflateConfig::default());
|
||||
|
||||
let mut socket = accept_with_config(stream, Some(config)).map_err(must_not_block)?;
|
||||
info!("Running test");
|
||||
loop {
|
||||
match socket.read()? {
|
||||
|
||||
@ -32,5 +32,5 @@ docker run -d --rm \
|
||||
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
|
||||
|
||||
sleep 3
|
||||
cargo run --release --example autobahn-client
|
||||
cargo run --release --example autobahn-client --features=deflate
|
||||
test_diff
|
||||
|
||||
@ -22,7 +22,7 @@ function test_diff() {
|
||||
fi
|
||||
}
|
||||
|
||||
cargo run --release --example autobahn-server & WSSERVER_PID=$!
|
||||
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
|
||||
sleep 3
|
||||
|
||||
docker run --rm \
|
||||
|
||||
19
src/error.rs
19
src/error.rs
@ -2,7 +2,10 @@
|
||||
|
||||
use std::{io, result, str, string};
|
||||
|
||||
use crate::protocol::{frame::coding::Data, Message};
|
||||
use crate::{
|
||||
extensions::{compression::CompressionError, ExtensionsError},
|
||||
protocol::{frame::coding::Data, Message},
|
||||
};
|
||||
#[cfg(feature = "handshake")]
|
||||
use http::{header::HeaderName, Response};
|
||||
use thiserror::Error;
|
||||
@ -194,6 +197,9 @@ pub enum ProtocolError {
|
||||
/// The `Sec-WebSocket-Protocol` header was invalid
|
||||
#[error("SubProtocol error: {0}")]
|
||||
SecWebSocketSubProtocolError(SubProtocolError),
|
||||
/// The `Sec-WebSocket-Extensions` header is invalid.
|
||||
#[error("Invalid \"Sec-WebSocket-Extensions\" header: {0}")]
|
||||
InvalidExtensionsHeader(#[from] ExtensionsError),
|
||||
/// Garbage data encountered after client request.
|
||||
#[error("Junk after client request")]
|
||||
JunkAfterRequest,
|
||||
@ -229,6 +235,9 @@ pub enum ProtocolError {
|
||||
/// Control frames must not be fragmented.
|
||||
#[error("Fragmented control frame")]
|
||||
FragmentedControlFrame,
|
||||
/// Control frames must not be compressed.
|
||||
#[error("Compressed control frame")]
|
||||
CompressedControlFrame,
|
||||
/// Control frames must have a payload of 125 bytes or less.
|
||||
#[error("Control frame too big (payload must be 125 bytes or less)")]
|
||||
ControlFrameTooBig,
|
||||
@ -241,6 +250,9 @@ pub enum ProtocolError {
|
||||
/// Received a continue frame despite there being nothing to continue.
|
||||
#[error("Continue frame but nothing to continue")]
|
||||
UnexpectedContinueFrame,
|
||||
/// Received a compressed continue frame.
|
||||
#[error("Continue frame must not have compress bit set")]
|
||||
CompressedContinueFrame,
|
||||
/// Received data while waiting for more fragments.
|
||||
#[error("While waiting for more fragments received: {0}")]
|
||||
ExpectedFragment(Data),
|
||||
@ -253,6 +265,9 @@ pub enum ProtocolError {
|
||||
/// The payload for the closing frame is invalid.
|
||||
#[error("Invalid close sequence")]
|
||||
InvalidCloseSequence,
|
||||
/// Compression or decompression failure.
|
||||
#[error("Compression/decompression failed: {0}")]
|
||||
CompressionFailure(#[from] CompressionError),
|
||||
}
|
||||
|
||||
/// Indicates the specific type/cause of URL error.
|
||||
@ -331,6 +346,6 @@ mod test {
|
||||
#[test]
|
||||
fn protocol_error_size() {
|
||||
let size = std::mem::size_of::<crate::error::ProtocolError>();
|
||||
assert!(size <= 16, "ProtocolError is large: {size}");
|
||||
assert!(size <= 24, "ProtocolError is large: {size}");
|
||||
}
|
||||
}
|
||||
|
||||
1055
src/extensions/compression/deflate/config.rs
Normal file
1055
src/extensions/compression/deflate/config.rs
Normal file
File diff suppressed because it is too large
Load Diff
713
src/extensions/compression/deflate/mod.rs
Normal file
713
src/extensions/compression/deflate/mod.rs
Normal file
@ -0,0 +1,713 @@
|
||||
//! Implements "permessage-deflate" PMCE defined in [RFC 7692 Section 7]
|
||||
//!
|
||||
//! [RFC 7692 Section 7]: https://tools.ietf.org/html/rfc7692#section-7
|
||||
use bytes::Bytes;
|
||||
use flate2::{Compress, Decompress, FlushCompress, FlushDecompress, Status};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{extensions::compression::DecompressionError, protocol::Role};
|
||||
|
||||
mod config;
|
||||
#[cfg_attr(not(feature = "handshake"), allow(unused_imports))]
|
||||
pub(crate) use config::ParameterError as DeflateParameterError;
|
||||
pub use config::{
|
||||
DeflateConfig, NegotiationError as DeflateNegotiationError, PermessageDeflateConfig,
|
||||
PER_MESSAGE_DEFLATE as EXTENSION_NAME,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Manages per message compression using DEFLATE.
|
||||
pub struct DeflateContext {
|
||||
compress: DeflateCompress,
|
||||
decompress: DeflateDecompress,
|
||||
}
|
||||
|
||||
/// Errors from `permessage-deflate` extension.
|
||||
#[derive(Copy, Clone, Debug, Error, PartialEq, Eq)]
|
||||
pub enum DeflateError {
|
||||
/// Compress failed
|
||||
#[error("Failed to compress")]
|
||||
Compress,
|
||||
/// Decompress failed
|
||||
#[error("Failed to decompress")]
|
||||
Decompress,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DeflateCompress {
|
||||
own_context_takeover: bool,
|
||||
/// The actual compressor to run payloads through.
|
||||
///
|
||||
/// Use the low-level [`Compress`] API instead of the higher-level
|
||||
/// [`flate2::zlib::write::ZlibEncoder`] so we can compress directly into
|
||||
/// the output buffer instead of the intermediate one that that type holds.
|
||||
compressor: Compress,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DeflateDecompress {
|
||||
/// The actual decompressor to run payloads through.
|
||||
///
|
||||
/// Use the low-level [`Decompress`] API instead of the higher-level
|
||||
/// [`flate2::zlib::write::ZlibDecoder`] so we can decompress directly into
|
||||
/// the output buffer instead of the intermediate one that that type holds.
|
||||
/// This also lets us avoid some decompression errors that the higher-level
|
||||
/// version exhibited with certain highly-compressed payloads.
|
||||
decompressor: Decompress,
|
||||
peer_context_takeover: bool,
|
||||
}
|
||||
|
||||
impl DeflateContext {
|
||||
pub(crate) fn new(role: Role, config: DeflateConfig) -> Self {
|
||||
let DeflateConfig {
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
compression,
|
||||
..
|
||||
} = config;
|
||||
|
||||
// Per RFC 7692 Section 7:
|
||||
//
|
||||
// These parameters enable two methods (no_context_takeover and
|
||||
// max_window_bits) of constraining memory usage that may be
|
||||
// applied independently to either direction of WebSocket traffic.
|
||||
// The extension parameters with the "client_" prefix are used by
|
||||
// the client to configure its compressor and by the server to
|
||||
// configure its decompressor. The extension parameters with the
|
||||
// "server_" prefix are used by the server to configure its
|
||||
// compressor and by the client to configure its decompressor. All
|
||||
// four parameters are defined for both a client's extension
|
||||
// negotiation offer and a server's extension negotiation response.
|
||||
//
|
||||
// Here `role` is for our own end of the connection, as opposed to the
|
||||
// peer end.
|
||||
let (own_no_context_takeover, peer_no_context_takeover) = match role {
|
||||
Role::Client => (client_no_context_takeover, server_no_context_takeover),
|
||||
Role::Server => (server_no_context_takeover, client_no_context_takeover),
|
||||
};
|
||||
|
||||
// Both ends of the connection act as both compressor and decompressor.
|
||||
// We compress with the window size for our role and decompress with the
|
||||
// size for the opposite role.
|
||||
let (compressor_window_bits, decompressor_window_bits) = match role {
|
||||
Role::Client => (config.client_max_window_bits(), config.server_max_window_bits()),
|
||||
Role::Server => (config.server_max_window_bits(), config.client_max_window_bits()),
|
||||
};
|
||||
|
||||
DeflateContext {
|
||||
compress: DeflateCompress {
|
||||
own_context_takeover: !own_no_context_takeover,
|
||||
compressor: Compress::new_with_window_bits(
|
||||
compression,
|
||||
false,
|
||||
compressor_window_bits.get(),
|
||||
),
|
||||
},
|
||||
decompress: DeflateDecompress {
|
||||
peer_context_takeover: !peer_no_context_takeover,
|
||||
decompressor: Decompress::new_with_window_bits(
|
||||
false,
|
||||
decompressor_window_bits.get(),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Compress the payload of an outgoing message.
|
||||
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Bytes, DeflateError> {
|
||||
self.compress.compress(data).map_err(|e| {
|
||||
log::debug!("compression failed: {e}");
|
||||
DeflateError::Compress
|
||||
})
|
||||
}
|
||||
|
||||
/// Decompress the payload in a received frame.
|
||||
///
|
||||
/// The `is_final` argument should only be set when calling with the contents of the last frame in a message.
|
||||
pub(crate) fn decompress(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
is_final: bool,
|
||||
size_limit: usize,
|
||||
) -> Result<Bytes, DecompressionError<DeflateError>> {
|
||||
self.decompress.decompress(data, is_final, size_limit).map_err(|e| {
|
||||
e.map(|e: std::io::Error| {
|
||||
log::debug!("decompression failed: {e}");
|
||||
DeflateError::Decompress
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const ELIDED_TRAILER_BLOCK_CONTENTS: &[u8] = &[0x00, 0x00, 0xff, 0xff];
|
||||
|
||||
impl DeflateCompress {
|
||||
/// Compress the contents of an entire message.
|
||||
///
|
||||
/// This is asymmetric with [`DeflateDecompress::decompress`] in that it
|
||||
/// operates on the contents of an entire message, not the comprising frames.
|
||||
fn compress(&mut self, mut data: &[u8]) -> Result<Bytes, std::io::Error> {
|
||||
log::trace!("compressing message payload with {} bytes", data.len());
|
||||
if data.is_empty() {
|
||||
// Fast path for an empty payload: it gets DEFLATE compressed to a
|
||||
// zero-length uncompressed block, which conveniently is
|
||||
// concat([0x00], ELIDED_TRAILER_BLOCK_CONTENTS). Then, per the RFC,
|
||||
// we elide the trailing 4 bytes to get a single 0x00 byte as the
|
||||
// compressed payload.
|
||||
return Ok(Bytes::from_static(&[0x00]));
|
||||
}
|
||||
|
||||
let mut output = Vec::new();
|
||||
|
||||
// The amount of space that should be available in `output` before
|
||||
// attempting to compress data into it.
|
||||
const REQUIRED_OUTPUT_SPACE: usize = 4096;
|
||||
|
||||
// Per RFC 7692 Section 7.2.1:
|
||||
//
|
||||
// An endpoint uses the following algorithm to compress a message.
|
||||
//
|
||||
// 1. Compress all the octets of the payload of the message using
|
||||
// DEFLATE.
|
||||
//
|
||||
|
||||
{
|
||||
let mut total_read = self.compressor.total_in();
|
||||
loop {
|
||||
// Make sure there's space for compress_vec to write to.
|
||||
output.reserve(REQUIRED_OUTPUT_SPACE);
|
||||
|
||||
let r = self.compressor.compress_vec(data, &mut output, FlushCompress::None)?;
|
||||
|
||||
let read_before = std::mem::replace(&mut total_read, self.compressor.total_in());
|
||||
let read = (total_read - read_before) as usize;
|
||||
|
||||
data = &data[read..];
|
||||
log::trace!(
|
||||
"compressed {read} bytes, {} remaining; partial output is {} bytes",
|
||||
data.len(),
|
||||
output.len()
|
||||
);
|
||||
|
||||
match r {
|
||||
Status::Ok => continue,
|
||||
Status::BufError if read == 0 => {
|
||||
// We made no progress, so this BufError means that
|
||||
// we're out of input.
|
||||
break;
|
||||
}
|
||||
Status::BufError => {
|
||||
// We made some progress, so we can continue after
|
||||
// making more output space.
|
||||
continue;
|
||||
}
|
||||
Status::StreamEnd => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!("flushing compressed data");
|
||||
|
||||
// 2. If the resulting data does not end with an empty DEFLATE
|
||||
// block with no compression (the "BTYPE" bits are set to 00),
|
||||
// append an empty DEFLATE block with no compression to the tail
|
||||
// end.
|
||||
|
||||
// Ideally, at this point, we'd be able to just call compress_vec once
|
||||
// with an empty slice, FlushCompress::Sync, and a vector with more than
|
||||
// enough output space, and then we'd get an empty block and be done.
|
||||
// After all, compress_vec is documented to output "as much output as
|
||||
// possible". Unfortunately, compress_vec does not actually do that for
|
||||
// all backends. See:
|
||||
// - https://github.com/rust-lang/flate2-rs/blob/1.1.2/src/ffi/rust.rs#L169
|
||||
// - https://github.com/Frommi/miniz_oxide/blob/0.8.8/miniz_oxide/src/deflate/stream.rs#L82
|
||||
// - https://github.com/Frommi/miniz_oxide/issues/105
|
||||
//
|
||||
// This causes compress_vec to return Ok as soon as the compressor
|
||||
// writes *any* output when called with an empty slice.
|
||||
//
|
||||
// So, instead, we need to keep calling compress_vec with an empty slice
|
||||
// until we stop making progress.
|
||||
//
|
||||
// Once we have done that properly, we should always have an empty block
|
||||
// at the end of the output, and then we can truncate the output to
|
||||
// remove the empty block, per the RFC.
|
||||
{
|
||||
let mut total_out = self.compressor.total_out();
|
||||
loop {
|
||||
output.reserve(REQUIRED_OUTPUT_SPACE);
|
||||
let output_len_before = output.len();
|
||||
let output_available_before = output.capacity() - output_len_before;
|
||||
|
||||
let _ = self.compressor.compress_vec(&[], &mut output, FlushCompress::Sync)?;
|
||||
log::trace!(
|
||||
"flushed {} bytes into an available {output_available_before} bytes",
|
||||
output.len() - output_len_before,
|
||||
);
|
||||
let out_before = std::mem::replace(&mut total_out, self.compressor.total_out());
|
||||
if total_out == out_before {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail
|
||||
// end. After this step, the last octet of the compressed data
|
||||
// contains (possibly part of) the DEFLATE header bits with the
|
||||
// "BTYPE" bits set to 00.
|
||||
|
||||
debug_assert!(output.ends_with(ELIDED_TRAILER_BLOCK_CONTENTS), "output is {output:02x?}");
|
||||
output.truncate(output.len() - ELIDED_TRAILER_BLOCK_CONTENTS.len());
|
||||
|
||||
if !self.own_context_takeover {
|
||||
// Reset if the next frame isn't supposed to be starting with the
|
||||
// same compression window.
|
||||
self.compressor.reset();
|
||||
}
|
||||
|
||||
log::trace!("finished compression into {} bytes", output.len());
|
||||
Ok(Bytes::from(output))
|
||||
}
|
||||
}
|
||||
|
||||
impl DeflateDecompress {
|
||||
/// Decompress the contents of a single frame.
|
||||
///
|
||||
/// The `is_final` argument must be `true` if and only if the frame is the
|
||||
/// last one in a message. The `size_limit` argument is the maximum number
|
||||
/// of bytes that can be decompressed. If the input `data` decompresses to
|
||||
/// more than `size_limit` bytes, [`DecompressionError::SizeLimitReached`]
|
||||
/// will be returned.
|
||||
fn decompress(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
is_final: bool,
|
||||
size_limit: usize,
|
||||
) -> Result<Bytes, DecompressionError<std::io::Error>> {
|
||||
// From RFC 7692 Section 7.2.2:
|
||||
//
|
||||
// An endpoint uses the following algorithm to decompress a message.
|
||||
//
|
||||
// 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
|
||||
// payload of the message.
|
||||
//
|
||||
// 2. Decompress the resulting data using DEFLATE.
|
||||
|
||||
let mut output = Vec::new();
|
||||
|
||||
log::trace!(
|
||||
"decompressing {} bytes in {} frame",
|
||||
data.len(),
|
||||
if is_final { "final" } else { "intermediate" }
|
||||
);
|
||||
let mut total_read = self.decompressor.total_in();
|
||||
|
||||
let mut decompress_from = |mut data: &[u8]| {
|
||||
loop {
|
||||
// Make sure there's some space to decompress into,
|
||||
// optimistically assuming a 50% compression ratio of the input.
|
||||
// This might put us slightly beyond the requested size limit
|
||||
// but it also might not all be used.
|
||||
output.reserve(2 * data.len());
|
||||
|
||||
let r =
|
||||
self.decompressor.decompress_vec(data, &mut output, FlushDecompress::None)?;
|
||||
|
||||
if output.len() > size_limit {
|
||||
return Err(DecompressionError::SizeLimitReached);
|
||||
}
|
||||
let read_before = std::mem::replace(&mut total_read, self.decompressor.total_in());
|
||||
|
||||
let read = (total_read - read_before) as usize;
|
||||
|
||||
data = &data[read..];
|
||||
|
||||
match r {
|
||||
Status::Ok => continue,
|
||||
Status::BufError => {
|
||||
// We've either run out of input data or output space.
|
||||
// Since we reserve space ahead of time, this must mean
|
||||
// we're out of input.
|
||||
break;
|
||||
}
|
||||
Status::StreamEnd => {
|
||||
// Finished a block with BFINAL set. This is legal; from
|
||||
// RFC 7692 Section 7.2.3.4:
|
||||
//
|
||||
// On platforms on which the flush method using an
|
||||
// empty DEFLATE block with no compression is not
|
||||
// available, implementors can choose to flush data
|
||||
// using DEFLATE blocks with "BFINAL" set to 1.
|
||||
//
|
||||
// On the decompression end we reset the compressor in
|
||||
// response. This relies on the assumption that the
|
||||
// client produced the block with BFINAL set by
|
||||
// informing their compressor that the stream was
|
||||
// ending, and so any blocks afterwards won't reference
|
||||
// any context from this block or earlier. It's
|
||||
// obviously not a perfect assumption, but it matches
|
||||
// the behavior of other widely-deployed
|
||||
// permessage-deflate implementations.
|
||||
self.decompressor.reset(false);
|
||||
total_read = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
decompress_from(data)?;
|
||||
|
||||
if is_final {
|
||||
// Decompress the final block that is part of the logical input to
|
||||
// DEFLATE but is elided from the message payloads. This implicitly
|
||||
// flushes out any pending bytes that were part of the previous
|
||||
// block and doesn't leave any others since the trailer is explicitly
|
||||
// an empty block.
|
||||
decompress_from(&ELIDED_TRAILER_BLOCK_CONTENTS)?;
|
||||
|
||||
if !self.peer_context_takeover {
|
||||
self.decompressor.reset(false);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Bytes::from(output))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DeflateContext> for super::PerMessageCompressionContext {
|
||||
fn from(value: DeflateContext) -> Self {
|
||||
Self::Deflate(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
use rand::{distr::Distribution as _, RngCore, SeedableRng as _};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn interop() {
|
||||
let mut data = vec![0; 2048];
|
||||
rand::rngs::SmallRng::seed_from_u64(1023).fill_bytes(&mut data);
|
||||
|
||||
let configs = [
|
||||
DeflateConfig::default(),
|
||||
DeflateConfig::default().set_no_context_takeover(Role::Client, true),
|
||||
DeflateConfig::default()
|
||||
.set_no_context_takeover(Role::Client, true)
|
||||
.set_max_window_bits(Role::Client, 10)
|
||||
.unwrap(),
|
||||
DeflateConfig::default().set_max_window_bits(Role::Client, 10).unwrap(),
|
||||
];
|
||||
|
||||
let frame_sizes = [16, 64, data.len()];
|
||||
|
||||
for config in configs {
|
||||
for frame_size in frame_sizes {
|
||||
let mut client = DeflateContext::new(Role::Client, config);
|
||||
let mut server = DeflateContext::new(Role::Server, config);
|
||||
|
||||
let mut send_and_receive = |data| {
|
||||
let compressed = client.compress(data).unwrap();
|
||||
|
||||
let mut decompressed = Vec::<u8>::new();
|
||||
|
||||
let mut it = compressed.chunks(frame_size).peekable();
|
||||
while let Some(frame) = it.next() {
|
||||
decompressed.extend_from_slice(
|
||||
&server.decompress(frame, it.peek().is_none(), usize::MAX).unwrap(),
|
||||
);
|
||||
}
|
||||
decompressed
|
||||
};
|
||||
|
||||
let decompressed = send_and_receive(&data);
|
||||
assert_eq!(data, decompressed);
|
||||
|
||||
// Make sure we haven't broken compression or decompression for
|
||||
// the *next* message.
|
||||
let decompressed = send_and_receive(b"second message");
|
||||
assert_eq!(decompressed, b"second message");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn large_message_compression() {
|
||||
let mut data = vec![0; 1 << 19];
|
||||
rand::rngs::SmallRng::seed_from_u64(1023).fill_bytes(&mut data);
|
||||
|
||||
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
|
||||
|
||||
let compressed = context.compress(&data).unwrap();
|
||||
|
||||
assert_eq!(&context.decompress(&compressed, true, usize::MAX).unwrap(), &data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompression_limits_applied() {
|
||||
let data = vec![0; 1 << 18];
|
||||
|
||||
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
|
||||
let compressed = context.compress(&data).unwrap();
|
||||
|
||||
// A buffer of all zeros compresses very well.
|
||||
assert!(compressed.len() < data.len() / 500);
|
||||
|
||||
assert_eq!(
|
||||
context.decompress(&compressed, true, data.len() - 1),
|
||||
Err(DecompressionError::SizeLimitReached)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compressible_payload_prefixes() {
|
||||
let _ = env_logger::try_init();
|
||||
let data: Vec<u8> = rand::distr::Alphanumeric
|
||||
.sample_iter(&mut rand::rngs::SmallRng::from_seed([59; 32]))
|
||||
.take(1 << 16)
|
||||
.collect();
|
||||
|
||||
let prefixes =
|
||||
(5..).map(|i| 1 << i).take_while(|len| *len <= data.len()).map(|len| &data[..len]);
|
||||
|
||||
for prefix in prefixes {
|
||||
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
|
||||
println!("compressing {} bytes of compressible data", prefix.len());
|
||||
|
||||
let compressed = context.compress(prefix).unwrap();
|
||||
assert_eq!(context.decompress(&compressed, true, usize::MAX).unwrap(), prefix);
|
||||
}
|
||||
}
|
||||
|
||||
/// Utilities for testing decomrpession of highly-compressed payloads.
|
||||
pub(crate) mod very_compressed {
|
||||
use bytes::Bytes;
|
||||
|
||||
// Compressed payload that decompresses to 50KB of zeroes. This was
|
||||
// specifically chosen so that its compressed form aligns with a byte
|
||||
// boundary, which lets us repeat it an arbitrary number of times to
|
||||
// form the payload of a single message.
|
||||
pub(crate) const FRAME_PAYLOAD: &[u8; 66] = &[
|
||||
0xec, 0xc1, 0x31, 0x01, 0x00, 0x00, 0x00, 0xc2, 0xa0, 0xf5, 0x4f, 0x6d, 0x0b, 0x2f,
|
||||
0xa0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe0, 0x6f,
|
||||
];
|
||||
pub(crate) const DECOMPRESSED_LEN: usize = 50 * 1024;
|
||||
|
||||
pub(crate) fn make_frames(frame_count: usize) -> impl Iterator<Item = (Bytes, bool)> {
|
||||
std::iter::repeat_n(FRAME_PAYLOAD, frame_count).enumerate().map(move |(i, bytes)| {
|
||||
let is_final = i == frame_count - 1;
|
||||
let bytes = if is_final {
|
||||
bytes.iter().copied().chain(std::iter::once(0x00)).collect()
|
||||
} else {
|
||||
Bytes::from_static(bytes)
|
||||
};
|
||||
(bytes, is_final)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn large_message_decompression() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
for frame_count in 1..=10 {
|
||||
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
|
||||
|
||||
let decompressed: Bytes = very_compressed::make_frames(frame_count)
|
||||
.enumerate()
|
||||
.flat_map(|(i, (frame, is_final))| {
|
||||
context
|
||||
.decompress
|
||||
.decompress(&frame, is_final, usize::MAX)
|
||||
.unwrap_or_else(|e| panic!("deflating frame {i}/{frame_count} failed: {e}"))
|
||||
})
|
||||
.collect();
|
||||
assert!(decompressed.iter().all(|b| *b == 0));
|
||||
assert_eq!(decompressed.len(), frame_count * very_compressed::DECOMPRESSED_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompress_multiple_messages_that_each_set_bfinal() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let mut rng = rand::rngs::SmallRng::from_seed([12; 32]);
|
||||
let uncompressed_payloads = std::iter::repeat_with(|| {
|
||||
let mut data: Vec<u8> = vec![0; 1 << 12];
|
||||
rng.fill_bytes(&mut data);
|
||||
data
|
||||
});
|
||||
|
||||
let mut context = DeflateContext::new(Role::Server, DeflateConfig::default());
|
||||
|
||||
for (i, payload) in uncompressed_payloads.enumerate().take(5) {
|
||||
let mut compressed = context.compress(&payload).unwrap().try_into_mut().unwrap();
|
||||
// The final block in the stream is a 5-byte uncompressed block, but
|
||||
// with the trailing 4 bytes of the body chopped off (per the RFC).
|
||||
// We don't know where in the last *byte* the final block begins
|
||||
// (since DEFLATE is a bit-oriented protocol), so to make sure the
|
||||
// payload ends with a block with BFINAL set we need to append
|
||||
// another block. First we reattach the chopped-off bytes from the
|
||||
// last block. Then we push *another* 5-byte uncompressed block with
|
||||
// BFINAL set. Lastly we chop off the trailing 4 bytes per the spec.
|
||||
compressed.extend_from_slice(ELIDED_TRAILER_BLOCK_CONTENTS);
|
||||
compressed.extend_from_slice(&[0x01, 0x00, 0x00, 0xff, 0xff]);
|
||||
compressed.truncate(compressed.len() - ELIDED_TRAILER_BLOCK_CONTENTS.len());
|
||||
|
||||
println!("decompressing block {i}");
|
||||
let decompressed = context.decompress(&compressed, true, usize::MAX).unwrap();
|
||||
assert_eq!(decompressed.len(), payload.len());
|
||||
assert_eq!(decompressed, payload);
|
||||
}
|
||||
}
|
||||
|
||||
mod rfc_7692_section_7_2_3_examples {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn one_block() {
|
||||
// From RFC 7692 Section 7.2.3.1:
|
||||
//
|
||||
// Suppose that an endpoint sends a text message "Hello". If the
|
||||
// endpoint uses one compressed DEFLATE block (compressed with fixed
|
||||
// Huffman code and the "BFINAL" bit not set) to compress the message,
|
||||
// the endpoint obtains the compressed data to use for the message
|
||||
// payload as follows.
|
||||
//
|
||||
// The endpoint compresses "Hello" into one compressed DEFLATE block and
|
||||
// flushes the resulting data into a byte array using an empty DEFLATE
|
||||
// block with no compression:
|
||||
//
|
||||
// 0xf2 0x48 0xcd 0xc9 0xc9 0x07 0x00 0x00 0x00 0xff 0xff
|
||||
//
|
||||
// By stripping 0x00 0x00 0xff 0xff from the tail end, the endpoint gets
|
||||
// the data to use for the message payload:
|
||||
//
|
||||
const EXPECTED_COMPRESSED_PAYLOAD: &[u8] = &[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00];
|
||||
|
||||
let mut context = DeflateContext::new(Role::Server, DeflateConfig::default());
|
||||
let compressed = context.compress(b"Hello").unwrap();
|
||||
assert_eq!(&compressed[..], EXPECTED_COMPRESSED_PAYLOAD);
|
||||
//
|
||||
// ...
|
||||
//
|
||||
// Suppose that the endpoint sends the compressed message with
|
||||
// fragmentation. The endpoint splits the compressed data into
|
||||
// fragments and builds frames for each fragment. For example, if the
|
||||
// fragments are 3 and 4 octets,
|
||||
//
|
||||
const FRAGMENTED_FRAMES: &[&[u8]] = &[
|
||||
// the first frame is:
|
||||
&[0x41, 0x03, 0xf2, 0x48, 0xcd],
|
||||
// and the second frame is:
|
||||
&[0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00],
|
||||
];
|
||||
//
|
||||
// Note that the RSV1 bit is set only on the first frame.
|
||||
|
||||
let frame_payloads =
|
||||
FRAGMENTED_FRAMES.iter().map(|frame| &frame[2..]).collect::<Vec<_>>();
|
||||
|
||||
let decompressed = frame_payloads
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, payload)| {
|
||||
context.decompress(payload, index == frame_payloads.len() - 1, usize::MAX)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap()
|
||||
.concat();
|
||||
|
||||
assert_eq!(decompressed, b"Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sharing_sliding_window() {
|
||||
const ROLE: Role = Role::Client;
|
||||
|
||||
// From RFC 7692 Section 7.2.3.2:
|
||||
//
|
||||
// Suppose that a client has sent a message "Hello" as a compressed
|
||||
// message and will send the same message "Hello" again as a compressed
|
||||
// message.
|
||||
//
|
||||
const FIRST_PAYLOAD: &[u8] = &[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00];
|
||||
//
|
||||
// The above is the payload of the first message that the client has
|
||||
// sent. If the "agreed parameters" contain the
|
||||
// "client_no_context_takeover" extension parameter, the client
|
||||
// compresses the payload of the next message into the same bytes (if
|
||||
// the client uses the same "BTYPE" value and "BFINAL" value). So, the
|
||||
// payload of the second message will be:
|
||||
//
|
||||
const SECOND_PAYLOAD: &[u8] = &[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00];
|
||||
|
||||
let mut context = DeflateContext::new(
|
||||
ROLE,
|
||||
DeflateConfig::default().set_no_context_takeover(ROLE, true),
|
||||
);
|
||||
assert_eq!(&context.compress(b"Hello").unwrap()[..], FIRST_PAYLOAD);
|
||||
assert_eq!(&context.compress(b"Hello").unwrap()[..], SECOND_PAYLOAD);
|
||||
|
||||
//
|
||||
// If the "agreed parameters" did not contain the
|
||||
// "client_no_context_takeover" extension parameter, the client can
|
||||
// compress the payload of the next message into fewer bytes by
|
||||
// referencing the history in the LZ77 sliding window. So, the payload
|
||||
// of the second message will be:
|
||||
//
|
||||
const NEW_SECOND_PAYLOAD: &[u8] = &[0xf2, 0x00, 0x11, 0x00, 0x00];
|
||||
|
||||
let mut context = DeflateContext::new(ROLE, DeflateConfig::default());
|
||||
assert_eq!(&context.compress(b"Hello").unwrap()[..], FIRST_PAYLOAD);
|
||||
assert_eq!(&context.compress(b"Hello").unwrap()[..], NEW_SECOND_PAYLOAD);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deflate_block_with_bfinal_set() {
|
||||
// From RFC 7692 Section 7.2.3.4:
|
||||
//
|
||||
// On platforms on which the flush method using an empty DEFLATE
|
||||
// block with no compression is not available, implementors can
|
||||
// choose to flush data using DEFLATE blocks with "BFINAL" set to
|
||||
// 1.
|
||||
|
||||
const PAYLOAD: &[u8] = &[0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00, 0x00];
|
||||
|
||||
// This is the payload of a message containing "Hello" compressed
|
||||
// using a DEFLATE block with "BFINAL" set to 1. The first 7
|
||||
// octets constitute a DEFLATE block with "BFINAL" set to 1 and
|
||||
// "BTYPE" set to 01 containing "Hello". The last 1 octet (0x00)
|
||||
// contains the header bits with "BFINAL" set to 0 and "BTYPE" set
|
||||
// to 00, and 5 padding bits of 0. This octet is necessary to
|
||||
// allow the payload to be decompressed in the same manner as
|
||||
// messages flushed using DEFLATE blocks with "BFINAL" unset.
|
||||
|
||||
let mut context = DeflateContext::new(Role::Client, DeflateConfig::default());
|
||||
assert_eq!(
|
||||
context.decompress(PAYLOAD, true, usize::MAX),
|
||||
Ok(Bytes::from_static(b"Hello"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_deflate_blocks() {
|
||||
// From RFC 7692 Section 7.2.3.5:
|
||||
//
|
||||
// Two or more DEFLATE blocks may be used in one message.
|
||||
|
||||
const TWO_BLOCKS: &[u8] =
|
||||
&[0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xca, 0xc9, 0xc9, 0x07, 0x00];
|
||||
|
||||
let mut context = DeflateContext::new(Role::Client, DeflateConfig::new());
|
||||
|
||||
assert_eq!(&context.decompress(TWO_BLOCKS, true, usize::MAX).unwrap()[..], b"Hello");
|
||||
}
|
||||
}
|
||||
}
|
||||
89
src/extensions/compression/mod.rs
Normal file
89
src/extensions/compression/mod.rs
Normal file
@ -0,0 +1,89 @@
|
||||
//! [Per-Message Compression Extensions][rfc7692]
|
||||
//!
|
||||
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
|
||||
|
||||
use bytes::Bytes;
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
pub mod deflate;
|
||||
|
||||
/// Active context for performing per-message compression.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(not(feature = "deflate"), allow(missing_copy_implementations))] // This is only trivially copyable if compression is disabled.
|
||||
pub enum PerMessageCompressionContext {
|
||||
/// Context for compressing/decompressing with `permessage-deflate`.
|
||||
#[cfg(feature = "deflate")]
|
||||
Deflate(deflate::DeflateContext),
|
||||
}
|
||||
|
||||
/// Error encountered while compressing or decompressing.
|
||||
#[derive(Copy, Clone, Debug, Error, PartialEq, Eq)]
|
||||
pub enum CompressionError {
|
||||
/// Error encountered while deflating or inflating
|
||||
#[error("Deflate error: {0}")]
|
||||
#[cfg(feature = "deflate")]
|
||||
Deflate(deflate::DeflateError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub(crate) enum DecompressionError<E = CompressionError> {
|
||||
/// The decompressed frame is larger than the configured limit.
|
||||
#[error("decompressed data is too large")]
|
||||
SizeLimitReached,
|
||||
/// An error was encountered while decompressing.
|
||||
#[error("{0}")]
|
||||
Decompression(E),
|
||||
}
|
||||
|
||||
impl PerMessageCompressionContext {
|
||||
#[inline]
|
||||
pub(crate) fn compressor<'s>(
|
||||
&'s mut self,
|
||||
) -> impl FnMut(&Bytes) -> Result<Bytes, CompressionError> + 's {
|
||||
move |payload: &Bytes| match self {
|
||||
#[cfg(feature = "deflate")]
|
||||
Self::Deflate(deflate_config) => {
|
||||
deflate_config.compress(payload).map_err(CompressionError::Deflate)
|
||||
}
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
_ => {
|
||||
let _ = payload;
|
||||
unreachable!("*PerMessageCompressionContext is uninhabited")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn decompressor<'s>(
|
||||
&'s mut self,
|
||||
) -> impl FnMut(&Bytes, bool, usize) -> Result<Bytes, DecompressionError> + 's {
|
||||
move |payload, is_final, size_limit| match self {
|
||||
#[cfg(feature = "deflate")]
|
||||
Self::Deflate(deflate_config) => deflate_config
|
||||
.decompress(payload, is_final, size_limit)
|
||||
.map_err(|e| e.map(CompressionError::Deflate)),
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
_ => {
|
||||
let _ = (payload, is_final, size_limit);
|
||||
unreachable!("*PerMessageCompressionContext is uninhabited")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> DecompressionError<E> {
|
||||
pub(crate) fn map<T>(self, f: impl FnOnce(E) -> T) -> DecompressionError<T> {
|
||||
match self {
|
||||
Self::SizeLimitReached => DecompressionError::SizeLimitReached,
|
||||
Self::Decompression(e) => DecompressionError::Decompression(f(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<std::io::Error>> From<E> for DecompressionError<std::io::Error> {
|
||||
fn from(value: E) -> Self {
|
||||
Self::Decompression(value.into())
|
||||
}
|
||||
}
|
||||
57
src/extensions/headers/mod.rs
Normal file
57
src/extensions/headers/mod.rs
Normal file
@ -0,0 +1,57 @@
|
||||
//! HTTP Request and response header handling.
|
||||
use headers::Error;
|
||||
use http::HeaderValue;
|
||||
|
||||
mod sec_websocket_extensions;
|
||||
#[allow(unused)]
|
||||
pub(crate) use sec_websocket_extensions::{
|
||||
SecWebsocketExtensions, WebsocketExtensionParam, WebsocketProtocolExtension,
|
||||
};
|
||||
|
||||
/// Reads a comma-delimited raw header into a Vec.
|
||||
fn from_comma_delimited<'i, I, T, E>(values: &mut I) -> Result<E, Error>
|
||||
where
|
||||
I: Iterator<Item = &'i HeaderValue>,
|
||||
T: ::std::str::FromStr,
|
||||
E: ::std::iter::FromIterator<T>,
|
||||
{
|
||||
from_delimited(&mut values.flat_map(|header_value| header_value.to_str()), ',')
|
||||
}
|
||||
|
||||
/// Reads a single-character-delimited raw header into a Vec.
|
||||
fn from_delimited<'i, I, T, E>(values: &mut I, delimiter: char) -> Result<E, Error>
|
||||
where
|
||||
I: Iterator<Item = &'i str>,
|
||||
T: ::std::str::FromStr,
|
||||
E: ::std::iter::FromIterator<T>,
|
||||
{
|
||||
values
|
||||
.flat_map(|string| {
|
||||
let mut in_quotes = false;
|
||||
string
|
||||
.split(move |c| {
|
||||
#[allow(clippy::collapsible_else_if)]
|
||||
if in_quotes {
|
||||
if c == '"' {
|
||||
in_quotes = false;
|
||||
}
|
||||
false // dont split
|
||||
} else {
|
||||
if c == delimiter {
|
||||
true // split
|
||||
} else {
|
||||
if c == '"' {
|
||||
in_quotes = true;
|
||||
}
|
||||
false // dont split
|
||||
}
|
||||
}
|
||||
})
|
||||
.filter_map(|x| match x.trim() {
|
||||
"" => None,
|
||||
y => Some(y),
|
||||
})
|
||||
.map(|x| x.parse().map_err(|_| Error::invalid()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
376
src/extensions/headers/sec_websocket_extensions.rs
Normal file
376
src/extensions/headers/sec_websocket_extensions.rs
Normal file
@ -0,0 +1,376 @@
|
||||
use std::{borrow::Cow, fmt::Debug, iter::FromIterator, str::FromStr};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use http::HeaderValue;
|
||||
|
||||
use super::{from_comma_delimited, from_delimited};
|
||||
|
||||
/// The `Sec-Websocket-Extensions` header.
|
||||
///
|
||||
/// This header is used in the Websocket handshake, sent by the client to the
|
||||
/// server and then from the server to the client. It is a proposed and
|
||||
/// agreed-upon list of websocket protocol extensions to use.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
|
||||
pub struct SecWebsocketExtensions(Vec<WebsocketProtocolExtension>);
|
||||
|
||||
/// An extension listed in a [`SecWebsocketExtensions`] header.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct WebsocketProtocolExtension {
|
||||
name: Cow<'static, str>,
|
||||
params: Vec<WebsocketExtensionParam>,
|
||||
}
|
||||
|
||||
/// Named parameter for an extension in a `Sec-Websocket-Extensions` header.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct WebsocketExtensionParam {
|
||||
name: Cow<'static, str>,
|
||||
value: Option<String>,
|
||||
}
|
||||
|
||||
impl SecWebsocketExtensions {
|
||||
/// Constructs a new header with the provided extensions.
|
||||
pub fn new(extensions: impl IntoIterator<Item = WebsocketProtocolExtension>) -> Self {
|
||||
Self(extensions.into_iter().collect())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the extensions in this header.
|
||||
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
|
||||
self.into_iter()
|
||||
}
|
||||
|
||||
/// Returns the number of extensions in this header.
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns a [`HeaderValue`] with the encoded contents of this header.
|
||||
pub fn header_value(&self) -> HeaderValue {
|
||||
let extensions = CommaDelimited(self.0.as_slice());
|
||||
let mut buffer = BytesMut::with_capacity(extensions.encoded_len());
|
||||
|
||||
extensions.write_with(&mut |slice| buffer.extend_from_slice(slice));
|
||||
|
||||
HeaderValue::from_maybe_shared(buffer).expect("valid construction")
|
||||
}
|
||||
}
|
||||
|
||||
impl WebsocketProtocolExtension {
|
||||
/// Constructs a new extension directive with the given name and parameters.
|
||||
pub fn new(
|
||||
name: impl Into<Cow<'static, str>>,
|
||||
params: impl IntoIterator<Item = WebsocketExtensionParam>,
|
||||
) -> Self {
|
||||
Self { name: name.into(), params: params.into_iter().collect() }
|
||||
}
|
||||
|
||||
/// The name of this extension directive.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns an iterator over the parameters for this extension directive.
|
||||
pub fn params(&self) -> impl Iterator<Item = &WebsocketExtensionParam> {
|
||||
self.params.iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl WebsocketExtensionParam {
|
||||
/// Constructs a new parameter with the given name and optional value.
|
||||
#[inline]
|
||||
pub fn new(name: impl Into<Cow<'static, str>>, value: Option<String>) -> Self {
|
||||
Self { name: name.into(), value }
|
||||
}
|
||||
|
||||
/// The name of the parameter.
|
||||
#[inline]
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// The parameter value, if there is one.
|
||||
#[inline]
|
||||
pub fn value(&self) -> Option<&str> {
|
||||
self.value.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl headers::Header for SecWebsocketExtensions {
|
||||
fn name() -> &'static ::http::header::HeaderName {
|
||||
&::http::header::SEC_WEBSOCKET_EXTENSIONS
|
||||
}
|
||||
|
||||
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
|
||||
where
|
||||
I: Iterator<Item = &'i HeaderValue>,
|
||||
{
|
||||
from_comma_delimited(values).map(SecWebsocketExtensions)
|
||||
}
|
||||
fn encode<E: Extend<headers::HeaderValue>>(&self, values: &mut E) {
|
||||
values.extend(std::iter::once(self.header_value()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WebsocketProtocolExtension> for SecWebsocketExtensions {
|
||||
fn from(value: WebsocketProtocolExtension) -> Self {
|
||||
Self(vec![value])
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<WebsocketProtocolExtension> for SecWebsocketExtensions {
|
||||
fn from_iter<T: IntoIterator<Item = WebsocketProtocolExtension>>(iter: T) -> Self {
|
||||
Self(iter.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for SecWebsocketExtensions {
|
||||
type Item = WebsocketProtocolExtension;
|
||||
|
||||
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> IntoIterator for &'a SecWebsocketExtensions {
|
||||
type Item = &'a WebsocketProtocolExtension;
|
||||
|
||||
type IntoIter = std::slice::Iter<'a, WebsocketProtocolExtension>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for WebsocketProtocolExtension {
|
||||
type Err = headers::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let (name, tail) = s.split_once(';').map(|(n, t)| (n, Some(t))).unwrap_or((s, None));
|
||||
|
||||
let params = from_delimited(&mut tail.into_iter(), ';')?;
|
||||
|
||||
Ok(Self { name: name.trim().to_owned().into(), params })
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebsocketProtocolExtension {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let Self { name, params } = self;
|
||||
|
||||
write!(f, "{name}")?;
|
||||
for param in params {
|
||||
write!(f, "; {param}")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for WebsocketExtensionParam {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let (name, value) = s.split_once('=').map(|(n, t)| (n, Some(t))).unwrap_or((s, None));
|
||||
|
||||
let value = value.map(|value| value.trim().to_owned());
|
||||
|
||||
Ok(Self { name: name.trim().to_owned().into(), value })
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebsocketExtensionParam {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let Self { name, value } = self;
|
||||
|
||||
write!(f, "{name}")?;
|
||||
if let Some(value) = value {
|
||||
write!(f, "={value}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
trait WriteTo {
|
||||
fn encoded_len(&self) -> usize {
|
||||
let mut size = 0;
|
||||
self.write_with(&mut |slice| size += slice.len());
|
||||
size
|
||||
}
|
||||
|
||||
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized));
|
||||
}
|
||||
|
||||
impl WriteTo for WebsocketProtocolExtension {
|
||||
fn encoded_len(&self) -> usize {
|
||||
let Self { name, params } = self;
|
||||
|
||||
let params_len: usize = params.iter().map(|p| p.encoded_len() + 2).sum();
|
||||
|
||||
name.len() + params_len
|
||||
}
|
||||
|
||||
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
|
||||
let Self { name, params } = self;
|
||||
write(name.as_bytes());
|
||||
|
||||
for param in params {
|
||||
write(b"; ");
|
||||
param.write_with(write);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteTo for WebsocketExtensionParam {
|
||||
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
|
||||
let Self { name, value } = self;
|
||||
write(name.as_bytes());
|
||||
|
||||
if let Some(value) = value {
|
||||
write(b"=");
|
||||
write(value.as_bytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CommaDelimited<T>(T);
|
||||
|
||||
impl<T> CommaDelimited<T> {
|
||||
const SEPARATOR: &[u8] = b", ";
|
||||
}
|
||||
|
||||
impl<T: WriteTo> WriteTo for CommaDelimited<&[T]> {
|
||||
fn encoded_len(&self) -> usize {
|
||||
let all_encoded_len: usize = self.0.iter().map(T::encoded_len).sum();
|
||||
let all_separators_len = self.0.len().saturating_sub(1) * Self::SEPARATOR.len();
|
||||
all_encoded_len + all_separators_len
|
||||
}
|
||||
|
||||
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
|
||||
let mut is_first = true;
|
||||
for item in self.0 {
|
||||
let was_first = std::mem::replace(&mut is_first, false);
|
||||
if !was_first {
|
||||
write(Self::SEPARATOR);
|
||||
}
|
||||
item.write_with(write);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WriteTo, const N: usize> WriteTo for CommaDelimited<[T; N]> {
|
||||
fn encoded_len(&self) -> usize {
|
||||
CommaDelimited(self.0.as_slice()).encoded_len()
|
||||
}
|
||||
|
||||
fn write_with(&self, write: &mut (impl FnMut(&[u8]) + ?Sized)) {
|
||||
CommaDelimited(self.0.as_slice()).write_with(write);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use headers::{Header, HeaderMapExt as _};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_decode<T: Header>(values: &[&str]) -> Option<T> {
|
||||
let mut map = ::http::HeaderMap::new();
|
||||
for val in values {
|
||||
map.append(T::name(), val.parse().unwrap());
|
||||
}
|
||||
map.typed_get()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn test_encode<T: Header>(header: T) -> ::http::HeaderMap {
|
||||
let mut map = ::http::HeaderMap::new();
|
||||
map.typed_insert(header);
|
||||
map
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_separate_headers() {
|
||||
// From https://tools.ietf.org/html/rfc6455#section-9.1
|
||||
let extensions =
|
||||
test_decode::<SecWebsocketExtensions>(&["foo", "bar; baz=2"]).expect("valid");
|
||||
|
||||
assert_eq!(
|
||||
extensions,
|
||||
SecWebsocketExtensions(vec![
|
||||
WebsocketProtocolExtension { name: "foo".into(), params: vec![] },
|
||||
WebsocketProtocolExtension {
|
||||
name: "bar".into(),
|
||||
params: vec![WebsocketExtensionParam {
|
||||
name: "baz".into(),
|
||||
value: Some("2".to_owned())
|
||||
}],
|
||||
}
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_trip_complex() {
|
||||
let extensions = test_decode::<SecWebsocketExtensions>(&[
|
||||
"deflate-stream",
|
||||
"mux; max-channels=4; flow-control, deflate-stream",
|
||||
"private-extension",
|
||||
])
|
||||
.expect("valid");
|
||||
|
||||
let headers = test_encode(extensions);
|
||||
assert_eq!(
|
||||
headers["sec-websocket-extensions"],
|
||||
"deflate-stream, mux; max-channels=4; flow-control, deflate-stream, private-extension"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_to_exact_encoded_len() {
|
||||
trait WriteToDyn: Debug {
|
||||
fn encoded_len(&self) -> usize;
|
||||
fn write_with(&self, write: &mut dyn FnMut(&[u8]));
|
||||
}
|
||||
|
||||
impl<W: WriteTo + Debug> WriteToDyn for W {
|
||||
fn encoded_len(&self) -> usize {
|
||||
WriteTo::encoded_len(self)
|
||||
}
|
||||
|
||||
fn write_with(&self, write: &mut dyn FnMut(&[u8])) {
|
||||
WriteTo::write_with(self, write);
|
||||
}
|
||||
}
|
||||
|
||||
// This isn't a required property for correctness but if the length
|
||||
// precomputation is wrong we'll over- or under-allocate during
|
||||
// conversion.
|
||||
let cases: &[Box<dyn WriteToDyn>] = &[
|
||||
Box::new(CommaDelimited([
|
||||
WebsocketProtocolExtension::from_str("extension-name").unwrap(),
|
||||
WebsocketProtocolExtension::from_str("with-params; a=5; b=8").unwrap(),
|
||||
])),
|
||||
Box::new(CommaDelimited::<[WebsocketProtocolExtension; 0]>([])),
|
||||
Box::new(CommaDelimited([
|
||||
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
|
||||
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
|
||||
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
|
||||
])),
|
||||
Box::new(WebsocketProtocolExtension::new(
|
||||
"name",
|
||||
["foo=123".parse().unwrap(), "bar".parse().unwrap(), "baz=four".parse().unwrap()],
|
||||
)),
|
||||
];
|
||||
|
||||
for case in cases {
|
||||
let mut value = Vec::new();
|
||||
let expected_len = case.encoded_len();
|
||||
case.write_with(&mut |slice| value.extend_from_slice(slice));
|
||||
|
||||
assert_eq!(value.len(), expected_len, "for {case:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
422
src/extensions/mod.rs
Normal file
422
src/extensions/mod.rs
Normal file
@ -0,0 +1,422 @@
|
||||
//! WebSocket extensions.
|
||||
// Only `permessage-deflate` is supported at the moment.
|
||||
|
||||
use bytes::Bytes;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::extensions::compression::{
|
||||
CompressionError, DecompressionError, PerMessageCompressionContext,
|
||||
};
|
||||
#[cfg(feature = "handshake")]
|
||||
use crate::extensions::headers::{SecWebsocketExtensions, WebsocketProtocolExtension};
|
||||
use crate::protocol::Role;
|
||||
|
||||
pub mod compression;
|
||||
#[cfg(feature = "headers")]
|
||||
pub(crate) mod headers;
|
||||
|
||||
/// Container for configured extensions for a connection.
|
||||
#[derive(Debug, Default)]
|
||||
#[allow(missing_copy_implementations)]
|
||||
pub struct Extensions {
|
||||
/// The Per-Message Compression extension configured for the connection, if
|
||||
/// any.
|
||||
per_message_compression: Option<PerMessageCompressionContext>,
|
||||
}
|
||||
|
||||
/// Configuration for extensions for a connection.
|
||||
#[derive(Copy, Clone, Default, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct ExtensionsConfig {
|
||||
/// Configuration for the `permessage-deflate` PMCE as specified by [RFC 7692].
|
||||
///
|
||||
/// [RFC 7692]: https://tools.ietf.org/html/rfc7692
|
||||
#[cfg(feature = "deflate")]
|
||||
pub permessage_deflate: Option<compression::deflate::DeflateConfig>,
|
||||
}
|
||||
|
||||
/// Error encountered while handling extensions.
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub enum ExtensionsError {
|
||||
/// The header included an invalid extension.
|
||||
#[error("Extension header had invalid extension: {0}")]
|
||||
InvalidExtension(Box<str>),
|
||||
/// The negotiation response included an extension more than once.
|
||||
#[error("Extension negotiation response had conflicting extension: {0}")]
|
||||
ExtensionConflict(Box<str>),
|
||||
/// The header included an unparsable extension.
|
||||
#[error("Extension negotiation response had malformed extension: {0}")]
|
||||
MalformedExtension(&'static str),
|
||||
}
|
||||
|
||||
#[cfg(feature = "handshake")]
|
||||
impl ExtensionsConfig {
|
||||
pub(crate) fn generate_offers(&self) -> impl Iterator<Item = WebsocketProtocolExtension> {
|
||||
let Self {
|
||||
#[cfg(feature = "deflate")]
|
||||
permessage_deflate,
|
||||
} = self;
|
||||
|
||||
#[allow(unused_mut, unused_assignments)]
|
||||
let mut permessage_compression_offer = None;
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
{
|
||||
permessage_compression_offer = permessage_deflate.as_ref().map(|p| p.as_offer().into());
|
||||
}
|
||||
|
||||
permessage_compression_offer.into_iter()
|
||||
}
|
||||
|
||||
/// Checks that the given extensions are compatible with the given config
|
||||
/// (if any).
|
||||
///
|
||||
/// Receives a [`SecWebsocketExtensions`] header in a handshake response and
|
||||
/// evaluates it against the given local configuration. Returns a new
|
||||
/// `Extensions` that can be used for the connection for the completed
|
||||
/// handshake.
|
||||
pub(crate) fn verify_agreed_on(
|
||||
&self,
|
||||
agreed: SecWebsocketExtensions,
|
||||
) -> Result<Extensions, ExtensionsError> {
|
||||
#[cfg_attr(not(feature = "deflate"), allow(unused_mut))]
|
||||
let mut per_message_compression = None;
|
||||
|
||||
for extension in agreed.iter() {
|
||||
match extension.name() {
|
||||
#[cfg(feature = "deflate")]
|
||||
compression::deflate::EXTENSION_NAME => {
|
||||
use compression::deflate::{
|
||||
DeflateContext, DeflateParameterError, PermessageDeflateConfig,
|
||||
EXTENSION_NAME,
|
||||
};
|
||||
|
||||
// Already had PMCE configured
|
||||
if per_message_compression.is_some() {
|
||||
return Err(ExtensionsError::ExtensionConflict(EXTENSION_NAME.into()));
|
||||
}
|
||||
|
||||
let deflate = self
|
||||
.permessage_deflate
|
||||
.ok_or_else(|| ExtensionsError::InvalidExtension(EXTENSION_NAME.into()))?;
|
||||
|
||||
let extension: PermessageDeflateConfig = PermessageDeflateConfig::parse_params(
|
||||
extension.params(),
|
||||
)
|
||||
.map_err(|_: DeflateParameterError| {
|
||||
ExtensionsError::MalformedExtension(EXTENSION_NAME)
|
||||
})?;
|
||||
|
||||
let deflate_config = deflate.accept_response(extension).map_err(|e| {
|
||||
ExtensionsError::InvalidExtension(format!("{EXTENSION_NAME}: {e}").into())
|
||||
})?;
|
||||
|
||||
per_message_compression =
|
||||
Some(DeflateContext::new(Role::Client, deflate_config).into());
|
||||
}
|
||||
name => return Err(ExtensionsError::InvalidExtension(name.into())),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Extensions { per_message_compression })
|
||||
}
|
||||
|
||||
/// Checks whether the given extension headers are compatible with the given
|
||||
/// config (if any).
|
||||
///
|
||||
/// Recieves a [`SecWebsocketExtensions`] header in a handshake request and
|
||||
/// evaluates it against the given local configuration. Returns a
|
||||
/// `SecWebsocketExtensions` header to be sent in the handshake response to
|
||||
/// the client, and a `Extensions` value to be used for the connection, once
|
||||
/// it is established.
|
||||
pub(crate) fn accept_offers(
|
||||
&self,
|
||||
extensions: &SecWebsocketExtensions,
|
||||
) -> Result<(Extensions, Option<SecWebsocketExtensions>), ExtensionsError> {
|
||||
#[cfg_attr(not(feature = "deflate"), allow(unused_mut))]
|
||||
let mut per_message_compression = None;
|
||||
|
||||
for extension in extensions.iter() {
|
||||
// Only one extension is currently supported. If that changes,
|
||||
// this will need to be updated to apply the extensions in the correct order.
|
||||
match extension.name() {
|
||||
#[cfg(feature = "deflate")]
|
||||
compression::deflate::EXTENSION_NAME => {
|
||||
use compression::deflate::{
|
||||
DeflateContext, PermessageDeflateConfig, EXTENSION_NAME,
|
||||
};
|
||||
|
||||
let deflate = match self.permessage_deflate {
|
||||
Some(deflate) => deflate,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let extension = match PermessageDeflateConfig::parse_params(extension.params())
|
||||
{
|
||||
Ok(extension) => extension,
|
||||
Err(e) => {
|
||||
// Per RFC 7692 Section 7:
|
||||
//
|
||||
// A server MUST decline an extension negotiation
|
||||
// offer for this extension if any of the following
|
||||
// conditions are met:
|
||||
//
|
||||
// o The negotiation offer contains an extension
|
||||
// parameter not defined for use in an offer.
|
||||
//
|
||||
// Declining instead of rejecting the request
|
||||
// outright allows clients that conform to a
|
||||
// (currently hypothetical) RFC that supersedes RFC
|
||||
// 7692 to fall back to requesting to the behavior
|
||||
// specified in the latter.
|
||||
log::debug!("{EXTENSION_NAME} extension: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Per RFC 7692 Section 5:
|
||||
//
|
||||
// A client may also offer multiple PMCE choices to the server
|
||||
// by including multiple elements in the
|
||||
// "Sec-WebSocket-Extensions" header, one for each PMCE
|
||||
// offered. This set of elements MAY include multiple PMCEs
|
||||
// with the same extension name to offer the possibility to
|
||||
// use the same algorithm with different configuration
|
||||
// parameters. The order of elements is important as it
|
||||
// specifies the client's preference. An element preceding
|
||||
// another element has higher preference. It is recommended
|
||||
// that a server accepts PMCEs with higher preference if the
|
||||
// server supports them.
|
||||
//
|
||||
// Follow the RFC recommendation by not overwriting a PMCE that
|
||||
// is already configured.
|
||||
if per_message_compression.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((config, response)) = deflate.accept_offer(extension) {
|
||||
per_message_compression = Some((
|
||||
DeflateContext::new(Role::Server, config).into(),
|
||||
response.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
// Ignore any unknown extensions in the offer.
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let (per_message_compression, response) = match per_message_compression {
|
||||
Some((a, b)) => (Some(a), Some(b)),
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
Ok((
|
||||
Extensions { per_message_compression },
|
||||
response.map(|response| SecWebsocketExtensions::new(std::iter::once(response))),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionsConfig {
|
||||
/// Bypasses negotiation of extension parameters and enables those that have
|
||||
/// been configured.
|
||||
///
|
||||
/// Returns an [`Extensions`] that has all the extensions enabled that this
|
||||
/// [`ExtensionsConfig`] was configured with.
|
||||
pub(crate) fn into_unnegotiated_context(self, role: Role) -> Extensions {
|
||||
// This can only be infallible while only one per-message compression
|
||||
// extension is supported. If more are added there will need to be some
|
||||
// resolution strategy for picking which one takes precedence.
|
||||
let Self {
|
||||
#[cfg(feature = "deflate")]
|
||||
permessage_deflate,
|
||||
} = self;
|
||||
|
||||
#[cfg_attr(feature = "deflate", allow(unused_assignments))]
|
||||
#[cfg_attr(not(feature = "deflate"), allow(unused_mut))]
|
||||
let mut per_message_compression = None;
|
||||
#[cfg(feature = "deflate")]
|
||||
{
|
||||
per_message_compression = permessage_deflate
|
||||
.map(|deflate| compression::deflate::DeflateContext::new(role, deflate).into());
|
||||
}
|
||||
let _ = role;
|
||||
|
||||
Extensions { per_message_compression }
|
||||
}
|
||||
}
|
||||
|
||||
impl Extensions {
|
||||
/// Returns a function that, if present, compresses a message payload.
|
||||
///
|
||||
/// The returned value will only be `Some` if a per-message compression
|
||||
/// extension, as specified by [RFC 7692], was configured for the connection
|
||||
/// state to which this `Extensions` applies.
|
||||
///
|
||||
/// [RFC 7692]: https://tools.ietf.org/html/rfc7692
|
||||
#[inline]
|
||||
pub(crate) fn per_message_compressor<'s>(
|
||||
&'s mut self,
|
||||
) -> Option<impl FnOnce(&Bytes) -> Result<Bytes, CompressionError> + 's> {
|
||||
let Self { per_message_compression } = self;
|
||||
|
||||
per_message_compression.as_mut().map(PerMessageCompressionContext::compressor)
|
||||
}
|
||||
|
||||
/// Returns a function that, if present, decompresses a frame payload.
|
||||
///
|
||||
/// The returned value will only be `Some` if a per-message compression
|
||||
/// extension, as specified by [RFC 7692], was configured for the connection
|
||||
/// state to which this `Extensions` applies. The closure takes as arguments
|
||||
/// the frame payload, in bytes, a boolean indicating whether the frame is
|
||||
/// the final one for a message, and the maximum number of uncompressed
|
||||
/// bytes to produce before returning an error.
|
||||
///
|
||||
/// [RFC 7692]: https://tools.ietf.org/html/rfc7692
|
||||
#[inline]
|
||||
pub(crate) fn per_message_decompressor<'s>(
|
||||
&'s mut self,
|
||||
) -> Option<impl FnMut(&Bytes, bool, usize) -> Result<Bytes, DecompressionError> + 's> {
|
||||
let Self { per_message_compression } = self;
|
||||
per_message_compression.as_mut().map(PerMessageCompressionContext::decompressor)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "handshake")]
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn accept_offers_ignores_unknown_extensions() {
|
||||
let (Extensions { per_message_compression }, response) = ExtensionsConfig::default()
|
||||
.accept_offers(&SecWebsocketExtensions::new([
|
||||
"unknown-1".parse().unwrap(),
|
||||
"other-unknown; a=5; b=3".parse().unwrap(),
|
||||
]))
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(per_message_compression, None));
|
||||
assert_eq!(response, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_offers_with_deflate_disabled() {
|
||||
let extensions = ExtensionsConfig::default();
|
||||
|
||||
// With or without #[cfg(feature = "deflate")], the extension should be ignored.
|
||||
let (Extensions { per_message_compression }, response) =
|
||||
extensions.accept_offers(&SecWebsocketExtensions::new([])).unwrap();
|
||||
|
||||
assert!(matches!(per_message_compression, None));
|
||||
assert_eq!(response, None);
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn accept_offers_with_deflate_enabled() {
|
||||
let extensions = ExtensionsConfig { permessage_deflate: Some(Default::default()) };
|
||||
|
||||
{
|
||||
// If the client doesn't offer permessage-deflate support, the response
|
||||
// shouldn't include it.
|
||||
let (Extensions { per_message_compression }, response) =
|
||||
extensions.accept_offers(&SecWebsocketExtensions::new([])).unwrap();
|
||||
assert!(matches!(per_message_compression, None));
|
||||
assert_eq!(response, None);
|
||||
}
|
||||
|
||||
{
|
||||
// If the client does offer support, the response should include it.
|
||||
let (Extensions { per_message_compression }, response) = extensions
|
||||
.accept_offers(&SecWebsocketExtensions::new([
|
||||
WebsocketProtocolExtension::new(compression::deflate::EXTENSION_NAME, []),
|
||||
WebsocketProtocolExtension::new("some-other-extension", []),
|
||||
]))
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(per_message_compression, Some(_)));
|
||||
assert_eq!(
|
||||
response,
|
||||
Some(SecWebsocketExtensions::new([WebsocketProtocolExtension::new(
|
||||
compression::deflate::EXTENSION_NAME,
|
||||
[]
|
||||
)]))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn accept_offers_picks_first_acceptable_offer() {
|
||||
use compression::deflate::*;
|
||||
let extensions = ExtensionsConfig {
|
||||
permessage_deflate: Some(
|
||||
DeflateConfig::new().set_max_window_bits(Role::Client, 11).unwrap(),
|
||||
),
|
||||
};
|
||||
|
||||
let (Extensions { per_message_compression }, response) = extensions
|
||||
.accept_offers(&SecWebsocketExtensions::new([
|
||||
// These two offers are declined because they doesn't indicate
|
||||
// support for client_max_window_bits, which the server is
|
||||
// configured to require.
|
||||
"permessage-deflate".parse().unwrap(),
|
||||
"permessage-deflate; server_max_window_bits=12".parse().unwrap(),
|
||||
// This offer would be acceptable but it has a parameter that the server doesn't recognize.
|
||||
"permessage-deflate; client_max_window_bits=11; parameter-from-the-future=3"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
// This offer is accepted.
|
||||
"permessage-deflate; client_no_context_takeover; client_max_window_bits=11"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
// This offer is ignored since an earlier one was accepted.
|
||||
"permessage-deflate; client_max_window_bits=10".parse().unwrap(),
|
||||
]))
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(per_message_compression, Some(PerMessageCompressionContext::Deflate(_))));
|
||||
assert_eq!(
|
||||
response,
|
||||
Some(SecWebsocketExtensions::new([DeflateConfig::new()
|
||||
.set_no_context_takeover(Role::Client, true)
|
||||
.set_max_window_bits(Role::Client, 11)
|
||||
.unwrap()
|
||||
.as_offer()
|
||||
.into()]))
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn verify_agreed_on_deflate_then_garbage() {
|
||||
let extensions = ExtensionsConfig { permessage_deflate: Some(Default::default()) };
|
||||
|
||||
let result = extensions.verify_agreed_on(SecWebsocketExtensions::new([
|
||||
WebsocketProtocolExtension::new(compression::deflate::EXTENSION_NAME, []),
|
||||
WebsocketProtocolExtension::new("unrecognized", []),
|
||||
]));
|
||||
|
||||
assert_eq!(result.unwrap_err(), ExtensionsError::InvalidExtension("unrecognized".into()));
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn verify_agreed_on_deflate_multiple_times() {
|
||||
let extensions = ExtensionsConfig { permessage_deflate: Some(Default::default()) };
|
||||
|
||||
let result = extensions.verify_agreed_on(SecWebsocketExtensions::new([
|
||||
WebsocketProtocolExtension::new(compression::deflate::EXTENSION_NAME, []),
|
||||
WebsocketProtocolExtension::new(
|
||||
compression::deflate::EXTENSION_NAME,
|
||||
["client_no_context_takeover".parse().unwrap()],
|
||||
),
|
||||
]));
|
||||
|
||||
assert_eq!(
|
||||
result.unwrap_err(),
|
||||
ExtensionsError::ExtensionConflict(compression::deflate::EXTENSION_NAME.into())
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -5,6 +5,7 @@ use std::{
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use headers::{Header, HeaderMapExt};
|
||||
use http::{
|
||||
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||
};
|
||||
@ -19,6 +20,7 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result, SubProtocolError, UrlError},
|
||||
extensions::{headers::SecWebsocketExtensions, Extensions, ExtensionsConfig},
|
||||
handshake::version_as_str,
|
||||
protocol::{Role, WebSocket, WebSocketConfig},
|
||||
};
|
||||
@ -59,7 +61,7 @@ impl<S: Read + Write> ClientHandshake<S> {
|
||||
|
||||
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
||||
// Also extract the key from it (it must be present in a correct request).
|
||||
let (request, key) = generate_request(request)?;
|
||||
let (request, key) = generate_request(request, config.as_ref().map(|w| &w.extensions))?;
|
||||
|
||||
let machine = HandshakeMachine::start_write(stream, request);
|
||||
|
||||
@ -90,7 +92,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
|
||||
}
|
||||
StageResult::DoneReading { stream, result, tail } => {
|
||||
let result = match self.verify_data.verify_response(result) {
|
||||
let (result, extensions) = match self
|
||||
.verify_data
|
||||
.verify_response(result, self.config.as_ref().map(|c| &c.extensions))
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(Error::Http(mut e)) => {
|
||||
*e.body_mut() = Some(tail);
|
||||
@ -100,8 +105,13 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||
};
|
||||
|
||||
debug!("Client handshake done.");
|
||||
let websocket =
|
||||
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
|
||||
let websocket = WebSocket::from_partially_read_with_extensions(
|
||||
stream,
|
||||
tail,
|
||||
Role::Client,
|
||||
self.config,
|
||||
extensions,
|
||||
);
|
||||
ProcessingResult::Done((websocket, result))
|
||||
}
|
||||
})
|
||||
@ -109,7 +119,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||
}
|
||||
|
||||
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
|
||||
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
||||
pub fn generate_request(
|
||||
mut request: Request,
|
||||
extensions: Option<&ExtensionsConfig>,
|
||||
) -> Result<(Vec<u8>, String)> {
|
||||
let mut req = Vec::new();
|
||||
write!(
|
||||
req,
|
||||
@ -161,6 +174,14 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(header) = extensions
|
||||
.map(ExtensionsConfig::generate_offers)
|
||||
.map(SecWebsocketExtensions::new)
|
||||
.filter(|header| header.len() != 0)
|
||||
{
|
||||
headers.append(SecWebsocketExtensions::name(), header.header_value());
|
||||
}
|
||||
|
||||
// Now we must ensure that the headers that we've written once are not anymore present in the map.
|
||||
// If they do, then the request is invalid (some headers are duplicated there for some reason).
|
||||
let websocket_headers_contains =
|
||||
@ -220,7 +241,11 @@ struct VerifyData {
|
||||
}
|
||||
|
||||
impl VerifyData {
|
||||
pub fn verify_response(&self, response: Response) -> Result<Response> {
|
||||
pub fn verify_response(
|
||||
&self,
|
||||
response: Response,
|
||||
extensions: Option<&ExtensionsConfig>,
|
||||
) -> Result<(Response, Extensions)> {
|
||||
// 1. If the status code received from the server is not 101, the
|
||||
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
|
||||
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||
@ -265,7 +290,18 @@ impl VerifyData {
|
||||
// that was not present in the client's handshake (the server has
|
||||
// indicated an extension not requested by the client), the client
|
||||
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
||||
// TODO
|
||||
let extensions_header =
|
||||
headers.typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
|
||||
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
|
||||
})?;
|
||||
|
||||
let extensions = match extensions_header {
|
||||
None => Extensions::default(),
|
||||
Some(agreed) => extensions
|
||||
.ok_or(ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into()))?
|
||||
.verify_agreed_on(agreed)
|
||||
.map_err(ProtocolError::from)?,
|
||||
};
|
||||
|
||||
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
|
||||
// and this header field indicates the use of a subprotocol that was
|
||||
@ -294,7 +330,7 @@ impl VerifyData {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
Ok((response, extensions))
|
||||
}
|
||||
}
|
||||
|
||||
@ -374,7 +410,7 @@ mod tests {
|
||||
#[test]
|
||||
fn request_formatting() {
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, None).unwrap();
|
||||
let correct = construct_expected("localhost", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
@ -382,7 +418,7 @@ mod tests {
|
||||
#[test]
|
||||
fn request_formatting_with_host() {
|
||||
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, None).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
@ -390,11 +426,41 @@ mod tests {
|
||||
#[test]
|
||||
fn request_formatting_with_at() {
|
||||
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, None).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn request_with_compression() {
|
||||
use crate::extensions::{compression::deflate::DeflateConfig, ExtensionsConfig};
|
||||
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(
|
||||
request,
|
||||
Some(&ExtensionsConfig {
|
||||
permessage_deflate: Some(DeflateConfig::default()),
|
||||
..ExtensionsConfig::default()
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
let correct = format!(
|
||||
"\
|
||||
GET /getCaseCount HTTP/1.1\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n\
|
||||
sec-websocket-extensions: permessage-deflate; client_max_window_bits\r\n\
|
||||
\r\n",
|
||||
host = "localhost",
|
||||
key = key
|
||||
);
|
||||
assert_eq!(String::try_from(request).unwrap(), &correct[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_parsing() {
|
||||
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
|
||||
@ -406,6 +472,6 @@ mod tests {
|
||||
#[test]
|
||||
fn invalid_custom_request() {
|
||||
let request = http::Request::builder().method("GET").body(()).unwrap();
|
||||
assert!(generate_request(request).is_err());
|
||||
assert!(generate_request(request, None).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ use std::{
|
||||
result::Result as StdResult,
|
||||
};
|
||||
|
||||
use headers::{Header, HeaderMapExt};
|
||||
use http::{
|
||||
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||
};
|
||||
@ -20,6 +21,7 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result},
|
||||
extensions::{headers::SecWebsocketExtensions, Extensions},
|
||||
handshake::version_as_str,
|
||||
protocol::{Role, WebSocket, WebSocketConfig},
|
||||
};
|
||||
@ -203,6 +205,8 @@ pub struct ServerHandshake<S, C> {
|
||||
config: Option<WebSocketConfig>,
|
||||
/// Error code/flag. If set, an error will be returned after sending response to the client.
|
||||
error_response: Option<ErrorResponse>,
|
||||
// Negotiated extension context for server.
|
||||
extensions: Extensions,
|
||||
/// Internal stream type.
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
@ -220,6 +224,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
|
||||
callback: Some(callback),
|
||||
config,
|
||||
error_response: None,
|
||||
extensions: Extensions::default(),
|
||||
_marker: PhantomData,
|
||||
},
|
||||
}
|
||||
@ -241,7 +246,30 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
||||
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
|
||||
}
|
||||
|
||||
let response = create_response(&result)?;
|
||||
let mut response = create_response(&result)?;
|
||||
if let Some(extensions) =
|
||||
result.headers().typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
|
||||
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
|
||||
})?
|
||||
{
|
||||
let extensions_config = self
|
||||
.config
|
||||
.ok_or_else(|| {
|
||||
ProtocolError::InvalidHeader(
|
||||
SecWebsocketExtensions::name().clone().into(),
|
||||
)
|
||||
})?
|
||||
.extensions;
|
||||
let (extensions, agreed) = extensions_config
|
||||
.accept_offers(&extensions)
|
||||
.map_err(ProtocolError::from)?;
|
||||
|
||||
if let Some(agreed) = agreed {
|
||||
response.headers_mut().typed_insert(agreed)
|
||||
};
|
||||
self.extensions = extensions;
|
||||
}
|
||||
|
||||
let callback_result = if let Some(callback) = self.callback.take() {
|
||||
callback.on_request(&result, response)
|
||||
} else {
|
||||
@ -284,7 +312,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
||||
return Err(Error::Http(http::Response::from_parts(parts, body).into()));
|
||||
} else {
|
||||
debug!("Server handshake done.");
|
||||
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
|
||||
let websocket = WebSocket::from_raw_socket_with_extensions(
|
||||
stream,
|
||||
Role::Server,
|
||||
self.config,
|
||||
std::mem::take(&mut self.extensions),
|
||||
);
|
||||
ProcessingResult::Done(websocket)
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ pub mod buffer;
|
||||
#[cfg(feature = "handshake")]
|
||||
pub mod client;
|
||||
pub mod error;
|
||||
pub mod extensions;
|
||||
#[cfg(feature = "handshake")]
|
||||
pub mod handshake;
|
||||
pub mod protocol;
|
||||
|
||||
@ -300,6 +300,32 @@ impl Frame {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new compressed data frame.
|
||||
///
|
||||
/// `opcode` is of type `Data` because, per [RFC 7692 Section 6], "PMCEs
|
||||
/// operate only on data messages".
|
||||
///
|
||||
/// [RFC 7692 Section 6]: https://tools.ietf.org/html/rfc7692#section-6
|
||||
#[inline]
|
||||
pub(crate) fn compressed_message(data: Bytes, opcode: Data, is_final: bool) -> Frame {
|
||||
Frame {
|
||||
header: FrameHeader {
|
||||
is_final,
|
||||
opcode: OpCode::Data(opcode),
|
||||
// Per RFC 7692 Section 6:
|
||||
//
|
||||
// This document allocates the RSV1 bit of the WebSocket
|
||||
// header for PMCEs and calls the bit the "Per-Message
|
||||
// Compressed" bit. On a WebSocket connection where a PMCE is
|
||||
// in use, this bit indicates whether a message is compressed
|
||||
// or not.
|
||||
rsv1: true,
|
||||
..FrameHeader::default()
|
||||
},
|
||||
payload: data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Pong control frame.
|
||||
#[inline]
|
||||
pub fn pong(data: impl Into<Bytes>) -> Frame {
|
||||
|
||||
@ -301,7 +301,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn read_frames() {
|
||||
env_logger::init();
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let raw = Cursor::new(vec![
|
||||
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
|
||||
|
||||
@ -83,6 +83,8 @@ use bytes::Bytes;
|
||||
#[derive(Debug)]
|
||||
pub struct IncompleteMessage {
|
||||
collector: IncompleteMessageCollector,
|
||||
#[cfg(feature = "deflate")]
|
||||
compressed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -101,9 +103,32 @@ impl IncompleteMessage {
|
||||
IncompleteMessageCollector::Text(StringCollector::new())
|
||||
}
|
||||
},
|
||||
#[cfg(feature = "deflate")]
|
||||
compressed: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create new instance that will hold compressed data.
|
||||
#[cfg(feature = "deflate")]
|
||||
pub fn new_compressed(message_type: IncompleteMessageType) -> Self {
|
||||
IncompleteMessage {
|
||||
collector: match message_type {
|
||||
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
|
||||
IncompleteMessageType::Text => {
|
||||
IncompleteMessageCollector::Text(StringCollector::new())
|
||||
}
|
||||
},
|
||||
compressed: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compressed(&self) -> bool {
|
||||
#[cfg(feature = "deflate")]
|
||||
return self.compressed;
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Get the current filled size of the buffer.
|
||||
pub fn len(&self) -> usize {
|
||||
match self.collector {
|
||||
|
||||
@ -15,12 +15,14 @@ use self::{
|
||||
};
|
||||
use crate::{
|
||||
error::{CapacityError, Error, ProtocolError, Result},
|
||||
extensions::{compression::DecompressionError, Extensions, ExtensionsConfig},
|
||||
protocol::frame::Utf8Bytes,
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
io::{self, Read, Write},
|
||||
mem::replace,
|
||||
usize,
|
||||
};
|
||||
|
||||
/// Indicates a Client or Server role of the websocket
|
||||
@ -90,6 +92,11 @@ pub struct WebSocketConfig {
|
||||
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
|
||||
/// By default this option is set to `false`, i.e. according to RFC 6455.
|
||||
pub accept_unmasked_frames: bool,
|
||||
/// Configuration for optional extensions to the base websocket protocol.
|
||||
///
|
||||
/// Some extensions may require optional features to be enabled at build
|
||||
/// time to be supported.
|
||||
pub extensions: ExtensionsConfig,
|
||||
}
|
||||
|
||||
impl Default for WebSocketConfig {
|
||||
@ -101,6 +108,7 @@ impl Default for WebSocketConfig {
|
||||
max_message_size: Some(64 << 20),
|
||||
max_frame_size: Some(16 << 20),
|
||||
accept_unmasked_frames: false,
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -179,6 +187,18 @@ impl<Stream> WebSocket<Stream> {
|
||||
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
|
||||
}
|
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
pub fn from_raw_socket_with_extensions(
|
||||
stream: Stream,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Extensions,
|
||||
) -> Self {
|
||||
let mut context = WebSocketContext::new(role, config);
|
||||
context.extensions = extensions;
|
||||
WebSocket { socket: stream, context }
|
||||
}
|
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
///
|
||||
/// Call this function if you're using Tungstenite as a part of a web framework
|
||||
@ -192,10 +212,22 @@ impl<Stream> WebSocket<Stream> {
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Self {
|
||||
Self::from_partially_read_with_extensions(stream, part, role, config, Extensions::default())
|
||||
}
|
||||
|
||||
pub(crate) fn from_partially_read_with_extensions(
|
||||
stream: Stream,
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Extensions,
|
||||
) -> Self {
|
||||
WebSocket {
|
||||
socket: stream,
|
||||
context: WebSocketContext::from_partially_read(part, role, config),
|
||||
context: WebSocketContext::from_partially_read_with_extensions(
|
||||
part, role, config, extensions,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -375,6 +407,8 @@ pub struct WebSocketContext {
|
||||
unflushed_additional: bool,
|
||||
/// The configuration for the websocket session.
|
||||
config: WebSocketConfig,
|
||||
// Container for extensions.
|
||||
extensions: Extensions,
|
||||
}
|
||||
|
||||
impl WebSocketContext {
|
||||
@ -384,7 +418,12 @@ impl WebSocketContext {
|
||||
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
|
||||
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
|
||||
let conf = config.unwrap_or_default();
|
||||
Self::_new(role, FrameCodec::new(conf.read_buffer_size), conf)
|
||||
Self::_new(
|
||||
role,
|
||||
FrameCodec::new(conf.read_buffer_size),
|
||||
conf,
|
||||
conf.extensions.into_unnegotiated_context(role),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a WebSocket context that manages an post-handshake stream.
|
||||
@ -393,10 +432,41 @@ impl WebSocketContext {
|
||||
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
|
||||
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
|
||||
let conf = config.unwrap_or_default();
|
||||
Self::_new(role, FrameCodec::from_partially_read(part, conf.read_buffer_size), conf)
|
||||
let extensions = conf.extensions.into_unnegotiated_context(role);
|
||||
Self::_new(
|
||||
role,
|
||||
FrameCodec::from_partially_read(part, conf.read_buffer_size),
|
||||
conf,
|
||||
extensions,
|
||||
)
|
||||
}
|
||||
|
||||
fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
|
||||
/// Create a WebSocket context for a post-handshake stream with the enabled extensions.
|
||||
///
|
||||
/// Where [`WebSocketContext::from_partially_read`] infers the enabled
|
||||
/// extensions from the [`WebSocketConfig`], this allows the caller to
|
||||
/// explicitly sets the extensions in use for the connection.
|
||||
pub(crate) fn from_partially_read_with_extensions(
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Extensions,
|
||||
) -> Self {
|
||||
let conf = config.unwrap_or_default();
|
||||
Self::_new(
|
||||
role,
|
||||
FrameCodec::from_partially_read(part, conf.read_buffer_size),
|
||||
conf,
|
||||
extensions,
|
||||
)
|
||||
}
|
||||
|
||||
fn _new(
|
||||
role: Role,
|
||||
mut frame: FrameCodec,
|
||||
config: WebSocketConfig,
|
||||
extensions: Extensions,
|
||||
) -> Self {
|
||||
config.assert_valid();
|
||||
frame.set_max_out_buffer_len(config.max_write_buffer_size);
|
||||
frame.set_out_buffer_write_len(config.write_buffer_size);
|
||||
@ -408,6 +478,7 @@ impl WebSocketContext {
|
||||
additional_send: None,
|
||||
unflushed_additional: false,
|
||||
config,
|
||||
extensions,
|
||||
}
|
||||
}
|
||||
|
||||
@ -500,9 +571,18 @@ impl WebSocketContext {
|
||||
return Err(Error::Protocol(ProtocolError::SendAfterClosing));
|
||||
}
|
||||
|
||||
let mut prepare_data_frame = |data, opdata| -> Result<Frame, ProtocolError> {
|
||||
const IS_FINAL: bool = true;
|
||||
if let Some(compressor) = self.extensions.per_message_compressor() {
|
||||
let compressed = compressor(&data)?;
|
||||
return Ok(Frame::compressed_message(compressed, opdata, IS_FINAL));
|
||||
}
|
||||
Ok(Frame::message(data, OpCode::Data(opdata), IS_FINAL))
|
||||
};
|
||||
|
||||
let frame = match message {
|
||||
Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true),
|
||||
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
|
||||
Message::Text(data) => prepare_data_frame(data.into(), OpData::Text)?,
|
||||
Message::Binary(data) => prepare_data_frame(data, OpData::Binary)?,
|
||||
Message::Ping(data) => Frame::ping(data),
|
||||
Message::Pong(data) => {
|
||||
self.set_additional(Frame::pong(data));
|
||||
@ -633,17 +713,53 @@ impl WebSocketContext {
|
||||
if !self.state.can_read() {
|
||||
return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing));
|
||||
}
|
||||
// MUST be 0 unless an extension is negotiated that defines meanings
|
||||
// for non-zero values. If a nonzero value is received and none of
|
||||
// the negotiated extensions defines the meaning of such a nonzero
|
||||
// value, the receiving endpoint MUST _Fail the WebSocket
|
||||
// Connection_.
|
||||
{
|
||||
|
||||
let (is_compressed, decompressor) = {
|
||||
let decompressor = self.extensions.per_message_decompressor();
|
||||
let hdr = frame.header();
|
||||
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
|
||||
// Per RFC 6455, the RSV1, RSV2, and RSV3 bits
|
||||
//
|
||||
// MUST be 0 unless an extension is negotiated that defines
|
||||
// meanings for non-zero values. If a nonzero value is
|
||||
// received and none of the negotiated extensions defines the
|
||||
// meaning of such a nonzero value, the receiving endpoint
|
||||
// MUST _Fail the WebSocket Connection_.
|
||||
//
|
||||
// Per RFC 7692:
|
||||
//
|
||||
// This document allocates the RSV1 bit of the WebSocket
|
||||
// header for PMCEs and calls the bit the "Per-Message
|
||||
// Compressed" bit. On a WebSocket connection where a PMCE is
|
||||
// in use, this bit indicates whether a message is compressed
|
||||
// or not.
|
||||
if (hdr.rsv1 && decompressor.is_none()) || hdr.rsv2 || hdr.rsv3 {
|
||||
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
|
||||
}
|
||||
}
|
||||
|
||||
let decompressor_with_size_limit = decompressor.map(|mut f| {
|
||||
let incomplete_len =
|
||||
self.incomplete.as_ref().map(IncompleteMessage::len).unwrap_or(0);
|
||||
let message_max = self.config.max_message_size.unwrap_or(usize::MAX);
|
||||
|
||||
move |bytes, is_final| {
|
||||
let decompress_limit = message_max.saturating_sub(incomplete_len);
|
||||
|
||||
f(bytes, is_final, decompress_limit).map_err(|e| match e {
|
||||
DecompressionError::SizeLimitReached => {
|
||||
Error::Capacity(CapacityError::MessageTooLong {
|
||||
size: incomplete_len.saturating_add(decompress_limit),
|
||||
max_size: message_max,
|
||||
})
|
||||
}
|
||||
DecompressionError::Decompression(e) => {
|
||||
Error::Protocol(ProtocolError::CompressionFailure(e))
|
||||
}
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
(hdr.rsv1, decompressor_with_size_limit)
|
||||
};
|
||||
|
||||
if self.role == Role::Client && frame.is_masked() {
|
||||
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
|
||||
@ -652,6 +768,7 @@ impl WebSocketContext {
|
||||
|
||||
match frame.header().opcode {
|
||||
OpCode::Control(ctl) => {
|
||||
drop(decompressor);
|
||||
match ctl {
|
||||
// All control frames MUST have a payload length of 125 bytes or less
|
||||
// and MUST NOT be fragmented. (RFC 6455)
|
||||
@ -661,6 +778,15 @@ impl WebSocketContext {
|
||||
_ if frame.payload().len() > 125 => {
|
||||
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
|
||||
}
|
||||
// Per RFC 7692:
|
||||
//
|
||||
// An endpoint MUST NOT set the "Per-Message
|
||||
// Compressed" bit of control frames and non-first
|
||||
// fragments of a data message. An endpoint receiving
|
||||
// such a frame MUST _Fail the WebSocket Connection_.
|
||||
_ if is_compressed => {
|
||||
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
|
||||
}
|
||||
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
|
||||
OpCtl::Reserved(i) => {
|
||||
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
|
||||
@ -679,29 +805,74 @@ impl WebSocketContext {
|
||||
|
||||
OpCode::Data(data) => {
|
||||
let fin = frame.header().is_final;
|
||||
|
||||
match data {
|
||||
OpData::Continue => {
|
||||
let msg = self
|
||||
let incomplete = self
|
||||
.incomplete
|
||||
.as_mut()
|
||||
.ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
|
||||
msg.extend(frame.into_payload(), self.config.max_message_size)?;
|
||||
|
||||
// Per RFC 7692:
|
||||
//
|
||||
// An endpoint MUST NOT set the "Per-Message
|
||||
// Compressed" bit of control frames and non-first
|
||||
// fragments of a data message. An endpoint
|
||||
// receiving such a frame MUST _Fail the WebSocket
|
||||
// Connection_.
|
||||
if is_compressed {
|
||||
return Err(Error::Protocol(ProtocolError::CompressedContinueFrame));
|
||||
}
|
||||
|
||||
let mut payload = frame.into_payload();
|
||||
if incomplete.compressed() {
|
||||
let mut decompressor = decompressor.ok_or_else(|| {
|
||||
// This is a continuation frame that was
|
||||
// received with compression disabled, but
|
||||
// the initial frame of the message was
|
||||
// received with compression *enabled* and
|
||||
// RSV1 set.
|
||||
//
|
||||
// The only way to get here is to manually
|
||||
// disable compression for a stream after
|
||||
// it's been established, which is arguably
|
||||
// operator error. This is incorrect enough
|
||||
// that it's not worth spending a lot of
|
||||
// code on, but it's better to return an
|
||||
// error here than crash.
|
||||
log::debug!("compression was disabled between receiving frames");
|
||||
ProtocolError::CompressedContinueFrame
|
||||
})?;
|
||||
|
||||
payload = decompressor(&payload, fin)?;
|
||||
};
|
||||
|
||||
incomplete.extend(payload, self.config.max_message_size)?;
|
||||
|
||||
if fin {
|
||||
Ok(Some(self.incomplete.take().unwrap().complete()?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
c if self.incomplete.is_some() => {
|
||||
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
|
||||
}
|
||||
OpData::Text if fin => {
|
||||
check_max_size(frame.payload().len(), self.config.max_message_size)?;
|
||||
Ok(Some(Message::Text(frame.into_text()?)))
|
||||
}
|
||||
OpData::Binary if fin => {
|
||||
check_max_size(frame.payload().len(), self.config.max_message_size)?;
|
||||
Ok(Some(Message::Binary(frame.into_payload())))
|
||||
OpData::Text | OpData::Binary if fin => {
|
||||
let payload = frame.into_payload();
|
||||
let payload = match decompressor.filter(|_| is_compressed) {
|
||||
Some(mut frame_decompressor) => frame_decompressor(&payload, fin)?,
|
||||
None => payload,
|
||||
};
|
||||
|
||||
check_max_size(payload.len(), self.config.max_message_size)?;
|
||||
|
||||
match data {
|
||||
OpData::Text => Ok(Some(Message::Text(payload.try_into()?))),
|
||||
OpData::Binary => Ok(Some(Message::Binary(payload))),
|
||||
_ => panic!("Bug: message is not text nor binary"),
|
||||
}
|
||||
}
|
||||
OpData::Text | OpData::Binary => {
|
||||
let message_type = match data {
|
||||
@ -709,8 +880,19 @@ impl WebSocketContext {
|
||||
OpData::Binary => IncompleteMessageType::Binary,
|
||||
_ => panic!("Bug: message is not text nor binary"),
|
||||
};
|
||||
let mut incomplete = IncompleteMessage::new(message_type);
|
||||
incomplete.extend(frame.into_payload(), self.config.max_message_size)?;
|
||||
|
||||
let payload = frame.into_payload();
|
||||
let payload = match decompressor.filter(|_| is_compressed) {
|
||||
Some(mut frame_decompressor) => frame_decompressor(&payload, fin)?,
|
||||
None => payload,
|
||||
};
|
||||
let mut incomplete = match is_compressed {
|
||||
#[cfg(feature = "deflate")]
|
||||
true => IncompleteMessage::new_compressed(message_type),
|
||||
_ => IncompleteMessage::new(message_type),
|
||||
};
|
||||
incomplete.extend(payload, self.config.max_message_size)?;
|
||||
|
||||
self.incomplete = Some(incomplete);
|
||||
Ok(None)
|
||||
}
|
||||
@ -860,6 +1042,7 @@ impl<T> CheckConnectionReset for Result<T> {
|
||||
mod tests {
|
||||
use super::{Message, Role, WebSocket, WebSocketConfig};
|
||||
use crate::error::{CapacityError, Error};
|
||||
use crate::extensions::ExtensionsConfig;
|
||||
|
||||
use std::{io, io::Cursor};
|
||||
|
||||
@ -920,4 +1103,178 @@ mod tests {
|
||||
Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 }))
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn per_message_deflate_compression() {
|
||||
// Example frames from RFC 7692 Section 7.2.3.2
|
||||
|
||||
use crate::{extensions::compression, protocol::FrameCodec};
|
||||
|
||||
let mut stream = Cursor::new(Vec::new());
|
||||
|
||||
let config = WebSocketConfig {
|
||||
extensions: ExtensionsConfig {
|
||||
permessage_deflate: Some(compression::deflate::DeflateConfig::default()),
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let mut socket = WebSocket::from_raw_socket(&mut stream, Role::Client, Some(config));
|
||||
|
||||
// The same message sent twice should compress better the second time
|
||||
// because context takeover is enabled.
|
||||
socket.write(Message::Text("Hello".into())).unwrap();
|
||||
socket.write(Message::Text("Hello".into())).unwrap();
|
||||
socket.flush().unwrap();
|
||||
|
||||
let written = stream.into_inner();
|
||||
let mut codec = FrameCodec::new(written.len());
|
||||
|
||||
let mut stream = Cursor::new(written);
|
||||
let first_frame = codec.read_frame(&mut stream, None, true, false).unwrap().unwrap();
|
||||
let second_frame = codec.read_frame(&mut stream, None, true, false).unwrap().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
first_frame.payload(),
|
||||
// First frame payload, from the RFC
|
||||
&[0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
second_frame.payload(),
|
||||
// Second frame payload, from the RFC
|
||||
&[0xf2, 0x00, 0x11, 0x00, 0x00]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn per_message_deflate_decompression() {
|
||||
// Example frames from RFC 7692 Section 7.2.3.2
|
||||
|
||||
use crate::extensions::compression::deflate::DeflateConfig;
|
||||
|
||||
let incoming =
|
||||
Cursor::new(&[0x41, 0x03, 0xf2, 0x48, 0xcd, 0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00]);
|
||||
let config = WebSocketConfig {
|
||||
extensions: ExtensionsConfig { permessage_deflate: Some(DeflateConfig::default()) },
|
||||
..Default::default()
|
||||
};
|
||||
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(config));
|
||||
|
||||
assert_eq!(socket.read().unwrap(), Message::Text("Hello".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn per_message_compression_not_recognized() {
|
||||
// Without the extension configuration, frames with the RSV1 bit set are rejected.
|
||||
|
||||
let incoming =
|
||||
Cursor::new(&[0x41, 0x03, 0xf2, 0x48, 0xcd, 0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00]);
|
||||
let config =
|
||||
WebSocketConfig { extensions: ExtensionsConfig::default(), ..Default::default() };
|
||||
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(config));
|
||||
|
||||
assert!(matches!(
|
||||
socket.read().unwrap_err(),
|
||||
Error::Protocol(crate::error::ProtocolError::NonZeroReservedBits)
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn per_message_compression_decompress_respects_message_size_limit() {
|
||||
use crate::extensions::compression::deflate::test::very_compressed;
|
||||
use crate::extensions::compression::deflate::DeflateConfig;
|
||||
use crate::protocol::frame::{
|
||||
coding::{Data, OpCode},
|
||||
FrameHeader,
|
||||
};
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let base_config = WebSocketConfig {
|
||||
extensions: ExtensionsConfig {
|
||||
permessage_deflate: Some(DeflateConfig::default()),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
fn make_message(frame_count: usize) -> Vec<u8> {
|
||||
let mut is_first = true;
|
||||
let mut output = Vec::new();
|
||||
|
||||
for (frame, is_final) in very_compressed::make_frames(frame_count) {
|
||||
let is_first = std::mem::replace(&mut is_first, false);
|
||||
let header = FrameHeader {
|
||||
opcode: OpCode::Data(if is_first { Data::Binary } else { Data::Continue }),
|
||||
rsv1: is_first,
|
||||
is_final,
|
||||
..Default::default()
|
||||
};
|
||||
header.format(frame.len() as u64, &mut output).unwrap();
|
||||
output.extend_from_slice(&frame);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
// With the default configuration, a short message of these frames is fine.
|
||||
{
|
||||
let input = Cursor::new(make_message(4));
|
||||
let mut socket =
|
||||
WebSocket::from_raw_socket(input, Role::Client, Some(base_config.clone()));
|
||||
|
||||
let message = socket.read().unwrap();
|
||||
assert_eq!(
|
||||
message,
|
||||
Message::Binary(
|
||||
bytes::BytesMut::zeroed(4 * very_compressed::DECOMPRESSED_LEN).into()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// The maximum frame size limits on-the-wire frame size, not
|
||||
// decompressed size, so this still decompresses.
|
||||
{
|
||||
let input = Cursor::new(make_message(2));
|
||||
const MAX_FRAME_SIZE: usize = very_compressed::DECOMPRESSED_LEN - 1;
|
||||
|
||||
let mut socket = WebSocket::from_raw_socket(
|
||||
input,
|
||||
Role::Client,
|
||||
Some(base_config.clone().max_frame_size(Some(MAX_FRAME_SIZE))),
|
||||
);
|
||||
|
||||
let message = socket.read().unwrap();
|
||||
assert_eq!(
|
||||
message,
|
||||
Message::Binary(
|
||||
bytes::BytesMut::zeroed(2 * very_compressed::DECOMPRESSED_LEN).into()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// With a reduced maximum message size, decompressing the whole message
|
||||
// fails.
|
||||
{
|
||||
let input = Cursor::new(make_message(5));
|
||||
const MAX_MESSAGE_SIZE: usize = 3 * very_compressed::DECOMPRESSED_LEN;
|
||||
|
||||
let mut socket = WebSocket::from_raw_socket(
|
||||
input,
|
||||
Role::Client,
|
||||
Some(base_config.clone().max_message_size(Some(MAX_MESSAGE_SIZE))),
|
||||
);
|
||||
|
||||
let error = socket.read().unwrap_err();
|
||||
assert!(matches!(
|
||||
error,
|
||||
Error::Capacity(CapacityError::MessageTooLong {
|
||||
size: _,
|
||||
max_size: MAX_MESSAGE_SIZE
|
||||
})
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user