diff --git a/Cargo.toml b/Cargo.toml index 6216ed2..749c838 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,8 @@ rand_core = { version = "0.6.4", default-features = false } openssl = { version = "0.10.39", optional = true, features = ["vendored"] } rug = { version = "1.26", default-features = false, features = ["integer"], optional = true } glass_pumpkin = { version = "1", optional = true } +rayon = { version = "1", optional = true } +num_cpus = { version = "1", optional = true } [dev-dependencies] # need `crypto-bigint` with `alloc` to test `BoxedUint` @@ -35,6 +37,7 @@ tests-gmp = ["rug/std"] tests-glass-pumpkin = ["glass_pumpkin"] tests-exhaustive = [] tests-all = ["tests-openssl", "tests-gmp", "tests-exhaustive", "tests-glass-pumpkin"] +rayon = ["dep:rayon", "num_cpus"] [package.metadata.docs.rs] features = ["default"] diff --git a/src/presets.rs b/src/presets.rs index 7ead461..b978239 100644 --- a/src/presets.rs +++ b/src/presets.rs @@ -6,6 +6,9 @@ use rand_core::CryptoRngCore; #[cfg(feature = "default-rng")] use rand_core::OsRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ParallelBridge, ParallelIterator}; + use crate::hazmat::{lucas_test, random_odd_integer, AStarBase, LucasCheck, MillerRabin, Primality, Sieve}; /// Returns a random prime of size `bit_length` using [`OsRng`] as the RNG. @@ -90,6 +93,45 @@ pub fn generate_safe_prime_with_rng( } } +#[cfg(feature = "rayon")] +/// Use [`rayon`] to parallelize the prime search. +/// +/// Returns a random prime of size `bit_length` using the provided RNG. +/// +/// Panics if `bit_length` is less than 2, or greater than the bit size of the target `Uint`. +/// +/// Panics if the platform is unable to spawn threads. +pub fn par_generate_prime_with_rng(rng: &mut (impl CryptoRngCore + Send + Sync + Clone), bit_length: u32) -> T +where + T: Integer + RandomBits + RandomMod, +{ + if bit_length < 2 { + panic!("`bit_length` must be 2 or greater."); + } + let bit_length = NonZeroU32::new(bit_length).expect("`bit_length` should be non-zero"); + + // TODO(dp): decide how to set the threadcount. + let threadcount = core::cmp::max(2, num_cpus::get() / 2); + let threadpool = rayon::ThreadPoolBuilder::new() + .num_threads(threadcount) + .build() + .expect("If the platform can spawn threads, then this call will work."); + let start = random_odd_integer::(rng, bit_length).get(); + let sieve = Sieve::new(start, bit_length, false); + + threadpool.install(|| { + match sieve.par_bridge().find_any(|c| { + //TODO(dp): This clone feels dumb. OTOH it is fairly cheap to clone a rng. + let mut rng = rng.clone(); + is_prime_with_rng(&mut rng, c) + }) { + // TODO(dp): This clone is very annoying. + Some(p) => return p.clone(), + None => par_generate_prime_with_rng(rng, bit_length.get()), + } + }) +} + /// Probabilistically checks if the given number is prime using the provided RNG. /// /// Performed checks: @@ -367,6 +409,31 @@ mod tests { } } +#[cfg(all(test, feature = "rayon"))] +mod tests_rayon { + use super::{is_prime, par_generate_prime_with_rng}; + use crypto_bigint::{nlimbs, BoxedUint, U128}; + + use super::*; + #[test] + fn parallel_prime_generation() { + for bit_length in (28..=128).step_by(10) { + let p: U128 = par_generate_prime_with_rng(&mut OsRng, bit_length); + assert!(p.bits_vartime() == bit_length); + assert!(is_prime(&p)); + } + } + + #[test] + fn parallel_prime_generation_boxed() { + for bit_length in (28..=128).step_by(10) { + let p: BoxedUint = par_generate_prime_with_rng(&mut OsRng, bit_length); + assert!(p.bits_vartime() == bit_length); + assert!(p.to_words().len() == nlimbs!(bit_length)); + assert!(is_prime(&p)); + } + } +} #[cfg(test)] #[cfg(feature = "tests-openssl")] mod tests_openssl {