Skip to content

Commit

Permalink
add poll_read_early_data
Browse files Browse the repository at this point in the history
  • Loading branch information
tahmid-23 committed May 21, 2024
1 parent 7448a86 commit d152300
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 36 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ exclude = ["/.github", "/examples", "/scripts"]
[dependencies]
tokio = "1.0"
rustls = { version = "0.23.5", default-features = false, features = ["std"] }
pin-project-lite = "0.2.14"
pki-types = { package = "rustls-pki-types", version = "1" }

[features]
Expand Down
56 changes: 56 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io;
use std::io::Read;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
Expand Down Expand Up @@ -96,6 +97,61 @@ where
}
}

#[cfg(feature = "early-data")]
impl<IO> TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn poll_read_early_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

match &this.state {
TlsState::Stream | TlsState::WriteShutdown => {
{
let mut stream = stream.as_mut_pin();

while !stream.eof && stream.session.wants_read() {
match stream.read_io(cx) {
Poll::Ready(Ok(0)) => {
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
}

if let Some(mut early_data) = stream.session.early_data() {
match early_data.read(buf.initialize_unfilled()) {
Ok(n) => if n > 0 {
buf.advance(n);
return Poll::Ready(Ok(()));
}
Err(err) => return Poll::Ready(Err(err))
}
}

if stream.session.is_handshaking() {
return Poll::Pending;
}

return Poll::Ready(Ok(()));
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
s => unreachable!("server TLS can not hit this state: {:?}", s),
}
}
}

impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
Expand Down
82 changes: 46 additions & 36 deletions tests/early-data.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#![cfg(feature = "early-data")]

use std::io::{self, BufReader, Cursor, Read, Write};
use std::net::{SocketAddr, TcpListener};
use std::io::{self, BufReader, Cursor};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread;

use futures_util::{future::Future, ready};
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};
use pin_project_lite::pin_project;
use rustls::{self, ClientConfig, RootCertStore, ServerConfig};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{client, server, TlsAcceptor, TlsConnector};

struct Read1<T>(T);

Expand All @@ -33,12 +33,27 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
}
}

pin_project! {
struct TlsStreamEarlyWrapper<IO> {
#[pin]
inner: server::TlsStream<IO>
}
}

impl<IO> AsyncRead for TlsStreamEarlyWrapper<IO>
where
IO: AsyncRead + AsyncWrite + Unpin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
return self.project().inner.poll_read_early_data(cx, buf);
}
}

async fn send(
config: Arc<ClientConfig>,
addr: SocketAddr,
data: &[u8],
vectored: bool,
) -> io::Result<(TlsStream<TcpStream>, Vec<u8>)> {
) -> io::Result<(client::TlsStream<TcpStream>, Vec<u8>)> {
let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?;
let domain = pki_types::ServerName::try_from("foobar.com").unwrap();
Expand Down Expand Up @@ -75,38 +90,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
.unwrap();
server.max_early_data_size = 8192;
let server = Arc::new(server);
let acceptor = Arc::new(TlsAcceptor::from(server));

let listener = TcpListener::bind("127.0.0.1:0")?;
let listener = TcpListener::bind("127.0.0.1:0").await?;
let server_port = listener.local_addr().unwrap().port();
thread::spawn(move || loop {
let (mut sock, _addr) = listener.accept().unwrap();
tokio::spawn(async move {
loop {
let (mut sock, _addr) = listener.accept().await.unwrap();

let acceptor = acceptor.clone();
tokio::spawn(async move {
let stream = acceptor.accept(&mut sock).await.unwrap();

let server = Arc::clone(&server);
thread::spawn(move || {
let mut conn = ServerConnection::new(server).unwrap();
conn.complete_io(&mut sock).unwrap();
let mut buf = Vec::new();
let mut stream_wrapper = TlsStreamEarlyWrapper{ inner: stream };
stream_wrapper.read_to_end(&mut buf).await.unwrap();
let mut stream = stream_wrapper.inner;
stream.write_all(b"EARLY:").await.unwrap();
stream.write_all(&buf).await.unwrap();

if let Some(mut early_data) = conn.early_data() {
let mut buf = Vec::new();
early_data.read_to_end(&mut buf).unwrap();
let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"EARLY:").unwrap();
stream.write_all(&buf).unwrap();
}

let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"LATE:").unwrap();
loop {
let mut buf = [0; 1024];
let n = stream.read(&mut buf).unwrap();
if n == 0 {
conn.send_close_notify();
conn.complete_io(&mut sock).unwrap();
break;
}
stream.write_all(&buf[..n]).unwrap();
}
});
stream.read_to_end(&mut buf).await.unwrap();
stream.write_all(b"LATE:").await.unwrap();
stream.write_all(&buf).await.unwrap();

stream.shutdown().await.unwrap();
});
}
});

let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
Expand All @@ -125,7 +135,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {

let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?;
assert!(!io.get_ref().1.is_early_data_accepted());
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));
assert_eq!("EARLY:LATE:hello", String::from_utf8_lossy(&buf));

let (io, buf) = send(config, addr, b"world!", vectored).await?;
assert!(io.get_ref().1.is_early_data_accepted());
Expand Down

0 comments on commit d152300

Please sign in to comment.