Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring the overflow behavior in bit shifts in sync with std #395

Merged
merged 15 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ name = "boxed_residue"
harness = false
required-features = ["alloc"]

[[bench]]
name = "boxed_uint"
harness = false
required-features = ["alloc"]

[[bench]]
name = "dyn_residue"
harness = false
Expand Down
48 changes: 48 additions & 0 deletions benches/boxed_uint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use crypto_bigint::BoxedUint;
use rand_core::OsRng;

/// Size of `BoxedUint` to use in benchmark.
const UINT_BITS: u32 = 4096;

fn bench_shifts(c: &mut Criterion) {
let mut group = c.benchmark_group("bit shifts");

group.bench_function("shl_vartime", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| black_box(x.shl_vartime(UINT_BITS / 2 + 10)),
BatchSize::SmallInput,
)
});

group.bench_function("shl", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| x.shl(UINT_BITS / 2 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| black_box(x.shr_vartime(UINT_BITS / 2 + 10)),
BatchSize::SmallInput,
)
});

group.bench_function("shr", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| x.shr(UINT_BITS / 2 + 10),
BatchSize::SmallInput,
)
});

group.finish();
}

criterion_group!(benches, bench_shifts);

criterion_main!(benches);
70 changes: 50 additions & 20 deletions benches/uint.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use criterion::{
black_box, criterion_group, criterion_main, measurement::Measurement, BatchSize,
BenchmarkGroup, Criterion,
};
use crypto_bigint::{Limb, NonZero, Random, Reciprocal, U128, U2048, U256};
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use crypto_bigint::{Limb, NonZero, Random, Reciprocal, Uint, U128, U2048, U256};
use rand_core::OsRng;

fn bench_division<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
fn bench_division(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");

group.bench_function("div/rem, U256/U128, full size", |b| {
b.iter_batched(
|| {
Expand Down Expand Up @@ -69,9 +68,13 @@ fn bench_division<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
BatchSize::SmallInput,
)
});

group.finish();
}

fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
fn bench_shl(c: &mut Criterion) {
let mut group = c.benchmark_group("left shift");

group.bench_function("shl_vartime, small, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput)
});
Expand All @@ -84,16 +87,54 @@ fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
)
});

group.bench_function("shl_vartime_wide, large, U2048", |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::shl_vartime_wide(x, 1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shl, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput)
});

group.finish();
}

fn bench_shr(c: &mut Criterion) {
let mut group = c.benchmark_group("right shift");

group.bench_function("shr_vartime, small, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shr_vartime(10), BatchSize::SmallInput)
});

group.bench_function("shr_vartime, large, U2048", |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.shr_vartime(1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime_wide, large, U2048", |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::shr_vartime_wide(x, 1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput)
});

group.finish();
}

fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
fn bench_inv_mod(c: &mut Criterion) {
let mut group = c.benchmark_group("modular ops");

group.bench_function("inv_odd_mod, U256", |b| {
b.iter_batched(
|| {
Expand Down Expand Up @@ -144,21 +185,10 @@ fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
BatchSize::SmallInput,
)
});
}

fn bench_wrapping_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");
bench_division(&mut group);
group.finish();
}

fn bench_modular_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("modular ops");
bench_shifts(&mut group);
bench_inv_mod(&mut group);
group.finish();
}

criterion_group!(benches, bench_wrapping_ops, bench_modular_ops);
criterion_group!(benches, bench_shl, bench_shr, bench_division, bench_inv_mod);

criterion_main!(benches);
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
//! U256::from_be_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551");
//!
//! // Compute `MODULUS` shifted right by 1 at compile time
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1);
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1).0;
//! ```
//!
//! Note that large constant computations may accidentally trigger a the `const_eval_limit` of the compiler.
Expand Down
1 change: 1 addition & 0 deletions src/limb/bit_not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::Not;

