diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 550840005..aec80cd73 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -240,7 +240,7 @@ impl DensePolynomial { } pub fn evals(&self) -> Vec { - self.Z.clone() + self.Z[..self.len].to_owned() } pub fn evals_ref(&self) -> &[F] { diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 85be551c5..747e3fe7b 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -165,6 +165,16 @@ impl AuxComputation { }); }); + /* Hack(arasuarun): Set all variables in the last step to 0. + Needed for the crush-second-sumcheck optimization. + There should be a better way to do this instead of iterating over all witness segments. + */ + let mut last_index = batch_size - 1; + while last_index < aux_poly.len() { + aux_poly[last_index] = F::zero(); + last_index += batch_size; + } + DensePolynomial::new(aux_poly) } @@ -671,7 +681,9 @@ impl CombinedUniformBuilder CombinedUniformBuilder CombinedUniformBuilder UniformSpartanKey usize { + self.uniform_r1cs.num_vars + } + /// Evaluates A(r_x, y) + r_rlc * B(r_x, y) + r_rlc^2 * C(r_x, y) where r_x = r_constr || r_step for all y. #[tracing::instrument(skip_all, name = "UniformSpartanKey::evaluate_r1cs_mle_rlc")] pub fn evaluate_r1cs_mle_rlc(&self, r_constr: &[F], r_step: &[F], r_rlc: F) -> Vec { @@ -180,10 +185,9 @@ impl UniformSpartanKey UniformSpartanKey, non_uni_constants: Option>| -> Vec { - // +1 for constant - let mut evals = unsafe_allocate_zero_vec(self.uniform_r1cs.num_vars + 1); + // evals: [inputs, aux ... 1, ...] where ... indicates padding to next power of 2 + let mut evals = + unsafe_allocate_zero_vec(self.uniform_r1cs.num_vars.next_power_of_two() * 4); // *4 to accommodate cross-step constraints for (row, col, val) in constraints.vars.iter() { evals[*col] += mul_0_1_optimized(val, &eq_rx_constr[*row]); } @@ -204,6 +209,8 @@ impl UniformSpartanKey UniformSpartanKey>(); - let mut rlc = unsafe_allocate_zero_vec(self.num_cols_total()); - - { - let span = tracing::span!(tracing::Level::INFO, "big_rlc_computation"); - let _guard = span.enter(); - rlc.par_chunks_mut(self.num_steps) - .take(self.uniform_r1cs.num_vars) - .enumerate() - .for_each(|(var_index, var_chunk)| { - if !sm_rlc[var_index].is_zero() { - for (step_index, item) in var_chunk.iter_mut().enumerate() { - *item = mul_0_1_optimized(&eq_rx_step[step_index], &sm_rlc[var_index]); - } - } - }); - } - - rlc[self.num_vars_total()] = sm_rlc[self.uniform_r1cs.num_vars]; // constant - // Handle non-uniform constraints let update_non_uni = |rlc: &mut Vec, offset: &SparseEqualityItem, @@ -251,19 +239,12 @@ impl UniformSpartanKey UniformSpartanKey F { + pub fn evaluate_z_mle(&self, segment_evals: &[F], r: &[F], with_const: bool) -> F { assert_eq!(self.uniform_r1cs.num_vars, segment_evals.len()); - assert_eq!(r.len(), self.full_z_len().log_2()); + assert_eq!(r.len(), self.full_z_len().log_2()); // Z can be computed in two halves, [Variables, (constant) 1, 0 , ...] indexed by the first bit. let r_const = r[0]; let r_rest = &r[1..]; - assert_eq!(r_rest.len(), self.num_vars_total().log_2()); // Don't need the last log2(num_steps) bits, they've been evaluated already. let var_bits = self.uniform_r1cs.num_vars.next_power_of_two().log_2(); let r_var = &r_rest[..var_bits]; + let r_x_step = &r_rest[var_bits..]; + + let eq_last_step = EqPolynomial::new(r_x_step.to_vec()).evaluate(&vec![F::one(); r_x_step.len()]); let r_var_eq = EqPolynomial::evals(r_var); let eval_variables: F = (0..self.uniform_r1cs.num_vars) .map(|var_index| r_var_eq[var_index] * segment_evals[var_index]) .sum(); - let const_poly = SparsePolynomial::new(self.num_vars_total().log_2(), vec![(F::one(), 0)]); - let eval_const = const_poly.evaluate(r_rest); - (F::one() - r_const) * eval_variables + r_const * eval_const + // If r_const = 1, only the constant position (with all other index bits are 0) has a non-zero value + let var_and_const_bits: usize = var_bits + 1; + let eq_consts = EqPolynomial::new(r[..var_and_const_bits].to_vec()); + let mut eq_const = eq_consts.evaluate(&index_to_field_bitvector( + 1 << (var_and_const_bits - 1), + var_and_const_bits, + )); + let const_coeff = if with_const { F::one() } else { F::zero() }; + + ((F::one() - r_const) * eval_variables) + eq_const * const_coeff * (F::one() - eq_last_step) } /// Evaluates A(r), B(r), C(r) efficiently using their small uniform representations. #[tracing::instrument(skip_all, name = "UniformSpartanKey::evaluate_r1cs_matrix_mles")] - pub fn evaluate_r1cs_matrix_mles(&self, r: &[F]) -> (F, F, F) { + pub fn evaluate_r1cs_matrix_mles(&self, r: &[F], r_choice: &F) -> (F, F, F) { let total_rows_bits = self.num_rows_total().log_2(); let total_cols_bits = self.num_cols_total().log_2(); let steps_bits: usize = self.num_steps.log_2(); @@ -321,12 +311,15 @@ impl UniformSpartanKey| -> F { let mut full_mle_evaluation: F = constraints @@ -334,7 +327,7 @@ impl UniformSpartanKey() - * eq_rx_ry_step; + ; full_mle_evaluation += constraints .consts @@ -350,31 +343,23 @@ impl UniformSpartanKey| -> F { - let mut non_uni_mle = non_uni - .offset_vars - .iter() - .map(|(col, offset, coeff)| { - if !offset { - *coeff * eq_ry_var[*col] * eq_rx_ry_step - } else { - *coeff * eq_ry_var[*col] * eq_step_offset_1 - } - }) - .sum::(); - - non_uni_mle += non_uni.constant * col_eq_constant; - - non_uni_mle + let mut non_uni_a_mle = F::zero(); + let mut non_uni_b_mle = F::zero(); + + let compute_non_uniform = |uni_mle: &mut F, non_uni_mle: &mut F, non_uni: &SparseEqualityItem, eq_rx: F| { + for (col, offset, coeff) in &non_uni.offset_vars { + if !offset { + *uni_mle += *coeff * eq_ry_var[*col] * eq_rx; + } else { + *non_uni_mle += *coeff * eq_ry_var[*col] * eq_rx; + } + } }; for (i, constraint) in self.offset_eq_r1cs.constraints.iter().enumerate() { - let non_uni_a = compute_non_uniform(&constraint.eq); - let non_uni_b = compute_non_uniform(&constraint.condition); let non_uni_constraint_index = index_to_field_bitvector(self.uniform_r1cs.num_rows + i, constraint_rows_bits); @@ -386,10 +371,33 @@ impl UniformSpartanKey>| { + if let Some(non_uni_constants) = non_uni_constants { + for (i, non_uni_constant) in non_uni_constants.iter().enumerate() { + // The matrix values are present even in the last step. + // It's the role of the evaluation of the z mle to ignore the last step. + let first_non_uniform_row = self.uniform_r1cs.num_rows; + *uni_mle += eq_rx_constr[first_non_uniform_row + i] * non_uni_constant * col_eq_constant; + } + } + }; + + let (eq_constants, condition_constants) = self.offset_eq_r1cs.constants(); + compute_non_uni_constants(&mut a_mle, Some(eq_constants)); + compute_non_uni_constants(&mut b_mle, Some(condition_constants)); + + a_mle = (F::one() - r_choice) * a_mle + + *r_choice * non_uni_a_mle; + b_mle = (F::one() - r_choice) * b_mle + + *r_choice * non_uni_b_mle; + c_mle = (F::one() - r_choice) * c_mle; + (a_mle, b_mle, c_mle) } @@ -486,39 +494,39 @@ impl UniformSpartanKey = [r_row_constr, r_row_step].concat(); -// for i in 0..key.num_cols_total() { -// let col_coordinate = index_to_field_bitvector(i, col_coordinate_len); - -// let coordinate: Vec = [row_coordinate.clone(), col_coordinate].concat(); -// let expected_rlc = a.evaluate(&coordinate) -// + r_rlc * b.evaluate(&coordinate) -// + r_rlc * r_rlc * c.evaluate(&coordinate); - -// assert_eq!(expected_rlc, rlc[i], "Failed at {i}"); -// } -// } + // #[test] + // fn evaluate_r1cs_mle_rlc() { + // let (_builder, key) = simp_test_builder_key(); + // let (a, b, c) = simp_test_big_matrices(); + // let a = DensePolynomial::new(a); + // let b = DensePolynomial::new(b); + // let c = DensePolynomial::new(c); + + // let r_row_constr_len = (key.uniform_r1cs.num_rows + 1).next_power_of_two().log_2(); + // let r_col_step_len = key.num_steps.log_2(); + + // let r_row_constr = vec![Fr::from(100), Fr::from(200)]; + // let r_row_step = vec![Fr::from(100), Fr::from(200)]; + // assert_eq!(r_row_constr.len(), r_row_constr_len); + // assert_eq!(r_row_step.len(), r_col_step_len); + // let r_rlc = Fr::from(1000); + + // let rlc = key.evaluate_r1cs_mle_rlc(&r_row_constr, &r_row_step, r_rlc); + + // // let row_coordinate_len = key.num_rows_total().log_2(); + // let col_coordinate_len = key.num_cols_total().log_2(); + // let row_coordinate: Vec = [r_row_constr, r_row_step].concat(); + // for i in 0..key.num_cols_total() { + // let col_coordinate = index_to_field_bitvector(i, col_coordinate_len); + + // let coordinate: Vec = [row_coordinate.clone(), col_coordinate].concat(); + // let expected_rlc = a.evaluate(&coordinate) + // + r_rlc * b.evaluate(&coordinate) + // + r_rlc * r_rlc * c.evaluate(&coordinate); + + // assert_eq!(expected_rlc, rlc[i], "Failed at {i}"); + // } + // } // #[test] // fn r1cs_matrix_mles_offset_constraints() { diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 1be9ee2ed..1e5abc594 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -1,5 +1,6 @@ #![allow(clippy::len_without_is_empty)] +// use core::range; use std::marker::PhantomData; use crate::field::JoltField; @@ -23,6 +24,7 @@ use thiserror::Error; use crate::{ poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, subprotocols::sumcheck::SumcheckInstanceProof, + r1cs::special_polys::eq_plus_one, }; use super::builder::CombinedUniformBuilder; @@ -77,7 +79,10 @@ pub struct UniformSpartanProof< pub(crate) outer_sumcheck_proof: SumcheckInstanceProof, pub(crate) outer_sumcheck_claims: (F, F, F), pub(crate) inner_sumcheck_proof: SumcheckInstanceProof, + pub(crate) shift_sumcheck_proof: SumcheckInstanceProof, + pub(crate) shift_sumcheck_claim: F, pub(crate) claimed_witness_evals: Vec, + pub(crate) claimed_witness_evals_shift_sumcheck: Vec, _marker: PhantomData, } @@ -155,31 +160,137 @@ where + r_inner_sumcheck_RLC * claim_Bz + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * claim_Cz; + let num_steps_padded = constraint_builder.uniform_repeat().next_power_of_two(); + let num_steps_bits = num_steps_padded.ilog2() as usize; + let num_constraints_bits = key.num_cons_total.log_2() - num_steps_bits; + + let r_x_step = &outer_sumcheck_r[num_constraints_bits..]; + + // Binding 1: evaluating z on r_x_step + let is_last_step = EqPolynomial::new(r_x_step.to_vec()).evaluate(&vec![F::one(); r_x_step.len()]); + let eq_rx_step = EqPolynomial::evals(r_x_step); + + let mut evals: Vec = flattened_polys + .par_iter() + .map(|poly| { + poly.Z + .par_iter() + .enumerate() + .map(|(t, &val)| { + if t == num_steps_padded - 1 { // ignore last step + F::zero() + } else { + val * eq_rx_step[t] + } + }) + .sum() + }) + .collect(); + evals.resize(evals.len().next_power_of_two(), F::zero()); + evals.push(F::one() - is_last_step); // Constant, ignores the last step. + evals.resize(evals.len().next_power_of_two(), F::zero()); + + let eq_plus_one_rx_step: Vec = (0..num_steps_padded) + .map(|t| eq_plus_one(r_x_step, &crate::utils::index_to_field_bitvector(t, num_steps_bits), num_steps_bits)) + .collect(); + + let mut evals_shifted: Vec = flattened_polys + .par_iter() + .map(|poly| { + poly.Z + .par_iter() + .enumerate() + .map(|(t, &val)| { + if t == num_steps_padded - 1 { // ignore last step + F::zero() + } else { + val * eq_plus_one_rx_step[t] + } + }) + .sum() + }) + .collect(); + evals_shifted.resize(evals.len(), F::zero()); + + let poly_z = DensePolynomial::new(evals.into_iter().chain(evals_shifted.into_iter()).collect()); + // this is the polynomial extended from the vector r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y - let num_steps_bits = constraint_builder - .uniform_repeat() - .next_power_of_two() - .ilog2(); let (rx_con, rx_ts) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); - let mut poly_ABC = + let poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); - + assert_eq!(poly_z.len(), poly_ABC.len()); + assert_eq!(poly_ABC.len(), key.num_vars_uniform().next_power_of_two() * 4); // *4 to support cross_step constraints + + let num_rounds = poly_ABC.len().log_2(); + let mut polys = vec![poly_ABC, poly_z]; + let comb_func = |poly_evals: &[F]| -> F { + assert_eq!(poly_evals.len(), 2); + poly_evals[0] * poly_evals[1] + }; let (inner_sumcheck_proof, inner_sumcheck_r, _claims_inner) = - SumcheckInstanceProof::prove_spartan_quadratic( - &claim_inner_joint, // r_A * v_A + r_B * v_B + r_C * v_C - num_rounds_y, - &mut poly_ABC, // r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y - &flattened_polys, - transcript, - ); - drop_in_background_thread(poly_ABC); - - // Requires 'r_col_segment_bits' to index the (const, segment). Within that segment we index the step using 'r_col_step' - let r_col_segment_bits = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; - let r_col_step = &inner_sumcheck_r[r_col_segment_bits..]; - - let chi = EqPolynomial::evals(r_col_step); + SumcheckInstanceProof::prove_arbitrary( + &claim_inner_joint, + num_rounds, + &mut polys, + comb_func, + 2, + transcript); + + drop_in_background_thread(polys); + + let r_y_var = inner_sumcheck_r[1..].to_vec(); + assert_eq!(r_y_var.len(), key.num_vars_uniform().next_power_of_two().log_2() + 1); + + let eq_ry_var = EqPolynomial::evals(&r_y_var); + + /* Sumcheck 3: the shift sumcheck */ + + /* Binding 2: evaluating z on r_y_var + + TODO(arasuarun): this might lead to inefficient memory paging + as we access each poly in flattened_poly num_steps_padded-many times. + */ + let mut evals_z_r_y_var: Vec = (0..constraint_builder.uniform_repeat()) + .map(|t| { + flattened_polys + .par_iter() + .enumerate() + .map(|(i, poly)| { + if t < poly.Z.len() { + poly.Z[t] * eq_ry_var[i] + } else { + F::zero() + } + }) + .sum() + }) + .collect(); + evals_z_r_y_var.resize(num_steps_padded, F::zero()); + + let num_rounds_shift_sumcheck = num_steps_bits; + let mut shift_sumcheck_polys = vec![DensePolynomial::new(evals_z_r_y_var), DensePolynomial::new(eq_plus_one_rx_step.clone())]; + + let shift_sumcheck_claim = (0..((1 << num_rounds_shift_sumcheck) - 1)) + .into_par_iter() + .map(|i| { + let params: Vec = shift_sumcheck_polys.iter().map(|poly| poly[i]).collect(); + comb_func(¶ms) + }) + .reduce(|| F::zero(), |acc, x| acc + x); + + let (shift_sumcheck_proof, shift_sumcheck_r, shift_sumcheck_claims) = + SumcheckInstanceProof::prove_arbitrary( + &shift_sumcheck_claim, + num_rounds_shift_sumcheck, + &mut shift_sumcheck_polys, + comb_func, + 2, + transcript); + drop_in_background_thread(shift_sumcheck_polys); + + // Polynomial evals for inner sumcheck + let chi = EqPolynomial::evals(&r_x_step); let claimed_witness_evals: Vec<_> = flattened_polys .par_iter() .map(|poly| poly.evaluate_at_chi_low_optimized(&chi)) @@ -188,12 +299,27 @@ where opening_accumulator.append( &flattened_polys, DensePolynomial::new(chi), - r_col_step.to_vec(), + r_x_step.to_vec(), &claimed_witness_evals.iter().collect::>(), transcript, ); - // Outer sumcheck claims: [A(r_x), B(r_x), C(r_x)] + // Polynomial evals for shift sumcheck + let chi2 = EqPolynomial::evals(&shift_sumcheck_r); + let claimed_witness_evals_shift_sumcheck: Vec<_> = flattened_polys + .par_iter() + .map(|poly| poly.evaluate_at_chi_low_optimized(&chi2)) + .collect(); + + opening_accumulator.append( + &flattened_polys, + DensePolynomial::new(chi2), + shift_sumcheck_r.to_vec(), + &claimed_witness_evals_shift_sumcheck.iter().collect::>(), + transcript, + ); + + // Outer sumcheck claims: [eq(r_x), A(r_x), B(r_x), C(r_x)] let outer_sumcheck_claims = ( outer_sumcheck_claims[0], outer_sumcheck_claims[1], @@ -204,7 +330,10 @@ where outer_sumcheck_proof, outer_sumcheck_claims, inner_sumcheck_proof, + shift_sumcheck_proof, + shift_sumcheck_claim, claimed_witness_evals, + claimed_witness_evals_shift_sumcheck, _marker: PhantomData, }) } @@ -260,34 +389,58 @@ where + r_inner_sumcheck_RLC * self.outer_sumcheck_claims.1 + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * self.outer_sumcheck_claims.2; + let num_rounds = (key.num_vars_uniform() * 2).next_power_of_two().log_2() + 1; // +1 for cross-step let (claim_inner_final, inner_sumcheck_r) = self .inner_sumcheck_proof - .verify(claim_inner_joint, num_rounds_y, 2, transcript) + .verify(claim_inner_joint, num_rounds, 2, transcript) .map_err(|_| SpartanError::InvalidInnerSumcheckProof)?; - // n_prefix = n_segments + 1 - let n_prefix = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; + let n_constraint_bits_uniform = key.uniform_r1cs.num_rows.next_power_of_two().log_2(); + let outer_sumcheck_r_step = &r_x[n_constraint_bits_uniform..]; - let eval_Z = key.evaluate_z_mle(&self.claimed_witness_evals, &inner_sumcheck_r); + let r_choice = inner_sumcheck_r[0]; + let r_y_var = inner_sumcheck_r[1..].to_vec(); + let y_prime = [r_y_var.clone(), outer_sumcheck_r_step.to_owned()].concat(); + let eval_z = key.evaluate_z_mle(&self.claimed_witness_evals, &y_prime, true); - let r_y = inner_sumcheck_r.clone(); - let r = [r_x, r_y].concat(); - let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r); + let r = [r_x.clone(), y_prime].concat(); + let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r, &r_choice); let left_expected = eval_a + r_inner_sumcheck_RLC * eval_b + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * eval_c; - let right_expected = eval_Z; + let right_expected = + (F::one() - r_choice) * eval_z + + r_choice * self.shift_sumcheck_claim; let claim_inner_final_expected = left_expected * right_expected; + + assert_eq!(claim_inner_final, claim_inner_final_expected); if claim_inner_final != claim_inner_final_expected { return Err(SpartanError::InvalidInnerSumcheckClaim); } + let num_steps_bits = outer_sumcheck_r_step.len(); + let num_rounds_shift_sumcheck = num_steps_bits; + let (claim_shift_final, shift_sumcheck_r) = self + .shift_sumcheck_proof + .verify(self.shift_sumcheck_claim, num_rounds_shift_sumcheck, 2, transcript) + .map_err(|_| SpartanError::InvalidInnerSumcheckProof)?; + + let y_prime_shift_sumcheck = [r_y_var, shift_sumcheck_r.to_owned()].concat(); + let eval_z_shift_sumcheck = key.evaluate_z_mle(&self.claimed_witness_evals_shift_sumcheck, &y_prime_shift_sumcheck, false); + let eq_plus_one_shift_sumcheck = eq_plus_one(&outer_sumcheck_r_step, &shift_sumcheck_r, num_steps_bits); + let claim_shift_sumcheck_expected = eval_z_shift_sumcheck * eq_plus_one_shift_sumcheck; + assert_eq!(claim_shift_final, claim_shift_sumcheck_expected); + if claim_shift_final != claim_shift_sumcheck_expected { + return Err(SpartanError::InvalidInnerSumcheckClaim); + } + let flattened_commitments: Vec<_> = I::flatten::() .iter() .map(|var| var.get_ref(commitments)) .collect(); - let r_y_point = &inner_sumcheck_r[n_prefix..]; + + let r_y_point = &r_x[n_constraint_bits_uniform..]; opening_accumulator.append( &flattened_commitments, r_y_point.to_vec(), @@ -295,6 +448,13 @@ where transcript, ); + opening_accumulator.append( + &flattened_commitments, + shift_sumcheck_r.to_vec(), + &self.claimed_witness_evals_shift_sumcheck.iter().collect::>(), + transcript, + ); + Ok(()) } } @@ -353,3 +513,93 @@ where // .expect("Spartan verifier failed"); // } // } + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + use rand_core::{RngCore, CryptoRng}; + use ark_bn254::Fr as F; // Add this line to import the field type + use ark_ff::{Zero, One}; // Import the Zero trait + + #[test] + fn test_shifted_polynomial_evaluations() { + // Generate a vector z of random field elements of length 128 + let mut rng = rand::thread_rng(); + let z: Vec = (0..128).map(|_| F::from(rng.gen::())).collect(); + + // Resize z to the next power of two + let mut z_resized = z.clone(); + z_resized.resize(z.len().next_power_of_two() * 2, F::zero()); + + println!("z_resized.len(): {:?}", z_resized.len()); + + let r_x_step: Vec = vec![F::zero(), F::zero(), F::one(), F::zero()]; + + // Create the polynomial from z + let mut poly_z = DensePolynomial::new(z_resized.clone()); + for r_s in r_x_step.iter().rev() { + poly_z.bound_poly_var_bot(r_s); + } + let evals = poly_z.evals(); + + // Create the shifted polynomial from z + let mut z_shifted: Vec = z[1..].to_vec(); + z_shifted.resize(z.len().next_power_of_two(), F::zero()); + + let mut poly_z_shifted = DensePolynomial::new(z_shifted.clone()); + for r_s in r_x_step.iter().rev() { + poly_z_shifted.bound_poly_var_bot(r_s); + } + let evals_shifted = poly_z_shifted.evals(); + + // // print the first 10 lines of evals and evals_shifted + // for i in 0..4 { + // println!("z: {:?}", z); + // println!("evals_shifted: {:?}", evals_shifted); + // // println!("evals[{}]: {:?}, evals_shifted[{}]: {:?}", i, evals[i], i, evals_shifted[i]); + // println!("z[{}]: {:?}, evals_shifted[{}]: {:?}", i+1, z[i+1], i, evals_shifted[i]); + // // println!("z[{}]: {:?}", i+1, z[i+1]); + + // } + + // print each element of z preceded by index: + for i in 0..z.len() { + println!("z[{}]: {:?}", i, z[i]); + } + println!("evals_shifted: {:?}", evals_shifted); + + + + // // Evaluate the polynomials at a random point k + // let k: F = F::random(&mut rng); + // let eval_at_k = poly_z.evaluate(&k); + // let eval_shifted_at_k_minus_1 = poly_z_shifted.evaluate(&(k - F::one())); + + // // Check if the evaluations are correct + // assert_eq!(eval_at_k, eval_shifted_at_k_minus_1); + } + #[test] + fn test_eq_polynomial_evals() { + // Generate a random vector of length 8 + let mut rng = rand::thread_rng(); + let random_vector: Vec = (0..8).map(|_| F::from(rng.gen::())).collect(); + + // generate all 1s vector of lenght 8 + let all_ones_vector: Vec = (0..8).map(|_| F::one()).collect(); + + // Run EqPolynomial::evals on the random vector + let eq_evals = EqPolynomial::evals(&random_vector); + let all_ones_evals = EqPolynomial::evals(&all_ones_vector); + + // // Print the random vector and its evaluations + // for i in 0..random_vector.len() { + // println!("random_vector[{}]: {:?}", i, random_vector[i]); + // } + // println!("eq_evals: {:?}", eq_evals); + println!("all_ones_evals.last(): {:?}", all_ones_evals[2]); + + // // Check if the evaluations are correct (this is a placeholder, you should replace it with actual checks) + // assert_eq!(eq_evals.len(), random_vector.len().next_power_of_two()); + } +} \ No newline at end of file diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index 008f892be..95553caa4 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -105,6 +105,17 @@ impl SumcheckInstanceProof = Vec::new(); let mut compressed_polys: Vec> = Vec::new(); + #[cfg(test)] + { + let total_evals = 1 << num_rounds; + let mut sum = F::zero(); + for i in 0..total_evals { + let params: Vec = polys.iter().map(|poly| poly[i]).collect(); + sum += comb_func(¶ms); + } + assert_eq!(&sum, _claim, "Sumcheck claim is wrong"); + } + for _round in 0..num_rounds { // Vector storing evaluations of combined polynomials g(x) = P_0(x) * ... P_{num_polys} (x) // for points {0, ..., |g(x)|}