Skip to content

Commit

Permalink
refactor to share code between the existing quarks impl and sparse qu…
Browse files Browse the repository at this point in the history
…arks
  • Loading branch information
sagar-a16z committed Nov 11, 2024
1 parent ff11a77 commit 634ac6d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 111 deletions.
2 changes: 1 addition & 1 deletion jolt-core/benches/grand_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ fn benchmark_verify<PCS, F, G, ProofTranscript>(
transcript = ProofTranscript::new(b"test_transcript");
let mut verifier_accumulator: VerifierOpeningAccumulator<F, PCS, ProofTranscript> =
VerifierOpeningAccumulator::new();
let (_, r_verifier) = QuarkGrandProduct::verify_grand_product(
let (_, r_verifier) = QuarkGrandProduct::verify_quark_grand_product(
&proof,
&known_products,
Some(&mut verifier_accumulator),
Expand Down
4 changes: 4 additions & 0 deletions jolt-core/src/subprotocols/grand_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ where

Self::verify_layers(&proof.layers, claim, transcript, r)
}

fn quark_poly(&self) -> Option<&[F]> {
None
}
}

pub trait BatchedGrandProductLayer<F, ProofTranscript>:
Expand Down
115 changes: 78 additions & 37 deletions jolt-core/src/subprotocols/grand_product_quarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct QuarkGrandProductProof<

pub struct QuarkGrandProduct<F: JoltField, ProofTranscript: Transcript> {
batch_size: usize,
quark_poly: Vec<F>,
quark_poly: Option<Vec<F>>,
base_layers: Vec<DenseInterleavedPolynomial<F>>,
_marker: PhantomData<ProofTranscript>,
}
Expand Down Expand Up @@ -93,7 +93,7 @@ where
if tree_depth <= num_layers {
return Self {
batch_size,
quark_poly: Vec::new(),
quark_poly: None,
base_layers: layers,
_marker: PhantomData,
};
Expand All @@ -104,70 +104,113 @@ where
let quark_poly = layers.pop().unwrap().coeffs;
Self {
batch_size,
quark_poly,
quark_poly: Some(quark_poly),
base_layers: layers,
_marker: PhantomData,
}
}
/// The number of layers in the grand product, in this case it is the log of the quark layer size plus the gkr layer depth.
fn num_layers(&self) -> usize {
todo!()
// self.quark_poly[0].len().log_2()
self.base_layers.len()
}

/// The claimed outputs of the grand products.
fn claimed_outputs(&self) -> Vec<F> {
if self.quark_poly.is_empty() {
if let Some(quark_poly) = &self.quark_poly {
let chunk_size = quark_poly.len() / self.batch_size;
quark_poly
.par_chunks(chunk_size)
.map(|chunk| chunk.iter().product())
.collect()
} else {
let top_layer = &self.base_layers[self.base_layers.len() - 1];
return top_layer
top_layer
.par_chunks(2)
.map(|chunk| chunk[0] * chunk[1])
.collect();
.collect()
}

let chunk_size = self.quark_poly.len() / self.batch_size;
self.quark_poly
.par_chunks(chunk_size)
.map(|chunk| chunk.iter().product())
.collect()
}

/// Returns an iterator over the layers of this batched grand product circuit.
/// Each layer is mutable so that its polynomials can be bound over the course
/// of proving.
#[allow(unreachable_code)]
fn layers(
&'_ mut self,
) -> impl Iterator<Item = &'_ mut dyn BatchedGrandProductLayer<F, ProofTranscript>> {
panic!("We don't use the default prover and so we don't need the generic iterator");
std::iter::empty()
self.base_layers
.iter_mut()
.map(|layer| layer as &mut dyn BatchedGrandProductLayer<F, ProofTranscript>)
.rev()
}

fn quark_poly(&self) -> Option<&[F]> {
self.quark_poly.as_deref()
}

/// Computes a batched grand product proof, layer by layer.
#[tracing::instrument(skip_all, name = "BatchedGrandProduct::prove_grand_product")]
fn prove_grand_product(
&mut self,
opening_accumulator: Option<&mut ProverOpeningAccumulator<F, ProofTranscript>>,
transcript: &mut ProofTranscript,
setup: Option<&PCS::Setup>,
) -> (BatchedGrandProductProof<PCS, ProofTranscript>, Vec<F>) {
let mut proof_layers = Vec::with_capacity(self.base_layers.len());
QuarkGrandProductBase::prove_quark_grand_product(
self,
opening_accumulator,
transcript,
setup,
)
}

#[tracing::instrument(skip_all, name = "BatchedGrandProduct::verify_grand_product")]
fn verify_grand_product(
proof: &BatchedGrandProductProof<PCS, ProofTranscript>,
claimed_outputs: &[F],
opening_accumulator: Option<&mut VerifierOpeningAccumulator<F, PCS, ProofTranscript>>,
transcript: &mut ProofTranscript,
_setup: Option<&PCS::Setup>,
) -> (F, Vec<F>) {
QuarkGrandProductBase::verify_quark_grand_product::<Self, PCS>(
proof,
claimed_outputs,
opening_accumulator,
transcript,
)
}
}

