Skip to content

Commit

Permalink
refactor: allow setters to be used on shared socket
Browse files Browse the repository at this point in the history
  • Loading branch information
jpcsmith committed Dec 13, 2023
1 parent 96b9dbd commit 902933b
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 86 deletions.
17 changes: 11 additions & 6 deletions crates/scion/src/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,23 @@ pub struct DispatcherStream {
}

impl DispatcherStream {
/// Create a new DispatcherStream over an already connected UnixStream.
pub fn new(stream: UnixStream) -> Self {
Self {
inner: stream,
send_buffer: BytesMut::with_capacity(SEND_BUFFER_LEN),
recv_buffer: BytesMut::with_capacity(RECV_BUFFER_LEN),
parser: StreamParser::new(),
}
}

/// Connects to the dispatcher over a Unix socket at the provided path.
pub async fn connect<P: AsRef<Path> + std::fmt::Debug>(path: P) -> Result<Self, io::Error> {
tracing::trace!(?path, "connecting to dispatcher");
let inner = UnixStream::connect(path).await?;
tracing::trace!("successfully connected");

Ok(Self {
inner,
send_buffer: BytesMut::with_capacity(SEND_BUFFER_LEN),
recv_buffer: BytesMut::with_capacity(RECV_BUFFER_LEN),
parser: StreamParser::new(),
})
Ok(Self::new(inner))
}

/// Register to receive SCION packet for the given address and port.
Expand Down
222 changes: 150 additions & 72 deletions crates/scion/src/udp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

//! A socket to send UDP datagrams via SCION.

use std::{cmp, io, sync::Arc};
use std::{
cmp,
io,
sync::{Arc, RwLock},
};

use bytes::Bytes;
use chrono::Utc;
Expand All @@ -20,7 +24,7 @@ use crate::dispatcher::{self, get_dispatcher_path, DispatcherStream, Registratio

#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum ConnectError {
pub enum BindError {
#[error("failed to connect to the dispatcher, reason: {0}")]
DispatcherConnectFailed(#[from] io::Error),
#[error("failed to bind to the requested port")]
Expand All @@ -37,7 +41,9 @@ pub enum SendError {
#[error("path is expired")]
PathExpired,
#[error("remote address is not set")]
NoRemoteAddress,
NotConnected,
#[error("socket is already connected")]
AlreadyConnected,
#[error("path is not set")]
NoPath,
#[error("no underlay next hop provided by path")]
Expand Down Expand Up @@ -72,34 +78,32 @@ impl From<packet::EncodeError> for SendError {
#[derive(Debug)]
pub struct UdpSocket {
inner: Arc<UdpSocketInner>,
local_address: SocketAddr,
remote_address: Option<SocketAddr>,
path: Option<Path>,
}

impl UdpSocket {
pub async fn bind(address: SocketAddr) -> Result<Self, ConnectError> {
pub async fn bind(address: SocketAddr) -> Result<Self, BindError> {
Self::bind_with_dispatcher(address, get_dispatcher_path()).await
}

pub async fn bind_with_dispatcher<P: AsRef<std::path::Path> + std::fmt::Debug>(
address: SocketAddr,
dispatcher_path: P,
) -> Result<Self, ConnectError> {
) -> Result<Self, BindError> {
let mut stream = DispatcherStream::connect(dispatcher_path).await?;
let local_address = stream.register(address).await?;

Ok(Self {
inner: Arc::new(UdpSocketInner::new(stream)),
local_address,
remote_address: None,
path: None,
})
Ok(Self::new(stream, local_address))
}

fn new(stream: DispatcherStream, local_addr: SocketAddr) -> Self {
Self {
inner: Arc::new(UdpSocketInner::new(stream, local_addr)),
}
}

/// Returns the local SCION address to which this socket is bound.
pub fn local_addr(&self) -> SocketAddr {
self.local_address
self.inner.local_addr()
}

/// Receive a SCION UDP packet.
Expand Down Expand Up @@ -151,91 +155,64 @@ impl UdpSocket {

/// Returns the remote SCION address set for this socket, if any.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_address
self.inner.remote_addr()
}

/// Returns the SCION path set for this socket, if any.
pub fn path(&self) -> Option<&Path> {
self.path.as_ref()
pub fn path(&self) -> Option<Path> {
self.inner.path()
}

/// Registers a remote address for this socket.
pub fn connect(&mut self, remote_address: SocketAddr) -> &mut Self {
self.remote_address = Some(remote_address);
self
pub fn connect(&self, remote_address: SocketAddr) {
self.inner.set_remote_address(remote_address);
}

/// Registers a path for this socket.
pub fn set_path(&mut self, path: Path) -> &mut Self {
self.path = Some(path);
/// Registers or clears a path for this socket.
pub fn set_path(&self, path: Option<Path>) -> &Self {
self.inner.set_path(path);
self
}

/// Sends the payload using the registered remote address and path
///
/// Returns an error if the remote address or path are unset
pub async fn send(&self, payload: Bytes) -> Result<(), SendError> {
self.send_to_with(
payload,
self.remote_address.ok_or(SendError::NoRemoteAddress)?,
self.path.as_ref().ok_or(SendError::NoPath)?,
)
.await
self.inner.send_with_to(payload, None, None).await
}

/// Sends the payload to the specified destination using the registered path
///
/// Returns an error if the path is unset
pub async fn send_to(&self, payload: Bytes, destination: SocketAddr) -> Result<(), SendError> {
self.send_to_with(
payload,
destination,
self.path.as_ref().ok_or(SendError::NoPath)?,
)
.await
self.inner
.send_with_to(payload, Some(destination), None)
.await
}

/// Sends the payload to the registered destination using the specified path
///
/// Returns an error if the remote address is unset
pub async fn send_with(&self, payload: Bytes, path: &Path) -> Result<(), SendError> {
self.send_to_with(
payload,
self.remote_address.ok_or(SendError::NoRemoteAddress)?,
path,
)
.await
self.inner.send_with_to(payload, None, Some(path)).await
}

/// Sends the payload to the specified remote address and path
pub async fn send_to_with(
pub async fn send_with_to(
&self,
payload: Bytes,
destination: SocketAddr,
path: &Path,
) -> Result<(), SendError> {
self.inner
.send_between_with(
payload,
&ByEndpoint {
destination,
source: self.local_addr(),
},
path,
)
.await?;
Ok(())
.send_with_to(payload, Some(destination), Some(path))
.await
}
}

/// Error messages returned from the UDP socket.
pub type ReceiveError = std::convert::Infallible;

#[derive(Debug)]
struct UdpSocketInner {
state: Mutex<State>,
}

macro_rules! log_err {
($message:expr) => {
|err| {
Expand All @@ -245,19 +222,41 @@ macro_rules! log_err {
};
}

#[derive(Debug)]
struct UdpSocketInner {
stream: Mutex<DispatcherStream>,
state: RwLock<Arc<State>>,
}

impl UdpSocketInner {
fn new(stream: DispatcherStream) -> Self {
fn new(stream: DispatcherStream, local_address: SocketAddr) -> Self {
Self {
state: Mutex::new(State { stream }),
state: RwLock::new(Arc::new(State {
local_address,
remote_address: None,
path: None,
})),
stream: Mutex::new(stream),
}
}

async fn send_between_with(
async fn send_with_to(
&self,
payload: Bytes,
endhosts: &ByEndpoint<SocketAddr>,
path: &Path,
destination: Option<SocketAddr>,
path: Option<&Path>,
) -> Result<(), SendError> {
let state = self.state.read().unwrap().clone();
let path = path.or(state.path.as_ref()).ok_or(SendError::NoPath)?;
let Some(destination) = destination.xor(state.remote_address) else {
// Either both are None or both are Some
return if state.remote_address.is_none() {
Err(SendError::NotConnected)
} else {
Err(SendError::AlreadyConnected)
};
};

if let Some(metadata) = &path.metadata {
if metadata.expiration < Utc::now() {
return Err(SendError::PathExpired);
Expand All @@ -266,21 +265,27 @@ impl UdpSocketInner {

let relay = if path.underlay_next_hop.is_some() {
path.underlay_next_hop
} else if endhosts.source.isd_asn() == endhosts.destination.isd_asn() {
endhosts.destination.local_address().map(|mut socket_addr| {
} else if state.local_address.isd_asn() == destination.isd_asn() {
destination.local_address().map(|mut socket_addr| {
socket_addr.set_port(dispatcher::UNDERLAY_PORT);
socket_addr
})
} else {
return Err(SendError::NoUnderlayNextHop);
};

let packet = ScionPacketUdp::new(endhosts, path, payload)?;
let packet = ScionPacketUdp::new(
&ByEndpoint {
destination,
source: state.local_address,
},
path,
payload,
)?;

self.state
self.stream
.lock()
.await
.stream
.send_packet_via(relay, packet)
.await?;
Ok(())
Expand Down Expand Up @@ -309,8 +314,8 @@ impl UdpSocketInner {
async fn recv_loop(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> {
loop {
let receive_result = {
let state = &mut *self.state.lock().await;
state.stream.receive_packet().await
let stream = &mut *self.stream.lock().await;
stream.receive_packet().await
};

match receive_result {
Expand Down Expand Up @@ -366,9 +371,82 @@ impl UdpSocketInner {
),
))
}

pub fn local_addr(&self) -> SocketAddr {
self.state.read().unwrap().local_address
}

pub fn remote_addr(&self) -> Option<SocketAddr> {
self.state.read().unwrap().remote_address
}

pub fn set_remote_address(&self, remote_address: SocketAddr) {
Arc::make_mut(&mut *self.state.write().unwrap()).remote_address = Some(remote_address);
}

pub fn path(&self) -> Option<Path> {
self.state.read().unwrap().path.clone()
}

pub fn set_path(&self, path: Option<Path>) {
Arc::make_mut(&mut *self.state.write().unwrap()).path = path;
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
struct State {
stream: DispatcherStream,
local_address: SocketAddr,
remote_address: Option<SocketAddr>,
path: Option<Path>,
}

#[cfg(test)]
mod tests {
use tokio::{net::UnixStream, sync::Notify};

use super::*;

fn new_socket() -> Result<(SocketAddr, UdpSocket), Box<dyn std::error::Error>> {
let (inner, _) = UnixStream::pair()?;
let stream = DispatcherStream::new(inner);
let local_addr: SocketAddr = "[1-ff00:0:111,127.0.0.17]:12300".parse()?;

Ok((local_addr, UdpSocket::new(stream, local_addr)))
}

#[tokio::test]
async fn set_path() -> Result<(), Box<dyn std::error::Error>> {
let (local_addr, socket) = new_socket()?;

let path = Path::empty(ByEndpoint::with_cloned(local_addr.isd_asn()));

let notify = Arc::new(Notify::new());
let notify2 = Arc::new(Notify::new());

let (result1, result2) = tokio::join!(
async {
let initial = socket.path();
socket.set_path(Some(path.clone()));
notify.notify_one();

notify2.notified().await;
let last_set = socket.path();

(initial, last_set)
},
async {
notify.notified().await;
let first_set = socket.path();
socket.set_path(None);
notify2.notify_one();

first_set
}
);

assert_eq!(result1, (None, None));
assert_eq!(result2, Some(path));

Ok(())
}
}
Loading

0 comments on commit 902933b

Please sign in to comment.