Skip to content

Commit

Permalink
actually shut down sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Nov 21, 2024
1 parent 283f785 commit 0b5966c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 38 deletions.
5 changes: 5 additions & 0 deletions iroh-net/src/magicsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,11 @@ impl Handle {
}
self.msock.closing.store(true, Ordering::Relaxed);
self.msock.actor_sender.send(ActorMessage::Shutdown).await?;
self.msock.pconn4.close().await;
if let Some(ref conn) = self.msock.pconn6 {
conn.close().await;
}

self.msock.closed.store(true, Ordering::SeqCst);
self.msock.direct_addrs.addrs.shutdown();

Expand Down
26 changes: 17 additions & 9 deletions iroh-net/src/magicsock/udp_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl UdpConn {
let sock = bind(addr)?;
let state = sock.with_socket(|socket| {
quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket))
})?;
})??;

Ok(Self {
io: Arc::new(sock),
Expand All @@ -45,7 +45,7 @@ impl UdpConn {
// update socket state
let new_state = self.io.with_socket(|socket| {
quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(socket))
})?;
})??;
*self.inner.write().unwrap() = new_state;
Ok(())
}
Expand All @@ -59,6 +59,11 @@ impl UdpConn {
io: self.io.clone(),
})
}

/// Closes the socket for good
pub async fn close(&self) {
self.io.close().await;
}
}

impl AsyncUdpSocket for UdpConn {
Expand Down Expand Up @@ -90,8 +95,9 @@ impl AsyncUdpSocket for UdpConn {
})
});
match res {
Ok(()) => return Ok(()),
Err(err) => {
Ok(Ok(())) => return Ok(()),
Err(err) => return Err(err), // closed error
Ok(Err(err)) => {
if err.kind() == std::io::ErrorKind::WouldBlock {
continue;
}
Expand Down Expand Up @@ -129,22 +135,23 @@ impl AsyncUdpSocket for UdpConn {
}

match self.io.with_socket(|io| io.poll_recv_ready(cx)) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => match self.io.handle_read_error(err) {
Ok(Poll::Pending) => return Poll::Pending,
Ok(Poll::Ready(Ok(()))) => {}
Ok(Poll::Ready(Err(err))) => match self.io.handle_read_error(err) {
Some(err) => return Poll::Ready(Err(err)),
None => {
continue;
}
},
Err(err) => return Poll::Ready(Err(err)),
}

let res = self.io.try_io(Interest::READABLE, || {
self.io
.with_socket(|io| self.inner.read().unwrap().recv(io.into(), bufs, meta))
});
match res {
Ok(count) => {
Ok(Ok(count)) => {
for meta in meta.iter().take(count) {
trace!(
src = %meta.addr,
Expand All @@ -156,7 +163,7 @@ impl AsyncUdpSocket for UdpConn {
}
return Poll::Ready(Ok(count));
}
Err(err) => {
Ok(Err(err)) => {
// ignore spurious wakeups
if err.kind() == std::io::ErrorKind::WouldBlock {
continue;
Expand All @@ -168,6 +175,7 @@ impl AsyncUdpSocket for UdpConn {
}
}
}
Err(err) => return Poll::Ready(Err(err)),
}
}
}
Expand Down
109 changes: 80 additions & 29 deletions net-tools/netwatch/src/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
task::Poll,
};

use anyhow::{ensure, Context, Result};
use anyhow::{bail, ensure, Context, Result};
use tracing::{debug, warn};

use super::IpFamily;
Expand Down Expand Up @@ -82,7 +82,9 @@ impl UdpSocket {
// Remove old socket
let mut guard = self.socket.write().unwrap();
{
let socket = guard.take().expect("not yet dropped");
let Some(socket) = guard.take() else {
bail!("cannot rebind closed socket");
};
drop(socket);
}

Expand Down Expand Up @@ -113,13 +115,18 @@ impl UdpSocket {
}

/// Use the socket
pub fn with_socket<F, T>(&self, f: F) -> T
pub fn with_socket<F, T>(&self, f: F) -> std::io::Result<T>
where
F: FnOnce(&tokio::net::UdpSocket) -> T,
{
let guard = self.socket.read().unwrap();
let socket = guard.as_ref().expect("missing socket");
f(socket)
let Some(socket) = guard.as_ref() else {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
));
};
Ok(f(socket))
}

pub fn try_io<R>(
Expand All @@ -128,7 +135,12 @@ impl UdpSocket {
f: impl FnOnce() -> std::io::Result<R>,
) -> std::io::Result<R> {
let guard = self.socket.read().unwrap();
let socket = guard.as_ref().expect("missing socket");
let Some(socket) = guard.as_ref() else {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
));
};
socket.try_io(interest, f)
}

Expand Down Expand Up @@ -173,7 +185,13 @@ impl UdpSocket {
pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> {
let mut guard = self.socket.write().unwrap();
// dance around to make non async connect work
let socket_tokio = guard.take().expect("missing socket");
let Some(socket_tokio) = guard.take() else {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
));
};

let socket_std = socket_tokio.into_std()?;
socket_std.connect(addr)?;
let socket_tokio = tokio::net::UdpSocket::from_std(socket_std)?;
Expand All @@ -184,30 +202,38 @@ impl UdpSocket {
/// Returns the local address of this socket.
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
let guard = self.socket.read().unwrap();
let socket = guard.as_ref().expect("missing socket");
let Some(socket) = guard.as_ref() else {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
));
};

socket.local_addr()
}

