Skip to content

Commit

Permalink
fix: use a hybrid waker approach in connection_stream
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 30, 2023
1 parent 1cc619b commit 295cf6b
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 186 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ rustls-pemfile = "1.0"
rustls = { version = "0.21", features = [ "dangerous_configuration" ] }
ntest = "0.9"
rstest = "0.18"
fastwebsockets = "0.4.4"
fastwebsockets = { version = "=0.5.0", features = [ "unstable-split" ] }
172 changes: 161 additions & 11 deletions src/connection_stream.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use futures::task::noop_waker_ref;
use rustls::Connection;
use rustls::IoState;
use std::io;
Expand All @@ -8,6 +9,7 @@ use std::pin::Pin;
use std::task::ready;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;
Expand All @@ -16,6 +18,40 @@ use crate::adapter::read_tls;
use crate::adapter::write_tls;
use crate::trace;

/// A poll context may be explicit (ie: passed by the external user) or implicit (ie: using whatever stashed waker
/// we have stored).
///
/// Note that if we have no waker stored, we'll use the explicit context (which means we'll wake up a reader for writes
/// to progress, and a writer for reads to progress).
enum PollContext<'a, 'b: 'a> {
Explicit(&'a mut Context<'b>),
Implicit(&'a mut Context<'b>),
}

impl<'a, 'b: 'a> PollContext<'a, 'b> {
pub fn cx<'f, 'g, R>(
&'_ mut self,
waker_ref: &'g mut Option<Waker>,
f: impl Fn(&mut Context<'_>) -> R,
) -> R
where
'b: 'f,
'a: 'f,
'g: 'f,
{
match self {
Self::Explicit(ref mut cx) => f(cx),
Self::Implicit(ref mut cx) => match waker_ref {
Some(w) => {
let mut cx = Context::from_waker(w);
f(&mut cx)
}
None => f(cx),
},
}
}
}

pub struct ConnectionStream {
tls: Connection,
tcp: TcpStream,
Expand All @@ -30,6 +66,10 @@ pub struct ConnectionStream {
rd_error: Option<io::ErrorKind>,
/// An error on the underlying socket's write side.
wr_error: Option<io::ErrorKind>,
/// The last read waker.
rd_waker: Option<Waker>,
/// The last write waker.
wr_waker: Option<Waker>,
}

#[derive(Debug, PartialEq, Eq)]
Expand All @@ -51,6 +91,8 @@ impl ConnectionStream {
rd_proto_error: None,
rd_error: None,
wr_error: None,
rd_waker: None,
wr_waker: None,
}
}

Expand All @@ -76,7 +118,7 @@ impl ConnectionStream {
.unwrap_or_default()
}

