Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async way to read early data from TLSAcceptor #73

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ exclude = ["/.github", "/examples", "/scripts"]
[dependencies]
tokio = "1.0"
rustls = { version = "0.23.5", default-features = false, features = ["std"] }
pin-project-lite = "0.2.14"
tahmid-23 marked this conversation as resolved.
Show resolved Hide resolved
pki-types = { package = "rustls-pki-types", version = "1" }

[features]
Expand Down
56 changes: 56 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io;
use std::io::Read;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
Expand Down Expand Up @@ -96,6 +97,61 @@ where
}
}

#[cfg(feature = "early-data")]
impl<IO> TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn poll_read_early_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

match &this.state {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: would be nice to restructure this code more like this:

match &this.state {
    TlsState::Stream | TlsState::WriteShutdown => {}
    TlsState::ReadShutdown | TlsState::FullyShutdown => return Poll::Ready(Ok(())),
    s => unreachable!("server TLS can not hit this state: {:?}", s),
}

let mut stream = stream.as_mut_pin();
..

TlsState::Stream | TlsState::WriteShutdown => {
{
let mut stream = stream.as_mut_pin();

while !stream.eof && stream.session.wants_read() {
match stream.read_io(cx) {
Poll::Ready(Ok(0)) => {
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
}
Comment on lines +120 to +132
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks it's duplicated from somewhere else. Can we abstract over it instead? If not, why not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the branches update different variables, I don't really know how to extract it


if let Some(mut early_data) = stream.session.early_data() {
match early_data.read(buf.initialize_unfilled()) {
Ok(n) => if n > 0 {
buf.advance(n);
return Poll::Ready(Ok(()));
}
Err(err) => return Poll::Ready(Err(err))
}
}

if stream.session.is_handshaking() {
return Poll::Pending;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return pending here? or should I handshake? I'm worried about a hang due to missing wake.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there would be a missing wake? since handshake is done once the client sends it's finished, and this code depends on read. I'm not fully sure though.

but now that I think about it, the test code doesn't cover this. since I do acceptor.accept(&mut sock).await, the handshake already finishes, so reading the early data "async" effectively did nothing.
should I change some of the server stream logic to use the early data state (kind of like client)? acceptor.accept would no longer wait for the handshake to complete, and writing/reading would need to wait on the handshake to complete.
if I do that, then maybe it would make sense to further the handshake here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we got Pending from stream.read_io() it's fine to return Pending here, can is_handshaking() return true after read_io() has yielded Ok(0)?

}

return Poll::Ready(Ok(()));
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
s => unreachable!("server TLS can not hit this state: {:?}", s),
}
}
}

impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
Expand Down
82 changes: 46 additions & 36 deletions tests/early-data.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#![cfg(feature = "early-data")]

use std::io::{self, BufReader, Cursor, Read, Write};
use std::net::{SocketAddr, TcpListener};
use std::io::{self, BufReader, Cursor};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread;

use futures_util::{future::Future, ready};
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};
use pin_project_lite::pin_project;
use rustls::{self, ClientConfig, RootCertStore, ServerConfig};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{client, server, TlsAcceptor, TlsConnector};

struct Read1<T>(T);

Expand All @@ -33,12 +33,27 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
}
}

pin_project! {
struct TlsStreamEarlyWrapper<IO> {
#[pin]
inner: server::TlsStream<IO>
}
}

impl<IO> AsyncRead for TlsStreamEarlyWrapper<IO>
where
IO: AsyncRead + AsyncWrite + Unpin {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
return self.project().inner.poll_read_early_data(cx, buf);
tahmid-23 marked this conversation as resolved.
Show resolved Hide resolved
}
}

async fn send(
config: Arc<ClientConfig>,
addr: SocketAddr,
data: &[u8],
vectored: bool,
) -> io::Result<(TlsStream<TcpStream>, Vec<u8>)> {
) -> io::Result<(client::TlsStream<TcpStream>, Vec<u8>)> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems unrelated? If so, prefer to avoid it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave a more qualified name because I use server::TlsStream as well. should I still revert it?

let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?;
let domain = pki_types::ServerName::try_from("foobar.com").unwrap();
Expand Down Expand Up @@ -75,38 +90,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
.unwrap();
server.max_early_data_size = 8192;
let server = Arc::new(server);
let acceptor = Arc::new(TlsAcceptor::from(server));

let listener = TcpListener::bind("127.0.0.1:0")?;
let listener = TcpListener::bind("127.0.0.1:0").await?;
let server_port = listener.local_addr().unwrap().port();
thread::spawn(move || loop {
let (mut sock, _addr) = listener.accept().unwrap();
tokio::spawn(async move {
loop {
let (mut sock, _addr) = listener.accept().await.unwrap();

let acceptor = acceptor.clone();
tokio::spawn(async move {
let stream = acceptor.accept(&mut sock).await.unwrap();

let server = Arc::clone(&server);
thread::spawn(move || {
let mut conn = ServerConnection::new(server).unwrap();
conn.complete_io(&mut sock).unwrap();
let mut buf = Vec::new();
let mut stream_wrapper = TlsStreamEarlyWrapper{ inner: stream };
stream_wrapper.read_to_end(&mut buf).await.unwrap();
let mut stream = stream_wrapper.inner;
stream.write_all(b"EARLY:").await.unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a test that explicitly tests for being able to read actual data off of the early data stream?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow. This is already reading the early data?

stream.write_all(&buf).await.unwrap();

if let Some(mut early_data) = conn.early_data() {
let mut buf = Vec::new();
early_data.read_to_end(&mut buf).unwrap();
let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"EARLY:").unwrap();
stream.write_all(&buf).unwrap();
}

let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"LATE:").unwrap();
loop {
let mut buf = [0; 1024];
let n = stream.read(&mut buf).unwrap();
if n == 0 {
conn.send_close_notify();
conn.complete_io(&mut sock).unwrap();
break;
}
stream.write_all(&buf[..n]).unwrap();
}
});
stream.read_to_end(&mut buf).await.unwrap();
stream.write_all(b"LATE:").await.unwrap();
stream.write_all(&buf).await.unwrap();

stream.shutdown().await.unwrap();
});
}
});

let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
Expand All @@ -125,7 +135,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {

let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?;
assert!(!io.get_ref().1.is_early_data_accepted());
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));
assert_eq!("EARLY:LATE:hello", String::from_utf8_lossy(&buf));

let (io, buf) = send(config, addr, b"world!", vectored).await?;
assert!(io.get_ref().1.is_early_data_accepted());
Expand Down
Loading