Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add cancellation to the read #17

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 18 additions & 51 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,7 @@ pub(super) mod tests {
use tokio::net::TcpSocket;
use tokio::spawn;
use tokio::sync::Barrier;
use tokio::time::timeout;

type TestResult = Result<(), std::io::Error>;

Expand Down Expand Up @@ -1235,13 +1236,16 @@ pub(super) mod tests {
tls_pair_handshake_buffer_size(None, None).await
}

async fn expect_eof_read(stm: &mut TlsStream) {
async fn expect_eof_read(stm: &mut (impl AsyncReadExt + Unpin)) {
let mut buf = [0_u8; 1];
let e = stm.read(&mut buf).await.expect("Expected no error");
assert_eq!(e, 0, "expected eof");
}

async fn expect_io_error_read(stm: &mut TlsStream, kind: io::ErrorKind) {
async fn expect_io_error_read(
stm: &mut (impl AsyncReadExt + Unpin),
kind: io::ErrorKind,
) {
let mut buf = [0_u8; 1];
let e = stm.read(&mut buf).await.expect_err("Expected error");
assert_eq!(e.kind(), kind);
Expand Down Expand Up @@ -1815,16 +1819,15 @@ pub(super) mod tests {
#[case(true)]
#[case(false)]
#[tokio::test]
async fn large_transfer_with_buffer_limit_split(#[case] swap: bool) -> TestResult {
async fn large_transfer_with_buffer_limit_split(
#[case] swap: bool,
) -> TestResult {
const BUF_SIZE: usize = 10 * 1024;
const BUF_COUNT: usize = 1024;

let (server, client) = tls_pair_buffer_size(
BUF_SIZE.try_into().ok(),
)
.await;
let (server, client) = tls_pair_buffer_size(BUF_SIZE.try_into().ok()).await;

let (server, mut client) = if swap {
let (server, client) = if swap {
(client, server)
} else {
(server, client)
Expand All @@ -1833,25 +1836,29 @@ pub(super) mod tests {
let a = spawn(async move {
let (mut r, mut w) = server.into_split();
let a = spawn(async move {
r.read_u8().await.expect_err("");
timeout(Duration::from_millis(1), r.read_u8())
.await
.expect_err("");
});
let b = spawn(async move {
// Heap allocate a large buffer and send it
let buf = vec![42; BUF_COUNT * BUF_SIZE];
w.write_all(&buf).await.unwrap();
w.flush().await.unwrap();
w.shutdown().await.unwrap();
});

a.await.unwrap();
b.await.unwrap();
});
let b = spawn(async move {
let (mut r, _w) = client.into_split();
for _ in 0..BUF_COUNT {
tokio::time::sleep(Duration::from_millis(1)).await;
let mut buf = [0; BUF_SIZE];
assert_eq!(BUF_SIZE, client.read_exact(&mut buf).await.unwrap());
assert_eq!(BUF_SIZE, r.read_exact(&mut buf).await.unwrap());
}
expect_eof_read(&mut client).await;
expect_eof_read(&mut r).await;
});
a.await?;
b.await?;
Expand Down Expand Up @@ -1968,44 +1975,4 @@ pub(super) mod tests {
assert_eq!(n, expected);
Ok(())
}

// #[tokio::test(flavor = "current_thread")]
// async fn large_transfer_drop_socket_after_flush() -> TestResult {
// const BUF_SIZE: usize = 10 * 1024;
// const BUF_COUNT: usize = 1024;
// const LAST_COUNT: usize = 512;

// let (mut server, mut client) = tls_pair_handshake().await;
// // let (tx, rx) = tokio::sync::oneshot::channel();
// let a = spawn(async move {
// // Heap allocate a large buffer and send it
// let buf = vec![42; BUF_COUNT * BUF_SIZE];
// let (mut rd, mut wr) = server.into_split();
// let rd = spawn(async move { rd.read_u8().await });
// wr.write_all(&buf).await.unwrap();
// wr.flush().await.unwrap();
// // let (mut tcp, _tls) = server.into_inner();
// wr.shutdown().await.unwrap();
// drop(wr);
// // drop(tcp);
// rd.await;
// });
// let b = spawn(async move {
// tokio::time::sleep(Duration::from_millis(109)).await;
// for i in 0..BUF_COUNT {
// tokio::time::sleep(Duration::from_millis(10)).await;
// let mut buf = [0; BUF_SIZE];
// assert_eq!(
// BUF_SIZE,
// client
// .read_exact(&mut buf)
// .await
// .expect(&format!("After reading {i} packets"))
// );
// }
// });
// a.await?;
// b.await?;
// Ok(())
// }
}