Skip to content

Commit

Permalink
Add Options for controlling max frame size of incoming messages
Browse files Browse the repository at this point in the history
  • Loading branch information
emilk committed Feb 26, 2024
1 parent b503681 commit 8a94828
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 40 deletions.
41 changes: 33 additions & 8 deletions ewebsock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
//!
//! Usage:
//! ``` no_run
//! let (mut sender, receiver) = ewebsock::connect("ws://example.com").unwrap();
//! let options = ewebsock::Options::default();
//! let (mut sender, receiver) = ewebsock::connect("ws://example.com", options).unwrap();
//! sender.send(ewebsock::WsMessage::Text("Hello!".into()));
//! while let Some(event) = receiver.try_recv() {
//! println!("Received {:?}", event);
Expand Down Expand Up @@ -31,6 +32,9 @@ mod native_tungstenite_tokio;
#[cfg(feature = "tokio")]
pub use native_tungstenite_tokio::*;

#[cfg(not(target_arch = "wasm32"))]
mod tungstenite_common;

#[cfg(target_arch = "wasm32")]
mod web;

Expand Down Expand Up @@ -117,6 +121,26 @@ pub type Result<T> = std::result::Result<T, Error>;

pub(crate) type EventHandler = Box<dyn Send + Fn(WsEvent) -> std::ops::ControlFlow<()>>;

/// Options for a connection.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Options {
/// The maximum size of a single incoming message frame, in bytes.
///
/// The primary reason for setting this to something other than [`usize::MAX`] is
/// to prevent a malicious server from eating up all your RAM.
///
/// Ignored on Web.
pub max_incoming_frame_size: usize,
}

impl Default for Options {
fn default() -> Self {
Self {
max_incoming_frame_size: 64 * 1024 * 1024,
}
}
}

