Perform extension negotiation during handshaking

Also enable construction of WebSocket instances with configured
extensions.

Based on work by Benjamin Swart <Benjaminswart@email.cz>
This commit is contained in:
Alex Bakon 2025-08-22 13:47:38 -04:00
parent 8a7ef29ef7
commit 9e8e06816e
6 changed files with 162 additions and 19 deletions

View File

@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
otherwise you won't be able to communicate with the TLS endpoints.
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
Testing
-------

View File

@ -3,7 +3,7 @@
use std::{io, result, str, string};
use crate::{
extensions::compression::CompressionError,
extensions::{compression::CompressionError, ExtensionsError},
protocol::{frame::coding::Data, Message},
};
#[cfg(feature = "handshake")]
@ -197,6 +197,9 @@ pub enum ProtocolError {
/// The `Sec-WebSocket-Protocol` header was invalid
#[error("SubProtocol error: {0}")]
SecWebSocketSubProtocolError(SubProtocolError),
/// The `Sec-WebSocket-Extensions` header is invalid.
#[error("Invalid \"Sec-WebSocket-Extensions\" header: {0}")]
InvalidExtensionsHeader(#[from] ExtensionsError),
/// Garbage data encountered after client request.
#[error("Junk after client request")]
JunkAfterRequest,

View File

@ -1,6 +1,5 @@
//! WebSocket extensions.
// Only `permessage-deflate` is supported at the moment.
#![allow(dead_code)]
use bytes::Bytes;
use thiserror::Error;

View File

@ -5,6 +5,7 @@ use std::{
marker::PhantomData,
};
use headers::{Header, HeaderMapExt};
use http::{
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -19,6 +20,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result, SubProtocolError, UrlError},
extensions::{headers::SecWebsocketExtensions, Extensions, ExtensionsConfig},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -58,7 +60,7 @@ impl<S: Read + Write> ClientHandshake<S> {
// Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?;
let (request, key) = generate_request(request, config.as_ref().map(|w| &w.extensions))?;
let machine = HandshakeMachine::start_write(stream, request);
@ -89,7 +91,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail } => {
let result = match self.verify_data.verify_response(result) {
let (result, extensions) = match self
.verify_data
.verify_response(result, self.config.as_ref().map(|c| &c.extensions))
{
Ok(r) => r,
Err(Error::Http(mut e)) => {
*e.body_mut() = Some(tail);
@ -99,8 +104,13 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
};
debug!("Client handshake done.");
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
let websocket = WebSocket::from_partially_read_with_extensions(
stream,
tail,
Role::Client,
self.config,
extensions,
);
ProcessingResult::Done((websocket, result))
}
})
@ -108,7 +118,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
}
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
pub fn generate_request(
mut request: Request,
extensions: Option<&ExtensionsConfig>,
) -> Result<(Vec<u8>, String)> {
let mut req = Vec::new();
write!(
req,
@ -160,6 +173,14 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
.unwrap();
}
if let Some(header) = extensions
.map(ExtensionsConfig::generate_offers)
.map(SecWebsocketExtensions::new)
.filter(|header| header.len() != 0)
{
headers.append(SecWebsocketExtensions::name(), header.header_value());
}
// Now we must ensure that the headers that we've written once are not anymore present in the map.
// If they do, then the request is invalid (some headers are duplicated there for some reason).
let insensitive: Vec<String> =
@ -219,7 +240,11 @@ struct VerifyData {
}
impl VerifyData {
pub fn verify_response(&self, response: Response) -> Result<Response> {
pub fn verify_response(
&self,
response: Response,
extensions: Option<&ExtensionsConfig>,
) -> Result<(Response, Extensions)> {
// 1. If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
@ -264,7 +289,18 @@ impl VerifyData {
// that was not present in the client's handshake (the server has
// indicated an extension not requested by the client), the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
// TODO
let extensions_header =
headers.typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
})?;
let extensions = match extensions_header {
None => Extensions::default(),
Some(agreed) => extensions
.ok_or(ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into()))?
.verify_agreed_on(agreed)
.map_err(ProtocolError::from)?,
};
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
// and this header field indicates the use of a subprotocol that was
@ -293,7 +329,7 @@ impl VerifyData {
}
}
Ok(response)
Ok((response, extensions))
}
}
@ -373,7 +409,7 @@ mod tests {
#[test]
fn request_formatting() {
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, None).unwrap();
let correct = construct_expected("localhost", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -381,7 +417,7 @@ mod tests {
#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
@ -389,11 +425,41 @@ mod tests {
#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(request).unwrap();
let (request, key) = generate_request(request, None).unwrap();
let correct = construct_expected("localhost:9001", &key);
assert_eq!(&request[..], &correct[..]);
}
#[cfg(feature = "deflate")]
#[test]
fn request_with_compression() {
use crate::extensions::{compression::deflate::DeflateConfig, ExtensionsConfig};
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
let (request, key) = generate_request(
request,
Some(&ExtensionsConfig {
permessage_deflate: Some(DeflateConfig::default()),
..ExtensionsConfig::default()
}),
)
.unwrap();
let correct = format!(
"\
GET /getCaseCount HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
sec-websocket-extensions: permessage-deflate; client_max_window_bits\r\n\
\r\n",
host = "localhost",
key = key
);
assert_eq!(String::try_from(request).unwrap(), &correct[..]);
}
#[test]
fn response_parsing() {
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
@ -405,6 +471,6 @@ mod tests {
#[test]
fn invalid_custom_request() {
let request = http::Request::builder().method("GET").body(()).unwrap();
assert!(generate_request(request).is_err());
assert!(generate_request(request, None).is_err());
}
}

View File

@ -6,6 +6,7 @@ use std::{
result::Result as StdResult,
};
use headers::{Header, HeaderMapExt};
use http::{
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
};
@ -20,6 +21,7 @@ use super::{
};
use crate::{
error::{Error, ProtocolError, Result},
extensions::{headers::SecWebsocketExtensions, Extensions},
protocol::{Role, WebSocket, WebSocketConfig},
};
@ -202,6 +204,8 @@ pub struct ServerHandshake<S, C> {
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_response: Option<ErrorResponse>,
// Negotiated extension context for server.
extensions: Extensions,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -219,6 +223,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback),
config,
error_response: None,
extensions: Extensions::default(),
_marker: PhantomData,
},
}
@ -240,7 +245,30 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
}
let response = create_response(&result)?;
let mut response = create_response(&result)?;
if let Some(extensions) =
result.headers().typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
})?
{
let extensions_config = self
.config
.ok_or_else(|| {
ProtocolError::InvalidHeader(
SecWebsocketExtensions::name().clone().into(),
)
})?
.extensions;
let (extensions, agreed) = extensions_config
.accept_offers(&extensions)
.map_err(ProtocolError::from)?;
if let Some(agreed) = agreed {
response.headers_mut().typed_insert(agreed)
};
self.extensions = extensions;
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result, response)
} else {
@ -283,7 +311,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(http::Response::from_parts(parts, body)));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
let websocket = WebSocket::from_raw_socket_with_extensions(
stream,
Role::Server,
self.config,
std::mem::take(&mut self.extensions),
);
ProcessingResult::Done(websocket)
}
}

