From 0c54c52f3507bf069fb7d4bce1644bb218578475 Mon Sep 17 00:00:00 2001 From: Markus Legner Date: Mon, 11 Dec 2023 18:19:28 +0100 Subject: [PATCH] feat: add recv_from_without_path to avoid allocating memory --- crates/scion/src/udp_socket.rs | 65 +++++++++--- crates/scion/tests/test_udp_socket.rs | 136 +++++++++++++++++--------- 2 files changed, 142 insertions(+), 59 deletions(-) diff --git a/crates/scion/src/udp_socket.rs b/crates/scion/src/udp_socket.rs index 8c111e0..1fecf96 100644 --- a/crates/scion/src/udp_socket.rs +++ b/crates/scion/src/udp_socket.rs @@ -109,7 +109,8 @@ impl UdpSocket { /// payload. /// /// Additionally returns the remote SCION socket address, and the path over which the packet was - /// received. + /// received. Note that copying the path requires allocating memory; consider using the method + /// [`Self::recv_from_without_path`] instead, if you do not need the path information. pub async fn recv_from( &self, buffer: &mut [u8], @@ -117,6 +118,20 @@ impl UdpSocket { self.inner.recv_from(buffer).await } + /// Receive a SCION UDP packet from a remote endpoint. + /// + /// The UDP payload is written into the provided buffer. If there is insufficient space, excess + /// data is dropped. The returned number of bytes always refers to the amount of data in the UDP + /// payload. + /// + /// Additionally returns the remote SCION socket address. + pub async fn recv_from_without_path( + &self, + buffer: &mut [u8], + ) -> Result<(usize, SocketAddr), ReceiveError> { + self.inner.recv_from_without_path(buffer).await + } + /// Returns the remote SCION address set for this socket, if any. pub fn remote_addr(&self) -> Option { self.remote_address @@ -255,6 +270,34 @@ impl UdpSocketInner { } async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> { + let (packet_len, sender, last_host, scion_packet) = self.recv_from_loop(buf).await?; + let path = { + let dataplane_path = scion_packet.headers.path.deep_copy(); + Path::new(dataplane_path, scion_packet.headers.address.ia, last_host) + }; + Ok((packet_len, sender, path)) + } + + async fn recv_from_without_path( + &self, + buf: &mut [u8], + ) -> Result<(usize, SocketAddr), ReceiveError> { + let (packet_len, sender, ..) = self.recv_from_loop(buf).await?; + Ok((packet_len, sender)) + } + + async fn recv_from_loop( + &self, + buf: &mut [u8], + ) -> Result< + ( + usize, + SocketAddr, + Option, + ScionPacketRaw, + ), + ReceiveError, + > { loop { let receive_result = { let state = &mut *self.state.lock().await; @@ -263,8 +306,11 @@ impl UdpSocketInner { match receive_result { Ok(packet) => { - if let Some(result) = self.parse_incoming(packet, buf) { - return Ok(result); + let last_host = packet.last_host; + if let Some((packet_len, sender, scion_packet)) = + self.parse_incoming(packet, buf) + { + return Ok((packet_len, sender, last_host, scion_packet)); } else { continue; } @@ -278,7 +324,7 @@ impl UdpSocketInner { &self, mut packet: Packet, buf: &mut [u8], - ) -> Option<(usize, SocketAddr, Path)> { + ) -> Option<(usize, SocketAddr, ScionPacketRaw)> { // TODO(jsmith): Need a representation of the packets for logging purposes let mut scion_packet = ScionPacketRaw::decode(&mut packet.content) .map_err(log_err!("failed to decode SCION packet")) @@ -300,20 +346,11 @@ impl UdpSocketInner { return None; }; - let path = { - let dataplane_path = scion_packet.headers.path.deep_copy(); - Path::new( - dataplane_path, - scion_packet.headers.address.ia, - packet.last_host, - ) - }; - let payload_len = udp_datagram.payload.len(); let copy_length = cmp::min(payload_len, buf.len()); buf[..copy_length].copy_from_slice(&udp_datagram.payload[..copy_length]); - Some((payload_len, source, path)) + Some((payload_len, source, scion_packet)) } } diff --git a/crates/scion/tests/test_udp_socket.rs b/crates/scion/tests/test_udp_socket.rs index 8ad8892..2df1f35 100644 --- a/crates/scion/tests/test_udp_socket.rs +++ b/crates/scion/tests/test_udp_socket.rs @@ -1,9 +1,12 @@ +use std::sync::OnceLock; + use bytes::Bytes; use scion::{ daemon::{get_daemon_address, DaemonClient}, udp_socket::UdpSocket, }; -use scion_proto::{address::SocketAddr, packet::ByEndpoint}; +use scion_proto::{address::SocketAddr, packet::ByEndpoint, path::Path}; +use tokio::sync::Mutex; type TestError = Result<(), Box>; @@ -11,50 +14,93 @@ static MESSAGE: Bytes = Bytes::from_static(b"Hello SCION!"); macro_rules! test_send_receive_reply { ($name:ident, $source:expr, $destination:expr) => { - #[tokio::test] - #[ignore = "requires daemon and dispatcher"] - async fn $name() -> TestError { - let endpoints: ByEndpoint = ByEndpoint { - source: $source.parse().unwrap(), - destination: $destination.parse().unwrap(), - }; - let daemon_client_source = DaemonClient::connect(&get_daemon_address()) - .await - .expect("should be able to connect"); - let mut socket_source = UdpSocket::bind(endpoints.source).await?; - let socket_destination = UdpSocket::bind(endpoints.destination).await?; - - socket_source.connect(endpoints.destination); - let path_forward = daemon_client_source - .paths_to(endpoints.destination.isd_asn()) - .await? - .next() - .unwrap(); - println!("Forward path: {:?}", path_forward.dataplane_path); - socket_source.set_path(path_forward); - socket_source.send(MESSAGE.clone()).await?; - - let mut buffer = [0_u8; 100]; - let (length, sender, mut path) = tokio::time::timeout( - std::time::Duration::from_secs(1), - socket_destination.recv_from(&mut buffer), - ) - .await??; - assert_eq!(sender, endpoints.source); - assert_eq!(buffer[..length], MESSAGE[..]); - - path.reverse()?; - println!("Reply path: {:?}", path.dataplane_path); - socket_destination - .send_to_with(MESSAGE.clone(), sender, &path) - .await?; - let _ = tokio::time::timeout( - std::time::Duration::from_secs(1), - socket_source.recv_from(&mut buffer), - ) - .await??; - - Ok(()) + mod $name { + use super::*; + + // Prevent tests running simultaneously to avoid registration errors from the dispatcher + fn lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::default()) + } + + async fn get_sockets( + ) -> Result<(UdpSocket, UdpSocket, Path), Box> { + let endpoints: ByEndpoint = ByEndpoint { + source: $source.parse().unwrap(), + destination: $destination.parse().unwrap(), + }; + let daemon_client_source = DaemonClient::connect(&get_daemon_address()) + .await + .expect("should be able to connect"); + let mut socket_source = UdpSocket::bind(endpoints.source).await?; + let socket_destination = UdpSocket::bind(endpoints.destination).await?; + + socket_source.connect(endpoints.destination); + let path_forward = daemon_client_source + .paths_to(endpoints.destination.isd_asn()) + .await? + .next() + .unwrap(); + println!("Forward path: {:?}", path_forward.dataplane_path); + socket_source.set_path(path_forward.clone()); + + Ok((socket_source, socket_destination, path_forward)) + } + + #[tokio::test] + #[ignore = "requires daemon and dispatcher"] + async fn message() -> TestError { + let _lock = lock().lock().await; + + let (socket_source, socket_destination, ..) = get_sockets().await?; + socket_source.send(MESSAGE.clone()).await?; + + let mut buffer = [0_u8; 100]; + let (length, sender) = tokio::time::timeout( + std::time::Duration::from_secs(1), + socket_destination.recv_from_without_path(&mut buffer), + ) + .await??; + assert_eq!(sender, socket_source.local_addr()); + assert_eq!(buffer[..length], MESSAGE[..]); + Ok(()) + } + + #[tokio::test] + #[ignore = "requires daemon and dispatcher"] + async fn message_and_response() -> TestError { + let _lock = lock().lock().await; + + let (socket_source, socket_destination, path_forward) = get_sockets().await?; + socket_source.send(MESSAGE.clone()).await?; + + let mut buffer = [0_u8; 100]; + let (length, sender, mut path) = tokio::time::timeout( + std::time::Duration::from_secs(1), + socket_destination.recv_from(&mut buffer), + ) + .await??; + assert_eq!(sender, socket_source.local_addr()); + assert_eq!(buffer[..length], MESSAGE[..]); + + path.reverse()?; + println!("Reply path: {:?}", path.dataplane_path); + socket_destination + .send_to_with(MESSAGE.clone(), sender, &path) + .await?; + + let (_, _, mut path_return) = tokio::time::timeout( + std::time::Duration::from_secs(1), + socket_source.recv_from(&mut buffer), + ) + .await??; + assert_eq!( + path_return.reverse()?.dataplane_path, + path_forward.dataplane_path + ); + + Ok(()) + } } }; }