diff --git a/CHANGELOG.md b/CHANGELOG.md index 99975bf89..ebb3de2ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +### Fixed +- `Take::chunks_vectored` returns as much as possible. + # 1.0.1 (January 11, 2021) ### Changed diff --git a/src/buf/take.rs b/src/buf/take.rs index 2f4c436ac..3a3afbc53 100644 --- a/src/buf/take.rs +++ b/src/buf/take.rs @@ -1,6 +1,8 @@ use crate::Buf; use core::cmp; +#[cfg(feature = "std")] +use std::io::IoSlice; /// A `Buf` adapter which limits the bytes read from an underlying buffer. /// @@ -144,4 +146,31 @@ impl Buf for Take { self.inner.advance(cnt); self.limit -= cnt; } + + #[cfg(feature = "std")] + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + let filled = self.inner.chunks_vectored(dst); + + // There's a change the inner provided more than our limit, truncate. + if self.limit < self.inner.remaining() { + let mut limit = self.limit; + for (idx, s) in dst.iter_mut().enumerate() { + if limit <= s.len() { + // Safety: + // The actual slice inside the IoSlice comes from self.inner, so it really has + // the 'a lifetime. But s comes from dst and the only way how to get to that + // slice again is through its deref, so that shortens the lifetime to that of + // dst artificially. So we have to extend it back again, but it won't be + // extended beyond that of 'a. + let slice = unsafe { std::slice::from_raw_parts(s.as_ptr(), limit) }; + *s = IoSlice::new(slice); + return idx + 1; + } else { + limit -= s.len(); + } + } + } + + filled + } } diff --git a/tests/test_take.rs b/tests/test_take.rs index a23a29edb..2b99e8aec 100644 --- a/tests/test_take.rs +++ b/tests/test_take.rs @@ -1,5 +1,8 @@ #![warn(rust_2018_idioms)] +#[cfg(feature = "std")] +use std::io::IoSlice; + use bytes::buf::Buf; #[test] @@ -10,3 +13,50 @@ fn long_take() { assert_eq!(11, buf.remaining()); assert_eq!(b"hello world", buf.chunk()); } + +// Provide a buf with two slices. +#[cfg(feature = "std")] +fn chained() -> impl Buf { + let a: &[u8] = b"Hello "; + let b: &[u8] = b"World"; + a.chain(b) +} + +#[test] +#[cfg(feature = "std")] +fn take_vectored_doesnt_fit() { + // When there are not enough io slices. + let mut slices = [IoSlice::new(&[]); 1]; + let buf = chained().take(10); + assert_eq!(1, buf.chunks_vectored(&mut slices)); + assert_eq!(b"Hello ", &slices[0] as &[u8]); +} + +#[test] +#[cfg(feature = "std")] +fn take_vectored_long() { + let mut slices = [IoSlice::new(&[]); 2]; + let buf = chained().take(20); + assert_eq!(2, buf.chunks_vectored(&mut slices)); + assert_eq!(b"Hello ", &slices[0] as &[u8]); + assert_eq!(b"World", &slices[1] as &[u8]); +} + +#[test] +#[cfg(feature = "std")] +fn take_vectored_many_slices() { + let mut slices = [IoSlice::new(&[]); 3]; + let buf = chained().take(10); + assert_eq!(2, buf.chunks_vectored(&mut slices)); + assert_eq!(b"Hello ", &slices[0] as &[u8]); + assert_eq!(b"Worl", &slices[1] as &[u8]); +} + +#[test] +#[cfg(feature = "std")] +fn take_vectored_short() { + let mut slices = [IoSlice::new(&[]); 3]; + let buf = chained().take(3); + assert_eq!(1, buf.chunks_vectored(&mut slices)); + assert_eq!(b"Hel", &slices[0] as &[u8]); +}