Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(iroh-relay): cleanup client connections in all cases #3105

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions iroh-relay/src/server/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub(super) struct Config {
pub(super) struct Client {
/// Identity of the connected peer.
node_id: NodeId,
/// Connection identifier.
connection_id: u64,
/// Used to close the connection loop.
done: CancellationToken,
/// Actor handle.
Expand All @@ -64,7 +66,7 @@ impl Client {
/// Creates a client from a connection & starts a read and write loop to handle io to and from
/// the client
/// Call [`Client::shutdown`] to close the read and write loops before dropping the [`Client`]
pub(super) fn new(config: Config, clients: &Clients) -> Client {
pub(super) fn new(config: Config, connection_id: u64, clients: &Clients) -> Client {
let Config {
node_id,
stream: io,
Expand Down Expand Up @@ -98,29 +100,21 @@ impl Client {
disco_send_queue: disco_send_queue_r,
node_gone: peer_gone_r,
node_id,
connection_id,
clients: clients.clone(),
};

// start io loop
let io_done = done.clone();
let handle = tokio::task::spawn(
async move {
match actor.run(io_done).await {
Err(e) => {
warn!("writer closed in error {e:#?}");
}
Ok(()) => {
debug!("writer closed");
}
}
}
.instrument(
tracing::info_span!("client connection actor", remote_node = %node_id.fmt_short()),
),
);
let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!(
"client connection actor",
remote_node = %node_id.fmt_short(),
connection_id = connection_id
)));

Client {
node_id,
connection_id,
handle: AbortOnDropHandle::new(handle),
done,
send_queue: send_queue_s,
Expand All @@ -129,11 +123,15 @@ impl Client {
}
}

pub(super) fn connection_id(&self) -> u64 {
self.connection_id
}

/// Shutdown the reader and writer loops and closes the connection.
///
/// Any shutdown errors will be logged as warnings.
pub(super) async fn shutdown(self) {
self.done.cancel();
self.start_shutdown();
if let Err(e) = self.handle.await {
warn!(
remote_node = %self.node_id.fmt_short(),
Expand All @@ -142,6 +140,11 @@ impl Client {
};
}

/// Starts the process of shutdown.
pub(super) fn start_shutdown(&self) {
self.done.cancel();
}

pub(super) fn try_send_packet(
&self,
src: NodeId,
Expand Down Expand Up @@ -194,12 +197,29 @@ struct Actor {
node_gone: mpsc::Receiver<NodeId>,
/// [`NodeId`] of this client
node_id: NodeId,
/// Connection identifier.
connection_id: u64,
/// Reference to the other connected clients.
clients: Clients,
}

impl Actor {
async fn run(mut self, done: CancellationToken) -> Result<()> {
async fn run(mut self, done: CancellationToken) {
match self.run_inner(done).await {
Err(e) => {
warn!("actor errored {e:#?}, exiting");
}
Ok(()) => {
debug!("actor finished, exiting");
}
}

self.clients
.unregister(self.connection_id, self.node_id)
.await;
}

async fn run_inner(&mut self, done: CancellationToken) -> Result<()> {
let jitter = Duration::from_secs(5);
let mut keep_alive = tokio::time::interval(KEEP_ALIVE + jitter);
// ticks immediately
Expand Down Expand Up @@ -304,7 +324,7 @@ impl Actor {
match frame {
Frame::SendPacket { dst_key, packet } => {
let packet_len = packet.len();
self.handle_frame_send_packet(dst_key, packet).await?;
self.handle_frame_send_packet(dst_key, packet)?;
inc_by!(Metrics, bytes_recv, packet_len as u64);
}
Frame::Ping { data } => {
Expand All @@ -323,15 +343,13 @@ impl Actor {
Ok(())
}

async fn handle_frame_send_packet(&self, dst: NodeId, data: Bytes) -> Result<()> {
fn handle_frame_send_packet(&self, dst: NodeId, data: Bytes) -> Result<()> {
if disco::looks_like_disco_wrapper(&data) {
inc!(Metrics, disco_packets_recv);
self.clients
.send_disco_packet(dst, data, self.node_id)
.await?;
self.clients.send_disco_packet(dst, data, self.node_id)?;
} else {
inc!(Metrics, send_packets_recv);
self.clients.send_packet(dst, data, self.node_id).await?;
self.clients.send_packet(dst, data, self.node_id)?;
}
Ok(())
}
Expand Down Expand Up @@ -546,6 +564,7 @@ mod tests {
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
node_gone: peer_gone_r,
connection_id: 0,
node_id,
clients: clients.clone(),
};
Expand Down Expand Up @@ -630,7 +649,7 @@ mod tests {
.await?;

done.cancel();
handle.await??;
handle.await?;
Ok(())
}

Expand Down
83 changes: 50 additions & 33 deletions iroh-relay/src/server/clients.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
//! The "Server" side of the client. Uses the `ClientConnManager`.
// Based on tailscale/derp/derp_server.go

use std::{collections::HashSet, sync::Arc};
use std::{
collections::HashSet,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};

use anyhow::{bail, Result};
use bytes::Bytes;
Expand All @@ -24,6 +30,8 @@ struct Inner {
clients: DashMap<NodeId, Client>,
/// Map of which client has sent where
sent_to: DashMap<NodeId, HashSet<NodeId>>,
/// Connection ID Counter
next_connection_id: AtomicU64,
}

impl Clients {
Expand All @@ -41,9 +49,10 @@ impl Clients {
/// Builds the client handler and starts the read & write loops for the connection.
pub async fn register(&self, client_config: Config) {
let node_id = client_config.node_id;
let connection_id = self.get_connection_id();
trace!(remote_node = node_id.fmt_short(), "registering client");

let client = Client::new(client_config, self);
let client = Client::new(client_config, connection_id, self);
if let Some(old_client) = self.0.clients.insert(node_id, client) {
debug!(
remote_node = node_id.fmt_short(),
Expand All @@ -53,20 +62,27 @@ impl Clients {
}
}

fn get_connection_id(&self) -> u64 {
self.0.next_connection_id.fetch_add(1, Ordering::Relaxed)
}

/// Removes the client from the map of clients, & sends a notification
/// to each client that peers has sent data to, to let them know that
/// peer is gone from the network.
///
/// Explicitly drops the reference to the client to avoid deadlock.
async fn unregister<'a>(
&self,
client: dashmap::mapref::one::Ref<'a, iroh_base::PublicKey, Client>,
node_id: NodeId,
) {
drop(client); // avoid deadlock
trace!(node_id = node_id.fmt_short(), "unregistering client");

if let Some((_, client)) = self.0.clients.remove(&node_id) {
/// Must be passed a matching connection_id.
pub(super) async fn unregister<'a>(&self, connection_id: u64, node_id: NodeId) {
trace!(
node_id = node_id.fmt_short(),
connection_id,
"unregistering client"
);

if let Some((_, client)) = self
.0
.clients
.remove_if(&node_id, |_, c| c.connection_id() == connection_id)
{
if let Some((_, sent_to)) = self.0.sent_to.remove(&node_id) {
for key in sent_to {
match client.try_send_peer_gone(key) {
Expand All @@ -91,7 +107,7 @@ impl Clients {
}

/// Attempt to send a packet to client with [`NodeId`] `dst`.
pub(super) async fn send_packet(&self, dst: NodeId, data: Bytes, src: NodeId) -> Result<()> {
pub(super) fn send_packet(&self, dst: NodeId, data: Bytes, src: NodeId) -> Result<()> {
let Some(client) = self.0.clients.get(&dst) else {
debug!(dst = dst.fmt_short(), "no connected client, dropped packet");
inc!(Metrics, send_packets_dropped);
Expand All @@ -115,19 +131,14 @@ impl Clients {
dst = dst.fmt_short(),
"can no longer write to client, dropping message and pruning connection"
);
self.unregister(client, dst).await;
client.start_shutdown();
bail!("failed to send message: gone");
}
}
}

/// Attempt to send a disco packet to client with [`NodeId`] `dst`.
pub(super) async fn send_disco_packet(
&self,
dst: NodeId,
data: Bytes,
src: NodeId,
) -> Result<()> {
pub(super) fn send_disco_packet(&self, dst: NodeId, data: Bytes, src: NodeId) -> Result<()> {
let Some(client) = self.0.clients.get(&dst) else {
debug!(
dst = dst.fmt_short(),
Expand All @@ -154,7 +165,7 @@ impl Clients {
dst = dst.fmt_short(),
"can no longer write to client, dropping disco message and pruning connection"
);
self.unregister(client, dst).await;
client.start_shutdown();
bail!("failed to send message: gone");
}
}
Expand Down Expand Up @@ -205,9 +216,7 @@ mod tests {

// send packet
let data = b"hello world!";
clients
.send_packet(a_key, Bytes::from(&data[..]), b_key)
.await?;
clients.send_packet(a_key, Bytes::from(&data[..]), b_key)?;
let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?;
assert_eq!(
frame,
Expand All @@ -218,9 +227,7 @@ mod tests {
);

// send disco packet
clients
.send_disco_packet(a_key, Bytes::from(&data[..]), b_key)
.await?;
clients.send_disco_packet(a_key, Bytes::from(&data[..]), b_key)?;
let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?;
assert_eq!(
frame,
Expand All @@ -230,13 +237,23 @@ mod tests {
}
);

let client = clients.0.clients.get(&a_key).unwrap();

// send peer_gone. Also, tests that we do not get a deadlock
// when unregistering.
clients.unregister(client, a_key).await;
{
let client = clients.0.clients.get(&a_key).unwrap();
// shutdown client a, this should trigger the removal from the clients list
client.start_shutdown();
}

assert!(!clients.0.clients.contains_key(&a_key));
// need to wait a moment for the removal to be processed
let c = clients.clone();
tokio::time::timeout(Duration::from_secs(1), async move {
loop {
if !c.0.clients.contains_key(&a_key) {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
})
.await?;
clients.shutdown().await;

Ok(())
Expand Down
Loading