Skip to content

Commit

Permalink
feat: write_vectored
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 30, 2023
1 parent 70e6af3 commit c432167
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 10 deletions.
71 changes: 71 additions & 0 deletions src/connection_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,77 @@ impl ConnectionStream {
res
}

pub fn poll_write_vectored(
&mut self,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
// Zero-length writes always succeed
if bufs.is_empty() {
self.wr_waker.take();
return Poll::Ready(Ok(0));
}

// Writes after shutdown return NotConnected
if self.wants_close_sent {
self.wr_waker.take();
return Poll::Ready(Err(ErrorKind::NotConnected.into()));
}

// First prepare to write
let res = loop {
let write = self.poll_write_only(PollContext::Explicit(cx));
match write {
// No room to write
StreamProgress::RegisteredWaker => break Poll::Pending,
// We wrote something, so let's loop again
StreamProgress::MadeProgress => continue,
// Wedged on an error
StreamProgress::Error => {
break Poll::Ready(Err(self.wr_error.unwrap().into()))
}
// No current write interest, so let's generate some
StreamProgress::NoInterest => {
// Write it
let n = self
.tls
.writer()
.write_vectored(bufs)
.expect("Write will never fail");
trace!("w={n}");
assert!(n > 0);
// Drain what we can
while self.poll_write_only(PollContext::Explicit(cx))
== StreamProgress::MadeProgress
{}
// And then return what we wrote
break Poll::Ready(Ok(n));
}
};
};

// Then read until we lose interest
while self.poll_read_only(PollContext::Implicit(cx))
== StreamProgress::MadeProgress
{}

if res.is_ready() {
self.wr_waker.take();
} else {
// Replace the waker unless we already have it
if !self
.wr_waker
.as_ref()
.map(|w| cx.waker().will_wake(w))
.unwrap_or_default()
{
self.wr_waker = Some(cx.waker().clone());
}
}

res
}

/// Fully write a buffer to the TLS stream, expecting it to write fully and not fail.
pub(crate) fn write_buf_fully(&mut self, buf: &[u8]) {
let n = self.tls.writer().write(buf).expect("Write will never fail");
Expand Down
68 changes: 68 additions & 0 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::cell::Cell;
use std::fmt::Debug;
use std::io;
use std::io::ErrorKind;
use std::io::Write;
use tokio::task::JoinError;

use std::num::NonZeroUsize;
Expand Down Expand Up @@ -552,6 +553,73 @@ impl AsyncWrite for TlsStream {
}
}

fn is_write_vectored(&self) -> bool {
true
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let buffer_size = self.buffer_size;
loop {
break match &mut self.state {
TlsStreamState::Handshaking {
handle,
write_waker,
write_buf,
..
} => {
// If the handshake completed, we want to finalize it and then continue
if handle.is_finished() {
let Poll::Ready(res) =
handle.poll_unpin(&mut Context::from_waker(noop_waker_ref()))
else {
unreachable!()
};
self.finalize_handshake(res)?;
continue;
}
if let Some(buffer_size) = buffer_size {
let mut remaining = buffer_size.get() - write_buf.len();
if remaining == 0 {
// No room to write, so store the waker for whenever the handshake is done
write_waker.set_waker(cx.waker());
trace!("write limit");
Poll::Pending
} else {
trace!("write buf");
let mut wrote = 0;
for buf in bufs {
if buf.len() <= remaining {
write_buf.extend_from_slice(buf);
wrote += buf.len();
remaining -= buf.len();
} else {
write_buf.extend_from_slice(&buf[0..remaining]);
wrote += remaining;
break;
}
}

// TODO(mmastrac): this currently ignores remaining size
Poll::Ready(Ok(wrote))
}
} else {
trace!("write buf");
Poll::Ready(Ok(write_buf.write_vectored(bufs).unwrap()))
}
}
TlsStreamState::Open(ref mut stm) => stm.poll_write_vectored(cx, bufs),
TlsStreamState::Closed => {
Poll::Ready(Err(ErrorKind::NotConnected.into()))
}
TlsStreamState::ClosedError(err) => Poll::Ready(Err((*err).into())),
};
}
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand Down
37 changes: 27 additions & 10 deletions src/system_test/fastwebsockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,28 @@ const LARGE_PAYLOAD: [u8; 48 * 1024] = [0xff; 48 * 1024];
const SMALL_PAYLOAD: [u8; 16] = [0xff; 16];

#[rstest]
#[case(false, false, false)]
#[case(false, false, true)]
#[case(false, true, false)]
#[case(false, true, true)]
#[case(true, false, false)]
#[case(true, false, true)]
#[case(true, true, false)]
#[case(true, true, true)]
#[case(false, false, false, false)]
#[case(false, false, false, true)]
#[case(false, false, true, false)]
#[case(false, false, true, true)]
#[case(false, true, false, false)]
#[case(false, true, false, true)]
#[case(false, true, true, false)]
#[case(false, true, true, true)]
#[case(true, false, false, false)]
#[case(true, false, false, true)]
#[case(true, false, true, false)]
#[case(true, false, true, true)]
#[case(true, true, false, false)]
#[case(true, true, false, true)]
#[case(true, true, true, false)]
#[case(true, true, true, true)]
#[tokio::test]
async fn test_fastwebsockets(
#[case] handshake: bool,
#[case] buffer_limit: bool,
#[case] large_payload: bool,
#[case] use_writev: bool,
) {
let payload = if large_payload {
LARGE_PAYLOAD.as_slice()
Expand All @@ -45,9 +54,13 @@ async fn test_fastwebsockets(
server.handshake().await.expect("failed handshake");
}

let a = tokio::spawn(async {
let a = tokio::spawn(async move {
let mut ws = WebSocket::after_handshake(server, Role::Server);
ws.set_auto_close(true);
if use_writev {
ws.set_writev(true);
ws.set_writev_threshold(0);
}
for i in 0..1000 {
println!("send {i}");
ws.write_frame(Frame::binary(Payload::Borrowed(payload)))
Expand All @@ -62,8 +75,12 @@ async fn test_fastwebsockets(
let frame = ws.read_frame().await.expect("failed to read");
assert_eq!(frame.opcode, OpCode::Close);
});
let b = tokio::spawn(async {
let b = tokio::spawn(async move {
let mut ws = WebSocket::after_handshake(client, Role::Client);
if use_writev {
ws.set_writev(true);
ws.set_writev_threshold(0);
}
ws.set_auto_close(true);
for i in 0..1000 {
println!("recv {i}");
Expand Down

0 comments on commit c432167

Please sign in to comment.