From c62d48b596a7587c7e40e0cedcfd91f89d850b32 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Fri, 27 Oct 2023 10:39:58 -0600 Subject: [PATCH] Add fastwebsockets test and fix pre-handshake buffer issue --- Cargo.lock | 93 +++++++++++++++++++++++++++++++ Cargo.toml | 1 + src/lib.rs | 25 +++++---- src/stream.rs | 22 ++++++-- src/system_test/fastwebsockets.rs | 78 ++++++++++++++++++++++++++ src/system_test/mod.rs | 1 + 6 files changed, 204 insertions(+), 16 deletions(-) create mode 100644 src/system_test/fastwebsockets.rs create mode 100644 src/system_test/mod.rs diff --git a/Cargo.lock b/Cargo.lock index df2dc3d..1a6159a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "fastwebsockets" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e6185b6dc9dddc4db0dedd2e213047e93bcbf7a0fb092abc4c4e4f3195efdb4" +dependencies = [ + "rand", + "simdutf8", + "thiserror", + "tokio", + "utf-8", +] + [[package]] name = "futures" version = "0.3.28" @@ -187,6 +200,17 @@ dependencies = [ "slab", ] +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "gimli" version = "0.28.0" @@ -371,6 +395,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro-crate" version = "1.3.1" @@ -399,6 +429,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.3.5" @@ -527,6 +587,7 @@ dependencies = [ name = "rustls-tokio-stream" version = "0.2.1" dependencies = [ + "fastwebsockets", "futures", "ntest", "rstest", @@ -576,6 +637,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simdutf8" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" + [[package]] name = "slab" version = "0.4.9" @@ -629,6 +696,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + [[package]] name = "tokio" version = "1.32.0" @@ -688,6 +775,12 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 4b405b0..cc0d74e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,4 @@ rustls-pemfile = "1.0" rustls = { version = "0.21", features = [ "dangerous_configuration" ] } ntest = "0.9" rstest = "0.18" +fastwebsockets = "0.4.4" diff --git a/src/lib.rs b/src/lib.rs index 6f1fbe1..1c9b601 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,9 @@ mod connection_stream; mod handshake; mod stream; +#[cfg(test)] +mod system_test; + pub use handshake::handshake_task; pub use stream::TlsHandshake; pub use stream::TlsStream; @@ -21,22 +24,24 @@ struct TestOptions { } macro_rules! trace { - ($($args:expr),+) => { - #[cfg(feature="trace")] - { - println!($($args),+); - } - #[cfg(not(feature="trace"))] - { - format!($($args),+); - } - }; + ($($args:expr),+) => { + #[cfg(feature="trace")] + { + println!($($args),+); + } + #[cfg(not(feature="trace"))] + { + format!($($args),+); + } + }; } pub(crate) use trace; #[cfg(test)] mod tests { + pub use super::stream::tests::tls_pair; + pub use super::stream::tests::tls_pair_buffer_size; use rustls::client::ServerCertVerified; use rustls::client::ServerCertVerifier; use rustls::Certificate; diff --git a/src/stream.rs b/src/stream.rs index 677d503..712f2a2 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -512,9 +512,13 @@ impl AsyncWrite for TlsStream { Poll::Pending } else { trace!("write buf"); - let buf = &buf[0..remaining]; - write_buf.extend_from_slice(buf); - Poll::Ready(Ok(buf.len())) + if buf.len() <= remaining { + write_buf.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } else { + write_buf.extend_from_slice(&buf[0..remaining]); + Poll::Ready(Ok(remaining)) + } } } else { trace!("write buf"); @@ -594,7 +598,7 @@ impl Drop for TlsStream { } #[cfg(test)] -mod tests { +pub(super) mod tests { use super::*; use futures::stream::FuturesUnordered; use futures::FutureExt; @@ -705,7 +709,13 @@ mod tests { (server, client) } - async fn tls_pair() -> (TlsStream, TlsStream) { + pub async fn tls_pair() -> (TlsStream, TlsStream) { + tls_pair_buffer_size(None).await + } + + pub async fn tls_pair_buffer_size( + buffer_size: Option, + ) -> (TlsStream, TlsStream) { let (server, client) = tcp_pair().await; let server = TlsStream::new_server_side(server, server_config(&[]).into(), None); @@ -713,7 +723,7 @@ mod tests { client, client_config(&[]).into(), "example.com".try_into().unwrap(), - None, + buffer_size, ); (server, client) diff --git a/src/system_test/fastwebsockets.rs b/src/system_test/fastwebsockets.rs new file mode 100644 index 0000000..c6a41d0 --- /dev/null +++ b/src/system_test/fastwebsockets.rs @@ -0,0 +1,78 @@ +use std::num::NonZeroUsize; + +use fastwebsockets::Frame; +use fastwebsockets::OpCode; +use fastwebsockets::Payload; +use fastwebsockets::Role; +use fastwebsockets::WebSocket; +use rstest::rstest; + +const LARGE_PAYLOAD: [u8; 48 * 1024] = [0xff; 48 * 1024]; +const SMALL_PAYLOAD: [u8; 16] = [0xff; 16]; + +#[rstest] +#[case(false, false, false)] +#[case(false, false, true)] +#[case(false, true, false)] +#[case(false, true, true)] +#[case(true, false, false)] +#[case(true, false, true)] +#[case(true, true, false)] +#[case(true, true, true)] +#[tokio::test] +async fn test_fastwebsockets( + #[case] handshake: bool, + #[case] buffer_limit: bool, + #[case] large_payload: bool, +) { + let payload = if large_payload { + LARGE_PAYLOAD.as_slice() + } else { + SMALL_PAYLOAD.as_slice() + }; + let (mut client, mut server) = if buffer_limit { + crate::tests::tls_pair_buffer_size(Some( + NonZeroUsize::try_from(1024).unwrap(), + )) + .await + } else { + crate::tests::tls_pair().await + }; + if handshake { + client.handshake().await.expect("failed handshake"); + server.handshake().await.expect("failed handshake"); + } + + let a = tokio::spawn(async { + let mut ws = WebSocket::after_handshake(server, Role::Server); + ws.set_auto_close(true); + for _ in 0..1000 { + ws.write_frame(Frame::binary(Payload::Borrowed(payload))) + .await + .expect("failed to write"); + } + let frame = ws.read_frame().await.expect("failed to read"); + assert_eq!(frame.payload.len(), payload.len()); + ws.write_frame(Frame::close(1000, &[])) + .await + .expect("failed to close"); + let frame = ws.read_frame().await.expect("failed to read"); + assert_eq!(frame.opcode, OpCode::Close); + }); + let b = tokio::spawn(async { + let mut ws = WebSocket::after_handshake(client, Role::Client); + ws.set_auto_close(true); + for _ in 0..1000 { + let frame = ws.read_frame().await.expect("failed to read"); + assert_eq!(frame.payload.len(), payload.len()); + } + ws.write_frame(Frame::binary(Payload::Borrowed(payload))) + .await + .expect("failed to write"); + let frame = ws.read_frame().await.expect("failed to read"); + assert_eq!(frame.opcode, OpCode::Close); + }); + + a.await.expect("failed to join"); + b.await.expect("failed to join"); +} diff --git a/src/system_test/mod.rs b/src/system_test/mod.rs new file mode 100644 index 0000000..8bf2cbb --- /dev/null +++ b/src/system_test/mod.rs @@ -0,0 +1 @@ +mod fastwebsockets;