Skip to content

Commit

Permalink
switch to new database design
Browse files Browse the repository at this point in the history
Under this design, *all* databases are a
`DatabaseImpl<U>`, where the `U` implements
`UserData` (you can use `()` if there is none).

Code would default to `&dyn salsa::Database` but
if you want to give access to the userdata, you
can define a custom database trait
`MyDatabase: salsa::Databse` so long as you

* annotate `MyDatabase` trait definition of
  impls of `MyDatabase` with `#[salsa::db]`
* implement `MyDatabase` for `DatabaseImpl<U>`
  where `U` is your userdata (this could be a
  blanket impl, if you don't know the precise
  userdata type).

The `tests/common/mod.rs` shows the pattern.
  • Loading branch information
nikomatsakis committed Jul 27, 2024
1 parent 9f05061 commit 3d01f41
Show file tree
Hide file tree
Showing 75 changed files with 484 additions and 1,088 deletions.
52 changes: 1 addition & 51 deletions components/salsa-macros/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ struct DbMacro {
impl DbMacro {
fn try_db(self, input: syn::Item) -> syn::Result<TokenStream> {
match input {
syn::Item::Struct(input) => {
let has_storage_impl = self.zalsa_database_impl(&input)?;
Ok(quote! {
#has_storage_impl
#input
})
}
syn::Item::Trait(mut input) => {
self.add_salsa_view_method(&mut input)?;
Ok(quote! {
Expand All @@ -53,54 +46,11 @@ impl DbMacro {
}
_ => Err(syn::Error::new_spanned(
input,
"`db` must be applied to a struct, trait, or impl",
"`db` must be applied to a trait or impl",
)),
}
}

fn find_storage_field(&self, input: &syn::ItemStruct) -> syn::Result<syn::Ident> {
let storage = "storage";
for field in input.fields.iter() {
if let Some(i) = &field.ident {
if i == storage {
return Ok(i.clone());
}
} else {
return Err(syn::Error::new_spanned(
field,
"database struct must be a braced struct (`{}`) with a field named `storage`",
));
}
}

Err(syn::Error::new_spanned(
&input.ident,
"database struct must be a braced struct (`{}`) with a field named `storage`",
))
}

fn zalsa_database_impl(&self, input: &syn::ItemStruct) -> syn::Result<TokenStream> {
let storage = self.find_storage_field(input)?;
let db = &input.ident;
let zalsa = self.hygiene.ident("zalsa");

Ok(quote! {
const _: () = {
use salsa::plumbing as #zalsa;

unsafe impl #zalsa::ZalsaDatabase for #db {
fn zalsa(&self) -> &dyn #zalsa::Zalsa {
&self.#storage
}

fn zalsa_mut(&mut self) -> &mut dyn #zalsa::Zalsa {
&mut self.#storage
}
}
};
})
}

fn add_salsa_view_method(&self, input: &mut syn::ItemTrait) -> syn::Result<()> {
input.items.push(parse_quote! {
#[doc(hidden)]
Expand Down
43 changes: 21 additions & 22 deletions examples/calc/db.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,48 @@
use std::sync::{Arc, Mutex};

use salsa::UserData;

pub type CalcDatabaseImpl = salsa::DatabaseImpl<Calc>;

// ANCHOR: db_struct
#[derive(Default)]
#[salsa::db]
pub(crate) struct Database {
storage: salsa::Storage<Self>,

pub struct Calc {
// The logs are only used for testing and demonstrating reuse:
//
logs: Option<Arc<Mutex<Vec<String>>>>,
logs: Arc<Mutex<Option<Vec<String>>>>,
}
// ANCHOR_END: db_struct

impl Database {
impl Calc {
/// Enable logging of each salsa event.
#[cfg(test)]
pub fn enable_logging(self) -> Self {
assert!(self.logs.is_none());
Self {
storage: self.storage,
logs: Some(Default::default()),
pub fn enable_logging(&self) {
let mut logs = self.logs.lock().unwrap();
if logs.is_none() {
*logs = Some(vec![]);
}
}

#[cfg(test)]
pub fn take_logs(&mut self) -> Vec<String> {
if let Some(logs) = &self.logs {
std::mem::take(&mut *logs.lock().unwrap())
pub fn take_logs(&self) -> Vec<String> {
let mut logs = self.logs.lock().unwrap();
if let Some(logs) = &mut *logs {
std::mem::take(logs)
} else {
panic!("logs not enabled");
vec![]
}
}
}

// ANCHOR: db_impl
#[salsa::db]
impl salsa::Database for Database {
fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) {
impl UserData for Calc {
fn salsa_event(db: &CalcDatabaseImpl, event: &dyn Fn() -> salsa::Event) {
let event = event();
eprintln!("Event: {event:?}");
// Log interesting events, if logging is enabled
if let Some(logs) = &self.logs {
// don't log boring events
if let Some(logs) = &mut *db.logs.lock().unwrap() {
// only log interesting events
if let salsa::EventKind::WillExecute { .. } = event.kind {
logs.lock().unwrap().push(format!("Event: {event:?}"));
logs.push(format!("Event: {event:?}"));
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion examples/calc/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use db::CalcDatabaseImpl;
use ir::{Diagnostic, SourceProgram};
use salsa::Database as Db;

Expand All @@ -8,7 +9,7 @@ mod parser;
mod type_check;

pub fn main() {
let db = db::Database::default();
let db: CalcDatabaseImpl = Default::default();
let source_program = SourceProgram::new(&db, String::new());
compile::compile(&db, source_program);
let diagnostics = compile::compile::accumulated::<Diagnostic>(&db, source_program);
Expand Down
6 changes: 4 additions & 2 deletions examples/calc/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,11 @@ impl<'db> Parser<'_, 'db> {
/// Returns the statements and the diagnostics generated.
#[cfg(test)]
fn parse_string(source_text: &str) -> String {
use salsa::Database as _;
use salsa::Database;

crate::db::Database::default().attach(|db| {
use crate::db::CalcDatabaseImpl;

CalcDatabaseImpl::default().attach(|db| {
// Create the source program
let source_program = SourceProgram::new(db, source_text.to_string());

Expand Down
9 changes: 4 additions & 5 deletions examples/calc/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ use derive_new::new;
use expect_test::expect;
use salsa::Accumulator;
#[cfg(test)]
use salsa::Database as _;
#[cfg(test)]
use test_log::test;

// ANCHOR: parse_statements
Expand Down Expand Up @@ -100,12 +98,13 @@ fn check_string(
expected_diagnostics: expect_test::Expect,
edits: &[(&str, expect_test::Expect, expect_test::Expect)],
) {
use salsa::Setter;
use salsa::{Database, Setter};

use crate::{db::Database, ir::SourceProgram, parser::parse_statements};
use crate::{db::CalcDatabaseImpl, ir::SourceProgram, parser::parse_statements};

// Create the database
let mut db = Database::default().enable_logging();
let mut db = CalcDatabaseImpl::default();
db.enable_logging();

// Create the source program
let source_program = SourceProgram::new(&db, source_text.to_string());
Expand Down
35 changes: 15 additions & 20 deletions examples/lazy-input/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use notify_debouncer_mini::{
notify::{RecommendedWatcher, RecursiveMode},
DebounceEventResult, Debouncer,
};
use salsa::{Accumulator, Setter};
use salsa::{Accumulator, DatabaseImpl, Setter, UserData};

// ANCHOR: main
fn main() -> Result<()> {
// Create the channel to receive file change events.
let (tx, rx) = unbounded();
let mut db = Database::new(tx);
let mut db = DatabaseImpl::with(LazyInput::new(tx));

let initial_file_path = std::env::args_os()
.nth(1)
Expand Down Expand Up @@ -74,28 +74,34 @@ trait Db: salsa::Database {
fn input(&self, path: PathBuf) -> Result<File>;
}

#[salsa::db]
struct Database {
storage: salsa::Storage<Self>,
struct LazyInput {
logs: Mutex<Vec<String>>,
files: DashMap<PathBuf, File>,
file_watcher: Mutex<Debouncer<RecommendedWatcher>>,
}

impl Database {
impl LazyInput {
fn new(tx: Sender<DebounceEventResult>) -> Self {
let storage = Default::default();
Self {
storage,
logs: Default::default(),
files: DashMap::new(),
file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), None, tx).unwrap()),
}
}
}

impl UserData for LazyInput {
fn salsa_event(db: &DatabaseImpl<Self>, event: &dyn Fn() -> salsa::Event) {
// don't log boring events
let event = event();
if let salsa::EventKind::WillExecute { .. } = event.kind {
db.logs.lock().unwrap().push(format!("{:?}", event));
}
}
}

#[salsa::db]
impl Db for Database {
impl Db for DatabaseImpl<LazyInput> {
fn input(&self, path: PathBuf) -> Result<File> {
let path = path
.canonicalize()
Expand All @@ -122,17 +128,6 @@ impl Db for Database {
}
// ANCHOR_END: db

#[salsa::db]
impl salsa::Database for Database {
fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) {
// don't log boring events
let event = event();
if let salsa::EventKind::WillExecute { .. } = event.kind {
self.logs.lock().unwrap().push(format!("{:?}", event));
}
}
}

#[salsa::accumulator]
struct Diagnostic(String);

Expand Down
Loading

0 comments on commit 3d01f41

Please sign in to comment.