diff --git a/mvpoly/src/monomials.rs b/mvpoly/src/monomials.rs index 3a11fe4ef5..275d322a5a 100644 --- a/mvpoly/src/monomials.rs +++ b/mvpoly/src/monomials.rs @@ -293,7 +293,8 @@ impl MVPoly for Sparse(rng: &mut RNG, max_degree: Option) -> Self { let degree = max_degree.unwrap_or(D); // Generating all monomials with degree <= degree^N - let nested_loops_indices: Vec> = compute_indices_nested_loop(vec![degree; N]); + let nested_loops_indices: Vec> = + compute_indices_nested_loop(vec![degree; N], max_degree); // Filtering the monomials with degree <= degree let exponents: Vec> = nested_loops_indices .into_iter() @@ -409,8 +410,10 @@ impl MVPoly for Sparse) -> Vec> { +pub fn compute_indices_nested_loop( + nested_loop_sizes: Vec, + upper_bound: Option, +) -> Vec> { let n = nested_loop_sizes.iter().product(); (0..n) - .map(|i| { + .filter_map(|i| { let mut div = 1; // Compute indices for the loop, step i let indices: Vec = nested_loop_sizes @@ -268,7 +272,15 @@ pub fn compute_indices_nested_loop(nested_loop_sizes: Vec) -> Vec() <= upper_bound { + Some(indices) + } else { + None + } + } else { + Some(indices) + } }) .collect() } diff --git a/mvpoly/tests/utils.rs b/mvpoly/tests/utils.rs index dad2abe4d3..032f36ec77 100644 --- a/mvpoly/tests/utils.rs +++ b/mvpoly/tests/utils.rs @@ -232,7 +232,7 @@ pub fn test_compute_indices_nested_loop() { // sorting to get the same order let mut exp_indices = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]]; exp_indices.sort(); - let mut comp_indices = compute_indices_nested_loop(nested_loops); + let mut comp_indices = compute_indices_nested_loop(nested_loops, None); comp_indices.sort(); assert_eq!(exp_indices, comp_indices); @@ -247,7 +247,7 @@ pub fn test_compute_indices_nested_loop() { vec![2, 1], ]; exp_indices.sort(); - let mut comp_indices = compute_indices_nested_loop(nested_loops); + let mut comp_indices = compute_indices_nested_loop(nested_loops, None); comp_indices.sort(); assert_eq!(exp_indices, comp_indices); @@ -292,38 +292,51 @@ pub fn test_compute_indices_nested_loop() { vec![2, 2, 1, 1], ]; exp_indices.sort(); - let mut comp_indices = compute_indices_nested_loop(nested_loops); + let mut comp_indices = compute_indices_nested_loop(nested_loops, None); comp_indices.sort(); assert_eq!(exp_indices, comp_indices); // Simple and single loop let nested_loops = vec![3]; let exp_indices = vec![vec![0], vec![1], vec![2]]; - let mut comp_indices = compute_indices_nested_loop(nested_loops); + let mut comp_indices = compute_indices_nested_loop(nested_loops, None); comp_indices.sort(); assert_eq!(exp_indices, comp_indices); // relatively large loops let nested_loops = vec![10, 10]; - let comp_indices = compute_indices_nested_loop(nested_loops); + let comp_indices = compute_indices_nested_loop(nested_loops, None); // Only checking the length as it would take too long to unroll the result assert_eq!(comp_indices.len(), 100); // Non-uniform loop sizes, relatively large let nested_loops = vec![5, 7, 3]; - let comp_indices = compute_indices_nested_loop(nested_loops); + let comp_indices = compute_indices_nested_loop(nested_loops, None); assert_eq!(comp_indices.len(), 5 * 7 * 3); } #[test] fn test_compute_indices_nested_loop_edge_cases() { let nested_loops = vec![]; - let comp_indices: Vec> = compute_indices_nested_loop(nested_loops); + let comp_indices: Vec> = compute_indices_nested_loop(nested_loops, None); let exp_output: Vec> = vec![vec![]]; assert_eq!(comp_indices, exp_output); // With one empty loop. Should match the documentation let nested_loops = vec![3, 0, 2]; - let comp_indices = compute_indices_nested_loop(nested_loops); + let comp_indices = compute_indices_nested_loop(nested_loops, None); assert_eq!(comp_indices.len(), 0); } + +#[test] +fn test_compute_indices_nested_loops_upper_bound() { + let nested_loops = vec![3, 3]; + let comp_indices = compute_indices_nested_loop(nested_loops.clone(), Some(0)); + assert_eq!(comp_indices.len(), 1); + + let comp_indices = compute_indices_nested_loop(nested_loops.clone(), Some(1)); + assert_eq!(comp_indices.len(), 3); + + let comp_indices = compute_indices_nested_loop(nested_loops, Some(2)); + assert_eq!(comp_indices.len(), 6); +}