diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index ff59e8548bb..6046f819c8e 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1513,6 +1513,11 @@ impl Handle { } self.msock.closing.store(true, Ordering::Relaxed); self.msock.actor_sender.send(ActorMessage::Shutdown).await?; + self.msock.pconn4.close().await; + if let Some(ref conn) = self.msock.pconn6 { + conn.close().await; + } + self.msock.closed.store(true, Ordering::SeqCst); self.msock.direct_addrs.addrs.shutdown(); diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index f622d9d6da7..f537f6265fd 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -30,7 +30,7 @@ impl UdpConn { let sock = bind(addr)?; let state = sock.with_socket(|socket| { quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) - })?; + })??; Ok(Self { io: Arc::new(sock), @@ -45,7 +45,7 @@ impl UdpConn { // update socket state let new_state = self.io.with_socket(|socket| { quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket)) - })?; + })??; *self.inner.write().unwrap() = new_state; Ok(()) } @@ -59,6 +59,11 @@ impl UdpConn { io: self.io.clone(), }) } + + /// Closes the socket for good + pub async fn close(&self) { + self.io.close().await; + } } impl AsyncUdpSocket for UdpConn { @@ -90,8 +95,9 @@ impl AsyncUdpSocket for UdpConn { }) }); match res { - Ok(()) => return Ok(()), - Err(err) => { + Ok(Ok(())) => return Ok(()), + Err(err) => return Err(err), // closed error + Ok(Err(err)) => { if err.kind() == std::io::ErrorKind::WouldBlock { continue; } @@ -129,14 +135,15 @@ impl AsyncUdpSocket for UdpConn { } match self.io.with_socket(|io| io.poll_recv_ready(cx)) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(err)) => match self.io.handle_read_error(err) { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(Ok(()))) => {} + Ok(Poll::Ready(Err(err))) => match self.io.handle_read_error(err) { Some(err) => return Poll::Ready(Err(err)), None => { continue; } }, + Err(err) => return Poll::Ready(Err(err)), } let res = self.io.try_io(Interest::READABLE, || { @@ -144,7 +151,7 @@ impl AsyncUdpSocket for UdpConn { .with_socket(|io| self.inner.read().unwrap().recv(io.into(), bufs, meta)) }); match res { - Ok(count) => { + Ok(Ok(count)) => { for meta in meta.iter().take(count) { trace!( src = %meta.addr, @@ -156,7 +163,7 @@ impl AsyncUdpSocket for UdpConn { } return Poll::Ready(Ok(count)); } - Err(err) => { + Ok(Err(err)) => { // ignore spurious wakeups if err.kind() == std::io::ErrorKind::WouldBlock { continue; @@ -168,6 +175,7 @@ impl AsyncUdpSocket for UdpConn { } } } + Err(err) => return Poll::Ready(Err(err)), } } } diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 06a5c127c5a..12b78be1d55 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -7,7 +7,7 @@ use std::{ task::Poll, }; -use anyhow::{ensure, Context, Result}; +use anyhow::{bail, ensure, Context, Result}; use tracing::{debug, warn}; use super::IpFamily; @@ -82,7 +82,9 @@ impl UdpSocket { // Remove old socket let mut guard = self.socket.write().unwrap(); { - let socket = guard.take().expect("not yet dropped"); + let Some(socket) = guard.take() else { + bail!("cannot rebind closed socket"); + }; drop(socket); } @@ -113,13 +115,18 @@ impl UdpSocket { } /// Use the socket - pub fn with_socket(&self, f: F) -> T + pub fn with_socket(&self, f: F) -> std::io::Result where F: FnOnce(&tokio::net::UdpSocket) -> T, { let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); - f(socket) + let Some(socket) = guard.as_ref() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + Ok(f(socket)) } pub fn try_io( @@ -128,7 +135,12 @@ impl UdpSocket { f: impl FnOnce() -> std::io::Result, ) -> std::io::Result { let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; socket.try_io(interest, f) } @@ -173,7 +185,13 @@ impl UdpSocket { pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { let mut guard = self.socket.write().unwrap(); // dance around to make non async connect work - let socket_tokio = guard.take().expect("missing socket"); + let Some(socket_tokio) = guard.take() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + let socket_std = socket_tokio.into_std()?; socket_std.connect(addr)?; let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?; @@ -184,30 +202,38 @@ impl UdpSocket { /// Returns the local address of this socket. pub fn local_addr(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )); + }; + socket.local_addr() } /// Closes the socket, and waits for the underlying `libc::close` call to be finished. - pub async fn close(self) { - let std_sock = self - .socket - .write() - .unwrap() - .take() - .expect("not yet dropped") - .into_std(); - let res = tokio::runtime::Handle::current() - .spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }) - .await; - if let Err(err) = res { - warn!("failed to close socket: {:?}", err); + pub async fn close(&self) { + let socket = self.socket.write().unwrap().take(); + if let Some(sock) = socket { + let std_sock = sock.into_std(); + let res = tokio::runtime::Handle::current() + .spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }) + .await; + if let Err(err) = res { + warn!("failed to close socket: {:?}", err); + } } } + /// Check if this socket is closed. + pub fn is_closed(&self) -> bool { + self.socket.read().unwrap().is_none() + } + /// Handle potential read errors, updating internal state. /// /// Returns `Some(error)` if the error is fatal otherwise `None. @@ -255,7 +281,12 @@ impl UdpSocket { } } let guard = self.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match socket.poll_send_ready(cx) { Poll::Pending => return Poll::Pending, @@ -302,7 +333,12 @@ impl Future for RecvFut<'_, '_> { } let guard = socket.socket.read().unwrap(); - let inner_socket = guard.as_ref().expect("missing socket"); + let Some(inner_socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match inner_socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, @@ -360,7 +396,12 @@ impl Future for RecvFromFut<'_, '_> { } } let guard = socket.socket.read().unwrap(); - let inner_socket = guard.as_ref().expect("missing socket"); + let Some(inner_socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match inner_socket.poll_recv_ready(cx) { Poll::Pending => return Poll::Pending, @@ -430,7 +471,12 @@ impl Future for SendFut<'_, '_> { } } let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match socket.poll_send_ready(c) { Poll::Pending => return Poll::Pending, @@ -488,7 +534,12 @@ impl Future for SendToFut<'_, '_> { } let guard = self.socket.socket.read().unwrap(); - let socket = guard.as_ref().expect("missing socket"); + let Some(socket) = guard.as_ref() else { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + ))); + }; match socket.poll_send_ready(cx) { Poll::Pending => return Poll::Pending,