View File

@ -186,6 +186,18 @@ impl<Stream> WebSocket<Stream> {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
pub fn from_raw_socket_with_extensions(
stream: Stream,
role: Role,
config: Option<WebSocketConfig>,
extensions: Extensions,
) -> Self {
let mut context = WebSocketContext::new(role, config);
context.extensions = extensions;
WebSocket { socket: stream, context }
}
/// Convert a raw socket into a WebSocket without performing a handshake.
///
/// Call this function if you're using Tungstenite as a part of a web framework
@ -199,10 +211,22 @@ impl<Stream> WebSocket<Stream> {
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
Self::from_partially_read_with_extensions(stream, part, role, config, Extensions::default())
}
pub(crate) fn from_partially_read_with_extensions(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Extensions,
) -> Self {
WebSocket {
socket: stream,
context: WebSocketContext::from_partially_read(part, role, config),
context: WebSocketContext::from_partially_read_with_extensions(
part, role, config, extensions,
),
}
}
@ -411,6 +435,26 @@ impl WebSocketContext {
)
}
/// Create a WebSocket context for a post-handshake stream with the enabled extensions.
///
/// Where [`WebSocketContext::from_partially_read`] infers the enabled
/// extensions from the [`WebSocketConfig`], this allows the caller to
/// explicitly sets the extensions in use for the connection.
pub(crate) fn from_partially_read_with_extensions(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
extensions: Extensions,
) -> Self {
let conf = config.unwrap_or_default();
Self::_new(
role,
FrameCodec::from_partially_read(part, conf.read_buffer_size),
conf,
extensions,
)
}
fn _new(
role: Role,
mut frame: FrameCodec,