fn poll_read_only(&mut self, cx: &mut Context<'_>) -> StreamProgress {
fn poll_read_only(&mut self, mut cx: PollContext) -> StreamProgress {
if self.rd_error.is_some() || self.rd_proto_error.is_some() {
StreamProgress::Error
} else if self.tls.wants_read() {
Expand All @@ -98,7 +140,7 @@ impl ConnectionStream {
}
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
match self.tcp.poll_read_ready(cx) {
match cx.cx(&mut self.rd_waker, |cx| self.tcp.poll_read_ready(cx)) {
Poll::Pending => break StreamProgress::RegisteredWaker,
Poll::Ready(Err(err)) => {
self.rd_error = Some(err.kind());
Expand All @@ -121,7 +163,7 @@ impl ConnectionStream {
}
}

fn poll_write_only(&mut self, cx: &mut Context<'_>) -> StreamProgress {
fn poll_write_only(&mut self, mut cx: PollContext) -> StreamProgress {
if self.wr_error.is_some() {
StreamProgress::Error
} else if self.tls.wants_write() {
Expand All @@ -133,7 +175,8 @@ impl ConnectionStream {
break StreamProgress::MadeProgress;
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
match self.tcp.poll_write_ready(cx) {
match cx.cx(&mut self.wr_waker, |cx| self.tcp.poll_write_ready(cx))
{
Poll::Pending => break StreamProgress::RegisteredWaker,
Poll::Ready(Err(err)) => {
self.wr_error = Some(err.kind());
Expand Down Expand Up @@ -218,13 +261,13 @@ impl ConnectionStream {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> {
loop {
let res = loop {
// First prepare to read
let read = self.poll_read_only(cx);
let read = self.poll_read_only(PollContext::Explicit(cx));

// Then write until we lose interest
loop {
let write = self.poll_write_only(cx);
let write = self.poll_write_only(PollContext::Implicit(cx));
if write != StreamProgress::MadeProgress {
break;
}
Expand All @@ -245,7 +288,23 @@ impl ConnectionStream {
}
}
}
};

if res.is_ready() {
self.rd_waker.take();
} else {
// Replace the waker unless we already have it
if !self
.rd_waker
.as_ref()
.map(|w| cx.waker().will_wake(w))
.unwrap_or_default()
{
self.rd_waker = Some(cx.waker().clone());
}
}

res
}

/// Polls the connection for writes, reading as needed. As TLS may need to pump writes during reads, or
Expand All @@ -264,17 +323,19 @@ impl ConnectionStream {
) -> Poll<io::Result<usize>> {
// Zero-length writes always succeed
if buf.is_empty() {
self.wr_waker.take();
return Poll::Ready(Ok(0));
}

// Writes after shutdown return NotConnected
if self.wants_close_sent {
self.wr_waker.take();
return Poll::Ready(Err(ErrorKind::NotConnected.into()));
}

// First prepare to write
let res = loop {
let write = self.poll_write_only(cx);
let write = self.poll_write_only(PollContext::Explicit(cx));
match write {
// No room to write
StreamProgress::RegisteredWaker => break Poll::Pending,
Expand All @@ -291,15 +352,33 @@ impl ConnectionStream {
trace!("w={n}");
assert!(n > 0);
// Drain what we can
while self.poll_write_only(cx) == StreamProgress::MadeProgress {}
while self.poll_write_only(PollContext::Explicit(cx))
== StreamProgress::MadeProgress
{}
// And then return what we wrote
break Poll::Ready(Ok(n));
}
};
};

// Then read until we lose interest
while self.poll_read_only(cx) == StreamProgress::MadeProgress {}
while self.poll_read_only(PollContext::Implicit(cx))
== StreamProgress::MadeProgress
{}

if res.is_ready() {
self.wr_waker.take();
} else {
// Replace the waker unless we already have it
if !self
.wr_waker
.as_ref()
.map(|w| cx.waker().will_wake(w))
.unwrap_or_default()
{
self.wr_waker = Some(cx.waker().clone());
}
}

res
}
Expand All @@ -314,7 +393,7 @@ impl ConnectionStream {
/// reads at all.
pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match self.poll_write_only(cx) {
match self.poll_write_only(PollContext::Explicit(cx)) {
StreamProgress::RegisteredWaker => break Poll::Pending,
StreamProgress::MadeProgress => continue,
StreamProgress::NoInterest => break Poll::Ready(Ok(())),
Expand Down Expand Up @@ -348,6 +427,50 @@ impl ConnectionStream {
}
}

#[cfg(test)]
impl tokio::io::AsyncRead for ConnectionStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
ConnectionStream::poll_read(self.get_mut(), cx, buf).map(|r| r.map(|_| ()))
}
}

#[cfg(test)]
impl tokio::io::AsyncWrite for ConnectionStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
ConnectionStream::poll_write(self.get_mut(), cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_bufs: &[futures::io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
unimplemented!()
}

fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
unimplemented!()
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
ConnectionStream::poll_shutdown(self.get_mut(), cx)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -685,4 +808,31 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_split() {
use tokio::io::AsyncReadExt;
let (mut server, client) = tls_pair().await;
println!("==================");
let a = tokio::task::spawn(async move {
let mut buf = [0; 16];
server.read_exact(&mut buf).await.unwrap();
server.write_all(&buf).await.unwrap();
});

let (mut rx, mut tx) = tokio::io::split(client);
let b = tokio::task::spawn(async move {
let mut buf = [0; 16];
rx.read_exact(&mut buf).await.unwrap();
});
tokio::task::yield_now().await;
let c = tokio::task::spawn(async move {
let buf = [0; 16];
tx.write_all(&buf).await.unwrap();
});

a.await.expect("failed to join");
b.await.expect("failed to join");
c.await.expect("failed to join");
}
}
Loading

0 comments on commit 295cf6b

Please sign in to comment.