Skip to content

Commit

Permalink
Remove some unsafe code (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkuris authored Dec 1, 2023
1 parent a773750 commit fedc673
Showing 1 changed file with 18 additions and 26 deletions.
44 changes: 18 additions & 26 deletions growth-ring/src/wal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// See the file LICENSE.md for licensing terms.

use async_trait::async_trait;
use bytemuck::{cast_slice, AnyBitPattern};
use bytemuck::{cast_slice, AnyBitPattern, NoUninit};
use futures::{
future::{self, FutureExt, TryFutureExt},
stream::StreamExt,
Expand Down Expand Up @@ -34,7 +34,8 @@ enum WalRingType {
Last,
}

#[repr(packed)]
#[repr(C, packed)]
#[derive(NoUninit, Copy, Clone, Debug, AnyBitPattern)]
struct WalRingBlob {
counter: u32,
crc32: u32,
Expand Down Expand Up @@ -111,8 +112,8 @@ const fn counter_lt(a: u32, b: u32) -> bool {
}
}

#[repr(C)]
#[derive(Debug, Clone, Copy, AnyBitPattern)]
#[repr(transparent)]
#[derive(Debug, Clone, Copy, AnyBitPattern, NoUninit)]
struct Header {
/// all preceding files (<fid) could be removed if not yet
recover_fid: u64,
Expand Down Expand Up @@ -231,7 +232,7 @@ pub trait WalStore<F: WalFile> {
struct WalFileHandle<'a, F: WalFile + 'static, S: WalStore<F>> {
fid: WalFileId,
handle: &'a dyn WalFile,
pool: *const WalFilePool<F, S>,
pool: &'a WalFilePool<F, S>,
wal_file: PhantomData<F>,
}

Expand All @@ -244,9 +245,7 @@ impl<'a, F: WalFile, S: WalStore<F>> std::ops::Deref for WalFileHandle<'a, F, S>

impl<'a, F: WalFile + 'static, S: WalStore<F>> Drop for WalFileHandle<'a, F, S> {
fn drop(&mut self) {
unsafe {
(*self.pool).release_file(self.fid);
}
(self.pool).release_file(self.fid);
}
}

Expand Down Expand Up @@ -302,17 +301,16 @@ impl<F: WalFile + 'static, S: WalStore<F>> WalFilePool<F, S> {
.ok_or(WalError::Other("short read".to_string()))
}

async fn write_header(&self, header: &Header) -> Result<(), WalError> {
let base = header as *const Header as usize as *const u8;
let bytes = unsafe { std::slice::from_raw_parts(base, HEADER_SIZE) };
self.header_file.write(0, bytes.into()).await?;
async fn write_header(&self, header: Header) -> Result<(), WalError> {
self.header_file
.write(0, cast_slice(&[header]).into())
.await?;
Ok(())
}

#[allow(clippy::await_holding_refcell_ref)]
// TODO: Refactor to remove mutable reference from being awaited.
async fn get_file(&self, fid: u64, touch: bool) -> Result<WalFileHandle<F, S>, WalError> {
let pool = self as *const WalFilePool<F, S>;
if let Some(h) = self.handle_cache.borrow_mut().pop(&fid) {
let handle = match self.handle_used.borrow_mut().entry(fid) {
hash_map::Entry::Vacant(e) => unsafe {
Expand All @@ -323,7 +321,7 @@ impl<F: WalFile + 'static, S: WalStore<F>> WalFilePool<F, S> {
Ok(WalFileHandle {
fid,
handle,
pool,
pool: self,
wal_file: PhantomData,
})
} else {
Expand All @@ -341,7 +339,7 @@ impl<F: WalFile + 'static, S: WalStore<F>> WalFilePool<F, S> {
Ok(WalFileHandle {
fid,
handle: &*v.0,
pool,
pool: self,
wal_file: PhantomData,
})
}
Expand Down Expand Up @@ -510,8 +508,6 @@ pub struct WalWriter<F: WalFile, S: WalStore<F>> {
msize: usize,
}

unsafe impl<F: WalFile, S> Send for WalWriter<F, S> where S: WalStore<F> + Send {}

impl<F: WalFile + 'static, S: WalStore<F>> WalWriter<F, S> {
fn new(state: WalState, file_pool: WalFilePool<F, S>) -> Self {
let mut b = Vec::new();
Expand Down Expand Up @@ -1001,13 +997,9 @@ impl WalLoader {
None => _yield!(),
};
v.off += msize as u64;
let header = unsafe { &*header_raw.as_ptr().cast::<WalRingBlob>() };
let header = WalRingBlob {
counter: header.counter,
crc32: header.crc32,
rsize: header.rsize,
rtype: header.rtype,
};
let header: &[WalRingBlob] = cast_slice(&header_raw);
let header = *header.get(0)?;

let payload;
match header.rtype.try_into() {
Ok(WalRingType::Full)
Expand Down Expand Up @@ -1115,7 +1107,7 @@ impl WalLoader {
};
let ringid_start = (fid << file_nbit) + v.off;
v.off += msize as u64;
let header = unsafe { &*header_raw.as_ptr().cast::<WalRingBlob>() };
let header: WalRingBlob = *cast_slice(&header_raw).get(0)?;
let rsize = header.rsize;
match header.rtype.try_into() {
Ok(WalRingType::Full) => {
Expand Down Expand Up @@ -1292,7 +1284,7 @@ impl WalLoader {
None => 0,
};

file_pool.write_header(&Header { recover_fid }).await?;
file_pool.write_header(Header { recover_fid }).await?;

let mut skip_remove = false;
for (fname, f) in scanned.into_iter() {
Expand Down

0 comments on commit fedc673

Please sign in to comment.