From 7717aa40fca7ced7c6bb4b04c5c7ee938adade1d Mon Sep 17 00:00:00 2001 From: Sam Thomas Date: Wed, 15 May 2024 02:09:22 +0100 Subject: [PATCH] add MmapTable to back large immutable data structures associated with an analysed binary (i.e., memory map, instruction table) --- fugue-high/Cargo.toml | 3 + fugue-high/src/util/mod.rs | 1 + fugue-high/src/util/table.rs | 228 +++++++++++++++++++++++++++++++++++ 3 files changed, 232 insertions(+) create mode 100644 fugue-high/src/util/table.rs diff --git a/fugue-high/Cargo.toml b/fugue-high/Cargo.toml index ea8fa72..00c6527 100644 --- a/fugue-high/Cargo.toml +++ b/fugue-high/Cargo.toml @@ -10,15 +10,18 @@ fugue-arch = { version = "0.3", path = "../fugue-arch" } fugue-bv = { version = "0.3", path = "../fugue-bv" } fugue-bytes = { version = "0.3", path = "../fugue-bytes" } fugue-ir = { version = "0.3", path = "../fugue-ir" } +heed = { version = "0.20", features = ["read-txn-no-tls", "posix-sem"] } nom = "7" memmap2 = "0.9" object = "0.35" ouroboros = "0.18" regex = "1" +rkyv = "0.7" rustc-hash = "1.1" serde = { version = "1", features = ["derive"] } serde_yaml = "0.9" static_init = "1" +tempfile = "3" thiserror = "1" yaxpeax-arch = { version = "0.2", default-features = false } yaxpeax-arm = "0.2" diff --git a/fugue-high/src/util/mod.rs b/fugue-high/src/util/mod.rs index 6274733..46514ad 100644 --- a/fugue-high/src/util/mod.rs +++ b/fugue-high/src/util/mod.rs @@ -10,6 +10,7 @@ use memmap2::Mmap; use object::ReadRef; pub mod patfind; +pub mod table; pub enum OwnedOrRef<'a, T> { Owned(T), diff --git a/fugue-high/src/util/table.rs b/fugue-high/src/util/table.rs new file mode 100644 index 0000000..24f0ea0 --- /dev/null +++ b/fugue-high/src/util/table.rs @@ -0,0 +1,228 @@ +use std::marker::PhantomData; +use std::path::Path; + +use heed::types::{Bytes, Str}; +use heed::{Database, Env, EnvOpenOptions, RoTxn, RwTxn}; +use rkyv::ser::serializers::AllocSerializer; +use rkyv::{Archive, Serialize}; +use tempfile::TempDir; +use thiserror::Error; + +pub struct MmapTable { + environment: Env, + database: Database, + temporary: Option, +} + +pub struct MmapTableReader<'a, T> +where + T: Archive, +{ + table: &'a MmapTable, + txn: RoTxn<'a>, + _marker: PhantomData, +} + +pub struct MmapTableWriter<'a, T> +where + T: Archive, +{ + table: &'a MmapTable, + txn: RwTxn<'a>, + _marker: PhantomData, +} + +#[derive(Debug, Error)] +pub enum MmapTableError { + #[error(transparent)] + Database(anyhow::Error), + #[error(transparent)] + IO(anyhow::Error), +} + +impl MmapTableError { + pub fn database(e: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self::Database(e.into()) + } + + pub fn database_with(m: M) -> Self + where + M: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static, + { + Self::Database(anyhow::Error::msg(m)) + } + + pub fn io(e: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self::IO(e.into()) + } +} + +impl MmapTable { + pub fn new(name: impl AsRef, backing: impl AsRef) -> Result { + let environment = unsafe { + EnvOpenOptions::new() + .max_dbs(16) + .map_size(4 * 1024 * 1024 * 1024) + .open(backing.as_ref()) + .map_err(MmapTableError::database)? + }; + + let database = { + let mut txn = environment.write_txn().map_err(MmapTableError::database)?; + let database = environment + .create_database(&mut txn, Some(name.as_ref())) + .map_err(MmapTableError::database)?; + txn.commit().map_err(MmapTableError::database)?; + database + }; + + Ok(Self { + environment, + database, + temporary: None, + }) + } + + pub fn temporary(name: impl AsRef) -> Result { + let backing = tempfile::tempdir().map_err(MmapTableError::io)?; + + let mut slf = Self::new(name.as_ref(), backing.as_ref())?; + + slf.temporary = Some(backing); + + Ok(slf) + } + + pub fn reader<'a, T>(&'a self) -> Result, MmapTableError> + where + T: Archive, + { + let txn = self + .environment + .read_txn() + .map_err(MmapTableError::database)?; + Ok(MmapTableReader { + table: self, + txn, + _marker: PhantomData, + }) + } + + pub fn writer<'a, T>(&'a mut self) -> Result, MmapTableError> + where + T: Archive, + { + let txn = self + .environment + .write_txn() + .map_err(MmapTableError::database)?; + Ok(MmapTableWriter { + table: self, + txn, + _marker: PhantomData, + }) + } +} + +impl<'a, T> MmapTableReader<'a, T> +where + T: Archive, +{ + pub fn get(&self, key: impl AsRef) -> Result, MmapTableError> { + let val = self + .table + .database + .get(&self.txn, key.as_ref()) + .map_err(MmapTableError::database)?; + + Ok(val.map(|val| unsafe { rkyv::archived_root::(val) })) + } +} + +impl<'a, T> MmapTableWriter<'a, T> +where + T: Archive + Serialize>, +{ + pub fn get(&self, key: impl AsRef) -> Result, MmapTableError> { + let val = self + .table + .database + .get(&self.txn, key.as_ref()) + .map_err(MmapTableError::database)?; + + Ok(val.map(|val| unsafe { rkyv::archived_root::(val) })) + } + + pub fn set(&mut self, key: impl AsRef, val: impl AsRef) -> Result<(), MmapTableError> { + let val = rkyv::to_bytes::<_, 1024>(val.as_ref()).map_err(MmapTableError::database)?; + + self.table + .database + .put(&mut self.txn, key.as_ref(), val.as_ref()) + .map_err(MmapTableError::database)?; + + Ok(()) + } + + pub fn clear(&mut self) -> Result<(), MmapTableError> { + self.table + .database + .clear(&mut self.txn) + .map_err(MmapTableError::database)?; + + Ok(()) + } + + pub fn remove(&mut self, key: impl AsRef) -> Result<(), MmapTableError> { + self.table + .database + .delete(&mut self.txn, key.as_ref()) + .map_err(MmapTableError::database)?; + + Ok(()) + } + + pub fn abort(self) { + self.txn.abort() + } + + pub fn commit(self) -> Result<(), MmapTableError> { + self.txn.commit().map_err(MmapTableError::database) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_project() -> Result<(), Box> { + let mut pt = MmapTable::temporary("project")?; + + { + let mut writer = pt.writer::>()?; + + writer.set("mapping1", vec![0u8; 10])?; + writer.set("mapping2", vec![0u8; 100 * 1024 * 1024])?; + writer.set("mapping3", vec![0u8; 256 * 1024 * 1024])?; + + writer.commit()?; + } + + { + let reader = pt.reader::>()?; + + let bytes = reader.get("mapping2")?.unwrap(); + + assert_eq!(bytes.len(), 100 * 1024 * 1024); + } + + Ok(()) + } +}