Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant-time square root and division #277

Closed
wants to merge 8 commits into from
Closed
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 69 additions & 38 deletions src/uint/sqrt.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
//! [`Uint`] square root operations.

use super::Uint;
use crate::{Limb, Word};
use subtle::{ConstantTimeEq, CtOption};

impl<const LIMBS: usize> Uint<LIMBS> {
/// See [`Self::sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `sqrt_vartime` in a future release."
)]
pub const fn sqrt(&self) -> Self {
self.sqrt_vartime()
/// Computes √(`self`) in constant time.
/// Based on Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
///
/// Callers can check if `self` is a square by squaring the result
pub fn sqrt(&self) -> Self {
let max_bits = (self.bits() + 1) >> 1;
let cap = Self::ONE.shl(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think wrapping_div() is currently constant-time (although it could be made so).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, I didn't see that wrapping_div is only constant-time for a fixed rhs. Actually, I think there might be a documentation issue here—I can't find anywhere in the public documentation that says this, only the documentation for ct_div_rem (which doesn't appear in the public docs since it's pub(crate)).

What do you think would be the better approach here: making a new function that's like ct_div_rem but constant-time with respect to both inputs, or modifying ct_div_rem to have this stronger constant-time guarantee and moving the "constant-time only for fixed rhs" behavior to a new function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the name is misleading - it should have had a vartime suffix (and the whole div.rs is kind of a mess in terms of naming - see #268). So the proper way to proceed I think would be to rename the current one to _vartime, and implement a constant-time one in its place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears this is a blocker on merging this PR. I guess we can go ahead and flip over to the v0.6 series per #268 and try to land this PR afterward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking on this: wrapping_div calls const_div_rem, which claims:

    /// This function is constant-time with respect to both `self` and `rhs`.
    pub(crate) const fn const_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which commit are you looking at? The current master has no const_div_rem(), and Uint::wrapping_div() uses ct_div_rem(), which is not constant-time in rhs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changes wrapping_div to call const_div_rem, which is also added by this PR

let t = guess.wrapping_add(&q);
t.shr_vartime(1)
HastD marked this conversation as resolved.
Show resolved Hide resolved
};

// Repeat enough times to guarantee result has stabilized.
// See Hast, "Note on computation of integer square roots" for a proof of this bound.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When your paper has a more permanent link (e.g. on arxiv), please make a PR referencing it here

for _ in 0..usize::BITS - Self::BITS.leading_zeros() {
guess = xn;
xn = {
let q = self.checked_div(&guess).unwrap_or(Self::ZERO);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}

// at least one of `guess` and `xn` is now equal to √(`self`), so return the minimum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So at this point guess == x_n, and xn == x_{n+1}, where n = floor(log2(Self::BITS)). But in the paper it says that it should be n = floor(log2(Self::BITS)) + 1 - am I missing something?

Self::ct_select(&guess, &xn, Uint::ct_gt(&guess, &xn))
}

/// Computes √(`self`)
Expand All @@ -27,21 +46,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};

// If guess increased, the initial guess was low.
// Repeat until reverse course.
while Uint::ct_lt(&guess, &xn).is_true_vartime() {
// Sometimes an increase is too far, especially with large
// powers, and then takes a long time to walk back. The upper
// bound is based on bit size, so saturate on that.
let le = Limb::ct_le(Limb(xn.bits_vartime() as Word), Limb(max_bits as Word));
guess = Self::ct_select(&cap, &xn, le);
xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}
// Note, xn <= guess at this point.

// Repeat while guess decreases.
while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to use ct_gt() and ct_is_nonzero() here, those are constant-time

Expand All @@ -56,29 +61,26 @@ impl<const LIMBS: usize> Uint<LIMBS> {
Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, we don't need constant-timeness.

}

/// See [`Self::wrapping_sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release."
)]
pub const fn wrapping_sqrt(&self) -> Self {
self.wrapping_sqrt_vartime()
/// Wrapped sqrt is just normal √(`self`)
/// There’s no way wrapping could ever happen.
/// This function exists so that all operations are accounted for in the wrapping operations.
pub fn wrapping_sqrt(&self) -> Self {
self.sqrt()
}

/// Wrapped sqrt is just normal √(`self`)
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
/// This function exists so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_sqrt_vartime(&self) -> Self {
self.sqrt_vartime()
}

/// See [`Self::checked_sqrt_vartime`].
#[deprecated(
since = "0.5.3",
note = "This functionality will be moved to `checked_sqrt_vartime` in a future release."
)]
/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
/// only if the √(`self`)² == self
pub fn checked_sqrt(&self) -> CtOption<Self> {
self.checked_sqrt_vartime()
let r = self.sqrt();
let s = r.wrapping_mul(&r);
CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
}

/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
Expand All @@ -103,13 +105,24 @@ mod tests {

#[test]
fn edge() {
assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
assert_eq!(U256::ONE.sqrt(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt(), half);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An idea for the edge case test: a number that actually needs the maximum amount of iterations to converge. According to my tests, the ones before 10000 are 80, 99, 4224, 4355, 4488, 4623, 4760, 4899; but please check independently. (Also an interesting mathematical question - is there some rule for their distribution)

}

#[test]
fn edge_vartime() {
assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO);
assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt_vartime(), half,);
assert_eq!(U256::MAX.sqrt_vartime(), half);
}

#[test]
Expand All @@ -131,13 +144,28 @@ mod tests {
for (a, e) in &tests {
let l = U256::from(*a);
let r = U256::from(*e);
assert_eq!(l.sqrt(), r);
assert_eq!(l.sqrt_vartime(), r);
assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
}
}

#[test]
fn nonsquares() {
assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
}

#[test]
fn nonsquares_vartime() {
assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8));
assert_eq!(
U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
Expand All @@ -163,14 +191,17 @@ mod tests {
let t = rng.next_u32() as u64;
let s = U256::from(t);
let s2 = s.checked_mul(&s).unwrap();
assert_eq!(s2.sqrt(), s);
assert_eq!(s2.sqrt_vartime(), s);
assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
}

for _ in 0..50 {
let s = U256::random(&mut rng);
let mut s2 = U512::ZERO;
s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
assert_eq!(s.square().sqrt(), s2);
assert_eq!(s.square().sqrt_vartime(), s2);
}
}
Expand Down