From 1b5dd80fd259d61ba8336c0773352e526d025503 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Wed, 15 Nov 2023 16:41:01 -0700 Subject: [PATCH] Use nonblocking_drop everywhere --- src/connection_stream.rs | 13 +++++-------- src/stream.rs | 41 ++++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/connection_stream.rs b/src/connection_stream.rs index 3a85ec1..2dd5872 100644 --- a/src/connection_stream.rs +++ b/src/connection_stream.rs @@ -6,14 +6,11 @@ use std::io; use std::io::ErrorKind; use std::io::Read; use std::io::Write; -use std::pin::Pin; -use std::ptr::NonNull; use std::sync::Arc; use std::task::ready; use std::task::Context; use std::task::Poll; use std::task::Waker; -use tokio::io::AsyncWrite; use tokio::io::ReadBuf; use tokio::net::TcpStream; @@ -476,7 +473,7 @@ impl ConnectionStream { #[cfg(test)] impl tokio::io::AsyncRead for ConnectionStream { fn poll_read( - self: Pin<&mut Self>, + self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { @@ -487,7 +484,7 @@ impl tokio::io::AsyncRead for ConnectionStream { #[cfg(test)] impl tokio::io::AsyncWrite for ConnectionStream { fn poll_write( - self: Pin<&mut Self>, + self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { @@ -495,7 +492,7 @@ impl tokio::io::AsyncWrite for ConnectionStream { } fn poll_write_vectored( - self: Pin<&mut Self>, + self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[futures::io::IoSlice<'_>], ) -> Poll> { @@ -509,14 +506,14 @@ impl tokio::io::AsyncWrite for ConnectionStream { } fn poll_flush( - self: Pin<&mut Self>, + self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { ConnectionStream::poll_flush(self.get_mut(), cx) } fn poll_shutdown( - self: Pin<&mut Self>, + self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { ConnectionStream::poll_shutdown(self.get_mut(), cx) diff --git a/src/stream.rs b/src/stream.rs index 3817350..2f98366 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -541,6 +541,8 @@ impl TlsStream { let mut stm = ConnectionStream::new(tcp, tls); poll_fn(|cx| stm.poll_write(cx, &buf)).await?; poll_fn(|cx| stm.poll_shutdown(cx)).await?; + let (tcp, _) = stm.into_inner(); + nonblocking_tcp_drop(tcp); } Err(err) => { if err.is_panic() { @@ -557,6 +559,8 @@ impl TlsStream { } TlsStreamState::Open(mut stm) => { poll_fn(|cx| stm.poll_shutdown(cx)).await?; + let (tcp, _) = stm.into_inner(); + nonblocking_tcp_drop(tcp); } TlsStreamState::Closed | TlsStreamState::ClosedError(_) => { // Nothing @@ -607,6 +611,20 @@ async fn send_handshake( res } +fn nonblocking_tcp_drop(tcp: TcpStream) { + if let Ok(tcp) = tcp.into_std() { + spawn_blocking(move || { + // TODO(mmastrac): this should not be necessary with SO_LINGER but I cannot get that working + trace!("in drop tcp task"); + // Drop the TCP stream here just in case close() blocks + _ = tcp.set_nonblocking(false); + sleep(Duration::from_secs(1)); + drop(tcp); + trace!("done drop tcp task"); + }); + } +} + impl AsyncRead for TlsStream { fn poll_read( mut self: Pin<&mut Self>, @@ -840,17 +858,7 @@ impl Drop for TlsStream { let res = poll_fn(|cx| stm.poll_shutdown(cx)).await; trace!("shutdown handshake {:?}", res); let (tcp, _) = stm.into_inner(); - if let Ok(tcp) = tcp.into_std() { - spawn_blocking(move || { - // TODO(mmastrac): this should not be necessary with SO_LINGER but I cannot get that working - trace!("in drop tcp task"); - // Drop the TCP stream here just in case close() blocks - _ = tcp.set_nonblocking(false); - sleep(Duration::from_secs(1)); - drop(tcp); - trace!("done drop tcp task"); - }); - } + nonblocking_tcp_drop(tcp); } x @ Err(_) => { trace!("{x:?}"); @@ -868,16 +876,7 @@ impl Drop for TlsStream { let res = poll_fn(|cx| stm.poll_shutdown(cx)).await; trace!("shutdown open {:?}", res); let (tcp, _) = stm.into_inner(); - if let Ok(tcp) = tcp.into_std() { - spawn_blocking(move || { - trace!("in drop tcp task"); - // Drop the TCP stream here just in case close() blocks - _ = tcp.set_nonblocking(false); - sleep(Duration::from_secs(1)); - drop(tcp); - trace!("done drop tcp task"); - }); - } + nonblocking_tcp_drop(tcp); trace!("done drop task"); }); }