Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convenience impls for common types #137

Merged
merged 6 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ impl AllocationMode {

/// Specifies how variables of type `Self` should be allocated in a
/// `ConstraintSystem`.
pub trait AllocVar<V, F: Field>
where
Self: Sized,
V: ?Sized,
{
pub trait AllocVar<V: ?Sized, F: Field>: Sized {
/// Allocates a new variable of type `Self` in the `ConstraintSystem` `cs`.
/// The mode of allocation is decided by `mode`.
fn new_variable<T: Borrow<V>>(
Expand Down Expand Up @@ -92,10 +88,56 @@ impl<I, F: Field, A: AllocVar<I, F>> AllocVar<[I], F> for Vec<A> {
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
let mut vec = Vec::new();
for value in f()?.borrow().iter() {
vec.push(A::new_variable(cs.clone(), || Ok(value), mode)?);
}
Ok(vec)
f().and_then(|v| {
v.borrow()
.iter()
.map(|e| A::new_variable(cs.clone(), || Ok(e), mode))
.collect()
})
}
}

/// Dummy impl for `()`.
impl<F: Field> AllocVar<(), F> for () {
fn new_variable<T: Borrow<()>>(
_cs: impl Into<Namespace<F>>,
_f: impl FnOnce() -> Result<T, SynthesisError>,
_mode: AllocationMode,
) -> Result<Self, SynthesisError> {
Ok(())
}
}

/// This blanket implementation just allocates variables in `Self`
/// element by element.
impl<I, F: Field, A: AllocVar<I, F>, const N: usize> AllocVar<[I; N], F> for [A; N] {
fn new_variable<T: Borrow<[I; N]>>(
cs: impl Into<Namespace<F>>,
f: impl FnOnce() -> Result<T, SynthesisError>,
mode: AllocationMode,
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
f().map(|v| {
let v = v.borrow();
core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap())
})
}
}

/// This blanket implementation just allocates variables in `Self`
/// element by element.
impl<I, F: Field, A: AllocVar<I, F>, const N: usize> AllocVar<[I], F> for [A; N] {
fn new_variable<T: Borrow<[I]>>(
cs: impl Into<Namespace<F>>,
f: impl FnOnce() -> Result<T, SynthesisError>,
mode: AllocationMode,
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
f().map(|v| {
let v = v.borrow();
core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap())
})
}
}
6 changes: 3 additions & 3 deletions src/boolean/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl<F: PrimeField> Boolean<F> {
let mut bits_iter = bits.iter().rev(); // Iterate in big-endian

// Runs of ones in r
let mut last_run = Boolean::constant(true);
let mut last_run = Boolean::TRUE;
let mut current_run = vec![];

let mut element_num_bits = 0;
Expand All @@ -57,12 +57,12 @@ impl<F: PrimeField> Boolean<F> {
}

if bits.len() > element_num_bits {
let mut or_result = Boolean::constant(false);
let mut or_result = Boolean::FALSE;
for should_be_zero in &bits[element_num_bits..] {
or_result |= should_be_zero;
let _ = bits_iter.next().unwrap();
}
or_result.enforce_equal(&Boolean::constant(false))?;
or_result.enforce_equal(&Boolean::FALSE)?;
}

for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) {
Expand Down
2 changes: 1 addition & 1 deletion src/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<F: Field> Boolean<F> {
/// let true_var = Boolean::<Fr>::TRUE;
/// let false_var = Boolean::<Fr>::FALSE;
///
/// true_var.enforce_equal(&Boolean::constant(true))?;
/// true_var.enforce_equal(&Boolean::TRUE)?;
/// false_var.enforce_equal(&Boolean::constant(false))?;
/// # Ok(())
/// # }
Expand Down
38 changes: 35 additions & 3 deletions src/cmp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use ark_ff::Field;
use ark_ff::{Field, PrimeField};
use ark_relations::r1cs::SynthesisError;

use crate::{boolean::Boolean, R1CSVar};
use crate::{boolean::Boolean, eq::EqGadget, R1CSVar};

/// Specifies how to generate constraints for comparing two variables.
pub trait CmpGadget<F: Field>: R1CSVar<F> {
pub trait CmpGadget<F: Field>: R1CSVar<F> + EqGadget<F> {
fn is_gt(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
other.is_lt(self)
}
Expand All @@ -19,3 +19,35 @@ pub trait CmpGadget<F: Field>: R1CSVar<F> {
other.is_ge(self)
}
}

/// Mimics the behavior of `std::cmp::PartialOrd` for `()`.
impl<F: Field> CmpGadget<F> for () {
fn is_gt(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::FALSE)
}

fn is_ge(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::TRUE)
}

fn is_lt(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::FALSE)
}

fn is_le(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::TRUE)
}
}

