From 3320b016aacd499b3e28ad9bbe5e856789ffa92a Mon Sep 17 00:00:00 2001 From: Jean-Pierre Smith Date: Thu, 14 Dec 2023 00:27:32 +0100 Subject: [PATCH] fix: correct and test send and receive corner cases --- .../scion-proto/src/address/scion_address.rs | 1 - crates/scion-proto/src/packet/udp.rs | 33 ++ crates/scion-proto/src/path.rs | 10 + crates/scion/src/udp_socket.rs | 379 ++++++++++++++++-- 4 files changed, 398 insertions(+), 25 deletions(-) diff --git a/crates/scion-proto/src/address/scion_address.rs b/crates/scion-proto/src/address/scion_address.rs index 49d8aac..7c1b1e4 100644 --- a/crates/scion-proto/src/address/scion_address.rs +++ b/crates/scion-proto/src/address/scion_address.rs @@ -167,7 +167,6 @@ macro_rules! scion_address { fn from_str(s: &str) -> Result { s.split_once(',') .and_then(|(ia_str, host_str)| { - println!("{}, {}", ia_str, host_str); ia_str.parse().ok().zip(host_str.parse().ok()) }) .map(|(isd_asn, host)| Self {isd_asn, host}) diff --git a/crates/scion-proto/src/packet/udp.rs b/crates/scion-proto/src/packet/udp.rs index 6f87f64..558262f 100644 --- a/crates/scion-proto/src/packet/udp.rs +++ b/crates/scion-proto/src/packet/udp.rs @@ -19,6 +19,39 @@ pub struct ScionPacketUdp { pub datagram: UdpDatagram, } +impl ScionPacketUdp { + /// Returns the source socket address of the UDP packet. + pub fn source(&self) -> Option { + self.headers + .address + .source() + .map(|scion_addr| SocketAddr::new(scion_addr, self.src_port())) + } + + /// Returns the destination socket address of the UDP packet. + pub fn destination(&self) -> Option { + self.headers + .address + .destination() + .map(|scion_addr| SocketAddr::new(scion_addr, self.dst_port())) + } + + /// Returns the UDP packet payload. + pub fn payload(&self) -> &Bytes { + &self.datagram.payload + } + + /// Returns the UDP source port + pub fn src_port(&self) -> u16 { + self.datagram.port.source + } + + /// Returns the UDP destination port + pub fn dst_port(&self) -> u16 { + self.datagram.port.destination + } +} + impl ScionPacketUdp { /// Creates a new SCION UDP packet based on the UDP payload pub fn new( diff --git a/crates/scion-proto/src/path.rs b/crates/scion-proto/src/path.rs index 53b85ed..771637f 100644 --- a/crates/scion-proto/src/path.rs +++ b/crates/scion-proto/src/path.rs @@ -54,6 +54,16 @@ impl Path { } } + /// Returns a path for sending packets within the specified AS. + /// + /// # Panics + /// + /// Panics if the AS is a wildcard AS. + pub fn local(isd_asn: IsdAsn) -> Self { + assert!(!isd_asn.is_wildcard(), "no local path for wildcard AS"); + Self::empty(ByEndpoint::with_cloned(isd_asn)) + } + pub fn empty(isd_asn: ByEndpoint) -> Self { Self { dataplane_path: DataplanePath::EmptyPath, diff --git a/crates/scion/src/udp_socket.rs b/crates/scion/src/udp_socket.rs index ff1e075..c21c6f8 100644 --- a/crates/scion/src/udp_socket.rs +++ b/crates/scion/src/udp_socket.rs @@ -42,8 +42,6 @@ pub enum SendError { PathExpired, #[error("remote address is not set")] NotConnected, - #[error("socket is already connected")] - AlreadyConnected, #[error("path is not set")] NoPath, #[error("no underlay next hop provided by path")] @@ -165,7 +163,12 @@ impl UdpSocket { /// Registers a remote address for this socket. pub fn connect(&self, remote_address: SocketAddr) { - self.inner.set_remote_address(remote_address); + self.inner.set_remote_address(Some(remote_address)); + } + + /// Clears the association, if any, with the remote address. + pub fn disconnect(&self) { + self.inner.set_remote_address(None); } /// Registers or clears a path for this socket. @@ -211,7 +214,11 @@ impl UdpSocket { } /// Error messages returned from the UDP socket. -pub type ReceiveError = std::convert::Infallible; +#[derive(Debug, thiserror::Error, PartialEq, Eq)] +pub enum ReceiveError { + #[error("attempted to receive with a zero-length buffer")] + ZeroLengthBuffer, +} macro_rules! log_err { ($message:expr) => { @@ -248,13 +255,8 @@ impl UdpSocketInner { ) -> Result<(), SendError> { let state = self.state.read().unwrap().clone(); let path = path.or(state.path.as_ref()).ok_or(SendError::NoPath)?; - let Some(destination) = destination.xor(state.remote_address) else { - // Either both are None or both are Some - return if state.remote_address.is_none() { - Err(SendError::NotConnected) - } else { - Err(SendError::AlreadyConnected) - }; + let Some(destination) = destination.or(state.remote_address) else { + return Err(SendError::NotConnected); }; if let Some(metadata) = &path.metadata { @@ -312,6 +314,14 @@ impl UdpSocketInner { } async fn recv_loop(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr, Path), ReceiveError> { + if buf.is_empty() { + return Err(ReceiveError::ZeroLengthBuffer); + } + + // Keep a copy of the connection's remote_addr locally, so that the user connecting to a + // different destination does not affect what this call should return. + let remote_addr = self.state.read().unwrap().remote_address; + loop { let receive_result = { let stream = &mut *self.stream.lock().await; @@ -320,7 +330,9 @@ impl UdpSocketInner { match receive_result { Ok(packet) => { - if let Some((packet_len, sender, path)) = self.parse_incoming(packet, buf) { + if let Some((packet_len, sender, path)) = + self.parse_incoming_for(packet, buf, remote_addr) + { return Ok((packet_len, sender, path)); } else { continue; @@ -331,10 +343,11 @@ impl UdpSocketInner { } } - fn parse_incoming( + fn parse_incoming_for( &self, mut packet: Packet, buf: &mut [u8], + remote_addr: Option, ) -> Option<(usize, SocketAddr, Path)> { // TODO(jsmith): Need a representation of the packets for logging purposes let mut scion_packet = ScionPacketRaw::decode(&mut packet.content) @@ -357,6 +370,13 @@ impl UdpSocketInner { return None; }; + if let Some(remote_addr) = remote_addr { + if remote_addr != source { + tracing::debug!(%source, %remote_addr, "dropping packet not from connected remote"); + return None; + } + } + let payload_len = udp_datagram.payload.len(); let copy_length = cmp::min(payload_len, buf.len()); buf[..copy_length].copy_from_slice(&udp_datagram.payload[..copy_length]); @@ -380,8 +400,8 @@ impl UdpSocketInner { self.state.read().unwrap().remote_address } - pub fn set_remote_address(&self, remote_address: SocketAddr) { - Arc::make_mut(&mut *self.state.write().unwrap()).remote_address = Some(remote_address); + pub fn set_remote_address(&self, remote_address: Option) { + Arc::make_mut(&mut *self.state.write().unwrap()).remote_address = remote_address; } pub fn path(&self) -> Option { @@ -402,23 +422,334 @@ struct State { #[cfg(test)] mod tests { + use scion_proto::path::DataplanePath; use tokio::{net::UnixStream, sync::Notify}; use super::*; - fn new_socket() -> Result<(SocketAddr, UdpSocket), Box> { - let (inner, _) = UnixStream::pair()?; - let stream = DispatcherStream::new(inner); - let local_addr: SocketAddr = "[1-ff00:0:111,127.0.0.17]:12300".parse()?; + type TestResult = Result>; - Ok((local_addr, UdpSocket::new(stream, local_addr))) + mod utils { + + use super::*; + + pub fn socket_from(source: SocketAddr) -> TestResult<(UdpSocket, DispatcherStream)> { + let (inner, inner_remote) = UnixStream::pair()?; + Ok(( + UdpSocket::new(DispatcherStream::new(inner), source), + DispatcherStream::new(inner_remote), + )) + } + + pub async fn read_udp_packet( + dispatcher: &mut DispatcherStream, + ) -> TestResult { + let mut packet = dispatcher.receive_packet().await?; + let packet_raw = ScionPacketRaw::decode(&mut packet.content)?; + let packet_udp = ScionPacketUdp::try_from(packet_raw)?; + + Ok(packet_udp) + } + + pub async fn local_send_raw( + dispatcher: &mut DispatcherStream, + endpoints: ByEndpoint, + message: &[u8], + ) -> TestResult<()> { + debug_assert!( + endpoints.map(SocketAddr::isd_asn).are_equal(), + "expected intra-AS addresses" + ); + + let relay = endpoints + .destination + .local_address() + .map(|mut socket_addr| { + socket_addr.set_port(dispatcher::UNDERLAY_PORT); + socket_addr + }) + .expect("IPv4/6 local address"); + + let packet = ScionPacketUdp::new( + &endpoints, + &Path::local(endpoints.source.isd_asn()), + Bytes::copy_from_slice(message), + )?; + + dispatcher.send_packet_via(Some(relay), packet).await?; + + Ok(()) + } } - #[tokio::test] - async fn set_path() -> Result<(), Box> { - let (local_addr, socket) = new_socket()?; + macro_rules! async_test_case { + ($name:ident: $func:ident($arg1:expr$(, $arg:expr)*)) => { + #[tokio::test] + async fn $name() -> TestResult { + $func($arg1 $(, $arg)*).await + } + }; + } + + const MESSAGE: &[u8] = b"Hello World! Hello World! Hello World!"; + + mod send_to_via { + use super::*; + + async fn test_send_to_via( + local_addr: &str, + remote_addr: &str, + connect_addr: Option<&str>, + ) -> TestResult { + let local_addr = local_addr.parse()?; + let remote_addr = remote_addr.parse()?; + + let (socket, mut dispatcher) = utils::socket_from(local_addr)?; + let path = Path::local(local_addr.isd_asn()); + + if let Some(connect_addr) = connect_addr { + socket.connect(connect_addr.parse()?); + } + + socket + .send_to_via(Bytes::from_static(MESSAGE), remote_addr, &path) + .await?; + + let udp_packet = utils::read_udp_packet(&mut dispatcher).await?; + assert_eq!(udp_packet.source(), Some(local_addr)); + assert_eq!(udp_packet.destination(), Some(remote_addr)); + assert_eq!(udp_packet.payload().as_ref(), MESSAGE); + + Ok(()) + } + + async_test_case! { + unconnected: test_send_to_via( + "[1-ff00:0:111,10.0.0.1]:22472", "[1-ff00:0:111,10.0.0.2]:443", None + ) + } + + async_test_case! { + connected: test_send_to_via( + "[1-ff00:0:111,10.0.0.1]:22472", + "[1-ff00:0:111,10.32.32.32]:443", + Some("[1-ff00:0:111,10.64.64.64]:1024") + ) + } + + const REMOTE_ADDR: &str = "[1-ff00:0:111,10.32.32.32]:443"; + async_test_case! { + connected_same_destination: + test_send_to_via("[1-ff00:0:111,10.0.0.1]:22472", REMOTE_ADDR, Some(REMOTE_ADDR)) + } + } + + mod send_via { + use super::*; + + #[tokio::test] + async fn errs_when_unconnected() -> TestResult { + let local_addr = "[1-ff00:0:112,10.0.255.20]:2121".parse()?; + let (socket, _) = utils::socket_from(local_addr)?; + let path = Path::local(local_addr.isd_asn()); + + let err = socket + .send_via(Bytes::from_static(MESSAGE), &path) + .await + .expect_err("should fail on unconnected socket"); + + assert!( + matches!(err, SendError::NotConnected), + "expected {:?}, got {:?}", + SendError::NotConnected, + err + ); + + Ok(()) + } + + #[tokio::test] + async fn connected() -> TestResult { + let local_addr = "[1-ff00:0:112,10.0.255.20]:2020".parse()?; + let remote_addr = "[1-ff00:0:112,192.168.0.99]:9981".parse()?; + let (socket, mut dispatcher) = utils::socket_from(local_addr)?; + let path = Path::local(local_addr.isd_asn()); + + socket.connect(remote_addr); + socket.send_via(Bytes::from(MESSAGE), &path).await?; + + let udp_packet = utils::read_udp_packet(&mut dispatcher).await?; + + assert_eq!(udp_packet.source(), Some(local_addr)); + assert_eq!(udp_packet.destination(), Some(remote_addr)); + assert_eq!(udp_packet.payload().as_ref(), MESSAGE); + + Ok(()) + } + } + + async fn test_unconnected_recv( + local_addr: &str, + remote_addr: &str, + use_from: bool, + ) -> TestResult { + let endpoints = ByEndpoint:: { + source: remote_addr.parse()?, + destination: local_addr.parse()?, + }; + assert_eq!(endpoints.source.isd_asn(), endpoints.destination.isd_asn()); + + let mut buffer = [0u8; 64]; + let (socket, mut dispatcher) = utils::socket_from(endpoints.source)?; + utils::local_send_raw(&mut dispatcher, endpoints, MESSAGE).await?; + + let (length, incoming_remote_addr, incoming_path) = if use_from { + socket.recv_from_with_path(&mut buffer).await? + } else { + let res = socket.recv_with_path(&mut buffer).await?; + (res.0, endpoints.source, res.1) + }; + + assert_eq!(&buffer[..length], MESSAGE); + assert_eq!(incoming_remote_addr, endpoints.source); + assert_eq!(incoming_path.dataplane_path, DataplanePath::EmptyPath); + assert_eq!(incoming_path.isd_asn, endpoints.map(SocketAddr::isd_asn)); + assert_eq!(incoming_path.metadata, None); + assert_ne!(incoming_path.underlay_next_hop, None); + + Ok(()) + } + + async fn test_connected_recv( + local_addr: &str, + remote_addr: &str, + other_remote_addr: &str, + use_from: bool, + ) -> TestResult { + let endpoints = ByEndpoint:: { + source: remote_addr.parse()?, + destination: local_addr.parse()?, + }; + assert_eq!(endpoints.source.isd_asn(), endpoints.destination.isd_asn()); + + let messages = [ + b"Message 1!".as_slice(), + b"Message 2! Message 2!", + b"Message 3! Message 3! Message 3!", + ]; + let other_endpoints = ByEndpoint { + source: other_remote_addr.parse()?, + ..endpoints + }; + assert_eq!(other_endpoints.source.isd_asn(), endpoints.source.isd_asn()); + + let mut buffer = [0u8; 64]; + let (socket, mut dispatcher) = utils::socket_from(endpoints.source)?; + + // Write packets to be received + for (send_endpoints, message) in [ + (other_endpoints, messages[0]), + (endpoints, messages[1]), + (other_endpoints, messages[2]), + ] { + utils::local_send_raw(&mut dispatcher, send_endpoints, message).await?; + } + + // Connect to the remote source + socket.connect(endpoints.source); + + let length = if use_from { + let (length, remote_addr, _) = socket.recv_from_with_path(&mut buffer).await?; + assert_eq!(remote_addr, endpoints.source); + length + } else { + socket.recv_with_path(&mut buffer).await?.0 + }; + + // The first packet received is the second packet written. + assert_eq!(&buffer[..length], messages[1]); + + // Disconnect the association to receive packets with other addresses + socket.disconnect(); + + let length = if use_from { + let (length, remote_addr, _) = socket.recv_from_with_path(&mut buffer).await?; + assert_eq!(remote_addr, other_endpoints.source); + length + } else { + socket.recv_with_path(&mut buffer).await?.0 + }; - let path = Path::empty(ByEndpoint::with_cloned(local_addr.isd_asn())); + // The second packet packet received is the third packet written. + assert_eq!(&buffer[..length], messages[2]); + + Ok(()) + } + + mod recv_from_with_path { + use super::*; + + pub const USE_FROM: bool = true; + + async_test_case! { + connected: + test_connected_recv( + "[1-f:0:3,4.4.0.1]:80", "[1-f:0:3,11.10.13.7]:443", "[1-f:0:3,10.20.30.40]:981", USE_FROM + ) + } + async_test_case! { + unconnected: test_unconnected_recv("[1-f:0:3,4.4.0.1]:80", "[1-f:0:3,11.10.13.7]:443", USE_FROM) + } + + #[tokio::test] + async fn zero_length_buffer() -> TestResult { + let endpoints = ByEndpoint:: { + source: "[1-f:0:3,4.4.0.1]:80".parse()?, + destination: "[1-f:0:3,11.10.13.7]:443".parse()?, + }; + assert_eq!(endpoints.source.isd_asn(), endpoints.destination.isd_asn()); + + let mut buffer = [0u8; 64]; + let (socket, mut dispatcher) = utils::socket_from(endpoints.source)?; + utils::local_send_raw(&mut dispatcher, endpoints, MESSAGE).await?; + + let err = socket + .recv_from_with_path(&mut []) + .await + .expect_err("should fail due to zero-length buffer"); + assert_eq!(err, ReceiveError::ZeroLengthBuffer); + + // The data should still be available to read + let (length, incoming_remote_addr, _) = socket.recv_from_with_path(&mut buffer).await?; + + assert_eq!(&buffer[..length], MESSAGE); + assert_eq!(incoming_remote_addr, endpoints.source); + + Ok(()) + } + } + + mod recv_with_path { + use super::*; + + pub const USE_FROM: bool = true; + + async_test_case! { + connected: + test_connected_recv( + "[1-f:0:3,4.4.0.1]:80", "[1-f:0:3,11.10.13.7]:443", "[1-f:0:3,10.20.30.40]:981", !USE_FROM + ) + } + async_test_case! { + unconnected: test_unconnected_recv("[1-f:0:3,3.3.3.3]:80", "[1-f:0:3,9.9.9.81]:443", !USE_FROM) + } + } + + #[tokio::test] + async fn set_path() -> TestResult { + let local_addr: SocketAddr = "[1-f:0:1,9.8.7.6]:80".parse()?; + let (socket, _) = utils::socket_from(local_addr)?; + let path = Path::local(local_addr.isd_asn()); let notify = Arc::new(Notify::new()); let notify2 = Arc::new(Notify::new());