/// Closes the socket, and waits for the underlying `libc::close` call to be finished.
pub async fn close(self) {
let std_sock = self
.socket
.write()
.unwrap()
.take()
.expect("not yet dropped")
.into_std();
let res = tokio::runtime::Handle::current()
.spawn_blocking(move || {
// Calls libc::close, which can block
drop(std_sock);
})
.await;
if let Err(err) = res {
warn!("failed to close socket: {:?}", err);
pub async fn close(&self) {
let socket = self.socket.write().unwrap().take();
if let Some(sock) = socket {
let std_sock = sock.into_std();
let res = tokio::runtime::Handle::current()
.spawn_blocking(move || {
// Calls libc::close, which can block
drop(std_sock);
})
.await;
if let Err(err) = res {
warn!("failed to close socket: {:?}", err);
}
}
}

/// Check if this socket is closed.
pub fn is_closed(&self) -> bool {
self.socket.read().unwrap().is_none()
}

/// Handle potential read errors, updating internal state.
///
/// Returns `Some(error)` if the error is fatal otherwise `None.
Expand Down Expand Up @@ -255,7 +281,12 @@ impl UdpSocket {
}
}
let guard = self.socket.read().unwrap();
let socket = guard.as_ref().expect("missing socket");
let Some(socket) = guard.as_ref() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
)));
};

match socket.poll_send_ready(cx) {
Poll::Pending => return Poll::Pending,
Expand Down Expand Up @@ -302,7 +333,12 @@ impl Future for RecvFut<'_, '_> {
}

let guard = socket.socket.read().unwrap();
let inner_socket = guard.as_ref().expect("missing socket");
let Some(inner_socket) = guard.as_ref() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
)));
};

match inner_socket.poll_recv_ready(cx) {
Poll::Pending => return Poll::Pending,
Expand Down Expand Up @@ -360,7 +396,12 @@ impl Future for RecvFromFut<'_, '_> {
}
}
let guard = socket.socket.read().unwrap();
let inner_socket = guard.as_ref().expect("missing socket");
let Some(inner_socket) = guard.as_ref() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
)));
};

match inner_socket.poll_recv_ready(cx) {
Poll::Pending => return Poll::Pending,
Expand Down Expand Up @@ -430,7 +471,12 @@ impl Future for SendFut<'_, '_> {
}
}
let guard = self.socket.socket.read().unwrap();
let socket = guard.as_ref().expect("missing socket");
let Some(socket) = guard.as_ref() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
)));
};

match socket.poll_send_ready(c) {
Poll::Pending => return Poll::Pending,
Expand Down Expand Up @@ -488,7 +534,12 @@ impl Future for SendToFut<'_, '_> {
}

let guard = self.socket.socket.read().unwrap();
let socket = guard.as_ref().expect("missing socket");
let Some(socket) = guard.as_ref() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"socket closed",
)));
};

match socket.poll_send_ready(cx) {
Poll::Pending => return Poll::Pending,
Expand Down

0 comments on commit 0b5966c

Please sign in to comment.