From 23b26aa3137f15febe49fbc8e2563e49d0db4000 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 28 Nov 2024 09:44:45 +0100 Subject: [PATCH] refactor: move `Router` and `RouterBuilder` into `protocol.rs` --- iroh/src/discovery/local_swarm_discovery.rs | 3 +- iroh/src/lib.rs | 1 - iroh/src/protocol.rs | 218 ++++++++++++++++++- iroh/src/router.rs | 223 -------------------- 4 files changed, 217 insertions(+), 228 deletions(-) delete mode 100644 iroh/src/router.rs diff --git a/iroh/src/discovery/local_swarm_discovery.rs b/iroh/src/discovery/local_swarm_discovery.rs index fe9dfa45af5..753848b2129 100644 --- a/iroh/src/discovery/local_swarm_discovery.rs +++ b/iroh/src/discovery/local_swarm_discovery.rs @@ -20,8 +20,7 @@ //! .filter(|remote| { //! remote.sources().iter().any(|(source, duration)| { //! if let Source::Discovery { name } = source { -//! name == iroh::discovery::local_swarm_discovery::NAME -//! && *duration <= recent +//! name == iroh::discovery::local_swarm_discovery::NAME && *duration <= recent //! } else { //! false //! } diff --git a/iroh/src/lib.rs b/iroh/src/lib.rs index c112c025d49..db2a46203b4 100644 --- a/iroh/src/lib.rs +++ b/iroh/src/lib.rs @@ -242,7 +242,6 @@ pub mod endpoint; mod magicsock; pub mod metrics; pub mod protocol; -pub mod router; pub mod ticket; pub mod tls; diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index bbce8ca2141..91feecf3dae 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -1,11 +1,42 @@ //! TODO(matheus23) docs use std::{any::Any, collections::BTreeMap, sync::Arc}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use futures_buffered::join_all; use futures_lite::future::Boxed as BoxedFuture; +use futures_util::{ + future::{MapErr, Shared}, + FutureExt, TryFutureExt, +}; +use tokio::task::{JoinError, JoinSet}; +use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; +use tracing::{debug, error, warn}; -use crate::endpoint::Connecting; +use crate::{endpoint::Connecting, Endpoint}; + +/// TODO(matheus23): docs +#[derive(Clone, Debug)] +pub struct Router { + endpoint: Endpoint, + protocols: Arc, + // `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl. + // So we need + // - `Shared` so we can `task.await` from all `Node` clones + // - `MapErr` to map the `JoinError` to a `String`, because `JoinError` is `!Clone` + // - `AbortOnDropHandle` to make sure that the `task` is cancelled when all `Node`s are dropped + // (`Shared` acts like an `Arc` around its inner future). + task: Shared, JoinErrToStr>>, + cancel_token: CancellationToken, +} + +type JoinErrToStr = Box String + Send + Sync + 'static>; + +/// TODO(matheus23): docs +#[derive(Debug)] +pub struct RouterBuilder { + endpoint: Endpoint, + protocols: ProtocolMap, +} /// Handler for incoming connections. /// @@ -78,3 +109,186 @@ impl ProtocolMap { join_all(handlers).await; } } + +impl Router { + /// TODO(matheus23): docs + pub fn builder(endpoint: Endpoint) -> RouterBuilder { + RouterBuilder::new(endpoint) + } + + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } + + /// TODO(matheus23): docs + pub fn endpoint(&self) -> &Endpoint { + &self.endpoint + } + + /// TODO(matheus23): docs + pub async fn shutdown(self) -> Result<()> { + // Trigger shutdown of the main run task by activating the cancel token. + self.cancel_token.cancel(); + + // Wait for the main task to terminate. + self.task.await.map_err(|err| anyhow!(err))?; + + Ok(()) + } +} + +impl RouterBuilder { + /// TODO(matheus23): docs + pub fn new(endpoint: Endpoint) -> Self { + Self { + endpoint, + protocols: ProtocolMap::default(), + } + } + + /// TODO(matheus23): docs + pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: Arc) -> Self { + self.protocols.insert(alpn.as_ref().to_vec(), handler); + self + } + + /// Returns the [`Endpoint`] of the node. + pub fn endpoint(&self) -> &Endpoint { + &self.endpoint + } + + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } + + /// TODO(matheus23): docs + pub async fn spawn(self) -> Result { + // Update the endpoint with our alpns. + let alpns = self + .protocols + .alpns() + .map(|alpn| alpn.to_vec()) + .collect::>(); + + let protocols = Arc::new(self.protocols); + if let Err(err) = self.endpoint.set_alpns(alpns) { + shutdown(&self.endpoint, protocols.clone()).await; + return Err(err); + } + + let mut join_set = JoinSet::new(); + let endpoint = self.endpoint.clone(); + let protos = protocols.clone(); + let cancel = CancellationToken::new(); + let cancel_token = cancel.clone(); + + let run_loop_fut = async move { + let protocols = protos; + loop { + tokio::select! { + biased; + _ = cancel_token.cancelled() => { + break; + }, + // handle incoming p2p connections. + incoming = endpoint.accept() => { + let Some(incoming) = incoming else { + break; + }; + + let protocols = protocols.clone(); + join_set.spawn(async move { + handle_connection(incoming, protocols).await; + anyhow::Ok(()) + }); + }, + // handle task terminations and quit on panics. + res = join_set.join_next(), if !join_set.is_empty() => { + match res { + Some(Err(outer)) => { + if outer.is_panic() { + error!("Task panicked: {outer:?}"); + break; + } else if outer.is_cancelled() { + debug!("Task cancelled: {outer:?}"); + } else { + error!("Task failed: {outer:?}"); + break; + } + } + Some(Ok(Err(inner))) => { + debug!("Task errored: {inner:?}"); + } + _ => {} + } + }, + } + } + + shutdown(&endpoint, protocols).await; + + // Abort remaining tasks. + tracing::info!("Shutting down remaining tasks"); + join_set.shutdown().await; + }; + let task = tokio::task::spawn(run_loop_fut); + let task = AbortOnDropHandle::new(task) + .map_err(Box::new(|e: JoinError| e.to_string()) as JoinErrToStr) + .shared(); + + Ok(Router { + endpoint: self.endpoint, + protocols, + task, + cancel_token: cancel, + }) + } +} + +/// Shutdown the different parts of the router concurrently. +async fn shutdown(endpoint: &Endpoint, protocols: Arc) { + let error_code = 1u16; + + // We ignore all errors during shutdown. + let _ = tokio::join!( + // Close the endpoint. + // Closing the Endpoint is the equivalent of calling Connection::close on all + // connections: Operations will immediately fail with ConnectionError::LocallyClosed. + // All streams are interrupted, this is not graceful. + endpoint.close(error_code.into(), b"provider terminating"), + // Shutdown protocol handlers. + protocols.shutdown(), + ); +} + +async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc) { + let mut connecting = match incoming.accept() { + Ok(conn) => conn, + Err(err) => { + warn!("Ignoring connection: accepting failed: {err:#}"); + return; + } + }; + let alpn = match connecting.alpn().await { + Ok(alpn) => alpn, + Err(err) => { + warn!("Ignoring connection: invalid handshake: {err:#}"); + return; + } + }; + let Some(handler) = protocols.get(&alpn) else { + warn!("Ignoring connection: unsupported ALPN protocol"); + return; + }; + if let Err(err) = handler.accept(connecting).await { + warn!("Handling incoming connection ended with error: {err}"); + } +} diff --git a/iroh/src/router.rs b/iroh/src/router.rs deleted file mode 100644 index 8094cbee3fa..00000000000 --- a/iroh/src/router.rs +++ /dev/null @@ -1,223 +0,0 @@ -//! TODO(matheus23) docs -use std::sync::Arc; - -use anyhow::{anyhow, Result}; -use futures_util::{ - future::{MapErr, Shared}, - FutureExt, TryFutureExt, -}; -use tokio::task::{JoinError, JoinSet}; -use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; -use tracing::{debug, error, warn}; - -use crate::{ - protocol::{ProtocolHandler, ProtocolMap}, - Endpoint, -}; - -/// TODO(matheus23): docs -#[derive(Clone, Debug)] -pub struct Router { - endpoint: Endpoint, - protocols: Arc, - // `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl. - // So we need - // - `Shared` so we can `task.await` from all `Node` clones - // - `MapErr` to map the `JoinError` to a `String`, because `JoinError` is `!Clone` - // - `AbortOnDropHandle` to make sure that the `task` is cancelled when all `Node`s are dropped - // (`Shared` acts like an `Arc` around its inner future). - task: Shared, JoinErrToStr>>, - cancel_token: CancellationToken, -} - -type JoinErrToStr = Box String + Send + Sync + 'static>; - -impl Router { - /// TODO(matheus23): docs - pub fn builder(endpoint: Endpoint) -> RouterBuilder { - RouterBuilder::new(endpoint) - } - - /// Returns a protocol handler for an ALPN. - /// - /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` - /// does not match the passed type. - pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - self.protocols.get_typed(alpn) - } - - /// TODO(matheus23): docs - pub fn endpoint(&self) -> &Endpoint { - &self.endpoint - } - - /// TODO(matheus23): docs - pub async fn shutdown(self) -> Result<()> { - // Trigger shutdown of the main run task by activating the cancel token. - self.cancel_token.cancel(); - - // Wait for the main task to terminate. - self.task.await.map_err(|err| anyhow!(err))?; - - Ok(()) - } -} - -/// TODO(matheus23): docs -#[derive(Debug)] -pub struct RouterBuilder { - endpoint: Endpoint, - protocols: ProtocolMap, -} - -impl RouterBuilder { - /// TODO(matheus23): docs - pub fn new(endpoint: Endpoint) -> Self { - Self { - endpoint, - protocols: ProtocolMap::default(), - } - } - - /// TODO(matheus23): docs - pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: Arc) -> Self { - self.protocols.insert(alpn.as_ref().to_vec(), handler); - self - } - - /// Returns the [`Endpoint`] of the node. - pub fn endpoint(&self) -> &Endpoint { - &self.endpoint - } - - /// Returns a protocol handler for an ALPN. - /// - /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` - /// does not match the passed type. - pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - self.protocols.get_typed(alpn) - } - - /// TODO(matheus23): docs - pub async fn spawn(self) -> Result { - // Update the endpoint with our alpns. - let alpns = self - .protocols - .alpns() - .map(|alpn| alpn.to_vec()) - .collect::>(); - - let protocols = Arc::new(self.protocols); - if let Err(err) = self.endpoint.set_alpns(alpns) { - shutdown(&self.endpoint, protocols.clone()).await; - return Err(err); - } - - let mut join_set = JoinSet::new(); - let endpoint = self.endpoint.clone(); - let protos = protocols.clone(); - let cancel = CancellationToken::new(); - let cancel_token = cancel.clone(); - - let run_loop_fut = async move { - let protocols = protos; - loop { - tokio::select! { - biased; - _ = cancel_token.cancelled() => { - break; - }, - // handle incoming p2p connections. - incoming = endpoint.accept() => { - let Some(incoming) = incoming else { - break; - }; - - let protocols = protocols.clone(); - join_set.spawn(async move { - handle_connection(incoming, protocols).await; - anyhow::Ok(()) - }); - }, - // handle task terminations and quit on panics. - res = join_set.join_next(), if !join_set.is_empty() => { - match res { - Some(Err(outer)) => { - if outer.is_panic() { - error!("Task panicked: {outer:?}"); - break; - } else if outer.is_cancelled() { - debug!("Task cancelled: {outer:?}"); - } else { - error!("Task failed: {outer:?}"); - break; - } - } - Some(Ok(Err(inner))) => { - debug!("Task errored: {inner:?}"); - } - _ => {} - } - }, - } - } - - shutdown(&endpoint, protocols).await; - - // Abort remaining tasks. - tracing::info!("Shutting down remaining tasks"); - join_set.shutdown().await; - }; - let task = tokio::task::spawn(run_loop_fut); - let task = AbortOnDropHandle::new(task) - .map_err(Box::new(|e: JoinError| e.to_string()) as JoinErrToStr) - .shared(); - - Ok(Router { - endpoint: self.endpoint, - protocols, - task, - cancel_token: cancel, - }) - } -} - -/// Shutdown the different parts of the router concurrently. -async fn shutdown(endpoint: &Endpoint, protocols: Arc) { - let error_code = 1u16; - - // We ignore all errors during shutdown. - let _ = tokio::join!( - // Close the endpoint. - // Closing the Endpoint is the equivalent of calling Connection::close on all - // connections: Operations will immediately fail with ConnectionError::LocallyClosed. - // All streams are interrupted, this is not graceful. - endpoint.close(error_code.into(), b"provider terminating"), - // Shutdown protocol handlers. - protocols.shutdown(), - ); -} - -async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc) { - let mut connecting = match incoming.accept() { - Ok(conn) => conn, - Err(err) => { - warn!("Ignoring connection: accepting failed: {err:#}"); - return; - } - }; - let alpn = match connecting.alpn().await { - Ok(alpn) => alpn, - Err(err) => { - warn!("Ignoring connection: invalid handshake: {err:#}"); - return; - } - }; - let Some(handler) = protocols.get(&alpn) else { - warn!("Ignoring connection: unsupported ALPN protocol"); - return; - }; - if let Err(err) = handler.accept(connecting).await { - warn!("Handling incoming connection ended with error: {err}"); - } -}