diff --git a/crates/scion/src/reliable.rs b/crates/scion/src/reliable.rs index 6709ad8..5bbecab 100644 --- a/crates/scion/src/reliable.rs +++ b/crates/scion/src/reliable.rs @@ -1,14 +1,20 @@ mod common_header; -pub use common_header::DecodeError; +use std::net::SocketAddr; -mod error; -pub use error::ReliableRelayError; +use bytes::Bytes; +pub use common_header::DecodeError; mod relay_protocol; -pub use relay_protocol::ReliableRelayProtocol; +pub use relay_protocol::{ReceiveError, ReliableRelayProtocol, SendError}; mod parser; mod registration; mod wire_utils; const ADDRESS_TYPE_OCTETS: usize = 1; + +#[derive(Debug)] +pub struct Packet { + pub last_hop: Option, + pub content: Vec, +} diff --git a/crates/scion/src/reliable/common_header.rs b/crates/scion/src/reliable/common_header.rs index 69a5fb5..dcd368a 100644 --- a/crates/scion/src/reliable/common_header.rs +++ b/crates/scion/src/reliable/common_header.rs @@ -10,7 +10,7 @@ use super::{ use crate::address::{HostAddress, HostType}; /// Errors occurring during decoding of packets received over the reliable-relay protocol. -#[derive(Error, Debug, Eq, PartialEq)] +#[derive(Error, Debug, Eq, PartialEq, Clone, Copy)] pub enum DecodeError { /// The decoded packet started with an incorrect token. This indicates a /// synchronisation issue with the relay. @@ -26,20 +26,20 @@ pub enum DecodeError { } /// Partial or fully decoded commonHeader +#[derive(Debug)] pub(super) enum DecodedHeader { Partial(PartialHeader), Full(CommonHeader), } impl DecodedHeader { - #[allow(dead_code)] pub fn is_fully_decoded(&self) -> bool { matches!(self, DecodedHeader::Full(..)) } } /// A partially decoded common header -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub(super) struct PartialHeader { pub host_type: HostType, pub payload_length: u32, @@ -126,7 +126,6 @@ impl CommonHeader { /// The minimum length of a common header. pub const MIN_LENGTH: usize = Self::COOKIE_LENGTH + ADDRESS_TYPE_OCTETS + Self::PAYLOAD_SIZE_LENGTH; - #[allow(dead_code)] /// The maximum length of the common header. pub const MAX_LENGTH: usize = Self::MIN_LENGTH + IPV6_OCTETS + LAYER4_PORT_OCTETS; @@ -134,13 +133,7 @@ impl CommonHeader { const COOKIE_LENGTH: usize = 8; const PAYLOAD_SIZE_LENGTH: usize = 4; - #[allow(dead_code)] - pub fn new() -> Self { - Self::default() - } - /// The size of the payload as a usize. - #[allow(dead_code)] #[inline] pub fn payload_size(&self) -> usize { self.payload_length @@ -210,6 +203,7 @@ impl CommonHeader { /// Panics if there is insufficient data in the buffer to decode the entire header. /// To avoid the panic, ensure there is [`Self::MAX_LENGTH`] bytes available, or /// use [`Self::partial_decode()`] instead. + #[cfg(test)] pub fn decode(buffer: &mut impl Buf) -> Result { PartialHeader::decode(buffer).map(|header| header.finish_decoding(buffer)) } diff --git a/crates/scion/src/reliable/error.rs b/crates/scion/src/reliable/error.rs index cf18690..62f7e42 100644 --- a/crates/scion/src/reliable/error.rs +++ b/crates/scion/src/reliable/error.rs @@ -1,11 +1 @@ use thiserror::Error; - -#[derive(Error, Debug, Eq, PartialEq)] -pub enum ReliableRelayError { - #[error("provided destination address must be specified, not 0.0.0.0 or ::0")] - DestinationUnspecified, - #[error("provided destination port mmust be specified")] - DestinationPortUnspecified, - #[error("port mismatch, requested port {requested}, received port {assigned}")] - PortMismatch { requested: u16, assigned: u16 }, -} diff --git a/crates/scion/src/reliable/parser.rs b/crates/scion/src/reliable/parser.rs index c2aefc6..9daafe2 100644 --- a/crates/scion/src/reliable/parser.rs +++ b/crates/scion/src/reliable/parser.rs @@ -1,116 +1,208 @@ -use std::{collections::VecDeque, ops::Deref}; - use bytes::{Buf, Bytes}; -use super::common_header::{CommonHeader, DecodeError, DecodedHeader}; +use super::{ + common_header::{CommonHeader, DecodeError, DecodedHeader}, + wire_utils::BytesQueue, + Packet, +}; -pub(super) struct StreamParser { - // INV: byte objects are always non-empty - byte_queue: VecDeque, - bytes_remaining: usize, - next_header: Option, +#[derive(Debug, thiserror::Error)] +pub(super) enum ParseError { + #[error("parser is blocked awaiting data")] + Blocked, + #[error("parsing the next packet resulted in an error")] + Decode(#[from] DecodeError), + #[error("parsing the packet after the next resulted in an error")] + PacketWithError { packet: Packet, error: DecodeError }, } -impl StreamParser { - pub fn new() -> Self { - Self { - byte_queue: VecDeque::new(), - bytes_remaining: 0, - next_header: None, - } - } +#[derive(Debug)] +enum State { + Good(StreamParserInner), + Bad(DecodeError), +} - pub fn append_data(&mut self, data: Bytes) { - if !data.is_empty() { - self.bytes_remaining += data.len(); - self.byte_queue.push_back(data); +impl State { + #[must_use] + fn to_bad(&self, error: DecodeError) -> Self { + match self { + Self::Good(_) => Self::Bad(error), + Self::Bad(_) => panic!("should not be called in the bad state"), } } +} - pub fn next_packet(&mut self) -> Result)>, DecodeError> { - match &self.next_header { - None if self.remaining() >= CommonHeader::MIN_LENGTH => { - self.next_header = Some(CommonHeader::partial_decode(self)?); - - // Recursively try to get the payload if we parsed a full common header - if self.next_header.as_ref().unwrap().is_fully_decoded() { - self.next_packet() - } else { - Ok(None) - } - } - Some(DecodedHeader::Partial(header)) if self.remaining() >= header.required_bytes() => { - self.next_header = Some(DecodedHeader::Full(header.finish_decoding(self))); - self.next_packet() - } - Some(DecodedHeader::Full(header)) if self.remaining() >= header.payload_size() => { - let header = *header; - let payload = self.get_payload(header.payload_size()); +#[derive(Default, Debug)] +struct StreamParserInner { + byte_queue: BytesQueue, + // INV: Always contains the most decodable version of the next header. + next_header: Option, +} - self.next_header = None; +impl StreamParserInner { + fn remaining(&self) -> usize { + self.byte_queue.remaining() + } - Ok(Some((header, payload))) + /// Decode and store the next partial or common header available from the data. + /// + /// Does nothing is a full common header is already decoded, or if there is insufficient data. + fn decode_next_header(&mut self) -> Result<(), DecodeError> { + match &mut self.next_header { + None if self.byte_queue.remaining() >= CommonHeader::MIN_LENGTH => { + self.next_header = Some(CommonHeader::partial_decode(&mut self.byte_queue)?); } - _ => Ok(None), + Some(DecodedHeader::Partial(header)) + if self.byte_queue.remaining() >= header.required_bytes() => + { + self.next_header = Some(DecodedHeader::Full( + header.finish_decoding(&mut self.byte_queue), + )); + } + _ => (), } + + Ok(()) } + /// Get the payload for the currently decoded common header. + /// + /// Requires a common header to be decoded and there to be sufficient data remaining. fn get_payload(&mut self, payload_size: usize) -> Vec { - let mut result = vec![]; + assert!(self.remaining() >= payload_size); - let mut payload_bytes_needed = payload_size; + let mut payload = vec![]; + let mut bytes_needed = payload_size; - while payload_bytes_needed > 0 { + while bytes_needed > 0 { let mut data = self.byte_queue.pop_front().expect("there must be data"); - if data.len() > payload_bytes_needed { - self.byte_queue - .push_front(data.split_off(payload_bytes_needed)); + if data.len() > bytes_needed { + self.byte_queue.push_front(data.split_off(bytes_needed)); } - assert!(data.len() <= payload_bytes_needed); - - payload_bytes_needed -= data.len(); - result.push(data); + assert!(data.len() <= bytes_needed); + bytes_needed -= data.len(); + payload.push(data); } - result + payload + } + + /// Returns true if a packet is available to be retrieved. + pub fn is_packet_available(&self) -> bool { + if let Some(DecodedHeader::Full(header)) = &self.next_header { + self.remaining() >= header.payload_size() + } else { + false + } } } -impl Buf for StreamParser { - fn remaining(&self) -> usize { - self.bytes_remaining +/// A parser to decode [`CommonHeader`]s and payloads from a sequence of [`Bytes`] +/// with arbitrary boundaries. +#[derive(Debug)] +pub(super) struct StreamParser { + state: State, +} + +impl StreamParser { + pub fn new() -> Self { + Self::default() } - fn chunk(&self) -> &[u8] { - self.byte_queue.front().map_or(&[], |data| data.deref()) + #[cfg(test)] + pub fn remaining(&self) -> usize { + match &self.state { + State::Good(inner) => inner.remaining(), + State::Bad(_) => 0, + } } - fn advance(&mut self, cnt: usize) { - if cnt == 0 { - return; + /// Returns the StreamParserInner if in a good state, otherwise returns the error + /// that moved the parser to a bad state. + fn inner_mut(&mut self) -> Result<&mut StreamParserInner, DecodeError> { + match &mut self.state { + State::Good(inner) => Ok(inner), + State::Bad(err) => Err(*err), } - if cnt > self.bytes_remaining { - panic!( - "cnt > self.remaining() ({} > {})", - cnt, self.bytes_remaining - ); + } + + /// Add the data to the queue, and attempts to decode the next packet. + /// + /// Returns true if the addition of the data resulted in a packet being available. + /// + /// # Errors + /// + /// Returns the error that caused the packet to fail to be decoded or the last + /// error encountered if the stream is in a bad state. + pub fn append_data(&mut self, data: Bytes) -> Result { + if data.is_empty() { + return Ok(false); } - let mut advance_by = cnt; - while advance_by > 0 { - let mut data = self.byte_queue.pop_front().expect("there must be data"); + let inner = self.inner_mut()?; + let packet_already_available = inner.is_packet_available(); + + inner.byte_queue.push_back(data); - if data.len() > advance_by { - self.byte_queue.push_front(data.split_off(advance_by)); + if !packet_already_available { + match inner.decode_next_header() { + Ok(()) => Ok(self.is_packet_available()), + Err(err) => { + self.state = self.state.to_bad(err); + Err(err) + } } - assert!(data.len() <= advance_by); + } else { + Ok(false) + } + } + + /// Returns the next available packet if any. + /// + /// # Errors + /// + /// Returns an error if blocked or if decoding for the requested packet failed. + /// If a packet is successfully parsed, but the parsing of the packet after that + /// fails, then both the packet and an error are returned. + pub fn next_packet(&mut self) -> Result { + let inner = self.inner_mut()?; + + if let Some(DecodedHeader::Full(header)) = inner.next_header { + inner.next_header.take(); + + let packet = Packet { + last_hop: header.destination, + content: inner.get_payload(header.payload_size()), + }; - advance_by -= data.len(); + match inner.decode_next_header() { + Ok(()) => Ok(packet), + Err(error) => { + self.state = self.state.to_bad(error); + Err(ParseError::PacketWithError { packet, error }) + } + } + } else { + Err(ParseError::Blocked) } + } - self.bytes_remaining -= cnt; + /// Returns true if a packet is available to be retrieved. + pub fn is_packet_available(&self) -> bool { + match &self.state { + State::Good(inner) => inner.is_packet_available(), + State::Bad(_) => false, + } + } +} + +impl Default for StreamParser { + fn default() -> Self { + Self { + state: State::Good(StreamParserInner::default()), + } } } @@ -118,16 +210,149 @@ impl Buf for StreamParser { mod tests { use super::*; - #[test] - fn has_available_multiple() { - let mut parser = StreamParser::new(); + const PACKET: [u8; 35] = [ + 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 2, 0, 0, 0, 4, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 0, 80, b'R', b'U', b'S', b'T', + ]; + const BAD_PACKET: [u8; 35] = [ + 0xbe, 2, 0xef, 3, 0xde, 0, 0xad, 1, 2, 0, 0, 0, 4, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 0, 80, b'R', b'U', b'S', b'T', + ]; + + mod append_data { + use super::*; + + #[test] + fn from_empty() { + let mut parser = StreamParser::new(); - parser.append_data(Bytes::from_static(&[0, 1, 2])); - parser.append_data(Bytes::from_static(&[4, 5, 6])); + const MIN_LENGTH: usize = CommonHeader::MIN_LENGTH; + const AVAILABLE: bool = true; + + let mut range_lower: usize = 0; + for (range_upper, expected_remaining, expected_available) in [ + (MIN_LENGTH - 1, MIN_LENGTH - 1, !AVAILABLE), + (MIN_LENGTH + 2, 2, !AVAILABLE), + (PACKET.len() - 1, 3, !AVAILABLE), + (PACKET.len(), 4, AVAILABLE), + ] { + let is_available = parser + .append_data(PACKET[range_lower..range_upper].into()) + .unwrap(); + + let context = format!("at append_data(PACKET[{}..{}])", range_lower, range_upper); + + assert_eq!(is_available, expected_available, "{}", context); + assert_eq!(parser.is_packet_available(), is_available, "{}", context); + assert_eq!(parser.remaining(), expected_remaining, "{}", context); + + range_lower = range_upper; + } + } + + #[test] + fn packet_already_available() { + let mut parser = StreamParser::new(); + + let newly_available = parser.append_data(PACKET.as_slice().into()).unwrap(); + assert!(newly_available); + assert_eq!(newly_available, parser.is_packet_available()); + + let newly_available = parser.append_data(PACKET.as_slice().into()).unwrap(); + assert!(!newly_available); + assert_ne!(newly_available, parser.is_packet_available()); + } + + #[test] + fn bad_packet() { + let mut parser = StreamParser::new(); + parser + .append_data(BAD_PACKET.as_slice().into()) + .expect_err("expected invalid packet"); + + assert!(!parser.is_packet_available()); + } + + #[test] + fn bad_packet_when_already_available() { + let mut parser = StreamParser::new(); + + parser.append_data(PACKET.as_slice().into()).unwrap(); + assert!(parser.is_packet_available()); + + parser + .append_data(BAD_PACKET.as_slice().into()) + .expect("should not yet parse the bad packet as another is pending"); + } - let mut buffer = [0u8; 6]; - parser.copy_to_slice(&mut buffer); + #[test] + fn append_good_packet_to_bad_state() { + let mut parser = StreamParser::new(); - assert_eq!(buffer, [0, 1, 2, 4, 5, 6]); + parser + .append_data(BAD_PACKET.as_slice().into()) + .unwrap_err(); + assert!(!parser.is_packet_available()); + + parser + .append_data(PACKET.as_slice().into()) + .expect_err("should fail due to appending error to bad state"); + assert!(!parser.is_packet_available()); + } + } + + mod next_packet { + use super::*; + use crate::test_utils::parse; + + #[test] + fn available() { + let mut parser = StreamParser::new(); + + assert!(!parser.is_packet_available()); + assert!(matches!(parser.next_packet(), Err(ParseError::Blocked))); + + parser.append_data(PACKET.as_slice().into()).unwrap(); + + assert!(parser.is_packet_available()); + + let packet = parser.next_packet().unwrap(); + + assert_eq!(packet.last_hop, Some(parse!("[2001:db8::1]:80"))); + assert_eq!(packet.content, vec![Bytes::from_static(b"RUST")]); + assert!(!parser.is_packet_available()); + } + + #[test] + fn multiple_available() { + let mut parser = StreamParser::new(); + + assert!(!parser.is_packet_available()); + assert!(matches!(parser.next_packet(), Err(ParseError::Blocked))); + + const NUMBER_PACKETS: usize = 3; + + for _ in 0..NUMBER_PACKETS { + parser.append_data(PACKET.as_slice().into()).unwrap(); + } + + for _ in 0..NUMBER_PACKETS { + assert!(parser.is_packet_available()); + let _ = parser.next_packet().unwrap(); + } + assert!(!parser.is_packet_available()); + } + + #[test] + fn packet_with_error() { + let mut parser = StreamParser::new(); + parser.append_data(PACKET.as_slice().into()).unwrap(); + parser.append_data(BAD_PACKET.as_slice().into()).unwrap(); + + assert!(matches!( + parser.next_packet().expect_err("should return an error"), + ParseError::PacketWithError { .. } + )); + } } } diff --git a/crates/scion/src/reliable/registration.rs b/crates/scion/src/reliable/registration.rs index 30ddf95..e3d1306 100644 --- a/crates/scion/src/reliable/registration.rs +++ b/crates/scion/src/reliable/registration.rs @@ -72,7 +72,6 @@ impl RegistrationRequest { self } - #[allow(dead_code)] pub fn encoded_length(&self) -> usize { CommonHeader::MIN_LENGTH + self.encoded_request_length() } @@ -82,7 +81,6 @@ impl RegistrationRequest { /// # Panics /// /// Panics if there is not enough space in the buffer to encode the request. - #[allow(dead_code)] pub fn encode_to(&self, buffer: &mut impl BufMut) { self.encode_common_header(buffer); self.encode_request(buffer); @@ -111,7 +109,6 @@ impl RegistrationRequest { assert_eq!(written, self.encoded_request_length()); } - #[allow(dead_code)] #[inline] fn encode_common_header(&self, buffer: &mut impl BufMut) { CommonHeader { @@ -153,19 +150,16 @@ impl RegistrationRequest { /// this always contains the assigned port number. pub(super) struct RegistrationResponse { /// The port assigned by the dispatcher. - #[allow(dead_code)] pub assigned_port: u16, } impl RegistrationResponse { /// The length of the encoded registration response. - #[allow(dead_code)] pub const ENCODED_LENGTH: usize = LAYER4_PORT_OCTETS; /// Decode a registration response from the provided buffer. /// /// Returns None if the buffer contains less than 2 bytes. - #[allow(dead_code)] pub fn decode(buffer: &mut impl Buf) -> Option { if buffer.remaining() >= Self::ENCODED_LENGTH { Some(Self { diff --git a/crates/scion/src/reliable/relay_protocol.rs b/crates/scion/src/reliable/relay_protocol.rs index 6eb212d..285b44d 100644 --- a/crates/scion/src/reliable/relay_protocol.rs +++ b/crates/scion/src/reliable/relay_protocol.rs @@ -1,47 +1,80 @@ -use std::{collections::VecDeque, net::SocketAddr}; +use std::{cmp::min, collections::VecDeque, net::SocketAddr}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use super::{ common_header::{CommonHeader, DecodeError}, - error::ReliableRelayError, - parser::StreamParser, + parser::{ParseError, StreamParser}, registration::{RegistrationRequest, RegistrationResponse}, + wire_utils::BytesQueue, + Packet, }; use crate::address::{IsdAsn, ServiceAddress}; enum State { Initial, RegistrationRequested { - request: RegistrationRequest, + bytes_received: BytesQueue, is_sent: bool, + request: RegistrationRequest, }, Registered { - transmit_queue: VecDeque<(CommonHeader, Bytes)>, + /// The port on which the instance is registered port: u16, + /// Packets waiting to be sent to the dispatcher + transmit_queue: VecDeque<(CommonHeader, Bytes)>, + parser: StreamParser, }, Terminated, } +#[derive(Debug, PartialEq, Eq)] +pub enum Event { + Registered, + Terminated { reason: ReliableRelayError }, + PacketsAvailable, +} + +/// The SCION client-to-dispatcher relay protocol. +/// +/// A reliable relay protocol to be used with the SCION dispatcher for sending pub struct ReliableRelayProtocol { state: State, - parser: StreamParser, + events: VecDeque, } impl ReliableRelayProtocol { - const MAX_TRANSMIT_BUFFER_SIZE: usize = 1_048_576; // 1 MiB + /// Maximum number of packets to send in a single [`Self::poll_transmit()`] call. + pub const MAX_TRANSMIT_BURST: usize = 100; + /// Create a new protocol instance. + /// + /// The instance must first be registered to a port before it can be used to + /// send or receive packets. pub fn new() -> Self { Self { state: State::Initial, - parser: StreamParser::new(), + events: VecDeque::new(), } } + /// Register to receive SCION packets destined for the given address and port. + /// + /// Should only be called once on a protocol instance, that has not been previously + /// registered. + /// + /// # Panics + /// + /// Panics if repeated registrations are attempted. + // TODO(jsmith): Clarify if we want to instead return an error instead of panic.. pub fn register(&mut self, isd_asn: IsdAsn, public_address: SocketAddr) { self.register_with_dispatcher(RegistrationRequest::new(isd_asn, public_address)) } + /// Register to receive SCION packets destined for the given address and port, or + /// for a specific SCION service. + /// + /// See [`Self::register()`] for more details. pub fn register_service( &mut self, isd_asn: IsdAsn, @@ -60,19 +93,27 @@ impl ReliableRelayProtocol { self.state = State::RegistrationRequested { request, is_sent: false, + bytes_received: BytesQueue::default(), } } + // None of these states can be reached without a previous registration call. + // Therefore, being in these states implies a repeated call to register. State::RegistrationRequested { .. } => panic!("registration already requested"), State::Registered { .. } => panic!("already registered with the dispatcher"), State::Terminated => panic!("protocol has already terminated"), } } + /// Poll for data pending to be sent to the dispatcher. + /// + /// Returns at most [`Self::MAX_TRANSMIT_BURST`] packets in a vector with length of at most twice + /// that value. Therefore, repeated calls may be necessary to fully drain the pending packets. pub fn poll_transmit(&mut self) -> Option> { match &mut self.state { State::RegistrationRequested { request, is_sent: is_sent @ false, + .. } => { let mut buffer = BytesMut::with_capacity(request.encoded_length()); request.encode_to(&mut buffer); @@ -81,41 +122,77 @@ impl ReliableRelayProtocol { Some(vec![buffer.freeze()]) } - State::Registered { transmit_queue, .. } => { - if transmit_queue.is_empty() { - None - } else { - let buffer_length = std::cmp::min( - CommonHeader::MAX_LENGTH.saturating_mul(transmit_queue.len()), - Self::MAX_TRANSMIT_BUFFER_SIZE, - ); - let mut buffer = BytesMut::with_capacity(buffer_length); - let mut output_bytes = Vec::new(); - - while let Some((header, bytes)) = transmit_queue.pop_front() { - let header_buffer = buffer.split_to(header.encoded_length()); - - output_bytes.push(header_buffer.freeze()); - output_bytes.push(bytes); - - if buffer.remaining_mut() < CommonHeader::MAX_LENGTH { - break; - } - } + State::Registered { transmit_queue, .. } if !transmit_queue.is_empty() => { + // The value 2 * MAX_TRANSMIT_BURST must be at most isize, which is the limit for + // number of Vec elements. + assert!(Self::MAX_TRANSMIT_BURST <= (isize::MAX >> 1) as usize); + + let transmit_burst = min(transmit_queue.len(), Self::MAX_TRANSMIT_BURST); + let buffer_length = transmit_queue + .iter() + .take(transmit_burst) + .fold(0, |sum, (header, _)| sum + header.encoded_length()); - Some(output_bytes) + let mut buffer = BytesMut::with_capacity(buffer_length); + let mut to_transmit = Vec::new(); + + for (header, bytes) in transmit_queue.drain(..transmit_burst) { + let header_buffer = buffer.split_to(header.encoded_length()); + + to_transmit.push(header_buffer.freeze()); + to_transmit.push(bytes); } + + Some(to_transmit) } - State::Initial | State::Terminated | State::RegistrationRequested { .. } => None, + _ => None, } } - pub fn send( - &mut self, - scion_packet_data: Bytes, - destination: SocketAddr, - ) -> Result<(), SendError> { + /// Returns application-facing events. + /// + /// The instance should be polled after one or more calls made to [`Self::handle_incoming()`]. + pub fn poll(&mut self) -> Option { + self.events.pop_front() + } + + /// Returns any packets that have been received from the dispatcher, + /// along with the last hop traversed. + /// + /// # Errors + /// + /// Returns an error if the protocol has already terminated, is not yet registered, + /// or is blocked waiting on more data. + pub fn receive(&mut self) -> Result { + match &mut self.state { + State::Initial | State::RegistrationRequested { .. } => { + Err(ReceiveError::NotRegistered) + } + State::Terminated => Err(ReceiveError::ProtocolTerminated), + State::Registered { parser, .. } => match parser.next_packet() { + Ok(packet) => Ok(packet), + Err(ParseError::Blocked) => Err(ReceiveError::Blocked), + Err(ParseError::PacketWithError { packet, error }) => { + self.terminate_protocol(error.into()); + Ok(packet) + } + Err(ParseError::Decode(error)) => { + self.terminate_protocol(error.into()); + Err(ReceiveError::ProtocolError(error)) + } + }, + } + } + + /// Send data to the specified destination. + /// + /// # Errors + /// + /// Returns an error if the destination is an unspecified IPv4 address (e.g., 0.0.0.0), + /// if the destination port is 0, or if the packet is larger than [`u32::MAX`] bytes instead + /// length. + pub fn send(&mut self, packet: Bytes, destination: SocketAddr) -> Result<(), SendError> { match &mut self.state { State::Initial | State::RegistrationRequested { .. } => Err(SendError::NotRegistered), State::Terminated => Err(SendError::ProtocolTerminated), @@ -125,96 +202,138 @@ impl ReliableRelayProtocol { } else if destination.port() == 0 { Err(SendError::DestinationPortUnspecified) } else { - transmit_queue.push_back(( - CommonHeader { - destination: Some(destination), - payload_length: u32::try_from(scion_packet_data.len()) - .or(Err(SendError::PayloadTooLarge(scion_packet_data.len())))?, - }, - scion_packet_data, - )); + let header = CommonHeader { + destination: Some(destination), + payload_length: u32::try_from(packet.len()) + .or(Err(SendError::PacketTooLarge(packet.len())))?, + }; + transmit_queue.push_back((header, packet)); Ok(()) } } } } - pub fn handle_incoming(&mut self, data: Bytes) -> Result<(), ReliableRelayError> { - match self.state { + /// Process stream data arriving from the dispatcher, and execute protocol logic on the data. + /// + /// This can result in events being generated or packets being available, which can be extracted + /// via the methods [`Self::poll`] and [`Self::receive`]. + pub fn handle_incoming(&mut self, data: Bytes) { + match &mut self.state { State::Initial | State::RegistrationRequested { is_sent: false, .. } => { - panic!("not yet registered, cannot handle incoming data") + // The current SCION dispatcher does not send data to the client before + // receiving a registration. However, there is a possibility that the + // source of the data being provided is not following the expected protocol. + // We therefore treat this as a protocol error and not a programmer error. + self.terminate_protocol(ReliableRelayError::DataBeforeRegistration) } State::RegistrationRequested { is_sent: true, .. } => { - self.parser.append_data(data); - self.maybe_complete_registration() + let mut data = data; + if self.maybe_complete_registration(&mut data) { + // Registration completed successfully and the state has advanced, + // handle the remaining data. + self.handle_incoming(data); + } else { + // Discard the data as it no longer usable. + assert!(data.is_empty() || self.is_terminated()); + } } - State::Registered { .. } => { - self.parser.append_data(data); - Ok(()) + State::Registered { parser, .. } => { + match parser.append_data(data) { + Err(err) => { + // Appended data that resulted in an invalid next packet, move to + // the terminated state. + self.terminate_protocol(err.into()); + } + Ok(true) => self.events.push_back(Event::PacketsAvailable), + Ok(false) => (), + } } - State::Terminated => panic!("protocol already terminated"), + State::Terminated => (), // Discard the data } } - fn maybe_complete_registration(&mut self) -> Result<(), ReliableRelayError> { + fn terminate_protocol(&mut self, reason: ReliableRelayError) { + self.state = State::Terminated; + self.events.push_back(Event::Terminated { reason }); + } + + /// Return True if a port has been successfully registered. + pub fn is_registered(&self) -> bool { + matches!(&self.state, State::Registered { .. }) + } + + pub fn is_terminated(&self) -> bool { + matches!(&self.state, State::Terminated) + } + + /// Completes the port registration, if possible, and advances to the next state. + /// + /// Takes only as much data as required to complete the registration, and returns true + /// if the registration completed successfully, false otherwise. + /// + /// If registration did not complete, either all the data was consumed or the protocol + /// terminated. + fn maybe_complete_registration(&mut self, data: &mut Bytes) -> bool { let State::RegistrationRequested { is_sent: true, request, - } = &self.state else { - panic!("must only be called while awaiting registration response"); + bytes_received, + } = &mut self.state + else { + unreachable!("only called while awaiting registration response"); }; - if let Some(response) = RegistrationResponse::decode(&mut self.parser) { + assert!(bytes_received.remaining() < RegistrationResponse::ENCODED_LENGTH); + + // Add only as much bytes as required and leave the rest in data. + let bytes_required = RegistrationResponse::ENCODED_LENGTH - bytes_received.remaining(); + let bytes_to_take = min(bytes_required, data.len()); + bytes_received.push_back(data.split_to(bytes_to_take)); + + if let Some(response) = RegistrationResponse::decode(bytes_received) { let requested_port = request.public_address.port(); if requested_port != response.assigned_port && requested_port != 0 { - self.state = State::Terminated; - - Err(ReliableRelayError::PortMismatch { + self.terminate_protocol(ReliableRelayError::PortMismatch { requested: requested_port, assigned: response.assigned_port, - }) + }); + + false } else { + // We added only as much as was required for decoding the response, so this is empty + assert_eq!(bytes_received.remaining(), 0); + self.state = State::Registered { transmit_queue: VecDeque::new(), port: response.assigned_port, + parser: StreamParser::new(), }; + self.events.push_back(Event::Registered); - Ok(()) + true } } else { - Ok(()) // Need more data, do nothing. + false } } + /// Returns the port registered to this protocol instance, if any. pub fn port(&self) -> Option { match self.state { State::Registered { port, .. } => Some(port), _ => None, } } - - pub fn receive(&mut self) -> Result<(SocketAddr, Vec), ReceiveError> { - match self.state { - State::Initial | State::RegistrationRequested { .. } => { - Err(ReceiveError::NotRegistered) - } - State::Terminated => Err(ReceiveError::ProtocolTerminated), - State::Registered { .. } => match self.parser.next_packet()? { - Some((header, bytes)) => Ok(( - // TODO(jsmith): Determine in which cases we do not receive an address - // TODO(jsmith): Handle when we do not receive an address with an error - header - .destination - .expect("there is always a desintation address"), - bytes, - )), - None => Err(ReceiveError::Blocked), - }, - } - } } +// TODO(jsmith): Determine how to simplify the errors +// We have two kinds of errors that we can encounter. First, errors like calling +// register twice on an object can be avoided by the programmer, whereas others, +// like those due to invalid data from the wire, or tranitions to the terminated +// state may be a 'surprise' to the programmer. + #[derive(thiserror::Error, Debug, Eq, PartialEq)] pub enum ReceiveError { #[error("currently no packets available, try again later")] @@ -231,14 +350,24 @@ pub enum ReceiveError { pub enum SendError { #[error("not yet registered")] NotRegistered, - #[error("protocol already terminated, receive is no longer possible")] + #[error("protocol already terminated, send is no longer possible")] ProtocolTerminated, #[error("provided destination address must be specified, not 0.0.0.0 or ::0")] DestinationUnspecified, #[error("provided destination port mmust be specified")] DestinationPortUnspecified, #[error("payload size too large ({0}), should be at most {}", u32::MAX)] - PayloadTooLarge(usize), + PacketTooLarge(usize), +} + +#[derive(thiserror::Error, Debug, Eq, PartialEq)] +pub enum ReliableRelayError { + #[error("port mismatch, requested port {requested}, received port {assigned}")] + PortMismatch { requested: u16, assigned: u16 }, + #[error("the protocol received data before the registration request was sent")] + DataBeforeRegistration, + #[error("failed to decode the a message")] + DecodeError(#[from] DecodeError), } impl Default for ReliableRelayProtocol { @@ -277,40 +406,43 @@ mod tests { #[test] fn success() { let mut relay = send_registration(); - relay - .handle_incoming(Bytes::from_static(&[0, 80])) - .expect("no error for valid response"); + + relay.handle_incoming(Bytes::from_static(&[0, 80])); + + assert_eq!(relay.poll(), Some(Event::Registered)); assert_eq!(relay.port(), Some(80)); } #[test] fn port_mismatch() { let mut relay = send_registration(); - let error = relay - .handle_incoming(Bytes::from_static(&[0, 81])) - .expect_err("expected port mismatch error"); + + relay.handle_incoming(Bytes::from_static(&[0, 81])); assert_eq!( - error, - ReliableRelayError::PortMismatch { - requested: 80, - assigned: 81 - } + relay.poll(), + Some(Event::Terminated { + reason: ReliableRelayError::PortMismatch { + requested: 80, + assigned: 81 + } + }) ); + assert_eq!(relay.port(), None); } #[test] fn incremental_data() { let mut relay = send_registration(); - relay - .handle_incoming(Bytes::from_static(&[0])) - .expect("no error for partial response"); + relay.handle_incoming(Bytes::from_static(&[0])); + + assert_eq!(relay.poll(), None); assert_eq!(relay.port(), None); - relay - .handle_incoming(Bytes::from_static(&[80])) - .expect("no error for valid total response"); + relay.handle_incoming(Bytes::from_static(&[80])); + + assert_eq!(relay.poll(), Some(Event::Registered)); assert_eq!(relay.port(), Some(80)); } } @@ -326,25 +458,24 @@ mod tests { state: State::Registered { transmit_queue: VecDeque::new(), port: 80, + parser: StreamParser::default(), }, - parser: StreamParser::new(), + events: VecDeque::new(), }; let Err(ReceiveError::Blocked) = relay.receive() else { panic!("expected to be blocked"); }; - relay - .handle_incoming(Bytes::from_static(&[ - 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 5, 10, 2, 3, 4, 0, 80, b'H', - b'E', b'L', b'L', b'O', - ])) - .expect("should not err"); - - let (address, data_bytes) = relay.receive().expect("data to be available"); - assert_eq!(address, parse!("10.2.3.4:80")); - assert_eq!(data_bytes.len(), 1); - assert_eq!(data_bytes[0], b"HELLO".as_slice()); + relay.handle_incoming(Bytes::from_static(&[ + 0xde, 0, 0xad, 1, 0xbe, 2, 0xef, 3, 1, 0, 0, 0, 5, 10, 2, 3, 4, 0, 80, b'H', b'E', + b'L', b'L', b'O', + ])); + + let packet = relay.receive().expect("data to be available"); + assert_eq!(packet.last_hop, Some(parse!("10.2.3.4:80"))); + assert_eq!(packet.content.len(), 1); + assert_eq!(packet.content[0], b"HELLO".as_slice()); } #[test] @@ -353,8 +484,9 @@ mod tests { state: State::Registered { transmit_queue: VecDeque::new(), port: 80, + parser: StreamParser::default(), }, - parser: StreamParser::new(), + events: VecDeque::new(), }; let parts = [ @@ -368,16 +500,14 @@ mod tests { let Err(ReceiveError::Blocked) = relay.receive() else { panic!("expected to be blocked"); }; - relay - .handle_incoming(Bytes::from(data)) - .expect("should not err"); + relay.handle_incoming(Bytes::from(data)); } - let (address, data_bytes) = relay.receive().expect("data to be available"); - assert_eq!(address, parse!("10.2.3.4:80")); - assert_eq!(data_bytes.len(), 2); - assert_eq!(data_bytes[0], b"H".as_slice()); - assert_eq!(data_bytes[1], b"ELLO".as_slice()); + let packet = relay.receive().expect("data to be available"); + assert_eq!(packet.last_hop, Some(parse!("10.2.3.4:80"))); + assert_eq!(packet.content.len(), 2); + assert_eq!(packet.content[0], b"H".as_slice()); + assert_eq!(packet.content[1], b"ELLO".as_slice()); } } } diff --git a/crates/scion/src/reliable/wire_utils.rs b/crates/scion/src/reliable/wire_utils.rs index 82bbd99..0d336e3 100644 --- a/crates/scion/src/reliable/wire_utils.rs +++ b/crates/scion/src/reliable/wire_utils.rs @@ -1,3 +1,7 @@ +use std::collections::VecDeque; + +use bytes::{Buf, Bytes}; + use crate::address::{HostType, ServiceAddress}; pub(super) const IPV4_OCTETS: usize = 4; @@ -23,3 +27,102 @@ pub(super) fn encoded_port_length(host_type: HostType) -> usize { pub(super) fn encoded_address_and_port_length(host_type: HostType) -> usize { encoded_address_length(host_type) + encoded_port_length(host_type) } + +/// A queue of Bytes objects implementing the [`bytes::Buf`] trait. +#[derive(Default, Debug)] +pub(super) struct BytesQueue { + // INV: byte objects are always non-empty + queue: VecDeque, + bytes_remaining: usize, +} + +impl BytesQueue { + pub fn new() -> Self { + Self::default() + } + + pub fn pop_front(&mut self) -> Option { + if let Some(bytes) = self.queue.pop_front() { + self.bytes_remaining -= bytes.len(); + Some(bytes) + } else { + None + } + } + + pub fn push_front(&mut self, value: Bytes) { + if !value.is_empty() { + self.increase_remaining(value.len()); + self.queue.push_front(value) + } + } + + /// Append a Bytes to the queue, discarding it if it is empty. + pub fn push_back(&mut self, value: Bytes) { + if !value.is_empty() { + self.increase_remaining(value.len()); + self.queue.push_back(value) + } + } + + fn increase_remaining(&mut self, value: usize) { + self.bytes_remaining = self + .bytes_remaining + .checked_add(value) + .expect("never more than usize bytes in total"); + } +} + +impl Buf for BytesQueue { + fn remaining(&self) -> usize { + self.bytes_remaining + } + + fn chunk(&self) -> &[u8] { + self.queue.front().map_or(&[], |data| data) + } + + fn advance(&mut self, cnt: usize) { + if cnt == 0 { + return; + } + if cnt > self.bytes_remaining { + panic!( + "cnt > self.remaining() ({} > {})", + cnt, self.bytes_remaining + ); + } + + let mut advance_by = cnt; + while advance_by > 0 { + let mut data = self.queue.pop_front().expect("there must be data"); + + if data.len() > advance_by { + self.queue.push_front(data.split_off(advance_by)); + } + assert!(data.len() <= advance_by); + + advance_by -= data.len(); + } + + self.bytes_remaining -= cnt; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn has_available_multiple() { + let mut bytes_queue = BytesQueue::new(); + + bytes_queue.push_back(Bytes::from_static(&[0, 1, 2])); + bytes_queue.push_back(Bytes::from_static(&[4, 5, 6])); + + let mut buffer = [0u8; 6]; + bytes_queue.copy_to_slice(&mut buffer); + + assert_eq!(buffer, [0, 1, 2, 4, 5, 6]); + } +}