Skip to content

Commit

Permalink
refactor(iroh-relay): improve overall server structure (#2922)
Browse files Browse the repository at this point in the history
## Description

Trying to make the server relay code a little less crazy.
- moves `stun_metrics` from `server` to `server::metrics` which is a
more appropriate place for this.
- moves `ServerMessage` (the messages the server actor receives) from
`server` to `server::actor`
- removes `ServerMessage::Shutdown` since there was already a
`CancellationToken`, thus reducing the ways to shutdown to simply one.
- removes the boolean field `ServerActorTask::closed`. We don't need
this is we correctly handle shutdown on drop.
- Makes `write_timeout` not an option. This was never `None`.
- inline where possible to simplify
- reduce types
- some renaming of types, to make them easier to understand

## Breaking Changes

- `iroh-relay` now uses `NodeGone` instead of `PeerGone` in some enums,
but `iroh-relay` is unpublished, so technically not a breaking change

---------

Co-authored-by: Floris Bruynooghe <flub@n0.computer>
  • Loading branch information
dignifiedquire and flub authored Nov 13, 2024
1 parent 289b4cf commit 0e57292
Show file tree
Hide file tree
Showing 13 changed files with 998 additions and 1,276 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion iroh-net/src/magicsock/relay_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl ActiveRelay {
ReadResult::Continue
}
ReceivedMessage::Health { .. } => ReadResult::Continue,
ReceivedMessage::PeerGone(key) => {
ReceivedMessage::NodeGone(key) => {
self.node_present.remove(&key);
ReadResult::Continue
}
Expand Down
1 change: 1 addition & 0 deletions iroh-relay/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ base64 = "0.22.1"
bytes = "1.7"
clap = { version = "4", features = ["derive"], optional = true }
derive_more = { version = "1.0.0", features = ["debug", "display", "from", "try_into", "deref"] }
futures-buffered = "0.2.9"
futures-lite = "2.3"
futures-sink = "0.3.25"
futures-util = "0.3.25"
Expand Down
27 changes: 15 additions & 12 deletions iroh-relay/src/client/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use futures_util::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use iroh_base::key::{PublicKey, SecretKey};
use iroh_base::key::{NodeId, SecretKey};
use tokio::sync::mpsc;
use tokio_tungstenite_wasm::WebSocketStream;
use tokio_util::{
Expand Down Expand Up @@ -97,12 +97,12 @@ impl Conn {
/// Sends a packet to the node identified by `dstkey`
///
/// Errors if the packet is larger than [`MAX_PACKET_SIZE`]
pub async fn send(&self, dstkey: PublicKey, packet: Bytes) -> Result<()> {
trace!(%dstkey, len = packet.len(), "[RELAY] send");
pub async fn send(&self, dst: NodeId, packet: Bytes) -> Result<()> {
trace!(%dst, len = packet.len(), "[RELAY] send");

self.inner
.writer_channel
.send(ConnWriterMessage::Packet((dstkey, packet)))
.send(ConnWriterMessage::Packet((dst, packet)))
.await?;
Ok(())
}
Expand Down Expand Up @@ -176,7 +176,7 @@ fn process_incoming_frame(frame: Frame) -> Result<ReceivedMessage> {
// This predated FrameType::Ping/FrameType::Pong.
Ok(ReceivedMessage::KeepAlive)
}
Frame::PeerGone { peer } => Ok(ReceivedMessage::PeerGone(peer)),
Frame::NodeGone { node_id } => Ok(ReceivedMessage::NodeGone(node_id)),
Frame::RecvPacket { src_key, content } => {
let packet = ReceivedMessage::ReceivedPacket {
source: src_key,
Expand Down Expand Up @@ -209,8 +209,8 @@ fn process_incoming_frame(frame: Frame) -> Result<ReceivedMessage> {
/// The kinds of messages we can send to the [`Server`](crate::server::Server)
#[derive(Debug)]
enum ConnWriterMessage {
/// Send a packet (addressed to the [`PublicKey`]) to the server
Packet((PublicKey, Bytes)),
/// Send a packet (addressed to the [`NodeId`]) to the server
Packet((NodeId, Bytes)),
/// Send a pong to the server
Pong([u8; 8]),
/// Send a ping to the server
Expand Down Expand Up @@ -450,15 +450,15 @@ impl ConnBuilder {
pub enum ReceivedMessage {
/// Represents an incoming packet.
ReceivedPacket {
/// The [`PublicKey`] of the packet sender.
source: PublicKey,
/// The [`NodeId`] of the packet sender.
source: NodeId,
/// The received packet bytes.
#[debug(skip)]
data: Bytes, // TODO: ref
},
/// Indicates that the client identified by the underlying public key had previously sent you a
/// packet but has now disconnected from the server.
PeerGone(PublicKey),
NodeGone(NodeId),
/// Request from a client or server to reply to the
/// other side with a [`ReceivedMessage::Pong`] with the given payload.
Ping([u8; 8]),
Expand Down Expand Up @@ -495,7 +495,7 @@ pub enum ReceivedMessage {
pub(crate) async fn send_packet<S: Sink<Frame, Error = std::io::Error> + Unpin>(
mut writer: S,
rate_limiter: &Option<RateLimiter>,
dst_key: PublicKey,
dst: NodeId,
packet: Bytes,
) -> Result<()> {
ensure!(
Expand All @@ -504,7 +504,10 @@ pub(crate) async fn send_packet<S: Sink<Frame, Error = std::io::Error> + Unpin>(
packet.len()
);

let frame = Frame::SendPacket { dst_key, packet };
let frame = Frame::SendPacket {
dst_key: dst,
packet,
};
if let Some(rate_limiter) = rate_limiter {
if rate_limiter.check_n(frame.len()).is_err() {
tracing::warn!("dropping send: rate limit reached");
Expand Down
18 changes: 9 additions & 9 deletions iroh-relay/src/protos/relay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ pub(crate) enum Frame {
NotePreferred {
preferred: bool,
},
PeerGone {
peer: PublicKey,
NodeGone {
node_id: PublicKey,
},
Ping {
data: [u8; 8],
Expand All @@ -248,7 +248,7 @@ impl Frame {
Frame::RecvPacket { .. } => FrameType::RecvPacket,
Frame::KeepAlive => FrameType::KeepAlive,
Frame::NotePreferred { .. } => FrameType::NotePreferred,
Frame::PeerGone { .. } => FrameType::PeerGone,
Frame::NodeGone { .. } => FrameType::PeerGone,
Frame::Ping { .. } => FrameType::Ping,
Frame::Pong { .. } => FrameType::Pong,
Frame::Health { .. } => FrameType::Health,
Expand All @@ -271,7 +271,7 @@ impl Frame {
} => PUBLIC_KEY_LENGTH + content.len(),
Frame::KeepAlive => 0,
Frame::NotePreferred { .. } => 1,
Frame::PeerGone { .. } => PUBLIC_KEY_LENGTH,
Frame::NodeGone { .. } => PUBLIC_KEY_LENGTH,
Frame::Ping { .. } => 8,
Frame::Pong { .. } => 8,
Frame::Health { problem } => problem.len(),
Expand Down Expand Up @@ -331,7 +331,7 @@ impl Frame {
dst.put_u8(NOT_PREFERRED);
}
}
Frame::PeerGone { peer } => {
Frame::NodeGone { node_id: peer } => {
dst.put(peer.as_ref());
}
Frame::Ping { data } => {
Expand Down Expand Up @@ -430,7 +430,7 @@ impl Frame {
"invalid peer gone frame length"
);
let peer = PublicKey::try_from(&content[..32])?;
Self::PeerGone { peer }
Self::NodeGone { node_id: peer }
}
FrameType::Ping => {
anyhow::ensure!(content.len() == 8, "invalid ping frame length");
Expand Down Expand Up @@ -636,8 +636,8 @@ mod tests {
(Frame::KeepAlive, "06"),
(Frame::NotePreferred { preferred: true }, "07 01"),
(
Frame::PeerGone {
peer: client_key.public(),
Frame::NodeGone {
node_id: client_key.public(),
},
"08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e
a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d
Expand Down Expand Up @@ -731,7 +731,7 @@ mod proptests {
(key(), data(32)).prop_map(|(src_key, content)| Frame::RecvPacket { src_key, content });
let keep_alive = Just(Frame::KeepAlive);
let note_preferred = any::<bool>().prop_map(|preferred| Frame::NotePreferred { preferred });
let peer_gone = key().prop_map(|peer| Frame::PeerGone { peer });
let peer_gone = key().prop_map(|peer| Frame::NodeGone { node_id: peer });
let ping = prop::array::uniform8(any::<u8>()).prop_map(|data| Frame::Ping { data });
let pong = prop::array::uniform8(any::<u8>()).prop_map(|data| Frame::Pong { data });
let health = data(0).prop_map(|problem| Frame::Health { problem });
Expand Down
104 changes: 25 additions & 79 deletions iroh-relay/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,24 @@ use http::{
};
use hyper::body::Incoming;
use iroh_metrics::inc;
// Module defined in this file.
use stun_metrics::StunMetrics;
use tokio::{
net::{TcpListener, UdpSocket},
task::JoinSet,
};
use tokio_util::task::AbortOnDropHandle;
use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument};

use crate::{http::RELAY_PROBE_PATH, protos::stun};
use crate::{http::RELAY_PROBE_PATH, protos};

pub(crate) mod actor;
pub(crate) mod client_conn;
mod clients;
mod http_server;
mod metrics;
pub(crate) mod streams;
pub(crate) mod types;

pub use self::{
actor::{ClientConnHandler, ServerActorTask},
metrics::Metrics,
metrics::{Metrics, StunMetrics},
streams::MaybeTlsStream as MaybeTlsStreamServer,
};

Expand Down Expand Up @@ -401,17 +397,15 @@ async fn relay_supervisor(
mut relay_http_server: Option<http_server::Server>,
) -> Result<()> {
let res = match (relay_http_server.as_mut(), tasks.len()) {
(None, _) => tasks
.join_next()
.await
.unwrap_or_else(|| Ok(Err(anyhow!("Nothing to supervise")))),
(None, 0) => Ok(Err(anyhow!("Nothing to supervise"))),
(None, _) => tasks.join_next().await.expect("checked"),
(Some(relay), 0) => relay.task_handle().await.map(anyhow::Ok),
(Some(relay), _) => {
tokio::select! {
biased;
Some(ret) = tasks.join_next() => ret,

ret = tasks.join_next() => ret.expect("checked"),
ret = relay.task_handle() => ret.map(anyhow::Ok),
else => Ok(Err(anyhow!("Empty JoinSet (unreachable)"))),
}
}
};
Expand All @@ -435,11 +429,12 @@ async fn relay_supervisor(
};

// Ensure the HTTP server terminated, there is no harm in calling this after it is
// already shut down. The JoinSet is aborted on drop.
// already shut down.
if let Some(server) = relay_http_server {
server.shutdown();
}

// Stop all remaining tasks
tasks.shutdown().await;

ret
Expand Down Expand Up @@ -469,7 +464,7 @@ async fn server_stun_listener(sock: UdpSocket) -> Result<()> {
Ok((n, src_addr)) => {
inc!(StunMetrics, requests);
let pkt = &buffer[..n];
if !stun::is(pkt) {
if !protos::stun::is(pkt) {
debug!(%src_addr, "STUN: ignoring non stun packet");
inc!(StunMetrics, bad_requests);
continue;
Expand All @@ -489,20 +484,19 @@ async fn server_stun_listener(sock: UdpSocket) -> Result<()> {

/// Handles a single STUN request, doing all logging required.
async fn handle_stun_request(src_addr: SocketAddr, pkt: Vec<u8>, sock: Arc<UdpSocket>) {
let handle =
AbortOnDropHandle::new(tokio::task::spawn_blocking(
move || match stun::parse_binding_request(&pkt) {
Ok(txid) => {
debug!(%src_addr, %txid, "STUN: received binding request");
Some((txid, stun::response(txid, src_addr)))
}
Err(err) => {
inc!(StunMetrics, bad_requests);
warn!(%src_addr, "STUN: invalid binding request: {:?}", err);
None
}
},
));
let handle = AbortOnDropHandle::new(tokio::task::spawn_blocking(move || {
match protos::stun::parse_binding_request(&pkt) {
Ok(txid) => {
debug!(%src_addr, %txid, "STUN: received binding request");
Some((txid, protos::stun::response(txid, src_addr)))
}
Err(err) => {
inc!(StunMetrics, bad_requests);
warn!(%src_addr, "STUN: invalid binding request: {:?}", err);
None
}
}
}));
let (txid, response) = match handle.await {
Ok(Some(val)) => val,
Ok(None) => return,
Expand Down Expand Up @@ -679,54 +673,6 @@ impl hyper::service::Service<Request<Incoming>> for CaptivePortalService {
}
}

mod stun_metrics {
use iroh_metrics::{
core::{Counter, Metric},
struct_iterable::Iterable,
};

/// StunMetrics tracked for the DERPER
#[allow(missing_docs)]
#[derive(Debug, Clone, Iterable)]
pub struct StunMetrics {
/*
* Metrics about STUN requests over ipv6
*/
/// Number of stun requests made
pub requests: Counter,
/// Number of successful requests over ipv4
pub ipv4_success: Counter,
/// Number of successful requests over ipv6
pub ipv6_success: Counter,

/// Number of bad requests, either non-stun packets or incorrect binding request
pub bad_requests: Counter,
/// Number of failures
pub failures: Counter,
}

impl Default for StunMetrics {
fn default() -> Self {
Self {
/*
* Metrics about STUN requests
*/
requests: Counter::new("Number of STUN requests made to the server."),
ipv4_success: Counter::new("Number of successful ipv4 STUN requests served."),
ipv6_success: Counter::new("Number of successful ipv6 STUN requests served."),
bad_requests: Counter::new("Number of bad requests made to the STUN endpoint."),
failures: Counter::new("Number of STUN requests that end in failure."),
}
}
}

impl Metric for StunMetrics {
fn name() -> &'static str {
"stun"
}
}
}

#[cfg(test)]
mod tests {
use std::{net::Ipv4Addr, time::Duration};
Expand Down Expand Up @@ -1052,8 +998,8 @@ mod tests {
.await
.unwrap();

let txid = stun::TransactionId::default();
let req = stun::request(txid);
let txid = protos::stun::TransactionId::default();
let req = protos::stun::request(txid);
let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
socket
.send_to(&req, server.stun_addr().unwrap())
Expand All @@ -1065,7 +1011,7 @@ mod tests {
let (len, addr) = socket.recv_from(&mut buf).await.unwrap();
assert_eq!(addr, server.stun_addr().unwrap());
buf.truncate(len);
let (txid_back, response_addr) = stun::parse_response(&buf).unwrap();
let (txid_back, response_addr) = protos::stun::parse_response(&buf).unwrap();
assert_eq!(txid, txid_back);
assert_eq!(response_addr, socket.local_addr().unwrap());
}
Expand Down
Loading

0 comments on commit 0e57292

Please sign in to comment.