Skip to content

Commit

Permalink
Merge pull request #524 from PhoebeSzmucer/ps/accumulate-chain
Browse files Browse the repository at this point in the history
Fix accumulator only accumulating direct children
  • Loading branch information
nikomatsakis authored Jul 22, 2024
2 parents 431fd14 + a85ac26 commit c8234e4
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 12 deletions.
26 changes: 18 additions & 8 deletions src/function/accumulated.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{accumulator, storage::DatabaseGen, Id};
use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id};

use super::{Configuration, IngredientImpl};

Expand All @@ -21,14 +21,24 @@ where
// First ensure the result is up to date
self.fetch(db, key);

let database_key_index = self.database_key_index(key);
accumulator.produced_by(runtime, database_key_index, &mut output);
let db_key = self.database_key_index(key);
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];

if let Some(origin) = self.origin(key) {
for input in origin.inputs() {
if let Ok(input) = input.try_into() {
accumulator.produced_by(runtime, input, &mut output);
}
while let Some(k) = stack.pop() {
if visited.insert(k) {
accumulator.produced_by(runtime, k, &mut output);

let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index);
let inputs = origin.iter().flat_map(|origin| origin.inputs());
// Careful: we want to push in execution order, so reverse order to
// ensure the first child that was executed will be the first child popped
// from the stack.
stack.extend(
inputs
.flat_map(|input| TryInto::<DatabaseKeyIndex>::try_into(input).into_iter())
.rev(),
);
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/runtime/local_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub enum QueryOrigin {

impl QueryOrigin {
/// Indices for queries *read* by this query
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
let opt_edges = match self {
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
Expand All @@ -86,7 +86,7 @@ impl QueryOrigin {
}

/// Indices for queries *written* by this query (if any)
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
let opt_edges = match self {
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
Expand Down Expand Up @@ -127,7 +127,7 @@ impl QueryEdges {
/// Returns the (tracked) inputs that were executed in computing this memoized value.
///
/// These will always be in execution order.
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
self.input_outputs
.iter()
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input)
Expand All @@ -137,7 +137,7 @@ impl QueryEdges {
/// Returns the (tracked) outputs that were executed in computing this memoized value.
///
/// These will always be in execution order.
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
self.input_outputs
.iter()
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output)
Expand Down
57 changes: 57 additions & 0 deletions tests/accumulate-chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//! Test that when having nested tracked functions
//! we don't drop any values when accumulating.
mod common;

use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;

#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);

#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db);
}

#[salsa::tracked]
fn push_a_logs(db: &dyn Database) {
Log("log a".to_string()).accumulate(db);
push_b_logs(db);
}

#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
// No logs
push_c_logs(db);
}

#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
// No logs
push_d_logs(db);
}

#[salsa::tracked]
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
}

#[test]
fn accumulate_chain() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Check that we get all the logs.
expect![[r#"
[
Log(
"log a",
),
Log(
"log d",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}
64 changes: 64 additions & 0 deletions tests/accumulate-execution-order.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//! Demonstrates that accumulation is done in the order
//! in which things were originally executed.
mod common;

use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;

#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);

#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db);
}

#[salsa::tracked]
fn push_a_logs(db: &dyn Database) {
Log("log a".to_string()).accumulate(db);
push_b_logs(db);
push_c_logs(db);
push_d_logs(db);
}

#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
Log("log b".to_string()).accumulate(db);
push_d_logs(db);
}

#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
Log("log c".to_string()).accumulate(db);
}

#[salsa::tracked]
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
}

#[test]
fn accumulate_execution_order() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Check that we get logs in execution order
expect![[r#"
[
Log(
"log a",
),
Log(
"log b",
),
Log(
"log d",
),
Log(
"log c",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}
104 changes: 104 additions & 0 deletions tests/accumulate-no-duplicates.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//! Test that we don't get duplicate accumulated values
mod common;

use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;

// A(1) {
// B
// B
// C {
// D {
// A(2) {
// B
// }
// B
// }
// E
// }
// B
// }

#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);

#[salsa::input]
struct MyInput {
n: u32,
}

#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db, MyInput::new(db, 1));
}

#[salsa::tracked]
fn push_a_logs(db: &dyn Database, input: MyInput) {
Log("log a".to_string()).accumulate(db);
if input.n(db) == 1 {
push_b_logs(db);
push_b_logs(db);
push_c_logs(db);
push_b_logs(db);
} else {
push_b_logs(db);
}
}

#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
Log("log b".to_string()).accumulate(db);
}

#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
Log("log c".to_string()).accumulate(db);
push_d_logs(db);
push_e_logs(db);
}

// Note this isn't tracked
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
push_a_logs(db, MyInput::new(db, 2));
push_b_logs(db);
}

#[salsa::tracked]
fn push_e_logs(db: &dyn Database) {
Log("log e".to_string()).accumulate(db);
}

#[test]
fn accumulate_no_duplicates() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Test that there aren't duplicate B logs.
// Note that log A appears twice, because they both come
// from different inputs.
expect![[r#"
[
Log(
"log a",
),
Log(
"log b",
),
Log(
"log c",
),
Log(
"log d",
),
Log(
"log a",
),
Log(
"log e",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}

0 comments on commit c8234e4

Please sign in to comment.