Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udp_socket): add multiple receive variants with/without path information to avoid allocating memory #92

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/scion-proto/src/packet/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ pub struct ByEndpoint<T> {

impl<T> ByEndpoint<T> {
/// Swaps source and destination.
pub fn into_reversed(self) -> Self {
Self {
source: self.destination,
destination: self.source,
}
}

/// Swaps source and destination in place.
pub fn reverse(&mut self) -> &mut Self {
std::mem::swap(&mut self.source, &mut self.destination);
self
Expand Down
24 changes: 18 additions & 6 deletions crates/scion-proto/src/path/dataplane.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,26 @@ impl DataplanePath {
}

/// Reverses the path.
pub fn reverse(&self) -> Result<Self, UnsupportedPathType> {
pub fn to_reversed(&self) -> Result<Self, UnsupportedPathType> {
match self {
Self::EmptyPath => Ok(Self::EmptyPath),
Self::Standard(standard_path) => Ok(Self::Standard(standard_path.reverse())),
Self::Standard(standard_path) => Ok(Self::Standard(standard_path.to_reversed())),
Self::Unsupported { path_type, .. } => Err(UnsupportedPathType(u8::from(*path_type))),
}
}

/// Reverses the path in place.
pub fn reverse(&mut self) -> Result<&mut Self, UnsupportedPathType> {
match self {
Self::EmptyPath => (),
Self::Standard(standard_path) => *standard_path = standard_path.to_reversed(),
Self::Unsupported { path_type, .. } => {
return Err(UnsupportedPathType(u8::from(*path_type)))
}
}
Ok(self)
}

/// Returns true iff the path is a [`DataplanePath::EmptyPath`]
pub fn is_empty(&self) -> bool {
self == &Self::EmptyPath
Expand Down Expand Up @@ -205,9 +217,9 @@ mod tests {
#[test]
fn reverse_empty() {
let dataplane_path = DataplanePath::EmptyPath;
let reverse_path = dataplane_path.reverse().unwrap();
let reverse_path = dataplane_path.to_reversed().unwrap();
assert_eq!(dataplane_path, reverse_path);
assert_eq!(reverse_path.reverse().unwrap(), dataplane_path);
assert_eq!(reverse_path.to_reversed().unwrap(), dataplane_path);
}

test_path_create_encode_decode!(
Expand All @@ -231,8 +243,8 @@ mod tests {
#[test]
fn reverse_standard() {
let dataplane_path = standard_path();
let reverse_path = dataplane_path.reverse().unwrap();
let reverse_path = dataplane_path.to_reversed().unwrap();
assert!(dataplane_path != reverse_path);
assert_eq!(reverse_path.reverse().unwrap(), dataplane_path);
assert_eq!(reverse_path.to_reversed().unwrap(), dataplane_path);
}
}
6 changes: 3 additions & 3 deletions crates/scion-proto/src/path/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl StandardPath {
///
/// Can panic if the meta header is inconsistent with the encoded path or the encoded path
/// itself is inconsistent (e.g., the `current_info_field` points to an empty segment).
pub fn reverse(&self) -> Self {
pub fn to_reversed(&self) -> Self {
let meta_header = PathMetaHeader {
current_info_field: (self.meta_header.info_fields_count() as u8
- self.meta_header.current_info_field.get()
Expand Down Expand Up @@ -390,9 +390,9 @@ mod tests {
let mut data = $encoded_path;
let header = StandardPath::decode(&mut data).expect("valid decode");

let reverse_path = header.reverse();
let reverse_path = header.to_reversed();
assert!(header != reverse_path);
assert_eq!(header, reverse_path.reverse());
assert_eq!(header, reverse_path.to_reversed());
}
}
};
Expand Down
88 changes: 66 additions & 22 deletions crates/scion/src/udp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,51 @@ impl UdpSocket {
self.local_address
}

/// Receive a SCION UDP packet from a remote endpoint.
/// Receive a SCION UDP packet.
///
/// The UDP payload is written into the provided buffer. If there is insufficient space, excess
/// data is dropped. The returned number of bytes always refers to the amount of data in the UDP
/// payload.
pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ReceiveError> {
let (packet_len, _) = self.inner.recv_from(buffer).await?;
Ok(packet_len)
}
mlegner marked this conversation as resolved.
Show resolved Hide resolved

/// Receive a SCION UDP packet from a remote endpoint.
///
/// This behaves like [`Self::recv`] but additionally returns the remote SCION socket address.
pub async fn recv_from(&self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), ReceiveError> {
self.inner.recv_from(buffer).await
}

/// Receive a SCION UDP packet from a remote endpoint with path information.
///
/// Additionally returns
/// This behaves like [`Self::recv`] but additionally returns
/// - the remote SCION socket address and
/// - the path over which the packet was received. For supported path types, this path is
/// already reversed such that it can be used directly to send reply packets; for unsupported
/// path types, the path is unmodified.
pub async fn recv_from(
/// path types, the path is copied unmodified.
///
/// Note that copying/reversing the path requires allocating memory; if you do not need the path
/// information, consider using the method [`Self::recv_from`] instead.
pub async fn recv_with_path_from(
&self,
buffer: &mut [u8],
) -> Result<(usize, SocketAddr, Path), ReceiveError> {
self.inner.recv_from(buffer).await
self.inner.recv_with_path_from(buffer).await
}

/// Receive a SCION UDP packet with path information.
///
/// This behaves like [`Self::recv`] but additionally returns the path over which the packet was
/// received. For supported path types, this path is already reversed such that it can be used
/// directly to send reply packets; for unsupported path types, the path is copied unmodified.
///
/// Note that copying/reversing the path requires allocating memory; if you do not need the path
/// information, consider using the method [`Self::recv`] instead.
pub async fn recv_with_path(&self, buffer: &mut [u8]) -> Result<(usize, Path), ReceiveError> {
let (packet_len, _, path) = self.inner.recv_with_path_from(buffer).await?;
Ok((packet_len, path))
}
mlegner marked this conversation as resolved.
Show resolved Hide resolved

/// Returns the remote SCION address set for this socket, if any.
Expand Down Expand Up @@ -257,7 +286,27 @@ impl UdpSocketInner {
Ok(())
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> {
async fn recv_with_path_from(
&self,
buf: &mut [u8],
) -> Result<(usize, SocketAddr, Path), ReceiveError> {
let (packet_len, sender, mut path) = self.recv_loop(buf).await?;
// Explicit match here in case we add other errors to the `reverse` method at some point
match path.dataplane_path.reverse() {
Ok(_) => {
path.isd_asn.reverse();
}
Err(UnsupportedPathType(_)) => path.dataplane_path = path.dataplane_path.deep_copy(),
};
Ok((packet_len, sender, path))
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), ReceiveError> {
let (packet_len, sender, ..) = self.recv_loop(buf).await?;
Ok((packet_len, sender))
}

async fn recv_loop(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> {
loop {
let receive_result = {
let state = &mut *self.state.lock().await;
Expand All @@ -266,8 +315,8 @@ impl UdpSocketInner {

match receive_result {
Ok(packet) => {
if let Some(result) = self.parse_incoming(packet, buf) {
return Ok(result);
if let Some((packet_len, sender, path)) = self.parse_incoming(packet, buf) {
return Ok((packet_len, sender, path));
} else {
continue;
}
Expand Down Expand Up @@ -303,24 +352,19 @@ impl UdpSocketInner {
return None;
};

let path = {
// Explicit match here in case we add other errors to the `reverse` method at some point
let dataplane_path = match scion_packet.headers.path.reverse() {
Ok(p) => p,
Err(UnsupportedPathType(_)) => scion_packet.headers.path.deep_copy(),
};
Path::new(
dataplane_path,
*scion_packet.headers.address.ia.reverse(),
packet.last_host,
)
};

let payload_len = udp_datagram.payload.len();
let copy_length = cmp::min(payload_len, buf.len());
buf[..copy_length].copy_from_slice(&udp_datagram.payload[..copy_length]);

Some((payload_len, source, path))
Some((
payload_len,
source,
Path::new(
scion_packet.headers.path,
scion_packet.headers.address.ia,
packet.last_host,
),
))
}
}

Expand Down
132 changes: 85 additions & 47 deletions crates/scion/tests/test_udp_socket.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,100 @@
use std::{sync::OnceLock, time::Duration};

use bytes::Bytes;
use scion::{
daemon::{get_daemon_address, DaemonClient},
udp_socket::UdpSocket,
};
use scion_proto::{address::SocketAddr, packet::ByEndpoint};
use scion_proto::{address::SocketAddr, packet::ByEndpoint, path::Path};
use tokio::sync::Mutex;

type TestError = Result<(), Box<dyn std::error::Error>>;

static MESSAGE: Bytes = Bytes::from_static(b"Hello SCION!");
const TIMEOUT: Duration = std::time::Duration::from_secs(1);

macro_rules! test_send_receive_reply {
($name:ident, $source:expr, $destination:expr) => {
#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn $name() -> TestError {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward.clone());
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, path) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_destination.recv_from(&mut buffer),
)
.await??;
assert_eq!(sender, endpoints.source);
assert_eq!(buffer[..length], MESSAGE[..]);

println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;

let (_, _, path_return) = tokio::time::timeout(
std::time::Duration::from_secs(1),
socket_source.recv_from(&mut buffer),
)
.await??;
assert_eq!(path_return.isd_asn, path_forward.isd_asn);
assert_eq!(path_return.dataplane_path, path_forward.dataplane_path);

Ok(())
mod $name {
use super::*;

// Prevent tests running simultaneously to avoid registration errors from the dispatcher
fn lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::default())
}
jpcsmith marked this conversation as resolved.
Show resolved Hide resolved

async fn get_sockets(
) -> Result<(UdpSocket, UdpSocket, Path), Box<dyn std::error::Error>> {
let endpoints: ByEndpoint<SocketAddr> = ByEndpoint {
source: $source.parse().unwrap(),
destination: $destination.parse().unwrap(),
};
let daemon_client_source = DaemonClient::connect(&get_daemon_address())
.await
.expect("should be able to connect");
let mut socket_source = UdpSocket::bind(endpoints.source).await?;
let socket_destination = UdpSocket::bind(endpoints.destination).await?;

socket_source.connect(endpoints.destination);
let path_forward = daemon_client_source
.paths_to(endpoints.destination.isd_asn())
.await?
.next()
.unwrap();
println!("Forward path: {:?}", path_forward.dataplane_path);
socket_source.set_path(path_forward.clone());

Ok((socket_source, socket_destination, path_forward))
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, ..) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender) =
tokio::time::timeout(TIMEOUT, socket_destination.recv_from(&mut buffer))
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);
Ok(())
}

#[tokio::test]
#[ignore = "requires daemon and dispatcher"]
async fn message_and_response() -> TestError {
let _lock = lock().lock().await;

let (socket_source, socket_destination, path_forward) = get_sockets().await?;
socket_source.send(MESSAGE.clone()).await?;

let mut buffer = [0_u8; 100];
let (length, sender, path) = tokio::time::timeout(
TIMEOUT,
socket_destination.recv_with_path_from(&mut buffer),
)
.await??;
assert_eq!(sender, socket_source.local_addr());
assert_eq!(buffer[..length], MESSAGE[..]);

println!("Reply path: {:?}", path.dataplane_path);
socket_destination
.send_to_with(MESSAGE.clone(), sender, &path)
.await?;

let (_, path_return) =
tokio::time::timeout(TIMEOUT, socket_source.recv_with_path(&mut buffer))
.await??;
assert_eq!(path_return.isd_asn, path_forward.isd_asn);
assert_eq!(path_return.dataplane_path, path_forward.dataplane_path);

Ok(())
}
}
};
}
Expand Down