Skip to content

Commit

Permalink
mvpoly: introduce optional param upper bound to compute nested loops
Browse files Browse the repository at this point in the history
using filter_map speeds up the computation by avoiding creating a bigger map
than expected if a low upper bound is given with large number of indices.
  • Loading branch information
dannywillems committed Dec 11, 2024
1 parent 2d497c5 commit 597a664
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
9 changes: 6 additions & 3 deletions mvpoly/src/monomials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ impl<const N: usize, const D: usize, F: PrimeField> MVPoly<F, N, D> for Sparse<F
unsafe fn random<RNG: RngCore>(rng: &mut RNG, max_degree: Option<usize>) -> Self {
let degree = max_degree.unwrap_or(D);
// Generating all monomials with degree <= degree^N
let nested_loops_indices: Vec<Vec<usize>> = compute_indices_nested_loop(vec![degree; N]);
let nested_loops_indices: Vec<Vec<usize>> =
compute_indices_nested_loop(vec![degree; N], max_degree);
// Filtering the monomials with degree <= degree
let exponents: Vec<Vec<usize>> = nested_loops_indices
.into_iter()
Expand Down Expand Up @@ -409,8 +410,10 @@ impl<const N: usize, const D: usize, F: PrimeField> MVPoly<F, N, D> for Sparse<F
// Will be used to compute the nested sums
// It returns all the indices i_1, ..., i_k for the sums:
// Σ_{i_1 = 0}^{n_1} Σ_{i_2 = 0}^{n_2} ... Σ_{i_k = 0}^{n_k}
let indices =
compute_indices_nested_loop(non_zero_exponents.iter().map(|d| *d + 1).collect());
let indices = compute_indices_nested_loop(
non_zero_exponents.iter().map(|d| *d + 1).collect(),
None,
);
for i in 0..=u_degree {
// Add the binomial from the homogeneisation
// i.e (u_degree choose i)
Expand Down
20 changes: 16 additions & 4 deletions mvpoly/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ pub fn compute_all_two_factors_decomposition(
}
}

/// Compute the list of indices to perform N nested loops of different size each.
/// Compute the list of indices to perform N nested loops of different size
/// each, whose sum is less than or equal to an optional upper bound.
/// In other words, if we have to perform the 3 nested loops:
/// ```rust
/// let n1 = 3;
Expand Down Expand Up @@ -254,10 +255,13 @@ pub fn compute_all_two_factors_decomposition(
///
/// In the case of an empty loop (i.e. one value in the input list is 0), the
/// expected output is the empty list.
pub fn compute_indices_nested_loop(nested_loop_sizes: Vec<usize>) -> Vec<Vec<usize>> {
pub fn compute_indices_nested_loop(
nested_loop_sizes: Vec<usize>,
upper_bound: Option<usize>,
) -> Vec<Vec<usize>> {
let n = nested_loop_sizes.iter().product();
(0..n)
.map(|i| {
.filter_map(|i| {
let mut div = 1;
// Compute indices for the loop, step i
let indices: Vec<usize> = nested_loop_sizes
Expand All @@ -268,7 +272,15 @@ pub fn compute_indices_nested_loop(nested_loop_sizes: Vec<usize>) -> Vec<Vec<usi
k
})
.collect();
indices
if let Some(upper_bound) = upper_bound {
if indices.iter().sum::<usize>() <= upper_bound {
Some(indices)
} else {
None
}
} else {
Some(indices)
}
})
.collect()
}
29 changes: 21 additions & 8 deletions mvpoly/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ pub fn test_compute_indices_nested_loop() {
// sorting to get the same order
let mut exp_indices = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
exp_indices.sort();
let mut comp_indices = compute_indices_nested_loop(nested_loops);
let mut comp_indices = compute_indices_nested_loop(nested_loops, None);
comp_indices.sort();
assert_eq!(exp_indices, comp_indices);

Expand All @@ -247,7 +247,7 @@ pub fn test_compute_indices_nested_loop() {
vec![2, 1],
];
exp_indices.sort();
let mut comp_indices = compute_indices_nested_loop(nested_loops);
let mut comp_indices = compute_indices_nested_loop(nested_loops, None);
comp_indices.sort();
assert_eq!(exp_indices, comp_indices);

Expand Down Expand Up @@ -292,38 +292,51 @@ pub fn test_compute_indices_nested_loop() {
vec![2, 2, 1, 1],
];
exp_indices.sort();
let mut comp_indices = compute_indices_nested_loop(nested_loops);
let mut comp_indices = compute_indices_nested_loop(nested_loops, None);
comp_indices.sort();
assert_eq!(exp_indices, comp_indices);

// Simple and single loop
let nested_loops = vec![3];
let exp_indices = vec![vec![0], vec![1], vec![2]];
let mut comp_indices = compute_indices_nested_loop(nested_loops);
let mut comp_indices = compute_indices_nested_loop(nested_loops, None);
comp_indices.sort();
assert_eq!(exp_indices, comp_indices);

// relatively large loops
let nested_loops = vec![10, 10];
let comp_indices = compute_indices_nested_loop(nested_loops);
let comp_indices = compute_indices_nested_loop(nested_loops, None);
// Only checking the length as it would take too long to unroll the result
assert_eq!(comp_indices.len(), 100);

// Non-uniform loop sizes, relatively large
let nested_loops = vec![5, 7, 3];
let comp_indices = compute_indices_nested_loop(nested_loops);
let comp_indices = compute_indices_nested_loop(nested_loops, None);
assert_eq!(comp_indices.len(), 5 * 7 * 3);
}

#[test]
fn test_compute_indices_nested_loop_edge_cases() {
let nested_loops = vec![];
let comp_indices: Vec<Vec<usize>> = compute_indices_nested_loop(nested_loops);
let comp_indices: Vec<Vec<usize>> = compute_indices_nested_loop(nested_loops, None);
let exp_output: Vec<Vec<usize>> = vec![vec![]];
assert_eq!(comp_indices, exp_output);

// With one empty loop. Should match the documentation
let nested_loops = vec![3, 0, 2];
let comp_indices = compute_indices_nested_loop(nested_loops);
let comp_indices = compute_indices_nested_loop(nested_loops, None);
assert_eq!(comp_indices.len(), 0);
}

#[test]
fn test_compute_indices_nested_loops_upper_bound() {
let nested_loops = vec![3, 3];
let comp_indices = compute_indices_nested_loop(nested_loops.clone(), Some(0));
assert_eq!(comp_indices.len(), 1);

let comp_indices = compute_indices_nested_loop(nested_loops.clone(), Some(1));
assert_eq!(comp_indices.len(), 3);

let comp_indices = compute_indices_nested_loop(nested_loops, Some(2));
assert_eq!(comp_indices.len(), 6);
}

0 comments on commit 597a664

Please sign in to comment.