/// Connect to the given URL, and return a sender and receiver.
///
/// This is a wrapper around [`ws_connect`].
Expand All @@ -127,9 +151,9 @@ pub(crate) type EventHandler = Box<dyn Send + Fn(WsEvent) -> std::ops::ControlFl
///
/// See also the [`connect_with_wakeup`] function,
/// and the more advanced [`ws_connect`].
pub fn connect(url: impl Into<String>) -> Result<(WsSender, WsReceiver)> {
pub fn connect(url: impl Into<String>, options: Options) -> Result<(WsSender, WsReceiver)> {
let (ws_receiver, on_event) = WsReceiver::new();
let ws_sender = ws_connect(url.into(), on_event)?;
let ws_sender = ws_connect(url.into(), options, on_event)?;
Ok((ws_sender, ws_receiver))
}

Expand All @@ -146,10 +170,11 @@ pub fn connect(url: impl Into<String>) -> Result<(WsSender, WsReceiver)> {
/// Note that you have to wait for [`WsEvent::Opened`] before sending messages.
pub fn connect_with_wakeup(
url: impl Into<String>,
options: Options,
wake_up: impl Fn() + Send + Sync + 'static,
) -> Result<(WsSender, WsReceiver)> {
let (receiver, on_event) = WsReceiver::new_with_callback(wake_up);
let sender = ws_connect(url.into(), on_event)?;
let sender = ws_connect(url.into(), options, on_event)?;
Ok((sender, receiver))
}

Expand All @@ -160,8 +185,8 @@ pub fn connect_with_wakeup(
/// # Errors
/// * On native: failure to spawn a thread.
/// * On web: failure to use `WebSocket` API.
pub fn ws_connect(url: String, on_event: EventHandler) -> Result<WsSender> {
ws_connect_impl(url, on_event)
pub fn ws_connect(url: String, options: Options, on_event: EventHandler) -> Result<WsSender> {
ws_connect_impl(url, options, on_event)
}

/// Connect and call the given event handler on each received event.
Expand All @@ -174,6 +199,6 @@ pub fn ws_connect(url: String, on_event: EventHandler) -> Result<WsSender> {
/// # Errors
/// * On native: failure to spawn receiver thread.
/// * On web: failure to use `WebSocket` API.
pub fn ws_receive(url: String, on_event: EventHandler) -> Result<()> {
ws_receive_impl(url, on_event)
pub fn ws_receive(url: String, options: Options, on_event: EventHandler) -> Result<()> {
ws_receive_impl(url, options, on_event)
}
50 changes: 31 additions & 19 deletions ewebsock/src/native_tungstenite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::sync::mpsc::{Receiver, TryRecvError};

use crate::{EventHandler, Result, WsEvent, WsMessage};
use crate::{EventHandler, Options, Result, WsEvent, WsMessage};

/// This is how you send [`WsMessage`]s to the server.
///
Expand Down Expand Up @@ -47,11 +47,11 @@ impl WsSender {
}
}

pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()> {
pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHandler) -> Result<()> {
std::thread::Builder::new()
.name("ewebsock".to_owned())
.spawn(move || {
if let Err(err) = ws_receiver_blocking(&url, &on_event) {
if let Err(err) = ws_receiver_blocking(&url, options, &on_event) {
on_event(WsEvent::Error(err));
} else {
log::debug!("WebSocket connection closed.");
Expand All @@ -64,17 +64,21 @@ pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()>

/// Connect and call the given event handler on each received event.
///
/// Blocking version of [`ws_receive`], only avilable on native.
/// Blocking version of [`ws_receive`], only available on native.
///
/// # Errors
/// All errors are returned to the caller, and NOT reported via `on_event`.
pub fn ws_receiver_blocking(url: &str, on_event: &EventHandler) -> Result<()> {
let (mut socket, response) = match tungstenite::connect(url) {
Ok(result) => result,
Err(err) => {
return Err(format!("Connect: {err}"));
}
};
pub fn ws_receiver_blocking(url: &str, options: Options, on_event: &EventHandler) -> Result<()> {
let config = tungstenite::protocol::WebSocketConfig::from(options);
let max_redirects = 3; // tungstenite default

let (mut socket, response) =
match tungstenite::client::connect_with_config(url, Some(config), max_redirects) {
Ok(result) => result,
Err(err) => {
return Err(format!("Connect: {err}"));
}
};

log::debug!("WebSocket HTTP response code: {}", response.status());
log::trace!(
Expand Down Expand Up @@ -115,13 +119,17 @@ pub fn ws_receiver_blocking(url: &str, on_event: &EventHandler) -> Result<()> {
}
}

pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result<WsSender> {
pub(crate) fn ws_connect_impl(
url: String,
options: Options,
on_event: EventHandler,
) -> Result<WsSender> {
let (tx, rx) = std::sync::mpsc::channel();

std::thread::Builder::new()
.name("ewebsock".to_owned())
.spawn(move || {
if let Err(err) = ws_connect_blocking(&url, &on_event, &rx) {
if let Err(err) = ws_connect_blocking(&url, options, &on_event, &rx) {
on_event(WsEvent::Error(err));
} else {
log::debug!("WebSocket connection closed.");
Expand All @@ -140,15 +148,19 @@ pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result<WsS
/// All errors are returned to the caller, and NOT reported via `on_event`.
pub fn ws_connect_blocking(
url: &str,
options: Options,
on_event: &EventHandler,
rx: &Receiver<WsMessage>,
) -> Result<()> {
let (mut socket, response) = match tungstenite::connect(url) {
Ok(result) => result,
Err(err) => {
return Err(format!("Connect: {err}"));
}
};
let config = tungstenite::protocol::WebSocketConfig::from(options);
let max_redirects = 3; // tungstenite default
let (mut socket, response) =
match tungstenite::client::connect_with_config(url, Some(config), max_redirects) {
Ok(result) => result,
Err(err) => {
return Err(format!("Connect: {err}"));
}
};

log::debug!("WebSocket HTTP response code: {}", response.status());
log::trace!(
Expand Down
29 changes: 21 additions & 8 deletions ewebsock/src/native_tungstenite_tokio.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{EventHandler, Result, WsEvent, WsMessage};
use crate::{EventHandler, Options, Result, WsEvent, WsMessage};

/// This is how you send [`WsMessage`]s to the server.
///
Expand Down Expand Up @@ -45,12 +45,21 @@ impl WsSender {

async fn ws_connect_async(
url: String,
options: Options,
outgoing_messages_stream: impl futures::Stream<Item = WsMessage>,
on_event: EventHandler,
) {
use futures::StreamExt as _;

let (ws_stream, _) = match tokio_tungstenite::connect_async(url).await {
let config = tungstenite::protocol::WebSocketConfig::from(options);
let disable_nagle = false; // God damn everyone who adds negations to the names of their variables
let (ws_stream, _) = match tokio_tungstenite::connect_async_with_config(
url,
Some(config),
disable_nagle,
)
.await
{
Ok(result) => result,
Err(err) => {
on_event(WsEvent::Error(err.to_string()));
Expand Down Expand Up @@ -106,12 +115,16 @@ async fn ws_connect_async(
}

#[allow(clippy::unnecessary_wraps)]
pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result<WsSender> {
Ok(ws_connect_native(url, on_event))
pub(crate) fn ws_connect_impl(
url: String,
options: Options,
on_event: EventHandler,
) -> Result<WsSender> {
Ok(ws_connect_native(url, options, on_event))
}

/// Like [`ws_connect`], but cannot fail. Only available on native builds.

Check warning on line 126 in ewebsock/src/native_tungstenite_tokio.rs

View workflow job for this annotation

GitHub Actions / Rust format, cranky, check, test, doc

unresolved link to `ws_connect`

Check warning on line 126 in ewebsock/src/native_tungstenite_tokio.rs

View workflow job for this annotation

GitHub Actions / Rust format, cranky, check, test, doc

unresolved link to `ws_connect`
fn ws_connect_native(url: String, on_event: EventHandler) -> WsSender {
fn ws_connect_native(url: String, options: Options, on_event: EventHandler) -> WsSender {
let (tx, mut rx) = tokio::sync::mpsc::channel(1000);

let outgoing_messages_stream = async_stream::stream! {
Expand All @@ -122,12 +135,12 @@ fn ws_connect_native(url: String, on_event: EventHandler) -> WsSender {
};

tokio::spawn(async move {
ws_connect_async(url.clone(), outgoing_messages_stream, on_event).await;
ws_connect_async(url.clone(), options, outgoing_messages_stream, on_event).await;
log::debug!("WS connection finished.");
});
WsSender { tx: Some(tx) }
}

pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()> {
ws_connect_impl(url, on_event).map(|sender| sender.forget())
pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHandler) -> Result<()> {
ws_connect_impl(url, options, on_event).map(|sender| sender.forget())
}
16 changes: 16 additions & 0 deletions ewebsock/src/tungstenite_common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
impl From<crate::Options> for tungstenite::protocol::WebSocketConfig {
fn from(options: crate::Options) -> Self {
let crate::Options {
max_incoming_frame_size,
} = options;

tungstenite::protocol::WebSocketConfig {
max_frame_size: if max_incoming_frame_size == usize::MAX {
None
} else {
Some(max_incoming_frame_size)
},
..Default::default()
}
}
}
12 changes: 8 additions & 4 deletions ewebsock/src/web.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{EventHandler, Result, WsEvent, WsMessage};
use crate::{EventHandler, Options, Result, WsEvent, WsMessage};

#[allow(clippy::needless_pass_by_value)]
fn string_from_js_value(s: wasm_bindgen::JsValue) -> String {
Expand Down Expand Up @@ -63,11 +63,15 @@ impl WsSender {
}
}

pub(crate) fn ws_receive_impl(url: String, on_event: EventHandler) -> Result<()> {
ws_connect_impl(url, on_event).map(|sender| sender.forget())
pub(crate) fn ws_receive_impl(url: String, options: Options, on_event: EventHandler) -> Result<()> {
ws_connect_impl(url, options, on_event).map(|sender| sender.forget())
}

pub(crate) fn ws_connect_impl(url: String, on_event: EventHandler) -> Result<WsSender> {
pub(crate) fn ws_connect_impl(
url: String,
_ignored_options: Options,
on_event: EventHandler,
) -> Result<WsSender> {
// Based on https://rustwasm.github.io/wasm-bindgen/examples/websockets.html

use wasm_bindgen::closure::Closure;
Expand Down
2 changes: 1 addition & 1 deletion example_app/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl eframe::App for ExampleApp {
impl ExampleApp {
fn connect(&mut self, ctx: egui::Context) {
let wakeup = move || ctx.request_repaint(); // wake up UI thread on new message
match ewebsock::connect_with_wakeup(&self.url, wakeup) {
match ewebsock::connect_with_wakeup(&self.url, Default::default(), wakeup) {
Ok((ws_sender, ws_receiver)) => {
self.frontend = Some(FrontEnd::new(ws_sender, ws_receiver));
self.error.clear();
Expand Down

0 comments on commit 8a94828

Please sign in to comment.