diff --git a/crates/scion-proto/src/packet/headers.rs b/crates/scion-proto/src/packet/headers.rs index c1cd6e9..7b67325 100644 --- a/crates/scion-proto/src/packet/headers.rs +++ b/crates/scion-proto/src/packet/headers.rs @@ -103,7 +103,15 @@ pub struct ByEndpoint { impl ByEndpoint { /// Swaps source and destination. - pub fn reverse(&mut self) -> &mut Self { + pub fn reverse(self) -> Self { + Self { + source: self.destination, + destination: self.source, + } + } + + /// Swaps source and destination in place. + pub fn reverse_in_place(&mut self) -> &mut Self { std::mem::swap(&mut self.source, &mut self.destination); self } diff --git a/crates/scion/src/udp_socket.rs b/crates/scion/src/udp_socket.rs index 5980b9f..a959316 100644 --- a/crates/scion/src/udp_socket.rs +++ b/crates/scion/src/udp_socket.rs @@ -113,6 +113,9 @@ impl UdpSocket { /// - the path over which the packet was received. For supported path types, this path is /// already reversed such that it can be used directly to send reply packets; for unsupported /// path types, the path is unmodified. + /// + /// Note that copying/reversing the path requires allocating memory; if you do not need the path + /// information, consider using the method [`Self::recv_from_without_path`] instead. pub async fn recv_from( &self, buffer: &mut [u8], @@ -120,6 +123,20 @@ impl UdpSocket { self.inner.recv_from(buffer).await } + /// Receive a SCION UDP packet from a remote endpoint. + /// + /// The UDP payload is written into the provided buffer. If there is insufficient space, excess + /// data is dropped. The returned number of bytes always refers to the amount of data in the UDP + /// payload. + /// + /// Additionally returns the remote SCION socket address. + pub async fn recv_from_without_path( + &self, + buffer: &mut [u8], + ) -> Result<(usize, SocketAddr), ReceiveError> { + self.inner.recv_from_without_path(buffer).await + } + /// Returns the remote SCION address set for this socket, if any. pub fn remote_addr(&self) -> Option { self.remote_address @@ -258,6 +275,42 @@ impl UdpSocketInner { } async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> { + let (packet_len, sender, last_host, scion_packet) = self.recv_from_loop(buf).await?; + let path = { + // Explicit match here in case we add other errors to the `reverse` method at some point + let dataplane_path = match scion_packet.headers.path.reverse() { + Ok(p) => p, + Err(UnsupportedPathType(_)) => scion_packet.headers.path.deep_copy(), + }; + Path::new( + dataplane_path, + scion_packet.headers.address.ia.reverse(), + last_host, + ) + }; + Ok((packet_len, sender, path)) + } + + async fn recv_from_without_path( + &self, + buf: &mut [u8], + ) -> Result<(usize, SocketAddr), ReceiveError> { + let (packet_len, sender, ..) = self.recv_from_loop(buf).await?; + Ok((packet_len, sender)) + } + + async fn recv_from_loop( + &self, + buf: &mut [u8], + ) -> Result< + ( + usize, + SocketAddr, + Option, + ScionPacketRaw, + ), + ReceiveError, + > { loop { let receive_result = { let state = &mut *self.state.lock().await; @@ -266,8 +319,11 @@ impl UdpSocketInner { match receive_result { Ok(packet) => { - if let Some(result) = self.parse_incoming(packet, buf) { - return Ok(result); + let last_host = packet.last_host; + if let Some((packet_len, sender, scion_packet)) = + self.parse_incoming(packet, buf) + { + return Ok((packet_len, sender, last_host, scion_packet)); } else { continue; } @@ -281,7 +337,7 @@ impl UdpSocketInner { &self, mut packet: Packet, buf: &mut [u8], - ) -> Option<(usize, SocketAddr, Path)> { + ) -> Option<(usize, SocketAddr, ScionPacketRaw)> { // TODO(jsmith): Need a representation of the packets for logging purposes let mut scion_packet = ScionPacketRaw::decode(&mut packet.content) .map_err(log_err!("failed to decode SCION packet")) @@ -303,24 +359,11 @@ impl UdpSocketInner { return None; }; - let path = { - // Explicit match here in case we add other errors to the `reverse` method at some point - let dataplane_path = match scion_packet.headers.path.reverse() { - Ok(p) => p, - Err(UnsupportedPathType(_)) => scion_packet.headers.path.deep_copy(), - }; - Path::new( - dataplane_path, - *scion_packet.headers.address.ia.reverse(), - packet.last_host, - ) - }; - let payload_len = udp_datagram.payload.len(); let copy_length = cmp::min(payload_len, buf.len()); buf[..copy_length].copy_from_slice(&udp_datagram.payload[..copy_length]); - Some((payload_len, source, path)) + Some((payload_len, source, scion_packet)) } } diff --git a/crates/scion/tests/test_udp_socket.rs b/crates/scion/tests/test_udp_socket.rs index 15add9d..5e9257a 100644 --- a/crates/scion/tests/test_udp_socket.rs +++ b/crates/scion/tests/test_udp_socket.rs @@ -1,62 +1,99 @@ +use std::{sync::OnceLock, time::Duration}; + use bytes::Bytes; use scion::{ daemon::{get_daemon_address, DaemonClient}, udp_socket::UdpSocket, }; -use scion_proto::{address::SocketAddr, packet::ByEndpoint}; +use scion_proto::{address::SocketAddr, packet::ByEndpoint, path::Path}; +use tokio::sync::Mutex; type TestError = Result<(), Box>; static MESSAGE: Bytes = Bytes::from_static(b"Hello SCION!"); +const TIMEOUT: Duration = std::time::Duration::from_secs(1); macro_rules! test_send_receive_reply { ($name:ident, $source:expr, $destination:expr) => { - #[tokio::test] - #[ignore = "requires daemon and dispatcher"] - async fn $name() -> TestError { - let endpoints: ByEndpoint = ByEndpoint { - source: $source.parse().unwrap(), - destination: $destination.parse().unwrap(), - }; - let daemon_client_source = DaemonClient::connect(&get_daemon_address()) - .await - .expect("should be able to connect"); - let mut socket_source = UdpSocket::bind(endpoints.source).await?; - let socket_destination = UdpSocket::bind(endpoints.destination).await?; - - socket_source.connect(endpoints.destination); - let path_forward = daemon_client_source - .paths_to(endpoints.destination.isd_asn()) - .await? - .next() - .unwrap(); - println!("Forward path: {:?}", path_forward.dataplane_path); - socket_source.set_path(path_forward.clone()); - socket_source.send(MESSAGE.clone()).await?; - - let mut buffer = [0_u8; 100]; - let (length, sender, path) = tokio::time::timeout( - std::time::Duration::from_secs(1), - socket_destination.recv_from(&mut buffer), - ) - .await??; - assert_eq!(sender, endpoints.source); - assert_eq!(buffer[..length], MESSAGE[..]); - - println!("Reply path: {:?}", path.dataplane_path); - socket_destination - .send_to_with(MESSAGE.clone(), sender, &path) - .await?; - - let (_, _, path_return) = tokio::time::timeout( - std::time::Duration::from_secs(1), - socket_source.recv_from(&mut buffer), - ) - .await??; - assert_eq!(path_return.isd_asn, path_forward.isd_asn); - assert_eq!(path_return.dataplane_path, path_forward.dataplane_path); - - Ok(()) + mod $name { + use super::*; + + // Prevent tests running simultaneously to avoid registration errors from the dispatcher + fn lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::default()) + } + + async fn get_sockets( + ) -> Result<(UdpSocket, UdpSocket, Path), Box> { + let endpoints: ByEndpoint = ByEndpoint { + source: $source.parse().unwrap(), + destination: $destination.parse().unwrap(), + }; + let daemon_client_source = DaemonClient::connect(&get_daemon_address()) + .await + .expect("should be able to connect"); + let mut socket_source = UdpSocket::bind(endpoints.source).await?; + let socket_destination = UdpSocket::bind(endpoints.destination).await?; + + socket_source.connect(endpoints.destination); + let path_forward = daemon_client_source + .paths_to(endpoints.destination.isd_asn()) + .await? + .next() + .unwrap(); + println!("Forward path: {:?}", path_forward.dataplane_path); + socket_source.set_path(path_forward.clone()); + + Ok((socket_source, socket_destination, path_forward)) + } + + #[tokio::test] + #[ignore = "requires daemon and dispatcher"] + async fn message() -> TestError { + let _lock = lock().lock().await; + + let (socket_source, socket_destination, ..) = get_sockets().await?; + socket_source.send(MESSAGE.clone()).await?; + + let mut buffer = [0_u8; 100]; + let (length, sender) = tokio::time::timeout( + TIMEOUT, + socket_destination.recv_from_without_path(&mut buffer), + ) + .await??; + assert_eq!(sender, socket_source.local_addr()); + assert_eq!(buffer[..length], MESSAGE[..]); + Ok(()) + } + + #[tokio::test] + #[ignore = "requires daemon and dispatcher"] + async fn message_and_response() -> TestError { + let _lock = lock().lock().await; + + let (socket_source, socket_destination, path_forward) = get_sockets().await?; + socket_source.send(MESSAGE.clone()).await?; + + let mut buffer = [0_u8; 100]; + let (length, sender, path) = + tokio::time::timeout(TIMEOUT, socket_destination.recv_from(&mut buffer)) + .await??; + assert_eq!(sender, socket_source.local_addr()); + assert_eq!(buffer[..length], MESSAGE[..]); + + println!("Reply path: {:?}", path.dataplane_path); + socket_destination + .send_to_with(MESSAGE.clone(), sender, &path) + .await?; + + let (_, _, path_return) = + tokio::time::timeout(TIMEOUT, socket_source.recv_from(&mut buffer)).await??; + assert_eq!(path_return.isd_asn, path_forward.isd_asn); + assert_eq!(path_return.dataplane_path, path_forward.dataplane_path); + + Ok(()) + } } }; }