From 902933b5d9545ad52f486c2e2a23935ac776cab6 Mon Sep 17 00:00:00 2001 From: Jean-Pierre Smith Date: Tue, 12 Dec 2023 15:21:16 +0100 Subject: [PATCH] refactor: allow setters to be used on shared socket --- crates/scion/src/dispatcher.rs | 17 +- crates/scion/src/udp_socket.rs | 222 +++++++++++++++++--------- crates/scion/tests/test_udp_socket.rs | 15 +- 3 files changed, 168 insertions(+), 86 deletions(-) diff --git a/crates/scion/src/dispatcher.rs b/crates/scion/src/dispatcher.rs index aa5cfe4..a0d332e 100644 --- a/crates/scion/src/dispatcher.rs +++ b/crates/scion/src/dispatcher.rs @@ -97,18 +97,23 @@ pub struct DispatcherStream { } impl DispatcherStream { + /// Create a new DispatcherStream over an already connected UnixStream. + pub fn new(stream: UnixStream) -> Self { + Self { + inner: stream, + send_buffer: BytesMut::with_capacity(SEND_BUFFER_LEN), + recv_buffer: BytesMut::with_capacity(RECV_BUFFER_LEN), + parser: StreamParser::new(), + } + } + /// Connects to the dispatcher over a Unix socket at the provided path. pub async fn connect + std::fmt::Debug>(path: P) -> Result { tracing::trace!(?path, "connecting to dispatcher"); let inner = UnixStream::connect(path).await?; tracing::trace!("successfully connected"); - Ok(Self { - inner, - send_buffer: BytesMut::with_capacity(SEND_BUFFER_LEN), - recv_buffer: BytesMut::with_capacity(RECV_BUFFER_LEN), - parser: StreamParser::new(), - }) + Ok(Self::new(inner)) } /// Register to receive SCION packet for the given address and port. diff --git a/crates/scion/src/udp_socket.rs b/crates/scion/src/udp_socket.rs index 1d9f5d2..4b8dbd0 100644 --- a/crates/scion/src/udp_socket.rs +++ b/crates/scion/src/udp_socket.rs @@ -2,7 +2,11 @@ //! A socket to send UDP datagrams via SCION. -use std::{cmp, io, sync::Arc}; +use std::{ + cmp, + io, + sync::{Arc, RwLock}, +}; use bytes::Bytes; use chrono::Utc; @@ -20,7 +24,7 @@ use crate::dispatcher::{self, get_dispatcher_path, DispatcherStream, Registratio #[allow(missing_docs)] #[derive(Debug, thiserror::Error)] -pub enum ConnectError { +pub enum BindError { #[error("failed to connect to the dispatcher, reason: {0}")] DispatcherConnectFailed(#[from] io::Error), #[error("failed to bind to the requested port")] @@ -37,7 +41,9 @@ pub enum SendError { #[error("path is expired")] PathExpired, #[error("remote address is not set")] - NoRemoteAddress, + NotConnected, + #[error("socket is already connected")] + AlreadyConnected, #[error("path is not set")] NoPath, #[error("no underlay next hop provided by path")] @@ -72,34 +78,32 @@ impl From for SendError { #[derive(Debug)] pub struct UdpSocket { inner: Arc, - local_address: SocketAddr, - remote_address: Option, - path: Option, } impl UdpSocket { - pub async fn bind(address: SocketAddr) -> Result { + pub async fn bind(address: SocketAddr) -> Result { Self::bind_with_dispatcher(address, get_dispatcher_path()).await } pub async fn bind_with_dispatcher + std::fmt::Debug>( address: SocketAddr, dispatcher_path: P, - ) -> Result { + ) -> Result { let mut stream = DispatcherStream::connect(dispatcher_path).await?; let local_address = stream.register(address).await?; - Ok(Self { - inner: Arc::new(UdpSocketInner::new(stream)), - local_address, - remote_address: None, - path: None, - }) + Ok(Self::new(stream, local_address)) + } + + fn new(stream: DispatcherStream, local_addr: SocketAddr) -> Self { + Self { + inner: Arc::new(UdpSocketInner::new(stream, local_addr)), + } } /// Returns the local SCION address to which this socket is bound. pub fn local_addr(&self) -> SocketAddr { - self.local_address + self.inner.local_addr() } /// Receive a SCION UDP packet. @@ -151,23 +155,22 @@ impl UdpSocket { /// Returns the remote SCION address set for this socket, if any. pub fn remote_addr(&self) -> Option { - self.remote_address + self.inner.remote_addr() } /// Returns the SCION path set for this socket, if any. - pub fn path(&self) -> Option<&Path> { - self.path.as_ref() + pub fn path(&self) -> Option { + self.inner.path() } /// Registers a remote address for this socket. - pub fn connect(&mut self, remote_address: SocketAddr) -> &mut Self { - self.remote_address = Some(remote_address); - self + pub fn connect(&self, remote_address: SocketAddr) { + self.inner.set_remote_address(remote_address); } - /// Registers a path for this socket. - pub fn set_path(&mut self, path: Path) -> &mut Self { - self.path = Some(path); + /// Registers or clears a path for this socket. + pub fn set_path(&self, path: Option) -> &Self { + self.inner.set_path(path); self } @@ -175,67 +178,41 @@ impl UdpSocket { /// /// Returns an error if the remote address or path are unset pub async fn send(&self, payload: Bytes) -> Result<(), SendError> { - self.send_to_with( - payload, - self.remote_address.ok_or(SendError::NoRemoteAddress)?, - self.path.as_ref().ok_or(SendError::NoPath)?, - ) - .await + self.inner.send_with_to(payload, None, None).await } /// Sends the payload to the specified destination using the registered path /// /// Returns an error if the path is unset pub async fn send_to(&self, payload: Bytes, destination: SocketAddr) -> Result<(), SendError> { - self.send_to_with( - payload, - destination, - self.path.as_ref().ok_or(SendError::NoPath)?, - ) - .await + self.inner + .send_with_to(payload, Some(destination), None) + .await } /// Sends the payload to the registered destination using the specified path /// /// Returns an error if the remote address is unset pub async fn send_with(&self, payload: Bytes, path: &Path) -> Result<(), SendError> { - self.send_to_with( - payload, - self.remote_address.ok_or(SendError::NoRemoteAddress)?, - path, - ) - .await + self.inner.send_with_to(payload, None, Some(path)).await } /// Sends the payload to the specified remote address and path - pub async fn send_to_with( + pub async fn send_with_to( &self, payload: Bytes, destination: SocketAddr, path: &Path, ) -> Result<(), SendError> { self.inner - .send_between_with( - payload, - &ByEndpoint { - destination, - source: self.local_addr(), - }, - path, - ) - .await?; - Ok(()) + .send_with_to(payload, Some(destination), Some(path)) + .await } } /// Error messages returned from the UDP socket. pub type ReceiveError = std::convert::Infallible; -#[derive(Debug)] -struct UdpSocketInner { - state: Mutex, -} - macro_rules! log_err { ($message:expr) => { |err| { @@ -245,19 +222,41 @@ macro_rules! log_err { }; } +#[derive(Debug)] +struct UdpSocketInner { + stream: Mutex, + state: RwLock>, +} + impl UdpSocketInner { - fn new(stream: DispatcherStream) -> Self { + fn new(stream: DispatcherStream, local_address: SocketAddr) -> Self { Self { - state: Mutex::new(State { stream }), + state: RwLock::new(Arc::new(State { + local_address, + remote_address: None, + path: None, + })), + stream: Mutex::new(stream), } } - async fn send_between_with( + async fn send_with_to( &self, payload: Bytes, - endhosts: &ByEndpoint, - path: &Path, + destination: Option, + path: Option<&Path>, ) -> Result<(), SendError> { + let state = self.state.read().unwrap().clone(); + let path = path.or(state.path.as_ref()).ok_or(SendError::NoPath)?; + let Some(destination) = destination.xor(state.remote_address) else { + // Either both are None or both are Some + return if state.remote_address.is_none() { + Err(SendError::NotConnected) + } else { + Err(SendError::AlreadyConnected) + }; + }; + if let Some(metadata) = &path.metadata { if metadata.expiration < Utc::now() { return Err(SendError::PathExpired); @@ -266,8 +265,8 @@ impl UdpSocketInner { let relay = if path.underlay_next_hop.is_some() { path.underlay_next_hop - } else if endhosts.source.isd_asn() == endhosts.destination.isd_asn() { - endhosts.destination.local_address().map(|mut socket_addr| { + } else if state.local_address.isd_asn() == destination.isd_asn() { + destination.local_address().map(|mut socket_addr| { socket_addr.set_port(dispatcher::UNDERLAY_PORT); socket_addr }) @@ -275,12 +274,18 @@ impl UdpSocketInner { return Err(SendError::NoUnderlayNextHop); }; - let packet = ScionPacketUdp::new(endhosts, path, payload)?; + let packet = ScionPacketUdp::new( + &ByEndpoint { + destination, + source: state.local_address, + }, + path, + payload, + )?; - self.state + self.stream .lock() .await - .stream .send_packet_via(relay, packet) .await?; Ok(()) @@ -309,8 +314,8 @@ impl UdpSocketInner { async fn recv_loop(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> { loop { let receive_result = { - let state = &mut *self.state.lock().await; - state.stream.receive_packet().await + let stream = &mut *self.stream.lock().await; + stream.receive_packet().await }; match receive_result { @@ -366,9 +371,82 @@ impl UdpSocketInner { ), )) } + + pub fn local_addr(&self) -> SocketAddr { + self.state.read().unwrap().local_address + } + + pub fn remote_addr(&self) -> Option { + self.state.read().unwrap().remote_address + } + + pub fn set_remote_address(&self, remote_address: SocketAddr) { + Arc::make_mut(&mut *self.state.write().unwrap()).remote_address = Some(remote_address); + } + + pub fn path(&self) -> Option { + self.state.read().unwrap().path.clone() + } + + pub fn set_path(&self, path: Option) { + Arc::make_mut(&mut *self.state.write().unwrap()).path = path; + } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct State { - stream: DispatcherStream, + local_address: SocketAddr, + remote_address: Option, + path: Option, +} + +#[cfg(test)] +mod tests { + use tokio::{net::UnixStream, sync::Notify}; + + use super::*; + + fn new_socket() -> Result<(SocketAddr, UdpSocket), Box> { + let (inner, _) = UnixStream::pair()?; + let stream = DispatcherStream::new(inner); + let local_addr: SocketAddr = "[1-ff00:0:111,127.0.0.17]:12300".parse()?; + + Ok((local_addr, UdpSocket::new(stream, local_addr))) + } + + #[tokio::test] + async fn set_path() -> Result<(), Box> { + let (local_addr, socket) = new_socket()?; + + let path = Path::empty(ByEndpoint::with_cloned(local_addr.isd_asn())); + + let notify = Arc::new(Notify::new()); + let notify2 = Arc::new(Notify::new()); + + let (result1, result2) = tokio::join!( + async { + let initial = socket.path(); + socket.set_path(Some(path.clone())); + notify.notify_one(); + + notify2.notified().await; + let last_set = socket.path(); + + (initial, last_set) + }, + async { + notify.notified().await; + let first_set = socket.path(); + socket.set_path(None); + notify2.notify_one(); + + first_set + } + ); + + assert_eq!(result1, (None, None)); + assert_eq!(result2, Some(path)); + + Ok(()) + } } diff --git a/crates/scion/tests/test_udp_socket.rs b/crates/scion/tests/test_udp_socket.rs index 050b6c9..711bd73 100644 --- a/crates/scion/tests/test_udp_socket.rs +++ b/crates/scion/tests/test_udp_socket.rs @@ -8,7 +8,7 @@ use scion::{ use scion_proto::{address::SocketAddr, packet::ByEndpoint, path::Path}; use tokio::sync::Mutex; -type TestError = Result<(), Box>; +type TestResult = Result>; static MESSAGE: Bytes = Bytes::from_static(b"Hello SCION!"); const TIMEOUT: Duration = std::time::Duration::from_secs(1); @@ -24,8 +24,7 @@ macro_rules! test_send_receive_reply { LOCK.get_or_init(|| Mutex::default()) } - async fn get_sockets( - ) -> Result<(UdpSocket, UdpSocket, Path), Box> { + async fn get_sockets() -> TestResult<(UdpSocket, UdpSocket, Path)> { let endpoints: ByEndpoint = ByEndpoint { source: $source.parse().unwrap(), destination: $destination.parse().unwrap(), @@ -33,7 +32,7 @@ macro_rules! test_send_receive_reply { let daemon_client_source = DaemonClient::connect(&get_daemon_address()) .await .expect("should be able to connect"); - let mut socket_source = UdpSocket::bind(endpoints.source).await?; + let socket_source = UdpSocket::bind(endpoints.source).await?; let socket_destination = UdpSocket::bind(endpoints.destination).await?; socket_source.connect(endpoints.destination); @@ -43,14 +42,14 @@ macro_rules! test_send_receive_reply { .next() .unwrap(); println!("Forward path: {:?}", path_forward.dataplane_path); - socket_source.set_path(path_forward.clone()); + socket_source.set_path(Some(path_forward.clone())); Ok((socket_source, socket_destination, path_forward)) } #[tokio::test] #[ignore = "requires daemon and dispatcher"] - async fn message() -> TestError { + async fn message() -> TestResult { let _lock = lock().lock().await; let (socket_source, socket_destination, ..) = get_sockets().await?; @@ -67,7 +66,7 @@ macro_rules! test_send_receive_reply { #[tokio::test] #[ignore = "requires daemon and dispatcher"] - async fn message_and_response() -> TestError { + async fn message_and_response() -> TestResult { let _lock = lock().lock().await; let (socket_source, socket_destination, path_forward) = get_sockets().await?; @@ -84,7 +83,7 @@ macro_rules! test_send_receive_reply { println!("Reply path: {:?}", path.dataplane_path); socket_destination - .send_to_with(MESSAGE.clone(), sender, &path) + .send_with_to(MESSAGE.clone(), sender, &path) .await?; let (_, path_return) =