pub struct QuarkGrandProductBase<F: JoltField, ProofTranscript: Transcript> {
_marker: PhantomData<(F, ProofTranscript)>,
}

impl<F, ProofTranscript> QuarkGrandProductBase<F, ProofTranscript>
where
F: JoltField,
ProofTranscript: Transcript,
{
/// Computes a batched grand product proof, layer by layer.
#[tracing::instrument(skip_all, name = "QuarkGrandProduct::prove_grand_product")]
pub fn prove_quark_grand_product<PCS: CommitmentScheme<ProofTranscript, Field = F>>(
grand_product: &mut impl BatchedGrandProduct<F, PCS, ProofTranscript>,
opening_accumulator: Option<&mut ProverOpeningAccumulator<F, ProofTranscript>>,
transcript: &mut ProofTranscript,
setup: Option<&PCS::Setup>,
) -> (BatchedGrandProductProof<PCS, ProofTranscript>, Vec<F>) {
let mut proof_layers = Vec::with_capacity(grand_product.num_layers());

let outputs: Vec<F> =
<Self as BatchedGrandProduct<F, PCS, ProofTranscript>>::claimed_outputs(self);
let outputs: Vec<F> = grand_product.claimed_outputs();
transcript.append_scalars(&outputs);
let output_mle = DensePolynomial::new_padded(outputs);
let r_outputs: Vec<F> = transcript.challenge_vector(output_mle.get_num_vars());
let claim = output_mle.evaluate(&r_outputs);

// For proofs of polynomials of size less than 16 we support these with no quark proof
let (quark_option, mut random, mut claim) = if !self.quark_poly.is_empty() {
let (quark_option, mut random, mut claim) = if grand_product.quark_poly().is_some() {
// When doing the quark hybrid proof, we first prove the grand product of a layer of a polynomial which is N layers deep in the tree
// of a standard layered sumcheck grand product, then we use the sumcheck layers to prove via gkr layers that the random point opened
// by the quark proof is in fact the folded result of the base layer.
let (quark, random, quark_claim) =
QuarkGrandProductProof::<PCS, ProofTranscript>::prove(
&self.quark_poly,
grand_product.quark_poly().unwrap(),
r_outputs,
claim,
opening_accumulator.unwrap(),
Expand All @@ -179,7 +222,7 @@ where
(None, r_outputs, claim)
};

for layer in self.base_layers.iter_mut().rev() {
for layer in grand_product.layers() {
proof_layers.push(layer.prove_layer(&mut claim, &mut random, transcript));
}

Expand All @@ -193,14 +236,17 @@ where
}

/// Verifies the given grand product proof.
#[tracing::instrument(skip_all, name = "BatchedGrandProduct::verify_grand_product")]
fn verify_grand_product(
#[tracing::instrument(skip_all, name = "QuarkGrandProduct::verify_grand_product")]
pub fn verify_quark_grand_product<G, PCS>(
proof: &BatchedGrandProductProof<PCS, ProofTranscript>,
claimed_outputs: &[F],
opening_accumulator: Option<&mut VerifierOpeningAccumulator<F, PCS, ProofTranscript>>,
transcript: &mut ProofTranscript,
_setup: Option<&PCS::Setup>,
) -> (F, Vec<F>) {
) -> (F, Vec<F>)
where
PCS: CommitmentScheme<ProofTranscript, Field = F>,
G: BatchedGrandProduct<F, PCS, ProofTranscript>,
{
transcript.append_scalars(claimed_outputs);
let r_outputs: Vec<F> =
transcript.challenge_vector(claimed_outputs.len().next_power_of_two().log_2());
Expand All @@ -227,13 +273,8 @@ where
}
};

let (grand_product_claim, grand_product_r) = <Self as BatchedGrandProduct<
F,
PCS,
ProofTranscript,
>>::verify_layers(
&proof.layers, claim, transcript, rand
);
let (grand_product_claim, grand_product_r) =
G::verify_layers(&proof.layers, claim, transcript, rand);

(grand_product_claim, grand_product_r)
}
Expand Down Expand Up @@ -636,7 +677,7 @@ mod quark_grand_product_tests {
&known_products,
Some(&mut verifier_accumulator),
&mut verifier_transcript,
Some(&setup),
None,
);
assert!(verifier_accumulator
.reduce_and_verify(&setup, &batched_proof, &mut verifier_transcript)
Expand Down
91 changes: 18 additions & 73 deletions jolt-core/src/subprotocols/sparse_grand_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ use super::grand_product::{
use super::sumcheck::{BatchedCubicSumcheck, Bindable};
use crate::field::{JoltField, OptimizedMul};
use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::dense_mlpoly::DensePolynomial;
use crate::poly::opening_proof::{ProverOpeningAccumulator, VerifierOpeningAccumulator};
use crate::poly::sparse_interleaved_poly::SparseInterleavedPolynomial;
use crate::poly::split_eq_poly::SplitEqPolynomial;
use crate::poly::unipoly::UniPoly;
use crate::subprotocols::grand_product_quarks::QuarkGrandProductProof;
use crate::subprotocols::grand_product_quarks::QuarkGrandProductBase;
use crate::subprotocols::QuarkHybridLayerDepth;
use crate::utils::math::Math;
use crate::utils::thread::drop_in_background_thread;
use crate::utils::transcript::Transcript;
use rayon::prelude::*;
#[cfg(test)]
use crate::poly::dense_mlpoly::DensePolynomial;

/// A special bottom layer of a grand product, where boolean flags are used to
/// toggle the other inputs (fingerprints) going into the rest of the tree.
Expand Down Expand Up @@ -978,6 +979,10 @@ where
.rev()
}

fn quark_poly(&self) -> Option<&[F]> {
self.quark_poly.as_deref()
}

/// Computes a batched grand product proof, layer by layer.
#[tracing::instrument(skip_all, name = "BatchedGrandProduct::prove_grand_product")]
fn prove_grand_product(
Expand All @@ -986,47 +991,11 @@ where
transcript: &mut ProofTranscript,
setup: Option<&PCS::Setup>,
) -> (BatchedGrandProductProof<PCS, ProofTranscript>, Vec<F>) {
let mut proof_layers =
Vec::with_capacity(
<Self as BatchedGrandProduct<F, PCS, ProofTranscript>>::num_layers(self),
);

let outputs: Vec<F> =
<Self as BatchedGrandProduct<F, PCS, ProofTranscript>>::claimed_outputs(self);
transcript.append_scalars(&outputs);
let output_mle = DensePolynomial::new_padded(outputs);
let r_outputs: Vec<F> = transcript.challenge_vector(output_mle.get_num_vars());
let claim = output_mle.evaluate(&r_outputs);

let (quark_option, mut random, mut claim, _uses_quarks) = if self.quark_poly.is_some() {
// When doing the quark hybrid proof, we first prove the grand product of a layer of a polynomial which is N layers deep in the tree
// of a standard layered sumcheck grand product, then we use the sumcheck layers to prove via gkr layers that the random point opened
// by the quark proof is in fact the folded result of the base layer.
let (quark, random, quark_claim) =
QuarkGrandProductProof::<PCS, ProofTranscript>::prove(
self.quark_poly.as_ref().unwrap(),
r_outputs,
claim,
opening_accumulator.unwrap(),
transcript,
setup.unwrap(),
);
(Some(quark), random, quark_claim, true)
} else {
(None, r_outputs, claim, false)
};

let layers_iter = <Self as BatchedGrandProduct<F, PCS, ProofTranscript>>::layers(self);
for layer in layers_iter {
proof_layers.push(layer.prove_layer(&mut claim, &mut random, transcript));
}

(
BatchedGrandProductProof {
layers: proof_layers,
quark_proof: quark_option,
},
random,
QuarkGrandProductBase::prove_quark_grand_product(
self,
opening_accumulator,
transcript,
setup,
)
}

Expand All @@ -1039,36 +1008,12 @@ where
transcript: &mut ProofTranscript,
_setup: Option<&PCS::Setup>,
) -> (F, Vec<F>) {
transcript.append_scalars(claimed_outputs);
let r_outputs: Vec<F> =
transcript.challenge_vector(claimed_outputs.len().next_power_of_two().log_2());
let claim = DensePolynomial::new_padded(claimed_outputs.to_vec()).evaluate(&r_outputs);

let (claim, rand) = match proof.quark_proof.as_ref() {
Some(quark) => {
let v_len = quark.num_vars;
quark
.verify(
r_outputs,
claim,
opening_accumulator.unwrap(),
transcript,
v_len,
)
.unwrap_or_else(|e| panic!("quark verify error: {:?}", e))
}
None => (claim, r_outputs),
};

let (grand_product_claim, grand_product_r) = <Self as BatchedGrandProduct<
F,
PCS,
ProofTranscript,
>>::verify_layers(
&proof.layers, claim, transcript, rand
);

(grand_product_claim, grand_product_r)
QuarkGrandProductBase::verify_quark_grand_product::<Self, PCS>(
proof,
claimed_outputs,
opening_accumulator,
transcript,
)
}

fn verify_sumcheck_claim(
Expand Down

0 comments on commit 634ac6d

Please sign in to comment.