Skip to content

Commit

Permalink
Make inv_mod2k(_vartime) return a CtChoice (#416)
Browse files Browse the repository at this point in the history
* Make `inv_mod2k(_vartime)` return a CtChoice indicating if the inverse exists
* Make `inv_odd_mod()` return a falsy CtChoice if the given modulus is even
* Make `BoxedUint::inv_mod2k(_vartime)` return a Choice indicating if the inverse exists
* Make `BoxedUint::inv_odd_mod()` return a falsy CtChoice if the given modulus is even
  • Loading branch information
fjarri authored Dec 12, 2023
1 parent 90e472a commit 5ee582b
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 67 deletions.
6 changes: 6 additions & 0 deletions src/ct_choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ impl From<CtChoice> for bool {
}
}

impl PartialEq for CtChoice {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

#[cfg(test)]
mod tests {
use super::CtChoice;
Expand Down
11 changes: 4 additions & 7 deletions src/modular/boxed_residue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,9 @@ impl BoxedResidueParams {

/// Common functionality of `new` and `new_vartime`.
fn new_inner(modulus: BoxedUint, r: BoxedUint, r2: BoxedUint) -> CtOption<Self> {
let is_odd = modulus.is_odd();

// Since we are calculating the inverse modulo (Word::MAX+1),
// we can take the modulo right away and calculate the inverse of the first limb only.
let modulus_lo = BoxedUint::from(modulus.limbs.get(0).copied().unwrap_or_default());
let mod_neg_inv = Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k(Word::BITS).limbs[0].0));
// If the inverse exists, it means the modulus is odd.
let (inv_mod_limb, modulus_is_odd) = modulus.inv_mod2k(Word::BITS);
let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod_limb.limbs[0].0));
let r3 = montgomery_reduction_boxed(&mut r2.square(), &modulus, mod_neg_inv);

let params = Self {
Expand All @@ -119,7 +116,7 @@ impl BoxedResidueParams {
mod_neg_inv,
};

CtOption::new(params, is_odd)
CtOption::new(params, modulus_is_odd)
}

/// Modulus value.
Expand Down
12 changes: 5 additions & 7 deletions src/modular/dyn_residue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{
residue::{Residue, ResidueParams},
Retrieve,
};
use crate::{Integer, Limb, Uint, Word};
use crate::{Limb, Uint, Word};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};

/// Parameters to efficiently go to/from the Montgomery form for an odd modulus provided at runtime.
Expand All @@ -40,11 +40,9 @@ impl<const LIMBS: usize> DynResidueParams<LIMBS> {
let r = Uint::MAX.const_rem(modulus).0.wrapping_add(&Uint::ONE);
let r2 = Uint::const_rem_wide(r.square_wide(), modulus).0;

// Since we are calculating the inverse modulo (Word::MAX+1),
// we can take the modulo right away and calculate the inverse of the first limb only.
let modulus_lo = Uint::<1>::from_words([modulus.limbs[0].0]);
let mod_neg_inv =
Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k_vartime(Word::BITS).limbs[0].0));
// If the inverse does not exist, it means the modulus is odd.
let (inv_mod_limb, modulus_is_odd) = modulus.inv_mod2k_vartime(Word::BITS);
let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod_limb.limbs[0].0));

let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv);

Expand All @@ -56,7 +54,7 @@ impl<const LIMBS: usize> DynResidueParams<LIMBS> {
mod_neg_inv,
};

CtOption::new(params, modulus.is_odd())
CtOption::new(params, modulus_is_odd.into())
}

