Skip to content

Commit

Permalink
feat(udp_socket): add multiple receive variants with/without path inf…
Browse files Browse the repository at this point in the history
…ormation to avoid allocating memory (#92)
  • Loading branch information
mlegner authored Dec 12, 2023
1 parent 9a574bd commit 96b9dbd
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 78 deletions.
8 changes: 8 additions & 0 deletions crates/scion-proto/src/packet/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ pub struct ByEndpoint<T> {

impl<T> ByEndpoint<T> {
/// Swaps source and destination.
pub fn into_reversed(self) -> Self {
Self {
source: self.destination,
destination: self.source,
}
}

/// Swaps source and destination in place.
pub fn reverse(&mut self) -> &mut Self {
std::mem::swap(&mut self.source, &mut self.destination);
self
Expand Down
24 changes: 18 additions & 6 deletions crates/scion-proto/src/path/dataplane.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,26 @@ impl DataplanePath {
}

/// Reverses the path.
pub fn reverse(&self) -> Result<Self, UnsupportedPathType> {
pub fn to_reversed(&self) -> Result<Self, UnsupportedPathType> {
match self {
Self::EmptyPath => Ok(Self::EmptyPath),
Self::Standard(standard_path) => Ok(Self::Standard(standard_path.reverse())),
Self::Standard(standard_path) => Ok(Self::Standard(standard_path.to_reversed())),
Self::Unsupported { path_type, .. } => Err(UnsupportedPathType(u8::from(*path_type))),
}
}

/// Reverses the path in place.
pub fn reverse(&mut self) -> Result<&mut Self, UnsupportedPathType> {
match self {
Self::EmptyPath => (),
Self::Standard(standard_path) => *standard_path = standard_path.to_reversed(),
Self::Unsupported { path_type, .. } => {
return Err(UnsupportedPathType(u8::from(*path_type)))
}
}
Ok(self)
}

/// Returns true iff the path is a [`DataplanePath::EmptyPath`]
pub fn is_empty(&self) -> bool {
self == &Self::EmptyPath
Expand Down Expand Up @@ -205,9 +217,9 @@ mod tests {
#[test]
fn reverse_empty() {
let dataplane_path = DataplanePath::EmptyPath;
let reverse_path = dataplane_path.reverse().unwrap();
let reverse_path = dataplane_path.to_reversed().unwrap();
assert_eq!(dataplane_path, reverse_path);
assert_eq!(reverse_path.reverse().unwrap(), dataplane_path);
assert_eq!(reverse_path.to_reversed().unwrap(), dataplane_path);
}

test_path_create_encode_decode!(
Expand All @@ -231,8 +243,8 @@ mod tests {
#[test]
fn reverse_standard() {
let dataplane_path = standard_path();
let reverse_path = dataplane_path.reverse().unwrap();
let reverse_path = dataplane_path.to_reversed().unwrap();
assert!(dataplane_path != reverse_path);
assert_eq!(reverse_path.reverse().unwrap(), dataplane_path);
assert_eq!(reverse_path.to_reversed().unwrap(), dataplane_path);
}
}
6 changes: 3 additions & 3 deletions crates/scion-proto/src/path/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl StandardPath {
///
/// Can panic if the meta header is inconsistent with the encoded path or the encoded path
/// itself is inconsistent (e.g., the `current_info_field` points to an empty segment).
pub fn reverse(&self) -> Self {
pub fn to_reversed(&self) -> Self {
let meta_header = PathMetaHeader {
current_info_field: (self.meta_header.info_fields_count() as u8
- self.meta_header.current_info_field.get()
Expand Down Expand Up @@ -390,9 +390,9 @@ mod tests {
let mut data = $encoded_path;
let header = StandardPath::decode(&mut data).expect("valid decode");

let reverse_path = header.reverse();
let reverse_path = header.to_reversed();
assert!(header != reverse_path);
assert_eq!(header, reverse_path.reverse());
assert_eq!(header, reverse_path.to_reversed());
}
}
};
Expand Down
88 changes: 66 additions & 22 deletions crates/scion/src/udp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,51 @@ impl UdpSocket {
self.local_address
}

/// Receive a SCION UDP packet from a remote endpoint.
/// Receive a SCION UDP packet.
///
/// The UDP payload is written into the provided buffer. If there is insufficient space, excess
/// data is dropped. The returned number of bytes always refers to the amount of data in the UDP
/// payload.
pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ReceiveError> {
let (packet_len, _) = self.inner.recv_from(buffer).await?;
Ok(packet_len)
}

/// Receive a SCION UDP packet from a remote endpoint.
///
/// This behaves like [`Self::recv`] but additionally returns the remote SCION socket address.
pub async fn recv_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), ReceiveError> {
self.inner.recv_from(buffer).await
}

/// Receive a SCION UDP packet from a remote endpoint with path information.
///
/// Additionally returns
/// This behaves like [`Self::recv`] but additionally returns
/// - the remote SCION socket address and
/// - the path over which the packet was received. For supported path types, this path is
/// already reversed such that it can be used directly to send reply packets; for unsupported
/// path types, the path is unmodified.
pub async fn recv_from(
/// path types, the path is copied unmodified.
///
/// Note that copying/reversing the path requires allocating memory; if you do not need the path
/// information, consider using the method [`Self::recv_from`] instead.
pub async fn recv_with_path_from(
&self,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr, Path), ReceiveError> {
self.inner.recv_from(buffer).await
self.inner.recv_with_path_from(buffer).await
}

/// Receive a SCION UDP packet with path information.
///
/// This behaves like [`Self::recv`] but additionally returns the path over which the packet was
/// received. For supported path types, this path is already reversed such that it can be used
/// directly to send reply packets; for unsupported path types, the path is copied unmodified.
///
/// Note that copying/reversing the path requires allocating memory; if you do not need the path
/// information, consider using the method [`Self::recv`] instead.
pub async fn recv_with_path(&self, buffer: &mut [u8]) -> Result<(usize, Path), ReceiveError> {
let (packet_len, _, path) = self.inner.recv_with_path_from(buffer).await?;
Ok((packet_len, path))
}

/// Returns the remote SCION address set for this socket, if any.
Expand Down Expand Up @@ -257,7 +286,27 @@ impl UdpSocketInner {
Ok(())
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> {
async fn recv_with_path_from(
&self,
buf: &mut [u8],
) -> Result<(usize, SocketAddr, Path), ReceiveError> {
let (packet_len, sender, mut path) = self.recv_loop(buf).await?;
// Explicit match here in case we add other errors to the `reverse` method at some point
match path.dataplane_path.reverse() {
Ok(_) => {
path.isd_asn.reverse();
}
Err(UnsupportedPathType(_)) => path.dataplane_path = path.dataplane_path.deep_copy(),
};
Ok((packet_len, sender, path))
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), ReceiveError> {
let (packet_len, sender, ..) = self.recv_loop(buf).await?;
Ok((packet_len, sender))
}

async fn recv_loop(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> {
loop {
let receive_result = {
let state = &mut *self.state.lock().await;
Expand All @@ -266,8 +315,8 @@ impl UdpSocketInner {

match receive_result {
Ok(packet) => {
if let Some(result) = self.parse_incoming(packet, buf) {
return Ok(result);
if let Some((packet_len, sender, path)) = self.parse_incoming(packet, buf) {
return Ok((packet_len, sender, path));
} else {
continue;
}
Expand Down Expand Up @@ -303,24 +352,19 @@ impl UdpSocketInner {
return None;
};

let path = {
// Explicit match here in case we add other errors to the `reverse` method at some point
let dataplane_path = match scion_packet.headers.path.reverse() {
Ok(p) => p,
Err(UnsupportedPathType(_)) => scion_packet.headers.path.deep_copy(),
};
Path::new(
dataplane_path,
*scion_packet.headers.address.ia.reverse(),
packet.last_host,
)
};

let payload_len = udp_datagram.payload.len();
let copy_length = cmp::min(payload_len, buf.len());
buf[..copy_length].copy_from_slice(&udp_datagram.payload[..copy_length]);

Some((payload_len, source, path))
Some((
payload_len,
source,
Path::new(
scion_packet.headers.path,
scion_packet.headers.address.ia,
packet.last_host,
),
))
}
}

Expand Down
132 changes: 85 additions & 47 deletions crates/scion/tests/test_udp_socket.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,100 @@
use std::{sync::OnceLock, time::Duration};

use bytes::Bytes;
use scion::{
daemon::{get_daemon_address, DaemonClient},
udp_socket::UdpSocket,
};
use scion_proto::{address::SocketAddr, packet::ByEndpoint};
use scion_proto::{address::SocketAddr, packet::ByEndpoint, path::Path};
use tokio::sync::Mutex;

type TestError = Result<(), Box<dyn std::error::Error>>;

static MESSAGE: Bytes = Bytes::from_static(b"Hello SCION!");
const TIMEOUT: Duration = std::time::Duration::from_secs(1);

macro_rules! test_send_receive_reply {
($name:ident, $source:expr, $destination:expr) => {
#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn $name() -> TestError {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward.clone());
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, path) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_destination.recv_from(&mut buffer),
)
.await??;
assert_eq!(sender, endpoints.source);
assert_eq!(buffer[..length], MESSAGE[..]);

println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;

let (_, _, path_return) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_source.recv_from(&mut buffer),
)
.await??;
assert_eq!(path_return.isd_asn, path_forward.isd_asn);
assert_eq!(path_return.dataplane_path, path_forward.dataplane_path);

Ok(())
mod $name {
use super::*;

// Prevent tests running simultaneously to avoid registration errors from the dispatcher
fn lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::default())
}

async fn get_sockets(
) -> Result<(UdpSocket, UdpSocket, Path), Box<dyn std::error::Error>> {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward.clone());

Ok((socket_source, socket_destination, path_forward))
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, ..) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender) =
tokio::time::timeout(TIMEOUT, socket_destination.recv_from(&mut buffer))
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);
Ok(())
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message_and_response() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, path_forward) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, path) = tokio::time::timeout(
TIMEOUT,
socket_destination.recv_with_path_from(&mut buffer),
)
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);

println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;

let (_, path_return) =
tokio::time::timeout(TIMEOUT, socket_source.recv_with_path(&mut buffer))
.await??;
assert_eq!(path_return.isd_asn, path_forward.isd_asn);
assert_eq!(path_return.dataplane_path, path_forward.dataplane_path);

Ok(())
}
}
};
}
Expand Down

0 comments on commit 96b9dbd

Please sign in to comment.