Skip to content

Commit

Permalink
Add fastwebsockets test and fix pre-handshake buffer issue
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 27, 2023
1 parent cafcb1c commit c62d48b
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 16 deletions.
93 changes: 93 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ rustls-pemfile = "1.0"
rustls = { version = "0.21", features = [ "dangerous_configuration" ] }
ntest = "0.9"
rstest = "0.18"
fastwebsockets = "0.4.4"
25 changes: 15 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ mod connection_stream;
mod handshake;
mod stream;

#[cfg(test)]
mod system_test;

pub use handshake::handshake_task;
pub use stream::TlsHandshake;
pub use stream::TlsStream;
Expand All @@ -21,22 +24,24 @@ struct TestOptions {
}

macro_rules! trace {
($($args:expr),+) => {
#[cfg(feature="trace")]
{
println!($($args),+);
}
#[cfg(not(feature="trace"))]
{
format!($($args),+);
}
};
($($args:expr),+) => {
#[cfg(feature="trace")]
{
println!($($args),+);
}
#[cfg(not(feature="trace"))]
{
format!($($args),+);
}
};
}

pub(crate) use trace;

#[cfg(test)]
mod tests {
pub use super::stream::tests::tls_pair;
pub use super::stream::tests::tls_pair_buffer_size;
use rustls::client::ServerCertVerified;
use rustls::client::ServerCertVerifier;
use rustls::Certificate;
Expand Down
22 changes: 16 additions & 6 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,13 @@ impl AsyncWrite for TlsStream {
Poll::Pending
} else {
trace!("write buf");
let buf = &buf[0..remaining];
write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
if buf.len() <= remaining {
write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
} else {
write_buf.extend_from_slice(&buf[0..remaining]);
Poll::Ready(Ok(remaining))
}
}
} else {
trace!("write buf");
Expand Down Expand Up @@ -594,7 +598,7 @@ impl Drop for TlsStream {
}

#[cfg(test)]
mod tests {
pub(super) mod tests {
use super::*;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
Expand Down Expand Up @@ -705,15 +709,21 @@ mod tests {
(server, client)
}

async fn tls_pair() -> (TlsStream, TlsStream) {
pub async fn tls_pair() -> (TlsStream, TlsStream) {
tls_pair_buffer_size(None).await
}

pub async fn tls_pair_buffer_size(
buffer_size: Option<NonZeroUsize>,
) -> (TlsStream, TlsStream) {
let (server, client) = tcp_pair().await;
let server =
TlsStream::new_server_side(server, server_config(&[]).into(), None);
let client = TlsStream::new_client_side(
client,
client_config(&[]).into(),
"example.com".try_into().unwrap(),
None,
buffer_size,
);

(server, client)
Expand Down
78 changes: 78 additions & 0 deletions src/system_test/fastwebsockets.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::num::NonZeroUsize;

use fastwebsockets::Frame;
use fastwebsockets::OpCode;
use fastwebsockets::Payload;
use fastwebsockets::Role;
use fastwebsockets::WebSocket;
use rstest::rstest;

const LARGE_PAYLOAD: [u8; 48 * 1024] = [0xff; 48 * 1024];
const SMALL_PAYLOAD: [u8; 16] = [0xff; 16];

#[rstest]
#[case(false, false, false)]
#[case(false, false, true)]
#[case(false, true, false)]
#[case(false, true, true)]
#[case(true, false, false)]
#[case(true, false, true)]
#[case(true, true, false)]
#[case(true, true, true)]
#[tokio::test]
async fn test_fastwebsockets(
#[case] handshake: bool,
#[case] buffer_limit: bool,
#[case] large_payload: bool,
) {
let payload = if large_payload {
LARGE_PAYLOAD.as_slice()
} else {
SMALL_PAYLOAD.as_slice()
};
let (mut client, mut server) = if buffer_limit {
crate::tests::tls_pair_buffer_size(Some(
NonZeroUsize::try_from(1024).unwrap(),
))
.await
} else {
crate::tests::tls_pair().await
};
if handshake {
client.handshake().await.expect("failed handshake");
server.handshake().await.expect("failed handshake");
}

let a = tokio::spawn(async {
let mut ws = WebSocket::after_handshake(server, Role::Server);
ws.set_auto_close(true);
for _ in 0..1000 {
ws.write_frame(Frame::binary(Payload::Borrowed(payload)))
.await
.expect("failed to write");
}
let frame = ws.read_frame().await.expect("failed to read");
assert_eq!(frame.payload.len(), payload.len());
ws.write_frame(Frame::close(1000, &[]))
.await
.expect("failed to close");
let frame = ws.read_frame().await.expect("failed to read");
assert_eq!(frame.opcode, OpCode::Close);
});
let b = tokio::spawn(async {
let mut ws = WebSocket::after_handshake(client, Role::Client);
ws.set_auto_close(true);
for _ in 0..1000 {
let frame = ws.read_frame().await.expect("failed to read");
assert_eq!(frame.payload.len(), payload.len());
}
ws.write_frame(Frame::binary(Payload::Borrowed(payload)))
.await
.expect("failed to write");
let frame = ws.read_frame().await.expect("failed to read");
assert_eq!(frame.opcode, OpCode::Close);
});

a.await.expect("failed to join");
b.await.expect("failed to join");
}
1 change: 1 addition & 0 deletions src/system_test/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod fastwebsockets;

0 comments on commit c62d48b

Please sign in to comment.