impl Limb {
/// Calculates `!a`.
#[inline(always)]
pub const fn not(self) -> Self {
Limb(!self.0)
}
Expand Down
1 change: 1 addition & 0 deletions src/limb/bit_or.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::{BitOr, BitOrAssign};

impl Limb {
/// Calculates `a | b`.
#[inline(always)]
pub const fn bitor(self, rhs: Self) -> Self {
Limb(self.0 | rhs.0)
}
Expand Down
1 change: 1 addition & 0 deletions src/limb/bit_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::BitXor;

impl Limb {
/// Calculates `a ^ b`.
#[inline(always)]
pub const fn bitxor(self, rhs: Self) -> Self {
Limb(self.0 ^ rhs.0)
}
Expand Down
4 changes: 4 additions & 0 deletions src/limb/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@ use super::Limb;

impl Limb {
/// Calculate the number of bits needed to represent this number.
#[inline(always)]
pub const fn bits(self) -> u32 {
Limb::BITS - self.0.leading_zeros()
}

/// Calculate the number of leading zeros in the binary representation of this number.
#[inline(always)]
pub const fn leading_zeros(self) -> u32 {
self.0.leading_zeros()
}

/// Calculate the number of trailing zeros in the binary representation of this number.
#[inline(always)]
pub const fn trailing_zeros(self) -> u32 {
self.0.trailing_zeros()
}

/// Calculate the number of trailing ones the binary representation of this number.
#[inline(always)]
pub const fn trailing_ones(self) -> u32 {
self.0.trailing_ones()
}
Expand Down
2 changes: 1 addition & 1 deletion src/limb/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl Limb {
}

/// Perform saturating multiplication.
#[inline]
#[inline(always)]
pub const fn saturating_mul(&self, rhs: Self) -> Self {
Limb(self.0.saturating_mul(rhs.0))
}
Expand Down
6 changes: 6 additions & 0 deletions src/limb/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ impl Limb {
pub const fn shl(self, shift: u32) -> Self {
Limb(self.0 << shift)
}

/// Computes `self << 1` and return the result and the carry (0 or 1).
#[inline(always)]
pub(crate) const fn shl1(self) -> (Self, Self) {
(Self(self.0 << 1), Self(self.0 >> Self::HI_BIT))
}
}

impl Shl<u32> for Limb {
Expand Down
6 changes: 6 additions & 0 deletions src/limb/shr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ impl Limb {
pub const fn shr(self, shift: u32) -> Self {
Limb(self.0 >> shift)
}

/// Computes `self >> 1` and return the result and the carry (0 or `1 << HI_BIT`).
#[inline(always)]
pub(crate) const fn shr1(self) -> (Self, Self) {
(Self(self.0 >> 1), Self(self.0 << Self::HI_BIT))
}
}

impl Shr<u32> for Limb {
Expand Down
2 changes: 1 addition & 1 deletion src/modular/div_by_2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub(crate) fn div_by_2<const LIMBS: usize>(a: &Uint<LIMBS>, modulus: &Uint<LIMBS
// ("+1" because both `a` and `modulus` are odd, we lose 0.5 in each integer division).
// This will not overflow, so we can just use wrapping operations.

let (half, is_odd) = a.shr1_with_overflow();
let (half, is_odd) = a.shr1_with_carry();
let half_modulus = modulus.shr1();

let if_even = half;
Expand Down
5 changes: 5 additions & 0 deletions src/uint/boxed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ impl BoxedUint {

limbs.into()
}

/// Set the value of `self` to zero in-place.
pub(crate) fn set_to_zero(&mut self) {
self.limbs.as_mut().fill(Limb::ZERO)
}
}

impl NonZero<BoxedUint> {
Expand Down
2 changes: 1 addition & 1 deletion src/uint/boxed/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ mod tests {
fn uint_with_bits_at(positions: &[u32]) -> BoxedUint {
let mut result = BoxedUint::zero_with_precision(256);
for &pos in positions {
result |= BoxedUint::one_with_precision(256).shl_vartime(pos);
result |= BoxedUint::one_with_precision(256).shl_vartime(pos).unwrap();
}
result
}
Expand Down
8 changes: 5 additions & 3 deletions src/uint/boxed/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ impl BoxedUint {
let mb = rhs.bits();
let mut bd = self.bits_precision() - mb;
let mut rem = self.clone();
let mut c = rhs.shl_vartime(bd);
// Will not overflow since `bd < bits_precision`
let mut c = rhs.shl_vartime(bd).expect("shift within range");

loop {
let (r, borrow) = rem.sbb(&c, Limb::ZERO);
Expand Down Expand Up @@ -77,7 +78,7 @@ impl BoxedUint {
let bits_precision = self.bits_precision();
let mut rem = self.clone();
let mut quo = Self::zero_with_precision(bits_precision);
let mut c = rhs.shl(bits_precision - mb);
let (mut c, _overflow) = rhs.shl(bits_precision - mb);
let mut i = bits_precision;
let mut done = Choice::from(0u8);

Expand Down Expand Up @@ -110,7 +111,8 @@ impl BoxedUint {
let mut bd = self.bits_precision() - mb;
let mut remainder = self.clone();
let mut quotient = Self::zero_with_precision(self.bits_precision());
let mut c = rhs.shl_vartime(bd);
// Will not overflow since `bd < bits_precision`
let mut c = rhs.shl_vartime(bd).expect("shift within range");

loop {
let (mut r, borrow) = remainder.sbb(&c, Limb::ZERO);
Expand Down
10 changes: 5 additions & 5 deletions src/uint/boxed/inv_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ impl BoxedUint {

// Decompose `modulus = s * 2^k` where `s` is odd
let k = modulus.trailing_zeros();
let s = modulus.shr(k);
let s = modulus >> k;

// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
Expand All @@ -26,7 +26,7 @@ impl BoxedUint {
let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists

// This part is mod 2^k
let mask = Self::one().shl(k).wrapping_sub(&Self::one());
let mask = (Self::one() << k).wrapping_sub(&Self::one());
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask);

// Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
Expand Down Expand Up @@ -126,9 +126,9 @@ impl BoxedUint {
let cyy = new_u.conditional_adc_assign(modulus, cy);
debug_assert!(bool::from(cy.ct_eq(&cyy)));

let (new_a, overflow) = a.shr1_with_overflow();
debug_assert!(bool::from(!modulus_is_odd | !overflow));
let (mut new_u, cy) = new_u.shr1_with_overflow();
let (new_a, carry) = a.shr1_with_carry();
debug_assert!(bool::from(!modulus_is_odd | !carry));
let (mut new_u, cy) = new_u.shr1_with_carry();
let cy = new_u.conditional_adc_assign(&m1hp, cy);
debug_assert!(bool::from(!modulus_is_odd | !cy));

Expand Down
Loading