Skip to content

Commit

Permalink
Use nonblocking_drop everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Nov 15, 2023
1 parent 5db23e3 commit 1b5dd80
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 29 deletions.
13 changes: 5 additions & 8 deletions src/connection_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@ use std::io;
use std::io::ErrorKind;
use std::io::Read;
use std::io::Write;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::Arc;
use std::task::ready;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;

Expand Down Expand Up @@ -476,7 +473,7 @@ impl ConnectionStream {
#[cfg(test)]
impl tokio::io::AsyncRead for ConnectionStream {
fn poll_read(
self: Pin<&mut Self>,
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Expand All @@ -487,15 +484,15 @@ impl tokio::io::AsyncRead for ConnectionStream {
#[cfg(test)]
impl tokio::io::AsyncWrite for ConnectionStream {
fn poll_write(
self: Pin<&mut Self>,
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
ConnectionStream::poll_write(self.get_mut(), cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[futures::io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
Expand All @@ -509,14 +506,14 @@ impl tokio::io::AsyncWrite for ConnectionStream {
}

fn poll_flush(
self: Pin<&mut Self>,
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
ConnectionStream::poll_flush(self.get_mut(), cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
ConnectionStream::poll_shutdown(self.get_mut(), cx)
Expand Down
41 changes: 20 additions & 21 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ impl TlsStream {
let mut stm = ConnectionStream::new(tcp, tls);
poll_fn(|cx| stm.poll_write(cx, &buf)).await?;
poll_fn(|cx| stm.poll_shutdown(cx)).await?;
let (tcp, _) = stm.into_inner();
nonblocking_tcp_drop(tcp);
}
Err(err) => {
if err.is_panic() {
Expand All @@ -557,6 +559,8 @@ impl TlsStream {
}
TlsStreamState::Open(mut stm) => {
poll_fn(|cx| stm.poll_shutdown(cx)).await?;
let (tcp, _) = stm.into_inner();
nonblocking_tcp_drop(tcp);
}
TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
// Nothing
Expand Down Expand Up @@ -607,6 +611,20 @@ async fn send_handshake(
res
}

fn nonblocking_tcp_drop(tcp: TcpStream) {
if let Ok(tcp) = tcp.into_std() {
spawn_blocking(move || {
// TODO(mmastrac): this should not be necessary with SO_LINGER but I cannot get that working
trace!("in drop tcp task");
// Drop the TCP stream here just in case close() blocks
_ = tcp.set_nonblocking(false);
sleep(Duration::from_secs(1));
drop(tcp);
trace!("done drop tcp task");
});
}
}

impl AsyncRead for TlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -840,17 +858,7 @@ impl Drop for TlsStream {
let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
trace!("shutdown handshake {:?}", res);
let (tcp, _) = stm.into_inner();
if let Ok(tcp) = tcp.into_std() {
spawn_blocking(move || {
// TODO(mmastrac): this should not be necessary with SO_LINGER but I cannot get that working
trace!("in drop tcp task");
// Drop the TCP stream here just in case close() blocks
_ = tcp.set_nonblocking(false);
sleep(Duration::from_secs(1));
drop(tcp);
trace!("done drop tcp task");
});
}
nonblocking_tcp_drop(tcp);
}
x @ Err(_) => {
trace!("{x:?}");
Expand All @@ -868,16 +876,7 @@ impl Drop for TlsStream {
let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
trace!("shutdown open {:?}", res);
let (tcp, _) = stm.into_inner();
if let Ok(tcp) = tcp.into_std() {
spawn_blocking(move || {
trace!("in drop tcp task");
// Drop the TCP stream here just in case close() blocks
_ = tcp.set_nonblocking(false);
sleep(Duration::from_secs(1));
drop(tcp);
trace!("done drop tcp task");
});
}
nonblocking_tcp_drop(tcp);
trace!("done drop task");
});
}
Expand Down

0 comments on commit 1b5dd80

Please sign in to comment.