/// Mimics the lexicographic comparison behavior of `std::cmp::PartialOrd` for `[T]`.
impl<T: CmpGadget<F>, F: PrimeField> CmpGadget<F> for [T] {
fn is_ge(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
let mut result = Boolean::TRUE;
let mut all_equal_so_far = Boolean::TRUE;
for (a, b) in self.iter().zip(other) {
all_equal_so_far &= a.is_eq(b)?;
result &= a.is_gt(b)? | &all_equal_so_far;
}
Ok(result)
}
}
54 changes: 54 additions & 0 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,60 @@ impl<'a, F: Field, T: 'a + ToBytesGadget<F>> ToBytesGadget<F> for &'a T {
}
}

impl<T: ToBytesGadget<F>, F: Field> ToBytesGadget<F> for [T] {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bytes = Vec::new();
for elem in self {
let elem = elem.to_bytes_le()?;
bytes.extend_from_slice(&elem);
// Make sure that there's enough capacity to avoid reallocations.
bytes.reserve(elem.len() * (self.len() - 1));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we do this at each iteration of the loop in case the elements have different len?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually wasn't the smartest logic, but it kind of works out: if the elem.len() is less than or equal to that for previous elems, then don't change the capacity, resulting in a no-op. If elem.len() is larger than that for previous elems, it will reserve more space.

}
Ok(bytes)
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bytes = Vec::new();
for elem in self {
let elem = elem.to_non_unique_bytes_le()?;
bytes.extend_from_slice(&elem);
// Make sure that there's enough capacity to avoid reallocations.
bytes.reserve(elem.len() * (self.len() - 1));
}
Ok(bytes)
}
}

impl<T: ToBytesGadget<F>, F: Field> ToBytesGadget<F> for Vec<T> {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_bytes_le()
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_non_unique_bytes_le()
}
}

impl<T: ToBytesGadget<F>, F: Field, const N: usize> ToBytesGadget<F> for [T; N] {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_bytes_le()
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_non_unique_bytes_le()
}
}

impl<F: Field> ToBytesGadget<F> for () {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
Ok(Vec::new())
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
Ok(Vec::new())
}
}

/// Specifies how to convert a variable of type `Self` to variables of
/// type `FpVar<ConstraintF>`
pub trait ToConstraintFieldGadget<ConstraintF: ark_ff::PrimeField> {
Expand Down
109 changes: 100 additions & 9 deletions src/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub trait EqGadget<F: Field> {
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.is_eq(&other)?
.conditional_enforce_equal(&Boolean::constant(true), should_enforce)
.conditional_enforce_equal(&Boolean::TRUE, should_enforce)
}

/// Enforce that `self` and `other` are equal.
Expand All @@ -46,7 +46,7 @@ pub trait EqGadget<F: Field> {
/// are encouraged to carefully analyze the efficiency and safety of these.
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> {
self.conditional_enforce_equal(other, &Boolean::constant(true))
self.conditional_enforce_equal(other, &Boolean::TRUE)
}

/// If `should_enforce == true`, enforce that `self` and `other` are *not*
Expand All @@ -65,7 +65,7 @@ pub trait EqGadget<F: Field> {
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.is_neq(&other)?
.conditional_enforce_equal(&Boolean::constant(true), should_enforce)
.conditional_enforce_equal(&Boolean::TRUE, should_enforce)
}

/// Enforce that `self` and `other` are *not* equal.
Expand All @@ -78,20 +78,23 @@ pub trait EqGadget<F: Field> {
/// are encouraged to carefully analyze the efficiency and safety of these.
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn enforce_not_equal(&self, other: &Self) -> Result<(), SynthesisError> {
self.conditional_enforce_not_equal(other, &Boolean::constant(true))
self.conditional_enforce_not_equal(other, &Boolean::TRUE)
}
}

impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
assert_eq!(self.len(), other.len());
assert!(!self.is_empty());
let mut results = Vec::with_capacity(self.len());
for (a, b) in self.iter().zip(other) {
results.push(a.is_eq(b)?);
if self.is_empty() & other.is_empty() {
Ok(Boolean::TRUE)
} else {
let mut results = Vec::with_capacity(self.len());
for (a, b) in self.iter().zip(other) {
results.push(a.is_eq(b)?);
}
Boolean::kary_and(&results)
}
Boolean::kary_and(&results)
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
Expand Down Expand Up @@ -128,3 +131,91 @@ impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
}
}
}

