diff --git a/src/connection_stream.rs b/src/connection_stream.rs index f053086..1b000e9 100644 --- a/src/connection_stream.rs +++ b/src/connection_stream.rs @@ -288,6 +288,7 @@ impl ConnectionStream { StreamProgress::NoInterest => { // Write it let n = self.tls.writer().write(buf).expect("Write will never fail"); + trace!("w={n}"); assert!(n > 0); // Drain what we can while self.poll_write_only(cx) == StreamProgress::MadeProgress {} @@ -303,6 +304,12 @@ impl ConnectionStream { res } + /// Fully write a buffer to the TLS stream, expecting it to write fully and not fail. + pub(crate) fn write_buf_fully(&mut self, buf: &[u8]) { + let n = self.tls.writer().write(buf).expect("Write will never fail"); + assert_eq!(n, buf.len()) + } + /// Polls for completion of all the writes in the rustls [`Connection`]. Does not progress on /// reads at all. pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/src/stream.rs b/src/stream.rs index ff801e5..677d503 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -13,6 +13,7 @@ use futures::task::Poll; use futures::task::Waker; use futures::FutureExt; +use futures::task::noop_waker_ref; use rustls::ClientConfig; use rustls::ClientConnection; use rustls::Connection; @@ -23,8 +24,11 @@ use std::cell::Cell; use std::fmt::Debug; use std::io; use std::io::ErrorKind; +use tokio::task::JoinError; +use std::num::NonZeroUsize; use std::pin::Pin; +use std::rc::Rc; use std::sync::Arc; use std::task::ready; @@ -38,14 +42,31 @@ use tokio::spawn; use tokio::sync::watch; use tokio::task::JoinHandle; +#[derive(Clone, Default)] +struct SharedWaker(Rc>>); +unsafe impl Send for SharedWaker {} + +impl SharedWaker { + pub fn wake(&self) { + if let Some(waker) = self.0.take() { + waker.wake(); + } + } + + pub fn set_waker(&self, waker: &Waker) { + self.0.set(Some(waker.clone())) + } +} + enum TlsStreamState { /// If we are handshaking, writes are buffered and reads block. // TODO(mmastrac): We should be buffered in the Connection, not the Vec, as this results in a double-copy. - Handshaking( - JoinHandle>, - Cell>, - Vec, - ), + Handshaking { + handle: JoinHandle>, + read_waker: SharedWaker, + write_waker: SharedWaker, + write_buf: Vec, + }, /// The connection is open. Open(ConnectionStream), /// The connection is closed. @@ -56,13 +77,15 @@ enum TlsStreamState { pub struct TlsStream { state: TlsStreamState, + handshake: watch::Receiver>>, + buffer_size: Option, } impl Debug for TlsStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.state { - TlsStreamState::Handshaking(..) => { + TlsStreamState::Handshaking { .. } => { f.write_str("TlsStream { Handshaking }") } TlsStreamState::Open(..) => f.write_fmt(format_args!( @@ -86,13 +109,18 @@ impl TlsStream { fn new( tcp: TcpStream, mut tls: Connection, + buffer_size: Option, test_options: TestOptions, ) -> Self { - tls.set_buffer_limit(None); + tls.set_buffer_limit(buffer_size.map(|s| s.get())); let (tx, handshake) = watch::channel(None); + let read_waker = SharedWaker::default(); + let write_waker = SharedWaker::default(); + let read_waker_clone = read_waker.clone(); + let write_waker_clone = write_waker.clone(); - // TODO(mmastrac): We're using a oneshot to notify the reader, but this could be more efficient + // This task does nothing but yield a TlsHandshake to a oneshot let handle = spawn(async move { #[cfg(test)] if test_options.delay_handshake { @@ -110,12 +138,22 @@ impl TlsStream { } } + // We may have read/writes blocked on the handshake, so wake them all up + read_waker_clone.wake(); + write_waker_clone.wake(); + res }); Self { - state: TlsStreamState::Handshaking(handle, Cell::new(None), vec![]), + state: TlsStreamState::Handshaking { + handle: handle, + read_waker, + write_waker, + write_buf: vec![], + }, handshake, + buffer_size, } } @@ -123,9 +161,15 @@ impl TlsStream { tcp: TcpStream, tls_config: Arc, server_name: ServerName, + buffer_size: Option, ) -> Self { let tls = ClientConnection::new(tls_config, server_name).unwrap(); - Self::new(tcp, Connection::Client(tls), TestOptions::default()) + Self::new( + tcp, + Connection::Client(tls), + buffer_size, + TestOptions::default(), + ) } #[cfg(test)] @@ -133,42 +177,62 @@ impl TlsStream { tcp: TcpStream, tls_config: Arc, server_name: ServerName, + buffer_size: Option, test_options: TestOptions, ) -> Self { let tls = ClientConnection::new(tls_config, server_name).unwrap(); - Self::new(tcp, Connection::Client(tls), test_options) + Self::new(tcp, Connection::Client(tls), buffer_size, test_options) } pub fn new_client_side_from( tcp: TcpStream, connection: ClientConnection, + buffer_size: Option, ) -> Self { - Self::new(tcp, Connection::Client(connection), TestOptions::default()) + Self::new( + tcp, + Connection::Client(connection), + buffer_size, + TestOptions::default(), + ) } #[cfg(test)] pub(crate) fn new_server_side_test_options( tcp: TcpStream, tls_config: Arc, + buffer_size: Option, test_options: TestOptions, ) -> Self { let tls = ServerConnection::new(tls_config).unwrap(); - Self::new(tcp, Connection::Server(tls), test_options) + Self::new(tcp, Connection::Server(tls), buffer_size, test_options) } pub fn new_server_side( tcp: TcpStream, tls_config: Arc, + buffer_size: Option, ) -> Self { let tls = ServerConnection::new(tls_config).unwrap(); - Self::new(tcp, Connection::Server(tls), TestOptions::default()) + Self::new( + tcp, + Connection::Server(tls), + buffer_size, + TestOptions::default(), + ) } pub fn new_server_side_from( tcp: TcpStream, connection: ServerConnection, + buffer_size: Option, ) -> Self { - Self::new(tcp, Connection::Server(connection), TestOptions::default()) + Self::new( + tcp, + Connection::Server(connection), + buffer_size, + TestOptions::default(), + ) } /// Attempt to retrieve the inner stream and connection. @@ -190,7 +254,7 @@ impl TlsStream { TlsStreamState::Open(stm) => Ok(stm.into_inner()), TlsStreamState::Closed => Err(ErrorKind::NotConnected.into()), TlsStreamState::ClosedError(err) => Err(err.into()), - TlsStreamState::Handshaking(..) => unreachable!(), + TlsStreamState::Handshaking { .. } => unreachable!(), } } @@ -236,44 +300,58 @@ impl TlsStream { } } + fn finalize_handshake( + &mut self, + join_result: Result, JoinError>, + ) -> io::Result<()> { + trace!("finalize handshake"); + match &mut self.state { + TlsStreamState::Handshaking { + read_waker, + write_waker, + write_buf: buf, + .. + } => { + match join_result { + Err(err) => { + if err.is_panic() { + // Resume the panic on the main task + std::panic::resume_unwind(err.into_panic()); + } else { + unreachable!("Task should not have been cancelled"); + } + } + Ok(Err(err)) => { + return Err(err); + } + Ok(Ok((tcp, tls))) => { + let mut stm = ConnectionStream::new(tcp, tls); + // We need to save all the data we wrote before the connection. The stream has an internal buffer + // that matches our buffer, so it can accept it all. + stm.write_buf_fully(buf); + read_waker.wake(); + write_waker.wake(); + self.state = TlsStreamState::Open(stm); + return Ok(()); + } + } + } + _ => unreachable!(), + } + } + /// If the handshake is complete, migrate from a pending handshake to the open state. fn poll_pending_handshake( &mut self, cx: &mut Context<'_>, ) -> Poll> { - loop { - match &mut self.state { - TlsStreamState::Handshaking(ref mut handle, ref _waker, buf) => { - let res = ready!(handle.poll_unpin(cx)); - match res { - Err(err) => { - if err.is_panic() { - // Resume the panic on the main task - std::panic::resume_unwind(err.into_panic()); - } else { - unreachable!("Task should not have been cancelled"); - } - } - Ok(Err(err)) => { - return Poll::Ready(Err(err)); - } - Ok(Ok((tcp, tls))) => { - let mut stm = ConnectionStream::new(tcp, tls); - // We need to save all the data we wrote before the connection. The stream has an internal buffer - // that matches our buffer, so it can accept it all. - if let Poll::Ready(Ok(len)) = stm.poll_write(cx, buf) { - assert_eq!(len, buf.len()); - } else { - unreachable!("TLS stream should have accepted entire buffer"); - } - self.state = TlsStreamState::Open(stm); - continue; - } - } - } - _ => { - return Poll::Ready(Ok(())); - } + match &mut self.state { + TlsStreamState::Handshaking { handle, .. } => { + let res = ready!(handle.poll_unpin(cx)); + Poll::Ready(self.finalize_handshake(res)) + } + _ => { + return Poll::Ready(Ok(())); } } } @@ -303,7 +381,7 @@ impl TlsStream { match &mut self.state { // Handshaking: drop the handshake and return ready. - TlsStreamState::Handshaking(..) => { + TlsStreamState::Handshaking { .. } => { unreachable!() } TlsStreamState::Open(stm) => { @@ -323,7 +401,14 @@ impl TlsStream { let state = std::mem::replace(&mut self.state, TlsStreamState::Closed); trace!("closing {self:?}"); match state { - TlsStreamState::Handshaking(handle, _, buf) => { + TlsStreamState::Handshaking { + handle, + read_waker, + write_waker, + write_buf: buf, + } => { + read_waker.wake(); + write_waker.wake(); match handle.await { Ok(Ok((tcp, tls))) => { let mut stm = ConnectionStream::new(tcp, tls); @@ -361,28 +446,36 @@ impl AsyncRead for TlsStream { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - match ready!(self.poll_pending_handshake(cx)) { - Ok(()) => {} - Err(err) => { - self.state = TlsStreamState::ClosedError(err.kind()); - return Poll::Ready(Err(err)); - } - }; - match &mut self.state { - TlsStreamState::Handshaking(..) => { - unreachable!() - } - TlsStreamState::Open(ref mut stm) => { - match std::task::ready!(stm.poll_read(cx, buf)) { - Ok(_n) => { - // TODO: n? - Poll::Ready(Ok(())) + loop { + break match &mut self.state { + TlsStreamState::Handshaking { + handle, read_waker, .. + } => { + // If the handshake completed, we want to finalize it and then continue + if handle.is_finished() { + let Poll::Ready(res) = handle.poll_unpin(&mut Context::from_waker(noop_waker_ref())) else { + unreachable!() + }; + self.finalize_handshake(res)?; + continue; } - Err(err) => Poll::Ready(Err(err)), + + // Handshake is still blocking us + read_waker.set_waker(cx.waker()); + Poll::Pending } - } - TlsStreamState::Closed => Poll::Ready(Ok(())), - TlsStreamState::ClosedError(err) => Poll::Ready(Err((*err).into())), + TlsStreamState::Open(ref mut stm) => { + match std::task::ready!(stm.poll_read(cx, buf)) { + Ok(_n) => { + // TODO: n? + Poll::Ready(Ok(())) + } + Err(err) => Poll::Ready(Err(err)), + } + } + TlsStreamState::Closed => Poll::Ready(Ok(())), + TlsStreamState::ClosedError(err) => Poll::Ready(Err((*err).into())), + }; } } } @@ -393,17 +486,48 @@ impl AsyncWrite for TlsStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - // TODO: upgrade from handshaking if done - match &mut self.state { - TlsStreamState::Handshaking(_, _, write_buf) => { - write_buf.extend_from_slice(buf); - Poll::Ready(Ok(buf.len())) - } - TlsStreamState::Open(ref mut stm) => stm.poll_write(cx, buf), - TlsStreamState::Closed => { - Poll::Ready(Err(ErrorKind::NotConnected.into())) - } - TlsStreamState::ClosedError(err) => Poll::Ready(Err((*err).into())), + let buffer_size = self.buffer_size; + loop { + break match &mut self.state { + TlsStreamState::Handshaking { + handle, + write_waker, + write_buf, + .. + } => { + // If the handshake completed, we want to finalize it and then continue + if handle.is_finished() { + let Poll::Ready(res) = handle.poll_unpin(&mut Context::from_waker(noop_waker_ref())) else { + unreachable!() + }; + self.finalize_handshake(res)?; + continue; + } + if let Some(buffer_size) = buffer_size { + let remaining = buffer_size.get() - write_buf.len(); + if remaining == 0 { + // No room to write, so store the waker for whenever the handshake is done + write_waker.set_waker(cx.waker()); + trace!("write limit"); + Poll::Pending + } else { + trace!("write buf"); + let buf = &buf[0..remaining]; + write_buf.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } else { + trace!("write buf"); + write_buf.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } + TlsStreamState::Open(ref mut stm) => stm.poll_write(cx, buf), + TlsStreamState::Closed => { + Poll::Ready(Err(ErrorKind::NotConnected.into())) + } + TlsStreamState::ClosedError(err) => Poll::Ready(Err((*err).into())), + }; } } @@ -412,7 +536,7 @@ impl AsyncWrite for TlsStream { cx: &mut Context<'_>, ) -> Poll> { match &mut self.state { - TlsStreamState::Handshaking(..) => Poll::Ready(Ok(())), + TlsStreamState::Handshaking { .. } => Poll::Ready(Ok(())), TlsStreamState::Open(stm) => stm.poll_flush(cx), TlsStreamState::Closed => { Poll::Ready(Err(ErrorKind::NotConnected.into())) @@ -434,13 +558,15 @@ impl Drop for TlsStream { trace!("dropping {self:?}"); let state = std::mem::replace(&mut self.state, TlsStreamState::Closed); match state { - TlsStreamState::Handshaking(handle, _, buf) => { + TlsStreamState::Handshaking { + handle, write_buf, .. + } => { spawn(async move { trace!("in task"); match handle.await { Ok(Ok((tcp, tls))) => { let mut stm = ConnectionStream::new(tcp, tls); - trace!("{:?}", poll_fn(|cx| stm.poll_write(cx, &buf)).await); + stm.write_buf_fully(&write_buf); trace!("{:?}", poll_fn(|cx| stm.poll_shutdown(cx)).await); } x @ Err(_) => { @@ -581,11 +707,13 @@ mod tests { async fn tls_pair() -> (TlsStream, TlsStream) { let (server, client) = tcp_pair().await; - let server = TlsStream::new_server_side(server, server_config(&[]).into()); + let server = + TlsStream::new_server_side(server, server_config(&[]).into(), None); let client = TlsStream::new_client_side( client, client_config(&[]).into(), "example.com".try_into().unwrap(), + None, ); (server, client) @@ -593,7 +721,8 @@ mod tests { async fn tls_with_tcp_client() -> (TlsStream, TcpStream) { let (server, client) = tcp_pair().await; - let server = TlsStream::new_server_side(server, server_config(&[]).into()); + let server = + TlsStream::new_server_side(server, server_config(&[]).into(), None); (server, client) } @@ -603,6 +732,7 @@ mod tests { client, client_config(&[]).into(), "example.com".try_into().unwrap(), + None, ); (server, client) } @@ -626,12 +756,14 @@ mod tests { let server = TlsStream::new_server_side_test_options( server, server_config(&[]).into(), + None, server_test_options, ); let client = TlsStream::new_client_side_test_options( client, client_config(&[]).into(), "example.com".try_into().unwrap(), + None, client_test_options, ); @@ -640,22 +772,32 @@ mod tests { async fn tls_pair_alpn( server_alpn: &[&str], + server_buffer_size: Option, client_alpn: &[&str], + client_buffer_size: Option, ) -> (TlsStream, TlsStream) { let (server, client) = tcp_pair().await; - let server = - TlsStream::new_server_side(server, server_config(server_alpn).into()); + let server = TlsStream::new_server_side( + server, + server_config(server_alpn).into(), + server_buffer_size, + ); let client = TlsStream::new_client_side( client, client_config(client_alpn).into(), "example.com".try_into().unwrap(), + client_buffer_size, ); (server, client) } - async fn tls_pair_handshake() -> (TlsStream, TlsStream) { - let (mut server, mut client) = tls_pair_alpn(&[], &[]).await; + async fn tls_pair_handshake_buffer_size( + server_buffer_size: Option, + client_buffer_size: Option, + ) -> (TlsStream, TlsStream) { + let (mut server, mut client) = + tls_pair_alpn(&[], server_buffer_size, &[], client_buffer_size).await; let a = spawn(async move { server.handshake().await.unwrap(); server @@ -667,6 +809,10 @@ mod tests { (a.await.unwrap(), b.await.unwrap()) } + async fn tls_pair_handshake() -> (TlsStream, TlsStream) { + tls_pair_handshake_buffer_size(None, None).await + } + fn expect_io_error( e: Result, kind: io::ErrorKind, @@ -722,7 +868,7 @@ mod tests { #[ntest::timeout(60000)] async fn test_client_server_alpn() -> TestResult { let (mut server, mut client) = - tls_pair_alpn(&["a", "b", "c"], &["b"]).await; + tls_pair_alpn(&["a", "b", "c"], None, &["b"], None).await; let a = spawn(async move { let handshake = server.handshake().await.unwrap(); assert_eq!(handshake.alpn, Some("b".as_bytes().to_vec())); @@ -1051,6 +1197,36 @@ mod tests { Ok(()) } + #[tokio::test] + async fn large_transfer_with_buffer_limit() -> TestResult { + const BUF_SIZE: usize = 10 * 1024; + const BUF_COUNT: usize = 1024; + + let (mut server, mut client) = tls_pair_handshake_buffer_size( + BUF_SIZE.try_into().ok(), + BUF_SIZE.try_into().ok(), + ) + .await; + let a = spawn(async move { + // Heap allocate a large buffer and send it + let buf = vec![42; BUF_COUNT * BUF_SIZE]; + server.write_all(&buf).await.unwrap(); + server.shutdown().await.unwrap(); + server.close().await.unwrap(); + }); + let b = spawn(async move { + for _ in 0..BUF_COUNT { + tokio::time::sleep(Duration::from_millis(1)).await; + let mut buf = [0; BUF_SIZE]; + assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap()); + } + expect_eof_read(&mut client).await; + }); + a.await?; + b.await?; + Ok(()) + } + #[tokio::test(flavor = "current_thread")] async fn large_transfer_with_shutdown() -> TestResult { const BUF_SIZE: usize = 10 * 1024;