diff --git a/Cargo.toml b/Cargo.toml index 84c356f..d3b3219 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,4 @@ futures-util = "0.3.1" lazy_static = "1.1" webpki-roots = "0.26" rustls-pemfile = "2" +pin-project-lite = "0.2.14" diff --git a/src/server.rs b/src/server.rs index 02debac..f371f46 100644 --- a/src/server.rs +++ b/src/server.rs @@ -96,6 +96,65 @@ where } } +#[cfg(feature = "early-data")] +impl TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn poll_read_early_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + use std::io::Read; + + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + match &this.state { + TlsState::Stream | TlsState::WriteShutdown => { + { + let mut stream = stream.as_mut_pin(); + + while !stream.eof && stream.session.wants_read() { + match stream.read_io(cx) { + Poll::Ready(Ok(0)) => { + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + } + + if let Some(mut early_data) = stream.session.early_data() { + match early_data.read(buf.initialize_unfilled()) { + Ok(n) => { + if n > 0 { + buf.advance(n); + return Poll::Ready(Ok(())); + } + } + Err(err) => return Poll::Ready(Err(err)), + } + } + + if stream.session.is_handshaking() { + return Poll::Pending; + } + + Poll::Ready(Ok(())) + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), + s => unreachable!("server TLS can not hit this state: {:?}", s), + } + } +} + impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, diff --git a/tests/early-data.rs b/tests/early-data.rs index 42faad3..aa8e070 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -1,17 +1,17 @@ #![cfg(feature = "early-data")] -use std::io::{self, BufReader, Cursor, Read, Write}; -use std::net::{SocketAddr, TcpListener}; +use std::io::{self, BufReader, Cursor}; +use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::thread; use futures_util::{future::Future, ready}; -use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::net::TcpStream; -use tokio_rustls::{client::TlsStream, TlsConnector}; +use pin_project_lite::pin_project; +use rustls::{self, ClientConfig, RootCertStore, ServerConfig}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_rustls::{client, server, TlsAcceptor, TlsConnector}; struct Read1(T); @@ -33,12 +33,32 @@ impl Future for Read1 { } } +pin_project! { + struct TlsStreamEarlyWrapper { + #[pin] + inner: server::TlsStream + } +} + +impl AsyncRead for TlsStreamEarlyWrapper +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + return self.project().inner.poll_read_early_data(cx, buf); + } +} + async fn send( config: Arc, addr: SocketAddr, data: &[u8], vectored: bool, -) -> io::Result<(TlsStream, Vec)> { +) -> io::Result<(client::TlsStream, Vec)> { let connector = TlsConnector::from(config).early_data(true); let stream = TcpStream::connect(&addr).await?; let domain = pki_types::ServerName::try_from("foobar.com").unwrap(); @@ -75,38 +95,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { .unwrap(); server.max_early_data_size = 8192; let server = Arc::new(server); + let acceptor = Arc::new(TlsAcceptor::from(server)); - let listener = TcpListener::bind("127.0.0.1:0")?; + let listener = TcpListener::bind("127.0.0.1:0").await?; let server_port = listener.local_addr().unwrap().port(); - thread::spawn(move || loop { - let (mut sock, _addr) = listener.accept().unwrap(); + tokio::spawn(async move { + loop { + let (mut sock, _addr) = listener.accept().await.unwrap(); + + let acceptor = acceptor.clone(); + tokio::spawn(async move { + let stream = acceptor.accept(&mut sock).await.unwrap(); - let server = Arc::clone(&server); - thread::spawn(move || { - let mut conn = ServerConnection::new(server).unwrap(); - conn.complete_io(&mut sock).unwrap(); + let mut buf = Vec::new(); + let mut stream_wrapper = TlsStreamEarlyWrapper { inner: stream }; + stream_wrapper.read_to_end(&mut buf).await.unwrap(); + let mut stream = stream_wrapper.inner; + stream.write_all(b"EARLY:").await.unwrap(); + stream.write_all(&buf).await.unwrap(); - if let Some(mut early_data) = conn.early_data() { let mut buf = Vec::new(); - early_data.read_to_end(&mut buf).unwrap(); - let mut stream = Stream::new(&mut conn, &mut sock); - stream.write_all(b"EARLY:").unwrap(); - stream.write_all(&buf).unwrap(); - } - - let mut stream = Stream::new(&mut conn, &mut sock); - stream.write_all(b"LATE:").unwrap(); - loop { - let mut buf = [0; 1024]; - let n = stream.read(&mut buf).unwrap(); - if n == 0 { - conn.send_close_notify(); - conn.complete_io(&mut sock).unwrap(); - break; - } - stream.write_all(&buf[..n]).unwrap(); - } - }); + stream.read_to_end(&mut buf).await.unwrap(); + stream.write_all(b"LATE:").await.unwrap(); + stream.write_all(&buf).await.unwrap(); + + stream.shutdown().await.unwrap(); + }); + } }); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); @@ -125,7 +140,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?; assert!(!io.get_ref().1.is_early_data_accepted()); - assert_eq!("LATE:hello", String::from_utf8_lossy(&buf)); + assert_eq!("EARLY:LATE:hello", String::from_utf8_lossy(&buf)); let (io, buf) = send(config, addr, b"world!", vectored).await?; assert!(io.get_ref().1.is_early_data_accepted());