Skip to content

Commit

Permalink
Merge pull request #2432 from o1-labs/volhovm/2427-runtime-tables-sup…
Browse files Browse the repository at this point in the history
…port

Add support for runtime tables
  • Loading branch information
volhovm authored Jul 19, 2024
2 parents b45b551 + 36a554f commit 7c6af10
Show file tree
Hide file tree
Showing 14 changed files with 315 additions and 108 deletions.
4 changes: 2 additions & 2 deletions ivc/src/ivc/lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<Ff: PrimeField> LookupTableID for IVCLookupTable<Ff> {
}

/// Converts a value to its index in the fixed table.
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> usize {
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> Option<usize> {
match self {
Self::SerLookupTable(lt) => lt.ix_by_value(value),
}
Expand All @@ -54,7 +54,7 @@ impl<Ff: PrimeField> LookupTableID for IVCLookupTable<Ff> {

impl<Ff: PrimeField> IVCLookupTable<Ff> {
/// Provides a full list of entries for the given table.
pub fn entries<F: PrimeField>(&self, domain_d1_size: u64) -> Vec<F> {
pub fn entries<F: PrimeField>(&self, domain_d1_size: u64) -> Option<Vec<F>> {
match self {
Self::SerLookupTable(lt) => lt.entries(domain_d1_size),
}
Expand Down
92 changes: 84 additions & 8 deletions msm/src/circuit_design/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@ pub struct WitnessBuilderEnv<
/// Lookup multiplicities, a vector of values `m_i` per lookup
/// table, where `m_i` is how many times the lookup value number
/// `i` was looked up.
pub lookup_multiplicities: BTreeMap<LT, Vec<F>>,
pub lookup_multiplicities: BTreeMap<LT, Vec<u64>>,

// @volhovm: It seems much more ineffective to store
// Vec<BTreeMap<LT, Vec<Logup>>> than BTreeMap<LT, Vec<Vec<Logup>>>.
/// Lookup requests. Each vector element represents one row, and
/// each row is a map from lookup type to a vector of concrete
/// lookups requested.
pub lookups: Vec<BTreeMap<LT, Vec<Logup<F, LT>>>>,

/// Values of the runtime tables. Each element (value) in the map
/// is a on-the-fly built column.
pub runtime_tables: BTreeMap<LT, Vec<Vec<F>>>,

/// A runtime table resolver; the inverse of `runtime_tables`:
/// maps runtime lookup table to its `usize` position in `runtime_tables`.
pub runtime_tables_resolver: BTreeMap<LT, BTreeMap<Vec<F>, usize>>,

/// Fixed values for selector columns. `fixed_selectors[i][j]` is the
/// value for row #j of the selector #i.
pub fixed_selectors: Vec<Vec<F>>,
Expand Down Expand Up @@ -184,8 +194,40 @@ impl<
> LookupCap<F, CIx, LT> for WitnessBuilderEnv<F, CIx, N_WIT, N_REL, N_DSEL, N_FSEL, LT>
{
fn lookup(&mut self, table_id: LT, value: Vec<<Self as ColAccessCap<F, CIx>>::Variable>) {
let value_ix = table_id.ix_by_value(&value);
self.lookup_multiplicities.get_mut(&table_id).unwrap()[value_ix] += F::one();
let value_ix = if table_id.is_fixed() {
table_id
.ix_by_value(&value)
.expect("Could not resolve lookup for a fixed table")
} else {
// For runtime tables, we check if this value was already
// present in a table; if yes, we return its index (row),
// otherwise we add this value to the runtime table and
// return its new index (runtime_table.len() == cur_height).
let cur_height = self.runtime_tables.get_mut(&table_id).unwrap().len();
let resolver = self.runtime_tables_resolver.get_mut(&table_id).unwrap();
if let Some(prev_index) = resolver.get_mut(&value) {
*prev_index
} else {
(*resolver).insert(value.clone(), cur_height);
self.runtime_tables
.get_mut(&table_id)
.unwrap()
.push(value.clone());
cur_height
}
};
{
let multiplicities = self.lookup_multiplicities.get_mut(&table_id).unwrap();
// Since we allow multiple lookups per row, runtime tables
// can in theory grow bigger than the domain size. We
// still collect multiplicities as if runtime table vector
// is not height-bounded, but we will split it into chunks
// later.
if !table_id.is_fixed() && value_ix > multiplicities.len() {
multiplicities.resize(value_ix, 0u64);
}
multiplicities[value_ix] += 1;
}
self.lookups
.last_mut()
.unwrap()
Expand Down Expand Up @@ -264,7 +306,11 @@ impl<
/// Getting multiplicities for range check tables less or equal than 15 bits.
pub fn get_lookup_multiplicities(&self, domain_size: usize, table_id: LT) -> Vec<F> {
let mut m = Vec::with_capacity(domain_size);
m.extend(self.lookup_multiplicities[&table_id].to_vec());
m.extend(
self.lookup_multiplicities[&table_id]
.iter()
.map(|x| F::from(*x)),
);
if table_id.length() < domain_size {
let n_repeated_dummy_value: usize = domain_size - table_id.length() - 1;
let repeated_dummy_value: Vec<F> = iter::repeat(-F::one())
Expand Down Expand Up @@ -292,10 +338,16 @@ impl<
pub fn create() -> Self {
let mut lookups_row = BTreeMap::new();
let mut lookup_multiplicities = BTreeMap::new();
let mut runtime_tables_resolver = BTreeMap::new();
let mut runtime_tables = BTreeMap::new();
let fixed_selectors = vec![vec![]; N_FSEL];
for table_id in LT::all_variants().into_iter() {
lookups_row.insert(table_id, Vec::new());
lookup_multiplicities.insert(table_id, vec![F::zero(); table_id.length()]);
lookup_multiplicities.insert(table_id, vec![0u64; table_id.length()]);
if !table_id.is_fixed() {
runtime_tables_resolver.insert(table_id, BTreeMap::new());
runtime_tables.insert(table_id, vec![]);
}
}

Self {
Expand All @@ -305,6 +357,8 @@ impl<

lookup_multiplicities,
lookups: vec![lookups_row],
runtime_tables_resolver,
runtime_tables,
fixed_selectors,
phantom_cix: PhantomData,
assert_mapper: Box::new(|x| x),
Expand Down Expand Up @@ -367,10 +421,23 @@ impl<
*witness
}

/// Return all runtime tables collected so far, padded to the domain size.
pub fn get_runtime_tables(&self, domain_size: usize) -> BTreeMap<LT, Vec<Vec<F>>> {
let mut runtime_tables: BTreeMap<_, _> = self.runtime_tables.clone();
for (_table_id, content) in runtime_tables.iter_mut() {
// We pad the runtime table with dummies if it's too small.
if content.len() < domain_size {
let dummy_value = content[0].clone(); // we assume runtime tables are never empty
content.append(&mut vec![dummy_value; domain_size - content.len()]);
}
}
runtime_tables
}

pub fn get_logup_witness(
&self,
domain_size: usize,
lookup_tables_data: BTreeMap<LT, Vec<F>>,
lookup_tables_data: BTreeMap<LT, Vec<Vec<F>>>,
) -> Vec<LogupWitness<F, LT>> {
// Building lookup values
let mut lookup_tables: BTreeMap<LT, Vec<Vec<Logup<F, LT>>>> = BTreeMap::new();
Expand Down Expand Up @@ -403,6 +470,15 @@ impl<
}

let mut lookup_multiplicities: BTreeMap<LT, Vec<F>> = BTreeMap::new();
for (table_id, m) in lookup_multiplicities.iter() {
if !table_id.is_fixed() {
// Temporary assertion; to be removed when we support bigger
// runtime table/RAMlookups functionality.
assert!(m.len() <= domain_size,
"We do not _yet_ support wrapping runtime tables that are bigger than domain size.");
}
}

// Counting multiplicities & adding fixed column into the last column of every table.
for (table_id, table) in lookup_tables.iter_mut() {
let lookup_m = self.get_lookup_multiplicities(domain_size, *table_id);
Expand All @@ -413,7 +489,7 @@ impl<
.map(|(i, v)| Logup {
table_id: *table_id,
numerator: -lookup_m[i],
value: vec![*v],
value: v.clone(),
});
*(table.last_mut().unwrap()) = lookup_t.collect();
}
Expand All @@ -439,7 +515,7 @@ impl<
pub fn get_proof_inputs(
&self,
domain_size: usize,
lookup_tables_data: BTreeMap<LT, Vec<F>>,
lookup_tables_data: BTreeMap<LT, Vec<Vec<F>>>,
) -> ProofInputs<N_WIT, F, LT> {
let evaluations = self.get_relation_witness(domain_size);
let logups = self.get_logup_witness(domain_size, lookup_tables_data);
Expand Down
6 changes: 3 additions & 3 deletions msm/src/fec/lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ impl<Ff: PrimeField> LookupTableID for LookupTable<Ff> {
}

/// Converts a value to its index in the fixed table.
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> usize {
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> Option<usize> {
let value = value[0];
assert!(self.is_member(value));
match self {
Some(match self {
Self::RangeCheck15 => TryFrom::try_from(value.to_biguint()).unwrap(),
Self::RangeCheck14Abs => {
if value < F::from(1u64 << 14) {
Expand All @@ -77,7 +77,7 @@ impl<Ff: PrimeField> LookupTableID for LookupTable<Ff> {
}
}
Self::RangeCheckFfHighest(_) => TryFrom::try_from(value.to_biguint()).unwrap(),
}
})
}

fn all_variants() -> Vec<Self> {
Expand Down
9 changes: 8 additions & 1 deletion msm/src/fec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,14 @@ mod tests {
// Fixed tables can be generated inside lookup_tables_data. Runtime should be generated here.
let mut lookup_tables_data = BTreeMap::new();
for table_id in LookupTable::<Ff1>::all_variants().into_iter() {
lookup_tables_data.insert(table_id, table_id.entries(domain_size as u64));
lookup_tables_data.insert(
table_id,
table_id
.entries(domain_size as u64)
.into_iter()
.map(|x| vec![x])
.collect(),
);
}
let proof_inputs = witness_env.get_proof_inputs(domain_size, lookup_tables_data);

Expand Down
6 changes: 3 additions & 3 deletions msm/src/ffa/lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ impl LookupTableID for LookupTable {
}

/// Converts a value to its index in the fixed table.
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> usize {
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> Option<usize> {
let value = value[0];
match self {
Some(match self {
Self::RangeCheck15 => TryFrom::try_from(value.to_biguint()).unwrap(),
Self::RangeCheck1BitSigned => {
if value == F::zero() {
Expand All @@ -57,7 +57,7 @@ impl LookupTableID for LookupTable {
panic!("Invalid value for rangecheck1abs")
}
}
}
})
}

fn all_variants() -> Vec<Self> {
Expand Down
9 changes: 8 additions & 1 deletion msm/src/ffa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ mod tests {
// Fixed tables can be generated inside lookup_tables_data. Runtime should be generated here.
let mut lookup_tables_data = BTreeMap::new();
for table_id in LookupTable::all_variants().into_iter() {
lookup_tables_data.insert(table_id, table_id.entries(domain_size as u64));
lookup_tables_data.insert(
table_id,
table_id
.entries(domain_size as u64)
.into_iter()
.map(|x| vec![x])
.collect(),
);
}
let proof_inputs = witness_env.get_proof_inputs(domain_size, lookup_tables_data);

Expand Down
49 changes: 29 additions & 20 deletions msm/src/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ where
}

/// Trait for lookup table variants
pub trait LookupTableID: Send + Sync + Copy + Hash + Eq + PartialEq + Ord + PartialOrd {
pub trait LookupTableID:
Send + Sync + Copy + Hash + Eq + PartialEq + Ord + PartialOrd + core::fmt::Debug
{
/// Assign a unique ID, as a u32 value
fn to_u32(&self) -> u32;

Expand All @@ -207,8 +209,9 @@ pub trait LookupTableID: Send + Sync + Copy + Hash + Eq + PartialEq + Ord + Part
/// Returns the length of each table.
fn length(&self) -> usize;

/// Given a value, returns an index of this value in the table.
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> usize;
/// Returns None if the table is runtime (and thus mapping value
/// -> ix is not known at compile time.
fn ix_by_value<F: PrimeField>(&self, value: &[F]) -> Option<usize>;

fn all_variants() -> Vec<Self>;
}
Expand Down Expand Up @@ -297,8 +300,8 @@ impl<'lt, G, ID: LookupTableID> IntoIterator for &'lt LookupProof<G, ID> {
/// |------------------------------------------|
/// | denominators |
/// | /--------------\ |
/// column * (\prod_{i = 1}^{N} (β + f_{i}(X))) =
/// \sum_{i = 1}^{N} m_{i} * \prod_{j = 1, j \neq i}^{N} (β + f_{j}(X))
/// column * (\prod_{i = 0}^{N} (β + f_{i}(X))) =
/// \sum_{i = 0}^{N} m_{i} * \prod_{j = 1, j \neq i}^{N} (β + f_{j}(X))
/// | |--------------------------------------------------|
/// | Inner part of rhs |
/// | |
Expand All @@ -310,13 +313,17 @@ impl<'lt, G, ID: LookupTableID> IntoIterator for &'lt LookupProof<G, ID> {
/// ```
/// It is because h(X) (column) is defined as:
/// ```text
/// n m_i(X)
/// h(X) ∑ ----------
/// i=1 β + f_i(X)
/// n m_i(X) n 1 m_0(ω^j)
/// h(X) =---------- = ∑ ------------ - -----------
/// i=0 β + f_i(X) i=1 β + f_i(ω^j) β + t(ω^j)
///```
/// For instance, if i = 2, we have
/// The first form is generic, the second is concrete with f_0 = t; m_0 = m; m_i = 1 for ≠ 1.
/// We will be thinking in the generic form.
///
/// For instance, if N = 2, we have
/// ```text
/// h(X) = m_1(X) / (β + f_1(X)) + m_2(X) / (β + f_{2}(X))
///
/// m_1(X) * (β + f_2(X)) + m_2(X) * (β + f_{1}(X))
/// = ----------------------------------------------
/// (β + f_2(X)) * (β + f_1(X))
Expand Down Expand Up @@ -362,10 +369,13 @@ pub fn combine_lookups<F: PrimeField, ID: LookupTableID>(
beta.clone() + combined_value + x.table_id.to_constraint()
})
.collect::<Vec<_>>();

// Compute `column * (\prod_{i = 1}^{N} (β + f_{i}(X)))`
let lhs = denominators
.iter()
.fold(curr_cell(column), |acc, x| acc * x.clone());

// Compute `\sum_{i = 0}^{N} m_{i} * \prod_{j = 1, j \neq i}^{N} (β + f_{j}(X))`
let rhs = lookups
.into_iter()
.enumerate()
Expand All @@ -386,6 +396,7 @@ pub fn combine_lookups<F: PrimeField, ID: LookupTableID>(
// Individual sums
.reduce(|x, y| x + y)
.unwrap_or(E::zero());

lhs - rhs
}

Expand Down Expand Up @@ -521,12 +532,6 @@ pub mod prover {
> = {
(&lookups)
.into_par_iter()
.filter(|lookup| {
// FIXME: this is ugly.
// Does not handle RAMLookup
let table_id = lookup.f[0][0].table_id;
table_id.is_fixed()
})
.map(|lookup| {
let table_id = lookup.f[0][0].table_id;
(
Expand Down Expand Up @@ -610,21 +615,25 @@ pub mod prover {
table_id,
value,
} = &f_i[j as usize];
// Compute r * x_{1} + r^2 x_{2} + ... r^{N} x_{N}
let combined_value: G::ScalarField =
// Compute x_{1} + r x_{2} + ... r^{N-1} x_{N}
// This is what we actually put into the `fixed_lookup_tables`.
let combined_value_pow0: G::ScalarField =
value.iter().rev().fold(G::ScalarField::zero(), |acc, y| {
acc * vector_lookup_combiner + y
}) * vector_lookup_combiner;
});
// Compute r * x_{1} + r^2 x_{2} + ... r^{N} x_{N}
let combined_value: G::ScalarField =
combined_value_pow0 * vector_lookup_combiner;
// add table id
let combined_value = combined_value + table_id.to_field::<G::ScalarField>();

// If last element and fixed lookup tables, we keep
// the *combined* value of the table.
if i == (n - 1) && table_id.is_fixed() {
if i == (n - 1) {
fixed_lookup_tables
.entry(*table_id)
.or_insert_with(Vec::new)
.push(value[0]);
.push(combined_value_pow0);
}

// β + a_{i}
Expand Down
Loading

0 comments on commit 7c6af10

Please sign in to comment.