Perform extension negotiation during handshaking
Also enable construction of WebSocket instances with configured extensions. Based on work by Benjamin Swart <Benjaminswart@email.cz>
This commit is contained in:
parent
8a7ef29ef7
commit
9e8e06816e
@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
|
||||
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
|
||||
otherwise you won't be able to communicate with the TLS endpoints.
|
||||
|
||||
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
|
||||
|
||||
Testing
|
||||
-------
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
use std::{io, result, str, string};
|
||||
|
||||
use crate::{
|
||||
extensions::compression::CompressionError,
|
||||
extensions::{compression::CompressionError, ExtensionsError},
|
||||
protocol::{frame::coding::Data, Message},
|
||||
};
|
||||
#[cfg(feature = "handshake")]
|
||||
@ -197,6 +197,9 @@ pub enum ProtocolError {
|
||||
/// The `Sec-WebSocket-Protocol` header was invalid
|
||||
#[error("SubProtocol error: {0}")]
|
||||
SecWebSocketSubProtocolError(SubProtocolError),
|
||||
/// The `Sec-WebSocket-Extensions` header is invalid.
|
||||
#[error("Invalid \"Sec-WebSocket-Extensions\" header: {0}")]
|
||||
InvalidExtensionsHeader(#[from] ExtensionsError),
|
||||
/// Garbage data encountered after client request.
|
||||
#[error("Junk after client request")]
|
||||
JunkAfterRequest,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
//! WebSocket extensions.
|
||||
// Only `permessage-deflate` is supported at the moment.
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytes::Bytes;
|
||||
use thiserror::Error;
|
||||
|
||||
@ -5,6 +5,7 @@ use std::{
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use headers::{Header, HeaderMapExt};
|
||||
use http::{
|
||||
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||
};
|
||||
@ -19,6 +20,7 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result, SubProtocolError, UrlError},
|
||||
extensions::{headers::SecWebsocketExtensions, Extensions, ExtensionsConfig},
|
||||
protocol::{Role, WebSocket, WebSocketConfig},
|
||||
};
|
||||
|
||||
@ -58,7 +60,7 @@ impl<S: Read + Write> ClientHandshake<S> {
|
||||
|
||||
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
||||
// Also extract the key from it (it must be present in a correct request).
|
||||
let (request, key) = generate_request(request)?;
|
||||
let (request, key) = generate_request(request, config.as_ref().map(|w| &w.extensions))?;
|
||||
|
||||
let machine = HandshakeMachine::start_write(stream, request);
|
||||
|
||||
@ -89,7 +91,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
|
||||
}
|
||||
StageResult::DoneReading { stream, result, tail } => {
|
||||
let result = match self.verify_data.verify_response(result) {
|
||||
let (result, extensions) = match self
|
||||
.verify_data
|
||||
.verify_response(result, self.config.as_ref().map(|c| &c.extensions))
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(Error::Http(mut e)) => {
|
||||
*e.body_mut() = Some(tail);
|
||||
@ -99,8 +104,13 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||
};
|
||||
|
||||
debug!("Client handshake done.");
|
||||
let websocket =
|
||||
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
|
||||
let websocket = WebSocket::from_partially_read_with_extensions(
|
||||
stream,
|
||||
tail,
|
||||
Role::Client,
|
||||
self.config,
|
||||
extensions,
|
||||
);
|
||||
ProcessingResult::Done((websocket, result))
|
||||
}
|
||||
})
|
||||
@ -108,7 +118,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||
}
|
||||
|
||||
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
|
||||
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
||||
pub fn generate_request(
|
||||
mut request: Request,
|
||||
extensions: Option<&ExtensionsConfig>,
|
||||
) -> Result<(Vec<u8>, String)> {
|
||||
let mut req = Vec::new();
|
||||
write!(
|
||||
req,
|
||||
@ -160,6 +173,14 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(header) = extensions
|
||||
.map(ExtensionsConfig::generate_offers)
|
||||
.map(SecWebsocketExtensions::new)
|
||||
.filter(|header| header.len() != 0)
|
||||
{
|
||||
headers.append(SecWebsocketExtensions::name(), header.header_value());
|
||||
}
|
||||
|
||||
// Now we must ensure that the headers that we've written once are not anymore present in the map.
|
||||
// If they do, then the request is invalid (some headers are duplicated there for some reason).
|
||||
let insensitive: Vec<String> =
|
||||
@ -219,7 +240,11 @@ struct VerifyData {
|
||||
}
|
||||
|
||||
impl VerifyData {
|
||||
pub fn verify_response(&self, response: Response) -> Result<Response> {
|
||||
pub fn verify_response(
|
||||
&self,
|
||||
response: Response,
|
||||
extensions: Option<&ExtensionsConfig>,
|
||||
) -> Result<(Response, Extensions)> {
|
||||
// 1. If the status code received from the server is not 101, the
|
||||
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
|
||||
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||
@ -264,7 +289,18 @@ impl VerifyData {
|
||||
// that was not present in the client's handshake (the server has
|
||||
// indicated an extension not requested by the client), the client
|
||||
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
||||
// TODO
|
||||
let extensions_header =
|
||||
headers.typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
|
||||
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
|
||||
})?;
|
||||
|
||||
let extensions = match extensions_header {
|
||||
None => Extensions::default(),
|
||||
Some(agreed) => extensions
|
||||
.ok_or(ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into()))?
|
||||
.verify_agreed_on(agreed)
|
||||
.map_err(ProtocolError::from)?,
|
||||
};
|
||||
|
||||
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
|
||||
// and this header field indicates the use of a subprotocol that was
|
||||
@ -293,7 +329,7 @@ impl VerifyData {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
Ok((response, extensions))
|
||||
}
|
||||
}
|
||||
|
||||
@ -373,7 +409,7 @@ mod tests {
|
||||
#[test]
|
||||
fn request_formatting() {
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, None).unwrap();
|
||||
let correct = construct_expected("localhost", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
@ -381,7 +417,7 @@ mod tests {
|
||||
#[test]
|
||||
fn request_formatting_with_host() {
|
||||
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, None).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
@ -389,11 +425,41 @@ mod tests {
|
||||
#[test]
|
||||
fn request_formatting_with_at() {
|
||||
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, None).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn request_with_compression() {
|
||||
use crate::extensions::{compression::deflate::DeflateConfig, ExtensionsConfig};
|
||||
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(
|
||||
request,
|
||||
Some(&ExtensionsConfig {
|
||||
permessage_deflate: Some(DeflateConfig::default()),
|
||||
..ExtensionsConfig::default()
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
let correct = format!(
|
||||
"\
|
||||
GET /getCaseCount HTTP/1.1\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n\
|
||||
sec-websocket-extensions: permessage-deflate; client_max_window_bits\r\n\
|
||||
\r\n",
|
||||
host = "localhost",
|
||||
key = key
|
||||
);
|
||||
assert_eq!(String::try_from(request).unwrap(), &correct[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_parsing() {
|
||||
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
|
||||
@ -405,6 +471,6 @@ mod tests {
|
||||
#[test]
|
||||
fn invalid_custom_request() {
|
||||
let request = http::Request::builder().method("GET").body(()).unwrap();
|
||||
assert!(generate_request(request).is_err());
|
||||
assert!(generate_request(request, None).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ use std::{
|
||||
result::Result as StdResult,
|
||||
};
|
||||
|
||||
use headers::{Header, HeaderMapExt};
|
||||
use http::{
|
||||
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||
};
|
||||
@ -20,6 +21,7 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result},
|
||||
extensions::{headers::SecWebsocketExtensions, Extensions},
|
||||
protocol::{Role, WebSocket, WebSocketConfig},
|
||||
};
|
||||
|
||||
@ -202,6 +204,8 @@ pub struct ServerHandshake<S, C> {
|
||||
config: Option<WebSocketConfig>,
|
||||
/// Error code/flag. If set, an error will be returned after sending response to the client.
|
||||
error_response: Option<ErrorResponse>,
|
||||
// Negotiated extension context for server.
|
||||
extensions: Extensions,
|
||||
/// Internal stream type.
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
@ -219,6 +223,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
|
||||
callback: Some(callback),
|
||||
config,
|
||||
error_response: None,
|
||||
extensions: Extensions::default(),
|
||||
_marker: PhantomData,
|
||||
},
|
||||
}
|
||||
@ -240,7 +245,30 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
||||
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
|
||||
}
|
||||
|
||||
let response = create_response(&result)?;
|
||||
let mut response = create_response(&result)?;
|
||||
if let Some(extensions) =
|
||||
result.headers().typed_try_get::<SecWebsocketExtensions>().map_err(|_| {
|
||||
ProtocolError::InvalidHeader(SecWebsocketExtensions::name().clone().into())
|
||||
})?
|
||||
{
|
||||
let extensions_config = self
|
||||
.config
|
||||
.ok_or_else(|| {
|
||||
ProtocolError::InvalidHeader(
|
||||
SecWebsocketExtensions::name().clone().into(),
|
||||
)
|
||||
})?
|
||||
.extensions;
|
||||
let (extensions, agreed) = extensions_config
|
||||
.accept_offers(&extensions)
|
||||
.map_err(ProtocolError::from)?;
|
||||
|
||||
if let Some(agreed) = agreed {
|
||||
response.headers_mut().typed_insert(agreed)
|
||||
};
|
||||
self.extensions = extensions;
|
||||
}
|
||||
|
||||
let callback_result = if let Some(callback) = self.callback.take() {
|
||||
callback.on_request(&result, response)
|
||||
} else {
|
||||
@ -283,7 +311,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
||||
return Err(Error::Http(http::Response::from_parts(parts, body)));
|
||||
} else {
|
||||
debug!("Server handshake done.");
|
||||
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
|
||||
let websocket = WebSocket::from_raw_socket_with_extensions(
|
||||
stream,
|
||||
Role::Server,
|
||||
self.config,
|
||||
std::mem::take(&mut self.extensions),
|
||||
);
|
||||
ProcessingResult::Done(websocket)
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,6 +186,18 @@ impl<Stream> WebSocket<Stream> {
|
||||
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
|
||||
}
|
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
pub fn from_raw_socket_with_extensions(
|
||||
stream: Stream,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Extensions,
|
||||
) -> Self {
|
||||
let mut context = WebSocketContext::new(role, config);
|
||||
context.extensions = extensions;
|
||||
WebSocket { socket: stream, context }
|
||||
}
|
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
///
|
||||
/// Call this function if you're using Tungstenite as a part of a web framework
|
||||
@ -199,10 +211,22 @@ impl<Stream> WebSocket<Stream> {
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Self {
|
||||
Self::from_partially_read_with_extensions(stream, part, role, config, Extensions::default())
|
||||
}
|
||||
|
||||
pub(crate) fn from_partially_read_with_extensions(
|
||||
stream: Stream,
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Extensions,
|
||||
) -> Self {
|
||||
WebSocket {
|
||||
socket: stream,
|
||||
context: WebSocketContext::from_partially_read(part, role, config),
|
||||
context: WebSocketContext::from_partially_read_with_extensions(
|
||||
part, role, config, extensions,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -411,6 +435,26 @@ impl WebSocketContext {
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a WebSocket context for a post-handshake stream with the enabled extensions.
|
||||
///
|
||||
/// Where [`WebSocketContext::from_partially_read`] infers the enabled
|
||||
/// extensions from the [`WebSocketConfig`], this allows the caller to
|
||||
/// explicitly sets the extensions in use for the connection.
|
||||
pub(crate) fn from_partially_read_with_extensions(
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Extensions,
|
||||
) -> Self {
|
||||
let conf = config.unwrap_or_default();
|
||||
Self::_new(
|
||||
role,
|
||||
FrameCodec::from_partially_read(part, conf.read_buffer_size),
|
||||
conf,
|
||||
extensions,
|
||||
)
|
||||
}
|
||||
|
||||
fn _new(
|
||||
role: Role,
|
||||
mut frame: FrameCodec,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user