From 86694b05649c0c1666044b2ba5c386c2328aac18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Fugulin?= Date: Sun, 5 May 2024 11:58:00 -0400 Subject: [PATCH] Add zero-copy make_mut (#695) --- src/bytes.rs | 150 +++++++++++++++++++++++++++++++++- src/bytes_mut.rs | 32 +++++++- tests/test_bytes.rs | 111 +++++++++++++++++++++++++ tests/test_bytes_odd_alloc.rs | 50 ++++++++++++ 4 files changed, 341 insertions(+), 2 deletions(-) diff --git a/src/bytes.rs b/src/bytes.rs index 908cee9ad..b4359b08d 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -15,7 +15,7 @@ use crate::buf::IntoIter; #[allow(unused)] use crate::loom::sync::atomic::AtomicMut; use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize, Ordering}; -use crate::Buf; +use crate::{Buf, BytesMut}; /// A cheaply cloneable and sliceable chunk of contiguous memory. /// @@ -113,6 +113,7 @@ pub(crate) struct Vtable { /// /// takes `Bytes` to value pub to_vec: unsafe fn(&AtomicPtr<()>, *const u8, usize) -> Vec, + pub to_mut: unsafe fn(&AtomicPtr<()>, *const u8, usize) -> BytesMut, /// fn(data) pub is_unique: unsafe fn(&AtomicPtr<()>) -> bool, /// fn(data, ptr, len) @@ -507,6 +508,49 @@ impl Bytes { self.truncate(0); } + /// Try to convert self into `BytesMut`. + /// + /// If `self` is unique for the entire original buffer, this will succeed + /// and return a `BytesMut` with the contents of `self` without copying. + /// If `self` is not unique for the entire original buffer, this will fail + /// and return self. + /// + /// # Examples + /// + /// ``` + /// use bytes::{Bytes, BytesMut}; + /// + /// let bytes = Bytes::from(b"hello".to_vec()); + /// assert_eq!(bytes.try_into_mut(), Ok(BytesMut::from(&b"hello"[..]))); + /// ``` + pub fn try_into_mut(self) -> Result { + if self.is_unique() { + Ok(self.make_mut()) + } else { + Err(self) + } + } + + /// Convert self into `BytesMut`. + /// + /// If `self` is unique for the entire original buffer, this will return a + /// `BytesMut` with the contents of `self` without copying. + /// If `self` is not unique for the entire original buffer, this will make + /// a copy of `self` subset of the original buffer in a new `BytesMut`. + /// + /// # Examples + /// + /// ``` + /// use bytes::{Bytes, BytesMut}; + /// + /// let bytes = Bytes::from(b"hello".to_vec()); + /// assert_eq!(bytes.make_mut(), BytesMut::from(&b"hello"[..])); + /// ``` + pub fn make_mut(self) -> BytesMut { + let bytes = ManuallyDrop::new(self); + unsafe { (bytes.vtable.to_mut)(&bytes.data, bytes.ptr, bytes.len) } + } + #[inline] pub(crate) unsafe fn with_vtable( ptr: *const u8, @@ -917,6 +961,7 @@ impl fmt::Debug for Vtable { const STATIC_VTABLE: Vtable = Vtable { clone: static_clone, to_vec: static_to_vec, + to_mut: static_to_mut, is_unique: static_is_unique, drop: static_drop, }; @@ -931,6 +976,11 @@ unsafe fn static_to_vec(_: &AtomicPtr<()>, ptr: *const u8, len: usize) -> Vec, ptr: *const u8, len: usize) -> BytesMut { + let slice = slice::from_raw_parts(ptr, len); + BytesMut::from(slice) +} + fn static_is_unique(_: &AtomicPtr<()>) -> bool { false } @@ -944,6 +994,7 @@ unsafe fn static_drop(_: &mut AtomicPtr<()>, _: *const u8, _: usize) { static PROMOTABLE_EVEN_VTABLE: Vtable = Vtable { clone: promotable_even_clone, to_vec: promotable_even_to_vec, + to_mut: promotable_even_to_mut, is_unique: promotable_is_unique, drop: promotable_even_drop, }; @@ -951,6 +1002,7 @@ static PROMOTABLE_EVEN_VTABLE: Vtable = Vtable { static PROMOTABLE_ODD_VTABLE: Vtable = Vtable { clone: promotable_odd_clone, to_vec: promotable_odd_to_vec, + to_mut: promotable_odd_to_mut, is_unique: promotable_is_unique, drop: promotable_odd_drop, }; @@ -994,12 +1046,47 @@ unsafe fn promotable_to_vec( } } +unsafe fn promotable_to_mut( + data: &AtomicPtr<()>, + ptr: *const u8, + len: usize, + f: fn(*mut ()) -> *mut u8, +) -> BytesMut { + let shared = data.load(Ordering::Acquire); + let kind = shared as usize & KIND_MASK; + + if kind == KIND_ARC { + shared_to_mut_impl(shared.cast(), ptr, len) + } else { + // KIND_VEC is a view of an underlying buffer at a certain offset. + // The ptr + len always represents the end of that buffer. + // Before truncating it, it is first promoted to KIND_ARC. + // Thus, we can safely reconstruct a Vec from it without leaking memory. + debug_assert_eq!(kind, KIND_VEC); + + let buf = f(shared); + let off = offset_from(ptr, buf); + let cap = off + len; + let v = Vec::from_raw_parts(buf, cap, cap); + + let mut b = BytesMut::from_vec(v); + b.advance_unchecked(off); + b + } +} + unsafe fn promotable_even_to_vec(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> Vec { promotable_to_vec(data, ptr, len, |shared| { ptr_map(shared.cast(), |addr| addr & !KIND_MASK) }) } +unsafe fn promotable_even_to_mut(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> BytesMut { + promotable_to_mut(data, ptr, len, |shared| { + ptr_map(shared.cast(), |addr| addr & !KIND_MASK) + }) +} + unsafe fn promotable_even_drop(data: &mut AtomicPtr<()>, ptr: *const u8, len: usize) { data.with_mut(|shared| { let shared = *shared; @@ -1031,6 +1118,10 @@ unsafe fn promotable_odd_to_vec(data: &AtomicPtr<()>, ptr: *const u8, len: usize promotable_to_vec(data, ptr, len, |shared| shared.cast()) } +unsafe fn promotable_odd_to_mut(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> BytesMut { + promotable_to_mut(data, ptr, len, |shared| shared.cast()) +} + unsafe fn promotable_odd_drop(data: &mut AtomicPtr<()>, ptr: *const u8, len: usize) { data.with_mut(|shared| { let shared = *shared; @@ -1087,6 +1178,7 @@ const _: [(); 0 - mem::align_of::() % 2] = []; // Assert that the alignm static SHARED_VTABLE: Vtable = Vtable { clone: shared_clone, to_vec: shared_to_vec, + to_mut: shared_to_mut, is_unique: shared_is_unique, drop: shared_drop, }; @@ -1133,6 +1225,45 @@ unsafe fn shared_to_vec(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> Vec shared_to_vec_impl(data.load(Ordering::Relaxed).cast(), ptr, len) } +unsafe fn shared_to_mut_impl(shared: *mut Shared, ptr: *const u8, len: usize) -> BytesMut { + // The goal is to check if the current handle is the only handle + // that currently has access to the buffer. This is done by + // checking if the `ref_cnt` is currently 1. + // + // The `Acquire` ordering synchronizes with the `Release` as + // part of the `fetch_sub` in `release_shared`. The `fetch_sub` + // operation guarantees that any mutations done in other threads + // are ordered before the `ref_cnt` is decremented. As such, + // this `Acquire` will guarantee that those mutations are + // visible to the current thread. + // + // Otherwise, we take the other branch, copy the data and call `release_shared`. + if (*shared).ref_cnt.load(Ordering::Acquire) == 1 { + // Deallocate the `Shared` instance without running its destructor. + let shared = *Box::from_raw(shared); + let shared = ManuallyDrop::new(shared); + let buf = shared.buf; + let cap = shared.cap; + + // Rebuild Vec + let off = offset_from(ptr, buf); + let v = Vec::from_raw_parts(buf, len + off, cap); + + let mut b = BytesMut::from_vec(v); + b.advance_unchecked(off); + b + } else { + // Copy the data from Shared in a new Vec, then release it + let v = slice::from_raw_parts(ptr, len).to_vec(); + release_shared(shared); + BytesMut::from_vec(v) + } +} + +unsafe fn shared_to_mut(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> BytesMut { + shared_to_mut_impl(data.load(Ordering::Relaxed).cast(), ptr, len) +} + pub(crate) unsafe fn shared_is_unique(data: &AtomicPtr<()>) -> bool { let shared = data.load(Ordering::Acquire); let ref_cnt = (*shared.cast::()).ref_cnt.load(Ordering::Relaxed); @@ -1291,6 +1422,23 @@ where new_addr as *mut u8 } +/// Precondition: dst >= original +/// +/// The following line is equivalent to: +/// +/// ```rust,ignore +/// self.ptr.as_ptr().offset_from(ptr) as usize; +/// ``` +/// +/// But due to min rust is 1.39 and it is only stabilized +/// in 1.47, we cannot use it. +#[inline] +fn offset_from(dst: *const u8, original: *const u8) -> usize { + debug_assert!(dst >= original); + + dst as usize - original as usize +} + // compile-fails /// ```compile_fail diff --git a/src/bytes_mut.rs b/src/bytes_mut.rs index b01bb1adc..569f8be63 100644 --- a/src/bytes_mut.rs +++ b/src/bytes_mut.rs @@ -868,7 +868,7 @@ impl BytesMut { /// # SAFETY /// /// The caller must ensure that `count` <= `self.cap`. - unsafe fn advance_unchecked(&mut self, count: usize) { + pub(crate) unsafe fn advance_unchecked(&mut self, count: usize) { // Setting the start to 0 is a no-op, so return early if this is the // case. if count == 0 { @@ -1713,6 +1713,7 @@ unsafe fn rebuild_vec(ptr: *mut u8, mut len: usize, mut cap: usize, off: usize) static SHARED_VTABLE: Vtable = Vtable { clone: shared_v_clone, to_vec: shared_v_to_vec, + to_mut: shared_v_to_mut, is_unique: crate::bytes::shared_is_unique, drop: shared_v_drop, }; @@ -1747,6 +1748,35 @@ unsafe fn shared_v_to_vec(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> V } } +unsafe fn shared_v_to_mut(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> BytesMut { + let shared: *mut Shared = data.load(Ordering::Relaxed).cast(); + + if (*shared).is_unique() { + let shared = &mut *shared; + + // The capacity is always the original capacity of the buffer + // minus the offset from the start of the buffer + let v = &mut shared.vec; + let v_capacity = v.capacity(); + let v_ptr = v.as_mut_ptr(); + let offset = offset_from(ptr as *mut u8, v_ptr); + let cap = v_capacity - offset; + + let ptr = vptr(ptr as *mut u8); + + BytesMut { + ptr, + len, + cap, + data: shared, + } + } else { + let v = slice::from_raw_parts(ptr, len).to_vec(); + release_shared(shared); + BytesMut::from_vec(v) + } +} + unsafe fn shared_v_drop(data: &mut AtomicPtr<()>, _ptr: *const u8, _len: usize) { data.with_mut(|shared| { release_shared(*shared as *mut Shared); diff --git a/tests/test_bytes.rs b/tests/test_bytes.rs index 84c3d5a43..2f283af2f 100644 --- a/tests/test_bytes.rs +++ b/tests/test_bytes.rs @@ -1172,3 +1172,114 @@ fn shared_is_unique() { drop(b); assert!(c.is_unique()); } + +#[test] +fn test_bytes_make_mut_static() { + let bs = b"1b23exfcz3r"; + + // Test STATIC_VTABLE.to_mut + let bytes_mut = Bytes::from_static(bs).make_mut(); + assert_eq!(bytes_mut, bs[..]); +} + +#[test] +fn test_bytes_make_mut_bytes_mut_vec() { + let bs = b"1b23exfcz3r"; + let bs_long = b"1b23exfcz3r1b23exfcz3r"; + + // Test case where kind == KIND_VEC + let mut bytes_mut: BytesMut = bs[..].into(); + bytes_mut = bytes_mut.freeze().make_mut(); + assert_eq!(bytes_mut, bs[..]); + bytes_mut.extend_from_slice(&bs[..]); + assert_eq!(bytes_mut, bs_long[..]); +} + +#[test] +fn test_bytes_make_mut_bytes_mut_shared() { + let bs = b"1b23exfcz3r"; + + // Set kind to KIND_ARC so that after freeze, Bytes will use bytes_mut.SHARED_VTABLE + let mut bytes_mut: BytesMut = bs[..].into(); + drop(bytes_mut.split_off(bs.len())); + + let b1 = bytes_mut.freeze(); + let b2 = b1.clone(); + + // shared.is_unique() = False + let mut b1m = b1.make_mut(); + assert_eq!(b1m, bs[..]); + b1m[0] = b'9'; + + // shared.is_unique() = True + let b2m = b2.make_mut(); + assert_eq!(b2m, bs[..]); +} + +#[test] +fn test_bytes_make_mut_bytes_mut_offset() { + let bs = b"1b23exfcz3r"; + + // Test bytes_mut.SHARED_VTABLE.to_mut impl where offset != 0 + let mut bytes_mut1: BytesMut = bs[..].into(); + let bytes_mut2 = bytes_mut1.split_off(9); + + let b1 = bytes_mut1.freeze(); + let b2 = bytes_mut2.freeze(); + + let b1m = b1.make_mut(); + let b2m = b2.make_mut(); + + assert_eq!(b2m, bs[9..]); + assert_eq!(b1m, bs[..9]); +} + +#[test] +fn test_bytes_make_mut_promotable_even_vec() { + let vec = vec![33u8; 1024]; + + // Test case where kind == KIND_VEC + let b1 = Bytes::from(vec.clone()); + let b1m = b1.make_mut(); + assert_eq!(b1m, vec); +} + +#[test] +fn test_bytes_make_mut_promotable_even_arc_1() { + let vec = vec![33u8; 1024]; + + // Test case where kind == KIND_ARC, ref_cnt == 1 + let b1 = Bytes::from(vec.clone()); + drop(b1.clone()); + let b1m = b1.make_mut(); + assert_eq!(b1m, vec); +} + +#[test] +fn test_bytes_make_mut_promotable_even_arc_2() { + let vec = vec![33u8; 1024]; + + // Test case where kind == KIND_ARC, ref_cnt == 2 + let b1 = Bytes::from(vec.clone()); + let b2 = b1.clone(); + let b1m = b1.make_mut(); + assert_eq!(b1m, vec); + + // Test case where vtable = SHARED_VTABLE, kind == KIND_ARC, ref_cnt == 1 + let b2m = b2.make_mut(); + assert_eq!(b2m, vec); +} + +#[test] +fn test_bytes_make_mut_promotable_even_arc_offset() { + let vec = vec![33u8; 1024]; + + // Test case where offset != 0 + let mut b1 = Bytes::from(vec.clone()); + let b2 = b1.split_off(20); + let b1m = b1.make_mut(); + let b2m = b2.make_mut(); + + assert_eq!(b2m, vec[20..]); + assert_eq!(b1m, vec[..20]); +} diff --git a/tests/test_bytes_odd_alloc.rs b/tests/test_bytes_odd_alloc.rs index 27ed87736..8008a0e47 100644 --- a/tests/test_bytes_odd_alloc.rs +++ b/tests/test_bytes_odd_alloc.rs @@ -95,3 +95,53 @@ fn test_bytes_into_vec() { assert_eq!(Vec::from(b2), vec[20..]); assert_eq!(Vec::from(b1), vec[..20]); } + +#[test] +fn test_bytes_make_mut_vec() { + let vec = vec![33u8; 1024]; + + // Test case where kind == KIND_VEC + let b1 = Bytes::from(vec.clone()); + let b1m = b1.make_mut(); + assert_eq!(b1m, vec); +} + +#[test] +fn test_bytes_make_mut_arc_1() { + let vec = vec![33u8; 1024]; + + // Test case where kind == KIND_ARC, ref_cnt == 1 + let b1 = Bytes::from(vec.clone()); + drop(b1.clone()); + let b1m = b1.make_mut(); + assert_eq!(b1m, vec); +} + +#[test] +fn test_bytes_make_mut_arc_2() { + let vec = vec![33u8; 1024]; + + // Test case where kind == KIND_ARC, ref_cnt == 2 + let b1 = Bytes::from(vec.clone()); + let b2 = b1.clone(); + let b1m = b1.make_mut(); + assert_eq!(b1m, vec); + + // Test case where vtable = SHARED_VTABLE, kind == KIND_ARC, ref_cnt == 1 + let b2m = b2.make_mut(); + assert_eq!(b2m, vec); +} + +#[test] +fn test_bytes_make_mut_arc_offset() { + let vec = vec![33u8; 1024]; + + // Test case where offset != 0 + let mut b1 = Bytes::from(vec.clone()); + let b2 = b1.split_off(20); + let b1m = b1.make_mut(); + let b2m = b2.make_mut(); + + assert_eq!(b2m, vec[20..]); + assert_eq!(b1m, vec[..20]); +}