From d152300a7b2ddbec1fd91e00398542cda709eb57 Mon Sep 17 00:00:00 2001 From: tahmid-23 <60953955+tahmid-23@users.noreply.github.com> Date: Mon, 20 May 2024 20:13:55 -0400 Subject: [PATCH 1/2] add poll_read_early_data --- Cargo.toml | 1 + src/server.rs | 56 +++++++++++++++++++++++++++++++ tests/early-data.rs | 82 +++++++++++++++++++++++++-------------------- 3 files changed, 103 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 84c356f4..4d8e7f48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ exclude = ["/.github", "/examples", "/scripts"] [dependencies] tokio = "1.0" rustls = { version = "0.23.5", default-features = false, features = ["std"] } +pin-project-lite = "0.2.14" pki-types = { package = "rustls-pki-types", version = "1" } [features] diff --git a/src/server.rs b/src/server.rs index 02debac3..4ce4f3d9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,5 @@ use std::io; +use std::io::Read; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] @@ -96,6 +97,61 @@ where } } +#[cfg(feature = "early-data")] +impl TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn poll_read_early_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + match &this.state { + TlsState::Stream | TlsState::WriteShutdown => { + { + let mut stream = stream.as_mut_pin(); + + while !stream.eof && stream.session.wants_read() { + match stream.read_io(cx) { + Poll::Ready(Ok(0)) => { + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + } + + if let Some(mut early_data) = stream.session.early_data() { + match early_data.read(buf.initialize_unfilled()) { + Ok(n) => if n > 0 { + buf.advance(n); + return Poll::Ready(Ok(())); + } + Err(err) => return Poll::Ready(Err(err)) + } + } + + if stream.session.is_handshaking() { + return Poll::Pending; + } + + return Poll::Ready(Ok(())); + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), + s => unreachable!("server TLS can not hit this state: {:?}", s), + } + } +} + impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, diff --git a/tests/early-data.rs b/tests/early-data.rs index 42faad32..c7540548 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -1,17 +1,17 @@ #![cfg(feature = "early-data")] -use std::io::{self, BufReader, Cursor, Read, Write}; -use std::net::{SocketAddr, TcpListener}; +use std::io::{self, BufReader, Cursor}; +use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::thread; use futures_util::{future::Future, ready}; -use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::net::TcpStream; -use tokio_rustls::{client::TlsStream, TlsConnector}; +use pin_project_lite::pin_project; +use rustls::{self, ClientConfig, RootCertStore, ServerConfig}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_rustls::{client, server, TlsAcceptor, TlsConnector}; struct Read1(T); @@ -33,12 +33,27 @@ impl Future for Read1 { } } +pin_project! { + struct TlsStreamEarlyWrapper { + #[pin] + inner: server::TlsStream + } +} + +impl AsyncRead for TlsStreamEarlyWrapper +where + IO: AsyncRead + AsyncWrite + Unpin { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + return self.project().inner.poll_read_early_data(cx, buf); + } +} + async fn send( config: Arc, addr: SocketAddr, data: &[u8], vectored: bool, -) -> io::Result<(TlsStream, Vec)> { +) -> io::Result<(client::TlsStream, Vec)> { let connector = TlsConnector::from(config).early_data(true); let stream = TcpStream::connect(&addr).await?; let domain = pki_types::ServerName::try_from("foobar.com").unwrap(); @@ -75,38 +90,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { .unwrap(); server.max_early_data_size = 8192; let server = Arc::new(server); + let acceptor = Arc::new(TlsAcceptor::from(server)); - let listener = TcpListener::bind("127.0.0.1:0")?; + let listener = TcpListener::bind("127.0.0.1:0").await?; let server_port = listener.local_addr().unwrap().port(); - thread::spawn(move || loop { - let (mut sock, _addr) = listener.accept().unwrap(); + tokio::spawn(async move { + loop { + let (mut sock, _addr) = listener.accept().await.unwrap(); + + let acceptor = acceptor.clone(); + tokio::spawn(async move { + let stream = acceptor.accept(&mut sock).await.unwrap(); - let server = Arc::clone(&server); - thread::spawn(move || { - let mut conn = ServerConnection::new(server).unwrap(); - conn.complete_io(&mut sock).unwrap(); + let mut buf = Vec::new(); + let mut stream_wrapper = TlsStreamEarlyWrapper{ inner: stream }; + stream_wrapper.read_to_end(&mut buf).await.unwrap(); + let mut stream = stream_wrapper.inner; + stream.write_all(b"EARLY:").await.unwrap(); + stream.write_all(&buf).await.unwrap(); - if let Some(mut early_data) = conn.early_data() { let mut buf = Vec::new(); - early_data.read_to_end(&mut buf).unwrap(); - let mut stream = Stream::new(&mut conn, &mut sock); - stream.write_all(b"EARLY:").unwrap(); - stream.write_all(&buf).unwrap(); - } - - let mut stream = Stream::new(&mut conn, &mut sock); - stream.write_all(b"LATE:").unwrap(); - loop { - let mut buf = [0; 1024]; - let n = stream.read(&mut buf).unwrap(); - if n == 0 { - conn.send_close_notify(); - conn.complete_io(&mut sock).unwrap(); - break; - } - stream.write_all(&buf[..n]).unwrap(); - } - }); + stream.read_to_end(&mut buf).await.unwrap(); + stream.write_all(b"LATE:").await.unwrap(); + stream.write_all(&buf).await.unwrap(); + + stream.shutdown().await.unwrap(); + }); + } }); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); @@ -125,7 +135,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?; assert!(!io.get_ref().1.is_early_data_accepted()); - assert_eq!("LATE:hello", String::from_utf8_lossy(&buf)); + assert_eq!("EARLY:LATE:hello", String::from_utf8_lossy(&buf)); let (io, buf) = send(config, addr, b"world!", vectored).await?; assert!(io.get_ref().1.is_early_data_accepted()); From 2565c62510f54d15376daf3ee9adf172409af306 Mon Sep 17 00:00:00 2001 From: tahmid-23 <60953955+tahmid-23@users.noreply.github.com> Date: Tue, 21 May 2024 13:33:27 -0400 Subject: [PATCH 2/2] fix CI --- Cargo.toml | 2 +- src/server.rs | 15 +++++++++------ tests/early-data.rs | 11 ++++++++--- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4d8e7f48..d3b32192 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ exclude = ["/.github", "/examples", "/scripts"] [dependencies] tokio = "1.0" rustls = { version = "0.23.5", default-features = false, features = ["std"] } -pin-project-lite = "0.2.14" pki-types = { package = "rustls-pki-types", version = "1" } [features] @@ -35,3 +34,4 @@ futures-util = "0.3.1" lazy_static = "1.1" webpki-roots = "0.26" rustls-pemfile = "2" +pin-project-lite = "0.2.14" diff --git a/src/server.rs b/src/server.rs index 4ce4f3d9..f371f466 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,4 @@ use std::io; -use std::io::Read; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] @@ -107,6 +106,8 @@ where cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { + use std::io::Read; + let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); @@ -132,11 +133,13 @@ where if let Some(mut early_data) = stream.session.early_data() { match early_data.read(buf.initialize_unfilled()) { - Ok(n) => if n > 0 { - buf.advance(n); - return Poll::Ready(Ok(())); + Ok(n) => { + if n > 0 { + buf.advance(n); + return Poll::Ready(Ok(())); + } } - Err(err) => return Poll::Ready(Err(err)) + Err(err) => return Poll::Ready(Err(err)), } } @@ -144,7 +147,7 @@ where return Poll::Pending; } - return Poll::Ready(Ok(())); + Poll::Ready(Ok(())) } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), s => unreachable!("server TLS can not hit this state: {:?}", s), diff --git a/tests/early-data.rs b/tests/early-data.rs index c7540548..aa8e070d 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -42,8 +42,13 @@ pin_project! { impl AsyncRead for TlsStreamEarlyWrapper where - IO: AsyncRead + AsyncWrite + Unpin { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { return self.project().inner.poll_read_early_data(cx, buf); } } @@ -103,7 +108,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { let stream = acceptor.accept(&mut sock).await.unwrap(); let mut buf = Vec::new(); - let mut stream_wrapper = TlsStreamEarlyWrapper{ inner: stream }; + let mut stream_wrapper = TlsStreamEarlyWrapper { inner: stream }; stream_wrapper.read_to_end(&mut buf).await.unwrap(); let mut stream = stream_wrapper.inner; stream.write_all(b"EARLY:").await.unwrap();