/// This blanket implementation just forwards to the impl on [`[T]`].
impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for Vec<T> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
self.as_slice().is_eq(other.as_slice())
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_equal(
&self,
other: &Self,
condition: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_equal(other.as_slice(), condition)
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_not_equal(
&self,
other: &Self,
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_not_equal(other.as_slice(), should_enforce)
}
}

/// Dummy impl for `()`.
impl<F: Field> EqGadget<F> for () {
/// Output a `Boolean` value representing whether `self.value() ==
/// other.value()`.
#[inline]
fn is_eq(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::TRUE)
}

/// If `should_enforce == true`, enforce that `self` and `other` are equal;
/// else, enforce a vacuously true statement.
///
/// This is a no-op as `self.is_eq(other)?` is always `true`.
#[tracing::instrument(target = "r1cs", skip(self, _other))]
fn conditional_enforce_equal(
&self,
_other: &Self,
_should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
Ok(())
}

/// Enforce that `self` and `other` are equal.
///
/// This does not generate any constraints as `self.is_eq(other)?` is always
/// `true`.
#[tracing::instrument(target = "r1cs", skip(self, _other))]
fn enforce_equal(&self, _other: &Self) -> Result<(), SynthesisError> {
Ok(())
}
}

/// This blanket implementation just forwards to the impl on [`[T]`].
impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField, const N: usize> EqGadget<F> for [T; N] {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
self.as_slice().is_eq(other.as_slice())
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_equal(
&self,
other: &Self,
condition: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_equal(other.as_slice(), condition)
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_not_equal(
&self,
other: &Self,
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_not_equal(other.as_slice(), should_enforce)
}
}
2 changes: 1 addition & 1 deletion src/fields/emulated_fp/allocated_field_var.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ impl<TargetF: PrimeField, BaseF: PrimeField> ToBytesGadget<BaseF>

let num_bits = TargetF::BigInt::NUM_LIMBS * 64;
assert!(bits.len() <= num_bits);
bits.resize_with(num_bits, || Boolean::constant(false));
bits.resize_with(num_bits, || Boolean::FALSE);

let bytes = bits.chunks(8).map(UInt8::from_bits_le).collect();
Ok(bytes)
Expand Down
4 changes: 2 additions & 2 deletions src/fields/fp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ impl<F: PrimeField> ToBytesGadget<F> for AllocatedFp<F> {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let num_bits = F::BigInt::NUM_LIMBS * 64;
let mut bits = self.to_bits_le()?;
let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len());
let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
bits.extend(remainder);
let bytes = bits
.chunks(8)
Expand All @@ -568,7 +568,7 @@ impl<F: PrimeField> ToBytesGadget<F> for AllocatedFp<F> {
fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let num_bits = F::BigInt::NUM_LIMBS * 64;
let mut bits = self.to_non_unique_bits_le()?;
let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len());
let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
bits.extend(remainder);
let bytes = bits
.chunks(8)
Expand Down
Loading
Loading