From 2b1b74f8460b3e69bdcca84841f09f241004c9ec Mon Sep 17 00:00:00 2001 From: Mihir Wadekar Date: Sun, 15 Dec 2024 22:37:35 -0800 Subject: [PATCH] fix(r1cs): We add a range check to the chunk --- jolt-core/src/r1cs/builder.rs | 2 +- jolt-core/src/r1cs/constraints.rs | 40 +++++++++++++++++++++++- jolt-core/src/r1cs/inputs.rs | 51 +++++++++++++++++++++++++------ 3 files changed, 82 insertions(+), 11 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 85be551c5..da187d5f3 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -228,7 +228,7 @@ impl R1CSBuilder { } } - fn allocate_aux( + pub(crate) fn allocate_aux( &mut self, aux_symbol: I, symbolic_inputs: Vec, diff --git a/jolt-core/src/r1cs/constraints.rs b/jolt-core/src/r1cs/constraints.rs index 415f3587d..d88dae294 100644 --- a/jolt-core/src/r1cs/constraints.rs +++ b/jolt-core/src/r1cs/constraints.rs @@ -18,7 +18,7 @@ use crate::{ use super::{ builder::{CombinedUniformBuilder, OffsetEqConstraint, R1CSBuilder}, inputs::{AuxVariable, ConstraintInput, JoltR1CSInputs}, - ops::Variable, + ops::{Term, Variable, LC}, }; pub const PC_START_ADDRESS: i64 = 0x80000000; @@ -170,11 +170,20 @@ impl R1CSConstraints for JoltRV32IMConstrain y, ); + // Range check all the chunks + for i in 0..C { + range_check_single_8bit(cs, x_chunks[i]); + range_check_single_8bit(cs, y_chunks[i]); + range_check_single_8bit(cs, query_chunks[i]); + } + // if is_shift ? chunks_query[i] == zip(chunks_x[i], chunks_y[C-1]) : chunks_query[i] == zip(chunks_x[i], chunks_y[i]) let is_shift = JoltR1CSInputs::InstructionFlags(SLLInstruction::default().into()) + JoltR1CSInputs::InstructionFlags(SRLInstruction::default().into()) + JoltR1CSInputs::InstructionFlags(SRAInstruction::default().into()); for i in 0..C { + // Range check x_chunks[i], y_chunks[i], query_chunks[i] here + let relevant_chunk_y = cs.allocate_if_else( JoltR1CSInputs::Aux(AuxVariable::RelevantYChunk(i)), is_shift.clone(), @@ -262,3 +271,32 @@ impl R1CSConstraints for JoltRV32IMConstrain vec![pc_constraint, virtual_sequence_constraint] } } + +fn range_check_single_8bit( + cs: &mut R1CSBuilder, + chunk: Variable, +) { + let mut sum_expr = LC::zero(); + + for i in 0..8 { + let bit_var = cs.allocate_aux( + JoltR1CSInputs::Aux(AuxVariable::BitDecomposition(i)), + vec![], + Box::new(|_values| F::zero()), + ); + + // Constrain the bit to be binary. + cs.constrain_binary(bit_var); + + // Create a Term for bit_var * 2^i: Term(bit_var, (1 << i)) + let bit_term = Term(bit_var, (1 << i) as i64); + + // Add this term to the sum_expr. Term implements Into, so this works. + sum_expr = sum_expr + bit_term; + } + + // Finally, constrain chunk == sum_expr + cs.constrain_eq(chunk, sum_expr); +} + + diff --git a/jolt-core/src/r1cs/inputs.rs b/jolt-core/src/r1cs/inputs.rs index f73deab4d..331d6c9a7 100644 --- a/jolt-core/src/r1cs/inputs.rs +++ b/jolt-core/src/r1cs/inputs.rs @@ -238,7 +238,7 @@ impl() -> Vec { JoltR1CSInputs::iter() @@ -337,14 +340,26 @@ impl ConstraintInput for JoltR1CSInputs { Self::ChunksY(_) => (0..C).map(Self::ChunksY).collect(), Self::OpFlags(_) => CircuitFlags::iter().map(Self::OpFlags).collect(), Self::InstructionFlags(_) => RV32I::iter().map(Self::InstructionFlags).collect(), - Self::Aux(_) => AuxVariable::iter() - .flat_map(|aux| match aux { + Self::Aux(aux) => { + match aux { AuxVariable::RelevantYChunk(_) => (0..C) .map(|i| Self::Aux(AuxVariable::RelevantYChunk(i))) .collect(), - _ => vec![Self::Aux(aux)], - }) - .collect(), + AuxVariable::BitDecomposition(_) => { + // For bit decomposition, assume 8 bits + (0..8).map(|i| Self::Aux(AuxVariable::BitDecomposition(i))).collect() + } + // For all other AuxVariable variants, just one element + AuxVariable::LeftLookupOperand + | AuxVariable::RightLookupOperand + | AuxVariable::Product + | AuxVariable::WriteLookupOutputToRD + | AuxVariable::WritePCtoRD + | AuxVariable::NextPCJump + | AuxVariable::ShouldBranch + | AuxVariable::NextPC => vec![Self::Aux(aux)], + } + } _ => vec![variant], }) .collect() @@ -388,6 +403,13 @@ impl ConstraintInput for JoltR1CSInputs { AuxVariable::NextPCJump => &aux_polynomials.next_pc_jump, AuxVariable::ShouldBranch => &aux_polynomials.should_branch, AuxVariable::NextPC => &aux_polynomials.next_pc, + AuxVariable::BitDecomposition(_i) => { + // At this stage, we don't have separate storage for each bit. + // In practice, you'd store these as separate aux polynomials as well. + // For now, you might just panic or handle them after you've added + // the necessary aux polynomials. + panic!("BitDecomposition aux variables must be handled in the circuit builder") + } }, } } @@ -403,13 +425,15 @@ impl ConstraintInput for JoltR1CSInputs { AuxVariable::RightLookupOperand => &mut aux_polynomials.right_lookup_operand, AuxVariable::Product => &mut aux_polynomials.product, AuxVariable::RelevantYChunk(i) => &mut aux_polynomials.relevant_y_chunks[*i], - AuxVariable::WriteLookupOutputToRD => { - &mut aux_polynomials.write_lookup_output_to_rd - } + AuxVariable::WriteLookupOutputToRD => &mut aux_polynomials.write_lookup_output_to_rd, AuxVariable::WritePCtoRD => &mut aux_polynomials.write_pc_to_rd, AuxVariable::NextPCJump => &mut aux_polynomials.next_pc_jump, AuxVariable::ShouldBranch => &mut aux_polynomials.should_branch, AuxVariable::NextPC => &mut aux_polynomials.next_pc, + AuxVariable::BitDecomposition(_i) => { + // Similarly to get_ref, handle this once you have a place to store these bits. + panic!("BitDecomposition aux variables must be handled in the circuit builder") + } }, _ => panic!("get_ref_mut should only be invoked when computing aux polynomials"), } @@ -449,8 +473,17 @@ mod tests { .into_iter() .map(|i| JoltR1CSInputs::Aux(AuxVariable::RelevantYChunk(i))) .collect(), + AuxVariable::BitDecomposition(_) => (0..8) + .into_iter() + .map(|i| JoltR1CSInputs::Aux(AuxVariable::BitDecomposition(i))) + .collect(), _ => vec![JoltR1CSInputs::Aux(aux)], }) { + // For BitDecomposition we panic above, so skip them here + if let JoltR1CSInputs::Aux(AuxVariable::BitDecomposition(_)) = aux { + continue; + } + let ref_ptr = aux.get_ref(&jolt_polys) as *const DensePolynomial; let ref_mut_ptr = aux.get_ref_mut(&mut jolt_polys) as *const DensePolynomial; assert_eq!(ref_ptr, ref_mut_ptr, "Pointer mismatch for {:?}", aux);