From 033b27d9a9c52b04b8aeedfeb5a0c9475cc3bfad Mon Sep 17 00:00:00 2001 From: Michael Tautschnig Date: Wed, 24 Jul 2024 12:04:18 +0000 Subject: [PATCH] Work around CaDiCaL performance regression --- .../s2n-quic/quic/s2n-quic-core/checksum.rs | 484 ++++++++++++++++++ .../s2n-quic/quic/s2n-quic-core/src/slice.rs | 250 +++++++++ .../quic/s2n-quic-platform/message.rs | 190 +++++++ .../src/message/cmsg/tests.rs | 123 +++++ .../src/message/msg/tests.rs | 111 ++++ 5 files changed, 1158 insertions(+) create mode 100644 tests/perf/overlays/s2n-quic/quic/s2n-quic-core/checksum.rs create mode 100644 tests/perf/overlays/s2n-quic/quic/s2n-quic-core/src/slice.rs create mode 100644 tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/message.rs create mode 100644 tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/cmsg/tests.rs create mode 100644 tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/msg/tests.rs diff --git a/tests/perf/overlays/s2n-quic/quic/s2n-quic-core/checksum.rs b/tests/perf/overlays/s2n-quic/quic/s2n-quic-core/checksum.rs new file mode 100644 index 000000000000..76466bacd486 --- /dev/null +++ b/tests/perf/overlays/s2n-quic/quic/s2n-quic-core/checksum.rs @@ -0,0 +1,484 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{fmt, hash::Hasher, num::Wrapping}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod x86; + +/// Computes the [IP checksum](https://www.rfc-editor.org/rfc/rfc1071) over the given slice of bytes +#[inline] +pub fn checksum(data: &[u8]) -> u16 { + let mut checksum = Checksum::default(); + checksum.write(data); + checksum.finish() +} + +/// Minimum size for a payload to be considered for platform-specific code +const LARGE_WRITE_LEN: usize = 32; + +type Accumulator = u64; +type State = Wrapping; + +/// Platform-specific function for computing a checksum +type LargeWriteFn = for<'a> unsafe fn(&mut State, bytes: &'a [u8]) -> &'a [u8]; + +#[inline(always)] +fn write_sized_generic<'a, const MAX_LEN: usize, const CHUNK_LEN: usize>( + state: &mut State, + mut bytes: &'a [u8], + on_chunk: impl Fn(&[u8; CHUNK_LEN], &mut Accumulator), +) -> &'a [u8] { + //= https://www.rfc-editor.org/rfc/rfc1071#section-4.1 + //# The following "C" code algorithm computes the checksum with an inner + //# loop that sums 16-bits at a time in a 32-bit accumulator. + //# + //# in 6 + //# { + //# /* Compute Internet Checksum for "count" bytes + //# * beginning at location "addr". + //# */ + //# register long sum = 0; + //# + //# while( count > 1 ) { + //# /* This is the inner loop */ + //# sum += * (unsigned short) addr++; + //# count -= 2; + //# } + //# + //# /* Add left-over byte, if any */ + //# if( count > 0 ) + //# sum += * (unsigned char *) addr; + //# + //# /* Fold 32-bit sum to 16 bits */ + //# while (sum>>16) + //# sum = (sum & 0xffff) + (sum >> 16); + //# + //# checksum = ~sum; + //# } + + while bytes.len() >= MAX_LEN { + // use `get_unchecked` to make it easier for kani to analyze + let chunks = unsafe { bytes.get_unchecked(..MAX_LEN) }; + bytes = unsafe { bytes.get_unchecked(MAX_LEN..) }; + + let mut sum = 0; + // for each pair of bytes, interpret them as integers and sum them up + for chunk in chunks.chunks_exact(CHUNK_LEN) { + let chunk = unsafe { + // SAFETY: chunks_exact always produces a slice of CHUNK_LEN + debug_assert_eq!(chunk.len(), CHUNK_LEN); + &*(chunk.as_ptr() as *const [u8; CHUNK_LEN]) + }; + on_chunk(chunk, &mut sum); + } + *state += sum; + } + + bytes +} + +/// Generic implementation of a function that computes a checksum over the given slice +#[inline(always)] +fn write_sized_generic_u16<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] { + write_sized_generic::( + state, + bytes, + #[inline(always)] + |&bytes, acc| { + *acc += u16::from_ne_bytes(bytes) as Accumulator; + }, + ) +} + +#[inline(always)] +fn write_sized_generic_u32<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] { + write_sized_generic::( + state, + bytes, + #[inline(always)] + |&bytes, acc| { + *acc += u32::from_ne_bytes(bytes) as Accumulator; + }, + ) +} + +/// Returns the most optimized function implementation for the current platform +#[inline] +#[cfg(all(feature = "once_cell", not(any(kani, miri))))] +fn probe_write_large() -> LargeWriteFn { + static LARGE_WRITE_FN: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if let Some(fun) = x86::probe() { + return fun; + } + } + + write_sized_generic_u32::<16> + }); + + *LARGE_WRITE_FN +} + +#[inline] +#[cfg(not(all(feature = "once_cell", not(any(kani, miri)))))] +fn probe_write_large() -> LargeWriteFn { + write_sized_generic_u32::<16> +} + +/// Computes the [IP checksum](https://www.rfc-editor.org/rfc/rfc1071) over an arbitrary set of inputs +#[derive(Clone, Copy)] +pub struct Checksum { + state: State, + partial_write: bool, + write_large: LargeWriteFn, +} + +impl Default for Checksum { + fn default() -> Self { + Self { + state: Default::default(), + partial_write: false, + write_large: probe_write_large(), + } + } +} + +impl fmt::Debug for Checksum { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut v = *self; + v.carry(); + f.debug_tuple("Checksum").field(&v.finish()).finish() + } +} + +impl Checksum { + /// Creates a checksum instance without enabling the native implementation + #[inline] + pub fn generic() -> Self { + Self { + state: Default::default(), + partial_write: false, + write_large: write_sized_generic_u32::<16>, + } + } + + /// Writes a single byte to the checksum state + #[inline] + fn write_byte(&mut self, byte: u8, shift: bool) { + if shift { + self.state += (byte as Accumulator) << 8; + } else { + self.state += byte as Accumulator; + } + } + + /// Carries all of the bits into a single 16 bit range + #[inline] + fn carry(&mut self) { + #[cfg(kani)] + self.carry_rfc(); + #[cfg(not(kani))] + self.carry_optimized(); + } + + /// Carries all of the bits into a single 16 bit range + /// + /// This implementation is very similar to the way the RFC is written. + #[inline] + #[allow(dead_code)] + fn carry_rfc(&mut self) { + let mut state = self.state.0; + + for _ in 0..core::mem::size_of::() { + state = (state & 0xffff) + (state >> 16); + } + + self.state.0 = state; + } + + /// Carries all of the bits into a single 16 bit range + /// + /// This implementation was written after some optimization on the RFC version. It results in + /// about half the instructions needed as the RFC. + #[inline] + #[allow(dead_code)] + fn carry_optimized(&mut self) { + let values: [u16; core::mem::size_of::() / 2] = unsafe { + // SAFETY: alignment of the State is >= of u16 + debug_assert!(core::mem::align_of::() >= core::mem::align_of::()); + core::mem::transmute(self.state.0) + }; + + let mut sum = 0u16; + + for value in values { + let (res, overflowed) = sum.overflowing_add(value); + sum = res; + if overflowed { + sum += 1; + } + } + + self.state.0 = sum as _; + } + + /// Writes bytes to the checksum and ensures any single byte remainders are padded + #[inline] + pub fn write_padded(&mut self, bytes: &[u8]) { + self.write(bytes); + + // write a null byte if `bytes` wasn't 16-bit aligned + if core::mem::take(&mut self.partial_write) { + self.write_byte(0, cfg!(target_endian = "little")); + } + } + + /// Computes the final checksum + #[inline] + pub fn finish(self) -> u16 { + self.finish_be().to_be() + } + + #[inline] + pub fn finish_be(mut self) -> u16 { + self.carry(); + + let value = self.state.0 as u16; + let value = !value; + + // if value is 0, we need to set it to the max value to indicate the checksum was actually + // computed + if value == 0 { + return 0xffff; + } + + value + } +} + +impl Hasher for Checksum { + #[inline] + fn write(&mut self, mut bytes: &[u8]) { + if bytes.is_empty() { + return; + } + + // Check to see if we have a partial write to flush + if core::mem::take(&mut self.partial_write) { + let (chunk, remaining) = bytes.split_at(1); + bytes = remaining; + + // shift the byte if we're on little endian + self.write_byte(chunk[0], cfg!(target_endian = "little")); + } + + // Only delegate to the optimized platform function if the payload is big enough + if bytes.len() >= LARGE_WRITE_LEN { + bytes = unsafe { (self.write_large)(&mut self.state, bytes) }; + } + + // Fall back on the generic implementation to wrap things up + // + // NOTE: We don't use the u32 version with kani as it causes the verification time to + // increase by quite a bit. We have a separate proof for the functional equivalence of + // these two configurations. + #[cfg(not(kani))] + { + bytes = write_sized_generic_u32::<4>(&mut self.state, bytes); + } + + bytes = write_sized_generic_u16::<2>(&mut self.state, bytes); + + // if we only have a single byte left, write it to the state and mark it as a partial write + if let Some(byte) = bytes.first().copied() { + self.partial_write = true; + self.write_byte(byte, cfg!(target_endian = "big")); + } + } + + #[inline] + fn finish(&self) -> u64 { + Self::finish(*self) as _ + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bolero::check; + + #[test] + fn rfc_example_test() { + //= https://www.rfc-editor.org/rfc/rfc1071#section-3 + //= type=test + //# We now present explicit examples of calculating a simple 1's + //# complement sum on a 2's complement machine. The examples show the + //# same sum calculated byte by bye, by 16-bits words in normal and + //# swapped order, and 32 bits at a time in 3 different orders. All + //# numbers are in hex. + //# + //# Byte-by-byte "Normal" Swapped + //# Order Order + //# + //# Byte 0/1: 00 01 0001 0100 + //# Byte 2/3: f2 03 f203 03f2 + //# Byte 4/5: f4 f5 f4f5 f5f4 + //# Byte 6/7: f6 f7 f6f7 f7f6 + //# --- --- ----- ----- + //# Sum1: 2dc 1f0 2ddf0 1f2dc + //# + //# dc f0 ddf0 f2dc + //# Carrys: 1 2 2 1 + //# -- -- ---- ---- + //# Sum2: dd f2 ddf2 f2dd + //# + //# Final Swap: dd f2 ddf2 ddf2 + let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7]; + + let mut checksum = Checksum::default(); + checksum.write(&bytes); + checksum.carry(); + + assert_eq!((checksum.state.0 as u16).to_le_bytes(), [0xdd, 0xf2]); + assert_eq!((!rfc_c_port(&bytes)).to_be_bytes(), [0xdd, 0xf2]); + } + + fn rfc_c_port(data: &[u8]) -> u16 { + //= https://www.rfc-editor.org/rfc/rfc1071#section-4.1 + //= type=test + //# The following "C" code algorithm computes the checksum with an inner + //# loop that sums 16-bits at a time in a 32-bit accumulator. + //# + //# in 6 + //# { + //# /* Compute Internet Checksum for "count" bytes + //# * beginning at location "addr". + //# */ + //# register long sum = 0; + //# + //# while( count > 1 ) { + //# /* This is the inner loop */ + //# sum += * (unsigned short) addr++; + //# count -= 2; + //# } + //# + //# /* Add left-over byte, if any */ + //# if( count > 0 ) + //# sum += * (unsigned char *) addr; + //# + //# /* Fold 32-bit sum to 16 bits */ + //# while (sum>>16) + //# sum = (sum & 0xffff) + (sum >> 16); + //# + //# checksum = ~sum; + //# } + + let mut addr = data.as_ptr(); + let mut count = data.len(); + + unsafe { + let mut sum = 0u32; + + while count > 1 { + let value = u16::from_be_bytes([*addr, *addr.add(1)]); + sum = sum.wrapping_add(value as u32); + addr = addr.add(2); + count -= 2; + } + + if count > 0 { + let value = u16::from_be_bytes([*addr, 0]); + sum = sum.wrapping_add(value as u32); + } + + while sum >> 16 != 0 { + sum = (sum & 0xffff) + (sum >> 16); + } + + !(sum as u16) + } + } + + // Reduce the length to 4 for Kani until + // https://github.com/model-checking/kani/issues/3030 is fixed + #[cfg(any(kani, miri))] + const LEN: usize = if cfg!(kani) { 4 } else { 32 }; + + /// * Compares the implementation to a port of the C code defined in the RFC + /// * Ensures partial writes are correctly handled, even if they're not at a 16 bit boundary + #[test] + #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(minisat))] + fn differential() { + #[cfg(any(kani, miri))] + type Bytes = crate::testing::InlineVec; + #[cfg(not(any(kani, miri)))] + type Bytes = Vec; + + check!() + .with_type::<(usize, Bytes)>() + .for_each(|(index, bytes)| { + let index = if bytes.is_empty() { + 0 + } else { + *index % bytes.len() + }; + let (a, b) = bytes.split_at(index); + let mut cs = Checksum::default(); + cs.write(a); + cs.write(b); + + let mut rfc_value = rfc_c_port(bytes); + if rfc_value == 0 { + rfc_value = 0xffff; + } + + assert_eq!(rfc_value.to_be_bytes(), cs.finish().to_be_bytes()); + }); + } + + /// Shows that using the u32+u16 methods is the same as only using u16 + #[test] + #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))] + fn u32_u16_differential() { + #[cfg(any(kani, miri))] + type Bytes = crate::testing::InlineVec; + #[cfg(not(any(kani, miri)))] + type Bytes = Vec; + + check!().with_type::().for_each(|bytes| { + let a = { + let mut cs = Checksum::generic(); + let bytes = write_sized_generic_u32::<4>(&mut cs.state, bytes); + write_sized_generic_u16::<2>(&mut cs.state, bytes); + cs.finish() + }; + + let b = { + let mut cs = Checksum::generic(); + write_sized_generic_u16::<2>(&mut cs.state, bytes); + cs.finish() + }; + + assert_eq!(a, b); + }); + } + + /// Shows that RFC carry implementation is the same as the optimized version + #[test] + #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))] + fn carry_differential() { + check!().with_type::().cloned().for_each(|state| { + let mut opt = Checksum::generic(); + opt.state.0 = state; + opt.carry_optimized(); + + let mut rfc = Checksum::generic(); + rfc.state.0 = state; + rfc.carry_rfc(); + + assert_eq!(opt.state.0, rfc.state.0); + }); + } +} diff --git a/tests/perf/overlays/s2n-quic/quic/s2n-quic-core/src/slice.rs b/tests/perf/overlays/s2n-quic/quic/s2n-quic-core/src/slice.rs new file mode 100644 index 000000000000..563f28200551 --- /dev/null +++ b/tests/perf/overlays/s2n-quic/quic/s2n-quic-core/src/slice.rs @@ -0,0 +1,250 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::ops::{Deref, DerefMut}; + +pub mod deque; + +/// Copies vectored slices from one slice into another +/// +/// The number of copied items is limited by the minimum of the lengths of each of the slices. +/// +/// Returns the number of entries that were copied +#[inline] +pub fn vectored_copy(from: &[A], to: &mut [B]) -> usize +where + A: Deref, + B: Deref + DerefMut, + T: Copy, +{ + zip_chunks(from, to, |a, b| { + b.copy_from_slice(a); + }) +} + +/// Zips entries from one slice to another +/// +/// The number of copied items is limited by the minimum of the lengths of each of the slices. +/// +/// Returns the number of entries that were processed +#[inline] +pub fn zip(from: &[A], to: &mut [B], mut on_item: F) -> usize +where + A: Deref, + B: Deref + DerefMut, + F: FnMut(&At, &mut Bt), +{ + zip_chunks(from, to, |a, b| { + for (a, b) in a.iter().zip(b) { + on_item(a, b); + } + }) +} + +/// Zips overlapping chunks from one slice to another +/// +/// The number of copied items is limited by the minimum of the lengths of each of the slices. +/// +/// Returns the number of entries that were processed +#[inline] +pub fn zip_chunks(from: &[A], to: &mut [B], mut on_slice: F) -> usize +where + A: Deref, + B: Deref + DerefMut, + F: FnMut(&[At], &mut [Bt]), +{ + let mut count = 0; + + let mut from_index = 0; + let mut from_offset = 0; + + let mut to_index = 0; + let mut to_offset = 0; + + // The compiler isn't smart enough to remove all of the bounds checks so we resort to + // `get_unchecked`. + // + // https://godbolt.org/z/45cG1v + + // iterate until we reach one of the ends + while from_index < from.len() && to_index < to.len() { + let from = unsafe { + // Safety: this length is already checked in the while condition + debug_assert!(from.len() > from_index); + from.get_unchecked(from_index) + }; + + let to = unsafe { + // Safety: this length is already checked in the while condition + debug_assert!(to.len() > to_index); + to.get_unchecked_mut(to_index) + }; + + { + // calculate the current views + let from = unsafe { + // Safety: the slice offsets are checked at the end of the while loop + debug_assert!(from.len() >= from_offset); + from.get_unchecked(from_offset..) + }; + + let to = unsafe { + // Safety: the slice offsets are checked at the end of the while loop + debug_assert!(to.len() >= to_offset); + to.get_unchecked_mut(to_offset..) + }; + + let len = from.len().min(to.len()); + + unsafe { + // Safety: by using the min of the two lengths we will never exceed + // either slice's buffer + debug_assert!(from.len() >= len); + debug_assert!(to.len() >= len); + + let at = from.get_unchecked(..len); + let bt = to.get_unchecked_mut(..len); + + on_slice(at, bt); + } + + // increment the offsets + from_offset += len; + to_offset += len; + count += len; + } + + // check if the `from` is done + if from.len() == from_offset { + from_index += 1; + from_offset = 0; + } + + // check if the `to` is done + if to.len() == to_offset { + to_index += 1; + to_offset = 0; + } + } + + count +} + +/// Deduplicates elements in a slice +/// +/// # Note +/// +/// Items must be sorted before performing this function +#[inline] +pub fn partition_dedup(slice: &mut [T]) -> (&mut [T], &mut [T]) +where + T: PartialEq, +{ + // TODO replace with + // https://doc.rust-lang.org/std/primitive.slice.html#method.partition_dedup + // when stable + // + // For now, we've just inlined their implementation + + let len = slice.len(); + if len <= 1 { + return (slice, &mut []); + } + + let ptr = slice.as_mut_ptr(); + let mut next_read: usize = 1; + let mut next_write: usize = 1; + + // SAFETY: the `while` condition guarantees `next_read` and `next_write` + // are less than `len`, thus are inside `self`. `prev_ptr_write` points to + // one element before `ptr_write`, but `next_write` starts at 1, so + // `prev_ptr_write` is never less than 0 and is inside the slice. + // This fulfils the requirements for dereferencing `ptr_read`, `prev_ptr_write` + // and `ptr_write`, and for using `ptr.add(next_read)`, `ptr.add(next_write - 1)` + // and `prev_ptr_write.offset(1)`. + // + // `next_write` is also incremented at most once per loop at most meaning + // no element is skipped when it may need to be swapped. + // + // `ptr_read` and `prev_ptr_write` never point to the same element. This + // is required for `&mut *ptr_read`, `&mut *prev_ptr_write` to be safe. + // The explanation is simply that `next_read >= next_write` is always true, + // thus `next_read > next_write - 1` is too. + unsafe { + // Avoid bounds checks by using raw pointers. + while next_read < len { + let ptr_read = ptr.add(next_read); + let prev_ptr_write = ptr.add(next_write - 1); + if *ptr_read != *prev_ptr_write { + if next_read != next_write { + let ptr_write = prev_ptr_write.add(1); + core::ptr::swap(ptr_read, ptr_write); + } + next_write += 1; + } + next_read += 1; + } + } + + slice.split_at_mut(next_write) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::testing::InlineVec; + use bolero::check; + + fn assert_eq_slices(a: &[A], b: &[B]) + where + A: Deref, + B: Deref, + T: PartialEq + core::fmt::Debug, + { + let a = a.iter().flat_map(|a| a.iter()); + let b = b.iter().flat_map(|b| b.iter()); + + // make sure all of the values match + // + // Note: this doesn't use Iterator::eq, as the slice lengths may be different + for (a, b) in a.zip(b) { + assert_eq!(a, b); + } + } + + #[test] + fn vectored_copy_test() { + let from = [ + &[0][..], + &[1, 2, 3][..], + &[4, 5, 6, 7][..], + &[][..], + &[8, 9, 10, 11][..], + ]; + + for len in 0..6 { + let mut to = vec![vec![0; 2]; len]; + let copied_len = vectored_copy(&from, &mut to); + assert_eq!(copied_len, len * 2); + assert_eq_slices(&from, &to); + } + } + + const LEN: usize = if cfg!(kani) { 2 } else { 32 }; + + #[test] + #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))] + #[cfg_attr(miri, ignore)] // This test is too expensive for miri to complete in a reasonable amount of time + fn vectored_copy_fuzz_test() { + check!() + .with_type::<( + InlineVec, LEN>, + InlineVec, LEN>, + )>() + .cloned() + .for_each(|(from, mut to)| { + vectored_copy(&from, &mut to); + assert_eq_slices(&from, &to); + }) + } +} diff --git a/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/message.rs b/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/message.rs new file mode 100644 index 000000000000..1e7e58ad53ff --- /dev/null +++ b/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/message.rs @@ -0,0 +1,190 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{alloc::Layout, ptr::NonNull}; +use s2n_quic_core::{inet::datagram, io::tx, path}; + +#[cfg(s2n_quic_platform_cmsg)] +pub mod cmsg; +#[cfg(s2n_quic_platform_socket_mmsg)] +pub mod mmsg; +#[cfg(s2n_quic_platform_socket_msg)] +pub mod msg; +pub mod simple; + +pub mod default { + cfg_if::cfg_if! { + if #[cfg(s2n_quic_platform_socket_mmsg)] { + pub use super::mmsg::*; + } else if #[cfg(s2n_quic_platform_socket_msg)] { + pub use super::msg::*; + } else { + pub use super::simple::*; + } + } +} + +/// Tracks allocations of message ring buffer state +pub struct Storage { + ptr: NonNull, + layout: Layout, +} + +/// Safety: the ring buffer controls access to the underlying storage +unsafe impl Send for Storage {} +/// Safety: the ring buffer controls access to the underlying storage +unsafe impl Sync for Storage {} + +impl Storage { + #[inline] + pub fn new(layout: Layout) -> Self { + unsafe { + let ptr = alloc::alloc::alloc_zeroed(layout); + let ptr = NonNull::new(ptr).expect("could not allocate message storage"); + Self { layout, ptr } + } + } + + #[inline] + pub fn as_ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + + /// Asserts that the pointer is in bounds of the allocation + #[inline] + pub fn check_bounds(&self, ptr: *mut T) { + let start = self.as_ptr(); + let end = unsafe { + // Safety: pointer is allocated with the self.layout + start.add(self.layout.size()) + }; + let allocation_range = start..=end; + let actual_end_ptr = ptr as *mut u8; + debug_assert!(allocation_range.contains(&actual_end_ptr)); + } +} + +impl Drop for Storage { + fn drop(&mut self) { + unsafe { + // Safety: pointer was allocated with self.layout + alloc::alloc::dealloc(self.as_ptr(), self.layout) + } + } +} + +/// An abstract message that can be sent and received on a network +pub trait Message: 'static + Copy { + type Handle: path::Handle; + + const SUPPORTS_GSO: bool; + const SUPPORTS_ECN: bool; + const SUPPORTS_FLOW_LABELS: bool; + + /// Allocates `entries` messages, each with `payload_len` bytes + fn alloc(entries: u32, payload_len: u32, offset: usize) -> Storage; + + /// Returns the length of the payload + fn payload_len(&self) -> usize; + + /// Sets the payload length for the message + /// + /// # Safety + /// This method should only set the payload less than or + /// equal to its initially allocated size. + unsafe fn set_payload_len(&mut self, payload_len: usize); + + /// Validates that the `source` message can be replicated to `dest`. + /// + /// # Panics + /// + /// This panics when the messages cannot be replicated + fn validate_replication(source: &Self, dest: &Self); + + /// Returns a mutable pointer for the message payload + fn payload_ptr_mut(&mut self) -> *mut u8; + + /// Returns a mutable slice for the message payload + #[inline] + fn payload_mut(&mut self) -> &mut [u8] { + unsafe { core::slice::from_raw_parts_mut(self.payload_ptr_mut(), self.payload_len()) } + } + + /// Sets the segment size for the message payload + fn set_segment_size(&mut self, _size: usize) { + panic!("cannot use GSO on the current platform"); + } + + /// Resets the message for future use + /// + /// # Safety + /// This method should only set the MTU to the original value + unsafe fn reset(&mut self, mtu: usize); + + /// Reads the message as an RX packet + fn rx_read(&mut self, local_address: &path::LocalAddress) -> Option>; + + /// Writes the message into the TX packet + fn tx_write>( + &mut self, + message: M, + ) -> Result; +} + +pub struct RxMessage<'a, Handle: Copy> { + /// The received header for the message + pub header: datagram::Header, + /// The number of segments inside the message + pub segment_size: usize, + /// The full payload of the message + pub payload: &'a mut [u8], +} + +impl<'a, Handle: Copy> RxMessage<'a, Handle> { + #[inline] + pub fn for_each, &mut [u8])>(self, mut on_packet: F) { + // `chunks_mut` doesn't know what to do with zero-sized segments so return early + if self.segment_size == 0 { + return; + } + + for segment in self.payload.chunks_mut(self.segment_size) { + on_packet(self.header, segment); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bolero::check; + + #[test] + #[cfg_attr(kani, kani::proof, kani::unwind(17), kani::solver(minisat))] + fn rx_message_test() { + let path = bolero::gen::(); + let ecn = bolero::gen(); + let segment_size = bolero::gen(); + let max_payload_len = if cfg!(kani) { 16 } else { u16::MAX as usize }; + let payload_len = 0..=max_payload_len; + + check!() + .with_generator((path, ecn, segment_size, payload_len)) + .cloned() + .for_each(|(path, ecn, segment_size, payload_len)| { + let mut payload = vec![0u8; payload_len]; + let rx_message = RxMessage { + header: datagram::Header { path, ecn }, + segment_size, + payload: &mut payload, + }; + + rx_message.for_each(|header, segment| { + assert_eq!(header.path, path); + assert_eq!(header.ecn, ecn); + assert!(segment.len() <= payload_len); + assert!(segment.len() <= segment_size); + }) + }) + } +} diff --git a/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/cmsg/tests.rs b/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/cmsg/tests.rs new file mode 100644 index 000000000000..7555db7fb523 --- /dev/null +++ b/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/cmsg/tests.rs @@ -0,0 +1,123 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use bolero::{check, TypeGenerator}; +use core::mem::align_of; +use libc::c_int; + +#[inline] +fn aligned_iter(bytes: &[u8], f: impl FnOnce(decode::Iter)) { + // the bytes needs to be aligned to a cmsghdr + let offset = bytes.as_ptr().align_offset(align_of::()); + + if let Some(bytes) = bytes.get(offset..) { + let iter = unsafe { + // SAFETY: bytes are aligned above + decode::Iter::from_bytes(bytes) + }; + + f(iter) + } +} + +/// Ensures the cmsg iterator doesn't crash or segfault +#[test] +#[cfg_attr(kani, kani::proof, kani::solver(minisat), kani::unwind(17))] +fn iter_test() { + check!().for_each(|bytes| { + aligned_iter(bytes, |iter| { + for (cmsghdr, value) in iter { + let _ = cmsghdr; + let _ = value; + } + }) + }); +} + +/// Ensures the `decode::Iter::collect` doesn't crash or segfault +#[test] +#[cfg_attr(kani, kani::proof, kani::solver(minisat), kani::unwind(17))] +fn collect_test() { + check!().for_each(|bytes| { + aligned_iter(bytes, |iter| { + let _ = iter.collect(); + }) + }); +} + +#[derive(Clone, Copy, Debug, TypeGenerator)] +struct Op { + level: c_int, + ty: c_int, + value: Value, +} + +#[derive(Clone, Copy, Debug, TypeGenerator)] +enum Value { + U8(u8), + U16(u16), + U32(u32), + // alignment can't exceed that of cmsghdr + U64([u32; 2]), + U128([u32; 4]), +} + +impl Value { + fn check_value(&self, bytes: &[u8]) { + let expected_len = match self { + Self::U8(_) => 1, + Self::U16(_) => 2, + Self::U32(_) => 4, + Self::U64(_) => 8, + Self::U128(_) => 16, + }; + assert_eq!(expected_len, bytes.len()); + } +} + +fn round_trip(ops: &[Op]) { + let mut storage = Storage::<32>::default(); + let mut encoder = storage.encoder(); + + let mut expected_encoded_count = 0; + + for op in ops { + let res = match op.value { + Value::U8(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U16(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U32(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U64(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U128(value) => encoder.encode_cmsg(op.level, op.ty, value), + }; + + match res { + Ok(_) => expected_encoded_count += 1, + Err(_) => break, + } + } + + let mut actual_decoded_count = 0; + let mut iter = encoder.iter(); + + for (op, (cmsghdr, value)) in ops.iter().zip(&mut iter) { + assert_eq!(op.level, cmsghdr.cmsg_level); + assert_eq!(op.ty, cmsghdr.cmsg_type); + op.value.check_value(value); + actual_decoded_count += 1; + } + + assert_eq!(expected_encoded_count, actual_decoded_count); + assert!(iter.next().is_none()); +} + +#[cfg(not(kani))] +type Ops = Vec; +#[cfg(kani)] +type Ops = s2n_quic_core::testing::InlineVec; + +#[test] +#[cfg_attr(kani, kani::proof, kani::solver(kissat), kani::unwind(9))] +fn round_trip_test() { + check!().with_type::().for_each(|ops| round_trip(ops)); +} diff --git a/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/msg/tests.rs b/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/msg/tests.rs new file mode 100644 index 000000000000..28c443d18eb1 --- /dev/null +++ b/tests/perf/overlays/s2n-quic/quic/s2n-quic-platform/src/message/msg/tests.rs @@ -0,0 +1,111 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use bolero::check; +use core::mem::zeroed; +use s2n_quic_core::inet::{SocketAddress, Unspecified}; + +fn test_msghdr(f: F) { + const PAYLOAD_LEN: usize = 16; + + let mut msghdr = unsafe { zeroed::() }; + + let mut msgname = unsafe { zeroed::() }; + msghdr.msg_name = &mut msgname as *mut _ as *mut _; + msghdr.msg_namelen = size_of::() as _; + + let mut iovec = unsafe { zeroed::() }; + + let mut payload = [0u8; PAYLOAD_LEN]; + iovec.iov_base = &mut payload as *mut _ as *mut _; + iovec.iov_len = 1; + + msghdr.msg_iov = &mut iovec; + + let mut msg_control = >::default(); + msghdr.msg_controllen = msg_control.len() as _; + msghdr.msg_control = msg_control.as_mut_ptr() as *mut _; + + unsafe { + msghdr.reset(PAYLOAD_LEN); + } + + f(&mut msghdr); +} + +#[cfg(kani)] +#[allow(dead_code)] // Avoid warning when using stubs. +mod stubs { + use s2n_quic_core::inet::AncillaryData; + + pub fn collect(_iter: crate::message::cmsg::decode::Iter) -> AncillaryData { + let ancillary_data = kani::any(); + + ancillary_data + } +} + +#[test] +#[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(17))] +fn address_inverse_pair_test() { + check!() + .with_type::() + .cloned() + .for_each(|addr| { + test_msghdr(|message| { + message.set_remote_address(&addr); + + assert_eq!(message.remote_address(), Some(addr)); + }); + }); +} + +#[test] +#[cfg_attr( + kani, + kani::proof, + kani::solver(minisat), + kani::unwind(65), + // it's safe to stub out cmsg::decode since the cmsg result isn't actually checked in this particular test + kani::stub(cmsg::decode::collect, stubs::collect) +)] +fn handle_get_set_test() { + check!() + .with_generator(( + gen::(), + 1..=crate::features::gso::MaxSegments::MAX.into(), + )) + .cloned() + .for_each(|(handle, segment_size)| { + test_msghdr(|message| { + handle.update_msg_hdr(message); + + if segment_size > 1 { + message.set_segment_size(segment_size); + } + + let (header, _cmsg) = message.header().unwrap(); + + assert_eq!(header.path.remote_address, handle.remote_address); + + // no need to check this on kani since we abstract the decode() function to avoid performance issues + #[cfg(not(kani))] + { + if features::pktinfo::IS_SUPPORTED + && !handle.local_address.ip().is_unspecified() + { + assert_eq!(header.path.local_address.ip(), handle.local_address.ip()); + } + } + + // reset the message and ensure everything is zeroed + unsafe { + message.reset(0); + } + + let (header, _cmsg) = message.header().unwrap(); + assert!(header.path.remote_address.is_unspecified()); + }); + }); +}