/// Returns the modulus which was used to initialize these parameters.
Expand Down
1 change: 1 addition & 0 deletions src/modular/residue/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ macro_rules! impl_modulus {
$crate::Word::MIN.wrapping_sub(
Self::MODULUS
.inv_mod2k_vartime($crate::Word::BITS)
.0
.as_limbs()[0]
.0,
),
Expand Down
52 changes: 28 additions & 24 deletions src/uint/boxed/inv_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ impl BoxedUint {
// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
let (a, a_is_some) = self.inv_odd_mod(&s);
let b = self.inv_mod2k(k);
// inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
let b_is_some = k.ct_eq(&0) | self.is_odd();
let (b, b_is_some) = self.inv_mod2k(k);

// Restore from RNS:
// self^{-1} = a mod s = b mod 2^k
// => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
// (essentially one step of the Garner's algorithm for recovery from RNS).

let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists
let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists

// This part is mod 2^k
let mask = Self::one().shl(k).wrapping_sub(&Self::one());
Expand All @@ -39,14 +37,16 @@ impl BoxedUint {

/// Computes 1/`self` mod `2^k`.
///
/// Conditions: `self` < 2^k and `self` must be odd
pub(crate) fn inv_mod2k(&self, k: u32) -> Self {
// This is the same algorithm as in `inv_mod2k_vartime()`,
// but made constant-time w.r.t `k` as well.

/// If the inverse does not exist (`k > 0` and `self` is even),
/// returns `CtChoice::FALSE` as the second element of the tuple,
/// otherwise returns `CtChoice::TRUE`.
pub(crate) fn inv_mod2k(&self, k: u32) -> (Self, Choice) {
let mut x = Self::zero_with_precision(self.bits_precision()); // keeps `x` during iterations
let mut b = Self::one_with_precision(self.bits_precision()); // keeps `b_i` during iterations

// The inverse exists either if `k` is 0 or if `self` is odd.
let is_some = k.ct_eq(&0) | self.is_odd();

for i in 0..self.bits_precision() {
// Only iterations for i = 0..k need to change `x`,
// the rest are dummy ones performed for the sake of constant-timeness.
Expand All @@ -64,7 +64,7 @@ impl BoxedUint {
x.set_bit(i, x_i_choice);
}

x
(x, is_some)
}

/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
Expand All @@ -80,8 +80,8 @@ impl BoxedUint {
/// of `self` and `modulus`, respectively.
///
/// (the inversion speed will be proportional to `bits + modulus_bits`).
/// The second element of the tuple is the truthy value if an inverse exists,
/// otherwise it is a falsy value.
/// The second element of the tuple is the truthy value
/// if `modulus` is odd and an inverse exists, otherwise it is a falsy value.
///
/// **Note:** variable time in `bits` and `modulus_bits`.
///
Expand All @@ -90,7 +90,6 @@ impl BoxedUint {
debug_assert_eq!(self.bits_precision(), modulus.bits_precision());

let bits_precision = self.bits_precision();
debug_assert!(bool::from(modulus.is_odd()));

let mut a = self.clone();
let mut u = Self::one_with_precision(bits_precision);
Expand All @@ -100,13 +99,16 @@ impl BoxedUint {
// `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum.
let bit_size = bits + modulus_bits;

let mut m1hp = modulus.clone();
let (m1hp_new, carry) = m1hp.shr1_with_overflow();
debug_assert!(bool::from(carry));
m1hp = m1hp_new.wrapping_add(&Self::one_with_precision(bits_precision));
let m1hp = modulus
.shr1()
.wrapping_add(&Self::one_with_precision(bits_precision));

let modulus_is_odd = modulus.is_odd();

for _ in 0..bit_size {
debug_assert!(bool::from(b.is_odd()));
// A sanity check that `b` stays odd. Only matters if `modulus` was odd to begin with,
// otherwise this whole thing produces nonsense anyway.
debug_assert!(bool::from(!modulus_is_odd | b.is_odd()));

let self_odd = a.is_odd();

Expand All @@ -125,18 +127,18 @@ impl BoxedUint {
debug_assert!(bool::from(cy.ct_eq(&cyy)));

let (new_a, overflow) = a.shr1_with_overflow();
debug_assert!(!bool::from(overflow));
debug_assert!(bool::from(!modulus_is_odd | !overflow));
let (mut new_u, cy) = new_u.shr1_with_overflow();
let cy = new_u.conditional_adc_assign(&m1hp, cy);
debug_assert!(!bool::from(cy));
debug_assert!(bool::from(!modulus_is_odd | !cy));

a = new_a;
u = new_u;
v = new_v;
}

debug_assert!(bool::from(a.is_zero()));
(v, b.is_one())
debug_assert!(bool::from(!modulus_is_odd | a.is_zero()));
(v, b.is_one() & modulus_is_odd)
}
}

Expand All @@ -157,8 +159,9 @@ mod tests {
256,
)
.unwrap();
let a = v.inv_mod2k(256);
let (a, is_some) = v.inv_mod2k(256);
assert_eq!(e, a);
assert!(bool::from(is_some));

let v = BoxedUint::from_be_slice(
&hex!("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"),
Expand All @@ -170,7 +173,8 @@ mod tests {
256,
)
.unwrap();
let a = v.inv_mod2k(256);
let (a, is_some) = v.inv_mod2k(256);
assert_eq!(e, a);
assert!(bool::from(is_some));
}
}
Loading

0 comments on commit 5ee582b

Please sign in to comment.