From f9499256287b3a801ef2886488376b8f70cede3e Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 8 Dec 2023 15:44:24 -0800 Subject: [PATCH 01/15] Add more benchmarks for shifts --- benches/uint.rs | 70 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/benches/uint.rs b/benches/uint.rs index 944d8c24..3688a2d9 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -1,11 +1,10 @@ -use criterion::{ - black_box, criterion_group, criterion_main, measurement::Measurement, BatchSize, - BenchmarkGroup, Criterion, -}; -use crypto_bigint::{Limb, NonZero, Random, Reciprocal, U128, U2048, U256}; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use crypto_bigint::{Limb, NonZero, Random, Reciprocal, Uint, U128, U2048, U256}; use rand_core::OsRng; -fn bench_division(group: &mut BenchmarkGroup<'_, M>) { +fn bench_division(c: &mut Criterion) { + let mut group = c.benchmark_group("wrapping ops"); + group.bench_function("div/rem, U256/U128, full size", |b| { b.iter_batched( || { @@ -69,9 +68,13 @@ fn bench_division(group: &mut BenchmarkGroup<'_, M>) { BatchSize::SmallInput, ) }); + + group.finish(); } -fn bench_shifts(group: &mut BenchmarkGroup<'_, M>) { +fn bench_shl(c: &mut Criterion) { + let mut group = c.benchmark_group("left shift"); + group.bench_function("shl_vartime, small, U2048", |b| { b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput) }); @@ -84,16 +87,54 @@ fn bench_shifts(group: &mut BenchmarkGroup<'_, M>) { ) }); + group.bench_function("shl_vartime_wide, large, U2048", |b| { + b.iter_batched( + || (U2048::ONE, U2048::ONE), + |x| Uint::shl_vartime_wide(x, 1024 + 10), + BatchSize::SmallInput, + ) + }); + group.bench_function("shl, U2048", |b| { b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput) }); + group.finish(); +} + +fn bench_shr(c: &mut Criterion) { + let mut group = c.benchmark_group("right shift"); + + group.bench_function("shr_vartime, small, U2048", |b| { + b.iter_batched(|| U2048::ONE, |x| x.shr_vartime(10), BatchSize::SmallInput) + }); + + group.bench_function("shr_vartime, large, U2048", |b| { + b.iter_batched( + || U2048::ONE, + |x| x.shr_vartime(1024 + 10), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr_vartime_wide, large, U2048", |b| { + b.iter_batched( + || (U2048::ONE, U2048::ONE), + |x| Uint::shr_vartime_wide(x, 1024 + 10), + BatchSize::SmallInput, + ) + }); + group.bench_function("shr, U2048", |b| { b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput) }); + + group.finish(); } -fn bench_inv_mod(group: &mut BenchmarkGroup<'_, M>) { +fn bench_inv_mod(c: &mut Criterion) { + let mut group = c.benchmark_group("modular ops"); + group.bench_function("inv_odd_mod, U256", |b| { b.iter_batched( || { @@ -144,21 +185,10 @@ fn bench_inv_mod(group: &mut BenchmarkGroup<'_, M>) { BatchSize::SmallInput, ) }); -} -fn bench_wrapping_ops(c: &mut Criterion) { - let mut group = c.benchmark_group("wrapping ops"); - bench_division(&mut group); - group.finish(); -} - -fn bench_modular_ops(c: &mut Criterion) { - let mut group = c.benchmark_group("modular ops"); - bench_shifts(&mut group); - bench_inv_mod(&mut group); group.finish(); } -criterion_group!(benches, bench_wrapping_ops, bench_modular_ops); +criterion_group!(benches, bench_shl, bench_shr, bench_division, bench_inv_mod); criterion_main!(benches); From 8257838dab03a64910d8229aba28237ee5714915 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Thu, 7 Dec 2023 14:01:10 -0800 Subject: [PATCH 02/15] Inline some low-level ops --- src/limb/bit_not.rs | 1 + src/limb/bit_or.rs | 1 + src/limb/bit_xor.rs | 1 + src/limb/bits.rs | 4 ++++ src/limb/mul.rs | 2 +- src/limb/shl.rs | 6 ++++++ src/limb/shr.rs | 6 ++++++ 7 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/limb/bit_not.rs b/src/limb/bit_not.rs index 26676d59..6d728d45 100644 --- a/src/limb/bit_not.rs +++ b/src/limb/bit_not.rs @@ -5,6 +5,7 @@ use core::ops::Not; impl Limb { /// Calculates `!a`. + #[inline(always)] pub const fn not(self) -> Self { Limb(!self.0) } diff --git a/src/limb/bit_or.rs b/src/limb/bit_or.rs index f863ac0d..340f4f76 100644 --- a/src/limb/bit_or.rs +++ b/src/limb/bit_or.rs @@ -5,6 +5,7 @@ use core::ops::{BitOr, BitOrAssign}; impl Limb { /// Calculates `a | b`. + #[inline(always)] pub const fn bitor(self, rhs: Self) -> Self { Limb(self.0 | rhs.0) } diff --git a/src/limb/bit_xor.rs b/src/limb/bit_xor.rs index a5078229..7c04e7b7 100644 --- a/src/limb/bit_xor.rs +++ b/src/limb/bit_xor.rs @@ -5,6 +5,7 @@ use core::ops::BitXor; impl Limb { /// Calculates `a ^ b`. + #[inline(always)] pub const fn bitxor(self, rhs: Self) -> Self { Limb(self.0 ^ rhs.0) } diff --git a/src/limb/bits.rs b/src/limb/bits.rs index 1c7674f4..4553137b 100644 --- a/src/limb/bits.rs +++ b/src/limb/bits.rs @@ -2,21 +2,25 @@ use super::Limb; impl Limb { /// Calculate the number of bits needed to represent this number. + #[inline(always)] pub const fn bits(self) -> u32 { Limb::BITS - self.0.leading_zeros() } /// Calculate the number of leading zeros in the binary representation of this number. + #[inline(always)] pub const fn leading_zeros(self) -> u32 { self.0.leading_zeros() } /// Calculate the number of trailing zeros in the binary representation of this number. + #[inline(always)] pub const fn trailing_zeros(self) -> u32 { self.0.trailing_zeros() } /// Calculate the number of trailing ones the binary representation of this number. + #[inline(always)] pub const fn trailing_ones(self) -> u32 { self.0.trailing_ones() } diff --git a/src/limb/mul.rs b/src/limb/mul.rs index 7f8b0845..1ea73b4e 100644 --- a/src/limb/mul.rs +++ b/src/limb/mul.rs @@ -17,7 +17,7 @@ impl Limb { } /// Perform saturating multiplication. - #[inline] + #[inline(always)] pub const fn saturating_mul(&self, rhs: Self) -> Self { Limb(self.0.saturating_mul(rhs.0)) } diff --git a/src/limb/shl.rs b/src/limb/shl.rs index 03e4a103..0e655387 100644 --- a/src/limb/shl.rs +++ b/src/limb/shl.rs @@ -10,6 +10,12 @@ impl Limb { pub const fn shl(self, shift: u32) -> Self { Limb(self.0 << shift) } + + /// Computes `self << 1` and return the result and the carry (0 or 1). + #[inline(always)] + pub(crate) const fn shl1(self) -> (Self, Self) { + (Self(self.0 << 1), Self(self.0 >> Self::HI_BIT)) + } } impl Shl for Limb { diff --git a/src/limb/shr.rs b/src/limb/shr.rs index a91c65d5..10fca6dc 100644 --- a/src/limb/shr.rs +++ b/src/limb/shr.rs @@ -10,6 +10,12 @@ impl Limb { pub const fn shr(self, shift: u32) -> Self { Limb(self.0 >> shift) } + + /// Computes `self >> 1` and return the result and the carry (0 or `1 << HI_BIT`). + #[inline(always)] + pub(crate) const fn shr1(self) -> (Self, Self) { + (Self(self.0 >> 1), Self(self.0 << Self::HI_BIT)) + } } impl Shr for Limb { From 4dbb16cafdb92fe2359fac675e7b325f12c8bd27 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Tue, 5 Dec 2023 13:40:43 -0800 Subject: [PATCH 03/15] Return the overflow status in const fn bit shifts, panic if overflow occurred --- src/lib.rs | 2 +- src/uint/div.rs | 24 +++++++------ src/uint/inv_mod.rs | 9 +++-- src/uint/mul.rs | 2 +- src/uint/shl.rs | 78 ++++++++++++++++++++++++++++------------- src/uint/shr.rs | 76 +++++++++++++++++++++++++-------------- src/uint/sqrt.rs | 6 ++-- tests/uint_proptests.rs | 15 ++++++-- 8 files changed, 141 insertions(+), 71 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5b1b1fc4..ecf524cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,7 @@ //! U256::from_be_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"); //! //! // Compute `MODULUS` shifted right by 1 at compile time -//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1); +//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1).0; //! ``` //! //! Note that large constant computations may accidentally trigger a the `const_eval_limit` of the compiler. diff --git a/src/uint/div.rs b/src/uint/div.rs index ac50a206..b634a9c9 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -53,7 +53,9 @@ impl Uint { let mb = rhs.bits(); let mut rem = *self; let mut quo = Self::ZERO; - let mut c = rhs.shl(Self::BITS - mb); + // If there is overflow, it means `mb == 0`, so `rhs == 0`. + let (mut c, overflow) = rhs.shl(Self::BITS - mb); + let is_some = overflow.not(); let mut i = Self::BITS; let mut done = CtChoice::FALSE; @@ -73,7 +75,6 @@ impl Uint { quo = Self::ct_select(&quo.shl1(), &quo, done); } - let is_some = Limb(mb as Word).ct_is_nonzero(); quo = Self::ct_select(&Self::ZERO, &quo, is_some); (quo, rem, is_some) } @@ -93,7 +94,8 @@ impl Uint { let mut bd = Self::BITS - mb; let mut rem = *self; let mut quo = Self::ZERO; - let mut c = rhs.shl_vartime(bd); + let (mut c, overflow) = rhs.shl_vartime(bd); + let is_some = overflow.not(); loop { let (mut r, borrow) = rem.sbb(&c, Limb::ZERO); @@ -108,7 +110,6 @@ impl Uint { quo = quo.shl1(); } - let is_some = CtChoice::from_u32_nonzero(mb); quo = Self::ct_select(&Self::ZERO, &quo, is_some); (quo, rem, is_some) } @@ -125,7 +126,8 @@ impl Uint { let mb = rhs.bits_vartime(); let mut bd = Self::BITS - mb; let mut rem = *self; - let mut c = rhs.shl_vartime(bd); + let (mut c, overflow) = rhs.shl_vartime(bd); + let is_some = overflow.not(); loop { let (r, borrow) = rem.sbb(&c, Limb::ZERO); @@ -137,7 +139,6 @@ impl Uint { c = c.shr1(); } - let is_some = CtChoice::from_u32_nonzero(mb); (rem, is_some) } @@ -158,7 +159,7 @@ impl Uint { let (mut lower, mut upper) = lower_upper; // Factor of the modulus, split into two halves - let mut c = Self::shl_vartime_wide((*rhs, Uint::ZERO), bd); + let (mut c, _overflow) = Self::shl_vartime_wide((*rhs, Uint::ZERO), bd); loop { let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO); @@ -170,7 +171,8 @@ impl Uint { break; } bd -= 1; - c = Self::shr_vartime_wide(c, 1); + let (new_c, _overflow) = Self::shr_vartime_wide(c, 1); + c = new_c; } let is_some = CtChoice::from_u32_nonzero(mb); @@ -696,8 +698,8 @@ mod tests { fn div() { let mut rng = ChaChaRng::from_seed([7u8; 32]); for _ in 0..25 { - let num = U256::random(&mut rng).shr_vartime(128); - let den = U256::random(&mut rng).shr_vartime(128); + let (num, _) = U256::random(&mut rng).shr_vartime(128); + let (den, _) = U256::random(&mut rng).shr_vartime(128); let n = num.checked_mul(&den); if n.is_some().into() { let (q, _, is_some) = n.unwrap().const_div_rem(&den); @@ -808,7 +810,7 @@ mod tests { for _ in 0..25 { let num = U256::random(&mut rng); let k = rng.next_u32() % 256; - let den = U256::ONE.shl_vartime(k); + let (den, _) = U256::ONE.shl_vartime(k); let a = num.rem2k(k); let e = num.wrapping_rem(&den); diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index 28694720..bc8beb82 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -30,7 +30,8 @@ impl Uint { // b_{i+1} = (b_i - a * X_i) / 2 b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr1(); // Store the X_i bit in the result (x = x | (1 << X_i)) - x = x.bitor(&Uint::from_word(x_i).shl_vartime(i)); + let (shifted, _overflow) = Uint::from_word(x_i).shl_vartime(i); + x = x.bitor(&shifted); i += 1; } @@ -161,7 +162,7 @@ impl Uint { pub const fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) { // Decompose `modulus = s * 2^k` where `s` is odd let k = modulus.trailing_zeros(); - let s = modulus.shr(k); + let (s, _overflow) = modulus.shr(k); // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` @@ -176,7 +177,9 @@ impl Uint { let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists // This part is mod 2^k - let mask = Uint::ONE.shl(k).wrapping_sub(&Uint::ONE); + // Will not overflow since `modulus` is nonzero, and therefore `k < BITS`. + let (shifted, _overflow) = Uint::ONE.shl(k); + let mask = shifted.wrapping_sub(&Uint::ONE); let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask); // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`, diff --git a/src/uint/mul.rs b/src/uint/mul.rs index a668f2ce..41e9e0eb 100644 --- a/src/uint/mul.rs +++ b/src/uint/mul.rs @@ -135,7 +135,7 @@ impl Uint { // Double the current result, this accounts for the other half of the multiplication grid. // TODO: The top word is empty so we can also use a special purpose shl. - (lo, hi) = Self::shl_vartime_wide((lo, hi), 1); + (lo, hi) = Self::shl_vartime_wide((lo, hi), 1).0; // Handle the diagonal of the multiplication grid, which finishes the multiplication grid. let mut carry = Limb::ZERO; diff --git a/src/uint/shl.rs b/src/uint/shl.rs index ce4b9047..3e8d4db3 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -5,33 +5,36 @@ use core::ops::{Shl, ShlAssign}; impl Uint { /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub const fn shl(&self, shift: u32) -> Self { + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. + pub const fn shl(&self, shift: u32) -> (Self, CtChoice) { let overflow = CtChoice::from_u32_lt(shift, Self::BITS).not(); let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; while i < Self::LOG2_BITS + 1 { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::ct_select(&result, &result.shl_vartime(1 << i), bit); + result = Uint::ct_select(&result, &result.shl_vartime(1 << i).0, bit); i += 1; } - Uint::ct_select(&result, &Self::ZERO, overflow) + (Uint::ct_select(&result, &Self::ZERO, overflow), overflow) } /// Computes `self << shift`. + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shl_vartime(&self, shift: u32) -> Self { + pub const fn shl_vartime(&self, shift: u32) -> (Self, CtChoice) { let mut limbs = [Limb::ZERO; LIMBS]; if shift >= Self::BITS { - return Self { limbs }; + return (Self::ZERO, CtChoice::TRUE); } let shift_num = (shift / Limb::BITS) as usize; @@ -44,27 +47,34 @@ impl Uint { } let (new_lower, _carry) = (Self { limbs }).shl_limb(rem); - new_lower + (new_lower, CtChoice::FALSE) } /// Computes a left shift on a wide input as `(lo, hi)`. + /// If `shift >= Self::BITS`, returns a tuple of zeros as the first element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shl_vartime_wide(lower_upper: (Self, Self), shift: u32) -> (Self, Self) { - let (lower, mut upper) = lower_upper; - let new_lower = lower.shl_vartime(shift); - upper = upper.shl_vartime(shift); - if shift >= Self::BITS { - upper = upper.bitor(&lower.shl_vartime(shift - Self::BITS)); + pub const fn shl_vartime_wide( + lower_upper: (Self, Self), + shift: u32, + ) -> ((Self, Self), CtChoice) { + let (lower, upper) = lower_upper; + if shift >= 2 * Self::BITS { + ((Self::ZERO, Self::ZERO), CtChoice::TRUE) + } else if shift >= Self::BITS { + let (upper, _) = lower.shl_vartime(shift - Self::BITS); + ((Self::ZERO, upper), CtChoice::FALSE) } else { - upper = upper.bitor(&lower.shr_vartime(Self::BITS - shift)); + let (new_lower, _) = lower.shl_vartime(shift); + let (upper_lo, _) = lower.shr_vartime(Self::BITS - shift); + let (upper_hi, _) = upper.shl_vartime(shift); + ((new_lower, upper_lo.bitor(&upper_hi)), CtChoice::FALSE) } - - (new_lower, upper) } /// Computes `self << shift` where `0 <= shift < Limb::BITS`, @@ -94,7 +104,7 @@ impl Uint { /// Computes `self >> 1` in constant-time. pub(crate) const fn shl1(&self) -> Self { // TODO(tarcieri): optimized implementation - self.shl_vartime(1) + self.shl_vartime(1).0 } } @@ -102,7 +112,7 @@ impl Shl for Uint { type Output = Uint; fn shl(self, shift: u32) -> Uint { - Uint::::shl(&self, shift) + <&Uint as Shl>::shl(&self, shift) } } @@ -110,7 +120,12 @@ impl Shl for &Uint { type Output = Uint; fn shl(self, shift: u32) -> Uint { - self.shl(shift) + let (result, overflow) = Uint::::shl(self, shift); + assert!( + !overflow.is_true_vartime(), + "attempt to shift left with overflow" + ); + result } } @@ -122,7 +137,7 @@ impl ShlAssign for Uint { #[cfg(test)] mod tests { - use crate::{Limb, Uint, U128, U256}; + use crate::{CtChoice, Limb, Uint, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -153,6 +168,7 @@ mod tests { #[test] fn shl1() { assert_eq!(N << 1, TWO_N); + assert_eq!(N.shl1(), TWO_N); } #[test] @@ -171,8 +187,15 @@ mod tests { } #[test] + fn shl256_const() { + assert_eq!(N.shl(256), (U256::ZERO, CtChoice::TRUE)); + assert_eq!(N.shl_vartime(256), (U256::ZERO, CtChoice::TRUE)); + } + + #[test] + #[should_panic(expected = "attempt to shift left with overflow")] fn shl256() { - assert_eq!(N << 256, U256::default()); + let _ = N << 256; } #[test] @@ -184,7 +207,11 @@ mod tests { fn shl_wide_1_1_128() { assert_eq!( Uint::shl_vartime_wide((U128::ONE, U128::ONE), 128), - (U128::ZERO, U128::ONE) + ((U128::ZERO, U128::ONE), CtChoice::FALSE) + ); + assert_eq!( + Uint::shl_vartime_wide((U128::ONE, U128::ONE), 128), + ((U128::ZERO, U128::ONE), CtChoice::FALSE) ); } @@ -192,7 +219,10 @@ mod tests { fn shl_wide_max_0_1() { assert_eq!( Uint::shl_vartime_wide((U128::MAX, U128::ZERO), 1), - (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE) + ( + (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE), + CtChoice::FALSE + ) ); } @@ -200,7 +230,7 @@ mod tests { fn shl_wide_max_max_256() { assert_eq!( Uint::shl_vartime_wide((U128::MAX, U128::MAX), 256), - (U128::ZERO, U128::ZERO) + ((U128::ZERO, U128::ZERO), CtChoice::TRUE) ); } } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 8f2b0b69..b4476a04 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -1,40 +1,42 @@ //! [`Uint`] bitwise right shift operations. -use super::Uint; -use crate::{CtChoice, Limb}; +use crate::{CtChoice, Limb, Uint}; use core::ops::{Shr, ShrAssign}; impl Uint { - /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub const fn shr(&self, shift: u32) -> Self { + /// Computes `self >> shift`. + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. + pub const fn shr(&self, shift: u32) -> (Self, CtChoice) { let overflow = CtChoice::from_u32_lt(shift, Self::BITS).not(); let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; while i < Self::LOG2_BITS + 1 { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::ct_select(&result, &result.shr_vartime(1 << i), bit); + result = Uint::ct_select(&result, &result.shr_vartime(1 << i).0, bit); i += 1; } - Uint::ct_select(&result, &Self::ZERO, overflow) + (Uint::ct_select(&result, &Self::ZERO, overflow), overflow) } /// Computes `self >> shift`. + /// If `shift >= Self::BITS`, returns zero as the first tuple element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shr_vartime(&self, shift: u32) -> Self { + pub const fn shr_vartime(&self, shift: u32) -> (Self, CtChoice) { let full_shifts = (shift / Limb::BITS) as usize; let small_shift = shift & (Limb::BITS - 1); let mut limbs = [Limb::ZERO; LIMBS]; - if shift > Self::BITS { - return Self { limbs }; + if shift >= Self::BITS { + return (Self::ZERO, CtChoice::TRUE); } let shift = LIMBS - full_shifts; @@ -58,7 +60,7 @@ impl Uint { } } - Self { limbs } + (Self { limbs }, CtChoice::FALSE) } /// Computes a right shift on a wide input as `(lo, hi)`. @@ -68,17 +70,22 @@ impl Uint { /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn shr_vartime_wide(lower_upper: (Self, Self), shift: u32) -> (Self, Self) { - let (mut lower, upper) = lower_upper; - let new_upper = upper.shr_vartime(shift); - lower = lower.shr_vartime(shift); - if shift >= Self::BITS { - lower = lower.bitor(&upper.shr_vartime(shift - Self::BITS)); + pub const fn shr_vartime_wide( + lower_upper: (Self, Self), + shift: u32, + ) -> ((Self, Self), CtChoice) { + let (lower, upper) = lower_upper; + if shift >= 2 * Self::BITS { + ((Self::ZERO, Self::ZERO), CtChoice::TRUE) + } else if shift >= Self::BITS { + let (lower, _) = upper.shr_vartime(shift - Self::BITS); + ((lower, Self::ZERO), CtChoice::FALSE) } else { - lower = lower.bitor(&upper.shl_vartime(Self::BITS - shift)); + let (new_upper, _) = upper.shr_vartime(shift); + let (lower_hi, _) = upper.shl_vartime(Self::BITS - shift); + let (lower_lo, _) = lower.shr_vartime(shift); + ((lower_lo.bitor(&lower_hi), new_upper), CtChoice::FALSE) } - - (lower, new_upper) } /// Computes `self >> 1` in constant-time, returning [`CtChoice::TRUE`] if the overflowing bit @@ -109,7 +116,7 @@ impl Shr for Uint { type Output = Uint; fn shr(self, shift: u32) -> Uint { - Uint::::shr(&self, shift) + <&Uint as Shr>::shr(&self, shift) } } @@ -117,7 +124,12 @@ impl Shr for &Uint { type Output = Uint; fn shr(self, shift: u32) -> Uint { - self.shr(shift) + let (result, overflow) = Uint::::shr(self, shift); + assert!( + !overflow.is_true_vartime(), + "attempt to shift right with overflow" + ); + result } } @@ -129,7 +141,7 @@ impl ShrAssign for Uint { #[cfg(test)] mod tests { - use crate::{Uint, U128, U256}; + use crate::{CtChoice, Uint, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -142,11 +154,23 @@ mod tests { assert_eq!(N >> 1, N_2); } + #[test] + fn shr256_const() { + assert_eq!(N.shr(256), (U256::ZERO, CtChoice::TRUE)); + assert_eq!(N.shr_vartime(256), (U256::ZERO, CtChoice::TRUE)); + } + + #[test] + #[should_panic(expected = "attempt to shift right with overflow")] + fn shr256() { + let _ = N >> 256; + } + #[test] fn shr_wide_1_1_128() { assert_eq!( Uint::shr_vartime_wide((U128::ONE, U128::ONE), 128), - (U128::ONE, U128::ZERO) + ((U128::ONE, U128::ZERO), CtChoice::FALSE) ); } @@ -154,7 +178,7 @@ mod tests { fn shr_wide_0_max_1() { assert_eq!( Uint::shr_vartime_wide((U128::ZERO, U128::MAX), 1), - (U128::ONE << 127, U128::MAX >> 1) + ((U128::ONE << 127, U128::MAX >> 1), CtChoice::FALSE) ); } @@ -162,7 +186,7 @@ mod tests { fn shr_wide_max_max_256() { assert_eq!( Uint::shr_vartime_wide((U128::MAX, U128::MAX), 256), - (U128::ZERO, U128::ZERO) + ((U128::ZERO, U128::ZERO), CtChoice::TRUE) ); } } diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index eed95826..17394c0f 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -15,7 +15,8 @@ impl Uint { // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. - let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) // Repeat enough times to guarantee result has stabilized. let mut i = 0; @@ -49,7 +50,8 @@ impl Uint { // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. - let mut x = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = Self::ONE.shl((self.bits() + 1) >> 1); // ≥ √(`self`) // Stop right away if `x` is zero to avoid divizion by zero. while !x.cmp_vartime(&Self::ZERO).is_eq() { diff --git a/tests/uint_proptests.rs b/tests/uint_proptests.rs index 9c884e25..2a6c3c87 100644 --- a/tests/uint_proptests.rs +++ b/tests/uint_proptests.rs @@ -61,9 +61,10 @@ proptest! { let a_bi = to_biguint(&a); let expected = to_uint(a_bi << shift.into()); - let actual = a.shl_vartime(shift.into()); + let (actual, overflow) = a.shl_vartime(shift.into()); assert_eq!(expected, actual); + assert_eq!(overflow, CtChoice::FALSE); } #[test] @@ -74,9 +75,13 @@ proptest! { let shift = u32::from(shift) % (U256::BITS * 2); let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << U256::BITS as usize) - BigUint::one())); - let actual = a.shl(shift); + let (actual, overflow) = a.shl(shift); assert_eq!(expected, actual); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } } #[test] @@ -87,9 +92,13 @@ proptest! { let shift = u32::from(shift) % (U256::BITS * 2); let expected = to_uint(a_bi >> shift as usize); - let actual = a.shr(shift); + let (actual, overflow) = a.shr(shift); assert_eq!(expected, actual); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } } #[test] From 5f9fd151321246b58fa102debcf68a60c35da9a5 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sat, 9 Dec 2023 11:00:30 -0800 Subject: [PATCH 04/15] Make uint/shl.rs and shr.rs uniform --- src/uint/shl.rs | 43 ++++++++++++++++++++--- src/uint/shr.rs | 77 ++++++++++++++++++++++++++--------------- tests/uint_proptests.rs | 27 +++++++++++++-- 3 files changed, 112 insertions(+), 35 deletions(-) diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 3e8d4db3..5a59d257 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -8,11 +8,14 @@ impl Uint { /// If `shift >= Self::BITS`, returns zero as the first tuple element, /// and `CtChoice::TRUE` as the second element. pub const fn shl(&self, shift: u32) -> (Self, CtChoice) { + // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < BITS`). + let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); let overflow = CtChoice::from_u32_lt(shift, Self::BITS).not(); let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; - while i < Self::LOG2_BITS + 1 { + while i < shift_bits { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); result = Uint::ct_select(&result, &result.shl_vartime(1 << i).0, bit); i += 1; @@ -46,8 +49,21 @@ impl Uint { limbs[i] = self.limbs[i - shift_num]; } - let (new_lower, _carry) = (Self { limbs }).shl_limb(rem); - (new_lower, CtChoice::FALSE) + if rem == 0 { + return (Self { limbs }, CtChoice::FALSE); + } + + let mut carry = Limb::ZERO; + + while i < LIMBS { + let shifted = limbs[i].shl(rem); + let new_carry = limbs[i].shr(Limb::BITS - rem); + limbs[i] = shifted.bitor(carry); + carry = new_carry; + i += 1; + } + + (Self { limbs }, CtChoice::FALSE) } /// Computes a left shift on a wide input as `(lo, hi)`. @@ -101,10 +117,27 @@ impl Uint { (Uint::::new(limbs), Limb(carry)) } - /// Computes `self >> 1` in constant-time. + /// Computes `self << 1` in constant-time, returning [`CtChoice::TRUE`] if the overflowing bit + /// was set, and [`CtChoice::FALSE`] otherwise. + #[inline(always)] + pub(crate) const fn shl1_with_overflow(&self) -> (Self, CtChoice) { + let mut ret = Self::ZERO; + let mut i = 0; + let mut carry = Limb::ZERO; + while i < LIMBS { + let (shifted, new_carry) = self.limbs[i].shl1(); + ret.limbs[i] = shifted.bitor(carry); + carry = new_carry; + i += 1; + } + + (ret, CtChoice::from_word_lsb(carry.0)) + } + + /// Computes `self << 1` in constant-time. pub(crate) const fn shl1(&self) -> Self { // TODO(tarcieri): optimized implementation - self.shl_vartime(1).0 + self.shl1_with_overflow().0 } } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index b4476a04..3a4cd044 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -8,11 +8,14 @@ impl Uint { /// If `shift >= Self::BITS`, returns zero as the first tuple element, /// and `CtChoice::TRUE` as the second element. pub const fn shr(&self, shift: u32) -> (Self, CtChoice) { + // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < BITS`). + let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); let overflow = CtChoice::from_u32_lt(shift, Self::BITS).not(); let shift = shift % Self::BITS; let mut result = *self; let mut i = 0; - while i < Self::LOG2_BITS + 1 { + while i < shift_bits { let bit = CtChoice::from_u32_lsb((shift >> i) & 1); result = Uint::ct_select(&result, &result.shr_vartime(1 << i).0, bit); i += 1; @@ -31,39 +34,41 @@ impl Uint { /// to `self`. #[inline(always)] pub const fn shr_vartime(&self, shift: u32) -> (Self, CtChoice) { - let full_shifts = (shift / Limb::BITS) as usize; - let small_shift = shift & (Limb::BITS - 1); let mut limbs = [Limb::ZERO; LIMBS]; if shift >= Self::BITS { return (Self::ZERO, CtChoice::TRUE); } - let shift = LIMBS - full_shifts; + let shift_num = (shift / Limb::BITS) as usize; + let rem = shift % Limb::BITS; + let mut i = 0; + while i < LIMBS - shift_num { + limbs[i] = self.limbs[i + shift_num]; + i += 1; + } - if small_shift == 0 { - while i < shift { - limbs[i] = Limb(self.limbs[i + full_shifts].0); - i += 1; - } - } else { - while i < shift { - let mut lo = self.limbs[i + full_shifts].0 >> small_shift; + if rem == 0 { + return (Self { limbs }, CtChoice::FALSE); + } - if i < (LIMBS - 1) - full_shifts { - lo |= self.limbs[i + full_shifts + 1].0 << (Limb::BITS - small_shift); - } + let mut carry = Limb::ZERO; - limbs[i] = Limb(lo); - i += 1; - } + while i > 0 { + i -= 1; + let shifted = limbs[i].shr(rem); + let new_carry = limbs[i].shl(Limb::BITS - rem); + limbs[i] = shifted.bitor(carry); + carry = new_carry; } (Self { limbs }, CtChoice::FALSE) } /// Computes a right shift on a wide input as `(lo, hi)`. + /// If `shift >= Self::BITS`, returns a tuple of zeros as the first element, + /// and `CtChoice::TRUE` as the second element. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// @@ -90,24 +95,24 @@ impl Uint { /// Computes `self >> 1` in constant-time, returning [`CtChoice::TRUE`] if the overflowing bit /// was set, and [`CtChoice::FALSE`] otherwise. + #[inline(always)] pub(crate) const fn shr1_with_overflow(&self) -> (Self, CtChoice) { - let carry = CtChoice::from_word_lsb(self.limbs[0].0 & 1); let mut ret = Self::ZERO; - ret.limbs[0] = self.limbs[0].shr(1); - - let mut i = 1; - while i < LIMBS { - // set carry bit - ret.limbs[i - 1].0 |= (self.limbs[i].0 & 1) << Limb::HI_BIT; - ret.limbs[i] = self.limbs[i].shr(1); - i += 1; + let mut i = LIMBS; + let mut carry = Limb::ZERO; + while i > 0 { + i -= 1; + let (shifted, new_carry) = self.limbs[i].shr1(); + ret.limbs[i] = shifted.bitor(carry); + carry = new_carry; } - (ret, carry) + (ret, CtChoice::from_word_lsb(carry.0 >> Limb::HI_BIT)) } /// Computes `self >> 1` in constant-time. pub(crate) const fn shr1(&self) -> Self { + // TODO(tarcieri): optimized implementation self.shr1_with_overflow().0 } } @@ -151,6 +156,7 @@ mod tests { #[test] fn shr1() { + assert_eq!(N.shr1(), N_2); assert_eq!(N >> 1, N_2); } @@ -189,4 +195,19 @@ mod tests { ((U128::ZERO, U128::ZERO), CtChoice::TRUE) ); } + + /* + #[test] + fn shr_limb() { + let x = U128::from_be_hex("00112233445566778899aabbccddeeff"); + assert_eq!(x.shr_limb(0), (x, Limb::ZERO)); + assert_eq!( + x.shr_limb(8), + ( + U128::from_be_hex("0000112233445566778899aabbccddee"), + Limb(0xff << (Limb::BITS - 8)) + ) + ); + } + */ } diff --git a/tests/uint_proptests.rs b/tests/uint_proptests.rs index 2a6c3c87..e1b52f3a 100644 --- a/tests/uint_proptests.rs +++ b/tests/uint_proptests.rs @@ -60,11 +60,17 @@ proptest! { fn shl_vartime(a in uint(), shift in any::()) { let a_bi = to_biguint(&a); - let expected = to_uint(a_bi << shift.into()); + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (U256::BITS * 2); + + let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << U256::BITS as usize) - BigUint::one())); let (actual, overflow) = a.shl_vartime(shift.into()); assert_eq!(expected, actual); - assert_eq!(overflow, CtChoice::FALSE); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } } #[test] @@ -84,6 +90,23 @@ proptest! { } } + #[test] + fn shr_vartime(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (U256::BITS * 2); + + let expected = to_uint(a_bi >> shift as usize); + let (actual, overflow) = a.shr_vartime(shift); + + assert_eq!(expected, actual); + if shift >= U256::BITS { + assert_eq!(actual, U256::ZERO); + assert_eq!(overflow, CtChoice::TRUE); + } + } + #[test] fn shr(a in uint(), shift in any::()) { let a_bi = to_biguint(&a); From 90a21a89d290436f3e2b1ef568a07dfd58d63f89 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 8 Dec 2023 18:18:25 -0800 Subject: [PATCH 05/15] Adjust BoxedUint API --- src/uint/boxed/bits.rs | 2 +- src/uint/boxed/div.rs | 8 ++-- src/uint/boxed/inv_mod.rs | 4 +- src/uint/boxed/shl.rs | 87 ++++++++++++++++++------------------ src/uint/boxed/shr.rs | 94 +++++++++++++++++++++++---------------- 5 files changed, 107 insertions(+), 88 deletions(-) diff --git a/src/uint/boxed/bits.rs b/src/uint/boxed/bits.rs index fa681d22..e60ebae0 100644 --- a/src/uint/boxed/bits.rs +++ b/src/uint/boxed/bits.rs @@ -84,7 +84,7 @@ mod tests { fn uint_with_bits_at(positions: &[u32]) -> BoxedUint { let mut result = BoxedUint::zero_with_precision(256); for &pos in positions { - result |= BoxedUint::one_with_precision(256).shl_vartime(pos); + result |= BoxedUint::one_with_precision(256).shl_vartime(pos).unwrap(); } result } diff --git a/src/uint/boxed/div.rs b/src/uint/boxed/div.rs index c8a66213..d82bb982 100644 --- a/src/uint/boxed/div.rs +++ b/src/uint/boxed/div.rs @@ -37,7 +37,8 @@ impl BoxedUint { let mb = rhs.bits(); let mut bd = self.bits_precision() - mb; let mut rem = self.clone(); - let mut c = rhs.shl_vartime(bd); + // Will not overflow since `bd < bits_precision` + let mut c = rhs.shl_vartime(bd).expect("shift within range"); loop { let (r, borrow) = rem.sbb(&c, Limb::ZERO); @@ -77,7 +78,7 @@ impl BoxedUint { let bits_precision = self.bits_precision(); let mut rem = self.clone(); let mut quo = Self::zero_with_precision(bits_precision); - let mut c = rhs.shl(bits_precision - mb); + let (mut c, _overflow) = rhs.shl(bits_precision - mb); let mut i = bits_precision; let mut done = Choice::from(0u8); @@ -110,7 +111,8 @@ impl BoxedUint { let mut bd = self.bits_precision() - mb; let mut remainder = self.clone(); let mut quotient = Self::zero_with_precision(self.bits_precision()); - let mut c = rhs.shl_vartime(bd); + // Will not overflow since `bd < bits_precision` + let mut c = rhs.shl_vartime(bd).expect("shift within range"); loop { let (mut r, borrow) = remainder.sbb(&c, Limb::ZERO); diff --git a/src/uint/boxed/inv_mod.rs b/src/uint/boxed/inv_mod.rs index d4c063c5..ab058137 100644 --- a/src/uint/boxed/inv_mod.rs +++ b/src/uint/boxed/inv_mod.rs @@ -11,7 +11,7 @@ impl BoxedUint { // Decompose `modulus = s * 2^k` where `s` is odd let k = modulus.trailing_zeros(); - let s = modulus.shr(k); + let s = modulus >> k; // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` @@ -26,7 +26,7 @@ impl BoxedUint { let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists // This part is mod 2^k - let mask = Self::one().shl(k).wrapping_sub(&Self::one()); + let mask = (Self::one() << k).wrapping_sub(&Self::one()); let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask); // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`, diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 7fa9d4af..ee8b4b86 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -1,44 +1,55 @@ //! [`BoxedUint`] bitwise left shift operations. -use crate::{BoxedUint, CtChoice, Limb, Word}; +use crate::{BoxedUint, Limb}; use core::ops::{Shl, ShlAssign}; use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub fn shl(&self, shift: u32) -> Self { + /// Returns `None` if `shift >= Self::BITS`. + pub fn shl(&self, shift: u32) -> (Self, Choice) { + // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < bits_precision`). + let shift_bits = u32::BITS - (self.bits_precision() - 1).leading_zeros(); let overflow = !shift.ct_lt(&self.bits_precision()); let shift = shift % self.bits_precision(); - let log2_bits = u32::BITS - self.bits_precision().leading_zeros(); let mut result = self.clone(); - for i in 0..log2_bits { + for i in 0..shift_bits { let bit = Choice::from(((shift >> i) & 1) as u8); - result = Self::conditional_select(&result, &result.shl_vartime(1 << i), bit); + result = Self::conditional_select( + &result, + // Will not overflow by construction + &result.shl_vartime(1 << i).expect("shift within range"), + bit, + ); } - Self::conditional_select( - &result, - &Self::zero_with_precision(self.bits_precision()), + ( + Self::conditional_select( + &result, + &Self::zero_with_precision(self.bits_precision()), + overflow, + ), overflow, ) } /// Computes `self << shift`. + /// Returns `None` if `shift >= Self::BITS`. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect to `self`. #[inline(always)] - pub fn shl_vartime(&self, shift: u32) -> Self { - let nlimbs = self.nlimbs(); - let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); - + pub fn shl_vartime(&self, shift: u32) -> Option { if shift >= self.bits_precision() { - return Self { limbs }; + return None; } + let nlimbs = self.nlimbs(); + let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); + let shift_num = (shift / Limb::BITS) as usize; let rem = shift % Limb::BITS; @@ -48,39 +59,27 @@ impl BoxedUint { limbs[i] = self.limbs[i - shift_num]; } - let (new_lower, _carry) = (Self { limbs }).shl_limb(rem); - new_lower - } + if rem == 0 { + return Some(Self { limbs }); + } - /// Computes `self << shift` where `0 <= shift < Limb::BITS`, - /// returning the result and the carry. - #[inline(always)] - pub(crate) fn shl_limb(&self, shift: u32) -> (Self, Limb) { - let nlimbs = self.nlimbs(); - let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); + let mut carry = Limb::ZERO; - let nz = CtChoice::from_u32_nonzero(shift); - let lshift = shift; - let rshift = nz.if_true_u32(Limb::BITS - shift); - let carry = nz.if_true_word(self.limbs[nlimbs - 1].0.wrapping_shr(Word::BITS - shift)); - - let mut i = nlimbs - 1; - while i > 0 { - let mut limb = self.limbs[i].0 << lshift; - let hi = self.limbs[i - 1].0 >> rshift; - limb |= nz.if_true_word(hi); - limbs[i] = Limb(limb); - i -= 1 + while i < nlimbs { + let shifted = limbs[i].shl(rem); + let new_carry = limbs[i].shr(Limb::BITS - rem); + limbs[i] = shifted.bitor(carry); + carry = new_carry; + i += 1; } - limbs[0] = Limb(self.limbs[0].0 << lshift); - (Self { limbs }, Limb(carry)) + Some(Self { limbs }) } /// Computes `self >> 1` in constant-time. pub(crate) fn shl1(&self) -> Self { // TODO(tarcieri): optimized implementation - self.shl_vartime(1) + self.shl_vartime(1).expect("shift within range") } /// Computes `self >> 1` in-place in constant-time. @@ -94,7 +93,7 @@ impl Shl for BoxedUint { type Output = BoxedUint; fn shl(self, shift: u32) -> BoxedUint { - Self::shl(&self, shift) + <&BoxedUint as Shl>::shl(&self, shift) } } @@ -102,7 +101,9 @@ impl Shl for &BoxedUint { type Output = BoxedUint; fn shl(self, shift: u32) -> BoxedUint { - self.shl(shift) + let (result, overflow) = self.shl(shift); + assert!(!bool::from(overflow), "attempt to shift left with overflow"); + result } } @@ -121,11 +122,11 @@ mod tests { fn shl_vartime() { let one = BoxedUint::one_with_precision(128); - assert_eq!(BoxedUint::from(2u8), one.shl_vartime(1)); - assert_eq!(BoxedUint::from(4u8), one.shl_vartime(2)); + assert_eq!(BoxedUint::from(2u8), one.shl_vartime(1).unwrap()); + assert_eq!(BoxedUint::from(4u8), one.shl_vartime(2).unwrap()); assert_eq!( BoxedUint::from(0x80000000000000000u128), - one.shl_vartime(67) + one.shl_vartime(67).unwrap() ); } } diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index ba5a3487..a2f5bc61 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -5,64 +5,75 @@ use core::ops::{Shr, ShrAssign}; use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { - /// Computes `self << shift`. - /// Returns zero if `shift >= Self::BITS`. - pub fn shr(&self, shift: u32) -> Self { + /// Computes `self >> shift`. + /// Returns `None` if `shift >= Self::BITS`. + pub fn shr(&self, shift: u32) -> (Self, Choice) { + // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < bits_precision`). + let shift_bits = u32::BITS - (self.bits_precision() - 1).leading_zeros(); let overflow = !shift.ct_lt(&self.bits_precision()); let shift = shift % self.bits_precision(); - let log2_bits = u32::BITS - self.bits_precision().leading_zeros(); let mut result = self.clone(); - for i in 0..log2_bits { + for i in 0..shift_bits { let bit = Choice::from(((shift >> i) & 1) as u8); - result = Self::conditional_select(&result, &result.shr_vartime(1 << i), bit); + result = Self::conditional_select( + &result, + // Will not overflow by construction + &result.shr_vartime(1 << i).expect("shift within range"), + bit, + ); } - Self::conditional_select( - &result, - &Self::zero_with_precision(self.bits_precision()), + ( + Self::conditional_select( + &result, + &Self::zero_with_precision(self.bits_precision()), + overflow, + ), overflow, ) } /// Computes `self >> shift`. + /// Returns `None` if `shift >= Self::BITS`. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect to `self`. #[inline(always)] - pub fn shr_vartime(&self, shift: u32) -> Self { + pub fn shr_vartime(&self, shift: u32) -> Option { + if shift >= self.bits_precision() { + return None; + } + let nlimbs = self.nlimbs(); - let full_shifts = (shift / Limb::BITS) as usize; - let small_shift = shift & (Limb::BITS - 1); let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); - if shift > self.bits_precision() { - return Self { limbs }; - } + let shift_num = (shift / Limb::BITS) as usize; + let rem = shift % Limb::BITS; - let n = nlimbs - full_shifts; let mut i = 0; + while i < nlimbs - shift_num { + limbs[i] = self.limbs[i + shift_num]; + i += 1; + } + + if rem == 0 { + return Some(Self { limbs }); + } + + let mut carry = Limb::ZERO; - if small_shift == 0 { - while i < n { - limbs[i] = Limb(self.limbs[i + full_shifts].0); - i += 1; - } - } else { - while i < n { - let mut lo = self.limbs[i + full_shifts].0 >> small_shift; - - if i < (nlimbs - 1) - full_shifts { - lo |= self.limbs[i + full_shifts + 1].0 << (Limb::BITS - small_shift); - } - - limbs[i] = Limb(lo); - i += 1; - } + while i > 0 { + i -= 1; + let shifted = limbs[i].shr(rem); + let new_carry = limbs[i].shl(Limb::BITS - rem); + limbs[i] = shifted.bitor(carry); + carry = new_carry; } - Self { limbs } + Some(Self { limbs }) } /// Computes `self >> 1` in constant-time, returning a true [`Choice`] if the overflowing bit @@ -95,7 +106,7 @@ impl Shr for BoxedUint { type Output = BoxedUint; fn shr(self, shift: u32) -> BoxedUint { - Self::shr(&self, shift) + <&BoxedUint as Shr>::shr(&self, shift) } } @@ -103,7 +114,12 @@ impl Shr for &BoxedUint { type Output = BoxedUint; fn shr(self, shift: u32) -> BoxedUint { - self.shr(shift) + let (result, overflow) = self.shr(shift); + assert!( + !bool::from(overflow), + "attempt to shift right with overflow" + ); + result } } @@ -129,9 +145,9 @@ mod tests { #[test] fn shr_vartime() { let n = BoxedUint::from(0x80000000000000000u128); - assert_eq!(BoxedUint::zero(), n.shr_vartime(68)); - assert_eq!(BoxedUint::one(), n.shr_vartime(67)); - assert_eq!(BoxedUint::from(2u8), n.shr_vartime(66)); - assert_eq!(BoxedUint::from(4u8), n.shr_vartime(65)); + assert_eq!(BoxedUint::zero(), n.shr_vartime(68).unwrap()); + assert_eq!(BoxedUint::one(), n.shr_vartime(67).unwrap()); + assert_eq!(BoxedUint::from(2u8), n.shr_vartime(66).unwrap()); + assert_eq!(BoxedUint::from(4u8), n.shr_vartime(65).unwrap()); } } From 36175baa2546c27271cbaa9b13fec234c7c5f1a6 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sat, 9 Dec 2023 13:09:20 -0800 Subject: [PATCH 06/15] Add shift benchmarks for BoxedUint --- Cargo.toml | 5 +++++ benches/boxed_uint.rs | 48 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 benches/boxed_uint.rs diff --git a/Cargo.toml b/Cargo.toml index 89dd4e40..91cb0618 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,11 @@ name = "boxed_residue" harness = false required-features = ["alloc"] +[[bench]] +name = "boxed_uint" +harness = false +required-features = ["alloc"] + [[bench]] name = "dyn_residue" harness = false diff --git a/benches/boxed_uint.rs b/benches/boxed_uint.rs new file mode 100644 index 00000000..b34eaa33 --- /dev/null +++ b/benches/boxed_uint.rs @@ -0,0 +1,48 @@ +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use crypto_bigint::BoxedUint; +use rand_core::OsRng; + +/// Size of `BoxedUint` to use in benchmark. +const UINT_BITS: u32 = 4096; + +fn bench_shifts(c: &mut Criterion) { + let mut group = c.benchmark_group("bit shifts"); + + group.bench_function("shl_vartime", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| black_box(x.shl_vartime(UINT_BITS / 2 + 10)), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shl", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| x.shl(UINT_BITS / 2 + 10), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr_vartime", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| black_box(x.shr_vartime(UINT_BITS / 2 + 10)), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shr", |b| { + b.iter_batched( + || BoxedUint::random(&mut OsRng, UINT_BITS), + |x| x.shr(UINT_BITS / 2 + 10), + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +criterion_group!(benches, bench_shifts); + +criterion_main!(benches); From a17ef3b74492a61cc696ac143d8ff5e45d883f38 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sat, 9 Dec 2023 13:29:49 -0800 Subject: [PATCH 07/15] Improve BoxedUint::shl/shr performance by performing operations inplace when possible --- src/uint/boxed.rs | 5 +++ src/uint/boxed/shl.rs | 65 +++++++++++++++++++++---------- src/uint/boxed/shr.rs | 64 ++++++++++++++++++++----------- tests/boxed_uint_proptests.rs | 72 +++++++++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 44 deletions(-) diff --git a/src/uint/boxed.rs b/src/uint/boxed.rs index 25cdce1d..f797204b 100644 --- a/src/uint/boxed.rs +++ b/src/uint/boxed.rs @@ -253,6 +253,11 @@ impl BoxedUint { limbs.into() } + + /// Set the value of `self` to zero in-place. + pub(crate) fn set_to_zero(&mut self) { + self.limbs.as_mut().fill(Limb::ZERO) + } } impl NonZero { diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index ee8b4b86..f91ce3e9 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -14,15 +14,16 @@ impl BoxedUint { let overflow = !shift.ct_lt(&self.bits_precision()); let shift = shift % self.bits_precision(); let mut result = self.clone(); + let mut temp = self.clone(); for i in 0..shift_bits { let bit = Choice::from(((shift >> i) & 1) as u8); - result = Self::conditional_select( - &result, - // Will not overflow by construction - &result.shl_vartime(1 << i).expect("shift within range"), - bit, - ); + temp.set_to_zero(); + // Will not overflow by construction + result + .shl_vartime_into(&mut temp, 1 << i) + .expect("shift within range"); + result.conditional_assign(&temp, bit); } ( @@ -35,45 +36,55 @@ impl BoxedUint { ) } - /// Computes `self << shift`. + /// Computes `self << shift` and writes the result into `dest`. /// Returns `None` if `shift >= Self::BITS`. /// + /// WARNING: for performance reasons, `dest` is assumed to be pre-zeroized. + /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect to `self`. #[inline(always)] - pub fn shl_vartime(&self, shift: u32) -> Option { + fn shl_vartime_into(&self, dest: &mut Self, shift: u32) -> Option<()> { if shift >= self.bits_precision() { return None; } let nlimbs = self.nlimbs(); - let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); - let shift_num = (shift / Limb::BITS) as usize; let rem = shift % Limb::BITS; - let mut i = nlimbs; - while i > shift_num { - i -= 1; - limbs[i] = self.limbs[i - shift_num]; + for i in shift_num..nlimbs { + dest.limbs[i] = self.limbs[i - shift_num]; } if rem == 0 { - return Some(Self { limbs }); + return Some(()); } let mut carry = Limb::ZERO; - while i < nlimbs { - let shifted = limbs[i].shl(rem); - let new_carry = limbs[i].shr(Limb::BITS - rem); - limbs[i] = shifted.bitor(carry); + for i in shift_num..nlimbs { + let shifted = dest.limbs[i].shl(rem); + let new_carry = dest.limbs[i].shr(Limb::BITS - rem); + dest.limbs[i] = shifted.bitor(carry); carry = new_carry; - i += 1; } - Some(Self { limbs }) + Some(()) + } + + /// Computes `self << shift`. + /// Returns `None` if `shift >= Self::BITS`. + /// + /// NOTE: this operation is variable time with respect to `shift` *ONLY*. + /// + /// When used with a fixed `shift`, this function is constant-time with respect to `self`. + #[inline(always)] + pub fn shl_vartime(&self, shift: u32) -> Option { + let mut result = Self::zero_with_precision(self.bits_precision()); + let success = self.shl_vartime_into(&mut result, shift); + success.map(|_| result) } /// Computes `self >> 1` in constant-time. @@ -118,6 +129,18 @@ impl ShlAssign for BoxedUint { mod tests { use super::BoxedUint; + #[test] + fn shl() { + let one = BoxedUint::one_with_precision(128); + + assert_eq!(BoxedUint::from(2u8), one.shl(1).0); + assert_eq!(BoxedUint::from(4u8), one.shl(2).0); + assert_eq!( + BoxedUint::from(0x80000000000000000u128), + one.shl_vartime(67).unwrap() + ); + } + #[test] fn shl_vartime() { let one = BoxedUint::one_with_precision(128); diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index a2f5bc61..006c0fea 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -14,15 +14,16 @@ impl BoxedUint { let overflow = !shift.ct_lt(&self.bits_precision()); let shift = shift % self.bits_precision(); let mut result = self.clone(); + let mut temp = self.clone(); for i in 0..shift_bits { let bit = Choice::from(((shift >> i) & 1) as u8); - result = Self::conditional_select( - &result, - // Will not overflow by construction - &result.shr_vartime(1 << i).expect("shift within range"), - bit, - ); + temp.set_to_zero(); + // Will not overflow by construction + result + .shr_vartime_into(&mut temp, 1 << i) + .expect("shift within range"); + result.conditional_assign(&temp, bit); } ( @@ -38,42 +39,50 @@ impl BoxedUint { /// Computes `self >> shift`. /// Returns `None` if `shift >= Self::BITS`. /// + /// WARNING: for performance reasons, `dest` is assumed to be pre-zeroized. + /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// /// When used with a fixed `shift`, this function is constant-time with respect to `self`. #[inline(always)] - pub fn shr_vartime(&self, shift: u32) -> Option { + fn shr_vartime_into(&self, dest: &mut Self, shift: u32) -> Option<()> { if shift >= self.bits_precision() { return None; } let nlimbs = self.nlimbs(); - let mut limbs = vec![Limb::ZERO; nlimbs].into_boxed_slice(); - let shift_num = (shift / Limb::BITS) as usize; let rem = shift % Limb::BITS; - let mut i = 0; - while i < nlimbs - shift_num { - limbs[i] = self.limbs[i + shift_num]; - i += 1; + for i in 0..nlimbs - shift_num { + dest.limbs[i] = self.limbs[i + shift_num]; } if rem == 0 { - return Some(Self { limbs }); + return Some(()); } - let mut carry = Limb::ZERO; - - while i > 0 { - i -= 1; - let shifted = limbs[i].shr(rem); - let new_carry = limbs[i].shl(Limb::BITS - rem); - limbs[i] = shifted.bitor(carry); - carry = new_carry; + for i in 0..nlimbs - shift_num - 1 { + let shifted = dest.limbs[i].shr(rem); + let carry = dest.limbs[i + 1].shl(Limb::BITS - rem); + dest.limbs[i] = shifted.bitor(carry); } + dest.limbs[nlimbs - shift_num - 1] = dest.limbs[nlimbs - shift_num - 1].shr(rem); - Some(Self { limbs }) + Some(()) + } + + /// Computes `self >> shift`. + /// Returns `None` if `shift >= Self::BITS`. + /// + /// NOTE: this operation is variable time with respect to `shift` *ONLY*. + /// + /// When used with a fixed `shift`, this function is constant-time with respect to `self`. + #[inline(always)] + pub fn shr_vartime(&self, shift: u32) -> Option { + let mut result = Self::zero_with_precision(self.bits_precision()); + let success = self.shr_vartime_into(&mut result, shift); + success.map(|_| result) } /// Computes `self >> 1` in constant-time, returning a true [`Choice`] if the overflowing bit @@ -142,6 +151,15 @@ mod tests { assert_eq!(n, n_shr1); } + #[test] + fn shr() { + let n = BoxedUint::from(0x80000000000000000u128); + assert_eq!(BoxedUint::zero(), n.shr(68).0); + assert_eq!(BoxedUint::one(), n.shr(67).0); + assert_eq!(BoxedUint::from(2u8), n.shr(66).0); + assert_eq!(BoxedUint::from(4u8), n.shr(65).0); + } + #[test] fn shr_vartime() { let n = BoxedUint::from(0x80000000000000000u128); diff --git a/tests/boxed_uint_proptests.rs b/tests/boxed_uint_proptests.rs index 4fcb99d6..424d4b68 100644 --- a/tests/boxed_uint_proptests.rs +++ b/tests/boxed_uint_proptests.rs @@ -5,6 +5,7 @@ use core::cmp::Ordering; use crypto_bigint::{BoxedUint, CheckedAdd, Limb, NonZero}; use num_bigint::{BigUint, ModInverse}; +use num_traits::identities::One; use proptest::prelude::*; fn to_biguint(uint: &BoxedUint) -> BigUint { @@ -212,4 +213,75 @@ proptest! { prop_assert_eq!(expected, to_biguint(&actual)); } } + + #[test] + fn shl(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << a.bits_precision() as usize) - BigUint::one())); + let (actual, overflow) = a.shl(shift); + + assert_eq!(expected, actual); + if shift >= a.bits_precision() { + assert_eq!(actual, BoxedUint::zero()); + assert!(bool::from(overflow)); + } + } + + #[test] + fn shl_vartime(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << a.bits_precision() as usize) - BigUint::one())); + let actual = a.shl_vartime(shift); + + if shift >= a.bits_precision() { + assert!(actual.is_none()); + } + else { + assert_eq!(expected, actual.unwrap()); + } + } + + #[test] + fn shr(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint(a_bi >> shift as usize); + let (actual, overflow) = a.shr(shift); + + assert_eq!(expected, actual); + if shift >= a.bits_precision() { + assert_eq!(actual, BoxedUint::zero()); + assert!(bool::from(overflow)); + } + } + + + #[test] + fn shr_vartime(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = u32::from(shift) % (a.bits_precision() * 2); + + let expected = to_uint(a_bi >> shift as usize); + let actual = a.shr_vartime(shift); + + if shift >= a.bits_precision() { + assert!(actual.is_none()); + } + else { + assert_eq!(expected, actual.unwrap()); + } + } } From 972f54609fdd3fe9812419afe5717cc56354e1c3 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sat, 9 Dec 2023 13:02:33 -0800 Subject: [PATCH 08/15] Add a specialized BoxedUint::shl1 implementation --- src/uint/boxed/shl.rs | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index f91ce3e9..909a883c 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -89,14 +89,21 @@ impl BoxedUint { /// Computes `self >> 1` in constant-time. pub(crate) fn shl1(&self) -> Self { - // TODO(tarcieri): optimized implementation - self.shl_vartime(1).expect("shift within range") + let mut ret = self.clone(); + ret.shl1_assign(); + ret } /// Computes `self >> 1` in-place in constant-time. pub(crate) fn shl1_assign(&mut self) { - // TODO(tarcieri): optimized implementation - *self = self.shl1(); + let mut carry = self.limbs[0].0 >> Limb::HI_BIT; + self.limbs[0].shl_assign(1); + for i in 1..self.limbs.len() { + let new_carry = self.limbs[i].0 >> Limb::HI_BIT; + self.limbs[i].shl_assign(1); + self.limbs[i].0 |= carry; + carry = new_carry + } } } @@ -129,6 +136,14 @@ impl ShlAssign for BoxedUint { mod tests { use super::BoxedUint; + #[test] + fn shl1_assign() { + let mut n = BoxedUint::from(0x3c442b21f19185fe433f0a65af902b8fu128); + let n_shl1 = BoxedUint::from(0x78885643e3230bfc867e14cb5f20571eu128); + n.shl1_assign(); + assert_eq!(n, n_shl1); + } + #[test] fn shl() { let one = BoxedUint::one_with_precision(128); From 17145af7152e1deeec50028d319fbea7f15f1acf Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sat, 9 Dec 2023 13:44:46 -0800 Subject: [PATCH 09/15] Iterate in forward direction to help with cache --- src/uint/shl.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 5a59d257..632fe4e7 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -43,10 +43,10 @@ impl Uint { let shift_num = (shift / Limb::BITS) as usize; let rem = shift % Limb::BITS; - let mut i = LIMBS; - while i > shift_num { - i -= 1; + let mut i = shift_num; + while i < LIMBS { limbs[i] = self.limbs[i - shift_num]; + i += 1; } if rem == 0 { @@ -55,6 +55,7 @@ impl Uint { let mut carry = Limb::ZERO; + let mut i = shift_num; while i < LIMBS { let shifted = limbs[i].shl(rem); let new_carry = limbs[i].shr(Limb::BITS - rem); @@ -104,15 +105,15 @@ impl Uint { let rshift = nz.if_true_u32(Limb::BITS - shift); let carry = nz.if_true_word(self.limbs[LIMBS - 1].0.wrapping_shr(Word::BITS - shift)); - let mut i = LIMBS - 1; - while i > 0 { + limbs[0] = Limb(self.limbs[0].0 << lshift); + let mut i = 1; + while i < LIMBS { let mut limb = self.limbs[i].0 << lshift; let hi = self.limbs[i - 1].0 >> rshift; limb |= nz.if_true_word(hi); limbs[i] = Limb(limb); - i -= 1 + i += 1 } - limbs[0] = Limb(self.limbs[0].0 << lshift); (Uint::::new(limbs), Limb(carry)) } From 1640e79f5c124f1d4b868810849ab5a0c60ee159 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 11 Dec 2023 14:36:04 -0800 Subject: [PATCH 10/15] Rename `sh(r/l)1_with_overflow` to `*_with_carry` --- src/modular/div_by_2.rs | 2 +- src/uint/boxed/inv_mod.rs | 6 +++--- src/uint/boxed/shr.rs | 6 +++--- src/uint/inv_mod.rs | 6 +++--- src/uint/shl.rs | 8 ++++---- src/uint/shr.rs | 8 ++++---- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/modular/div_by_2.rs b/src/modular/div_by_2.rs index 278f3dda..12d82ed7 100644 --- a/src/modular/div_by_2.rs +++ b/src/modular/div_by_2.rs @@ -18,7 +18,7 @@ pub(crate) fn div_by_2(a: &Uint, modulus: &Uint> 1` in constant-time, returning a true [`Choice`] if the overflowing bit - /// was set, and a false [`Choice::FALSE`] otherwise. - pub(crate) fn shr1_with_overflow(&self) -> (Self, Choice) { + /// Computes `self >> 1` in constant-time, returning a true [`Choice`] + /// if the least significant bit was set, and a false [`Choice::FALSE`] otherwise. + pub(crate) fn shr1_with_carry(&self) -> (Self, Choice) { let carry = self.limbs[0].0 & 1; (self.shr1(), Choice::from(carry as u8)) } diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index bc8beb82..236f0ac1 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -128,9 +128,9 @@ impl Uint { let (new_u, cyy) = new_u.conditional_wrapping_add(modulus, cy); debug_assert!(cy.is_true_vartime() == cyy.is_true_vartime()); - let (new_a, overflow) = a.shr1_with_overflow(); - debug_assert!(modulus_is_odd.not().or(overflow.not()).is_true_vartime()); - let (new_u, cy) = new_u.shr1_with_overflow(); + let (new_a, carry) = a.shr1_with_carry(); + debug_assert!(modulus_is_odd.not().or(carry.not()).is_true_vartime()); + let (new_u, cy) = new_u.shr1_with_carry(); let (new_u, cy) = new_u.conditional_wrapping_add(&m1hp, cy); debug_assert!(modulus_is_odd.not().or(cy.not()).is_true_vartime()); diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 632fe4e7..2b885d8d 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -118,10 +118,10 @@ impl Uint { (Uint::::new(limbs), Limb(carry)) } - /// Computes `self << 1` in constant-time, returning [`CtChoice::TRUE`] if the overflowing bit - /// was set, and [`CtChoice::FALSE`] otherwise. + /// Computes `self << 1` in constant-time, returning [`CtChoice::TRUE`] + /// if the most significant bit was set, and [`CtChoice::FALSE`] otherwise. #[inline(always)] - pub(crate) const fn shl1_with_overflow(&self) -> (Self, CtChoice) { + pub(crate) const fn shl1_with_carry(&self) -> (Self, CtChoice) { let mut ret = Self::ZERO; let mut i = 0; let mut carry = Limb::ZERO; @@ -138,7 +138,7 @@ impl Uint { /// Computes `self << 1` in constant-time. pub(crate) const fn shl1(&self) -> Self { // TODO(tarcieri): optimized implementation - self.shl1_with_overflow().0 + self.shl1_with_carry().0 } } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 3a4cd044..ffe984a9 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -93,10 +93,10 @@ impl Uint { } } - /// Computes `self >> 1` in constant-time, returning [`CtChoice::TRUE`] if the overflowing bit - /// was set, and [`CtChoice::FALSE`] otherwise. + /// Computes `self >> 1` in constant-time, returning [`CtChoice::TRUE`] + /// if the least significant bit was set, and [`CtChoice::FALSE`] otherwise. #[inline(always)] - pub(crate) const fn shr1_with_overflow(&self) -> (Self, CtChoice) { + pub(crate) const fn shr1_with_carry(&self) -> (Self, CtChoice) { let mut ret = Self::ZERO; let mut i = LIMBS; let mut carry = Limb::ZERO; @@ -113,7 +113,7 @@ impl Uint { /// Computes `self >> 1` in constant-time. pub(crate) const fn shr1(&self) -> Self { // TODO(tarcieri): optimized implementation - self.shr1_with_overflow().0 + self.shr1_with_carry().0 } } From 0ff9c43c1cb681854e23ebfa52c59d761f824756 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Dec 2023 11:10:11 -0800 Subject: [PATCH 11/15] Fix BoxedUint shift docstrings --- src/uint/boxed/shl.rs | 10 +++++----- src/uint/boxed/shr.rs | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 909a883c..65ad5948 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -6,7 +6,7 @@ use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self << shift`. - /// Returns `None` if `shift >= Self::BITS`. + /// Returns `None` if `shift >= self.bits_precision()`. pub fn shl(&self, shift: u32) -> (Self, Choice) { // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < bits_precision`). @@ -37,7 +37,7 @@ impl BoxedUint { } /// Computes `self << shift` and writes the result into `dest`. - /// Returns `None` if `shift >= Self::BITS`. + /// Returns `None` if `shift >= self.bits_precision()`. /// /// WARNING: for performance reasons, `dest` is assumed to be pre-zeroized. /// @@ -75,7 +75,7 @@ impl BoxedUint { } /// Computes `self << shift`. - /// Returns `None` if `shift >= Self::BITS`. + /// Returns `None` if `shift >= self.bits_precision()`. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// @@ -87,14 +87,14 @@ impl BoxedUint { success.map(|_| result) } - /// Computes `self >> 1` in constant-time. + /// Computes `self << 1` in constant-time. pub(crate) fn shl1(&self) -> Self { let mut ret = self.clone(); ret.shl1_assign(); ret } - /// Computes `self >> 1` in-place in constant-time. + /// Computes `self << 1` in-place in constant-time. pub(crate) fn shl1_assign(&mut self) { let mut carry = self.limbs[0].0 >> Limb::HI_BIT; self.limbs[0].shl_assign(1); diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index 56755e97..2be717c6 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -6,7 +6,7 @@ use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self >> shift`. - /// Returns `None` if `shift >= Self::BITS`. + /// Returns `None` if `shift >= self.bits_precision()`. pub fn shr(&self, shift: u32) -> (Self, Choice) { // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < bits_precision`). @@ -37,7 +37,7 @@ impl BoxedUint { } /// Computes `self >> shift`. - /// Returns `None` if `shift >= Self::BITS`. + /// Returns `None` if `shift >= self.bits_precision()`. /// /// WARNING: for performance reasons, `dest` is assumed to be pre-zeroized. /// @@ -73,7 +73,7 @@ impl BoxedUint { } /// Computes `self >> shift`. - /// Returns `None` if `shift >= Self::BITS`. + /// Returns `None` if `shift >= self.bits_precision()`. /// /// NOTE: this operation is variable time with respect to `shift` *ONLY*. /// From 3558f9a843db513e5963bc2e1916907ebc93ccb6 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Dec 2023 11:10:30 -0800 Subject: [PATCH 12/15] Add BoxedUint::conditional_set_to_zero() --- src/uint/boxed.rs | 11 ++++++++++- src/uint/boxed/shl.rs | 11 +++-------- src/uint/boxed/shr.rs | 11 +++-------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/uint/boxed.rs b/src/uint/boxed.rs index f797204b..4f3633bd 100644 --- a/src/uint/boxed.rs +++ b/src/uint/boxed.rs @@ -27,7 +27,7 @@ mod rand; use crate::{Integer, Limb, NonZero, Uint, Word, Zero, U128, U64}; use alloc::{boxed::Box, vec, vec::Vec}; use core::{fmt, mem}; -use subtle::{Choice, ConstantTimeEq}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; #[cfg(feature = "zeroize")] use zeroize::Zeroize; @@ -258,6 +258,15 @@ impl BoxedUint { pub(crate) fn set_to_zero(&mut self) { self.limbs.as_mut().fill(Limb::ZERO) } + + /// Set the value of `self` to zero in-place if `choice` is truthy. + pub(crate) fn conditional_set_to_zero(&mut self, choice: Choice) { + let nlimbs = self.nlimbs(); + let limbs = self.limbs.as_mut(); + for i in 0..nlimbs { + limbs[i] = Limb::conditional_select(&limbs[i], &Limb::ZERO, choice); + } + } } impl NonZero { diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 65ad5948..83a1973a 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -26,14 +26,9 @@ impl BoxedUint { result.conditional_assign(&temp, bit); } - ( - Self::conditional_select( - &result, - &Self::zero_with_precision(self.bits_precision()), - overflow, - ), - overflow, - ) + result.conditional_set_to_zero(overflow); + + (result, overflow) } /// Computes `self << shift` and writes the result into `dest`. diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index 2be717c6..9a10e6f4 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -26,14 +26,9 @@ impl BoxedUint { result.conditional_assign(&temp, bit); } - ( - Self::conditional_select( - &result, - &Self::zero_with_precision(self.bits_precision()), - overflow, - ), - overflow, - ) + result.conditional_set_to_zero(overflow); + + (result, overflow) } /// Computes `self >> shift`. From 43af4e2d006cee6a1303a69bd8afc8e1a2f2d352 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Dec 2023 11:45:53 -0800 Subject: [PATCH 13/15] Document returning zero on overflow for BoxedUint shifts --- src/uint/boxed/shl.rs | 4 +++- src/uint/boxed/shr.rs | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 83a1973a..733dcf34 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -6,7 +6,9 @@ use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self << shift`. - /// Returns `None` if `shift >= self.bits_precision()`. + /// + /// Returns a zero and a falsy `Choice` if `shift >= self.bits_precision()`, + /// or the result and a truthy `Choice` otherwise. pub fn shl(&self, shift: u32) -> (Self, Choice) { // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < bits_precision`). diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index 9a10e6f4..09ee3905 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -6,7 +6,9 @@ use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self >> shift`. - /// Returns `None` if `shift >= self.bits_precision()`. + /// + /// Returns a zero and a falsy `Choice` if `shift >= self.bits_precision()`, + /// or the result and a truthy `Choice` otherwise. pub fn shr(&self, shift: u32) -> (Self, Choice) { // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < bits_precision`). From 18179e0fec20c9c3452fdb5f6639a71d44fa285f Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Dec 2023 12:16:48 -0800 Subject: [PATCH 14/15] Remove an obsolete test --- src/uint/shr.rs | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/uint/shr.rs b/src/uint/shr.rs index ffe984a9..5bb8093b 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -195,19 +195,4 @@ mod tests { ((U128::ZERO, U128::ZERO), CtChoice::TRUE) ); } - - /* - #[test] - fn shr_limb() { - let x = U128::from_be_hex("00112233445566778899aabbccddeeff"); - assert_eq!(x.shr_limb(0), (x, Limb::ZERO)); - assert_eq!( - x.shr_limb(8), - ( - U128::from_be_hex("0000112233445566778899aabbccddee"), - Limb(0xff << (Limb::BITS - 8)) - ) - ); - } - */ } From 6ce585a68e957b45f9a902d14d2ad0db027d2d8f Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Dec 2023 12:22:12 -0800 Subject: [PATCH 15/15] Typo fix in shift docstrings --- src/uint/boxed/shl.rs | 4 ++-- src/uint/boxed/shr.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 733dcf34..8daee882 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -7,8 +7,8 @@ use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self << shift`. /// - /// Returns a zero and a falsy `Choice` if `shift >= self.bits_precision()`, - /// or the result and a truthy `Choice` otherwise. + /// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`, + /// or the result and a falsy `Choice` otherwise. pub fn shl(&self, shift: u32) -> (Self, Choice) { // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < bits_precision`). diff --git a/src/uint/boxed/shr.rs b/src/uint/boxed/shr.rs index 09ee3905..ac1acee5 100644 --- a/src/uint/boxed/shr.rs +++ b/src/uint/boxed/shr.rs @@ -7,8 +7,8 @@ use subtle::{Choice, ConstantTimeLess}; impl BoxedUint { /// Computes `self >> shift`. /// - /// Returns a zero and a falsy `Choice` if `shift >= self.bits_precision()`, - /// or the result and a truthy `Choice` otherwise. + /// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`, + /// or the result and a falsy `Choice` otherwise. pub fn shr(&self, shift: u32) -> (Self, Choice) { // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < bits_precision`).