Skip to content

Commit

Permalink
feat: add recv_from_without_path to avoid allocating memory
Browse files Browse the repository at this point in the history
  • Loading branch information
mlegner committed Dec 11, 2023
1 parent 83933b1 commit 0c54c52
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 59 deletions.
65 changes: 51 additions & 14 deletions crates/scion/src/udp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,29 @@ impl UdpSocket {
/// payload.
///
/// Additionally returns the remote SCION socket address, and the path over which the packet was
/// received.
/// received. Note that copying the path requires allocating memory; consider using the method
/// [`Self::recv_from_without_path`] instead, if you do not need the path information.
pub async fn recv_from(
&self,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr, Path), ReceiveError> {
self.inner.recv_from(buffer).await
}

/// Receive a SCION UDP packet from a remote endpoint.
///
/// The UDP payload is written into the provided buffer. If there is insufficient space, excess
/// data is dropped. The returned number of bytes always refers to the amount of data in the UDP
/// payload.
///
/// Additionally returns the remote SCION socket address.
pub async fn recv_from_without_path(
&self,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr), ReceiveError> {
self.inner.recv_from_without_path(buffer).await
}

/// Returns the remote SCION address set for this socket, if any.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_address
Expand Down Expand Up @@ -255,6 +270,34 @@ impl UdpSocketInner {
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> {
let (packet_len, sender, last_host, scion_packet) = self.recv_from_loop(buf).await?;
let path = {
let dataplane_path = scion_packet.headers.path.deep_copy();
Path::new(dataplane_path, scion_packet.headers.address.ia, last_host)
};
Ok((packet_len, sender, path))
}

async fn recv_from_without_path(
&self,
buf: &mut [u8],
) -> Result<(usize, SocketAddr), ReceiveError> {
let (packet_len, sender, ..) = self.recv_from_loop(buf).await?;
Ok((packet_len, sender))
}

async fn recv_from_loop(
&self,
buf: &mut [u8],
) -> Result<
(
usize,
SocketAddr,
Option<std::net::SocketAddr>,
ScionPacketRaw,
),
ReceiveError,
> {
loop {
let receive_result = {
let state = &mut *self.state.lock().await;
Expand All @@ -263,8 +306,11 @@ impl UdpSocketInner {

match receive_result {
Ok(packet) => {
if let Some(result) = self.parse_incoming(packet, buf) {
return Ok(result);
let last_host = packet.last_host;
if let Some((packet_len, sender, scion_packet)) =
self.parse_incoming(packet, buf)
{
return Ok((packet_len, sender, last_host, scion_packet));
} else {
continue;
}
Expand All @@ -278,7 +324,7 @@ impl UdpSocketInner {
&self,
mut packet: Packet,
buf: &mut [u8],
) -> Option<(usize, SocketAddr, Path)> {
) -> Option<(usize, SocketAddr, ScionPacketRaw)> {
// TODO(jsmith): Need a representation of the packets for logging purposes
let mut scion_packet = ScionPacketRaw::decode(&mut packet.content)
.map_err(log_err!("failed to decode SCION packet"))
Expand All @@ -300,20 +346,11 @@ impl UdpSocketInner {
return None;
};

let path = {
let dataplane_path = scion_packet.headers.path.deep_copy();
Path::new(
dataplane_path,
scion_packet.headers.address.ia,
packet.last_host,
)
};

let payload_len = udp_datagram.payload.len();
let copy_length = cmp::min(payload_len, buf.len());
buf[..copy_length].copy_from_slice(&udp_datagram.payload[..copy_length]);

Some((payload_len, source, path))
Some((payload_len, source, scion_packet))
}
}

Expand Down
136 changes: 91 additions & 45 deletions crates/scion/tests/test_udp_socket.rs
Original file line number Diff line number Diff line change
@@ -1,60 +1,106 @@
use std::sync::OnceLock;

use bytes::Bytes;
use scion::{
daemon::{get_daemon_address, DaemonClient},
udp_socket::UdpSocket,
};
use scion_proto::{address::SocketAddr, packet::ByEndpoint};
use scion_proto::{address::SocketAddr, packet::ByEndpoint, path::Path};
use tokio::sync::Mutex;

type TestError = Result<(), Box<dyn std::error::Error>>;

static MESSAGE: Bytes = Bytes::from_static(b"Hello SCION!");

macro_rules! test_send_receive_reply {
($name:ident, $source:expr, $destination:expr) => {
#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn $name() -> TestError {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward);
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, mut path) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_destination.recv_from(&mut buffer),
)
.await??;
assert_eq!(sender, endpoints.source);
assert_eq!(buffer[..length], MESSAGE[..]);

path.reverse()?;
println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;
let _ = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_source.recv_from(&mut buffer),
)
.await??;

Ok(())
mod $name {
use super::*;

// Prevent tests running simultaneously to avoid registration errors from the dispatcher
fn lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::default())
}

async fn get_sockets(
) -> Result<(UdpSocket, UdpSocket, Path), Box<dyn std::error::Error>> {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward.clone());

Ok((socket_source, socket_destination, path_forward))
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, ..) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_destination.recv_from_without_path(&mut buffer),
)
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);
Ok(())
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message_and_response() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, path_forward) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, mut path) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_destination.recv_from(&mut buffer),
)
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);

path.reverse()?;
println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;

let (_, _, mut path_return) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_source.recv_from(&mut buffer),
)
.await??;
assert_eq!(
path_return.reverse()?.dataplane_path,
path_forward.dataplane_path
);

Ok(())
}
}
};
}
Expand Down

0 comments on commit 0c54c52

Please sign in to comment.