Skip to content

Commit

Permalink
Don't pessimize scalar accumulation of trailing elements in fold
Browse files Browse the repository at this point in the history
  • Loading branch information
HadrienG2 committed Oct 23, 2023
1 parent cf33af1 commit 5a70f3a
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ pub trait IteratorILP: Iterator + Sized + TrustedLowerBound {

// Set up accumulators
let mut accumulators: [Option<Acc>; STREAMS] = core::array::from_fn(|_| Some(neutral()));
let mut accumulate = |accumulator: &mut Option<Acc>, item| {
let mut accumulate_opt = |accumulator: &mut Option<Acc>, item| {
if let Some(prev_acc) = accumulator.take() {
*accumulator = Some(accumulate(prev_acc, item));
}
Expand All @@ -441,15 +441,10 @@ pub trait IteratorILP: Iterator + Sized + TrustedLowerBound {
// the lower bound returned by size_hint is correct, and
// the above loop will not iterate for more than this
// amount of iteration, so this is trusted to be safe.
accumulate(acc, unsafe { self.next().unwrap_unchecked() });
accumulate_opt(acc, unsafe { self.next().unwrap_unchecked() });
}
}

// Accumulate irregular elements at the end
for (idx, item) in self.enumerate() {
accumulate(&mut accumulators[idx % STREAMS], item);
}

// Merge the accumulators
let mut stride = STREAMS;
while stride > 1 {
Expand All @@ -461,7 +456,11 @@ pub trait IteratorILP: Iterator + Sized + TrustedLowerBound {
));
}
}
accumulators[0].take().unwrap()
let ilp_result = accumulators[0].take().unwrap();

// Accumulate remaining irregular elements using standard iterator fold,
// then merge (doing it like this improves floating-point accuracy)
merge(ilp_result, self.fold(neutral(), accumulate))
}

/// Like [`Iterator::reduce()`], but with multiple ILP streams
Expand Down

0 comments on commit 5a70f3a

Please sign in to comment.