Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udp_socket): add multiple receive variants with/without path information to avoid allocating memory #92

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion crates/scion-proto/src/packet/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,15 @@ pub struct ByEndpoint<T> {

impl<T> ByEndpoint<T> {
/// Swaps source and destination.
pub fn reverse(&mut self) -> &mut Self {
pub fn reverse(self) -> Self {
mlegner marked this conversation as resolved.
Show resolved Hide resolved
Self {
source: self.destination,
destination: self.source,
}
}

/// Swaps source and destination in place.
pub fn reverse_in_place(&mut self) -> &mut Self {
mlegner marked this conversation as resolved.
Show resolved Hide resolved
std::mem::swap(&mut self.source, &mut self.destination);
self
}
Expand Down
99 changes: 77 additions & 22 deletions crates/scion/src/udp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,51 @@ impl UdpSocket {
self.local_address
}

/// Receive a SCION UDP packet from a remote endpoint.
/// Receive a SCION UDP packet.
///
/// The UDP payload is written into the provided buffer. If there is insufficient space, excess
/// data is dropped. The returned number of bytes always refers to the amount of data in the UDP
/// payload.
pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ReceiveError> {
let (packet_len, _) = self.inner.recv_from(buffer).await?;
Ok(packet_len)
}
mlegner marked this conversation as resolved.
Show resolved Hide resolved

/// Receive a SCION UDP packet from a remote endpoint.
///
/// Additionally returns
/// This behaves like [`Self::recv`] but additionally returns the remote SCION socket address.
pub async fn recv_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), ReceiveError> {
self.inner.recv_from(buffer).await
}

/// Receive a SCION UDP packet from a remote endpoint with path information.
///
/// This behaves like [`Self::recv`] but additionally returns
/// - the remote SCION socket address and
/// - the path over which the packet was received. For supported path types, this path is
/// already reversed such that it can be used directly to send reply packets; for unsupported
/// path types, the path is unmodified.
mlegner marked this conversation as resolved.
Show resolved Hide resolved
pub async fn recv_from(
///
/// Note that copying/reversing the path requires allocating memory; if you do not need the path
/// information, consider using the method [`Self::recv_from`] instead.
pub async fn recv_with_path_from(
&self,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr, Path), ReceiveError> {
self.inner.recv_from(buffer).await
self.inner.recv_with_path_from(buffer).await
}

/// Receive a SCION UDP packet with path information.
///
/// This behaves like [`Self::recv`] but additionally returns the path over which the packet was
/// received. For supported path types, this path is already reversed such that it can be used
/// directly to send reply packets; for unsupported path types, the path is unmodified.
///
/// Note that copying/reversing the path requires allocating memory; if you do not need the path
/// information, consider using the method [`Self::recv`] instead.
pub async fn recv_with_path(&self, buffer: &mut [u8]) -> Result<(usize, Path), ReceiveError> {
let (packet_len, _, path) = self.inner.recv_with_path_from(buffer).await?;
Ok((packet_len, path))
}
mlegner marked this conversation as resolved.
Show resolved Hide resolved

/// Returns the remote SCION address set for this socket, if any.
Expand Down Expand Up @@ -257,7 +286,43 @@ impl UdpSocketInner {
Ok(())
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> {
async fn recv_with_path_from(
&self,
buf: &mut [u8],
) -> Result<(usize, SocketAddr, Path), ReceiveError> {
let (packet_len, sender, last_host, scion_packet) = self.recv_loop(buf).await?;
let path = {
// Explicit match here in case we add other errors to the `reverse` method at some point
let dataplane_path = match scion_packet.headers.path.reverse() {
Ok(p) => p,
Err(UnsupportedPathType(_)) => scion_packet.headers.path.deep_copy(),
};
Path::new(
dataplane_path,
scion_packet.headers.address.ia.reverse(),
jpcsmith marked this conversation as resolved.
Show resolved Hide resolved
last_host,
)
};
Ok((packet_len, sender, path))
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), ReceiveError> {
let (packet_len, sender, ..) = self.recv_loop(buf).await?;
Ok((packet_len, sender))
}

async fn recv_loop(
&self,
buf: &mut [u8],
) -> Result<
(
usize,
SocketAddr,
Option<std::net::SocketAddr>,
ScionPacketRaw,
),
ReceiveError,
> {
loop {
let receive_result = {
let state = &mut *self.state.lock().await;
Expand All @@ -266,8 +331,11 @@ impl UdpSocketInner {

match receive_result {
Ok(packet) => {
if let Some(result) = self.parse_incoming(packet, buf) {
return Ok(result);
let last_host = packet.last_host;
if let Some((packet_len, sender, scion_packet)) =
self.parse_incoming(packet, buf)
{
return Ok((packet_len, sender, last_host, scion_packet));
} else {
continue;
}
Expand All @@ -281,7 +349,7 @@ impl UdpSocketInner {
&self,
mut packet: Packet,
buf: &mut [u8],
) -> Option<(usize, SocketAddr, Path)> {
) -> Option<(usize, SocketAddr, ScionPacketRaw)> {
mlegner marked this conversation as resolved.
Show resolved Hide resolved
// TODO(jsmith): Need a representation of the packets for logging purposes
let mut scion_packet = ScionPacketRaw::decode(&mut packet.content)
.map_err(log_err!("failed to decode SCION packet"))
Expand All @@ -303,24 +371,11 @@ impl UdpSocketInner {
return None;
};

let path = {
// Explicit match here in case we add other errors to the `reverse` method at some point
let dataplane_path = match scion_packet.headers.path.reverse() {
Ok(p) => p,
Err(UnsupportedPathType(_)) => scion_packet.headers.path.deep_copy(),
};
Path::new(
dataplane_path,
*scion_packet.headers.address.ia.reverse(),
packet.last_host,
)
};

let payload_len = udp_datagram.payload.len();
let copy_length = cmp::min(payload_len, buf.len());
buf[..copy_length].copy_from_slice(&udp_datagram.payload[..copy_length]);

Some((payload_len, source, path))
Some((payload_len, source, scion_packet))
}
}

Expand Down
132 changes: 85 additions & 47 deletions crates/scion/tests/test_udp_socket.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,100 @@
use std::{sync::OnceLock, time::Duration};

use bytes::Bytes;
use scion::{
daemon::{get_daemon_address, DaemonClient},
udp_socket::UdpSocket,
};
use scion_proto::{address::SocketAddr, packet::ByEndpoint};
use scion_proto::{address::SocketAddr, packet::ByEndpoint, path::Path};
use tokio::sync::Mutex;

type TestError = Result<(), Box<dyn std::error::Error>>;

static MESSAGE: Bytes = Bytes::from_static(b"Hello SCION!");
const TIMEOUT: Duration = std::time::Duration::from_secs(1);

macro_rules! test_send_receive_reply {
($name:ident, $source:expr, $destination:expr) => {
#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn $name() -> TestError {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward.clone());
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, path) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_destination.recv_from(&mut buffer),
)
.await??;
assert_eq!(sender, endpoints.source);
assert_eq!(buffer[..length], MESSAGE[..]);

println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;

let (_, _, path_return) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_source.recv_from(&mut buffer),
)
.await??;
assert_eq!(path_return.isd_asn, path_forward.isd_asn);
assert_eq!(path_return.dataplane_path, path_forward.dataplane_path);

Ok(())
mod $name {
use super::*;

// Prevent tests running simultaneously to avoid registration errors from the dispatcher
fn lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::default())
}
jpcsmith marked this conversation as resolved.
Show resolved Hide resolved

async fn get_sockets(
) -> Result<(UdpSocket, UdpSocket, Path), Box<dyn std::error::Error>> {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward.clone());

Ok((socket_source, socket_destination, path_forward))
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, ..) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender) =
tokio::time::timeout(TIMEOUT, socket_destination.recv_from(&mut buffer))
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);
Ok(())
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message_and_response() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, path_forward) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, path) = tokio::time::timeout(
TIMEOUT,
socket_destination.recv_with_path_from(&mut buffer),
)
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);

println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;

let (_, path_return) =
tokio::time::timeout(TIMEOUT, socket_source.recv_with_path(&mut buffer))
.await??;
assert_eq!(path_return.isd_asn, path_forward.isd_asn);
assert_eq!(path_return.dataplane_path, path_forward.dataplane_path);

Ok(())
}
}
};
}
Expand Down