From f6314118da7168257ce65492d65577fcf1dc3ee8 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sun, 1 Sep 2024 18:41:29 +0200 Subject: [PATCH 01/10] Use `Result` over `catch_unwind` for cancellation and cycle handling --- Cargo.toml | 1 + benches/compare.rs | 31 +- benches/incremental.rs | 12 +- .../src/setup_input_struct.rs | 8 +- .../src/setup_interned_struct.rs | 2 +- .../salsa-macro-rules/src/setup_tracked_fn.rs | 24 +- .../src/setup_tracked_struct.rs | 10 +- components/salsa-macros/src/tracked_impl.rs | 3 +- examples/calc/compile.rs | 8 +- examples/calc/ir.rs | 10 +- examples/calc/main.rs | 4 +- examples/calc/parser.rs | 212 +++-- examples/calc/type_check.rs | 104 ++- examples/lazy-input/main.rs | 82 +- src/accumulator.rs | 2 +- src/cancelled.rs | 57 -- src/cycle.rs | 17 +- src/function.rs | 15 +- src/function/accumulated.rs | 8 +- src/function/execute.rs | 56 +- src/function/fetch.rs | 47 +- src/function/maybe_changed_after.rs | 46 +- src/ingredient.rs | 2 +- src/input.rs | 10 +- src/input/input_field.rs | 4 +- src/interned.rs | 18 +- src/key.rs | 2 +- src/lib.rs | 5 +- src/result.rs | 107 +++ src/runtime.rs | 25 +- src/table/sync.rs | 10 +- src/tracked_struct.rs | 18 +- src/tracked_struct/tracked_field.rs | 4 +- src/zalsa.rs | 51 +- src/zalsa_local.rs | 26 +- tests/accumulate-chain.rs | 30 +- tests/accumulate-custom-clone.rs | 11 +- tests/accumulate-custom-debug.rs | 12 +- tests/accumulate-dag.rs | 28 +- tests/accumulate-execution-order.rs | 28 +- tests/accumulate-from-tracked-fn.rs | 28 +- tests/accumulate-no-duplicates.rs | 42 +- tests/accumulate-reuse-workaround.rs | 26 +- tests/accumulate-reuse.rs | 20 +- tests/accumulate.rs | 54 +- ...of-tracked-structs-from-older-revisions.rs | 8 +- ...racked-structs-from-older-revisions.stderr | 4 +- tests/compile-fail/span-tracked-getter.rs | 9 +- tests/compile-fail/span-tracked-getter.stderr | 8 +- ...es-not-work-if-the-key-is-a-salsa-input.rs | 4 +- ...not-work-if-the-key-is-a-salsa-interned.rs | 2 +- .../compile-fail/tracked_fn_incompatibles.rs | 19 +- .../tracked_fn_incompatibles.stderr | 9 +- .../tracked_method_on_untracked_impl.rs | 2 +- .../tracked_method_on_untracked_impl.stderr | 2 +- tests/cycles.rs | 872 +++++++++--------- tests/debug.rs | 36 +- tests/deletion-cascade.rs | 43 +- tests/deletion-drops.rs | 20 +- tests/deletion.rs | 31 +- tests/elided-lifetime-in-tracked-fn.rs | 18 +- ...truct_changes_but_fn_depends_on_field_y.rs | 24 +- ...input_changes_but_fn_depends_on_field_y.rs | 20 +- tests/hello_world.rs | 26 +- tests/input_default.rs | 23 +- tests/input_field_durability.rs | 15 +- tests/interned-struct-with-lifetime.rs | 16 +- tests/is_send_sync.rs | 13 +- tests/lru.rs | 46 +- tests/mutate_in_place.rs | 6 +- tests/override_new_get_set.rs | 8 +- ...ng-tracked-struct-outside-of-tracked-fn.rs | 2 +- tests/parallel/parallel_cancellation.rs | 28 +- tests/parallel/parallel_cycle_all_recover.rs | 208 ++--- tests/parallel/parallel_cycle_mid_recover.rs | 204 ++-- tests/parallel/parallel_cycle_none_recover.rs | 37 +- tests/parallel/parallel_cycle_one_recover.rs | 178 ++-- tests/preverify-struct-with-leaked-data-2.rs | 28 +- tests/preverify-struct-with-leaked-data.rs | 20 +- ...the-key-is-created-in-the-current-query.rs | 21 +- tests/synthetic_write.rs | 12 +- tests/tracked-struct-id-field-bad-eq.rs | 15 +- tests/tracked-struct-id-field-bad-hash.rs | 14 +- tests/tracked-struct-unchanged-in-new-rev.rs | 11 +- tests/tracked-struct-value-field-bad-eq.rs | 20 +- tests/tracked-struct-value-field-not-eq.rs | 15 +- tests/tracked_fn_constant.rs | 18 +- .../tracked_fn_high_durability_dependency.rs | 12 +- tests/tracked_fn_no_eq.rs | 20 +- tests/tracked_fn_on_input.rs | 10 +- ...racked_fn_on_input_with_high_durability.rs | 16 +- tests/tracked_fn_on_interned.rs | 12 +- tests/tracked_fn_on_tracked.rs | 9 +- tests/tracked_fn_on_tracked_specify.rs | 35 +- tests/tracked_fn_read_own_entity.rs | 30 +- tests/tracked_fn_read_own_specify.rs | 16 +- tests/tracked_fn_return_ref.rs | 11 +- tests/tracked_method.rs | 20 +- tests/tracked_method_inherent_return_ref.rs | 12 +- tests/tracked_method_on_tracked_struct.rs | 26 +- tests/tracked_method_trait_return_ref.rs | 14 +- tests/tracked_struct_durability.rs | 39 +- tests/tracked_with_struct_db.rs | 16 +- tests/warnings/needless_lifetimes.rs | 6 +- 104 files changed, 2037 insertions(+), 1702 deletions(-) delete mode 100644 src/cancelled.rs create mode 100644 src/result.rs diff --git a/Cargo.toml b/Cargo.toml index d3815b181..837c24efe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ description = "A generic framework for on-demand, incrementalized computation (e arc-swap = "1" crossbeam = "0.8" dashmap = "6" +drop_bomb = "0.1.5" hashlink = "0.9" indexmap = "2" append-only-vec = "0.1.5" diff --git a/benches/compare.rs b/benches/compare.rs index ac15ac978..405944de3 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -7,8 +7,8 @@ pub struct Input { } #[salsa::tracked] -pub fn length(db: &dyn salsa::Database, input: Input) -> usize { - input.text(db).len() +pub fn length(db: &dyn salsa::Database, input: Input) -> salsa::Result { + Ok(input.text(db)?.len()) } #[salsa::interned] @@ -17,8 +17,11 @@ pub struct InternedInput<'db> { } #[salsa::tracked] -pub fn interned_length<'db>(db: &'db dyn salsa::Database, input: InternedInput<'db>) -> usize { - input.text(db).len() +pub fn interned_length<'db>( + db: &'db dyn salsa::Database, + input: InternedInput<'db>, +) -> salsa::Result { + Ok(input.text(db).len()) } fn mutating_inputs(c: &mut Criterion) { @@ -38,11 +41,11 @@ fn mutating_inputs(c: &mut Criterion) { group.bench_function(BenchmarkId::new("mutating", n), |b| { b.iter(|| { let input = Input::new(&db, base_string.clone()); - let actual_len = length(&db, input); + let actual_len = length(&db, input).unwrap(); assert_eq!(base_len, actual_len); input.set_text(&mut db).to(string.clone()); - let actual_len = length(&db, input); + let actual_len = length(&db, input).unwrap(); assert_eq!(new_len, actual_len); }) }); @@ -60,30 +63,30 @@ fn inputs(c: &mut Criterion) { group.bench_function(BenchmarkId::new("new", "InternedInput"), |b| { b.iter(|| { - let input: InternedInput = InternedInput::new(&db, "hello, world!".to_owned()); - interned_length(&db, input); + let input: InternedInput = InternedInput::new(&db, "hello, world!".to_owned()).unwrap(); + interned_length(&db, input).unwrap(); }) }); group.bench_function(BenchmarkId::new("amortized", "InternedInput"), |b| { - let input = InternedInput::new(&db, "hello, world!".to_owned()); - let _ = interned_length(&db, input); + let input = InternedInput::new(&db, "hello, world!".to_owned()).unwrap(); + let _ = interned_length(&db, input).unwrap(); - b.iter(|| interned_length(&db, input)); + b.iter(|| interned_length(&db, input).unwrap()); }); group.bench_function(BenchmarkId::new("new", "Input"), |b| { b.iter(|| { let input = Input::new(&db, "hello, world!".to_owned()); - length(&db, input); + length(&db, input).unwrap(); }) }); group.bench_function(BenchmarkId::new("amortized", "Input"), |b| { let input = Input::new(&db, "hello, world!".to_owned()); - let _ = length(&db, input); + let _ = length(&db, input).unwrap(); - b.iter(|| length(&db, input)); + b.iter(|| length(&db, input).unwrap()); }); group.finish(); diff --git a/benches/incremental.rs b/benches/incremental.rs index 5e5aa5f42..76e416c5d 100644 --- a/benches/incremental.rs +++ b/benches/incremental.rs @@ -12,14 +12,14 @@ struct Tracked<'db> { } #[salsa::tracked(return_ref)] -fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> Vec> { - (0..input.field(db)).map(|i| Tracked::new(db, i)).collect() +fn index<'db>(db: &'db dyn salsa::Database, input: Input) -> salsa::Result>> { + (0..input.field(db)?).map(|i| Tracked::new(db, i)).collect() } #[salsa::tracked] -fn root(db: &dyn salsa::Database, input: Input) -> usize { - let index = index(db, input); - index.len() +fn root(db: &dyn salsa::Database, input: Input) -> salsa::Result { + let index = index(db, input)?; + Ok(index.len()) } fn many_tracked_structs(criterion: &mut Criterion) { @@ -41,7 +41,7 @@ fn many_tracked_structs(criterion: &mut Criterion) { // Make a change, but fetch the result for the other input input2.set_field(db).to(2); - let result = root(db, *input); + let result = root(db, *input).unwrap(); assert_eq!(result, 1_000); }, diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 51cd482bc..40291db3e 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -142,7 +142,7 @@ macro_rules! setup_input_struct { } $( - $field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> $zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty) + $field_getter_vis fn $field_getter_id<'db, $Db>(self, db: &'db $Db) -> salsa::Result<$zalsa::maybe_cloned_ty!($field_option, 'db, $field_ty)> where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, @@ -151,12 +151,12 @@ macro_rules! setup_input_struct { db.as_dyn_database(), self, $field_index, - ); - $zalsa::maybe_clone!( + )?; + Ok($zalsa::maybe_clone!( $field_option, $field_ty, &fields.$field_index, - ) + )) } )* diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 2b7f8d8f4..c665b6a12 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -126,7 +126,7 @@ macro_rules! setup_interned_struct { } impl<$db_lt> $Struct<$db_lt> { - pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self + pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> salsa::Result where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 3d5862d48..fb48e7aa9 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -80,7 +80,7 @@ macro_rules! setup_tracked_fn { $($input_id: $input_ty,)* ) -> salsa::plumbing::macro_if! { if $return_ref { - &$db_lt $output_ty + salsa::Result<&$db_lt <$output_ty as salsa::plumbing::HasOutput>::Output> } else { $output_ty } @@ -156,7 +156,7 @@ macro_rules! setup_tracked_fn { type Input<$db_lt> = ($($input_ty),*); - type Output<$db_lt> = $output_ty; + type Output<$db_lt> = <$output_ty as $zalsa::HasOutput>::Output; const CYCLE_STRATEGY: $zalsa::CycleRecoveryStrategy = $zalsa::CycleRecoveryStrategy::$cycle_recovery_strategy; @@ -173,7 +173,7 @@ macro_rules! setup_tracked_fn { } } - fn execute<$db_lt>($db: &$db_lt Self::DbView, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { + fn execute<$db_lt>($db: &$db_lt Self::DbView, ($($input_id),*): ($($input_ty),*)) -> salsa::Result> { $inner_fn $inner($db, $($input_id),*) @@ -231,11 +231,11 @@ macro_rules! setup_tracked_fn { pub fn accumulated<$db_lt, A: salsa::Accumulator>( $db: &$db_lt dyn $Db, $($input_id: $input_ty,)* - ) -> Vec { + ) -> salsa::Result> { use salsa::plumbing as $zalsa; let key = $zalsa::macro_if! { if $needs_interner { - $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)) + $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*))? } else { $zalsa::AsId::as_id(&($($input_id),*)) } @@ -248,7 +248,7 @@ macro_rules! setup_tracked_fn { pub fn specify<$db_lt>( $db: &$db_lt dyn $Db, $($input_id: $input_ty,)* - value: $output_ty, + value: <$Configuration as $zalsa::function::Configuration>::Output<$db_lt>, ) { let key = $zalsa::AsId::as_id(&($($input_id),*)); $Configuration::fn_ingredient($db).specify_and_record( @@ -271,21 +271,21 @@ macro_rules! setup_tracked_fn { let result = $zalsa::macro_if! { if $needs_interner { { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)); - $Configuration::fn_ingredient($db).fetch($db, key) + let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*))?; + $Configuration::fn_ingredient($db).fetch($db, key)? } } else { - $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) + $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*)))? } }; - $zalsa::macro_if! { + Ok($zalsa::macro_if! { if $return_ref { result } else { - <$output_ty as std::clone::Clone>::clone(result) + <<$Configuration as $zalsa::function::Configuration>::Output<$db_lt> as std::clone::Clone>::clone(result) } - } + }) }) } }; diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index d0d42c6da..b0bc1b74e 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -184,7 +184,7 @@ macro_rules! setup_tracked_struct { } impl<$db_lt> $Struct<$db_lt> { - pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self + pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> salsa::Result where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, @@ -196,18 +196,18 @@ macro_rules! setup_tracked_struct { } $( - $field_getter_vis fn $field_getter_id<$Db>(self, db: &$db_lt $Db) -> $crate::maybe_cloned_ty!($field_option, $db_lt, $field_ty) + $field_getter_vis fn $field_getter_id<$Db>(self, db: &$db_lt $Db) -> salsa::Result<$crate::maybe_cloned_ty!($field_option, $db_lt, $field_ty)> where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { let db = db.as_dyn_database(); - let fields = $Configuration::ingredient(db).field(db, self, $field_index); - $crate::maybe_clone!( + let fields = $Configuration::ingredient(db).field(db, self, $field_index)?; + Ok($crate::maybe_clone!( $field_option, $field_ty, &fields.$field_index, - ) + )) } )* diff --git a/components/salsa-macros/src/tracked_impl.rs b/components/salsa-macros/src/tracked_impl.rs index 073f43ff4..3aaf2be60 100644 --- a/components/salsa-macros/src/tracked_impl.rs +++ b/components/salsa-macros/src/tracked_impl.rs @@ -274,7 +274,8 @@ impl Macro { ) -> syn::Result<()> { if let Some(return_ref) = &args.return_ref { if let syn::ReturnType::Type(_, t) = &mut sig.output { - **t = parse_quote!(& #db_lt #t) + **t = + parse_quote!(salsa::Result<&#db_lt <#t as salsa::plumbing::HasOutput>::Output>) } else { return Err(syn::Error::new_spanned( return_ref, diff --git a/examples/calc/compile.rs b/examples/calc/compile.rs index 2c88e6f15..427070039 100644 --- a/examples/calc/compile.rs +++ b/examples/calc/compile.rs @@ -1,7 +1,9 @@ use crate::{ir::SourceProgram, parser::parse_statements, type_check::type_check_program}; #[salsa::tracked] -pub fn compile(db: &dyn crate::Db, source_program: SourceProgram) { - let program = parse_statements(db, source_program); - type_check_program(db, program); +pub fn compile(db: &dyn crate::Db, source_program: SourceProgram) -> salsa::Result<()> { + let program = parse_statements(db, source_program)?; + type_check_program(db, program)?; + + Ok(()) } diff --git a/examples/calc/ir.rs b/examples/calc/ir.rs index 8767cc6fb..ba0ddd21b 100644 --- a/examples/calc/ir.rs +++ b/examples/calc/ir.rs @@ -108,18 +108,18 @@ pub struct Diagnostic { impl Diagnostic { #[cfg(test)] - pub fn render(&self, db: &dyn crate::Db, src: SourceProgram) -> String { + pub fn render(&self, db: &dyn crate::Db, src: SourceProgram) -> salsa::Result { use annotate_snippets::*; - let line_start = src.text(db)[..self.start].lines().count() + 1; - Renderer::plain() + let line_start = src.text(db)?[..self.start].lines().count() + 1; + Ok(Renderer::plain() .render( Level::Error.title(&self.message).snippet( - Snippet::source(src.text(db)) + Snippet::source(src.text(db)?) .line_start(line_start) .origin("input") .annotation(Level::Error.span(self.start..self.end).label("here")), ), ) - .to_string() + .to_string()) } } diff --git a/examples/calc/main.rs b/examples/calc/main.rs index 616dede67..14b395299 100644 --- a/examples/calc/main.rs +++ b/examples/calc/main.rs @@ -11,7 +11,7 @@ mod type_check; pub fn main() { let db: CalcDatabaseImpl = Default::default(); let source_program = SourceProgram::new(&db, String::new()); - compile::compile(&db, source_program); - let diagnostics = compile::compile::accumulated::(&db, source_program); + compile::compile(&db, source_program).unwrap(); + let diagnostics = compile::compile::accumulated::(&db, source_program).unwrap(); eprintln!("{diagnostics:?}"); } diff --git a/examples/calc/parser.rs b/examples/calc/parser.rs index d92e5d3a0..fcd227196 100644 --- a/examples/calc/parser.rs +++ b/examples/calc/parser.rs @@ -8,9 +8,9 @@ use crate::ir::{ // ANCHOR: parse_statements #[salsa::tracked] -pub fn parse_statements(db: &dyn crate::Db, source: SourceProgram) -> Program<'_> { +pub fn parse_statements(db: &dyn crate::Db, source: SourceProgram) -> salsa::Result> { // Get the source text from the database - let source_text = source.text(db); + let source_text = source.text(db)?; // Create the parser let mut parser = Parser { @@ -31,7 +31,7 @@ pub fn parse_statements(db: &dyn crate::Db, source: SourceProgram) -> Program<'_ } // Otherwise, there is more input, so parse a statement. - if let Some(statement) = parser.parse_statement() { + if let Some(statement) = parser.parse_statement()? { result.push(statement); } else { // If we failed, report an error at whatever position the parser @@ -98,7 +98,7 @@ impl<'db> Parser<'_, 'db> { } // Returns a span ranging from `start_position` until the current position (exclusive) - fn span_from(&self, start_position: usize) -> Span<'db> { + fn span_from(&self, start_position: usize) -> salsa::Result> { Span::new(self.db, start_position, self.position) } @@ -120,138 +120,169 @@ impl<'db> Parser<'_, 'db> { } // ANCHOR: parse_statement - fn parse_statement(&mut self) -> Option> { + fn parse_statement(&mut self) -> salsa::Result>> { let start_position = self.skip_whitespace(); - let word = self.word()?; + let Some(word) = self.word() else { + return Ok(None); + }; + if word == "fn" { - let func = self.parse_function()?; - Some(Statement::new( - self.span_from(start_position), + let Some(func) = self.parse_function()? else { + return Ok(None); + }; + Ok(Some(Statement::new( + self.span_from(start_position)?, StatementData::Function(func), - )) + ))) } else if word == "print" { - let expr = self.parse_expression()?; - Some(Statement::new( - self.span_from(start_position), + let Some(expr) = self.parse_expression()? else { + return Ok(None); + }; + Ok(Some(Statement::new( + self.span_from(start_position)?, StatementData::Print(expr), - )) + ))) } else { - None + Ok(None) } } // ANCHOR_END: parse_statement // ANCHOR: parse_function - fn parse_function(&mut self) -> Option> { + fn parse_function(&mut self) -> salsa::Result>> { let start_position = self.skip_whitespace(); - let name = self.word()?; - let name_span = self.span_from(start_position); - let name: FunctionId = FunctionId::new(self.db, name); + let Some(name) = self.word() else { + return Ok(None); + }; + + let name_span = self.span_from(start_position)?; + let name: FunctionId = FunctionId::new(self.db, name)?; // ^^^^^^^^^^^^^^^ // Create a new interned struct. - self.ch('(')?; - let args = self.parameters()?; - self.ch(')')?; - self.ch('=')?; - let body = self.parse_expression()?; - Some(Function::new(self.db, name, name_span, args, body)) + if self.ch('(')?.is_none() { + return Ok(None); + } + let Some(args) = self.parameters()? else { + return Ok(None); + }; + if self.ch(')')?.is_none() { + return Ok(None); + } + if self.ch('=')?.is_none() { + return Ok(None); + } + let Some(body) = self.parse_expression()? else { + return Ok(None); + }; + + Ok(Some(Function::new(self.db, name, name_span, args, body)?)) // ^^^^^^^^^^^^^ // Create a new entity struct. } // ANCHOR_END: parse_function - fn parse_expression(&mut self) -> Option> { + fn parse_expression(&mut self) -> salsa::Result>> { self.parse_op_expression(Self::parse_expression1, Self::low_op) } - fn low_op(&mut self) -> Option { - if self.ch('+').is_some() { - Some(Op::Add) - } else if self.ch('-').is_some() { - Some(Op::Subtract) + fn low_op(&mut self) -> salsa::Result> { + if self.ch('+')?.is_some() { + Ok(Some(Op::Add)) + } else if self.ch('-')?.is_some() { + Ok(Some(Op::Subtract)) } else { - None + Ok(None) } } /// Parses a high-precedence expression (times, div). /// /// On failure, skips arbitrary tokens. - fn parse_expression1(&mut self) -> Option> { + fn parse_expression1(&mut self) -> salsa::Result>> { self.parse_op_expression(Self::parse_expression2, Self::high_op) } - fn high_op(&mut self) -> Option { - if self.ch('*').is_some() { + fn high_op(&mut self) -> salsa::Result> { + Ok(if self.ch('*')?.is_some() { Some(Op::Multiply) - } else if self.ch('/').is_some() { + } else if self.ch('/')?.is_some() { Some(Op::Divide) } else { None - } + }) } fn parse_op_expression( &mut self, - mut parse_expr: impl FnMut(&mut Self) -> Option>, - mut op: impl FnMut(&mut Self) -> Option, - ) -> Option> { + mut parse_expr: impl FnMut(&mut Self) -> salsa::Result>>, + mut op: impl FnMut(&mut Self) -> salsa::Result>, + ) -> salsa::Result>> { let start_position = self.skip_whitespace(); - let mut expr1 = parse_expr(self)?; + let Some(mut expr1) = parse_expr(self)? else { + return Ok(None); + }; - while let Some(op) = op(self) { - let expr2 = parse_expr(self)?; + while let Some(op) = op(self)? { + let Some(expr2) = parse_expr(self)? else { + return Ok(None); + }; expr1 = Expression::new( - self.span_from(start_position), + self.span_from(start_position)?, ExpressionData::Op(Box::new(expr1), op, Box::new(expr2)), ); } - Some(expr1) + Ok(Some(expr1)) } /// Parses a "base expression" (no operators). /// /// On failure, skips arbitrary tokens. - fn parse_expression2(&mut self) -> Option> { + fn parse_expression2(&mut self) -> salsa::Result>> { let start_position = self.skip_whitespace(); if let Some(w) = self.word() { - if self.ch('(').is_some() { - let f = FunctionId::new(self.db, w); - let args = self.parse_expressions()?; + if self.ch('(')?.is_some() { + let f = FunctionId::new(self.db, w)?; + let Some(args) = self.parse_expressions()? else { + return Ok(None); + }; self.ch(')')?; - return Some(Expression::new( - self.span_from(start_position), + return Ok(Some(Expression::new( + self.span_from(start_position)?, ExpressionData::Call(f, args), - )); + ))); } - let v = VariableId::new(self.db, w); - Some(Expression::new( - self.span_from(start_position), + let v = VariableId::new(self.db, w)?; + Ok(Some(Expression::new( + self.span_from(start_position)?, ExpressionData::Variable(v), - )) + ))) } else if let Some(n) = self.number() { - Some(Expression::new( - self.span_from(start_position), + Ok(Some(Expression::new( + self.span_from(start_position)?, ExpressionData::Number(OrderedFloat::from(n)), - )) - } else if self.ch('(').is_some() { - let expr = self.parse_expression()?; + ))) + } else if self.ch('(')?.is_some() { + let Some(expr) = self.parse_expression()? else { + return Ok(None); + }; self.ch(')')?; - Some(expr) + Ok(Some(expr)) } else { - None + Ok(None) } } - fn parse_expressions(&mut self) -> Option>> { + fn parse_expressions(&mut self) -> salsa::Result>>> { let mut r = vec![]; loop { - let expr = self.parse_expression()?; + let Some(expr) = self.parse_expression()? else { + return Ok(None); + }; r.push(expr); - if self.ch(',').is_none() { - return Some(r); + if self.ch(',')?.is_none() { + return Ok(Some(r)); } } } @@ -260,15 +291,18 @@ impl<'db> Parser<'_, 'db> { /// No trailing commas because I am lazy. /// /// On failure, skips arbitrary tokens. - fn parameters(&mut self) -> Option>> { + fn parameters(&mut self) -> salsa::Result>>> { let mut r = vec![]; loop { - let name = self.word()?; - let vid = VariableId::new(self.db, name); + let Some(name) = self.word() else { + return Ok(None); + }; + + let vid = VariableId::new(self.db, name)?; r.push(vid); - if self.ch(',').is_none() { - return Some(r); + if self.ch(',')?.is_none() { + return Ok(Some(r)); } } } @@ -276,14 +310,14 @@ impl<'db> Parser<'_, 'db> { /// Parses a single character. /// /// Even on failure, only skips whitespace. - fn ch(&mut self, c: char) -> Option> { + fn ch(&mut self, c: char) -> salsa::Result>> { let start_position = self.skip_whitespace(); match self.peek() { Some(p) if c == p => { self.consume(c); - Some(self.span_from(start_position)) + Ok(Some(self.span_from(start_position)?)) } - _ => None, + _ => Ok(None), } } @@ -350,7 +384,7 @@ impl<'db> Parser<'_, 'db> { /// Create a new database with the given source text and parse the result. /// Returns the statements and the diagnostics generated. #[cfg(test)] -fn parse_string(source_text: &str) -> String { +fn parse_string(source_text: &str) -> salsa::Result { use salsa::Database; use crate::db::CalcDatabaseImpl; @@ -360,21 +394,21 @@ fn parse_string(source_text: &str) -> String { let source_program = SourceProgram::new(db, source_text.to_string()); // Invoke the parser - let statements = parse_statements(db, source_program); + let statements = parse_statements(db, source_program)?; // Read out any diagnostics - let accumulated = parse_statements::accumulated::(db, source_program); + let accumulated = parse_statements::accumulated::(db, source_program)?; // Format the result as a string and return it - format!("{:#?}", (statements, accumulated)) + Ok(format!("{:#?}", (statements, accumulated))) }) } // ANCHOR_END: parse_string // ANCHOR: parse_print #[test] -fn parse_print() { - let actual = parse_string("print 1 + 2"); +fn parse_print() -> salsa::Result<()> { + let actual = parse_string("print 1 + 2")?; let expected = expect_test::expect![[r#" ( Program { @@ -424,11 +458,12 @@ fn parse_print() { [], )"#]]; expected.assert_eq(&actual); + Ok(()) } // ANCHOR_END: parse_print #[test] -fn parse_example() { +fn parse_example() -> salsa::Result<()> { let actual = parse_string( " fn area_rectangle(w, h) = w * h @@ -437,7 +472,7 @@ fn parse_example() { print area_circle(1) print 11 * 2 ", - ); + )?; let expected = expect_test::expect![[r#" ( Program { @@ -704,13 +739,14 @@ fn parse_example() { [], )"#]]; expected.assert_eq(&actual); + Ok(()) } #[test] -fn parse_error() { +fn parse_error() -> salsa::Result<()> { let source_text: &str = "print 1 + + 2"; // 0123456789^ <-- this is the position 10, where the error is reported - let actual = parse_string(source_text); + let actual = parse_string(source_text)?; let expected = expect_test::expect![[r#" ( Program { @@ -726,13 +762,14 @@ fn parse_error() { ], )"#]]; expected.assert_eq(&actual); + Ok(()) } #[test] -fn parse_precedence() { +fn parse_precedence() -> salsa::Result<()> { // this parses as `(1 + (2 * 3)) + 4` let source_text: &str = "print 1 + 2 * 3 + 4"; - let actual = parse_string(source_text); + let actual = parse_string(source_text)?; let expected = expect_test::expect![[r#" ( Program { @@ -822,4 +859,5 @@ fn parse_precedence() { [], )"#]]; expected.assert_eq(&actual); + Ok(()) } diff --git a/examples/calc/type_check.rs b/examples/calc/type_check.rs index d73a552c2..143f46335 100644 --- a/examples/calc/type_check.rs +++ b/examples/calc/type_check.rs @@ -10,13 +10,15 @@ use test_log::test; // ANCHOR: parse_statements #[salsa::tracked] -pub fn type_check_program<'db>(db: &'db dyn crate::Db, program: Program<'db>) { - for statement in program.statements(db) { +pub fn type_check_program<'db>(db: &'db dyn crate::Db, program: Program<'db>) -> salsa::Result<()> { + for statement in program.statements(db)? { match &statement.data { - StatementData::Function(f) => type_check_function(db, *f, program), - StatementData::Print(e) => CheckExpression::new(db, program, &[]).check(e), + StatementData::Function(f) => type_check_function(db, *f, program)?, + StatementData::Print(e) => CheckExpression::new(db, program, &[]).check(e)?, } } + + Ok(()) } #[salsa::tracked] @@ -24,8 +26,8 @@ pub fn type_check_function<'db>( db: &'db dyn crate::Db, function: Function<'db>, program: Program<'db>, -) { - CheckExpression::new(db, program, function.args(db)).check(function.body(db)) +) -> salsa::Result<()> { + CheckExpression::new(db, program, function.args(db)?).check(function.body(db)?) } #[salsa::tracked] @@ -33,15 +35,16 @@ pub fn find_function<'db>( db: &'db dyn crate::Db, program: Program<'db>, name: FunctionId<'db>, -) -> Option> { - program - .statements(db) - .iter() - .flat_map(|s| match &s.data { - StatementData::Function(f) if f.name(db) == name => Some(*f), - _ => None, - }) - .next() +) -> salsa::Result>> { + for s in program.statements(db)? { + if let StatementData::Function(f) = &s.data { + if f.name(db)? == name { + return Ok(Some(*f)); + } + } + } + + Ok(None) } #[derive(new)] @@ -52,11 +55,11 @@ struct CheckExpression<'input, 'db> { } impl<'db> CheckExpression<'_, 'db> { - fn check(&self, expression: &Expression<'db>) { + fn check(&self, expression: &Expression<'db>) -> salsa::Result<()> { match &expression.data { crate::ir::ExpressionData::Op(left, _, right) => { - self.check(left); - self.check(right); + self.check(left)?; + self.check(right)?; } crate::ir::ExpressionData::Number(_) => {} crate::ir::ExpressionData::Variable(v) => { @@ -64,29 +67,32 @@ impl<'db> CheckExpression<'_, 'db> { self.report_error( expression.span, format!("the variable `{}` is not declared", v.text(self.db)), - ); + )?; } } crate::ir::ExpressionData::Call(f, args) => { - if self.find_function(*f).is_none() { + if self.find_function(*f)?.is_none() { self.report_error( expression.span, format!("the function `{}` is not declared", f.text(self.db)), - ); + )?; } for arg in args { - self.check(arg); + self.check(arg)?; } } } + + Ok(()) } - fn find_function(&self, f: FunctionId<'db>) -> Option> { + fn find_function(&self, f: FunctionId<'db>) -> salsa::Result>> { find_function(self.db, self.program, f) } - fn report_error(&self, span: Span, message: String) { - Diagnostic::new(span.start(self.db), span.end(self.db), message).accumulate(self.db); + fn report_error(&self, span: Span, message: String) -> salsa::Result<()> { + Diagnostic::new(span.start(self.db)?, span.end(self.db)?, message).accumulate(self.db); + Ok(()) } } @@ -97,7 +103,7 @@ fn check_string( source_text: &str, expected_diagnostics: expect_test::Expect, edits: &[(&str, expect_test::Expect, expect_test::Expect)], -) { +) -> salsa::Result<()> { use salsa::{Database, Setter}; use crate::{db::CalcDatabaseImpl, ir::SourceProgram, parser::parse_statements}; @@ -110,18 +116,19 @@ fn check_string( let source_program = SourceProgram::new(&db, source_text.to_string()); // Invoke the parser - let program = parse_statements(&db, source_program); + let program = parse_statements(&db, source_program)?; // Read out any diagnostics - db.attach(|db| { + db.attach(|db| -> salsa::Result<()> { let rendered_diagnostics: String = - type_check_program::accumulated::(db, program) + type_check_program::accumulated::(db, program)? .into_iter() .map(|d| d.render(db, source_program)) - .collect::>() + .collect::>>()? .join("\n"); expected_diagnostics.assert_eq(&rendered_diagnostics); - }); + Ok(()) + })?; // Clear logs db.take_logs(); @@ -132,23 +139,26 @@ fn check_string( .set_text(&mut db) .to(new_source_text.to_string()); - db.attach(|db| { - let program = parse_statements(db, source_program); + db.attach(|db| -> salsa::Result<()> { + let program = parse_statements(db, source_program)?; expected_diagnostics - .assert_debug_eq(&type_check_program::accumulated::(db, program)); - }); + .assert_debug_eq(&type_check_program::accumulated::(db, program)?); + + Ok(()) + })?; expected_logs.assert_debug_eq(&db.take_logs()); } + Ok(()) } #[test] -fn check_print() { - check_string("print 1 + 2", expect![""], &[]); +fn check_print() -> salsa::Result<()> { + check_string("print 1 + 2", expect![""], &[]) } #[test] -fn check_bad_variable_in_program() { +fn check_bad_variable_in_program() -> salsa::Result<()> { check_string( "print a + b", expect![[r#" @@ -165,11 +175,11 @@ fn check_bad_variable_in_program() { | ^ here |"#]], &[], - ); + ) } #[test] -fn check_bad_function_in_program() { +fn check_bad_function_in_program() -> salsa::Result<()> { check_string( "print a(22)", expect![[r#" @@ -180,11 +190,11 @@ fn check_bad_function_in_program() { | ^^^^^ here |"#]], &[], - ); + ) } #[test] -fn check_bad_variable_in_function() { +fn check_bad_variable_in_function() -> salsa::Result<()> { check_string( " fn add_one(a) = a + b @@ -202,11 +212,11 @@ fn check_bad_variable_in_function() { 6 | |"#]], &[], - ); + ) } #[test] -fn check_bad_function_in_function() { +fn check_bad_function_in_function() -> salsa::Result<()> { check_string( " fn add_one(a) = add_two(a) + b @@ -233,11 +243,11 @@ fn check_bad_function_in_function() { 6 | |"#]], &[], - ); + ) } #[test] -fn fix_bad_variable_in_function() { +fn fix_bad_variable_in_function() -> salsa::Result<()> { check_string( " fn double(a) = a * b @@ -272,5 +282,5 @@ fn fix_bad_variable_in_function() { ] "#]], )], - ); + ) } diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index 792b7f348..b2e7b18f2 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -28,8 +28,8 @@ fn main() -> Result<()> { loop { // Compile the code starting at the provided input, this will read other // needed files using the on-demand mechanism. - let sum = compile(&db, initial); - let diagnostics = compile::accumulated::(&db, initial); + let sum = compile(&db, initial)?; + let diagnostics = compile::accumulated::(&db, initial)?; if diagnostics.is_empty() { println!("Sum is: {}", sum); } else { @@ -138,16 +138,18 @@ impl Db for LazyInputDatabase { struct Diagnostic(String); impl Diagnostic { - fn push_error(db: &dyn Db, file: File, error: Report) { + fn push_error(db: &dyn Db, file: File, error: Report) -> salsa::Result<()> { Diagnostic(format!( "Error in file {}: {:?}\n", - file.path(db) + file.path(db)? .file_name() .unwrap_or_else(|| "".as_ref()) .to_string_lossy(), error, )) .accumulate(db); + + Ok(()) } } @@ -159,14 +161,14 @@ struct ParsedFile<'db> { } #[salsa::tracked] -fn compile(db: &dyn Db, input: File) -> u32 { - let parsed = parse(db, input); +fn compile(db: &dyn Db, input: File) -> salsa::Result { + let parsed = parse(db, input)?; sum(db, parsed) } #[salsa::tracked] -fn parse(db: &dyn Db, input: File) -> ParsedFile<'_> { - let mut lines = input.contents(db).lines(); +fn parse(db: &dyn Db, input: File) -> salsa::Result> { + let mut lines = input.contents(db)?.lines(); let value = match lines.next().map(|line| (line.parse::(), line)) { Some((Ok(num), _)) => num, Some((Err(e), line)) => { @@ -177,46 +179,48 @@ fn parse(db: &dyn Db, input: File) -> ParsedFile<'_> { "First line ({}) could not be parsed as an integer", line )), - ); + )?; 0 } None => { - Diagnostic::push_error(db, input, eyre!("File must contain an integer")); + Diagnostic::push_error(db, input, eyre!("File must contain an integer"))?; 0 } }; - let links = lines - .filter_map(|path| { - let relative_path = match path.parse::() { - Ok(path) => path, - Err(err) => { - Diagnostic::push_error( - db, - input, - Report::new(err).wrap_err(format!("Failed to parse path: {}", path)), - ); - return None; - } - }; - let link_path = input.path(db).parent().unwrap().join(relative_path); - match db.input(link_path) { - Ok(file) => Some(parse(db, file)), - Err(err) => { - Diagnostic::push_error(db, input, err); - None - } + + let mut links = Vec::new(); + + for path in lines { + let relative_path = match path.parse::() { + Ok(path) => path, + Err(err) => { + Diagnostic::push_error( + db, + input, + Report::new(err).wrap_err(format!("Failed to parse path: {}", path)), + )?; + continue; } - }) - .collect(); + }; + let link_path = input.path(db)?.parent().unwrap().join(relative_path); + match db.input(link_path) { + Ok(file) => links.push(parse(db, file)?), + Err(err) => { + Diagnostic::push_error(db, input, err)?; + } + } + } + ParsedFile::new(db, value, links) } #[salsa::tracked] -fn sum<'db>(db: &'db dyn Db, input: ParsedFile<'db>) -> u32 { - input.value(db) - + input - .links(db) - .iter() - .map(|&file| sum(db, file)) - .sum::() +fn sum<'db>(db: &'db dyn Db, input: ParsedFile<'db>) -> salsa::Result { + let mut links_sum = 0u32; + + for link in input.links(db)? { + links_sum += sum(db, *link)?; + } + + Ok(input.value(db)? + links_sum) } diff --git a/src/accumulator.rs b/src/accumulator.rs index 01df30fda..3be3cabe8 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -146,7 +146,7 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { + ) -> crate::Result { panic!("nothing should ever depend on an accumulator directly") } diff --git a/src/cancelled.rs b/src/cancelled.rs deleted file mode 100644 index 6c5a6e4cf..000000000 --- a/src/cancelled.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::{ - fmt, - panic::{self, UnwindSafe}, -}; - -/// A panic payload indicating that execution of a salsa query was cancelled. -/// -/// This can occur for a few reasons: -/// * -/// * -/// * -#[derive(Debug)] -#[non_exhaustive] -pub enum Cancelled { - /// The query was operating on revision R, but there is a pending write to move to revision R+1. - #[non_exhaustive] - PendingWrite, - - /// The query was blocked on another thread, and that thread panicked. - #[non_exhaustive] - PropagatedPanic, -} - -impl Cancelled { - pub(crate) fn throw(self) -> ! { - // We use resume and not panic here to avoid running the panic - // hook (that is, to avoid collecting and printing backtrace). - std::panic::resume_unwind(Box::new(self)); - } - - /// Runs `f`, and catches any salsa cancellation. - pub fn catch(f: F) -> Result - where - F: FnOnce() -> T + UnwindSafe, - { - match panic::catch_unwind(f) { - Ok(t) => Ok(t), - Err(payload) => match payload.downcast() { - Ok(cancelled) => Err(*cancelled), - Err(payload) => panic::resume_unwind(payload), - }, - } - } -} - -impl std::fmt::Display for Cancelled { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let why = match self { - Cancelled::PendingWrite => "pending write", - Cancelled::PropagatedPanic => "propagated panic", - }; - f.write_str("cancelled because of ")?; - f.write_str(why) - } -} - -impl std::error::Error for Cancelled {} diff --git a/src/cycle.rs b/src/cycle.rs index 6071aa309..6f7306363 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,5 +1,5 @@ use crate::{key::DatabaseKeyIndex, Database}; -use std::{panic::AssertUnwindSafe, sync::Arc}; +use std::sync::Arc; /// Captures the participants of a cycle that occurred when executing a query. /// @@ -33,21 +33,6 @@ impl Cycle { Arc::ptr_eq(&self.participants, &cycle.participants) } - pub(crate) fn throw(self) -> ! { - tracing::debug!("throwing cycle {:?}", self); - std::panic::resume_unwind(Box::new(self)) - } - - pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { - match std::panic::catch_unwind(AssertUnwindSafe(execute)) { - Ok(v) => Ok(v), - Err(err) => match err.downcast::() { - Ok(cycle) => Err(*cycle), - Err(other) => std::panic::resume_unwind(other), - }, - } - } - /// Iterate over the [`DatabaseKeyIndex`] for each query participating /// in the cycle. The start point of this iteration within the cycle /// is arbitrary but deterministic, but the ordering is otherwise determined diff --git a/src/function.rs b/src/function.rs index 24cbb1308..1e9fe6301 100644 --- a/src/function.rs +++ b/src/function.rs @@ -8,7 +8,7 @@ use crate::{ salsa_struct::SalsaStructInDb, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, - Cycle, Database, Id, Revision, + Cycle, Database, Id, Result, Revision, }; use self::delete::DeletedEntries; @@ -64,7 +64,7 @@ pub trait Configuration: Any { /// computed it before or because the old one relied on inputs that have changed. /// /// This invokes the function the user wrote. - fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; + fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Result>; /// If the cycle strategy is `Fallback`, then invoked when `key` is a participant /// in a cycle to find out what value it should have. @@ -194,7 +194,7 @@ where db: &dyn Database, input: Option, revision: Revision, - ) -> bool { + ) -> crate::Result { let key = input.unwrap(); let db = db.as_view::(); self.maybe_changed_after(db, key, revision) @@ -256,3 +256,12 @@ where .finish() } } + +// Consider replacing with `Try` when it stabilizes. +pub trait HasOutput { + type Output; +} + +impl HasOutput for std::result::Result { + type Output = T; +} diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 5f9ccc2a0..9c840caea 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -10,7 +10,7 @@ where { /// Helper used by `accumulate` functions. Computes the results accumulated by `database_key_index` /// and its inputs. - pub fn accumulated_by(&self, db: &C::DbView, key: Id) -> Vec + pub fn accumulated_by(&self, db: &C::DbView, key: Id) -> crate::Result> where A: accumulator::Accumulator, { @@ -19,12 +19,12 @@ where let current_revision = zalsa.current_revision(); let Some(accumulator) = >::from_db(db) else { - return vec![]; + return Ok(vec![]); }; let mut output = vec![]; // First ensure the result is up to date - self.fetch(db, key); + self.fetch(db, key)?; let db = db.as_dyn_database(); let db_key = self.database_key_index(key); @@ -50,6 +50,6 @@ where } } - output + Ok(output) } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 0f1876b6e..65788bdf8 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,12 +1,12 @@ use std::sync::Arc; +use super::{memo::Memo, Configuration, IngredientImpl}; +use crate::result::Error; use crate::{ - runtime::StampedValue, zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, - Event, EventKind, + runtime::StampedValue, zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Database, Event, + EventKind, }; -use super::{memo::Memo, Configuration, IngredientImpl}; - impl IngredientImpl where C: Configuration, @@ -25,7 +25,7 @@ where db: &'db C::DbView, active_query: ActiveQueryGuard<'_>, opt_old_memo: Option>>>, - ) -> StampedValue<&C::Output<'db>> { + ) -> crate::Result>> { let zalsa = db.zalsa(); let revision_now = zalsa.current_revision(); let database_key_index = active_query.database_key_index; @@ -49,29 +49,35 @@ where // stale, or value is absent. Let's execute! let database_key_index = active_query.database_key_index; let id = database_key_index.key_index; - let value = match Cycle::catch(|| C::execute(db, C::id_to_input(db, id))) { + let value = match C::execute(db, C::id_to_input(db, id)) { Ok(v) => v, - Err(cycle) => { - tracing::debug!( - "{database_key_index:?}: caught cycle {cycle:?}, have strategy {:?}", - C::CYCLE_STRATEGY - ); - match C::CYCLE_STRATEGY { - crate::cycle::CycleRecoveryStrategy::Panic => cycle.throw(), - crate::cycle::CycleRecoveryStrategy::Fallback => { - if let Some(c) = active_query.take_cycle() { - assert!(c.is(&cycle)); - C::recover_from_cycle(db, &cycle, C::id_to_input(db, id)) - } else { - // we are not a participant in this cycle - debug_assert!(!cycle - .participant_keys() - .any(|k| k == database_key_index)); - cycle.throw() + Err(error) => match error.into_cycle() { + Ok(cycle) => { + tracing::debug!( + "{database_key_index:?}: caught cycle {cycle:?}, have strategy {:?}", + C::CYCLE_STRATEGY + ); + match C::CYCLE_STRATEGY { + crate::cycle::CycleRecoveryStrategy::Panic => { + // Propagate the cycle to the parent query for recovery. + return Err(Error::cycle(cycle)); + } + crate::cycle::CycleRecoveryStrategy::Fallback => { + if let Some(c) = active_query.take_cycle() { + assert!(c.is(&cycle)); + C::recover_from_cycle(db, &cycle, C::id_to_input(db, id)) + } else { + // we are not a participant in this cycle + debug_assert!(!cycle + .participant_keys() + .any(|k| k == database_key_index)); + return Err(Error::cycle(cycle)); + } } } } - } + Err(error) => return Err(error), + }, }; let mut revisions = active_query.pop(); @@ -91,6 +97,6 @@ where .insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) .unwrap(); - stamp_template.stamp(value) + Ok(stamp_template.stamp(value)) } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ff77b95c2..d8016b1a7 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -6,23 +6,27 @@ impl IngredientImpl where C: Configuration, { - pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &C::Output<'db> { + pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> crate::Result<&C::Output<'db>> { let (zalsa, zalsa_local) = db.zalsas(); - zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); + zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database())?; let StampedValue { value, durability, changed_at, - } = self.compute_value(db, id); + } = self.compute_value(db, id)?; if let Some(evicted) = self.lru.record_use(id) { self.evict_value_from_memo_for(zalsa, evicted); } - zalsa_local.report_tracked_read(self.database_key_index(id).into(), durability, changed_at); + zalsa_local.report_tracked_read( + self.database_key_index(id).into(), + durability, + changed_at, + )?; - value + Ok(value) } #[inline] @@ -30,10 +34,14 @@ where &'db self, db: &'db C::DbView, id: Id, - ) -> StampedValue<&'db C::Output<'db>> { + ) -> crate::Result>> { loop { - if let Some(value) = self.fetch_hot(db, id).or_else(|| self.fetch_cold(db, id)) { - return value; + if let Some(value) = self.fetch_hot(db, id) { + return Ok(value); + } + + if let Some(value) = self.fetch_cold(db, id)? { + return Ok(value); } } } @@ -64,11 +72,12 @@ where &'db self, db: &'db C::DbView, id: Id, - ) -> Option>> { + ) -> crate::Result>>> { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. + // FIXME: Handle error let _claim_guard = zalsa.sync_table_for(id).claim( db.as_dyn_database(), zalsa_local, @@ -83,15 +92,21 @@ where let zalsa = db.zalsa(); let opt_old_memo = self.get_memo_from_table_for(zalsa, id); if let Some(old_memo) = &opt_old_memo { - if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) { - let value = unsafe { - // Unsafety invariant: memo is present in memo_map. - self.extend_memo_lifetime(old_memo).unwrap() - }; - return Some(old_memo.revisions.stamped_value(value)); + if old_memo.value.is_some() { + match self.deep_verify_memo(db, old_memo, &active_query) { + Ok(true) => { + let value = unsafe { + // Unsafety invariant: memo is present in memo_map. + self.extend_memo_lifetime(old_memo).unwrap() + }; + return Ok(Some(old_memo.revisions.stamped_value(value))); + } + Err(error) => return Err(error), + _ => {} + } } } - Some(self.execute(db, active_query, opt_old_memo)) + self.execute(db, active_query, opt_old_memo).map(Some) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 741a34dd4..d8de633bb 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -17,9 +17,9 @@ where db: &'db C::DbView, id: Id, revision: Revision, - ) -> bool { + ) -> crate::Result { let (zalsa, zalsa_local) = db.zalsas(); - zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); + zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database())?; loop { let database_key_index = self.database_key_index(id); @@ -30,17 +30,17 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { - return memo.revisions.changed_at > revision; + return Ok(memo.revisions.changed_at > revision); } drop(memo_guard); // release the arc-swap guard before cold path - if let Some(mcs) = self.maybe_changed_after_cold(db, id, revision) { - return mcs; + if let Some(mcs) = self.maybe_changed_after_cold(db, id, revision)? { + return Ok(mcs); } else { // We failed to claim, have to retry. } } else { // No memo? Assume has changed. - return true; + return Ok(true); } } } @@ -50,7 +50,7 @@ where db: &'db C::DbView, key_index: Id, revision: Revision, - ) -> Option { + ) -> crate::Result> { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); @@ -64,7 +64,7 @@ where // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { - return Some(true); + return Ok(Some(true)); }; tracing::debug!( @@ -74,8 +74,12 @@ where ); // Check if the inputs are still valid and we can just compare `changed_at`. - if self.deep_verify_memo(db, &old_memo, &active_query) { - return Some(old_memo.revisions.changed_at > revision); + match self.deep_verify_memo(db, &old_memo, &active_query) { + Ok(true) => { + return Ok(Some(old_memo.revisions.changed_at > revision)); + } + Err(error) => return Err(error), + _ => {} } // If inputs have changed, but we have an old value, we can re-execute. @@ -83,12 +87,12 @@ where // backdated. In that case, although we will have computed a new memo, // the value has not logically changed. if old_memo.value.is_some() { - let StampedValue { changed_at, .. } = self.execute(db, active_query, Some(old_memo)); - return Some(changed_at > revision); + let StampedValue { changed_at, .. } = self.execute(db, active_query, Some(old_memo))?; + return Ok(Some(changed_at > revision)); } // Otherwise, nothing for it: have to consider the value to have changed. - Some(true) + Ok(Some(true)) } /// True if the memo's value and `changed_at` time is still valid in this revision. @@ -138,7 +142,7 @@ where db: &C::DbView, old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, - ) -> bool { + ) -> crate::Result { let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; @@ -148,7 +152,7 @@ where ); if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo) { - return true; + return Ok(true); } match &old_memo.revisions.origin { @@ -164,15 +168,15 @@ where // Conditionally specified queries // where the value is specified // in rev 1 but not in rev 2. - return false; + return Ok(false); } QueryOrigin::BaseInput => { // This value was `set` by the mutator thread -- ie, it's a base input and it cannot be out of date. - return true; + return Ok(true); } QueryOrigin::DerivedUntracked(_) => { // Untracked inputs? Have to assume that it changed. - return false; + return Ok(false); } QueryOrigin::Derived(edges) => { // Fully tracked inputs? Iterate over the inputs and check them, one by one. @@ -186,9 +190,9 @@ where match edge_kind { EdgeKind::Input => { if dependency_index - .maybe_changed_after(db.as_dyn_database(), last_verified_at) + .maybe_changed_after(db.as_dyn_database(), last_verified_at)? { - return false; + return Ok(false); } } EdgeKind::Output => { @@ -221,6 +225,6 @@ where zalsa.current_revision(), database_key_index, ); - true + Ok(true) } } diff --git a/src/ingredient.rs b/src/ingredient.rs index 2d6067158..531466b56 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -37,7 +37,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &'db dyn Database, input: Option, revision: Revision, - ) -> bool; + ) -> crate::Result; /// What were the inputs (if any) that were used to create the value at `key_index`. fn origin(&self, db: &dyn Database, key_index: Id) -> Option; diff --git a/src/input.rs b/src/input.rs index ce35e1b9e..605aca5f5 100644 --- a/src/input.rs +++ b/src/input.rs @@ -179,7 +179,7 @@ impl IngredientImpl { db: &'db dyn crate::Database, id: C::Struct, field_index: usize, - ) -> &'db C::Fields { + ) -> crate::Result<&'db C::Fields> { let (zalsa, zalsa_local) = db.zalsas(); let field_ingredient_index = self.ingredient_index.successor(field_index); let id = id.as_id(); @@ -192,8 +192,8 @@ impl IngredientImpl { }, stamp.durability, stamp.changed_at, - ); - &value.fields + )?; + Ok(&value.fields) } /// Peek at the field values without recording any read dependency. @@ -216,10 +216,10 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { + ) -> crate::Result { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. - false + Ok(false) } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 505aaf4f0..8f322a659 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -54,11 +54,11 @@ where db: &dyn Database, input: Option, revision: Revision, - ) -> bool { + ) -> crate::Result { let zalsa = db.zalsa(); let input = input.unwrap(); let value = >::data(zalsa, input); - value.stamps[self.field_index].changed_at > revision + Ok(value.stamps[self.field_index].changed_at > revision) } fn origin(&self, _db: &dyn Database, _key_index: Id) -> Option { diff --git a/src/interned.rs b/src/interned.rs index 9ff667a37..616e73595 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -118,8 +118,8 @@ where &'db self, db: &'db dyn crate::Database, data: C::Data<'db>, - ) -> crate::Id { - C::deref_struct(self.intern(db, data)).as_id() + ) -> crate::Result { + Ok(C::deref_struct(self.intern(db, data)?).as_id()) } /// Intern data to a unique reference. @@ -127,13 +127,13 @@ where &'db self, db: &'db dyn crate::Database, data: C::Data<'db>, - ) -> C::Struct<'db> { + ) -> crate::Result> { let zalsa_local = db.zalsa_local(); zalsa_local.report_tracked_read( DependencyIndex::for_table(self.ingredient_index), Durability::MAX, self.reset_at, - ); + )?; // Optimisation to only get read lock on the map if the data has already // been interned. @@ -141,10 +141,10 @@ where if let Some(guard) = self.key_map.get(&internal_data) { let id = *guard; drop(guard); - return C::struct_from_id(id); + return Ok(C::struct_from_id(id)); } - match self.key_map.entry(internal_data.clone()) { + Ok(match self.key_map.entry(internal_data.clone()) { // Data has been interned by a racing call, use that ID instead dashmap::mapref::entry::Entry::Occupied(entry) => { let id = *entry.get(); @@ -168,7 +168,7 @@ where entry.insert(next_id); C::struct_from_id(next_id) } - } + }) } /// Lookup the data for an interned value based on its id. @@ -205,8 +205,8 @@ where _db: &dyn Database, _input: Option, revision: Revision, - ) -> bool { - revision < self.reset_at + ) -> crate::Result { + Ok(revision < self.reset_at) } fn cycle_recovery_strategy(&self) -> crate::cycle::CycleRecoveryStrategy { diff --git a/src/key.rs b/src/key.rs index df49d047b..7e030a399 100644 --- a/src/key.rs +++ b/src/key.rs @@ -51,7 +51,7 @@ impl DependencyIndex { &self, db: &dyn Database, last_verified_at: crate::Revision, - ) -> bool { + ) -> crate::Result { db.zalsa() .lookup_ingredient(self.ingredient_index) .maybe_changed_after(db, self.key_index, last_verified_at) diff --git a/src/lib.rs b/src/lib.rs index c23d9c2e3..1947ec123 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ mod accumulator; mod active_query; mod array; mod attach; -mod cancelled; mod cycle; mod database; mod database_impl; @@ -16,6 +15,7 @@ mod input; mod interned; mod key; mod nonce; +mod result; mod revision; mod runtime; mod salsa_struct; @@ -28,7 +28,6 @@ mod zalsa; mod zalsa_local; pub use self::accumulator::Accumulator; -pub use self::cancelled::Cancelled; pub use self::cycle::Cycle; pub use self::database::AsDynDatabase; pub use self::database::Database; @@ -39,6 +38,7 @@ pub use self::event::EventKind; pub use self::id::Id; pub use self::input::setter::Setter; pub use self::key::DatabaseKeyIndex; +pub use self::result::Result; pub use self::revision::Revision; pub use self::runtime::Runtime; pub use self::storage::Storage; @@ -73,6 +73,7 @@ pub mod plumbing { pub use crate::database::current_revision; pub use crate::database::Database; pub use crate::function::should_backdate_value; + pub use crate::function::HasOutput; pub use crate::id::AsId; pub use crate::id::FromId; pub use crate::id::Id; diff --git a/src/result.rs b/src/result.rs new file mode 100644 index 000000000..a94781858 --- /dev/null +++ b/src/result.rs @@ -0,0 +1,107 @@ +use crate::Cycle; +use drop_bomb::DropBomb; +use std::fmt; +use std::fmt::Debug; + +pub type Result = std::result::Result; + +#[derive(Debug)] + +pub struct Error { + kind: ErrorKind, +} + +impl Error { + pub(crate) fn cancelled(reason: Cancelled) -> Self { + Error { + kind: ErrorKind::Cancelled(reason), + } + } + + pub(crate) fn cycle(cycle: Cycle) -> Self { + Self { + kind: ErrorKind::Cycle(CycleError { + cycle, + bomb: DropBomb::new("TODO"), + }), + } + } + + pub(crate) fn into_cycle(self) -> std::result::Result { + match self.kind { + ErrorKind::Cycle(cycle) => Ok(cycle.take_cycle()), + _ => Err(self), + } + } +} + +impl From for Error { + fn from(value: CycleError) -> Self { + Self { + kind: ErrorKind::Cycle(value), + } + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + ErrorKind::Cycle(cycle) => { + write!(f, "cycle detected: {:?}", cycle) + } + ErrorKind::Cancelled(cancelled) => std::fmt::Display::fmt(cancelled, f), + } + } +} + +impl std::error::Error for Error {} + +#[derive(Debug)] +pub(crate) enum ErrorKind { + Cycle(CycleError), + Cancelled(Cancelled), +} + +#[derive(Debug)] +pub(crate) struct CycleError { + cycle: Cycle, + bomb: DropBomb, +} + +impl CycleError { + pub(crate) fn take_cycle(mut self) -> Cycle { + self.bomb.defuse(); + self.cycle + } +} + +// FIXME implement drop for Cancelled. + +/// A panic payload indicating that execution of a salsa query was cancelled. +/// +/// This can occur for a few reasons: +/// * +/// * +/// * +#[derive(Debug)] +#[non_exhaustive] +pub(crate) enum Cancelled { + /// The query was operating on revision R, but there is a pending write to move to revision R+1. + #[non_exhaustive] + PendingWrite, + + /// The query was blocked on another thread, and that thread panicked. + #[non_exhaustive] + PropagatedPanic, +} + +impl std::fmt::Display for Cancelled { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let why = match self { + Cancelled::PendingWrite => "pending write", + Cancelled::PropagatedPanic => "propagated panic", + }; + f.write_str("cancelled because of ")?; + f.write_str(why) + } +} diff --git a/src/runtime.rs b/src/runtime.rs index ba35f09fc..0ff6e412e 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -7,14 +7,14 @@ use std::{ use crossbeam::atomic::AtomicCell; use parking_lot::Mutex; +use self::dependency_graph::DependencyGraph; +use crate::result::Cancelled; use crate::{ active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, revision::AtomicRevision, table::Table, zalsa_local::ZalsaLocal, - Cancelled, Cycle, Database, Event, EventKind, Revision, + key::DatabaseKeyIndex, revision::AtomicRevision, table::Table, zalsa_local::ZalsaLocal, Cycle, + Database, Event, EventKind, Revision, }; -use self::dependency_graph::DependencyGraph; - mod dependency_graph; pub struct Runtime { @@ -183,12 +183,12 @@ impl Runtime { database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> crate::Result<()> { let mut dg = self.dependency_graph.lock(); let thread_id = std::thread::current().id(); if dg.depends_on(other_id, thread_id) { - self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id); + self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id)?; // If the above fn returns, then (via cycle recovery) it has unblocked the // cycle, so we can continue. @@ -217,14 +217,16 @@ impl Runtime { local_state.restore_query_stack(stack); match result { - WaitResult::Completed => (), + WaitResult::Completed => Ok(()), // If the other thread panicked, then we consider this thread // cancelled. The assumption is that the panic will be detected // by the other thread and responded to appropriately. - WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), + WaitResult::Panicked => { + Err(crate::result::Error::cancelled(Cancelled::PropagatedPanic)) + } - WaitResult::Cycle(c) => c.throw(), + WaitResult::Cycle(c) => Err(crate::result::Error::cycle(c)), } } @@ -243,7 +245,7 @@ impl Runtime { dg: &mut DependencyGraph, database_key_index: DatabaseKeyIndex, to_id: ThreadId, - ) { + ) -> crate::Result<()> { tracing::debug!( "unblock_cycle_and_maybe_throw(database_key={:?})", database_key_index @@ -333,9 +335,10 @@ impl Runtime { if me_recovered { // If the current thread has recovery, we want to throw // so that it can begin. - cycle.throw() + Err(crate::result::Error::cycle(cycle)) } else if others_recovered { // If other threads have recovery but we didn't: return and we will block on them. + Ok(()) } else { // if nobody has recover, then we panic panic_any(cycle); diff --git a/src/table/sync.rs b/src/table/sync.rs index 6a02bc8cf..258f2abce 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -37,7 +37,7 @@ impl SyncTable { zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, - ) -> Option> { + ) -> crate::Result>> { let mut syncs = self.syncs.write(); let zalsa = db.zalsa(); let thread_id = std::thread::current().id(); @@ -50,12 +50,12 @@ impl SyncTable { id: thread_id, anyone_waiting: AtomicBool::new(false), }); - Some(ClaimGuard { + Ok(Some(ClaimGuard { database_key_index, memo_ingredient_index, zalsa, sync_table: self, - }) + })) } Some(SyncState { id: other_id, @@ -68,8 +68,8 @@ impl SyncTable { // boolean is to decide *whether* to acquire the lock, // not to gate future atomic reads. anyone_waiting.store(true, Ordering::Relaxed); - zalsa.block_on_or_unwind(db, zalsa_local, database_key_index, *other_id, syncs); - None + zalsa.block_on_or_unwind(db, zalsa_local, database_key_index, *other_id, syncs)?; + Ok(None) } } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 6b83d0ae3..4f2d71591 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -252,13 +252,13 @@ where &'db self, db: &'db dyn Database, fields: C::Fields<'db>, - ) -> C::Struct<'db> { + ) -> crate::Result> { let (zalsa, zalsa_local) = db.zalsas(); let data_hash = crate::hash::hash(&C::id_fields(&fields)); let (current_deps, disambiguator) = - zalsa_local.disambiguate(self.ingredient_index, Revision::start(), data_hash); + zalsa_local.disambiguate(self.ingredient_index, Revision::start(), data_hash)?; let key_struct = KeyStruct { disambiguator, @@ -266,7 +266,7 @@ where }; let current_revision = zalsa.current_revision(); - match zalsa_local.tracked_struct_id(&key_struct) { + Ok(match zalsa_local.tracked_struct_id(&key_struct) { Some(id) => { // The struct already exists in the intern map. zalsa_local.add_output(self.database_key_index(id).into()); @@ -281,7 +281,7 @@ where zalsa_local.store_tracked_struct_id(key_struct, id); C::struct_from_id(id) } - } + }) } fn allocate<'db>( @@ -521,7 +521,7 @@ where db: &'db dyn crate::Database, s: C::Struct<'db>, field_index: usize, - ) -> &'db C::Fields<'db> { + ) -> crate::Result<&'db C::Fields<'db>> { let (zalsa, zalsa_local) = db.zalsas(); let id = C::deref_struct(s); let field_ingredient_index = self.ingredient_index.successor(field_index); @@ -538,9 +538,9 @@ where }, data.durability, field_changed_at, - ); + )?; - unsafe { self.to_self_ref(&data.fields) } + Ok(unsafe { self.to_self_ref(&data.fields) }) } } @@ -557,8 +557,8 @@ where _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { - false + ) -> crate::Result { + Ok(false) } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index d5c214ade..42fc6e6de 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -53,12 +53,12 @@ where db: &'db dyn Database, input: Option, revision: crate::Revision, - ) -> bool { + ) -> crate::Result { let zalsa = db.zalsa(); let id = input.unwrap(); let data = >::data(zalsa.table(), id); let field_changed_at = data.revisions[self.field_index]; - field_changed_at > revision + Ok(field_changed_at > revision) } fn origin( diff --git a/src/zalsa.rs b/src/zalsa.rs index 2f8fa95f8..fb50c50c9 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -187,32 +187,31 @@ impl Zalsa { let jar_type_id = jar.type_id(); let mut jar_map = self.jar_map.lock(); *jar_map - .entry(jar_type_id) - .or_insert_with(|| { - let index = IngredientIndex::from(self.ingredients_vec.len()); - let ingredients = jar.create_ingredients(self, index); - for ingredient in ingredients { - let expected_index = ingredient.ingredient_index(); - - if ingredient.requires_reset_for_new_revision() { - self.ingredients_requiring_reset.push(expected_index); + .entry(jar_type_id) + .or_insert_with(|| { + let index = IngredientIndex::from(self.ingredients_vec.len()); + let ingredients = jar.create_ingredients(self, index); + for ingredient in ingredients { + let expected_index = ingredient.ingredient_index(); + + if ingredient.requires_reset_for_new_revision() { + self.ingredients_requiring_reset.push(expected_index); + } + + let actual_index = self + .ingredients_vec + .push(ingredient); + assert_eq!( + expected_index.as_usize(), + actual_index, + "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", + self.ingredients_vec[actual_index], + expected_index, + actual_index, + ); } - - let actual_index = self - .ingredients_vec - .push(ingredient); - assert_eq!( - expected_index.as_usize(), - actual_index, - "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", - self.ingredients_vec[actual_index], - expected_index, - actual_index, - ); - - } - index - }) + index + }) } } @@ -273,7 +272,7 @@ impl Zalsa { database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> crate::Result<()> { self.runtime .block_on_or_unwind(db, local_state, database_key, other_id, query_mutex_guard) } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index e70cd38f9..8a571faa1 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -5,6 +5,7 @@ use crate::active_query::ActiveQuery; use crate::durability::Durability; use crate::key::DatabaseKeyIndex; use crate::key::DependencyIndex; +use crate::result::Cancelled; use crate::runtime::StampedValue; use crate::table::PageIndex; use crate::table::Slot; @@ -12,7 +13,6 @@ use crate::table::Table; use crate::tracked_struct::Disambiguator; use crate::tracked_struct::KeyStruct; use crate::zalsa::IngredientIndex; -use crate::Cancelled; use crate::Cycle; use crate::Database; use crate::Event; @@ -151,7 +151,7 @@ impl ZalsaLocal { input: DependencyIndex, durability: Durability, changed_at: Revision, - ) { + ) -> crate::Result<()> { debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", input, durability, changed_at @@ -182,9 +182,10 @@ impl ZalsaLocal { // stack frames, so they will just read the fallback value // from `Ci+1` and continue on their merry way. if let Some(cycle) = &top_query.cycle { - cycle.clone().throw() + return Err(crate::result::Error::cycle(cycle.clone())); } } + Ok(()) }) } @@ -247,7 +248,7 @@ impl ZalsaLocal { entity_index: IngredientIndex, reset_at: Revision, data_hash: u64, - ) -> (StampedValue<()>, Disambiguator) { + ) -> crate::Result<(StampedValue<()>, Disambiguator)> { assert!( self.query_in_progress(), "cannot create a tracked struct disambiguator outside of a tracked function" @@ -257,9 +258,9 @@ impl ZalsaLocal { DependencyIndex::for_table(entity_index), Durability::MAX, reset_at, - ); + )?; - self.with_query_stack(|stack| { + Ok(self.with_query_stack(|stack| { let top_query = stack.last_mut().unwrap(); let disambiguator = top_query.disambiguate(data_hash); ( @@ -270,7 +271,7 @@ impl ZalsaLocal { }, disambiguator, ) - }) + })) } #[track_caller] @@ -313,23 +314,24 @@ impl ZalsaLocal { /// This method should not be overridden by `Database` implementors. A /// `salsa_event` is emitted when this method is called, so that should be /// used instead. - pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) { + pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) -> crate::Result<()> { let thread_id = std::thread::current().id(); db.salsa_event(&|| Event { thread_id, - kind: EventKind::WillCheckCancellation, }); let zalsa = db.zalsa(); if zalsa.load_cancellation_flag() { - self.unwind_cancelled(zalsa.current_revision()); + return self.unwind_cancelled(zalsa.current_revision()); } + + Ok(()) } #[cold] - pub(crate) fn unwind_cancelled(&self, current_revision: Revision) { + pub(crate) fn unwind_cancelled(&self, current_revision: Revision) -> crate::Result<()> { self.report_untracked_read(current_revision); - Cancelled::PendingWrite.throw(); + Err(crate::result::Error::cancelled(Cancelled::PendingWrite)) } } diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs index b0d79bd48..f8cabfed9 100644 --- a/tests/accumulate-chain.rs +++ b/tests/accumulate-chain.rs @@ -11,37 +11,43 @@ use test_log::test; struct Log(#[allow(dead_code)] String); #[salsa::tracked] -fn push_logs(db: &dyn Database) { - push_a_logs(db); +fn push_logs(db: &dyn Database) -> salsa::Result<()> { + push_a_logs(db)?; + Ok(()) } #[salsa::tracked] -fn push_a_logs(db: &dyn Database) { +fn push_a_logs(db: &dyn Database) -> salsa::Result<()> { Log("log a".to_string()).accumulate(db); - push_b_logs(db); + push_b_logs(db)?; + + Ok(()) } #[salsa::tracked] -fn push_b_logs(db: &dyn Database) { +fn push_b_logs(db: &dyn Database) -> salsa::Result<()> { // No logs - push_c_logs(db); + push_c_logs(db)?; + Ok(()) } #[salsa::tracked] -fn push_c_logs(db: &dyn Database) { +fn push_c_logs(db: &dyn Database) -> salsa::Result<()> { // No logs - push_d_logs(db); + push_d_logs(db)?; + Ok(()) } #[salsa::tracked] -fn push_d_logs(db: &dyn Database) { +fn push_d_logs(db: &dyn Database) -> salsa::Result<()> { Log("log d".to_string()).accumulate(db); + Ok(()) } #[test] -fn accumulate_chain() { +fn accumulate_chain() -> salsa::Result<()> { DatabaseImpl::new().attach(|db| { - let logs = push_logs::accumulated::(db); + let logs = push_logs::accumulated::(db)?; // Check that we get all the logs. expect![[r#" [ @@ -53,5 +59,7 @@ fn accumulate_chain() { ), ]"#]] .assert_eq(&format!("{:#?}", logs)); + + Ok(()) }) } diff --git a/tests/accumulate-custom-clone.rs b/tests/accumulate-custom-clone.rs index 81612b318..1b98507f1 100644 --- a/tests/accumulate-custom-clone.rs +++ b/tests/accumulate-custom-clone.rs @@ -19,17 +19,19 @@ impl Clone for Log { } #[salsa::tracked] -fn push_logs(db: &dyn salsa::Database, input: MyInput) { - for i in 0..input.count(db) { +fn push_logs(db: &dyn salsa::Database, input: MyInput) -> salsa::Result<()> { + for i in 0..input.count(db)? { Log(format!("#{i}")).accumulate(db); } + + Ok(()) } #[test] -fn accumulate_custom_clone() { +fn accumulate_custom_clone() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 2); - let logs = push_logs::accumulated::(db, input); + let logs = push_logs::accumulated::(db, input)?; expect![[r##" [ Log( @@ -41,5 +43,6 @@ fn accumulate_custom_clone() { ] "##]] .assert_debug_eq(&logs); + Ok(()) }) } diff --git a/tests/accumulate-custom-debug.rs b/tests/accumulate-custom-debug.rs index 71a4ba86e..297a3753b 100644 --- a/tests/accumulate-custom-debug.rs +++ b/tests/accumulate-custom-debug.rs @@ -19,17 +19,19 @@ impl std::fmt::Debug for Log { } #[salsa::tracked] -fn push_logs(db: &dyn salsa::Database, input: MyInput) { - for i in 0..input.count(db) { +fn push_logs(db: &dyn salsa::Database, input: MyInput) -> salsa::Result<()> { + for i in 0..input.count(db)? { Log(format!("#{i}")).accumulate(db); } + + Ok(()) } #[test] -fn accumulate_custom_debug() { +fn accumulate_custom_debug() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 2); - let logs = push_logs::accumulated::(db, input); + let logs = push_logs::accumulated::(db, input)?; expect![[r##" [ CustomLog( @@ -41,5 +43,7 @@ fn accumulate_custom_debug() { ] "##]] .assert_debug_eq(&logs); + + Ok(()) }) } diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs index e23050ba4..bf24c6250 100644 --- a/tests/accumulate-dag.rs +++ b/tests/accumulate-dag.rs @@ -14,34 +14,40 @@ struct MyInput { struct Log(#[allow(dead_code)] String); #[salsa::tracked] -fn push_logs(db: &dyn Database, input: MyInput) { - push_a_logs(db, input); - push_b_logs(db, input); +fn push_logs(db: &dyn Database, input: MyInput) -> salsa::Result<()> { + push_a_logs(db, input)?; + push_b_logs(db, input)?; + + Ok(()) } #[salsa::tracked] -fn push_a_logs(db: &dyn Database, input: MyInput) { - let count = input.field_a(db); +fn push_a_logs(db: &dyn Database, input: MyInput) -> salsa::Result<()> { + let count = input.field_a(db)?; for i in 0..count { Log(format!("log_a({} of {})", i, count)).accumulate(db); } + + Ok(()) } #[salsa::tracked] -fn push_b_logs(db: &dyn Database, input: MyInput) { +fn push_b_logs(db: &dyn Database, input: MyInput) -> salsa::Result<()> { // Note that b calls a - push_a_logs(db, input); - let count = input.field_b(db); + push_a_logs(db, input)?; + let count = input.field_b(db)?; for i in 0..count { Log(format!("log_b({} of {})", i, count)).accumulate(db); } + + Ok(()) } #[test] -fn accumulate_a_called_twice() { +fn accumulate_a_called_twice() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, 2, 3); - let logs = push_logs::accumulated::(db, input); + let logs = push_logs::accumulated::(db, input)?; // Check that we don't see logs from `a` appearing twice in the input. expect![[r#" [ @@ -62,5 +68,7 @@ fn accumulate_a_called_twice() { ), ]"#]] .assert_eq(&format!("{:#?}", logs)); + + Ok(()) }) } diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs index edb82e487..72b7332ff 100644 --- a/tests/accumulate-execution-order.rs +++ b/tests/accumulate-execution-order.rs @@ -11,38 +11,40 @@ use test_log::test; struct Log(#[allow(dead_code)] String); #[salsa::tracked] -fn push_logs(db: &dyn Database) { - push_a_logs(db); +fn push_logs(db: &dyn Database) -> salsa::Result<()> { + push_a_logs(db) } #[salsa::tracked] -fn push_a_logs(db: &dyn Database) { +fn push_a_logs(db: &dyn Database) -> salsa::Result<()> { Log("log a".to_string()).accumulate(db); - push_b_logs(db); - push_c_logs(db); - push_d_logs(db); + push_b_logs(db)?; + push_c_logs(db)?; + push_d_logs(db) } #[salsa::tracked] -fn push_b_logs(db: &dyn Database) { +fn push_b_logs(db: &dyn Database) -> salsa::Result<()> { Log("log b".to_string()).accumulate(db); - push_d_logs(db); + push_d_logs(db) } #[salsa::tracked] -fn push_c_logs(db: &dyn Database) { +fn push_c_logs(db: &dyn Database) -> salsa::Result<()> { Log("log c".to_string()).accumulate(db); + Ok(()) } #[salsa::tracked] -fn push_d_logs(db: &dyn Database) { +fn push_d_logs(db: &dyn Database) -> salsa::Result<()> { Log("log d".to_string()).accumulate(db); + Ok(()) } #[test] -fn accumulate_execution_order() { +fn accumulate_execution_order() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { - let logs = push_logs::accumulated::(db); + let logs = push_logs::accumulated::(db)?; // Check that we get logs in execution order expect![[r#" [ @@ -60,5 +62,7 @@ fn accumulate_execution_order() { ), ]"#]] .assert_eq(&format!("{:#?}", logs)); + + Ok(()) }) } diff --git a/tests/accumulate-from-tracked-fn.rs b/tests/accumulate-from-tracked-fn.rs index 33d7bd3f5..33baba19a 100644 --- a/tests/accumulate-from-tracked-fn.rs +++ b/tests/accumulate-from-tracked-fn.rs @@ -17,34 +17,36 @@ struct List { struct Integers(u32); #[salsa::tracked] -fn compute(db: &dyn salsa::Database, input: List) { +fn compute(db: &dyn salsa::Database, input: List) -> salsa::Result<()> { eprintln!( "{:?}(value={:?}, next={:?})", input, - input.value(db), - input.next(db) + input.value(db)?, + input.next(db)? ); - let result = if let Some(next) = input.next(db) { - let next_integers = compute::accumulated::(db, next); + let result = if let Some(next) = input.next(db)? { + let next_integers = compute::accumulated::(db, next)?; eprintln!("{:?}", next_integers); - let v = input.value(db) + next_integers.iter().map(|a| a.0).sum::(); + let v = input.value(db)? + next_integers.iter().map(|a| a.0).sum::(); eprintln!("input={:?} v={:?}", input.value(db), v); v } else { - input.value(db) + input.value(db)? }; Integers(result).accumulate(db); eprintln!("pushed result {:?}", result); + + Ok(()) } #[test] -fn test1() { +fn test1() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let l0 = List::new(&db, 1, None); let l1 = List::new(&db, 10, Some(l0)); - compute(&db, l1); + compute(&db, l1)?; expect![[r#" [ Integers( @@ -55,10 +57,10 @@ fn test1() { ), ] "#]] - .assert_debug_eq(&compute::accumulated::(&db, l1)); + .assert_debug_eq(&compute::accumulated::(&db, l1)?); l0.set_value(&mut db).to(2); - compute(&db, l1); + compute(&db, l1)?; expect![[r#" [ Integers( @@ -69,5 +71,7 @@ fn test1() { ), ] "#]] - .assert_debug_eq(&compute::accumulated::(&db, l1)); + .assert_debug_eq(&compute::accumulated::(&db, l1)?); + + Ok(()) } diff --git a/tests/accumulate-no-duplicates.rs b/tests/accumulate-no-duplicates.rs index faf8c03af..8a1a6cf1f 100644 --- a/tests/accumulate-no-duplicates.rs +++ b/tests/accumulate-no-duplicates.rs @@ -30,51 +30,53 @@ struct MyInput { } #[salsa::tracked] -fn push_logs(db: &dyn Database) { - push_a_logs(db, MyInput::new(db, 1)); +fn push_logs(db: &dyn Database) -> salsa::Result<()> { + push_a_logs(db, MyInput::new(db, 1)) } #[salsa::tracked] -fn push_a_logs(db: &dyn Database, input: MyInput) { +fn push_a_logs(db: &dyn Database, input: MyInput) -> salsa::Result<()> { Log("log a".to_string()).accumulate(db); - if input.n(db) == 1 { - push_b_logs(db); - push_b_logs(db); - push_c_logs(db); - push_b_logs(db); + if input.n(db)? == 1 { + push_b_logs(db)?; + push_b_logs(db)?; + push_c_logs(db)?; + push_b_logs(db) } else { - push_b_logs(db); + push_b_logs(db) } } #[salsa::tracked] -fn push_b_logs(db: &dyn Database) { +fn push_b_logs(db: &dyn Database) -> salsa::Result<()> { Log("log b".to_string()).accumulate(db); + Ok(()) } #[salsa::tracked] -fn push_c_logs(db: &dyn Database) { +fn push_c_logs(db: &dyn Database) -> salsa::Result<()> { Log("log c".to_string()).accumulate(db); - push_d_logs(db); - push_e_logs(db); + push_d_logs(db)?; + push_e_logs(db) } // Note this isn't tracked -fn push_d_logs(db: &dyn Database) { +fn push_d_logs(db: &dyn Database) -> salsa::Result<()> { Log("log d".to_string()).accumulate(db); - push_a_logs(db, MyInput::new(db, 2)); - push_b_logs(db); + push_a_logs(db, MyInput::new(db, 2))?; + push_b_logs(db) } #[salsa::tracked] -fn push_e_logs(db: &dyn Database) { +fn push_e_logs(db: &dyn Database) -> salsa::Result<()> { Log("log e".to_string()).accumulate(db); + Ok(()) } #[test] -fn accumulate_no_duplicates() { +fn accumulate_no_duplicates() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { - let logs = push_logs::accumulated::(db); + let logs = push_logs::accumulated::(db)?; // Test that there aren't duplicate B logs. // Note that log A appears twice, because they both come // from different inputs. @@ -100,5 +102,7 @@ fn accumulate_no_duplicates() { ), ]"#]] .assert_eq(&format!("{:#?}", logs)); + + Ok(()) }) } diff --git a/tests/accumulate-reuse-workaround.rs b/tests/accumulate-reuse-workaround.rs index d72f971b0..7fa700ad0 100644 --- a/tests/accumulate-reuse-workaround.rs +++ b/tests/accumulate-reuse-workaround.rs @@ -20,41 +20,41 @@ struct List { struct Integers(u32); #[salsa::tracked] -fn compute(db: &dyn LogDatabase, input: List) -> u32 { +fn compute(db: &dyn LogDatabase, input: List) -> salsa::Result { db.push_log(format!("compute({:?})", input,)); // always pushes 0 Integers(0).accumulate(db); - let result = if let Some(next) = input.next(db) { - let next_integers = accumulated(db, next); - let v = input.value(db) + next_integers.iter().sum::(); + let result = if let Some(next) = input.next(db)? { + let next_integers = accumulated(db, next)?; + let v = input.value(db)? + next_integers.iter().sum::(); v } else { - input.value(db) + input.value(db)? }; // return value changes - result + Ok(result) } #[salsa::tracked(return_ref)] -fn accumulated(db: &dyn LogDatabase, input: List) -> Vec { +fn accumulated(db: &dyn LogDatabase, input: List) -> salsa::Result> { db.push_log(format!("accumulated({:?})", input)); - compute::accumulated::(db, input) + Ok(compute::accumulated::(db, input)? .into_iter() .map(|a| a.0) - .collect() + .collect()) } #[test] -fn test1() { +fn test1() -> salsa::Result<()> { let mut db = LoggerDatabase::default(); let l1 = List::new(&db, 1, None); let l2 = List::new(&db, 2, Some(l1)); - assert_eq!(compute(&db, l2), 2); + assert_eq!(compute(&db, l2)?, 2); db.assert_logs(expect![[r#" [ "compute(List { [salsa id]: Id(1), value: 2, next: Some(List { [salsa id]: Id(0), value: 1, next: None }) })", @@ -66,10 +66,12 @@ fn test1() { // and we re-execute accumulated for `l1`, but we do NOT re-execute // `compute` for `l2`. l1.set_value(&mut db).to(2); - assert_eq!(compute(&db, l2), 2); + assert_eq!(compute(&db, l2)?, 2); db.assert_logs(expect![[r#" [ "accumulated(List { [salsa id]: Id(0), value: 2, next: None })", "compute(List { [salsa id]: Id(0), value: 2, next: None })", ]"#]]); + + Ok(()) } diff --git a/tests/accumulate-reuse.rs b/tests/accumulate-reuse.rs index b9962849a..a8adc0c22 100644 --- a/tests/accumulate-reuse.rs +++ b/tests/accumulate-reuse.rs @@ -20,32 +20,32 @@ struct List { struct Integers(u32); #[salsa::tracked] -fn compute(db: &dyn LogDatabase, input: List) -> u32 { +fn compute(db: &dyn LogDatabase, input: List) -> salsa::Result { db.push_log(format!("compute({:?})", input,)); // always pushes 0 Integers(0).accumulate(db); - let result = if let Some(next) = input.next(db) { - let next_integers = compute::accumulated::(db, next); - let v = input.value(db) + next_integers.iter().map(|i| i.0).sum::(); + let result = if let Some(next) = input.next(db)? { + let next_integers = compute::accumulated::(db, next)?; + let v = input.value(db)? + next_integers.iter().map(|i| i.0).sum::(); v } else { - input.value(db) + input.value(db)? }; // return value changes - result + Ok(result) } #[test] -fn test1() { +fn test1() -> salsa::Result<()> { let mut db = LoggerDatabase::default(); let l1 = List::new(&db, 1, None); let l2 = List::new(&db, 2, Some(l1)); - assert_eq!(compute(&db, l2), 2); + assert_eq!(compute(&db, l2)?, 2); db.assert_logs(expect![[r#" [ "compute(List { [salsa id]: Id(1), value: 2, next: Some(List { [salsa id]: Id(0), value: 1, next: None }) })", @@ -57,10 +57,12 @@ fn test1() { // The only input for `compute(l1)` is the accumulated values from `l1`, // which have not changed. l1.set_value(&mut db).to(2); - assert_eq!(compute(&db, l2), 2); + assert_eq!(compute(&db, l2)?, 2); db.assert_logs(expect![[r#" [ "compute(List { [salsa id]: Id(1), value: 2, next: Some(List { [salsa id]: Id(0), value: 2, next: None }) })", "compute(List { [salsa id]: Id(0), value: 2, next: None })", ]"#]]); + + Ok(()) } diff --git a/tests/accumulate.rs b/tests/accumulate.rs index a69c27a1d..35877706d 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -15,51 +15,57 @@ struct MyInput { struct Log(#[allow(dead_code)] String); #[salsa::tracked] -fn push_logs(db: &dyn LogDatabase, input: MyInput) { +fn push_logs(db: &dyn LogDatabase, input: MyInput) -> salsa::Result<()> { db.push_log(format!( "push_logs(a = {}, b = {})", - input.field_a(db), - input.field_b(db) + input.field_a(db)?, + input.field_b(db)? )); // We don't invoke `push_a_logs` (or `push_b_logs`) with a value of 0. // This allows us to test what happens a change in inputs causes a function not to be called at all. - if input.field_a(db) > 0 { - push_a_logs(db, input); + if input.field_a(db)? > 0 { + push_a_logs(db, input)?; } - if input.field_b(db) > 0 { - push_b_logs(db, input); + if input.field_b(db)? > 0 { + push_b_logs(db, input)?; } + + Ok(()) } #[salsa::tracked] -fn push_a_logs(db: &dyn LogDatabase, input: MyInput) { - let field_a = input.field_a(db); +fn push_a_logs(db: &dyn LogDatabase, input: MyInput) -> salsa::Result<()> { + let field_a = input.field_a(db)?; db.push_log(format!("push_a_logs({})", field_a)); for i in 0..field_a { Log(format!("log_a({} of {})", i, field_a)).accumulate(db); } + + Ok(()) } #[salsa::tracked] -fn push_b_logs(db: &dyn LogDatabase, input: MyInput) { - let field_a = input.field_b(db); +fn push_b_logs(db: &dyn LogDatabase, input: MyInput) -> salsa::Result<()> { + let field_a = input.field_b(db)?; db.push_log(format!("push_b_logs({})", field_a)); for i in 0..field_a { Log(format!("log_b({} of {})", i, field_a)).accumulate(db); } + + Ok(()) } #[test] -fn accumulate_once() { +fn accumulate_once() -> salsa::Result<()> { let db = common::LoggerDatabase::default(); // Just call accumulate on a base input to see what happens. let input = MyInput::new(&db, 2, 3); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input)?; db.assert_logs(expect![[r#" [ "push_logs(a = 2, b = 3)", @@ -87,15 +93,16 @@ fn accumulate_once() { ), ]"#]] .assert_eq(&format!("{:#?}", logs)); + Ok(()) } #[test] -fn change_a_from_2_to_0() { +fn change_a_from_2_to_0() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); // Accumulate logs for `a = 2` and `b = 3` let input = MyInput::new(&db, 2, 3); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input)?; expect![[r#" [ Log( @@ -124,7 +131,7 @@ fn change_a_from_2_to_0() { // Change to `a = 0`, which means `push_logs` does not call `push_a_logs` at all input.set_field_a(&mut db).to(0); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input)?; expect![[r#" [ Log( @@ -142,15 +149,16 @@ fn change_a_from_2_to_0() { [ "push_logs(a = 0, b = 3)", ]"#]]); + Ok(()) } #[test] -fn change_a_from_2_to_1() { +fn change_a_from_2_to_1() -> salsa::Result<()> { let mut db = LoggerDatabase::default(); // Accumulate logs for `a = 2` and `b = 3` let input = MyInput::new(&db, 2, 3); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input)?; expect![[r#" [ Log( @@ -179,7 +187,7 @@ fn change_a_from_2_to_1() { // Change to `a = 1`, which means `push_logs` does not call `push_a_logs` at all input.set_field_a(&mut db).to(1); - let logs = push_logs::accumulated::(&db, input); + let logs = push_logs::accumulated::(&db, input)?; expect![[r#" [ Log( @@ -201,15 +209,16 @@ fn change_a_from_2_to_1() { "push_logs(a = 1, b = 3)", "push_a_logs(1)", ]"#]]); + Ok(()) } #[test] -fn get_a_logs_after_changing_b() { +fn get_a_logs_after_changing_b() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); // Invoke `push_a_logs` with `a = 2` and `b = 3` (but `b` doesn't matter) let input = MyInput::new(&db, 2, 3); - let logs = push_a_logs::accumulated::(&db, input); + let logs = push_a_logs::accumulated::(&db, input)?; expect![[r#" [ Log( @@ -228,7 +237,7 @@ fn get_a_logs_after_changing_b() { // Changing `b` does not cause `push_a_logs` to re-execute // and we still get the same result input.set_field_b(&mut db).to(5); - let logs = push_a_logs::accumulated::(&db, input); + let logs = push_a_logs::accumulated::(&db, input)?; expect![[r#" [ Log( @@ -241,4 +250,5 @@ fn get_a_logs_after_changing_b() { "#]] .assert_debug_eq(&logs); db.assert_logs(expect!["[]"]); + Ok(()) } diff --git a/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs b/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs index 3dbb4f2f5..ec0e20ec1 100644 --- a/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs +++ b/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs @@ -11,14 +11,14 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { - MyTracked::new(db, input.field(db) / 2) +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> salsa::Result> { + MyTracked::new(db, input.field(db)? / 2) } fn main() { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); - let tracked = tracked_fn(&db, input); + let tracked = tracked_fn(&db, input).unwrap(); input.set_field(&mut db).to(24); - tracked.field(&db); // tracked comes from prior revision + tracked.field(&db).unwrap(); // tracked comes from prior revision } diff --git a/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.stderr b/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.stderr index d353dd649..9f2d9109d 100644 --- a/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.stderr +++ b/tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.stderr @@ -1,9 +1,9 @@ error[E0502]: cannot borrow `db` as mutable because it is also borrowed as immutable --> tests/compile-fail/panic-when-reading-fields-of-tracked-structs-from-older-revisions.rs:22:21 | -21 | let tracked = tracked_fn(&db, input); +21 | let tracked = tracked_fn(&db, input).unwrap(); | --- immutable borrow occurs here 22 | input.set_field(&mut db).to(24); | ^^^^^^^ mutable borrow occurs here -23 | tracked.field(&db); // tracked comes from prior revision +23 | tracked.field(&db).unwrap(); // tracked comes from prior revision | ------- immutable borrow later used here diff --git a/tests/compile-fail/span-tracked-getter.rs b/tests/compile-fail/span-tracked-getter.rs index 245fd1008..7a7555196 100644 --- a/tests/compile-fail/span-tracked-getter.rs +++ b/tests/compile-fail/span-tracked-getter.rs @@ -4,12 +4,13 @@ pub struct MyTracked<'db> { } #[salsa::tracked] -fn my_fn(db: &dyn salsa::Database) { - let x = MyTracked::new(db, 22); - x.field(22); +fn my_fn(db: &dyn salsa::Database) -> salsa::Result<()> { + let x = MyTracked::new(db, 22)?; + x.field(22)?; + Ok(()) } fn main() { let mut db = salsa::DatabaseImpl::new(); - my_fn(&db); + my_fn(&db).unwrap(); } diff --git a/tests/compile-fail/span-tracked-getter.stderr b/tests/compile-fail/span-tracked-getter.stderr index fcf546c72..0876b9e43 100644 --- a/tests/compile-fail/span-tracked-getter.stderr +++ b/tests/compile-fail/span-tracked-getter.stderr @@ -1,7 +1,7 @@ error[E0308]: mismatched types --> tests/compile-fail/span-tracked-getter.rs:9:13 | -9 | x.field(22); +9 | x.field(22)?; | ----- ^^ expected `&_`, found integer | | | arguments to this method are incorrect @@ -18,13 +18,13 @@ note: method defined here | ^^^^^ help: consider borrowing here | -9 | x.field(&22); +9 | x.field(&22)?; | + warning: variable does not need to be mutable - --> tests/compile-fail/span-tracked-getter.rs:13:9 + --> tests/compile-fail/span-tracked-getter.rs:14:9 | -13 | let mut db = salsa::DatabaseImpl::new(); +14 | let mut db = salsa::DatabaseImpl::new(); | ----^^ | | | help: remove this `mut` diff --git a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs index 97e279cde..6bd5370bb 100644 --- a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs +++ b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs @@ -13,8 +13,8 @@ struct MyTracked<'db> { } #[salsa::tracked(specify)] -fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { - MyTracked::new(db, input.field(db) * 2) +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> salsa::Result> { + MyTracked::new(db, input.field(db)? * 2) } fn main() {} diff --git a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs index 5c5feef8f..a9a82fc02 100644 --- a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs +++ b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs @@ -13,7 +13,7 @@ struct MyTracked<'db> { } #[salsa::tracked(specify)] -fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInterned<'db>) -> MyTracked<'db> { +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInterned<'db>) -> salsa::Result> { MyTracked::new(db, input.field(db) * 2) } diff --git a/tests/compile-fail/tracked_fn_incompatibles.rs b/tests/compile-fail/tracked_fn_incompatibles.rs index 309e4fba0..116e68dfe 100644 --- a/tests/compile-fail/tracked_fn_incompatibles.rs +++ b/tests/compile-fail/tracked_fn_incompatibles.rs @@ -6,32 +6,31 @@ struct MyInput { } #[salsa::tracked(data = Data)] -fn tracked_fn_with_data(db: &dyn Db, input: MyInput) -> u32 { - input.field(db) * 2 +fn tracked_fn_with_data(db: &dyn Db, input: MyInput) -> salsa::Result { + Ok(input.field(db)? * 2) } #[salsa::tracked(db = Db)] -fn tracked_fn_with_db(db: &dyn Db, input: MyInput) -> u32 { - input.field(db) * 2 +fn tracked_fn_with_db(db: &dyn Db, input: MyInput) -> salsa::Result { + Ok(input.field(db)? * 2) } #[salsa::tracked(constructor = TrackedFn3)] -fn tracked_fn_with_constructor(db: &dyn Db, input: MyInput) -> u32 { - input.field(db) * 2 +fn tracked_fn_with_constructor(db: &dyn Db, input: MyInput) -> salsa::Result { + Ok(input.field(db)? * 2) } #[salsa::tracked] -fn tracked_fn_with_one_input(db: &dyn Db) -> u32 {} +fn tracked_fn_with_one_input(db: &dyn Db) -> salsa::Result {} #[salsa::tracked] -fn tracked_fn_with_receiver_not_applied_to_impl_block(&self, db: &dyn Db) -> u32 {} +fn tracked_fn_with_receiver_not_applied_to_impl_block(&self, db: &dyn Db) -> salsa::Result {} #[salsa::tracked(specify)] fn tracked_fn_with_too_many_arguments_for_specify( db: &dyn Db, input: MyInput, input: MyInput, -) -> u32 { -} +) -> salsa::Result {} fn main() {} diff --git a/tests/compile-fail/tracked_fn_incompatibles.stderr b/tests/compile-fail/tracked_fn_incompatibles.stderr index 5851bf7eb..3bf12c77c 100644 --- a/tests/compile-fail/tracked_fn_incompatibles.stderr +++ b/tests/compile-fail/tracked_fn_incompatibles.stderr @@ -19,7 +19,7 @@ error: `constructor` option not allowed here error: #[salsa::tracked] must also be applied to the impl block for tracked methods --> tests/compile-fail/tracked_fn_incompatibles.rs:27:55 | -27 | fn tracked_fn_with_receiver_not_applied_to_impl_block(&self, db: &dyn Db) -> u32 {} +27 | fn tracked_fn_with_receiver_not_applied_to_impl_block(&self, db: &dyn Db) -> salsa::Result {} | ^^^^^ error: only functions with a single salsa struct as their input can be specified @@ -33,5 +33,8 @@ error[E0308]: mismatched types | 23 | #[salsa::tracked] | ----------------- implicitly returns `()` as its body has no tail or `return` expression -24 | fn tracked_fn_with_one_input(db: &dyn Db) -> u32 {} - | ^^^ expected `u32`, found `()` +24 | fn tracked_fn_with_one_input(db: &dyn Db) -> salsa::Result {} + | ^^^^^^^^^^^^^^^^^^ expected `Result`, found `()` + | + = note: expected enum `Result` + found unit type `()` diff --git a/tests/compile-fail/tracked_method_on_untracked_impl.rs b/tests/compile-fail/tracked_method_on_untracked_impl.rs index c9e897ee4..8ff4cdef1 100644 --- a/tests/compile-fail/tracked_method_on_untracked_impl.rs +++ b/tests/compile-fail/tracked_method_on_untracked_impl.rs @@ -5,7 +5,7 @@ struct MyInput { impl MyInput { #[salsa::tracked] - fn tracked_method_on_untracked_impl(self, db: &dyn Db) -> u32 { + fn tracked_method_on_untracked_impl(self, db: &dyn Db) -> salsa::Result { input.field(db) } } diff --git a/tests/compile-fail/tracked_method_on_untracked_impl.stderr b/tests/compile-fail/tracked_method_on_untracked_impl.stderr index 2807c74db..7d3cd0740 100644 --- a/tests/compile-fail/tracked_method_on_untracked_impl.stderr +++ b/tests/compile-fail/tracked_method_on_untracked_impl.stderr @@ -1,5 +1,5 @@ error: #[salsa::tracked] must also be applied to the impl block for tracked methods --> tests/compile-fail/tracked_method_on_untracked_impl.rs:8:41 | -8 | fn tracked_method_on_untracked_impl(self, db: &dyn Db) -> u32 { +8 | fn tracked_method_on_untracked_impl(self, db: &dyn Db) -> salsa::Result { | ^^^^ diff --git a/tests/cycles.rs b/tests/cycles.rs index f07484188..1b665bd40 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -1,436 +1,438 @@ -#![allow(warnings)] +// FIXME -use std::panic::{RefUnwindSafe, UnwindSafe}; - -use expect_test::expect; -use salsa::DatabaseImpl; -use salsa::Durability; - -// Axes: -// -// Threading -// * Intra-thread -// * Cross-thread -- part of cycle is on one thread, part on another -// -// Recovery strategies: -// * Panic -// * Fallback -// * Mixed -- multiple strategies within cycle participants -// -// Across revisions: -// * N/A -- only one revision -// * Present in new revision, not old -// * Present in old revision, not new -// * Present in both revisions -// -// Dependencies -// * Tracked -// * Untracked -- cycle participant(s) contain untracked reads -// -// Layers -// * Direct -- cycle participant is directly invoked from test -// * Indirect -- invoked a query that invokes the cycle -// -// -// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | -// | ------ | -------- | -------- | --------- | ------ | --------- | -// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | -// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | -// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | -// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | -// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | -// | Intra | Fallback | New | Tracked | direct | cycle_appears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | -// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | - -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -struct Error { - cycle: Vec, -} - -use salsa::Database as Db; -use salsa::Setter; - -#[salsa::input] -struct MyInput {} - -#[salsa::tracked] -fn memoized_a(db: &dyn Db, input: MyInput) { - memoized_b(db, input) -} - -#[salsa::tracked] -fn memoized_b(db: &dyn Db, input: MyInput) { - memoized_a(db, input) -} - -#[salsa::tracked] -fn volatile_a(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_b(db, input) -} - -#[salsa::tracked] -fn volatile_b(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_a(db, input) -} - -/// The queries A, B, and C in `Database` can be configured -/// to invoke one another in arbitrary ways using this -/// enum. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum CycleQuery { - None, - A, - B, - C, - AthenC, -} - -#[salsa::input] -struct ABC { - a: CycleQuery, - b: CycleQuery, - c: CycleQuery, -} - -impl CycleQuery { - fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { - match self { - CycleQuery::A => cycle_a(db, abc), - CycleQuery::B => cycle_b(db, abc), - CycleQuery::C => cycle_c(db, abc), - CycleQuery::AthenC => { - let _ = cycle_a(db, abc); - cycle_c(db, abc) - } - CycleQuery::None => Ok(()), - } - } -} - -#[salsa::tracked(recovery_fn=recover_a)] -fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.a(db).invoke(db, abc) -} - -fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked(recovery_fn=recover_b)] -fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.b(db).invoke(db, abc) -} - -fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked] -fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.c(db).invoke(db, abc) -} - -#[track_caller] -fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { - let v = std::panic::catch_unwind(f); - if let Err(d) = &v { - if let Some(cycle) = d.downcast_ref::() { - return cycle.clone(); - } - } - panic!("unexpected value: {:?}", v) -} - -#[test] -fn cycle_memoized() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| memoized_a(db, input)); - let expected = expect![[r#" - [ - memoized_a(Id(0)), - memoized_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }) -} - -#[test] -fn cycle_volatile() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| volatile_a(db, input)); - let expected = expect![[r#" - [ - volatile_a(Id(0)), - volatile_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }); -} - -#[test] -fn expect_cycle() { - // A --> B - // ^ | - // +-----+ - - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(db, abc).is_err()); - }) -} - -#[test] -fn inner_cycle() { - // A --> B <-- C - // ^ | - // +-----+ - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); - let err = cycle_c(db, abc); - assert!(err.is_err()); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&err.unwrap_err().cycle); - }) -} - -#[test] -fn cycle_revalidate() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - abc.set_b(&mut db).to(CycleQuery::A); // same value as default - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_recovery_unchanged_twice() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - abc.set_c(&mut db).to(CycleQuery::A); // force new revision - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_appears() { - let mut db = salsa::DatabaseImpl::new(); - // A --> B - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); - - // A --> B - // ^ | - // +-----+ - abc.set_b(&mut db).to(CycleQuery::A); - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_disappears() { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - // A --> B - abc.set_b(&mut db).to(CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); -} - -/// A variant on `cycle_disappears` in which the values of -/// `a` and `b` are set with durability values. -/// If we are not careful, this could cause us to overlook -/// the fact that the cycle will no longer occur. -#[test] -fn cycle_disappears_durability() { - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new( - &mut db, - CycleQuery::None, - CycleQuery::None, - CycleQuery::None, - ); - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::B); - abc.set_b(&mut db) - .with_durability(Durability::HIGH) - .to(CycleQuery::A); - - assert!(cycle_a(&db, abc).is_err()); - - // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, - // because `b` participates in the same cycle as `a`, its final durability - // should be `LOW`. - // - // Check that setting a `LOW` input causes us to re-execute `b` query, and - // observe that the cycle goes away. - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::None); - - assert!(cycle_b(&mut db, abc).is_ok()); -} - -#[test] -fn cycle_mixed_1() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> B <-- C - // | ^ - // +-----+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); - - let expected = expect![[r#" - [ - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_c(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_mixed_2() { - salsa::DatabaseImpl::new().attach(|db| { - // Configuration: - // - // A --> B --> C - // ^ | - // +-----------+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_a(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_deterministic_order() { - // No matter whether we start from A or B, we get the same set of participants: - let f = || { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - (db, abc) - }; - let (db, abc) = f(); - let a = cycle_a(&db, abc); - let (db, abc) = f(); - let b = cycle_b(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); -} - -#[test] -fn cycle_multiple() { - // No matter whether we start from A or B, we get the same set of participants: - let mut db = salsa::DatabaseImpl::new(); - - // Configuration: - // - // A --> B <-- C - // ^ | ^ - // +-----+ | - // | | - // +-----+ - // - // Here, conceptually, B encounters a cycle with A and then - // recovers. - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); - - let c = cycle_c(&db, abc); - let b = cycle_b(&db, abc); - let a = cycle_a(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&( - c.unwrap_err().cycle, - b.unwrap_err().cycle, - a.unwrap_err().cycle, - )); -} - -#[test] -fn cycle_recovery_set_but_not_participating() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> C -+ - // ^ | - // +--+ - let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); - - // Here we expect C to panic and A not to recover: - let r = extract_cycle(|| drop(cycle_a(db, abc))); - let expected = expect![[r#" - [ - cycle_c(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&r.all_participants(db)); - }) -} +// #![allow(warnings)] +// +// use std::panic::{RefUnwindSafe, UnwindSafe}; +// +// use expect_test::expect; +// use salsa::DatabaseImpl; +// use salsa::Durability; +// +// // Axes: +// // +// // Threading +// // * Intra-thread +// // * Cross-thread -- part of cycle is on one thread, part on another +// // +// // Recovery strategies: +// // * Panic +// // * Fallback +// // * Mixed -- multiple strategies within cycle participants +// // +// // Across revisions: +// // * N/A -- only one revision +// // * Present in new revision, not old +// // * Present in old revision, not new +// // * Present in both revisions +// // +// // Dependencies +// // * Tracked +// // * Untracked -- cycle participant(s) contain untracked reads +// // +// // Layers +// // * Direct -- cycle participant is directly invoked from test +// // * Indirect -- invoked a query that invokes the cycle +// // +// // +// // | Thread | Recovery | Old, New | Dep style | Layers | Test Name | +// // | ------ | -------- | -------- | --------- | ------ | --------- | +// // | Intra | Panic | N/A | Tracked | direct | cycle_memoized | +// // | Intra | Panic | N/A | Untracked | direct | cycle_volatile | +// // | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | +// // | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | +// // | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | +// // | Intra | Fallback | New | Tracked | direct | cycle_appears | +// // | Intra | Fallback | Old | Tracked | direct | cycle_disappears | +// // | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | +// // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | +// // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | +// // | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | +// // | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | +// // | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | +// // | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | +// +// #[derive(PartialEq, Eq, Hash, Clone, Debug)] +// struct Error { +// cycle: Vec, +// } +// +// use salsa::Database as Db; +// use salsa::Setter; +// +// #[salsa::input] +// struct MyInput {} +// +// #[salsa::tracked] +// fn memoized_a(db: &dyn Db, input: MyInput) -> salsa::Result<()> { +// memoized_b(db, input) +// } +// +// #[salsa::tracked] +// fn memoized_b(db: &dyn Db, input: MyInput) -> salsa::Result<()> { +// memoized_a(db, input) +// } +// +// #[salsa::tracked] +// fn volatile_a(db: &dyn Db, input: MyInput) -> salsa::Result<()> { +// db.report_untracked_read(); +// volatile_b(db, input) +// } +// +// #[salsa::tracked] +// fn volatile_b(db: &dyn Db, input: MyInput) -> salsa::Result<()> { +// db.report_untracked_read(); +// volatile_a(db, input) +// } +// +// /// The queries A, B, and C in `Database` can be configured +// /// to invoke one another in arbitrary ways using this +// /// enum. +// #[derive(Debug, Copy, Clone, PartialEq, Eq)] +// enum CycleQuery { +// None, +// A, +// B, +// C, +// AthenC, +// } +// +// #[salsa::input] +// struct ABC { +// a: CycleQuery, +// b: CycleQuery, +// c: CycleQuery, +// } +// +// impl CycleQuery { +// fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { +// match self { +// CycleQuery::A => cycle_a(db, abc), +// CycleQuery::B => cycle_b(db, abc), +// CycleQuery::C => cycle_c(db, abc), +// CycleQuery::AthenC => { +// let _ = cycle_a(db, abc); +// cycle_c(db, abc) +// } +// CycleQuery::None => Ok(()), +// } +// } +// } +// +// #[salsa::tracked(recovery_fn=recover_a)] +// fn cycle_a(db: &dyn Db, abc: ABC) -> salsa::Result> { +// abc.a(db).invoke(db, abc) +// } +// +// fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { +// Err(Error { +// cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), +// }) +// } +// +// #[salsa::tracked(recovery_fn=recover_b)] +// fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { +// abc.b(db).invoke(db, abc) +// } +// +// fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { +// Err(Error { +// cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), +// }) +// } +// +// #[salsa::tracked] +// fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { +// abc.c(db).invoke(db, abc) +// } +// +// #[track_caller] +// fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { +// let v = std::panic::catch_unwind(f); +// if let Err(d) = &v { +// if let Some(cycle) = d.downcast_ref::() { +// return cycle.clone(); +// } +// } +// panic!("unexpected value: {:?}", v) +// } +// +// #[test] +// fn cycle_memoized() { +// salsa::DatabaseImpl::new().attach(|db| { +// let input = MyInput::new(db); +// let cycle = extract_cycle(|| memoized_a(db, input)); +// let expected = expect![[r#" +// [ +// memoized_a(Id(0)), +// memoized_b(Id(0)), +// ] +// "#]]; +// expected.assert_debug_eq(&cycle.all_participants(db)); +// }) +// } +// +// #[test] +// fn cycle_volatile() { +// salsa::DatabaseImpl::new().attach(|db| { +// let input = MyInput::new(db); +// let cycle = extract_cycle(|| volatile_a(db, input)); +// let expected = expect![[r#" +// [ +// volatile_a(Id(0)), +// volatile_b(Id(0)), +// ] +// "#]]; +// expected.assert_debug_eq(&cycle.all_participants(db)); +// }); +// } +// +// #[test] +// fn expect_cycle() { +// // A --> B +// // ^ | +// // +-----+ +// +// salsa::DatabaseImpl::new().attach(|db| { +// let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); +// assert!(cycle_a(db, abc).is_err()); +// }) +// } +// +// #[test] +// fn inner_cycle() { +// // A --> B <-- C +// // ^ | +// // +-----+ +// salsa::DatabaseImpl::new().attach(|db| { +// let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); +// let err = cycle_c(db, abc); +// assert!(err.is_err()); +// let expected = expect![[r#" +// [ +// "cycle_a(Id(0))", +// "cycle_b(Id(0))", +// ] +// "#]]; +// expected.assert_debug_eq(&err.unwrap_err().cycle); +// }) +// } +// +// #[test] +// fn cycle_revalidate() { +// // A --> B +// // ^ | +// // +-----+ +// let mut db = salsa::DatabaseImpl::new(); +// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); +// assert!(cycle_a(&db, abc).is_err()); +// abc.set_b(&mut db).to(CycleQuery::A); // same value as default +// assert!(cycle_a(&db, abc).is_err()); +// } +// +// #[test] +// fn cycle_recovery_unchanged_twice() { +// // A --> B +// // ^ | +// // +-----+ +// let mut db = salsa::DatabaseImpl::new(); +// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); +// assert!(cycle_a(&db, abc).is_err()); +// +// abc.set_c(&mut db).to(CycleQuery::A); // force new revision +// assert!(cycle_a(&db, abc).is_err()); +// } +// +// #[test] +// fn cycle_appears() { +// let mut db = salsa::DatabaseImpl::new(); +// // A --> B +// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); +// assert!(cycle_a(&db, abc).is_ok()); +// +// // A --> B +// // ^ | +// // +-----+ +// abc.set_b(&mut db).to(CycleQuery::A); +// assert!(cycle_a(&db, abc).is_err()); +// } +// +// #[test] +// fn cycle_disappears() { +// let mut db = salsa::DatabaseImpl::new(); +// +// // A --> B +// // ^ | +// // +-----+ +// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); +// assert!(cycle_a(&db, abc).is_err()); +// +// // A --> B +// abc.set_b(&mut db).to(CycleQuery::None); +// assert!(cycle_a(&db, abc).is_ok()); +// } +// +// /// A variant on `cycle_disappears` in which the values of +// /// `a` and `b` are set with durability values. +// /// If we are not careful, this could cause us to overlook +// /// the fact that the cycle will no longer occur. +// #[test] +// fn cycle_disappears_durability() { +// let mut db = salsa::DatabaseImpl::new(); +// let abc = ABC::new( +// &mut db, +// CycleQuery::None, +// CycleQuery::None, +// CycleQuery::None, +// ); +// abc.set_a(&mut db) +// .with_durability(Durability::LOW) +// .to(CycleQuery::B); +// abc.set_b(&mut db) +// .with_durability(Durability::HIGH) +// .to(CycleQuery::A); +// +// assert!(cycle_a(&db, abc).is_err()); +// +// // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, +// // because `b` participates in the same cycle as `a`, its final durability +// // should be `LOW`. +// // +// // Check that setting a `LOW` input causes us to re-execute `b` query, and +// // observe that the cycle goes away. +// abc.set_a(&mut db) +// .with_durability(Durability::LOW) +// .to(CycleQuery::None); +// +// assert!(cycle_b(&mut db, abc).is_ok()); +// } +// +// #[test] +// fn cycle_mixed_1() { +// salsa::DatabaseImpl::new().attach(|db| { +// // A --> B <-- C +// // | ^ +// // +-----+ +// let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); +// +// let expected = expect![[r#" +// [ +// "cycle_b(Id(0))", +// "cycle_c(Id(0))", +// ] +// "#]]; +// expected.assert_debug_eq(&cycle_c(db, abc).unwrap_err().cycle); +// }) +// } +// +// #[test] +// fn cycle_mixed_2() { +// salsa::DatabaseImpl::new().attach(|db| { +// // Configuration: +// // +// // A --> B --> C +// // ^ | +// // +-----------+ +// let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); +// let expected = expect![[r#" +// [ +// "cycle_a(Id(0))", +// "cycle_b(Id(0))", +// "cycle_c(Id(0))", +// ] +// "#]]; +// expected.assert_debug_eq(&cycle_a(db, abc).unwrap_err().cycle); +// }) +// } +// +// #[test] +// fn cycle_deterministic_order() { +// // No matter whether we start from A or B, we get the same set of participants: +// let f = || { +// let mut db = salsa::DatabaseImpl::new(); +// +// // A --> B +// // ^ | +// // +-----+ +// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); +// (db, abc) +// }; +// let (db, abc) = f(); +// let a = cycle_a(&db, abc); +// let (db, abc) = f(); +// let b = cycle_b(&db, abc); +// let expected = expect![[r#" +// ( +// [ +// "cycle_a(Id(0))", +// "cycle_b(Id(0))", +// ], +// [ +// "cycle_a(Id(0))", +// "cycle_b(Id(0))", +// ], +// ) +// "#]]; +// expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); +// } +// +// #[test] +// fn cycle_multiple() { +// // No matter whether we start from A or B, we get the same set of participants: +// let mut db = salsa::DatabaseImpl::new(); +// +// // Configuration: +// // +// // A --> B <-- C +// // ^ | ^ +// // +-----+ | +// // | | +// // +-----+ +// // +// // Here, conceptually, B encounters a cycle with A and then +// // recovers. +// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); +// +// let c = cycle_c(&db, abc); +// let b = cycle_b(&db, abc); +// let a = cycle_a(&db, abc); +// let expected = expect![[r#" +// ( +// [ +// "cycle_a(Id(0))", +// "cycle_b(Id(0))", +// ], +// [ +// "cycle_a(Id(0))", +// "cycle_b(Id(0))", +// ], +// [ +// "cycle_a(Id(0))", +// "cycle_b(Id(0))", +// ], +// ) +// "#]]; +// expected.assert_debug_eq(&( +// c.unwrap_err().cycle, +// b.unwrap_err().cycle, +// a.unwrap_err().cycle, +// )); +// } +// +// #[test] +// fn cycle_recovery_set_but_not_participating() { +// salsa::DatabaseImpl::new().attach(|db| { +// // A --> C -+ +// // ^ | +// // +--+ +// let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); +// +// // Here we expect C to panic and A not to recover: +// let r = extract_cycle(|| drop(cycle_a(db, abc))); +// let expected = expect![[r#" +// [ +// cycle_c(Id(0)), +// ] +// "#]]; +// expected.assert_debug_eq(&r.all_participants(db)); +// }) +// } diff --git a/tests/debug.rs b/tests/debug.rs index 3c2c896c1..6d84eaa5c 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -2,6 +2,7 @@ use expect_test::expect; use salsa::{Database, Setter}; +use std::fmt::Error; #[salsa::input] struct MyInput { @@ -36,20 +37,20 @@ fn input() { } #[salsa::tracked] -fn leak_debug_string(_db: &dyn salsa::Database, input: MyInput) -> String { - format!("{input:?}") +fn leak_debug_string(_db: &dyn salsa::Database, input: MyInput) -> salsa::Result { + Ok(format!("{input:?}")) } /// Test that field reads that occur as part of `Debug` are not tracked. /// Intentionally leaks the debug string. /// Don't try this at home, kids. #[test] -fn untracked_dependencies() { +fn untracked_dependencies() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); - let s = leak_debug_string(&db, input); + let s = leak_debug_string(&db, input)?; expect![[r#" "MyInput { [salsa id]: Id(0), field: 22 }" "#]] @@ -59,8 +60,10 @@ fn untracked_dependencies() { // check that we reuse the cached result for debug string // even though the dependency changed. - let s = leak_debug_string(&db, input); + let s = leak_debug_string(&db, input)?; assert!(s.contains(", field: 22 }")); + + Ok(()) } #[salsa::tracked(no_debug)] @@ -72,27 +75,38 @@ struct DerivedCustom<'db> { impl std::fmt::Debug for DerivedCustom<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { salsa::with_attached_database(|db| { - write!(f, "{:?} / {:?}", self.my_input(db), self.value(db)) + write!( + f, + "{:?} / {:?}", + self.my_input(db).map_err(|_| Error)?, + self.value(db).map_err(|_| Error)? + ) }) .unwrap_or_else(|| f.debug_tuple("DerivedCustom").finish()) } } #[salsa::tracked] -fn leak_derived_custom(db: &dyn salsa::Database, input: MyInput, value: u32) -> String { - let c = DerivedCustom::new(db, input, value); - format!("{c:?}") +fn leak_derived_custom( + db: &dyn salsa::Database, + input: MyInput, + value: u32, +) -> salsa::Result { + let c = DerivedCustom::new(db, input, value)?; + Ok(format!("{c:?}")) } #[test] -fn custom_debug_impl() { +fn custom_debug_impl() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); - let s = leak_derived_custom(&db, input, 23); + let s = leak_derived_custom(&db, input, 23)?; expect![[r#" "MyInput { [salsa id]: Id(0), field: 22 } / 23" "#]] .assert_debug_eq(&s); + + Ok(()) } diff --git a/tests/deletion-cascade.rs b/tests/deletion-cascade.rs index 7b3b2211e..bdff44939 100644 --- a/tests/deletion-cascade.rs +++ b/tests/deletion-cascade.rs @@ -15,13 +15,13 @@ struct MyInput { } #[salsa::tracked] -fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("final_result({:?})", input)); let mut sum = 0; - for tracked_struct in create_tracked_structs(db, input) { - sum += contribution_from_struct(db, tracked_struct); + for tracked_struct in create_tracked_structs(db, input)? { + sum += contribution_from_struct(db, tracked_struct)?; } - sum + Ok(sum) } #[salsa::tracked] @@ -30,31 +30,42 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn create_tracked_structs(db: &dyn LogDatabase, input: MyInput) -> Vec> { +fn create_tracked_structs( + db: &dyn LogDatabase, + input: MyInput, +) -> salsa::Result>> { db.push_log(format!("intermediate_result({:?})", input)); - (0..input.field(db)) - .map(|i| MyTracked::new(db, i)) - .collect() + + let mut result = Vec::new(); + + for i in 0..input.field(db)? { + result.push(MyTracked::new(db, i)?); + } + + Ok(result) } #[salsa::tracked] -fn contribution_from_struct<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> u32 { - let m = MyTracked::new(db, tracked.field(db)); - copy_field(db, m) * 2 +fn contribution_from_struct<'db>( + db: &'db dyn LogDatabase, + tracked: MyTracked<'db>, +) -> salsa::Result { + let m = MyTracked::new(db, tracked.field(db)?)?; + Ok(copy_field(db, m)? * 2) } #[salsa::tracked] -fn copy_field<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> u32 { +fn copy_field<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> salsa::Result { tracked.field(db) } #[test] -fn basic() { +fn basic() -> salsa::Result<()> { let mut db = common::DiscardLoggerDatabase::default(); // Creates 3 tracked structs let input = MyInput::new(&db, 3); - assert_eq!(final_result(&db, input), 2 * 2 + 2); + assert_eq!(final_result(&db, input)?, 2 * 2 + 2); db.assert_logs(expect![[r#" [ "final_result(MyInput { [salsa id]: Id(0), field: 3 })", @@ -76,7 +87,7 @@ fn basic() { // * the `copy_field` result input.set_field(&mut db).to(2); - assert_eq!(final_result(&db, input), 2); + assert_eq!(final_result(&db, input)?, 2); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 2 })", @@ -87,4 +98,6 @@ fn basic() { "salsa_event(DidDiscard { key: copy_field(Id(405)) })", "final_result(MyInput { [salsa id]: Id(0), field: 2 })", ]"#]]); + + Ok(()) } diff --git a/tests/deletion-drops.rs b/tests/deletion-drops.rs index 57811569c..947441e31 100644 --- a/tests/deletion-drops.rs +++ b/tests/deletion-drops.rs @@ -43,19 +43,19 @@ impl Drop for Bomb { #[salsa::tracked] impl MyInput { #[salsa::tracked] - fn create_tracked_struct(self, db: &dyn Database) -> MyTracked<'_> { + fn create_tracked_struct(self, db: &dyn Database) -> salsa::Result> { MyTracked::new( db, - self.identity(db), + self.identity(db)?, Bomb { - identity: self.identity(db), + identity: self.identity(db)?, }, ) } } #[test] -fn deletion_drops() { +fn deletion_drops() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); @@ -65,8 +65,8 @@ fn deletion_drops() { "#]] .assert_debug_eq(&dropped()); - let tracked_struct = input.create_tracked_struct(&db); - assert_eq!(tracked_struct.field(&db).identity, 22); + let tracked_struct = input.create_tracked_struct(&db)?; + assert_eq!(tracked_struct.field(&db)?.identity, 22); expect_test::expect![[r#" [] @@ -81,8 +81,8 @@ fn deletion_drops() { .assert_debug_eq(&dropped()); // Now that we execute with rev = 44, the old id is put on the free list - let tracked_struct = input.create_tracked_struct(&db); - assert_eq!(tracked_struct.field(&db).identity, 44); + let tracked_struct = input.create_tracked_struct(&db)?; + assert_eq!(tracked_struct.field(&db)?.identity, 44); expect_test::expect![[r#" [] @@ -91,7 +91,7 @@ fn deletion_drops() { // When we execute again with `input1`, that id is re-used, so the old value is deleted let input1 = MyInput::new(&db, 66); - let _tracked_struct1 = input1.create_tracked_struct(&db); + let _tracked_struct1 = input1.create_tracked_struct(&db)?; expect_test::expect![[r#" [ @@ -99,4 +99,6 @@ fn deletion_drops() { ] "#]] .assert_debug_eq(&dropped()); + + Ok(()) } diff --git a/tests/deletion.rs b/tests/deletion.rs index 0aef146c1..0a7e52c32 100644 --- a/tests/deletion.rs +++ b/tests/deletion.rs @@ -15,13 +15,13 @@ struct MyInput { } #[salsa::tracked] -fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("final_result({:?})", input)); let mut sum = 0; - for tracked_struct in create_tracked_structs(db, input) { - sum += contribution_from_struct(db, tracked_struct); + for tracked_struct in create_tracked_structs(db, input)? { + sum += contribution_from_struct(db, tracked_struct)?; } - sum + Ok(sum) } #[salsa::tracked] @@ -30,25 +30,32 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn create_tracked_structs(db: &dyn LogDatabase, input: MyInput) -> Vec> { +fn create_tracked_structs( + db: &dyn LogDatabase, + input: MyInput, +) -> salsa::Result>> { db.push_log(format!("intermediate_result({:?})", input)); - (0..input.field(db)) + + (0..input.field(db)?) .map(|i| MyTracked::new(db, i)) .collect() } #[salsa::tracked] -fn contribution_from_struct<'db>(db: &'db dyn LogDatabase, tracked: MyTracked<'db>) -> u32 { - tracked.field(db) * 2 +fn contribution_from_struct<'db>( + db: &'db dyn LogDatabase, + tracked: MyTracked<'db>, +) -> salsa::Result { + Ok(tracked.field(db)? * 2) } #[test] -fn basic() { +fn basic() -> salsa::Result<()> { let mut db = common::DiscardLoggerDatabase::default(); // Creates 3 tracked structs let input = MyInput::new(&db, 3); - assert_eq!(final_result(&db, input), 2 * 2 + 2); + assert_eq!(final_result(&db, input)?, 2 * 2 + 2); db.assert_logs(expect![[r#" [ "final_result(MyInput { [salsa id]: Id(0), field: 3 })", @@ -63,7 +70,7 @@ fn basic() { // * the struct's field // * the `contribution_from_struct` result input.set_field(&mut db).to(2); - assert_eq!(final_result(&db, input), 2); + assert_eq!(final_result(&db, input)?, 2); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 2 })", @@ -72,4 +79,6 @@ fn basic() { "salsa_event(DidDiscard { key: contribution_from_struct(Id(402)) })", "final_result(MyInput { [salsa id]: Id(0), field: 2 })", ]"#]]); + + Ok(()) } diff --git a/tests/elided-lifetime-in-tracked-fn.rs b/tests/elided-lifetime-in-tracked-fn.rs index f62c23a87..9461e790a 100644 --- a/tests/elided-lifetime-in-tracked-fn.rs +++ b/tests/elided-lifetime-in-tracked-fn.rs @@ -14,9 +14,9 @@ struct MyInput { } #[salsa::tracked] -fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("final_result({:?})", input)); - intermediate_result(db, input).field(db) * 2 + Ok(intermediate_result(db, input)?.field(db)? * 2) } #[salsa::tracked] @@ -25,17 +25,17 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result> { db.push_log(format!("intermediate_result({:?})", input)); - MyTracked::new(db, input.field(db) / 2) + MyTracked::new(db, input.field(db)? / 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "final_result(MyInput { [salsa id]: Id(0), field: 22 })", @@ -45,17 +45,19 @@ fn execute() { // Intermediate result is the same, so final result does // not need to be recomputed: input.set_field(&mut db).to(23); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 23 })", ]"#]]); input.set_field(&mut db).to(24); - assert_eq!(final_result(&db, input), 24); + assert_eq!(final_result(&db, input)?, 24); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 24 })", "final_result(MyInput { [salsa id]: Id(0), field: 24 })", ]"#]]); + + Ok(()) } diff --git a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs index fb62e1c5d..2cc1fd4de 100644 --- a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs @@ -15,15 +15,15 @@ struct MyInput { } #[salsa::tracked] -fn final_result_depends_on_x(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn final_result_depends_on_x(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("final_result_depends_on_x({:?})", input)); - intermediate_result(db, input).x(db) * 2 + Ok(intermediate_result(db, input)?.x(db)? * 2) } #[salsa::tracked] -fn final_result_depends_on_y(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn final_result_depends_on_y(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("final_result_depends_on_y({:?})", input)); - intermediate_result(db, input).y(db) * 2 + Ok(intermediate_result(db, input)?.y(db)? * 2) } #[salsa::tracked] @@ -33,12 +33,12 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { - MyTracked::new(db, (input.field(db) + 1) / 2, input.field(db) / 2) +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result> { + MyTracked::new(db, (input.field(db)? + 1) / 2, input.field(db)? / 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { // x = (input.field + 1) / 2 // y = input.field / 2 // final_result_depends_on_x = x * 2 = (input.field + 1) / 2 * 2 @@ -49,13 +49,13 @@ fn execute() { // x = (22 + 1) / 2 = 11 // y = 22 / 2 = 11 let input = MyInput::new(&db, 22); - assert_eq!(final_result_depends_on_x(&db, input), 22); + assert_eq!(final_result_depends_on_x(&db, input)?, 22); db.assert_logs(expect![[r#" [ "final_result_depends_on_x(MyInput { [salsa id]: Id(0), field: 22 })", ]"#]]); - assert_eq!(final_result_depends_on_y(&db, input), 22); + assert_eq!(final_result_depends_on_y(&db, input)?, 22); db.assert_logs(expect![[r#" [ "final_result_depends_on_y(MyInput { [salsa id]: Id(0), field: 22 })", @@ -65,7 +65,7 @@ fn execute() { // x = (23 + 1) / 2 = 12 // Intermediate result x changes, so final result depends on x // needs to be recomputed; - assert_eq!(final_result_depends_on_x(&db, input), 24); + assert_eq!(final_result_depends_on_x(&db, input)?, 24); db.assert_logs(expect![[r#" [ "final_result_depends_on_x(MyInput { [salsa id]: Id(0), field: 23 })", @@ -74,6 +74,8 @@ fn execute() { // y = 23 / 2 = 11 // Intermediate result y is the same, so final result depends on y // does not need to be recomputed; - assert_eq!(final_result_depends_on_y(&db, input), 22); + assert_eq!(final_result_depends_on_y(&db, input)?, 22); db.assert_logs(expect!["[]"]); + + Ok(()) } diff --git a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs index a18447958..2dd893e46 100644 --- a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs @@ -16,30 +16,30 @@ struct MyInput { } #[salsa::tracked] -fn result_depends_on_x(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn result_depends_on_x(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("result_depends_on_x({:?})", input)); - input.x(db) + 1 + Ok(input.x(db)? + 1) } #[salsa::tracked] -fn result_depends_on_y(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn result_depends_on_y(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("result_depends_on_y({:?})", input)); - input.y(db) - 1 + Ok(input.y(db)? - 1) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { // result_depends_on_x = x + 1 // result_depends_on_y = y - 1 let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22, 33); - assert_eq!(result_depends_on_x(&db, input), 23); + assert_eq!(result_depends_on_x(&db, input)?, 23); db.assert_logs(expect![[r#" [ "result_depends_on_x(MyInput { [salsa id]: Id(0), x: 22, y: 33 })", ]"#]]); - assert_eq!(result_depends_on_y(&db, input), 32); + assert_eq!(result_depends_on_y(&db, input)?, 32); db.assert_logs(expect![[r#" [ "result_depends_on_y(MyInput { [salsa id]: Id(0), x: 22, y: 33 })", @@ -47,7 +47,7 @@ fn execute() { input.set_x(&mut db).to(23); // input x changes, so result depends on x needs to be recomputed; - assert_eq!(result_depends_on_x(&db, input), 24); + assert_eq!(result_depends_on_x(&db, input)?, 24); db.assert_logs(expect![[r#" [ "result_depends_on_x(MyInput { [salsa id]: Id(0), x: 23, y: 33 })", @@ -55,6 +55,8 @@ fn execute() { // input y is the same, so result depends on y // does not need to be recomputed; - assert_eq!(result_depends_on_y(&db, input), 32); + assert_eq!(result_depends_on_y(&db, input)?, 32); db.assert_logs(expect!["[]"]); + + Ok(()) } diff --git a/tests/hello_world.rs b/tests/hello_world.rs index 04cfdee99..0aeba39fe 100644 --- a/tests/hello_world.rs +++ b/tests/hello_world.rs @@ -14,9 +14,9 @@ struct MyInput { } #[salsa::tracked] -fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("final_result({:?})", input)); - intermediate_result(db, input).field(db) * 2 + Ok(intermediate_result(db, input)?.field(db)? * 2) } #[salsa::tracked] @@ -25,17 +25,17 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result> { db.push_log(format!("intermediate_result({:?})", input)); - MyTracked::new(db, input.field(db) / 2) + MyTracked::new(db, input.field(db)? / 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "final_result(MyInput { [salsa id]: Id(0), field: 22 })", @@ -45,28 +45,30 @@ fn execute() { // Intermediate result is the same, so final result does // not need to be recomputed: input.set_field(&mut db).to(23); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 23 })", ]"#]]); input.set_field(&mut db).to(24); - assert_eq!(final_result(&db, input), 24); + assert_eq!(final_result(&db, input)?, 24); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 24 })", "final_result(MyInput { [salsa id]: Id(0), field: 24 })", ]"#]]); + + Ok(()) } /// Create and mutate a distinct input. No re-execution required. #[test] -fn red_herring() { +fn red_herring() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "final_result(MyInput { [salsa id]: Id(0), field: 22 })", @@ -80,7 +82,9 @@ fn red_herring() { input2.set_field(&mut db).to(66); // Re-run the query on the original input. Nothing re-executes! - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" []"#]]); + + Ok(()) } diff --git a/tests/input_default.rs b/tests/input_default.rs index 5a4d2bd54..59d21401d 100644 --- a/tests/input_default.rs +++ b/tests/input_default.rs @@ -11,33 +11,38 @@ struct MyInput { } #[test] -fn new_constructor() { +fn new_constructor() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, true); - assert!(input.required(&db)); - assert_eq!(input.optional(&db), 0); + assert!(input.required(&db)?); + assert_eq!(input.optional(&db)?, 0); + Ok(()) } #[test] -fn builder_specify_optional() { +fn builder_specify_optional() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::builder(true).optional(20).new(&db); - assert!(input.required(&db)); - assert_eq!(input.optional(&db), 20); + assert!(input.required(&db)?); + assert_eq!(input.optional(&db)?, 20); + + Ok(()) } #[test] -fn builder_default_optional_value() { +fn builder_default_optional_value() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::builder(true) .required_durability(Durability::HIGH) .new(&db); - assert!(input.required(&db)); - assert_eq!(input.optional(&db), 0); + assert!(input.required(&db)?); + assert_eq!(input.optional(&db)?, 0); + + Ok(()) } diff --git a/tests/input_field_durability.rs b/tests/input_field_durability.rs index b65a512e0..348eaf946 100644 --- a/tests/input_field_durability.rs +++ b/tests/input_field_durability.rs @@ -12,19 +12,21 @@ struct MyInput { } #[test] -fn required_field_durability() { +fn required_field_durability() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::builder(true) .required_field_durability(Durability::HIGH) .new(&db); - assert!(input.required_field(&db)); - assert_eq!(input.optional_field(&db), 0); + assert!(input.required_field(&db)?); + assert_eq!(input.optional_field(&db)?, 0); + + Ok(()) } #[test] -fn optional_field_durability() { +fn optional_field_durability() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::builder(true) @@ -32,6 +34,7 @@ fn optional_field_durability() { .optional_field_durability(Durability::HIGH) .new(&db); - assert!(input.required_field(&db)); - assert_eq!(input.optional_field(&db), 20); + assert!(input.required_field(&db)?); + assert_eq!(input.optional_field(&db)?, 20); + Ok(()) } diff --git a/tests/interned-struct-with-lifetime.rs b/tests/interned-struct-with-lifetime.rs index a74d2c42e..20707ff01 100644 --- a/tests/interned-struct-with-lifetime.rs +++ b/tests/interned-struct-with-lifetime.rs @@ -15,17 +15,19 @@ struct InternedPair<'db> { } #[salsa::tracked] -fn intern_stuff(db: &dyn salsa::Database) -> String { - let s1 = InternedString::new(db, "Hello, ".to_string()); - let s2 = InternedString::new(db, "World, ".to_string()); - let s3 = InternedPair::new(db, (s1, s2)); - format!("{s3:?}") +fn intern_stuff(db: &dyn salsa::Database) -> salsa::Result { + let s1 = InternedString::new(db, "Hello, ".to_string())?; + let s2 = InternedString::new(db, "World, ".to_string())?; + let s3 = InternedPair::new(db, (s1, s2))?; + Ok(format!("{s3:?}")) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); expect![[r#" "InternedPair { data: (InternedString { data: \"Hello, \" }, InternedString { data: \"World, \" }) }" - "#]].assert_debug_eq(&intern_stuff(&db)); + "#]].assert_debug_eq(&intern_stuff(&db)?); + + Ok(()) } diff --git a/tests/is_send_sync.rs b/tests/is_send_sync.rs index 6ada1bacc..04c4ace33 100644 --- a/tests/is_send_sync.rs +++ b/tests/is_send_sync.rs @@ -20,10 +20,11 @@ struct MyInterned<'db> { } #[salsa::tracked] -fn test(db: &dyn Database, input: MyInput) { +fn test(db: &dyn Database, input: MyInput) -> salsa::Result<()> { let input = is_send_sync(input); - let interned = is_send_sync(MyInterned::new(db, input.field(db).clone())); - let _tracked_struct = is_send_sync(MyTracked::new(db, interned)); + let interned = is_send_sync(MyInterned::new(db, input.field(db)?.clone())?); + let _tracked_struct = is_send_sync(MyTracked::new(db, interned)?); + Ok(()) } fn is_send_sync(t: T) -> T { @@ -31,8 +32,10 @@ fn is_send_sync(t: T) -> T { } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, "Hello".to_string()); - test(&db, input); + test(&db, input)?; + + Ok(()) } diff --git a/tests/lru.rs b/tests/lru.rs index 0cd360021..118ff9491 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -37,22 +37,22 @@ struct MyInput { } #[salsa::tracked(lru = 32)] -fn get_hot_potato(db: &dyn LogDatabase, input: MyInput) -> Arc { - db.push_log(format!("get_hot_potato({:?})", input.field(db))); - Arc::new(HotPotato::new(input.field(db))) +fn get_hot_potato(db: &dyn LogDatabase, input: MyInput) -> salsa::Result> { + db.push_log(format!("get_hot_potato({:?})", input.field(db)?)); + Ok(Arc::new(HotPotato::new(input.field(db)?))) } #[salsa::tracked] -fn get_hot_potato2(db: &dyn LogDatabase, input: MyInput) -> u32 { - db.push_log(format!("get_hot_potato2({:?})", input.field(db))); - get_hot_potato(db, input).0 +fn get_hot_potato2(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { + db.push_log(format!("get_hot_potato2({:?})", input.field(db)?)); + Ok(get_hot_potato(db, input)?.0) } #[salsa::tracked(lru = 32)] -fn get_volatile(db: &dyn LogDatabase, _input: MyInput) -> usize { +fn get_volatile(db: &dyn LogDatabase, _input: MyInput) -> salsa::Result { static COUNTER: AtomicUsize = AtomicUsize::new(0); db.report_untracked_read(); - COUNTER.fetch_add(1, Ordering::SeqCst) + Ok(COUNTER.fetch_add(1, Ordering::SeqCst)) } fn load_n_potatoes() -> usize { @@ -60,23 +60,25 @@ fn load_n_potatoes() -> usize { } #[test] -fn lru_works() { +fn lru_works() -> salsa::Result<()> { let db = common::LoggerDatabase::default(); assert_eq!(load_n_potatoes(), 0); for i in 0..128u32 { let input = MyInput::new(&db, i); - let p = get_hot_potato(&db, input); + let p = get_hot_potato(&db, input)?; assert_eq!(p.0, i) } // Create a new input to change the revision, and trigger the GC MyInput::new(&db, 0); assert_eq!(load_n_potatoes(), 32); + + Ok(()) } #[test] -fn lru_doesnt_break_volatile_queries() { +fn lru_doesnt_break_volatile_queries() -> salsa::Result<()> { let db = common::LoggerDatabase::default(); // Create all inputs first, so that there are no revision changes among calls to `get_volatile` @@ -87,21 +89,23 @@ fn lru_doesnt_break_volatile_queries() { // but it's much better than inconsistent results from volatile queries! for _ in 0..3 { for (i, input) in inputs.iter().enumerate() { - let x = get_volatile(&db, *input); + let x = get_volatile(&db, *input)?; assert_eq!(x, i); } } + + Ok(()) } #[test] -fn lru_can_be_changed_at_runtime() { +fn lru_can_be_changed_at_runtime() -> salsa::Result<()> { let db = common::LoggerDatabase::default(); assert_eq!(load_n_potatoes(), 0); let inputs: Vec<(u32, MyInput)> = (0..128).map(|i| (i, MyInput::new(&db, i))).collect(); for &(i, input) in inputs.iter() { - let p = get_hot_potato(&db, input); + let p = get_hot_potato(&db, input)?; assert_eq!(p.0, i) } @@ -112,7 +116,7 @@ fn lru_can_be_changed_at_runtime() { get_hot_potato::set_lru_capacity(&db, 64); assert_eq!(load_n_potatoes(), 32); for &(i, input) in inputs.iter() { - let p = get_hot_potato(&db, input); + let p = get_hot_potato(&db, input)?; assert_eq!(p.0, i) } @@ -124,7 +128,7 @@ fn lru_can_be_changed_at_runtime() { get_hot_potato::set_lru_capacity(&db, 0); assert_eq!(load_n_potatoes(), 64); for &(i, input) in inputs.iter() { - let p = get_hot_potato(&db, input); + let p = get_hot_potato(&db, input)?; assert_eq!(p.0, i) } @@ -134,10 +138,12 @@ fn lru_can_be_changed_at_runtime() { drop(db); assert_eq!(load_n_potatoes(), 0); + + Ok(()) } #[test] -fn lru_keeps_dependency_info() { +fn lru_keeps_dependency_info() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let capacity = 32; @@ -148,7 +154,7 @@ fn lru_keeps_dependency_info() { .collect(); for (i, input) in inputs.iter().enumerate() { - let x = get_hot_potato2(&db, *input); + let x = get_hot_potato2(&db, *input)?; assert_eq!(x as usize, i); } @@ -160,7 +166,9 @@ fn lru_keeps_dependency_info() { // calling `get_hot_potato2(0)` has to check that `get_hot_potato(0)` is still valid; // even though we've evicted it (LRU), we find that it is still good - let p = get_hot_potato2(&db, *inputs.first().unwrap()); + let p = get_hot_potato2(&db, *inputs.first().unwrap())?; assert_eq!(p, 0); db.assert_logs_len(0); + + Ok(()) } diff --git a/tests/mutate_in_place.rs b/tests/mutate_in_place.rs index 047373ee5..7ac2f28c2 100644 --- a/tests/mutate_in_place.rs +++ b/tests/mutate_in_place.rs @@ -10,7 +10,7 @@ struct MyInput { } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, "Hello".to_string()); @@ -25,5 +25,7 @@ fn execute() { assert_eq!(input.set_field(&mut db).to(my_string), ""); // Check if the stored String is the one we expected - assert_eq!(input.field(&db), "Hello World!"); + assert_eq!(input.field(&db)?, "Hello World!"); + + Ok(()) } diff --git a/tests/override_new_get_set.rs b/tests/override_new_get_set.rs index 9f3a87528..31662c085 100644 --- a/tests/override_new_get_set.rs +++ b/tests/override_new_get_set.rs @@ -22,7 +22,7 @@ impl MyInput { MyInput::from_string(db, s.to_string()) } - pub fn field(self, db: &dyn Db) -> String { + pub fn field(self, db: &dyn Db) -> salsa::Result { self.text(db) } @@ -39,7 +39,7 @@ struct MyInterned<'db> { } impl<'db> MyInterned<'db> { - pub fn new(db: &'db dyn Db, s: impl Display) -> MyInterned<'db> { + pub fn new(db: &'db dyn Db, s: impl Display) -> salsa::Result> { MyInterned::from_string(db, s.to_string()) } @@ -55,11 +55,11 @@ struct MyTracked<'db> { } impl<'db> MyTracked<'db> { - pub fn new(db: &'db dyn Db, s: impl Display) -> MyTracked<'db> { + pub fn new(db: &'db dyn Db, s: impl Display) -> salsa::Result> { MyTracked::from_string(db, s.to_string()) } - pub fn field(self, db: &'db dyn Db) -> String { + pub fn field(self, db: &'db dyn Db) -> salsa::Result { self.text(db) } } diff --git a/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs b/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs index 32b444c7f..e45959528 100644 --- a/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs +++ b/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs @@ -12,5 +12,5 @@ struct MyTracked<'db> { )] fn execute() { let db = salsa::DatabaseImpl::new(); - MyTracked::new(&db, 0); + MyTracked::new(&db, 0).unwrap(); } diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index a106ec7d7..365061cb4 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -2,7 +2,6 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::Cancelled; use salsa::Setter; use crate::setup::Knobs; @@ -14,14 +13,14 @@ struct MyInput { } #[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: MyInput) -> MyInput { +fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { db.signal(1); db.wait_for(2); dummy(db, input) } #[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> MyInput { +fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> salsa::Result { panic!("should never get here!") } @@ -46,24 +45,27 @@ fn execute() { let input = MyInput::new(&db, 1); - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); + let thread_a = std::thread::Builder::new() + .name("a".to_string()) + .spawn({ + let db = db.clone(); + move || a1(&db, input) + }) + .unwrap(); db.signal_on_did_cancel.store(2); input.set_field(&mut db).to(2); // Assert thread A *should* was cancelled - let cancelled = thread_a - .join() - .unwrap_err() - .downcast::() - .unwrap(); + let cancelled = thread_a.join().unwrap().unwrap_err(); // and inspect the output expect_test::expect![[r#" - PendingWrite + Error { + kind: Cancelled( + PendingWrite, + ), + } "#]] .assert_debug_eq(&cancelled); } diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs index 9dc8c74e2..552a6b292 100644 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ b/tests/parallel/parallel_cycle_all_recover.rs @@ -1,104 +1,104 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked(recovery_fn = recover_a1)] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a1"); - key.field(db) * 10 + 1 -} - -#[salsa::tracked(recovery_fn=recover_a2)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a2"); - key.field(db) * 10 + 2 -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 1 -} - -#[salsa::tracked(recovery_fn=recover_b2)] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b2"); - key.field(db) * 20 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected, recovers) -// | b2 completes, recovers -// | b1 completes, recovers -// a2 sees cycle, recovers -// a1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - assert_eq!(thread_a.join().unwrap(), 11); - assert_eq!(thread_b.join().unwrap(), 21); -} +// //! Test for cycle recover spread across two threads. +// //! See `../cycles.rs` for a complete listing of cycle tests, +// //! both intra and cross thread. +// +// use crate::setup::Knobs; +// use crate::setup::KnobsDatabase; +// +// #[salsa::input] +// pub(crate) struct MyInput { +// field: i32, +// } +// +// #[salsa::tracked(recovery_fn = recover_a1)] +// pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // Wait to create the cycle until both threads have entered +// db.signal(1); +// db.wait_for(2); +// +// a2(db, input) +// } +// +// fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +// dbg!("recover_a1"); +// key.field(db) * 10 + 1 +// } +// +// #[salsa::tracked(recovery_fn=recover_a2)] +// pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// b1(db, input) +// } +// +// fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +// dbg!("recover_a2"); +// key.field(db) * 10 + 2 +// } +// +// #[salsa::tracked(recovery_fn=recover_b1)] +// pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // Wait to create the cycle until both threads have entered +// db.wait_for(1); +// db.signal(2); +// +// // Wait for thread A to block on this thread +// db.wait_for(3); +// b2(db, input) +// } +// +// fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +// dbg!("recover_b1"); +// key.field(db) * 20 + 1 +// } +// +// #[salsa::tracked(recovery_fn=recover_b2)] +// pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// a1(db, input) +// } +// +// fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +// dbg!("recover_b2"); +// key.field(db) * 20 + 2 +// } +// +// // Recover cycle test: +// // +// // The pattern is as follows. +// // +// // Thread A Thread B +// // -------- -------- +// // a1 b1 +// // | wait for stage 1 (blocks) +// // signal stage 1 | +// // wait for stage 2 (blocks) (unblocked) +// // | signal stage 2 +// // (unblocked) wait for stage 3 (blocks) +// // a2 | +// // b1 (blocks -> stage 3) | +// // | (unblocked) +// // | b2 +// // | a1 (cycle detected, recovers) +// // | b2 completes, recovers +// // | b1 completes, recovers +// // a2 sees cycle, recovers +// // a1 completes, recovers +// +// #[test] +// fn execute() { +// let db = Knobs::default(); +// +// let input = MyInput::new(&db, 1); +// +// let thread_a = std::thread::spawn({ +// let db = db.clone(); +// db.knobs().signal_on_will_block.store(3); +// move || a1(&db, input).unwrap() +// }); +// +// let thread_b = std::thread::spawn({ +// let db = db.clone(); +// move || b1(&db, input).unwrap() +// }); +// +// assert_eq!(thread_a.join().unwrap(), 11); +// assert_eq!(thread_b.join().unwrap(), 21); +// } diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs index 593d46a66..5764db750 100644 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ b/tests/parallel/parallel_cycle_mid_recover.rs @@ -1,102 +1,102 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // tell thread b we have started - db.signal(1); - - // wait for thread b to block on a1 - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // create the cycle - b1(db, input) -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // wait for thread a to have started - db.wait_for(1); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will encounter a cycle but recover - b3(db, input); - b1(db, input); // hasn't recovered yet - 0 -} - -#[salsa::tracked(recovery_fn=recover_b3)] -pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will block on thread a, signaling stage 2 - a1(db, input) -} - -fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b3"); - key.field(db) * 200 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | | -// | b2 -// | b3 -// | a1 (blocks -> stage 2) -// (unblocked) | -// a2 (cycle detected) | -// b3 recovers -// b2 resumes -// b1 recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} +// //! Test for cycle recover spread across two threads. +// //! See `../cycles.rs` for a complete listing of cycle tests, +// //! both intra and cross thread. +// +// use crate::setup::{Knobs, KnobsDatabase}; +// +// #[salsa::input] +// pub(crate) struct MyInput { +// field: i32, +// } +// +// #[salsa::tracked] +// pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // tell thread b we have started +// db.signal(1); +// +// // wait for thread b to block on a1 +// db.wait_for(2); +// +// a2(db, input) +// } +// +// #[salsa::tracked] +// pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // create the cycle +// b1(db, input) +// } +// +// #[salsa::tracked(recovery_fn=recover_b1)] +// pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // wait for thread a to have started +// db.wait_for(1); +// b2(db, input) +// } +// +// fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +// dbg!("recover_b1"); +// key.field(db) * 20 + 2 +// } +// +// #[salsa::tracked] +// pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // will encounter a cycle but recover +// b3(db, input); +// b1(db, input); // hasn't recovered yet +// 0 +// } +// +// #[salsa::tracked(recovery_fn=recover_b3)] +// pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // will block on thread a, signaling stage 2 +// a1(db, input) +// } +// +// fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +// dbg!("recover_b3"); +// key.field(db) * 200 + 2 +// } +// +// // Recover cycle test: +// // +// // The pattern is as follows. +// // +// // Thread A Thread B +// // -------- -------- +// // a1 b1 +// // | wait for stage 1 (blocks) +// // signal stage 1 | +// // wait for stage 2 (blocks) (unblocked) +// // | | +// // | b2 +// // | b3 +// // | a1 (blocks -> stage 2) +// // (unblocked) | +// // a2 (cycle detected) | +// // b3 recovers +// // b2 resumes +// // b1 recovers +// +// #[test] +// fn execute() { +// let db = Knobs::default(); +// +// let input = MyInput::new(&db, 1); +// +// let thread_a = std::thread::spawn({ +// let db = db.clone(); +// move || a1(&db, input).unwrap() +// }); +// +// let thread_b = std::thread::spawn({ +// let db = db.clone(); +// db.knobs().signal_on_will_block.store(3); +// move || b1(&db, input).unwrap() +// }); +// +// // We expect that the recovery function yields +// // `1 * 20 + 2`, which is returned (and forwarded) +// // to b1, and from there to a2 and a1. +// assert_eq!(thread_a.join().unwrap(), 22); +// assert_eq!(thread_b.join().unwrap(), 22); +// } diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs index 89f1ecfb0..5458f8a19 100644 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ b/tests/parallel/parallel_cycle_none_recover.rs @@ -13,7 +13,7 @@ pub(crate) struct MyInput { } #[salsa::tracked] -pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> i32 { +pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { // Wait to create the cycle until both threads have entered db.signal(1); db.wait_for(2); @@ -22,7 +22,7 @@ pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> i32 { } #[salsa::tracked] -pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { +pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { // Wait to create the cycle until both threads have entered db.wait_for(1); db.signal(2); @@ -40,16 +40,22 @@ fn execute() { let input = MyInput::new(&db, -1); - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a(&db, input) - }); + let thread_a = std::thread::Builder::new() + .name("a".to_string()) + .spawn({ + let db = db.clone(); + db.knobs().signal_on_will_block.store(3); + move || a(&db, input) + }) + .unwrap(); - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b(&db, input) - }); + let thread_b = std::thread::Builder::new() + .name("b".to_string()) + .spawn({ + let db = db.clone(); + move || b(&db, input).unwrap() + }) + .unwrap(); // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). // Right now, it panics with a string. @@ -70,9 +76,8 @@ fn execute() { // We expect A to propagate a panic, which causes us to use the sentinel // type `Canceled`. - assert!(thread_a - .join() - .unwrap_err() - .downcast_ref::() - .is_some()); + assert_eq!( + thread_a.join().unwrap().unwrap_err().to_string(), + "cancelled because of propagated panic" + ); } diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs index c03782821..d31819327 100644 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ b/tests/parallel/parallel_cycle_one_recover.rs @@ -1,91 +1,91 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked(recovery_fn=recover)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -// Recover cycle test: +// //! Test for cycle recover spread across two threads. +// //! See `../cycles.rs` for a complete listing of cycle tests, +// //! both intra and cross thread. // -// The pattern is as follows. +// use crate::setup::{Knobs, KnobsDatabase}; // -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected) -// a2 recovery fn executes | -// a1 completes normally | -// b2 completes, recovers -// b1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} +// #[salsa::input] +// pub(crate) struct MyInput { +// field: i32, +// } +// +// #[salsa::tracked] +// pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // Wait to create the cycle until both threads have entered +// db.signal(1); +// db.wait_for(2); +// +// a2(db, input) +// } +// +// #[salsa::tracked(recovery_fn=recover)] +// pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// b1(db, input) +// } +// +// fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { +// dbg!("recover"); +// key.field(db) * 20 + 2 +// } +// +// #[salsa::tracked] +// pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// // Wait to create the cycle until both threads have entered +// db.wait_for(1); +// db.signal(2); +// +// // Wait for thread A to block on this thread +// db.wait_for(3); +// b2(db, input) +// } +// +// #[salsa::tracked] +// pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +// a1(db, input) +// } +// +// // Recover cycle test: +// // +// // The pattern is as follows. +// // +// // Thread A Thread B +// // -------- -------- +// // a1 b1 +// // | wait for stage 1 (blocks) +// // signal stage 1 | +// // wait for stage 2 (blocks) (unblocked) +// // | signal stage 2 +// // (unblocked) wait for stage 3 (blocks) +// // a2 | +// // b1 (blocks -> stage 3) | +// // | (unblocked) +// // | b2 +// // | a1 (cycle detected) +// // a2 recovery fn executes | +// // a1 completes normally | +// // b2 completes, recovers +// // b1 completes, recovers +// +// #[test] +// fn execute() { +// let db = Knobs::default(); +// +// let input = MyInput::new(&db, 1); +// +// let thread_a = std::thread::spawn({ +// let db = db.clone(); +// db.knobs().signal_on_will_block.store(3); +// move || a1(&db, input).unwrap() +// }); +// +// let thread_b = std::thread::spawn({ +// let db = db.clone(); +// move || b1(&db, input).unwrap() +// }); +// +// // We expect that the recovery function yields +// // `1 * 20 + 2`, which is returned (and forwarded) +// // to b1, and from there to a2 and a1. +// assert_eq!(thread_a.join().unwrap(), 22); +// assert_eq!(thread_b.join().unwrap(), 22); +// } diff --git a/tests/preverify-struct-with-leaked-data-2.rs b/tests/preverify-struct-with-leaked-data-2.rs index 5632e990f..5b8d431e4 100644 --- a/tests/preverify-struct-with-leaked-data-2.rs +++ b/tests/preverify-struct-with-leaked-data-2.rs @@ -25,40 +25,44 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn function(db: &dyn Database, input: MyInput) -> (usize, usize) { +fn function(db: &dyn Database, input: MyInput) -> salsa::Result<(usize, usize)> { // Read input 1 - let _field1 = input.field1(db); + let _field1 = input.field1(db)?; // **BAD:** Leak in the value of the counter non-deterministically let counter = COUNTER.with(|c| c.get()); // Create the tracked struct, which (from salsa's POV), only depends on field1; // but which actually depends on the leaked value. - let tracked = MyTracked::new(db, counter); + let tracked = MyTracked::new(db, counter)?; // Read the tracked field - let result = counter_field(db, input, tracked); + let result = counter_field(db, input, tracked)?; // Read input 2. This will cause us to re-execute on revision 2. - let _field2 = input.field2(db); + let _field2 = input.field2(db)?; - (result, tracked.counter(db)) + Ok((result, tracked.counter(db)?)) } #[salsa::tracked] -fn counter_field<'db>(db: &'db dyn Database, input: MyInput, tracked: MyTracked<'db>) -> usize { +fn counter_field<'db>( + db: &'db dyn Database, + input: MyInput, + tracked: MyTracked<'db>, +) -> salsa::Result { // Read input 2. This will cause us to re-execute on revision 2. - let _field2 = input.field2(db); + let _field2 = input.field2(db)?; tracked.counter(db) } #[test] -fn test_leaked_inputs_ignored() { +fn test_leaked_inputs_ignored() -> salsa::Result<()> { let mut db = common::EventLoggerDatabase::default(); let input = MyInput::new(&db, 10, 20); - let result_in_rev_1 = function(&db, input); + let result_in_rev_1 = function(&db, input)?; db.assert_logs(expect![[r#" [ "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", @@ -76,7 +80,7 @@ fn test_leaked_inputs_ignored() { // Also modify the thread-local counter COUNTER.with(|c| c.set(100)); - let result_in_rev_2 = function(&db, input); + let result_in_rev_2 = function(&db, input)?; db.assert_logs(expect![[r#" [ "Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }", @@ -97,4 +101,6 @@ fn test_leaked_inputs_ignored() { // // Contrast with preverify-struct-with-leaked-data-2.rs. assert_eq!(result_in_rev_2, (0, 0)); + + Ok(()) } diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index f709890a0..6adfec7b4 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -25,37 +25,37 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn function(db: &dyn Database, input: MyInput) -> (usize, usize) { +fn function(db: &dyn Database, input: MyInput) -> salsa::Result<(usize, usize)> { // Read input 1 - let _field1 = input.field1(db); + let _field1 = input.field1(db)?; // **BAD:** Leak in the value of the counter non-deterministically let counter = COUNTER.with(|c| c.get()); // Create the tracked struct, which (from salsa's POV), only depends on field1; // but which actually depends on the leaked value. - let tracked = MyTracked::new(db, counter); + let tracked = MyTracked::new(db, counter)?; // Read the tracked field - let result = counter_field(db, tracked); + let result = counter_field(db, tracked)?; // Read input 2. This will cause us to re-execute on revision 2. let _field2 = input.field2(db); - (result, tracked.counter(db)) + Ok((result, tracked.counter(db)?)) } #[salsa::tracked] -fn counter_field<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> usize { +fn counter_field<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> salsa::Result { tracked.counter(db) } #[test] -fn test_leaked_inputs_ignored() { +fn test_leaked_inputs_ignored() -> salsa::Result<()> { let mut db = common::EventLoggerDatabase::default(); let input = MyInput::new(&db, 10, 20); - let result_in_rev_1 = function(&db, input); + let result_in_rev_1 = function(&db, input)?; db.assert_logs(expect![[r#" [ "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", @@ -73,7 +73,7 @@ fn test_leaked_inputs_ignored() { // Also modify the thread-local counter COUNTER.with(|c| c.set(100)); - let result_in_rev_2 = function(&db, input); + let result_in_rev_2 = function(&db, input)?; db.assert_logs(expect![[r#" [ "Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }", @@ -91,4 +91,6 @@ fn test_leaked_inputs_ignored() { // // Contrast with preverify-struct-with-leaked-data-2.rs. assert_eq!(result_in_rev_2, (0, 0)); + + Ok(()) } diff --git a/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs b/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs index a407aee62..513e53140 100644 --- a/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs +++ b/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs @@ -16,22 +16,25 @@ struct MyTracked<'db> { fn tracked_struct_created_in_another_query<'db>( db: &'db dyn salsa::Database, input: MyInput, -) -> MyTracked<'db> { - MyTracked::new(db, input.field(db) * 2) +) -> salsa::Result> { + MyTracked::new(db, input.field(db)? * 2) } #[salsa::tracked] -fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { - let t = tracked_struct_created_in_another_query(db, input); - if input.field(db) != 0 { +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> salsa::Result> { + let t = tracked_struct_created_in_another_query(db, input)?; + if input.field(db)? != 0 { tracked_fn_extra::specify(db, t, 2222); } - t + Ok(t) } #[salsa::tracked(specify)] -fn tracked_fn_extra<'db>(_db: &'db dyn salsa::Database, _input: MyTracked<'db>) -> u32 { - 0 +fn tracked_fn_extra<'db>( + _db: &'db dyn salsa::Database, + _input: MyTracked<'db>, +) -> salsa::Result { + Ok(0) } #[test] @@ -41,5 +44,5 @@ fn tracked_fn_extra<'db>(_db: &'db dyn salsa::Database, _input: MyTracked<'db>) fn execute_when_specified() { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); - let tracked = tracked_fn(&db, input); + let tracked = tracked_fn(&db, input).unwrap(); } diff --git a/tests/synthetic_write.rs b/tests/synthetic_write.rs index 9e3c2f305..df5daa6cd 100644 --- a/tests/synthetic_write.rs +++ b/tests/synthetic_write.rs @@ -14,16 +14,16 @@ struct MyInput { } #[salsa::tracked] -fn tracked_fn(db: &dyn Database, input: MyInput) -> u32 { - input.field(db) * 2 +fn tracked_fn(db: &dyn Database, input: MyInput) -> salsa::Result { + Ok(input.field(db)? * 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = common::ExecuteValidateLoggerDatabase::default(); let input = MyInput::new(&db, 22); - assert_eq!(tracked_fn(&db, input), 44); + assert_eq!(tracked_fn(&db, input)?, 44); db.assert_logs(expect![[r#" [ @@ -34,10 +34,12 @@ fn execute() { db.synthetic_write(Durability::LOW); // Query should re-run - assert_eq!(tracked_fn(&db, input), 44); + assert_eq!(tracked_fn(&db, input)?, 44); db.assert_logs(expect![[r#" [ "salsa_event(DidValidateMemoizedValue { database_key: tracked_fn(Id(0)) })", ]"#]]); + + Ok(()) } diff --git a/tests/tracked-struct-id-field-bad-eq.rs b/tests/tracked-struct-id-field-bad-eq.rs index b003a3053..9790aee70 100644 --- a/tests/tracked-struct-id-field-bad-eq.rs +++ b/tests/tracked-struct-id-field-bad-eq.rs @@ -33,16 +33,19 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn the_fn(db: &dyn Database, input: MyInput) { - let tracked0 = MyTracked::new(db, BadEq::from(input.field(db))); - assert_eq!(tracked0.field(db).field, input.field(db)); +fn the_fn(db: &dyn Database, input: MyInput) -> salsa::Result<()> { + let tracked0 = MyTracked::new(db, BadEq::from(input.field(db)?))?; + assert_eq!(tracked0.field(db)?.field, input.field(db)?); + Ok(()) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, true); - the_fn(&db, input); + the_fn(&db, input)?; input.set_field(&mut db).to(false); - the_fn(&db, input); + the_fn(&db, input)?; + + Ok(()) } diff --git a/tests/tracked-struct-id-field-bad-hash.rs b/tests/tracked-struct-id-field-bad-hash.rs index 8a391b3b6..7da796ba7 100644 --- a/tests/tracked-struct-id-field-bad-hash.rs +++ b/tests/tracked-struct-id-field-bad-hash.rs @@ -37,17 +37,19 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn the_fn(db: &dyn Db, input: MyInput) { - let tracked0 = MyTracked::new(db, BadHash::from(input.field(db))); - assert_eq!(tracked0.field(db).field, input.field(db)); +fn the_fn(db: &dyn Db, input: MyInput) -> salsa::Result<()> { + let tracked0 = MyTracked::new(db, BadHash::from(input.field(db)?))?; + assert_eq!(tracked0.field(db)?.field, input.field(db)?); + Ok(()) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, true); - the_fn(&db, input); + the_fn(&db, input)?; input.set_field(&mut db).to(false); - the_fn(&db, input); + the_fn(&db, input)?; + Ok(()) } diff --git a/tests/tracked-struct-unchanged-in-new-rev.rs b/tests/tracked-struct-unchanged-in-new-rev.rs index e4633740f..cd4f6b947 100644 --- a/tests/tracked-struct-unchanged-in-new-rev.rs +++ b/tests/tracked-struct-unchanged-in-new-rev.rs @@ -12,12 +12,12 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn tracked_fn(db: &dyn Db, input: MyInput) -> MyTracked<'_> { - MyTracked::new(db, input.field(db) / 2) +fn tracked_fn(db: &dyn Db, input: MyInput) -> salsa::Result> { + MyTracked::new(db, input.field(db)? / 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input1 = MyInput::new(&db, 22); @@ -27,8 +27,9 @@ fn execute() { // modify the input and change the revision input1.set_field(&mut db).to(24); - let tracked2 = tracked_fn(&db, input2); + let tracked2 = tracked_fn(&db, input2)?; // this should not panic - tracked2.field(&db); + tracked2.field(&db)?; + Ok(()) } diff --git a/tests/tracked-struct-value-field-bad-eq.rs b/tests/tracked-struct-value-field-bad-eq.rs index f7444faea..481134217 100644 --- a/tests/tracked-struct-value-field-bad-eq.rs +++ b/tests/tracked-struct-value-field-bad-eq.rs @@ -37,27 +37,27 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn the_fn(db: &dyn Database, input: MyInput) -> bool { - let tracked = make_tracked_struct(db, input); +fn the_fn(db: &dyn Database, input: MyInput) -> salsa::Result { + let tracked = make_tracked_struct(db, input)?; read_tracked_struct(db, tracked) } #[salsa::tracked] -fn make_tracked_struct(db: &dyn Database, input: MyInput) -> MyTracked<'_> { - MyTracked::new(db, BadEq::from(input.field(db))) +fn make_tracked_struct(db: &dyn Database, input: MyInput) -> salsa::Result> { + MyTracked::new(db, BadEq::from(input.field(db)?)) } #[salsa::tracked] -fn read_tracked_struct<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> bool { - tracked.field(db).field +fn read_tracked_struct<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> salsa::Result { + Ok(tracked.field(db)?.field) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = common::ExecuteValidateLoggerDatabase::default(); let input = MyInput::new(&db, true); - let result = the_fn(&db, input); + let result = the_fn(&db, input)?; assert!(result); db.assert_logs(expect![[r#" @@ -69,7 +69,7 @@ fn execute() { // Update the input to `false` and re-execute. input.set_field(&mut db).to(false); - let result = the_fn(&db, input); + let result = the_fn(&db, input)?; // If the `Eq` impl were working properly, we would // now return `false`. But because the `Eq` is considered @@ -82,4 +82,6 @@ fn execute() { "salsa_event(DidValidateMemoizedValue { database_key: read_tracked_struct(Id(400)) })", "salsa_event(DidValidateMemoizedValue { database_key: the_fn(Id(0)) })", ]"#]]); + + Ok(()) } diff --git a/tests/tracked-struct-value-field-not-eq.rs b/tests/tracked-struct-value-field-not-eq.rs index eaf4a30c1..11c92e996 100644 --- a/tests/tracked-struct-value-field-not-eq.rs +++ b/tests/tracked-struct-value-field-not-eq.rs @@ -28,17 +28,20 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn the_fn(db: &dyn Database, input: MyInput) { - let tracked0 = MyTracked::new(db, NotEq::from(input.field(db))); - assert_eq!(tracked0.field(db).field, input.field(db)); +fn the_fn(db: &dyn Database, input: MyInput) -> salsa::Result<()> { + let tracked0 = MyTracked::new(db, NotEq::from(input.field(db)?))?; + assert_eq!(tracked0.field(db)?.field, input.field(db)?); + + Ok(()) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, true); - the_fn(&db, input); + the_fn(&db, input)?; input.set_field(&mut db).to(false); - the_fn(&db, input); + the_fn(&db, input)?; + Ok(()) } diff --git a/tests/tracked_fn_constant.rs b/tests/tracked_fn_constant.rs index c6753ebf4..7483ae237 100644 --- a/tests/tracked_fn_constant.rs +++ b/tests/tracked_fn_constant.rs @@ -7,23 +7,25 @@ use crate::common::LogDatabase; mod common; #[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database) -> u32 { - 44 +fn tracked_fn(db: &dyn salsa::Database) -> salsa::Result { + Ok(44) } #[salsa::tracked] -fn tracked_custom_db(db: &dyn LogDatabase) -> u32 { - 44 +fn tracked_custom_db(db: &dyn LogDatabase) -> salsa::Result { + Ok(44) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); - assert_eq!(tracked_fn(&db), 44); + assert_eq!(tracked_fn(&db)?, 44); + Ok(()) } #[test] -fn execute_custom() { +fn execute_custom() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); - assert_eq!(tracked_custom_db(&db), 44); + assert_eq!(tracked_custom_db(&db)?, 44); + Ok(()) } diff --git a/tests/tracked_fn_high_durability_dependency.rs b/tests/tracked_fn_high_durability_dependency.rs index a05be178f..2b6d68ac4 100644 --- a/tests/tracked_fn_high_durability_dependency.rs +++ b/tests/tracked_fn_high_durability_dependency.rs @@ -10,12 +10,12 @@ struct MyInput { } #[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> u32 { - input.field(db) * 2 +fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> salsa::Result { + Ok(input.field(db)? * 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::default(); let input_high = MyInput::new(&mut db, 0); @@ -24,7 +24,7 @@ fn execute() { .with_durability(Durability::HIGH) .to(2200); - assert_eq!(tracked_fn(&db, input_high), 4400); + assert_eq!(tracked_fn(&db, input_high)?, 4400); // Changing the value should re-execute the query input_high @@ -32,5 +32,7 @@ fn execute() { .with_durability(Durability::HIGH) .to(2201); - assert_eq!(tracked_fn(&db, input_high), 4402); + assert_eq!(tracked_fn(&db, input_high)?, 4402); + + Ok(()) } diff --git a/tests/tracked_fn_no_eq.rs b/tests/tracked_fn_no_eq.rs index 6f223b791..00e87611d 100644 --- a/tests/tracked_fn_no_eq.rs +++ b/tests/tracked_fn_no_eq.rs @@ -10,33 +10,33 @@ struct Input { } #[salsa::tracked(no_eq)] -fn abs_float(db: &dyn LogDatabase, input: Input) -> f32 { - let number = input.number(db); +fn abs_float(db: &dyn LogDatabase, input: Input) -> salsa::Result { + let number = input.number(db)?; db.push_log(format!("abs_float({number})")); - number.abs() as f32 + Ok(number.abs() as f32) } #[salsa::tracked] -fn derived(db: &dyn LogDatabase, input: Input) -> u32 { - let x = abs_float(db, input); +fn derived(db: &dyn LogDatabase, input: Input) -> salsa::Result { + let x = abs_float(db, input)?; db.push_log("derived".to_string()); - x as u32 + Ok(x as u32) } #[test] -fn invoke() { +fn invoke() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let input = Input::new(&db, 5); - let x = derived(&db, input); + let x = derived(&db, input)?; assert_eq!(x, 5); input.set_number(&mut db).to(-5); // Derived should re-execute even the result of `abs_float` is the same. - let x = derived(&db, input); + let x = derived(&db, input)?; assert_eq!(x, 5); db.assert_logs(expect![[r#" @@ -46,4 +46,6 @@ fn invoke() { "abs_float(-5)", "derived", ]"#]]); + + Ok(()) } diff --git a/tests/tracked_fn_on_input.rs b/tests/tracked_fn_on_input.rs index e588a40a9..1248f50e8 100644 --- a/tests/tracked_fn_on_input.rs +++ b/tests/tracked_fn_on_input.rs @@ -8,13 +8,15 @@ struct MyInput { } #[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> u32 { - input.field(db) * 2 +fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> salsa::Result { + Ok(input.field(db)? * 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); - assert_eq!(tracked_fn(&db, input), 44); + assert_eq!(tracked_fn(&db, input)?, 44); + + Ok(()) } diff --git a/tests/tracked_fn_on_input_with_high_durability.rs b/tests/tracked_fn_on_input_with_high_durability.rs index 17a2dd9a7..b8fdcbda7 100644 --- a/tests/tracked_fn_on_input_with_high_durability.rs +++ b/tests/tracked_fn_on_input_with_high_durability.rs @@ -13,18 +13,18 @@ struct MyInput { } #[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> u32 { - input.field(db) * 2 +fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> salsa::Result { + Ok(input.field(db)? * 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = EventLoggerDatabase::default(); let input_low = MyInput::new(&db, 22); let input_high = MyInput::builder(2200).durability(Durability::HIGH).new(&db); - assert_eq!(tracked_fn(&db, input_low), 44); - assert_eq!(tracked_fn(&db, input_high), 4400); + assert_eq!(tracked_fn(&db, input_low)?, 44); + assert_eq!(tracked_fn(&db, input_high)?, 4400); db.assert_logs(expect![[r#" [ @@ -36,8 +36,8 @@ fn execute() { db.synthetic_write(Durability::LOW); - assert_eq!(tracked_fn(&db, input_low), 44); - assert_eq!(tracked_fn(&db, input_high), 4400); + assert_eq!(tracked_fn(&db, input_low)?, 44); + assert_eq!(tracked_fn(&db, input_high)?, 4400); // FIXME: There's currently no good way to verify whether an input was validated using shallow or deep comparison. // All we can do for now is verify that the values were validated. @@ -52,4 +52,6 @@ fn execute() { "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", "Event { thread_id: ThreadId(2), kind: DidValidateMemoizedValue { database_key: tracked_fn(Id(1)) } }", ]"#]]); + + Ok(()) } diff --git a/tests/tracked_fn_on_interned.rs b/tests/tracked_fn_on_interned.rs index b551b880d..e791460c2 100644 --- a/tests/tracked_fn_on_interned.rs +++ b/tests/tracked_fn_on_interned.rs @@ -7,14 +7,16 @@ struct Name<'db> { } #[salsa::tracked] -fn tracked_fn<'db>(db: &'db dyn salsa::Database, name: Name<'db>) -> String { - name.name(db).clone() +fn tracked_fn<'db>(db: &'db dyn salsa::Database, name: Name<'db>) -> salsa::Result { + Ok(name.name(db).clone()) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); - let name = Name::new(&db, "Salsa".to_string()); + let name = Name::new(&db, "Salsa".to_string())?; - assert_eq!(tracked_fn(&db, name), "Salsa"); + assert_eq!(tracked_fn(&db, name)?, "Salsa"); + + Ok(()) } diff --git a/tests/tracked_fn_on_tracked.rs b/tests/tracked_fn_on_tracked.rs index 967bbd558..ed6cc6d59 100644 --- a/tests/tracked_fn_on_tracked.rs +++ b/tests/tracked_fn_on_tracked.rs @@ -12,13 +12,14 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> MyTracked<'_> { - MyTracked::new(db, input.field(db) * 2) +fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> salsa::Result> { + MyTracked::new(db, input.field(db)? * 2) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); - assert_eq!(tracked_fn(&db, input).field(&db), 44); + assert_eq!(tracked_fn(&db, input)?.field(&db)?, 44); + Ok(()) } diff --git a/tests/tracked_fn_on_tracked_specify.rs b/tests/tracked_fn_on_tracked_specify.rs index 70e4997a2..92adf12c1 100644 --- a/tests/tracked_fn_on_tracked_specify.rs +++ b/tests/tracked_fn_on_tracked_specify.rs @@ -13,33 +13,40 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { - let t = MyTracked::new(db, input.field(db) * 2); - if input.field(db) != 0 { +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput) -> salsa::Result> { + let t = MyTracked::new(db, input.field(db)? * 2)?; + if input.field(db)? != 0 { tracked_fn_extra::specify(db, t, 2222); } - t + Ok(t) } #[salsa::tracked(specify)] -fn tracked_fn_extra<'db>(_db: &'db dyn salsa::Database, _input: MyTracked<'db>) -> u32 { - 0 +fn tracked_fn_extra<'db>( + _db: &'db dyn salsa::Database, + _input: MyTracked<'db>, +) -> salsa::Result { + Ok(0) } #[test] -fn execute_when_specified() { +fn execute_when_specified() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 22); - let tracked = tracked_fn(&db, input); - assert_eq!(tracked.field(&db), 44); - assert_eq!(tracked_fn_extra(&db, tracked), 2222); + let tracked = tracked_fn(&db, input)?; + assert_eq!(tracked.field(&db)?, 44); + assert_eq!(tracked_fn_extra(&db, tracked)?, 2222); + + Ok(()) } #[test] -fn execute_when_not_specified() { +fn execute_when_not_specified() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let input = MyInput::new(&db, 0); - let tracked = tracked_fn(&db, input); - assert_eq!(tracked.field(&db), 0); - assert_eq!(tracked_fn_extra(&db, tracked), 0); + let tracked = tracked_fn(&db, input)?; + assert_eq!(tracked.field(&db)?, 0); + assert_eq!(tracked_fn_extra(&db, tracked)?, 0); + + Ok(()) } diff --git a/tests/tracked_fn_read_own_entity.rs b/tests/tracked_fn_read_own_entity.rs index 48ed793d1..fab452838 100644 --- a/tests/tracked_fn_read_own_entity.rs +++ b/tests/tracked_fn_read_own_entity.rs @@ -13,9 +13,9 @@ struct MyInput { } #[salsa::tracked] -fn final_result(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn final_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("final_result({:?})", input)); - intermediate_result(db, input).field(db) * 2 + Ok(intermediate_result(db, input)?.field(db)? * 2) } #[salsa::tracked] @@ -24,19 +24,19 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> MyTracked<'_> { +fn intermediate_result(db: &dyn LogDatabase, input: MyInput) -> salsa::Result> { db.push_log(format!("intermediate_result({:?})", input)); - let tracked = MyTracked::new(db, input.field(db) / 2); - let _ = tracked.field(db); // read the field of an entity we created - tracked + let tracked = MyTracked::new(db, input.field(db)? / 2)?; + let _ = tracked.field(db)?; // read the field of an entity we created + Ok(tracked) } #[test] -fn one_entity() { +fn one_entity() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "final_result(MyInput { [salsa id]: Id(0), field: 22 })", @@ -46,28 +46,30 @@ fn one_entity() { // Intermediate result is the same, so final result does // not need to be recomputed: input.set_field(&mut db).to(23); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 23 })", ]"#]]); input.set_field(&mut db).to(24); - assert_eq!(final_result(&db, input), 24); + assert_eq!(final_result(&db, input)?, 24); db.assert_logs(expect![[r#" [ "intermediate_result(MyInput { [salsa id]: Id(0), field: 24 })", "final_result(MyInput { [salsa id]: Id(0), field: 24 })", ]"#]]); + + Ok(()) } /// Create and mutate a distinct input. No re-execution required. #[test] -fn red_herring() { +fn red_herring() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" [ "final_result(MyInput { [salsa id]: Id(0), field: 22 })", @@ -81,7 +83,9 @@ fn red_herring() { input2.set_field(&mut db).to(66); // Re-run the query on the original input. Nothing re-executes! - assert_eq!(final_result(&db, input), 22); + assert_eq!(final_result(&db, input)?, 22); db.assert_logs(expect![[r#" []"#]]); + + Ok(()) } diff --git a/tests/tracked_fn_read_own_specify.rs b/tests/tracked_fn_read_own_specify.rs index 426d18a76..8b7413088 100644 --- a/tests/tracked_fn_read_own_specify.rs +++ b/tests/tracked_fn_read_own_specify.rs @@ -14,24 +14,24 @@ struct MyTracked<'db> { } #[salsa::tracked] -fn tracked_fn(db: &dyn LogDatabase, input: MyInput) -> u32 { +fn tracked_fn(db: &dyn LogDatabase, input: MyInput) -> salsa::Result { db.push_log(format!("tracked_fn({input:?})")); - let t = MyTracked::new(db, input.field(db) * 2); + let t = MyTracked::new(db, input.field(db)? * 2)?; tracked_fn_extra::specify(db, t, 2222); tracked_fn_extra(db, t) } #[salsa::tracked(specify)] -fn tracked_fn_extra<'db>(db: &dyn LogDatabase, input: MyTracked<'db>) -> u32 { +fn tracked_fn_extra<'db>(db: &dyn LogDatabase, input: MyTracked<'db>) -> salsa::Result { db.push_log(format!("tracked_fn_extra({input:?})")); - 0 + Ok(0) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = common::LoggerDatabase::default(); let input = MyInput::new(&db, 22); - assert_eq!(tracked_fn(&db, input), 2222); + assert_eq!(tracked_fn(&db, input)?, 2222); db.assert_logs(expect![[r#" [ "tracked_fn(MyInput { [salsa id]: Id(0), field: 22 })", @@ -42,6 +42,8 @@ fn execute() { db.synthetic_write(salsa::Durability::LOW); // Re-run the query on the original input. Nothing re-executes! - assert_eq!(tracked_fn(&db, input), 2222); + assert_eq!(tracked_fn(&db, input)?, 2222); db.assert_logs(expect!["[]"]); + + Ok(()) } diff --git a/tests/tracked_fn_return_ref.rs b/tests/tracked_fn_return_ref.rs index ecd91a17c..cf7352cb7 100644 --- a/tests/tracked_fn_return_ref.rs +++ b/tests/tracked_fn_return_ref.rs @@ -6,17 +6,17 @@ struct Input { } #[salsa::tracked(return_ref)] -fn test(db: &dyn salsa::Database, input: Input) -> Vec { - (0..input.number(db)) +fn test(db: &dyn salsa::Database, input: Input) -> salsa::Result> { + Ok((0..input.number(db)?) .map(|i| format!("test {}", i)) - .collect() + .collect()) } #[test] -fn invoke() { +fn invoke() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, 3); - let x: &Vec = test(db, input); + let x: &Vec = test(db, input)?; expect_test::expect![[r#" [ "test 0", @@ -25,5 +25,6 @@ fn invoke() { ] "#]] .assert_debug_eq(x); + Ok(()) }) } diff --git a/tests/tracked_method.rs b/tests/tracked_method.rs index 0291a748f..d557192da 100644 --- a/tests/tracked_method.rs +++ b/tests/tracked_method.rs @@ -3,7 +3,7 @@ #![allow(warnings)] trait TrackedTrait { - fn tracked_trait_fn(self, db: &dyn salsa::Database) -> u32; + fn tracked_trait_fn(self, db: &dyn salsa::Database) -> salsa::Result; } #[salsa::input] @@ -14,29 +14,31 @@ struct MyInput { #[salsa::tracked] impl MyInput { #[salsa::tracked] - fn tracked_fn(self, db: &dyn salsa::Database) -> u32 { - self.field(db) * 2 + fn tracked_fn(self, db: &dyn salsa::Database) -> salsa::Result { + Ok(self.field(db)? * 2) } #[salsa::tracked(return_ref)] - fn tracked_fn_ref(self, db: &dyn salsa::Database) -> u32 { - self.field(db) * 3 + fn tracked_fn_ref(self, db: &dyn salsa::Database) -> salsa::Result { + Ok(self.field(db)? * 3) } } #[salsa::tracked] impl TrackedTrait for MyInput { #[salsa::tracked] - fn tracked_trait_fn(self, db: &dyn salsa::Database) -> u32 { - self.field(db) * 4 + fn tracked_trait_fn(self, db: &dyn salsa::Database) -> salsa::Result { + Ok(self.field(db)? * 4) } } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { let mut db = salsa::DatabaseImpl::new(); let object = MyInput::new(&mut db, 22); // assert_eq!(object.tracked_fn(&db), 44); // assert_eq!(*object.tracked_fn_ref(&db), 66); - assert_eq!(object.tracked_trait_fn(&db), 88); + assert_eq!(object.tracked_trait_fn(&db)?, 88); + + Ok(()) } diff --git a/tests/tracked_method_inherent_return_ref.rs b/tests/tracked_method_inherent_return_ref.rs index 462a24da5..2329a23fc 100644 --- a/tests/tracked_method_inherent_return_ref.rs +++ b/tests/tracked_method_inherent_return_ref.rs @@ -8,18 +8,18 @@ struct Input { #[salsa::tracked] impl Input { #[salsa::tracked(return_ref)] - fn test(self, db: &dyn salsa::Database) -> Vec { - (0..self.number(db)) + fn test(self, db: &dyn salsa::Database) -> salsa::Result> { + Ok((0..self.number(db)?) .map(|i| format!("test {}", i)) - .collect() + .collect()) } } #[test] -fn invoke() { +fn invoke() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, 3); - let x: &Vec = input.test(db); + let x: &Vec = input.test(db)?; expect_test::expect![[r#" [ "test 0", @@ -28,5 +28,7 @@ fn invoke() { ] "#]] .assert_debug_eq(x); + + Ok(()) }) } diff --git a/tests/tracked_method_on_tracked_struct.rs b/tests/tracked_method_on_tracked_struct.rs index 1febcfd36..3e38df2db 100644 --- a/tests/tracked_method_on_tracked_struct.rs +++ b/tests/tracked_method_on_tracked_struct.rs @@ -11,8 +11,8 @@ pub struct Input { #[salsa::tracked] impl Input { #[salsa::tracked] - pub fn source_tree(self, db: &dyn Database) -> SourceTree<'_> { - SourceTree::new(db, self.name(db).clone()) + pub fn source_tree(self, db: &dyn Database) -> salsa::Result> { + SourceTree::new(db, self.name(db)?.clone()) } } @@ -24,43 +24,47 @@ pub struct SourceTree<'db> { #[salsa::tracked] impl<'db1> SourceTree<'db1> { #[salsa::tracked(return_ref)] - pub fn inherent_item_name(self, db: &'db1 dyn Database) -> String { + pub fn inherent_item_name(self, db: &'db1 dyn Database) -> salsa::Result { self.name(db) } } trait ItemName<'db1> { - fn trait_item_name(self, db: &'db1 dyn Database) -> &'db1 String; + fn trait_item_name(self, db: &'db1 dyn Database) -> salsa::Result<&'db1 String>; } #[salsa::tracked] impl<'db1> ItemName<'db1> for SourceTree<'db1> { #[salsa::tracked(return_ref)] - fn trait_item_name(self, db: &'db1 dyn Database) -> String { + fn trait_item_name(self, db: &'db1 dyn Database) -> salsa::Result { self.name(db) } } #[test] -fn test_inherent() { +fn test_inherent() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, "foo".to_string()); - let source_tree = input.source_tree(db); + let source_tree = input.source_tree(db)?; expect_test::expect![[r#" "foo" "#]] - .assert_debug_eq(source_tree.inherent_item_name(db)); + .assert_debug_eq(source_tree.inherent_item_name(db)?); + + Ok(()) }) } #[test] -fn test_trait() { +fn test_trait() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, "foo".to_string()); - let source_tree = input.source_tree(db); + let source_tree = input.source_tree(db)?; expect_test::expect![[r#" "foo" "#]] - .assert_debug_eq(source_tree.trait_item_name(db)); + .assert_debug_eq(source_tree.trait_item_name(db)?); + + Ok(()) }) } diff --git a/tests/tracked_method_trait_return_ref.rs b/tests/tracked_method_trait_return_ref.rs index 3c9fa5cc2..d43c6c072 100644 --- a/tests/tracked_method_trait_return_ref.rs +++ b/tests/tracked_method_trait_return_ref.rs @@ -6,24 +6,24 @@ struct Input { } trait Trait { - fn test(self, db: &dyn salsa::Database) -> &Vec; + fn test(self, db: &dyn salsa::Database) -> salsa::Result<&Vec>; } #[salsa::tracked] impl Trait for Input { #[salsa::tracked(return_ref)] - fn test(self, db: &dyn salsa::Database) -> Vec { - (0..self.number(db)) + fn test(self, db: &dyn salsa::Database) -> salsa::Result> { + Ok((0..self.number(db)?) .map(|i| format!("test {}", i)) - .collect() + .collect()) } } #[test] -fn invoke() { +fn invoke() -> salsa::Result<()> { salsa::DatabaseImpl::new().attach(|db| { let input = Input::new(db, 3); - let x: &Vec = input.test(db); + let x: &Vec = input.test(db)?; expect_test::expect![[r#" [ "test 0", @@ -32,5 +32,7 @@ fn invoke() { ] "#]] .assert_debug_eq(x); + + Ok(()) }) } diff --git a/tests/tracked_struct_durability.rs b/tests/tracked_struct_durability.rs index c1fb8b2b3..3b3979708 100644 --- a/tests/tracked_struct_durability.rs +++ b/tests/tracked_struct_durability.rs @@ -51,37 +51,37 @@ struct Inference<'db> { } #[salsa::tracked] -fn index<'db>(db: &'db dyn Db, file: File) -> Index<'db> { - let _ = file.field(db); - Index::new(db, Definitions::new(db, Definition::new(db, file))) +fn index<'db>(db: &'db dyn Db, file: File) -> salsa::Result> { + let _ = file.field(db)?; + Index::new(db, Definitions::new(db, Definition::new(db, file)?)?) } #[salsa::tracked] -fn definitions<'db>(db: &'db dyn Db, file: File) -> Definitions<'db> { - index(db, file).definitions(db) +fn definitions<'db>(db: &'db dyn Db, file: File) -> salsa::Result> { + index(db, file)?.definitions(db) } #[salsa::tracked] -fn infer<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Inference<'db> { - let file = definition.file(db); - if file.field(db) < 1 { +fn infer<'db>(db: &'db dyn Db, definition: Definition<'db>) -> salsa::Result> { + let file = definition.file(db)?; + if file.field(db)? < 1 { let dependent_file = db.file(1); - infer(db, definitions(db, dependent_file).definition(db)) + infer(db, definitions(db, dependent_file)?.definition(db)?) } else { - db.file(0).field(db); - index(db, file); + db.file(0).field(db)?; + index(db, file)?; Inference::new(db, definition) } } #[salsa::tracked] -fn check<'db>(db: &'db dyn Db, file: File) -> Inference<'db> { - let defs = definitions(db, file); - infer(db, defs.definition(db)) +fn check<'db>(db: &'db dyn Db, file: File) -> salsa::Result> { + let defs = definitions(db, file)?; + infer(db, defs.definition(db)?) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { #[salsa::db] #[derive(Default)] struct Database { @@ -117,7 +117,10 @@ fn execute() { // check(0) -> infer(0) -> definitions(0) -> index(0) // \-> infer(1) -> definitions(1) -> index(1) - assert_eq!(check(&db, file0).definition(&db).file(&db).field(&db), 1); + assert_eq!( + check(&db, file0)?.definition(&db)?.file(&db)?.field(&db)?, + 1 + ); // update the low durability file 0 file0.set_field(&mut db).to(0); @@ -127,5 +130,7 @@ fn execute() { // Definition(1), so we never validate Definition(1) in R2, so when we try to verify // Definition.file(1) (as an input of infer(1) ) we hit a panic for trying to use a struct that // isn't validated in R2. - check(&db, file0); + check(&db, file0)?; + + Ok(()) } diff --git a/tests/tracked_with_struct_db.rs b/tests/tracked_with_struct_db.rs index 6c6d2ab84..821a6e9f5 100644 --- a/tests/tracked_with_struct_db.rs +++ b/tests/tracked_with_struct_db.rs @@ -22,18 +22,18 @@ enum MyList<'db> { } #[salsa::tracked] -fn create_tracked_list(db: &dyn Database, input: MyInput) -> MyTracked<'_> { - let t0 = MyTracked::new(db, input, MyList::None); - let t1 = MyTracked::new(db, input, MyList::Next(t0)); - t1 +fn create_tracked_list(db: &dyn Database, input: MyInput) -> salsa::Result> { + let t0 = MyTracked::new(db, input, MyList::None)?; + let t1 = MyTracked::new(db, input, MyList::Next(t0))?; + Ok(t1) } #[test] -fn execute() { +fn execute() -> salsa::Result<()> { DatabaseImpl::new().attach(|db| { let input = MyInput::new(db, "foo".to_string()); - let t0: MyTracked = create_tracked_list(db, input); - let t1 = create_tracked_list(db, input); + let t0: MyTracked = create_tracked_list(db, input)?; + let t1 = create_tracked_list(db, input)?; expect_test::expect![[r#" MyTracked { [salsa id]: Id(401), @@ -55,5 +55,7 @@ fn execute() { "#]] .assert_debug_eq(&t0); assert_eq!(t0, t1); + + Ok(()) }) } diff --git a/tests/warnings/needless_lifetimes.rs b/tests/warnings/needless_lifetimes.rs index 0eb9198d0..cf6dc8cca 100644 --- a/tests/warnings/needless_lifetimes.rs +++ b/tests/warnings/needless_lifetimes.rs @@ -10,13 +10,15 @@ pub struct SourceTree<'db> {} #[salsa::tracked] impl<'db> SourceTree<'db> { #[salsa::tracked(return_ref)] - pub fn all_items(self, _db: &'db dyn Db) -> Vec { + pub fn all_items(self, _db: &'db dyn Db) -> salsa::Result> { todo!() } } #[salsa::tracked(return_ref)] -fn use_tree<'db>(_db: &'db dyn Db, _tree: SourceTree<'db>) {} +fn use_tree<'db>(_db: &'db dyn Db, _tree: SourceTree<'db>) -> salsa::Result<()> { + Ok(()) +} #[allow(unused)] fn use_it(db: &dyn Db, tree: SourceTree) { From b037e7b10fc77612ab822fc7faf4e5fc3983e7e1 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 2 Sep 2024 08:03:56 +0200 Subject: [PATCH 02/10] Box error kind --- src/result.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/result.rs b/src/result.rs index a94781858..fbafe6a36 100644 --- a/src/result.rs +++ b/src/result.rs @@ -8,27 +8,27 @@ pub type Result = std::result::Result; #[derive(Debug)] pub struct Error { - kind: ErrorKind, + kind: Box, } impl Error { pub(crate) fn cancelled(reason: Cancelled) -> Self { Error { - kind: ErrorKind::Cancelled(reason), + kind: Box::new(ErrorKind::Cancelled(reason)), } } pub(crate) fn cycle(cycle: Cycle) -> Self { Self { - kind: ErrorKind::Cycle(CycleError { + kind: Box::new(ErrorKind::Cycle(CycleError { cycle, bomb: DropBomb::new("TODO"), - }), + })), } } pub(crate) fn into_cycle(self) -> std::result::Result { - match self.kind { + match *self.kind { ErrorKind::Cycle(cycle) => Ok(cycle.take_cycle()), _ => Err(self), } @@ -38,14 +38,14 @@ impl Error { impl From for Error { fn from(value: CycleError) -> Self { Self { - kind: ErrorKind::Cycle(value), + kind: Box::new(ErrorKind::Cycle(value)), } } } impl std::fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.kind { + match &*self.kind { ErrorKind::Cycle(cycle) => { write!(f, "cycle detected: {:?}", cycle) } From 1530727cb3b23b6fa84800344c89973c98d6c8e4 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 2 Sep 2024 08:06:46 +0200 Subject: [PATCH 03/10] Fix new nightly clippy warnings --- src/function.rs | 2 +- src/function/execute.rs | 2 +- src/function/fetch.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/function.rs b/src/function.rs index 1e9fe6301..42ed227c4 100644 --- a/src/function.rs +++ b/src/function.rs @@ -165,7 +165,7 @@ where zalsa: &'db Zalsa, id: Id, memo: memo::Memo>, - ) -> Option<&C::Output<'db>> { + ) -> Option<&'db C::Output<'db>> { let memo = Arc::new(memo); let value = unsafe { // Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this diff --git a/src/function/execute.rs b/src/function/execute.rs index 65788bdf8..ce8576815 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -25,7 +25,7 @@ where db: &'db C::DbView, active_query: ActiveQueryGuard<'_>, opt_old_memo: Option>>>, - ) -> crate::Result>> { + ) -> crate::Result>> { let zalsa = db.zalsa(); let revision_now = zalsa.current_revision(); let database_key_index = active_query.database_key_index; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index d8016b1a7..ee7c1f131 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -6,7 +6,7 @@ impl IngredientImpl where C: Configuration, { - pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> crate::Result<&C::Output<'db>> { + pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> crate::Result<&'db C::Output<'db>> { let (zalsa, zalsa_local) = db.zalsas(); zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database())?; From b600f3d149509701ffb506b65fa7f7b908790c1f Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 2 Sep 2024 17:10:47 +0200 Subject: [PATCH 04/10] Migrate cycle handling test --- src/lib.rs | 1 + src/result.rs | 13 +- tests/cycles.rs | 903 +++++++++++++++++++++++++----------------------- 3 files changed, 474 insertions(+), 443 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1947ec123..38085c575 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,6 +81,7 @@ pub mod plumbing { pub use crate::ingredient::Jar; pub use crate::ingredient::JarAux; pub use crate::key::DatabaseKeyIndex; + pub use crate::result::error_as_cycle; pub use crate::revision::Revision; pub use crate::runtime::stamp; pub use crate::runtime::Runtime; diff --git a/src/result.rs b/src/result.rs index fbafe6a36..a44ef50c1 100644 --- a/src/result.rs +++ b/src/result.rs @@ -5,8 +5,14 @@ use std::fmt::Debug; pub type Result = std::result::Result; -#[derive(Debug)] +pub fn error_as_cycle(error: &Error) -> Option<&Cycle> { + match &*error.kind { + ErrorKind::Cycle(error) => Some(&error.cycle), + _ => None, + } +} +#[derive(Debug)] pub struct Error { kind: Box, } @@ -78,11 +84,6 @@ impl CycleError { // FIXME implement drop for Cancelled. /// A panic payload indicating that execution of a salsa query was cancelled. -/// -/// This can occur for a few reasons: -/// * -/// * -/// * #[derive(Debug)] #[non_exhaustive] pub(crate) enum Cancelled { diff --git a/tests/cycles.rs b/tests/cycles.rs index 1b665bd40..634101bd2 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -1,438 +1,467 @@ -// FIXME +#![allow(warnings)] -// #![allow(warnings)] -// -// use std::panic::{RefUnwindSafe, UnwindSafe}; -// -// use expect_test::expect; -// use salsa::DatabaseImpl; -// use salsa::Durability; -// -// // Axes: -// // -// // Threading -// // * Intra-thread -// // * Cross-thread -- part of cycle is on one thread, part on another -// // -// // Recovery strategies: -// // * Panic -// // * Fallback -// // * Mixed -- multiple strategies within cycle participants -// // -// // Across revisions: -// // * N/A -- only one revision -// // * Present in new revision, not old -// // * Present in old revision, not new -// // * Present in both revisions -// // -// // Dependencies -// // * Tracked -// // * Untracked -- cycle participant(s) contain untracked reads -// // -// // Layers -// // * Direct -- cycle participant is directly invoked from test -// // * Indirect -- invoked a query that invokes the cycle -// // -// // -// // | Thread | Recovery | Old, New | Dep style | Layers | Test Name | -// // | ------ | -------- | -------- | --------- | ------ | --------- | -// // | Intra | Panic | N/A | Tracked | direct | cycle_memoized | -// // | Intra | Panic | N/A | Untracked | direct | cycle_volatile | -// // | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | -// // | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | -// // | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | -// // | Intra | Fallback | New | Tracked | direct | cycle_appears | -// // | Intra | Fallback | Old | Tracked | direct | cycle_disappears | -// // | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | -// // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | -// // | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | -// // | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | -// // | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | -// // | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | -// // | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | -// -// #[derive(PartialEq, Eq, Hash, Clone, Debug)] -// struct Error { -// cycle: Vec, -// } -// -// use salsa::Database as Db; -// use salsa::Setter; -// -// #[salsa::input] -// struct MyInput {} -// -// #[salsa::tracked] -// fn memoized_a(db: &dyn Db, input: MyInput) -> salsa::Result<()> { -// memoized_b(db, input) -// } -// -// #[salsa::tracked] -// fn memoized_b(db: &dyn Db, input: MyInput) -> salsa::Result<()> { -// memoized_a(db, input) -// } -// -// #[salsa::tracked] -// fn volatile_a(db: &dyn Db, input: MyInput) -> salsa::Result<()> { -// db.report_untracked_read(); -// volatile_b(db, input) -// } -// -// #[salsa::tracked] -// fn volatile_b(db: &dyn Db, input: MyInput) -> salsa::Result<()> { -// db.report_untracked_read(); -// volatile_a(db, input) -// } -// -// /// The queries A, B, and C in `Database` can be configured -// /// to invoke one another in arbitrary ways using this -// /// enum. -// #[derive(Debug, Copy, Clone, PartialEq, Eq)] -// enum CycleQuery { -// None, -// A, -// B, -// C, -// AthenC, -// } -// -// #[salsa::input] -// struct ABC { -// a: CycleQuery, -// b: CycleQuery, -// c: CycleQuery, -// } -// -// impl CycleQuery { -// fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { -// match self { -// CycleQuery::A => cycle_a(db, abc), -// CycleQuery::B => cycle_b(db, abc), -// CycleQuery::C => cycle_c(db, abc), -// CycleQuery::AthenC => { -// let _ = cycle_a(db, abc); -// cycle_c(db, abc) -// } -// CycleQuery::None => Ok(()), -// } -// } -// } -// -// #[salsa::tracked(recovery_fn=recover_a)] -// fn cycle_a(db: &dyn Db, abc: ABC) -> salsa::Result> { -// abc.a(db).invoke(db, abc) -// } -// -// fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { -// Err(Error { -// cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), -// }) -// } -// -// #[salsa::tracked(recovery_fn=recover_b)] -// fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { -// abc.b(db).invoke(db, abc) -// } -// -// fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { -// Err(Error { -// cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), -// }) -// } -// -// #[salsa::tracked] -// fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { -// abc.c(db).invoke(db, abc) -// } -// -// #[track_caller] -// fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { -// let v = std::panic::catch_unwind(f); -// if let Err(d) = &v { -// if let Some(cycle) = d.downcast_ref::() { -// return cycle.clone(); -// } -// } -// panic!("unexpected value: {:?}", v) -// } -// -// #[test] -// fn cycle_memoized() { -// salsa::DatabaseImpl::new().attach(|db| { -// let input = MyInput::new(db); -// let cycle = extract_cycle(|| memoized_a(db, input)); -// let expected = expect![[r#" -// [ -// memoized_a(Id(0)), -// memoized_b(Id(0)), -// ] -// "#]]; -// expected.assert_debug_eq(&cycle.all_participants(db)); -// }) -// } -// -// #[test] -// fn cycle_volatile() { -// salsa::DatabaseImpl::new().attach(|db| { -// let input = MyInput::new(db); -// let cycle = extract_cycle(|| volatile_a(db, input)); -// let expected = expect![[r#" -// [ -// volatile_a(Id(0)), -// volatile_b(Id(0)), -// ] -// "#]]; -// expected.assert_debug_eq(&cycle.all_participants(db)); -// }); -// } -// -// #[test] -// fn expect_cycle() { -// // A --> B -// // ^ | -// // +-----+ -// -// salsa::DatabaseImpl::new().attach(|db| { -// let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); -// assert!(cycle_a(db, abc).is_err()); -// }) -// } -// -// #[test] -// fn inner_cycle() { -// // A --> B <-- C -// // ^ | -// // +-----+ -// salsa::DatabaseImpl::new().attach(|db| { -// let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); -// let err = cycle_c(db, abc); -// assert!(err.is_err()); -// let expected = expect![[r#" -// [ -// "cycle_a(Id(0))", -// "cycle_b(Id(0))", -// ] -// "#]]; -// expected.assert_debug_eq(&err.unwrap_err().cycle); -// }) -// } -// -// #[test] -// fn cycle_revalidate() { -// // A --> B -// // ^ | -// // +-----+ -// let mut db = salsa::DatabaseImpl::new(); -// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); -// assert!(cycle_a(&db, abc).is_err()); -// abc.set_b(&mut db).to(CycleQuery::A); // same value as default -// assert!(cycle_a(&db, abc).is_err()); -// } -// -// #[test] -// fn cycle_recovery_unchanged_twice() { -// // A --> B -// // ^ | -// // +-----+ -// let mut db = salsa::DatabaseImpl::new(); -// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); -// assert!(cycle_a(&db, abc).is_err()); -// -// abc.set_c(&mut db).to(CycleQuery::A); // force new revision -// assert!(cycle_a(&db, abc).is_err()); -// } -// -// #[test] -// fn cycle_appears() { -// let mut db = salsa::DatabaseImpl::new(); -// // A --> B -// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); -// assert!(cycle_a(&db, abc).is_ok()); -// -// // A --> B -// // ^ | -// // +-----+ -// abc.set_b(&mut db).to(CycleQuery::A); -// assert!(cycle_a(&db, abc).is_err()); -// } -// -// #[test] -// fn cycle_disappears() { -// let mut db = salsa::DatabaseImpl::new(); -// -// // A --> B -// // ^ | -// // +-----+ -// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); -// assert!(cycle_a(&db, abc).is_err()); -// -// // A --> B -// abc.set_b(&mut db).to(CycleQuery::None); -// assert!(cycle_a(&db, abc).is_ok()); -// } -// -// /// A variant on `cycle_disappears` in which the values of -// /// `a` and `b` are set with durability values. -// /// If we are not careful, this could cause us to overlook -// /// the fact that the cycle will no longer occur. -// #[test] -// fn cycle_disappears_durability() { -// let mut db = salsa::DatabaseImpl::new(); -// let abc = ABC::new( -// &mut db, -// CycleQuery::None, -// CycleQuery::None, -// CycleQuery::None, -// ); -// abc.set_a(&mut db) -// .with_durability(Durability::LOW) -// .to(CycleQuery::B); -// abc.set_b(&mut db) -// .with_durability(Durability::HIGH) -// .to(CycleQuery::A); -// -// assert!(cycle_a(&db, abc).is_err()); -// -// // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, -// // because `b` participates in the same cycle as `a`, its final durability -// // should be `LOW`. -// // -// // Check that setting a `LOW` input causes us to re-execute `b` query, and -// // observe that the cycle goes away. -// abc.set_a(&mut db) -// .with_durability(Durability::LOW) -// .to(CycleQuery::None); -// -// assert!(cycle_b(&mut db, abc).is_ok()); -// } -// -// #[test] -// fn cycle_mixed_1() { -// salsa::DatabaseImpl::new().attach(|db| { -// // A --> B <-- C -// // | ^ -// // +-----+ -// let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); -// -// let expected = expect![[r#" -// [ -// "cycle_b(Id(0))", -// "cycle_c(Id(0))", -// ] -// "#]]; -// expected.assert_debug_eq(&cycle_c(db, abc).unwrap_err().cycle); -// }) -// } -// -// #[test] -// fn cycle_mixed_2() { -// salsa::DatabaseImpl::new().attach(|db| { -// // Configuration: -// // -// // A --> B --> C -// // ^ | -// // +-----------+ -// let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); -// let expected = expect![[r#" -// [ -// "cycle_a(Id(0))", -// "cycle_b(Id(0))", -// "cycle_c(Id(0))", -// ] -// "#]]; -// expected.assert_debug_eq(&cycle_a(db, abc).unwrap_err().cycle); -// }) -// } -// -// #[test] -// fn cycle_deterministic_order() { -// // No matter whether we start from A or B, we get the same set of participants: -// let f = || { -// let mut db = salsa::DatabaseImpl::new(); -// -// // A --> B -// // ^ | -// // +-----+ -// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); -// (db, abc) -// }; -// let (db, abc) = f(); -// let a = cycle_a(&db, abc); -// let (db, abc) = f(); -// let b = cycle_b(&db, abc); -// let expected = expect![[r#" -// ( -// [ -// "cycle_a(Id(0))", -// "cycle_b(Id(0))", -// ], -// [ -// "cycle_a(Id(0))", -// "cycle_b(Id(0))", -// ], -// ) -// "#]]; -// expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); -// } -// -// #[test] -// fn cycle_multiple() { -// // No matter whether we start from A or B, we get the same set of participants: -// let mut db = salsa::DatabaseImpl::new(); -// -// // Configuration: -// // -// // A --> B <-- C -// // ^ | ^ -// // +-----+ | -// // | | -// // +-----+ -// // -// // Here, conceptually, B encounters a cycle with A and then -// // recovers. -// let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); -// -// let c = cycle_c(&db, abc); -// let b = cycle_b(&db, abc); -// let a = cycle_a(&db, abc); -// let expected = expect![[r#" -// ( -// [ -// "cycle_a(Id(0))", -// "cycle_b(Id(0))", -// ], -// [ -// "cycle_a(Id(0))", -// "cycle_b(Id(0))", -// ], -// [ -// "cycle_a(Id(0))", -// "cycle_b(Id(0))", -// ], -// ) -// "#]]; -// expected.assert_debug_eq(&( -// c.unwrap_err().cycle, -// b.unwrap_err().cycle, -// a.unwrap_err().cycle, -// )); -// } -// -// #[test] -// fn cycle_recovery_set_but_not_participating() { -// salsa::DatabaseImpl::new().attach(|db| { -// // A --> C -+ -// // ^ | -// // +--+ -// let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); -// -// // Here we expect C to panic and A not to recover: -// let r = extract_cycle(|| drop(cycle_a(db, abc))); -// let expected = expect![[r#" -// [ -// cycle_c(Id(0)), -// ] -// "#]]; -// expected.assert_debug_eq(&r.all_participants(db)); -// }) -// } +use std::panic::{RefUnwindSafe, UnwindSafe}; + +use expect_test::expect; +use salsa::DatabaseImpl; +use salsa::Durability; + +// Axes: +// +// Threading +// * Intra-thread +// * Cross-thread -- part of cycle is on one thread, part on another +// +// Recovery strategies: +// * Panic +// * Fallback +// * Mixed -- multiple strategies within cycle participants +// +// Across revisions: +// * N/A -- only one revision +// * Present in new revision, not old +// * Present in old revision, not new +// * Present in both revisions +// +// Dependencies +// * Tracked +// * Untracked -- cycle participant(s) contain untracked reads +// +// Layers +// * Direct -- cycle participant is directly invoked from test +// * Indirect -- invoked a query that invokes the cycle +// +// +// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | +// | ------ | -------- | -------- | --------- | ------ | --------- | +// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | +// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | +// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | +// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | +// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | +// | Intra | Fallback | New | Tracked | direct | cycle_appears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | +// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | +// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +struct CycleDescription { + participants: Vec, +} + +use salsa::Database as Db; +use salsa::Setter; + +#[salsa::input] +struct MyInput {} + +#[salsa::tracked] +fn memoized_a(db: &dyn Db, input: MyInput) -> salsa::Result<()> { + memoized_b(db, input) +} + +#[salsa::tracked] +fn memoized_b(db: &dyn Db, input: MyInput) -> salsa::Result<()> { + memoized_a(db, input) +} + +#[salsa::tracked] +fn volatile_a(db: &dyn Db, input: MyInput) -> salsa::Result<()> { + db.report_untracked_read(); + volatile_b(db, input) +} + +#[salsa::tracked] +fn volatile_b(db: &dyn Db, input: MyInput) -> salsa::Result<()> { + db.report_untracked_read(); + volatile_a(db, input) +} + +/// The queries A, B, and C in `Database` can be configured +/// to invoke one another in arbitrary ways using this +/// enum. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CycleQuery { + None, + A, + B, + C, + AthenC, +} + +#[salsa::input] +struct ABC { + a: CycleQuery, + b: CycleQuery, + c: CycleQuery, +} + +impl CycleQuery { + fn invoke(self, db: &dyn Db, abc: ABC) -> salsa::Result> { + match self { + CycleQuery::A => cycle_a(db, abc), + CycleQuery::B => cycle_b(db, abc), + CycleQuery::C => cycle_c(db, abc), + CycleQuery::AthenC => { + let _ = cycle_a(db, abc)?; + cycle_c(db, abc) + } + CycleQuery::None => Ok(None), + } + } +} + +#[salsa::tracked(recovery_fn=recover_a)] +fn cycle_a(db: &dyn Db, abc: ABC) -> salsa::Result> { + abc.a(db)?.invoke(db, abc) +} + +fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Option { + Some(CycleDescription { + participants: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), + }) +} + +#[salsa::tracked(recovery_fn=recover_b)] +fn cycle_b(db: &dyn Db, abc: ABC) -> salsa::Result> { + abc.b(db)?.invoke(db, abc) +} + +fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Option { + Some(CycleDescription { + participants: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), + }) +} + +#[salsa::tracked] +fn cycle_c(db: &dyn Db, abc: ABC) -> salsa::Result> { + abc.c(db)?.invoke(db, abc) +} + +#[track_caller] +fn extract_cycle(f: F) -> salsa::Cycle +where + F: FnOnce() -> salsa::Result + UnwindSafe, + R: std::fmt::Debug, +{ + let v = std::panic::catch_unwind(f); + if let Err(d) = &v { + if let Some(cycle) = d.downcast_ref::() { + return cycle.clone(); + } + } + panic!("unexpected value: {:?}", v) +} + +#[test] +fn cycle_memoized() -> salsa::Result<()> { + salsa::DatabaseImpl::new().attach(|db| { + let input = MyInput::new(db); + let cycle = extract_cycle(|| memoized_a(db, input)); + let expected = expect![[r#" + [ + memoized_a(Id(0)), + memoized_b(Id(0)), + ] + "#]]; + expected.assert_debug_eq(&cycle.all_participants(db)); + + Ok(()) + }) +} + +#[test] +fn cycle_volatile() { + salsa::DatabaseImpl::new().attach(|db| { + let input = MyInput::new(db); + let cycle = extract_cycle(|| volatile_a(db, input)); + let expected = expect![[r#" + [ + volatile_a(Id(0)), + volatile_b(Id(0)), + ] + "#]]; + expected.assert_debug_eq(&cycle.all_participants(db)); + }); +} + +#[test] +fn expect_cycle() -> salsa::Result<()> { + // A --> B + // ^ | + // +-----+ + + salsa::DatabaseImpl::new().attach(|db| { + let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(db, abc)?.is_some()); + Ok(()) + }) +} + +#[test] +fn inner_cycle() -> salsa::Result<()> { + // A --> B <-- C + // ^ | + // +-----+ + salsa::DatabaseImpl::new().attach(|db| { + let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); + let err = cycle_c(db, abc)?; + assert!(err.is_some()); + let expected = expect![[r#" + [ + "cycle_a(Id(0))", + "cycle_b(Id(0))", + ] + "#]]; + expected.assert_debug_eq(&err.unwrap().participants); + + Ok(()) + }) +} + +#[test] +fn cycle_revalidate() -> salsa::Result<()> { + // A --> B + // ^ | + // +-----+ + let mut db = salsa::DatabaseImpl::new(); + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc)?.is_some()); + abc.set_b(&mut db).to(CycleQuery::A); // same value as default + assert!(cycle_a(&db, abc)?.is_some()); + + Ok(()) +} + +#[test] +fn cycle_recovery_unchanged_twice() -> salsa::Result<()> { + // A --> B + // ^ | + // +-----+ + let mut db = salsa::DatabaseImpl::new(); + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc)?.is_some()); + + abc.set_c(&mut db).to(CycleQuery::A); // force new revision + assert!(cycle_a(&db, abc)?.is_some()); + + Ok(()) +} + +#[test] +fn cycle_appears() -> salsa::Result<()> { + let mut db = salsa::DatabaseImpl::new(); + // A --> B + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); + assert!(cycle_a(&db, abc)?.is_none()); + + // A --> B + // ^ | + // +-----+ + abc.set_b(&mut db).to(CycleQuery::A); + assert!(cycle_a(&db, abc)?.is_some()); + + Ok(()) +} + +#[test] +fn cycle_disappears() -> salsa::Result<()> { + let mut db = salsa::DatabaseImpl::new(); + + // A --> B + // ^ | + // +-----+ + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(cycle_a(&db, abc)?.is_some()); + + // A --> B + abc.set_b(&mut db).to(CycleQuery::None); + assert!(cycle_a(&db, abc)?.is_none()); + + Ok(()) +} + +/// A variant on `cycle_disappears` in which the values of +/// `a` and `b` are set with durability values. +/// If we are not careful, this could cause us to overlook +/// the fact that the cycle will no longer occur. +#[test] +fn cycle_disappears_durability() -> salsa::Result<()> { + let mut db = salsa::DatabaseImpl::new(); + let abc = ABC::new( + &mut db, + CycleQuery::None, + CycleQuery::None, + CycleQuery::None, + ); + abc.set_a(&mut db) + .with_durability(Durability::LOW) + .to(CycleQuery::B); + abc.set_b(&mut db) + .with_durability(Durability::HIGH) + .to(CycleQuery::A); + + assert!(cycle_a(&db, abc)?.is_some()); + + // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, + // because `b` participates in the same cycle as `a`, its final durability + // should be `LOW`. + // + // Check that setting a `LOW` input causes us to re-execute `b` query, and + // observe that the cycle goes away. + abc.set_a(&mut db) + .with_durability(Durability::LOW) + .to(CycleQuery::None); + + assert!(cycle_b(&mut db, abc)?.is_none()); + + Ok(()) +} + +#[test] +fn cycle_mixed_1() -> salsa::Result<()> { + salsa::DatabaseImpl::new().attach(|db| { + // A --> B <-- C + // | ^ + // +-----+ + let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); + + let expected = expect![[r#" + [ + "cycle_b(Id(0))", + "cycle_c(Id(0))", + ] + "#]]; + expected.assert_debug_eq(&cycle_c(db, abc)?.unwrap().participants); + + Ok(()) + }) +} + +#[test] +fn cycle_mixed_2() -> salsa::Result<()> { + salsa::DatabaseImpl::new().attach(|db| { + // Configuration: + // + // A --> B --> C + // ^ | + // +-----------+ + let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); + let expected = expect![[r#" + [ + "cycle_a(Id(0))", + "cycle_b(Id(0))", + "cycle_c(Id(0))", + ] + "#]]; + expected.assert_debug_eq(&cycle_a(db, abc)?.unwrap().participants); + Ok(()) + }) +} + +#[test] +fn cycle_deterministic_order() -> salsa::Result<()> { + // No matter whether we start from A or B, we get the same set of participants: + let f = || { + let mut db = salsa::DatabaseImpl::new(); + + // A --> B + // ^ | + // +-----+ + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + (db, abc) + }; + let (db, abc) = f(); + let a = cycle_a(&db, abc)?; + let (db, abc) = f(); + let b = cycle_b(&db, abc)?; + let expected = expect![[r#" + ( + [ + "cycle_a(Id(0))", + "cycle_b(Id(0))", + ], + [ + "cycle_a(Id(0))", + "cycle_b(Id(0))", + ], + ) + "#]]; + expected.assert_debug_eq(&(a.unwrap().participants, b.unwrap().participants)); + + Ok(()) +} + +#[test] +fn cycle_multiple() -> salsa::Result<()> { + // No matter whether we start from A or B, we get the same set of participants: + let mut db = salsa::DatabaseImpl::new(); + + // Configuration: + // + // A --> B <-- C + // ^ | ^ + // +-----+ | + // | | + // +-----+ + // + // Here, conceptually, B encounters a cycle with A and then + // recovers. + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); + + let c = cycle_c(&db, abc)?; + let b = cycle_b(&db, abc)?; + let a = cycle_a(&db, abc)?; + let expected = expect![[r#" + ( + [ + "cycle_a(Id(0))", + "cycle_b(Id(0))", + ], + [ + "cycle_a(Id(0))", + "cycle_b(Id(0))", + ], + [ + "cycle_a(Id(0))", + "cycle_b(Id(0))", + ], + ) + "#]]; + expected.assert_debug_eq(&( + c.unwrap().participants, + b.unwrap().participants, + a.unwrap().participants, + )); + + Ok(()) +} + +#[test] +fn cycle_recovery_set_but_not_participating() -> salsa::Result<()> { + salsa::DatabaseImpl::new().attach(|db| { + // A --> C -+ + // ^ | + // +--+ + let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); + + // Here we expect C to panic and A not to recover: + let r = extract_cycle(|| { + drop(cycle_a(db, abc)?); + Ok(()) + }); + let expected = expect![[r#" + [ + cycle_c(Id(0)), + ] + "#]]; + expected.assert_debug_eq(&r.all_participants(db)); + + Ok(()) + }) +} From f167e19ea67c58b60199800029b235d579ba2151 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 2 Sep 2024 17:18:23 +0200 Subject: [PATCH 05/10] Migrate parallel-cycle tests --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 2 +- src/function.rs | 2 +- src/function/execute.rs | 2 +- tests/cycles.rs | 20 +- tests/parallel/parallel_cycle_all_recover.rs | 208 +++++++++--------- tests/parallel/parallel_cycle_mid_recover.rs | 204 ++++++++--------- tests/parallel/parallel_cycle_one_recover.rs | 178 +++++++-------- 7 files changed, 312 insertions(+), 304 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index fb48e7aa9..a68426451 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -183,7 +183,7 @@ macro_rules! setup_tracked_fn { db: &$db_lt dyn $Db, cycle: &$zalsa::Cycle, ($($input_id),*): ($($input_ty),*) - ) -> Self::Output<$db_lt> { + ) -> salsa::Result> { $($cycle_recovery_fn)*(db, cycle, $($input_id),*) } diff --git a/src/function.rs b/src/function.rs index 42ed227c4..fe1947dfa 100644 --- a/src/function.rs +++ b/src/function.rs @@ -74,7 +74,7 @@ pub trait Configuration: Any { db: &'db Self::DbView, cycle: &Cycle, input: Self::Input<'db>, - ) -> Self::Output<'db>; + ) -> crate::Result>; } /// Function ingredients are the "workhorse" of salsa. diff --git a/src/function/execute.rs b/src/function/execute.rs index ce8576815..fd9425b9e 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -65,7 +65,7 @@ where crate::cycle::CycleRecoveryStrategy::Fallback => { if let Some(c) = active_query.take_cycle() { assert!(c.is(&cycle)); - C::recover_from_cycle(db, &cycle, C::id_to_input(db, id)) + C::recover_from_cycle(db, &cycle, C::id_to_input(db, id))? } else { // we are not a participant in this cycle debug_assert!(!cycle diff --git a/tests/cycles.rs b/tests/cycles.rs index 634101bd2..7c2c3a3c1 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -121,10 +121,14 @@ fn cycle_a(db: &dyn Db, abc: ABC) -> salsa::Result> { abc.a(db)?.invoke(db, abc) } -fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Option { - Some(CycleDescription { +fn recover_a( + db: &dyn Db, + cycle: &salsa::Cycle, + abc: ABC, +) -> salsa::Result> { + Ok(Some(CycleDescription { participants: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) + })) } #[salsa::tracked(recovery_fn=recover_b)] @@ -132,10 +136,14 @@ fn cycle_b(db: &dyn Db, abc: ABC) -> salsa::Result> { abc.b(db)?.invoke(db, abc) } -fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Option { - Some(CycleDescription { +fn recover_b( + db: &dyn Db, + cycle: &salsa::Cycle, + abc: ABC, +) -> salsa::Result> { + Ok(Some(CycleDescription { participants: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) + })) } #[salsa::tracked] diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs index 552a6b292..9d415158d 100644 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ b/tests/parallel/parallel_cycle_all_recover.rs @@ -1,104 +1,104 @@ -// //! Test for cycle recover spread across two threads. -// //! See `../cycles.rs` for a complete listing of cycle tests, -// //! both intra and cross thread. -// -// use crate::setup::Knobs; -// use crate::setup::KnobsDatabase; -// -// #[salsa::input] -// pub(crate) struct MyInput { -// field: i32, -// } -// -// #[salsa::tracked(recovery_fn = recover_a1)] -// pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // Wait to create the cycle until both threads have entered -// db.signal(1); -// db.wait_for(2); -// -// a2(db, input) -// } -// -// fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { -// dbg!("recover_a1"); -// key.field(db) * 10 + 1 -// } -// -// #[salsa::tracked(recovery_fn=recover_a2)] -// pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// b1(db, input) -// } -// -// fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { -// dbg!("recover_a2"); -// key.field(db) * 10 + 2 -// } -// -// #[salsa::tracked(recovery_fn=recover_b1)] -// pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // Wait to create the cycle until both threads have entered -// db.wait_for(1); -// db.signal(2); -// -// // Wait for thread A to block on this thread -// db.wait_for(3); -// b2(db, input) -// } -// -// fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { -// dbg!("recover_b1"); -// key.field(db) * 20 + 1 -// } -// -// #[salsa::tracked(recovery_fn=recover_b2)] -// pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// a1(db, input) -// } -// -// fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { -// dbg!("recover_b2"); -// key.field(db) * 20 + 2 -// } -// -// // Recover cycle test: -// // -// // The pattern is as follows. -// // -// // Thread A Thread B -// // -------- -------- -// // a1 b1 -// // | wait for stage 1 (blocks) -// // signal stage 1 | -// // wait for stage 2 (blocks) (unblocked) -// // | signal stage 2 -// // (unblocked) wait for stage 3 (blocks) -// // a2 | -// // b1 (blocks -> stage 3) | -// // | (unblocked) -// // | b2 -// // | a1 (cycle detected, recovers) -// // | b2 completes, recovers -// // | b1 completes, recovers -// // a2 sees cycle, recovers -// // a1 completes, recovers -// -// #[test] -// fn execute() { -// let db = Knobs::default(); -// -// let input = MyInput::new(&db, 1); -// -// let thread_a = std::thread::spawn({ -// let db = db.clone(); -// db.knobs().signal_on_will_block.store(3); -// move || a1(&db, input).unwrap() -// }); -// -// let thread_b = std::thread::spawn({ -// let db = db.clone(); -// move || b1(&db, input).unwrap() -// }); -// -// assert_eq!(thread_a.join().unwrap(), 11); -// assert_eq!(thread_b.join().unwrap(), 21); -// } +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::Knobs; +use crate::setup::KnobsDatabase; + +#[salsa::input] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked(recovery_fn = recover_a1)] +pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + a2(db, input) +} + +fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> salsa::Result { + dbg!("recover_a1"); + Ok(key.field(db)? * 10 + 1) +} + +#[salsa::tracked(recovery_fn=recover_a2)] +pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + b1(db, input) +} + +fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> salsa::Result { + dbg!("recover_a2"); + Ok(key.field(db)? * 10 + 2) +} + +#[salsa::tracked(recovery_fn=recover_b1)] +pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + b2(db, input) +} + +fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> salsa::Result { + dbg!("recover_b1"); + Ok(key.field(db)? * 20 + 1) +} + +#[salsa::tracked(recovery_fn=recover_b2)] +pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + a1(db, input) +} + +fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> salsa::Result { + dbg!("recover_b2"); + Ok(key.field(db)? * 20 + 2) +} + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected, recovers) +// | b2 completes, recovers +// | b1 completes, recovers +// a2 sees cycle, recovers +// a1 completes, recovers + +#[test] +fn execute() { + let db = Knobs::default(); + + let input = MyInput::new(&db, 1); + + let thread_a = std::thread::spawn({ + let db = db.clone(); + db.knobs().signal_on_will_block.store(3); + move || a1(&db, input).unwrap() + }); + + let thread_b = std::thread::spawn({ + let db = db.clone(); + move || b1(&db, input).unwrap() + }); + + assert_eq!(thread_a.join().unwrap(), 11); + assert_eq!(thread_b.join().unwrap(), 21); +} diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs index 5764db750..39c2edc73 100644 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ b/tests/parallel/parallel_cycle_mid_recover.rs @@ -1,102 +1,102 @@ -// //! Test for cycle recover spread across two threads. -// //! See `../cycles.rs` for a complete listing of cycle tests, -// //! both intra and cross thread. -// -// use crate::setup::{Knobs, KnobsDatabase}; -// -// #[salsa::input] -// pub(crate) struct MyInput { -// field: i32, -// } -// -// #[salsa::tracked] -// pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // tell thread b we have started -// db.signal(1); -// -// // wait for thread b to block on a1 -// db.wait_for(2); -// -// a2(db, input) -// } -// -// #[salsa::tracked] -// pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // create the cycle -// b1(db, input) -// } -// -// #[salsa::tracked(recovery_fn=recover_b1)] -// pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // wait for thread a to have started -// db.wait_for(1); -// b2(db, input) -// } -// -// fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { -// dbg!("recover_b1"); -// key.field(db) * 20 + 2 -// } -// -// #[salsa::tracked] -// pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // will encounter a cycle but recover -// b3(db, input); -// b1(db, input); // hasn't recovered yet -// 0 -// } -// -// #[salsa::tracked(recovery_fn=recover_b3)] -// pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // will block on thread a, signaling stage 2 -// a1(db, input) -// } -// -// fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { -// dbg!("recover_b3"); -// key.field(db) * 200 + 2 -// } -// -// // Recover cycle test: -// // -// // The pattern is as follows. -// // -// // Thread A Thread B -// // -------- -------- -// // a1 b1 -// // | wait for stage 1 (blocks) -// // signal stage 1 | -// // wait for stage 2 (blocks) (unblocked) -// // | | -// // | b2 -// // | b3 -// // | a1 (blocks -> stage 2) -// // (unblocked) | -// // a2 (cycle detected) | -// // b3 recovers -// // b2 resumes -// // b1 recovers -// -// #[test] -// fn execute() { -// let db = Knobs::default(); -// -// let input = MyInput::new(&db, 1); -// -// let thread_a = std::thread::spawn({ -// let db = db.clone(); -// move || a1(&db, input).unwrap() -// }); -// -// let thread_b = std::thread::spawn({ -// let db = db.clone(); -// db.knobs().signal_on_will_block.store(3); -// move || b1(&db, input).unwrap() -// }); -// -// // We expect that the recovery function yields -// // `1 * 20 + 2`, which is returned (and forwarded) -// // to b1, and from there to a2 and a1. -// assert_eq!(thread_a.join().unwrap(), 22); -// assert_eq!(thread_b.join().unwrap(), 22); -// } +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::input] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked] +pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // tell thread b we have started + db.signal(1); + + // wait for thread b to block on a1 + db.wait_for(2); + + a2(db, input) +} + +#[salsa::tracked] +pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // create the cycle + b1(db, input) +} + +#[salsa::tracked(recovery_fn=recover_b1)] +pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // wait for thread a to have started + db.wait_for(1); + b2(db, input) +} + +fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> salsa::Result { + dbg!("recover_b1"); + Ok(key.field(db)? * 20 + 2) +} + +#[salsa::tracked] +pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // will encounter a cycle but recover + b3(db, input)?; + b1(db, input)?; // hasn't recovered yet + Ok(0) +} + +#[salsa::tracked(recovery_fn=recover_b3)] +pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // will block on thread a, signaling stage 2 + a1(db, input) +} + +fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> salsa::Result { + dbg!("recover_b3"); + Ok(key.field(db)? * 200 + 2) +} + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | | +// | b2 +// | b3 +// | a1 (blocks -> stage 2) +// (unblocked) | +// a2 (cycle detected) | +// b3 recovers +// b2 resumes +// b1 recovers + +#[test] +fn execute() { + let db = Knobs::default(); + + let input = MyInput::new(&db, 1); + + let thread_a = std::thread::spawn({ + let db = db.clone(); + move || a1(&db, input).unwrap() + }); + + let thread_b = std::thread::spawn({ + let db = db.clone(); + db.knobs().signal_on_will_block.store(3); + move || b1(&db, input).unwrap() + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs index d31819327..9163b0c30 100644 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ b/tests/parallel/parallel_cycle_one_recover.rs @@ -1,91 +1,91 @@ -// //! Test for cycle recover spread across two threads. -// //! See `../cycles.rs` for a complete listing of cycle tests, -// //! both intra and cross thread. +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::input] +pub(crate) struct MyInput { + field: i32, +} + +#[salsa::tracked] +pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + a2(db, input) +} + +#[salsa::tracked(recovery_fn=recover)] +pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + b1(db, input) +} + +fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> salsa::Result { + dbg!("recover"); + Ok(key.field(db)? * 20 + 2) +} + +#[salsa::tracked] +pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + b2(db, input) +} + +#[salsa::tracked] +pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + a1(db, input) +} + +// Recover cycle test: // -// use crate::setup::{Knobs, KnobsDatabase}; +// The pattern is as follows. // -// #[salsa::input] -// pub(crate) struct MyInput { -// field: i32, -// } -// -// #[salsa::tracked] -// pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // Wait to create the cycle until both threads have entered -// db.signal(1); -// db.wait_for(2); -// -// a2(db, input) -// } -// -// #[salsa::tracked(recovery_fn=recover)] -// pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// b1(db, input) -// } -// -// fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { -// dbg!("recover"); -// key.field(db) * 20 + 2 -// } -// -// #[salsa::tracked] -// pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// // Wait to create the cycle until both threads have entered -// db.wait_for(1); -// db.signal(2); -// -// // Wait for thread A to block on this thread -// db.wait_for(3); -// b2(db, input) -// } -// -// #[salsa::tracked] -// pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { -// a1(db, input) -// } -// -// // Recover cycle test: -// // -// // The pattern is as follows. -// // -// // Thread A Thread B -// // -------- -------- -// // a1 b1 -// // | wait for stage 1 (blocks) -// // signal stage 1 | -// // wait for stage 2 (blocks) (unblocked) -// // | signal stage 2 -// // (unblocked) wait for stage 3 (blocks) -// // a2 | -// // b1 (blocks -> stage 3) | -// // | (unblocked) -// // | b2 -// // | a1 (cycle detected) -// // a2 recovery fn executes | -// // a1 completes normally | -// // b2 completes, recovers -// // b1 completes, recovers -// -// #[test] -// fn execute() { -// let db = Knobs::default(); -// -// let input = MyInput::new(&db, 1); -// -// let thread_a = std::thread::spawn({ -// let db = db.clone(); -// db.knobs().signal_on_will_block.store(3); -// move || a1(&db, input).unwrap() -// }); -// -// let thread_b = std::thread::spawn({ -// let db = db.clone(); -// move || b1(&db, input).unwrap() -// }); -// -// // We expect that the recovery function yields -// // `1 * 20 + 2`, which is returned (and forwarded) -// // to b1, and from there to a2 and a1. -// assert_eq!(thread_a.join().unwrap(), 22); -// assert_eq!(thread_b.join().unwrap(), 22); -// } +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected) +// a2 recovery fn executes | +// a1 completes normally | +// b2 completes, recovers +// b1 completes, recovers + +#[test] +fn execute() { + let db = Knobs::default(); + + let input = MyInput::new(&db, 1); + + let thread_a = std::thread::spawn({ + let db = db.clone(); + db.knobs().signal_on_will_block.store(3); + move || a1(&db, input).unwrap() + }); + + let thread_b = std::thread::spawn({ + let db = db.clone(); + move || b1(&db, input).unwrap() + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} From ccaa5d17b199ca7d82ce79c590f66bb99deb00b4 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 2 Sep 2024 17:50:40 +0200 Subject: [PATCH 06/10] Add test for cycle error propagation --- src/result.rs | 2 +- tests/cycle_dropping_error_panics.rs | 93 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 tests/cycle_dropping_error_panics.rs diff --git a/src/result.rs b/src/result.rs index a44ef50c1..f0806000e 100644 --- a/src/result.rs +++ b/src/result.rs @@ -28,7 +28,7 @@ impl Error { Self { kind: Box::new(ErrorKind::Cycle(CycleError { cycle, - bomb: DropBomb::new("TODO"), + bomb: DropBomb::new("Cycle errors must be propagated so that Salsa can resolve the cycle. If you see this message outside a salsa query, please open an issue."), })), } } diff --git a/tests/cycle_dropping_error_panics.rs b/tests/cycle_dropping_error_panics.rs new file mode 100644 index 000000000..21fee1467 --- /dev/null +++ b/tests/cycle_dropping_error_panics.rs @@ -0,0 +1,93 @@ +#![allow(warnings)] + +use std::panic::{RefUnwindSafe, UnwindSafe}; +use std::sync::atomic::AtomicUsize; + +use expect_test::expect; +use salsa::Cycle; +use salsa::DatabaseImpl; +use salsa::Durability; + +use salsa::Database as Db; +use salsa::Setter; + +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::tracked(recovery_fn = recover_a)] +fn cycle_a(db: &dyn Db, input: MyInput) -> salsa::Result { + cycle_b(db, input) +} + +fn recover_a(db: &dyn Db, cycle: &Cycle, input: MyInput) -> salsa::Result { + Ok("recovered".to_string()) +} + +#[salsa::tracked] +fn cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result { + Ok(cycle_a(db, input).unwrap_or_else(|error| format!("Suppressed error: {error}"))) +} + +#[test] +#[should_panic(expected = "Cycle errors must be propagated so that Salsa can resolve the cycle.")] +fn execute() { + salsa::DatabaseImpl::new().attach(|db| { + let input = MyInput::new(db, 2); + let result = cycle_a(db, input); + + panic!("Expected query to panic"); + }) +} + +#[salsa::tracked] +fn deferred_cycle_a(db: &dyn Db, input: MyInput) -> salsa::Result { + deferred_cycle_b(db, input) +} + +// Simulates some global state in the database that is updated during a query. +// An example of this is an input-map. +static EVEN_COUNT: AtomicUsize = AtomicUsize::new(0); + +#[salsa::tracked] +fn deferred_cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result { + let is_even = input.field(db)? % 2 == 0; + if is_even { + EVEN_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + + match deferred_cycle_c(db, input) { + Ok(result) => Ok(result), + Err(err) => { + if is_even { + EVEN_COUNT.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); + } + + Err(err) + } + } +} + +#[salsa::tracked(recovery_fn = recover_c)] +fn deferred_cycle_c(db: &dyn Db, input: MyInput) -> salsa::Result { + deferred_cycle_a(db, input) +} + +fn recover_c(db: &dyn Db, cycle: &Cycle, input: MyInput) -> salsa::Result { + Ok("recovered C".to_string()) +} + +// A query captures the error but propagates it before completion. +#[test] +fn deferred_propagation() { + salsa::DatabaseImpl::new().attach(|db| { + let input = MyInput::new(db, 2); + let result = deferred_cycle_a(db, input); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "recovered C"); + + assert_eq!(EVEN_COUNT.load(std::sync::atomic::Ordering::SeqCst), 1); + }) +} From 79e0155b5fc9cb00303cf7f73889396067304f38 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 2 Sep 2024 18:20:36 +0200 Subject: [PATCH 07/10] Add test for cancellation error propagation --- src/result.rs | 28 +++- tests/cycle_dropping_error_panics.rs | 6 +- tests/parallel/main.rs | 1 + tests/parallel/parallel_cancellation.rs | 11 +- .../parallel_cancellation_capture_error.rs | 129 ++++++++++++++++++ 5 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 tests/parallel/parallel_cancellation_capture_error.rs diff --git a/src/result.rs b/src/result.rs index f0806000e..89210b7b2 100644 --- a/src/result.rs +++ b/src/result.rs @@ -1,4 +1,4 @@ -use crate::Cycle; +use crate::{with_attached_database, Cycle}; use drop_bomb::DropBomb; use std::fmt; use std::fmt::Debug; @@ -20,7 +20,10 @@ pub struct Error { impl Error { pub(crate) fn cancelled(reason: Cancelled) -> Self { Error { - kind: Box::new(ErrorKind::Cancelled(reason)), + kind: Box::new(ErrorKind::Cancelled(CancelledError { + reason, + bomb: DropBomb::new("Cancellation errors must be propagated inside salsa queries. If you see this message outside a salsa query, please open an issue."), + })), } } @@ -65,7 +68,7 @@ impl std::error::Error for Error {} #[derive(Debug)] pub(crate) enum ErrorKind { Cycle(CycleError), - Cancelled(Cancelled), + Cancelled(CancelledError), } #[derive(Debug)] @@ -81,6 +84,21 @@ impl CycleError { } } +#[derive(Debug)] +pub(crate) struct CancelledError { + reason: Cancelled, + bomb: DropBomb, +} + +impl Drop for CancelledError { + fn drop(&mut self) { + if with_attached_database(|_| {}).is_none() { + // We are outside a query. It's okay if the user drops the error now + self.bomb.defuse(); + } + } +} + // FIXME implement drop for Cancelled. /// A panic payload indicating that execution of a salsa query was cancelled. @@ -96,9 +114,9 @@ pub(crate) enum Cancelled { PropagatedPanic, } -impl std::fmt::Display for Cancelled { +impl std::fmt::Display for CancelledError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let why = match self { + let why = match self.reason { Cancelled::PendingWrite => "pending write", Cancelled::PropagatedPanic => "propagated panic", }; diff --git a/tests/cycle_dropping_error_panics.rs b/tests/cycle_dropping_error_panics.rs index 21fee1467..90c784f43 100644 --- a/tests/cycle_dropping_error_panics.rs +++ b/tests/cycle_dropping_error_panics.rs @@ -54,14 +54,14 @@ static EVEN_COUNT: AtomicUsize = AtomicUsize::new(0); fn deferred_cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result { let is_even = input.field(db)? % 2 == 0; if is_even { - EVEN_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + EVEN_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); } match deferred_cycle_c(db, input) { Ok(result) => Ok(result), Err(err) => { if is_even { - EVEN_COUNT.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); + EVEN_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); } Err(err) @@ -88,6 +88,6 @@ fn deferred_propagation() { assert!(result.is_ok()); assert_eq!(result.unwrap(), "recovered C"); - assert_eq!(EVEN_COUNT.load(std::sync::atomic::Ordering::SeqCst), 1); + assert_eq!(EVEN_COUNT.load(std::sync::atomic::Ordering::Relaxed), 1); }) } diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 578a83cb3..1bda47b66 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,6 +1,7 @@ mod setup; mod parallel_cancellation; +mod parallel_cancellation_capture_error; mod parallel_cycle_all_recover; mod parallel_cycle_mid_recover; mod parallel_cycle_none_recover; diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index 365061cb4..d07fa63a6 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -53,6 +53,7 @@ fn execute() { }) .unwrap(); + db.wait_for(1); db.signal_on_did_cancel.store(2); input.set_field(&mut db).to(2); @@ -63,7 +64,15 @@ fn execute() { expect_test::expect![[r#" Error { kind: Cancelled( - PendingWrite, + CancelledError { + reason: PendingWrite, + bomb: DropBomb( + RealBomb { + msg: "Cancellation errors must be propagated inside salsa queries. If you see this message outside a salsa query, please open an issue.", + defused: false, + }, + ), + }, ), } "#]] diff --git a/tests/parallel/parallel_cancellation_capture_error.rs b/tests/parallel/parallel_cancellation_capture_error.rs new file mode 100644 index 000000000..d19aeeb64 --- /dev/null +++ b/tests/parallel/parallel_cancellation_capture_error.rs @@ -0,0 +1,129 @@ +//! Test that suppressing a cancellation error inside a query +//! panics in debug mode. + +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +use salsa::Setter; + +use crate::setup::Knobs; +use crate::setup::KnobsDatabase; + +#[salsa::input] +struct MyInput { + field: i32, +} + +#[salsa::tracked] +fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + db.signal(1); + db.wait_for(2); + + match dummy(db, input) { + Ok(result) => Ok(result), + Err(_) => Ok("Suppressed cancellation".to_string()), + } +} + +#[salsa::tracked] +fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> salsa::Result { + Ok("should never get here!".to_string()) +} + +// Cancellation signalling test +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 +// | wait for stage 1 +// signal stage 1 set input, triggers cancellation +// wait for stage 2 (blocks) triggering cancellation sends stage 2 +// | +// (unblocked) +// dummy +// drops error -> panics + +#[test] +fn execute() { + let mut db = Knobs::default(); + + let input = MyInput::new(&db, 1); + + let thread_a = std::thread::Builder::new() + .name("a".to_string()) + .spawn({ + let db = db.clone(); + move || a1(&db, input) + }) + .unwrap(); + + db.wait_for(1); + db.signal_on_did_cancel.store(2); + input.set_field(&mut db).to(2); + + // Assert thread A panicked because it captured the error + let error = thread_a.join().unwrap_err(); + + if let Some(error) = error.downcast_ref::() { + assert_eq!(*error, "Cancellation errors must be propagated inside salsa queries. If you see this message outside a salsa query, please open an issue."); + } else { + panic!("Thread A should have panicked!") + } +} + +// Simulates some global state in the database that is updated during a query. +// An example of this is an input-map. +static EVEN_COUNT: AtomicUsize = AtomicUsize::new(0); + +#[salsa::tracked] +fn a1_deferred(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { + let is_even = input.field(db)? % 2 == 0; + if is_even { + EVEN_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + + db.signal(1); + db.wait_for(2); + + match dummy(db, input) { + Ok(result) => Ok(result), + Err(error) => { + if is_even { + EVEN_COUNT.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + + Err(error) + } + } +} + +#[test] +fn rethrow() { + let mut db = Knobs::default(); + + let input = MyInput::new(&db, 2); + + let thread_a = std::thread::Builder::new() + .name("a".to_string()) + .spawn({ + let db = db.clone(); + move || a1_deferred(&db, input) + }) + .unwrap(); + + db.wait_for(1); + db.signal_on_did_cancel.store(2); + input.set_field(&mut db).to(2); + + // Assert thread A was cancelled. + let result = thread_a.join().unwrap(); + + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "cancelled because of pending write" + ); + assert_eq!(EVEN_COUNT.load(Ordering::Relaxed), 0); +} From 2f5ccf43ec678118b8b4c2897ddb6e9420a4b253 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 2 Sep 2024 18:27:45 +0200 Subject: [PATCH 08/10] Use DebugDropBomb --- src/result.rs | 12 ++++++------ tests/cycle_dropping_error_panics.rs | 1 + tests/parallel/parallel_cancellation.rs | 18 ++---------------- .../parallel_cancellation_capture_error.rs | 1 + 4 files changed, 10 insertions(+), 22 deletions(-) diff --git a/src/result.rs b/src/result.rs index 89210b7b2..74cfb43af 100644 --- a/src/result.rs +++ b/src/result.rs @@ -1,5 +1,5 @@ use crate::{with_attached_database, Cycle}; -use drop_bomb::DropBomb; +use drop_bomb::DebugDropBomb; use std::fmt; use std::fmt::Debug; @@ -22,7 +22,7 @@ impl Error { Error { kind: Box::new(ErrorKind::Cancelled(CancelledError { reason, - bomb: DropBomb::new("Cancellation errors must be propagated inside salsa queries. If you see this message outside a salsa query, please open an issue."), + bomb: DebugDropBomb::new("Cancellation errors must be propagated inside salsa queries. If you see this message outside a salsa query, please open an issue."), })), } } @@ -31,7 +31,7 @@ impl Error { Self { kind: Box::new(ErrorKind::Cycle(CycleError { cycle, - bomb: DropBomb::new("Cycle errors must be propagated so that Salsa can resolve the cycle. If you see this message outside a salsa query, please open an issue."), + bomb: DebugDropBomb::new("Cycle errors must be propagated so that Salsa can resolve the cycle. If you see this message outside a salsa query, please open an issue."), })), } } @@ -74,7 +74,7 @@ pub(crate) enum ErrorKind { #[derive(Debug)] pub(crate) struct CycleError { cycle: Cycle, - bomb: DropBomb, + bomb: DebugDropBomb, } impl CycleError { @@ -87,12 +87,12 @@ impl CycleError { #[derive(Debug)] pub(crate) struct CancelledError { reason: Cancelled, - bomb: DropBomb, + bomb: DebugDropBomb, } impl Drop for CancelledError { fn drop(&mut self) { - if with_attached_database(|_| {}).is_none() { + if !self.bomb.is_defused() && with_attached_database(|_| {}).is_none() { // We are outside a query. It's okay if the user drops the error now self.bomb.defuse(); } diff --git a/tests/cycle_dropping_error_panics.rs b/tests/cycle_dropping_error_panics.rs index 90c784f43..8338443fc 100644 --- a/tests/cycle_dropping_error_panics.rs +++ b/tests/cycle_dropping_error_panics.rs @@ -31,6 +31,7 @@ fn cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result { } #[test] +#[cfg(debug_assertions)] #[should_panic(expected = "Cycle errors must be propagated so that Salsa can resolve the cycle.")] fn execute() { salsa::DatabaseImpl::new().attach(|db| { diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index d07fa63a6..3d015b27b 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -61,20 +61,6 @@ fn execute() { let cancelled = thread_a.join().unwrap().unwrap_err(); // and inspect the output - expect_test::expect![[r#" - Error { - kind: Cancelled( - CancelledError { - reason: PendingWrite, - bomb: DropBomb( - RealBomb { - msg: "Cancellation errors must be propagated inside salsa queries. If you see this message outside a salsa query, please open an issue.", - defused: false, - }, - ), - }, - ), - } - "#]] - .assert_debug_eq(&cancelled); + expect_test::expect![[r#"cancelled because of pending write"#]] + .assert_eq(&cancelled.to_string()); } diff --git a/tests/parallel/parallel_cancellation_capture_error.rs b/tests/parallel/parallel_cancellation_capture_error.rs index d19aeeb64..56a7ecc55 100644 --- a/tests/parallel/parallel_cancellation_capture_error.rs +++ b/tests/parallel/parallel_cancellation_capture_error.rs @@ -46,6 +46,7 @@ fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> salsa::Result { // drops error -> panics #[test] +#[cfg(debug_assertions)] fn execute() { let mut db = Knobs::default(); From d8fe4a7f9a9bc7b1458287206b741aba9931a017 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 3 Sep 2024 09:50:57 +0200 Subject: [PATCH 09/10] Remove unused testing-only method --- src/lib.rs | 1 - src/result.rs | 7 ------- 2 files changed, 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 38085c575..1947ec123 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,7 +81,6 @@ pub mod plumbing { pub use crate::ingredient::Jar; pub use crate::ingredient::JarAux; pub use crate::key::DatabaseKeyIndex; - pub use crate::result::error_as_cycle; pub use crate::revision::Revision; pub use crate::runtime::stamp; pub use crate::runtime::Runtime; diff --git a/src/result.rs b/src/result.rs index 74cfb43af..21be17db2 100644 --- a/src/result.rs +++ b/src/result.rs @@ -5,13 +5,6 @@ use std::fmt::Debug; pub type Result = std::result::Result; -pub fn error_as_cycle(error: &Error) -> Option<&Cycle> { - match &*error.kind { - ErrorKind::Cycle(error) => Some(&error.cycle), - _ => None, - } -} - #[derive(Debug)] pub struct Error { kind: Box, From bb6bbdd0f745a1ccbbe5312a7ffcc3bf3c2c8cd4 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 3 Sep 2024 12:22:50 +0200 Subject: [PATCH 10/10] Remove fixmes --- src/function/fetch.rs | 1 - src/result.rs | 5 ----- tests/cycle_dropping_error_panics.rs | 20 +++++++++---------- .../parallel_cancellation_capture_error.rs | 13 +++++------- 4 files changed, 15 insertions(+), 24 deletions(-) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ee7c1f131..fb1e9b68e 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -77,7 +77,6 @@ where let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - // FIXME: Handle error let _claim_guard = zalsa.sync_table_for(id).claim( db.as_dyn_database(), zalsa_local, diff --git a/src/result.rs b/src/result.rs index 21be17db2..b81f81bfe 100644 --- a/src/result.rs +++ b/src/result.rs @@ -92,18 +92,13 @@ impl Drop for CancelledError { } } -// FIXME implement drop for Cancelled. - /// A panic payload indicating that execution of a salsa query was cancelled. #[derive(Debug)] -#[non_exhaustive] pub(crate) enum Cancelled { /// The query was operating on revision R, but there is a pending write to move to revision R+1. - #[non_exhaustive] PendingWrite, /// The query was blocked on another thread, and that thread panicked. - #[non_exhaustive] PropagatedPanic, } diff --git a/tests/cycle_dropping_error_panics.rs b/tests/cycle_dropping_error_panics.rs index 8338443fc..82321d9ec 100644 --- a/tests/cycle_dropping_error_panics.rs +++ b/tests/cycle_dropping_error_panics.rs @@ -17,17 +17,17 @@ struct MyInput { } #[salsa::tracked(recovery_fn = recover_a)] -fn cycle_a(db: &dyn Db, input: MyInput) -> salsa::Result { +fn cycle_a(db: &dyn Db, input: MyInput) -> salsa::Result<&'static str> { cycle_b(db, input) } -fn recover_a(db: &dyn Db, cycle: &Cycle, input: MyInput) -> salsa::Result { - Ok("recovered".to_string()) +fn recover_a(db: &dyn Db, cycle: &Cycle, input: MyInput) -> salsa::Result<&'static str> { + Ok("recovered") } #[salsa::tracked] -fn cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result { - Ok(cycle_a(db, input).unwrap_or_else(|error| format!("Suppressed error: {error}"))) +fn cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result<&'static str> { + Ok(cycle_a(db, input).unwrap_or_else(|error| "Suppressed error")) } #[test] @@ -43,7 +43,7 @@ fn execute() { } #[salsa::tracked] -fn deferred_cycle_a(db: &dyn Db, input: MyInput) -> salsa::Result { +fn deferred_cycle_a(db: &dyn Db, input: MyInput) -> salsa::Result<&'static str> { deferred_cycle_b(db, input) } @@ -52,7 +52,7 @@ fn deferred_cycle_a(db: &dyn Db, input: MyInput) -> salsa::Result { static EVEN_COUNT: AtomicUsize = AtomicUsize::new(0); #[salsa::tracked] -fn deferred_cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result { +fn deferred_cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result<&'static str> { let is_even = input.field(db)? % 2 == 0; if is_even { EVEN_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -71,12 +71,12 @@ fn deferred_cycle_b(db: &dyn Db, input: MyInput) -> salsa::Result { } #[salsa::tracked(recovery_fn = recover_c)] -fn deferred_cycle_c(db: &dyn Db, input: MyInput) -> salsa::Result { +fn deferred_cycle_c(db: &dyn Db, input: MyInput) -> salsa::Result<&'static str> { deferred_cycle_a(db, input) } -fn recover_c(db: &dyn Db, cycle: &Cycle, input: MyInput) -> salsa::Result { - Ok("recovered C".to_string()) +fn recover_c(db: &dyn Db, cycle: &Cycle, input: MyInput) -> salsa::Result<&'static str> { + Ok("recovered C") } // A query captures the error but propagates it before completion. diff --git a/tests/parallel/parallel_cancellation_capture_error.rs b/tests/parallel/parallel_cancellation_capture_error.rs index 56a7ecc55..701f906fa 100644 --- a/tests/parallel/parallel_cancellation_capture_error.rs +++ b/tests/parallel/parallel_cancellation_capture_error.rs @@ -15,19 +15,16 @@ struct MyInput { } #[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +fn a1(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result<&'static str> { db.signal(1); db.wait_for(2); - match dummy(db, input) { - Ok(result) => Ok(result), - Err(_) => Ok("Suppressed cancellation".to_string()), - } + Ok(dummy(db, input).unwrap_or("Suppressed cancellation")) } #[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> salsa::Result { - Ok("should never get here!".to_string()) +fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> salsa::Result<&'static str> { + Ok("should never get here!") } // Cancellation signalling test @@ -79,7 +76,7 @@ fn execute() { static EVEN_COUNT: AtomicUsize = AtomicUsize::new(0); #[salsa::tracked] -fn a1_deferred(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result { +fn a1_deferred(db: &dyn KnobsDatabase, input: MyInput) -> salsa::Result<&'static str> { let is_even = input.field(db)? % 2 == 0; if is_even { EVEN_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst);